sayakpaul HF staff Hila commited on
Commit
c4b2b37
1 Parent(s): 7f3e838

Co-authored-by: Hila <hilach70@gmail.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -2
  2. Transformer-Explainability/BERT_explainability.ipynb +581 -0
  3. Transformer-Explainability/BERT_explainability/modules/BERT/BERT.py +748 -0
  4. Transformer-Explainability/BERT_explainability/modules/BERT/BERT_cls_lrp.py +240 -0
  5. Transformer-Explainability/BERT_explainability/modules/BERT/BERT_orig_lrp.py +748 -0
  6. Transformer-Explainability/BERT_explainability/modules/BERT/BertForSequenceClassification.py +241 -0
  7. Transformer-Explainability/BERT_explainability/modules/BERT/ExplanationGenerator.py +165 -0
  8. Transformer-Explainability/BERT_explainability/modules/__init__.py +0 -0
  9. Transformer-Explainability/BERT_explainability/modules/layers_lrp.py +352 -0
  10. Transformer-Explainability/BERT_explainability/modules/layers_ours.py +373 -0
  11. Transformer-Explainability/BERT_params/boolq.json +26 -0
  12. Transformer-Explainability/BERT_params/boolq_baas.json +26 -0
  13. Transformer-Explainability/BERT_params/boolq_bert.json +32 -0
  14. Transformer-Explainability/BERT_params/boolq_soft.json +21 -0
  15. Transformer-Explainability/BERT_params/cose_bert.json +30 -0
  16. Transformer-Explainability/BERT_params/cose_multiclass.json +35 -0
  17. Transformer-Explainability/BERT_params/esnli_bert.json +28 -0
  18. Transformer-Explainability/BERT_params/evidence_inference.json +26 -0
  19. Transformer-Explainability/BERT_params/evidence_inference_bert.json +33 -0
  20. Transformer-Explainability/BERT_params/evidence_inference_soft.json +22 -0
  21. Transformer-Explainability/BERT_params/fever.json +26 -0
  22. Transformer-Explainability/BERT_params/fever_baas.json +25 -0
  23. Transformer-Explainability/BERT_params/fever_bert.json +32 -0
  24. Transformer-Explainability/BERT_params/fever_soft.json +21 -0
  25. Transformer-Explainability/BERT_params/movies.json +26 -0
  26. Transformer-Explainability/BERT_params/movies_baas.json +26 -0
  27. Transformer-Explainability/BERT_params/movies_bert.json +32 -0
  28. Transformer-Explainability/BERT_params/movies_soft.json +21 -0
  29. Transformer-Explainability/BERT_params/multirc.json +26 -0
  30. Transformer-Explainability/BERT_params/multirc_baas.json +26 -0
  31. Transformer-Explainability/BERT_params/multirc_bert.json +32 -0
  32. Transformer-Explainability/BERT_params/multirc_soft.json +21 -0
  33. Transformer-Explainability/BERT_rationale_benchmark/__init__.py +0 -0
  34. Transformer-Explainability/BERT_rationale_benchmark/metrics.py +1007 -0
  35. Transformer-Explainability/BERT_rationale_benchmark/models/model_utils.py +186 -0
  36. Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/__init__.py +0 -0
  37. Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/bert_pipeline.py +852 -0
  38. Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_train.py +235 -0
  39. Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_utils.py +1045 -0
  40. Transformer-Explainability/BERT_rationale_benchmark/models/sequence_taggers.py +78 -0
  41. Transformer-Explainability/BERT_rationale_benchmark/utils.py +251 -0
  42. Transformer-Explainability/DeiT.PNG +0 -0
  43. Transformer-Explainability/DeiT_example.ipynb +0 -0
  44. Transformer-Explainability/LICENSE +21 -0
  45. Transformer-Explainability/README.md +153 -0
  46. Transformer-Explainability/Transformer_explainability.ipynb +0 -0
  47. Transformer-Explainability/baselines/ViT/ViT_LRP.py +535 -0
  48. Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py +107 -0
  49. Transformer-Explainability/baselines/ViT/ViT_new.py +329 -0
  50. Transformer-Explainability/baselines/ViT/ViT_orig_LRP.py +508 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Comparative Explainability
3
- emoji: 📈
4
  colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.34.0
8
  app_file: app.py
 
1
  ---
2
  title: Comparative Explainability
3
+ emoji: 🏆
4
  colorFrom: red
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.34.0
8
  app_file: app.py
Transformer-Explainability/BERT_explainability.ipynb ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "BERT-explainability.ipynb",
7
+ "provenance": [],
8
+ "authorship_tag": "ABX9TyOm8dIRrumd5XNcc+fntVA5",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "accelerator": "GPU"
16
+ },
17
+ "cells": [
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {
21
+ "id": "view-in-github",
22
+ "colab_type": "text"
23
+ },
24
+ "source": [
25
+ "<a href=\"https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "metadata": {
31
+ "colab": {
32
+ "base_uri": "https://localhost:8080/"
33
+ },
34
+ "id": "YCdGaMuy56TA",
35
+ "outputId": "8f802262-55eb-4366-b772-89c4756224b3"
36
+ },
37
+ "source": [
38
+ "!git clone https://github.com/hila-chefer/Transformer-Explainability.git\n",
39
+ "\n",
40
+ "import os\n",
41
+ "os.chdir(f'./Transformer-Explainability')\n",
42
+ "\n",
43
+ "!pip install -r requirements.txt\n",
44
+ "!pip install captum"
45
+ ],
46
+ "execution_count": 1,
47
+ "outputs": [
48
+ {
49
+ "output_type": "stream",
50
+ "name": "stdout",
51
+ "text": [
52
+ "fatal: destination path 'Transformer-Explainability' already exists and is not an empty directory.\n",
53
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
54
+ "Requirement already satisfied: Pillow>=8.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 1)) (9.4.0)\n",
55
+ "Requirement already satisfied: einops==0.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 2)) (0.3.0)\n",
56
+ "Requirement already satisfied: h5py==2.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 3)) (2.8.0)\n",
57
+ "Requirement already satisfied: imageio==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 4)) (2.9.0)\n",
58
+ "Collecting matplotlib==3.3.2\n",
59
+ " Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
60
+ "Requirement already satisfied: opencv_python in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 6)) (4.6.0.66)\n",
61
+ "Requirement already satisfied: scikit_image==0.17.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 7)) (0.17.2)\n",
62
+ "Requirement already satisfied: scipy==1.5.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 8)) (1.5.2)\n",
63
+ "Requirement already satisfied: sklearn in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 9)) (0.0.post1)\n",
64
+ "Requirement already satisfied: torch==1.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 10)) (1.7.0)\n",
65
+ "Requirement already satisfied: torchvision==0.8.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 11)) (0.8.1)\n",
66
+ "Requirement already satisfied: tqdm==4.51.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 12)) (4.51.0)\n",
67
+ "Requirement already satisfied: transformers==3.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 13)) (3.5.1)\n",
68
+ "Requirement already satisfied: utils==1.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 14)) (1.0.1)\n",
69
+ "Requirement already satisfied: Pygments>=2.7.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 15)) (2.14.0)\n",
70
+ "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.21.6)\n",
71
+ "Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.15.0)\n",
72
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (1.4.4)\n",
73
+ "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (3.0.9)\n",
74
+ "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.8.2)\n",
75
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (0.11.0)\n",
76
+ "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2022.12.7)\n",
77
+ "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (3.0)\n",
78
+ "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (2022.10.10)\n",
79
+ "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (1.4.1)\n",
80
+ "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.6)\n",
81
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (4.4.0)\n",
82
+ "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.16.0)\n",
83
+ "Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.0.53)\n",
84
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.19.6)\n",
85
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.9.0)\n",
86
+ "Requirement already satisfied: sentencepiece==0.1.91 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.1.91)\n",
87
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (21.3)\n",
88
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2022.6.2)\n",
89
+ "Requirement already satisfied: tokenizers==0.9.3 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.9.3)\n",
90
+ "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2.25.1)\n",
91
+ "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (4.0.0)\n",
92
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (1.24.3)\n",
93
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (2.10)\n",
94
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (1.2.0)\n",
95
+ "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (7.1.2)\n",
96
+ "Installing collected packages: matplotlib\n",
97
+ " Attempting uninstall: matplotlib\n",
98
+ " Found existing installation: matplotlib 3.6.3\n",
99
+ " Uninstalling matplotlib-3.6.3:\n",
100
+ " Successfully uninstalled matplotlib-3.6.3\n",
101
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
102
+ "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
103
+ "\u001b[0mSuccessfully installed matplotlib-3.3.2\n",
104
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
105
+ "Requirement already satisfied: captum in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
106
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum) (3.3.2)\n",
107
+ "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum) (1.7.0)\n",
108
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum) (1.21.6)\n",
109
+ "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.16.0)\n",
110
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (4.4.0)\n",
111
+ "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.6)\n",
112
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (0.11.0)\n",
113
+ "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (9.4.0)\n",
114
+ "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2022.12.7)\n",
115
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (1.4.4)\n",
116
+ "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (3.0.9)\n",
117
+ "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2.8.2)\n",
118
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.15.0)\n"
119
+ ]
120
+ }
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "source": [
126
+ "!pip install captum==0.6.0\n",
127
+ "!pip install matplotlib==3.3.2"
128
+ ],
129
+ "metadata": {
130
+ "id": "zDPnh4lofcNw",
131
+ "outputId": "3d585bbc-ff3b-4a09-b5bf-57bb4d46e830",
132
+ "colab": {
133
+ "base_uri": "https://localhost:8080/"
134
+ }
135
+ },
136
+ "execution_count": 9,
137
+ "outputs": [
138
+ {
139
+ "output_type": "stream",
140
+ "name": "stdout",
141
+ "text": [
142
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
143
+ "Requirement already satisfied: captum==0.6.0 in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
144
+ "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.7.0)\n",
145
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.21.6)\n",
146
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (3.6.3)\n",
147
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (4.4.0)\n",
148
+ "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.16.0)\n",
149
+ "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.6)\n",
150
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.4.4)\n",
151
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.0.7)\n",
152
+ "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (9.4.0)\n",
153
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (2.8.2)\n",
154
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (21.3)\n",
155
+ "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (3.0.9)\n",
156
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (4.38.0)\n",
157
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (0.11.0)\n",
158
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7->matplotlib->captum==0.6.0) (1.15.0)\n",
159
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
160
+ "Collecting matplotlib==3.3.2\n",
161
+ " Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
162
+ "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (9.4.0)\n",
163
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (0.11.0)\n",
164
+ "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.21.6)\n",
165
+ "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (3.0.9)\n",
166
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.4.4)\n",
167
+ "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2022.12.7)\n",
168
+ "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2.8.2)\n",
169
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.2) (1.15.0)\n",
170
+ "Installing collected packages: matplotlib\n",
171
+ " Attempting uninstall: matplotlib\n",
172
+ " Found existing installation: matplotlib 3.6.3\n",
173
+ " Uninstalling matplotlib-3.6.3:\n",
174
+ " Successfully uninstalled matplotlib-3.6.3\n",
175
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
176
+ "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
177
+ "\u001b[0mSuccessfully installed matplotlib-3.3.2\n"
178
+ ]
179
+ }
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "metadata": {
185
+ "id": "4-XGl_Zw6Aht"
186
+ },
187
+ "source": [
188
+ "from transformers import BertTokenizer\n",
189
+ "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
190
+ "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n",
191
+ "from transformers import BertTokenizer\n",
192
+ "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
193
+ "from transformers import AutoTokenizer\n",
194
+ "\n",
195
+ "from captum.attr import visualization\n",
196
+ "import torch"
197
+ ],
198
+ "execution_count": 10,
199
+ "outputs": []
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "metadata": {
204
+ "id": "VakYjrkC6C3S"
205
+ },
206
+ "source": [
207
+ "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n",
208
+ "model.eval()\n",
209
+ "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\n",
210
+ "# initialize the explanations generator\n",
211
+ "explanations = Generator(model)\n",
212
+ "\n",
213
+ "classifications = [\"NEGATIVE\", \"POSITIVE\"]\n"
214
+ ],
215
+ "execution_count": 11,
216
+ "outputs": []
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {
221
+ "id": "jGRp376FPOvV"
222
+ },
223
+ "source": [
224
+ "#Positive sentiment example"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "metadata": {
230
+ "id": "uSLZtv546H2z",
231
+ "colab": {
232
+ "base_uri": "https://localhost:8080/",
233
+ "height": 219
234
+ },
235
+ "outputId": "26712e90-0b77-40b0-a908-fef13dd88bcd"
236
+ },
237
+ "source": [
238
+ "# encode a sentence\n",
239
+ "text_batch = [\"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.\"]\n",
240
+ "encoding = tokenizer(text_batch, return_tensors='pt')\n",
241
+ "input_ids = encoding['input_ids'].to(\"cuda\")\n",
242
+ "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
243
+ "\n",
244
+ "# true class is positive - 1\n",
245
+ "true_class = 1\n",
246
+ "\n",
247
+ "# generate an explanation for the input\n",
248
+ "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
249
+ "# normalize scores\n",
250
+ "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
251
+ "\n",
252
+ "# get the model classification\n",
253
+ "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
254
+ "classification = output.argmax(dim=-1).item()\n",
255
+ "# get class name\n",
256
+ "class_name = classifications[classification]\n",
257
+ "# if the classification is negative, higher explanation scores are more negative\n",
258
+ "# flip for visualization\n",
259
+ "if class_name == \"NEGATIVE\":\n",
260
+ " expl *= (-1)\n",
261
+ "\n",
262
+ "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
263
+ "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
264
+ "vis_data_records = [visualization.VisualizationDataRecord(\n",
265
+ " expl,\n",
266
+ " output[0][classification],\n",
267
+ " classification,\n",
268
+ " true_class,\n",
269
+ " true_class,\n",
270
+ " 1, \n",
271
+ " tokens,\n",
272
+ " 1)]\n",
273
+ "visualization.visualize_text(vis_data_records)"
274
+ ],
275
+ "execution_count": 12,
276
+ "outputs": [
277
+ {
278
+ "output_type": "stream",
279
+ "name": "stdout",
280
+ "text": [
281
+ "[('[CLS]', 0.0), ('this', 0.4267549514770508), ('movie', 0.30920878052711487), ('was', 0.2684089243412018), ('the', 0.33637329936027527), ('best', 0.6280889511108398), ('movie', 0.28546375036239624), ('i', 0.1863601952791214), ('have', 0.10115814208984375), ('ever', 0.1419338583946228), ('seen', 0.1898290067911148), ('!', 0.5944811105728149), ('some', 0.003896803595125675), ('scenes', 0.033401958644390106), ('were', 0.018588582053780556), ('ridiculous', 0.018908796831965446), (',', 0.0), ('but', 0.42920616269111633), ('acting', 0.43855082988739014), ('was', 0.500239372253418), ('great', 1.0), ('.', 0.014817383140325546), ('[SEP]', 0.0868983045220375)]\n"
282
+ ]
283
+ },
284
+ {
285
+ "output_type": "display_data",
286
+ "data": {
287
+ "text/plain": [
288
+ "<IPython.core.display.HTML object>"
289
+ ],
290
+ "text/html": [
291
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0; line-height:1.75\"><font color=\"black\"> best </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> have </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ever </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> seen </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ! </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> scenes </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ridiculous </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> acting </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> great </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
292
+ ]
293
+ },
294
+ "metadata": {}
295
+ },
296
+ {
297
+ "output_type": "execute_result",
298
+ "data": {
299
+ "text/plain": [
300
+ "<IPython.core.display.HTML object>"
301
+ ],
302
+ "text/html": [
303
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0; line-height:1.75\"><font color=\"black\"> best </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> have </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ever </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> seen </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ! </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> scenes </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ridiculous </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0; line-height:1.75\"><font color=\"black\"> acting </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> great </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
304
+ ]
305
+ },
306
+ "metadata": {},
307
+ "execution_count": 12
308
+ }
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "metadata": {
314
+ "id": "oO_k1BtSPVt3"
315
+ },
316
+ "source": [
317
+ "#Negative sentiment example"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "metadata": {
323
+ "colab": {
324
+ "base_uri": "https://localhost:8080/",
325
+ "height": 219
326
+ },
327
+ "id": "gD4xcvovI1KI",
328
+ "outputId": "e4a50a94-da4c-460e-b602-052b09cec28f"
329
+ },
330
+ "source": [
331
+ "# encode a sentence\n",
332
+ "text_batch = [\"I really didn't like this movie. Some of the actors were good, but overall the movie was boring.\"]\n",
333
+ "encoding = tokenizer(text_batch, return_tensors='pt')\n",
334
+ "input_ids = encoding['input_ids'].to(\"cuda\")\n",
335
+ "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
336
+ "\n",
337
+ "# generate an explanation for the input\n",
338
+ "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
339
+ "# normalize scores\n",
340
+ "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
341
+ "\n",
342
+ "# get the model classification\n",
343
+ "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
344
+ "classification = output.argmax(dim=-1).item()\n",
345
+ "# get class name\n",
346
+ "class_name = classifications[classification]\n",
347
+ "# if the classification is negative, higher explanation scores are more negative\n",
348
+ "# flip for visualization\n",
349
+ "if class_name == \"NEGATIVE\":\n",
350
+ " expl *= (-1)\n",
351
+ "\n",
352
+ "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
353
+ "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
354
+ "vis_data_records = [visualization.VisualizationDataRecord(\n",
355
+ " expl,\n",
356
+ " output[0][classification],\n",
357
+ " classification,\n",
358
+ " 1,\n",
359
+ " 1,\n",
360
+ " 1, \n",
361
+ " tokens,\n",
362
+ " 1)]\n",
363
+ "visualization.visualize_text(vis_data_records)"
364
+ ],
365
+ "execution_count": 13,
366
+ "outputs": [
367
+ {
368
+ "output_type": "stream",
369
+ "name": "stdout",
370
+ "text": [
371
+ "[('[CLS]', -0.0), ('i', -0.19109757244586945), ('really', -0.1888734996318817), ('didn', -0.2894313633441925), (\"'\", -0.006574898026883602), ('t', -0.36788827180862427), ('like', -0.15249046683311462), ('this', -0.18922168016433716), ('movie', -0.0404353104531765), ('.', -0.019592661410570145), ('some', -0.02311306819319725), ('of', -0.0), ('the', -0.02295113168656826), ('actors', -0.09577538073062897), ('were', -0.013370633125305176), ('good', -0.0323222391307354), (',', -0.004366681911051273), ('but', -0.05878860130906105), ('overall', -0.33596664667129517), ('the', -0.21820111572742462), ('movie', -0.05482065677642822), ('was', -0.6248231530189514), ('boring', -1.0), ('.', -0.031107747927308083), ('[SEP]', -0.052539654076099396)]\n"
372
+ ]
373
+ },
374
+ {
375
+ "output_type": "display_data",
376
+ "data": {
377
+ "text/plain": [
378
+ "<IPython.core.display.HTML object>"
379
+ ],
380
+ "text/html": [
381
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> really </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> didn </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0; line-height:1.75\"><font color=\"black\"> like </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> actors </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> good </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> overall </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> boring </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
382
+ ]
383
+ },
384
+ "metadata": {}
385
+ },
386
+ {
387
+ "output_type": "execute_result",
388
+ "data": {
389
+ "text/plain": [
390
+ "<IPython.core.display.HTML object>"
391
+ ],
392
+ "text/html": [
393
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> really </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> didn </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0; line-height:1.75\"><font color=\"black\"> like </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> this </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> some </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> actors </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> were </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> good </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> , </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> but </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> overall </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> the </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> boring </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
394
+ ]
395
+ },
396
+ "metadata": {},
397
+ "execution_count": 13
398
+ }
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "markdown",
403
+ "source": [
404
+ "# Choosing class for visualization example"
405
+ ],
406
+ "metadata": {
407
+ "id": "UUn2_SMPNG-Y"
408
+ }
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "source": [
413
+ "# encode a sentence\n",
414
+ "text_batch = [\"I hate that I love you.\"]\n",
415
+ "encoding = tokenizer(text_batch, return_tensors='pt')\n",
416
+ "input_ids = encoding['input_ids'].to(\"cuda\")\n",
417
+ "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
418
+ "\n",
419
+ "# true class is positive - 1\n",
420
+ "true_class = 1\n",
421
+ "\n",
422
+ "# generate an explanation for the input\n",
423
+ "target_class = 0\n",
424
+ "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
425
+ "# normalize scores\n",
426
+ "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
427
+ "\n",
428
+ "# get the model classification\n",
429
+ "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
430
+ "\n",
431
+ "# get class name\n",
432
+ "class_name = classifications[target_class]\n",
433
+ "# if the classification is negative, higher explanation scores are more negative\n",
434
+ "# flip for visualization\n",
435
+ "if class_name == \"NEGATIVE\":\n",
436
+ " expl *= (-1)\n",
437
+ "\n",
438
+ "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
439
+ "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
440
+ "vis_data_records = [visualization.VisualizationDataRecord(\n",
441
+ " expl,\n",
442
+ " output[0][classification],\n",
443
+ " classification,\n",
444
+ " true_class,\n",
445
+ " true_class,\n",
446
+ " 1, \n",
447
+ " tokens,\n",
448
+ " 1)]\n",
449
+ "visualization.visualize_text(vis_data_records)"
450
+ ],
451
+ "metadata": {
452
+ "id": "VQVmMFnzhPoV",
453
+ "outputId": "26a43f8a-340c-4821-b39c-80105a565810",
454
+ "colab": {
455
+ "base_uri": "https://localhost:8080/",
456
+ "height": 219
457
+ }
458
+ },
459
+ "execution_count": 14,
460
+ "outputs": [
461
+ {
462
+ "output_type": "stream",
463
+ "name": "stdout",
464
+ "text": [
465
+ "[('[CLS]', -0.0), ('i', -0.19790242612361908), ('hate', -1.0), ('that', -0.40287283062934875), ('i', -0.12505637109279633), ('love', -0.1307140290737152), ('you', -0.05467141419649124), ('.', -6.108225989009952e-06), ('[SEP]', -0.0)]\n"
466
+ ]
467
+ },
468
+ {
469
+ "output_type": "display_data",
470
+ "data": {
471
+ "text/plain": [
472
+ "<IPython.core.display.HTML object>"
473
+ ],
474
+ "text/html": [
475
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
476
+ ]
477
+ },
478
+ "metadata": {}
479
+ },
480
+ {
481
+ "output_type": "execute_result",
482
+ "data": {
483
+ "text/plain": [
484
+ "<IPython.core.display.HTML object>"
485
+ ],
486
+ "text/html": [
487
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
488
+ ]
489
+ },
490
+ "metadata": {},
491
+ "execution_count": 14
492
+ }
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "source": [
498
+ "# encode a sentence\n",
499
+ "text_batch = [\"I hate that I love you.\"]\n",
500
+ "encoding = tokenizer(text_batch, return_tensors='pt')\n",
501
+ "input_ids = encoding['input_ids'].to(\"cuda\")\n",
502
+ "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
503
+ "\n",
504
+ "# true class is positive - 1\n",
505
+ "true_class = 1\n",
506
+ "\n",
507
+ "# generate an explanation for the input\n",
508
+ "target_class = 1\n",
509
+ "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
510
+ "# normalize scores\n",
511
+ "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
512
+ "\n",
513
+ "# get the model classification\n",
514
+ "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
515
+ "\n",
516
+ "# get class name\n",
517
+ "class_name = classifications[target_class]\n",
518
+ "# if the classification is negative, higher explanation scores are more negative\n",
519
+ "# flip for visualization\n",
520
+ "if class_name == \"NEGATIVE\":\n",
521
+ " expl *= (-1)\n",
522
+ "\n",
523
+ "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
524
+ "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
525
+ "vis_data_records = [visualization.VisualizationDataRecord(\n",
526
+ " expl,\n",
527
+ " output[0][classification],\n",
528
+ " classification,\n",
529
+ " true_class,\n",
530
+ " true_class,\n",
531
+ " 1, \n",
532
+ " tokens,\n",
533
+ " 1)]\n",
534
+ "visualization.visualize_text(vis_data_records)"
535
+ ],
536
+ "metadata": {
537
+ "id": "WiQAWw0-imCg",
538
+ "outputId": "a8c66996-dcd0-4132-a8b0-2346d9bf9c7b",
539
+ "colab": {
540
+ "base_uri": "https://localhost:8080/",
541
+ "height": 219
542
+ }
543
+ },
544
+ "execution_count": 15,
545
+ "outputs": [
546
+ {
547
+ "output_type": "stream",
548
+ "name": "stdout",
549
+ "text": [
550
+ "[('[CLS]', 0.0), ('i', 0.2725590765476227), ('hate', 0.17270179092884064), ('that', 0.23211266100406647), ('i', 0.17642731964588165), ('love', 1.0), ('you', 0.2465524971485138), ('.', 0.0), ('[SEP]', 0.00015733683540020138)]\n"
551
+ ]
552
+ },
553
+ {
554
+ "output_type": "display_data",
555
+ "data": {
556
+ "text/plain": [
557
+ "<IPython.core.display.HTML object>"
558
+ ],
559
+ "text/html": [
560
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
561
+ ]
562
+ },
563
+ "metadata": {}
564
+ },
565
+ {
566
+ "output_type": "execute_result",
567
+ "data": {
568
+ "text/plain": [
569
+ "<IPython.core.display.HTML object>"
570
+ ],
571
+ "text/html": [
572
+ "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px; padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 60%)\"></span> Negative <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(0, 75%, 100%)\"></span> Neutral <span style=\"display: inline-block; width: 10px; height: 10px; border: 1px solid; background-color: hsl(120, 75%, 50%)\"></span> Positive </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> hate </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> that </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> i </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0; line-height:1.75\"><font color=\"black\"> love </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>"
573
+ ]
574
+ },
575
+ "metadata": {},
576
+ "execution_count": 15
577
+ }
578
+ ]
579
+ }
580
+ ]
581
+ }
Transformer-Explainability/BERT_explainability/modules/BERT/BERT.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from BERT_explainability.modules.layers_ours import *
8
+ from torch import nn
9
+ from transformers import BertConfig, BertPreTrainedModel, PreTrainedModel
10
+ from transformers.modeling_outputs import (BaseModelOutput,
11
+ BaseModelOutputWithPooling)
12
+
13
+ ACT2FN = {
14
+ "relu": ReLU,
15
+ "tanh": Tanh,
16
+ "gelu": GELU,
17
+ }
18
+
19
+
20
+ def get_activation(activation_string):
21
+ if activation_string in ACT2FN:
22
+ return ACT2FN[activation_string]
23
+ else:
24
+ raise KeyError(
25
+ "function {} not found in ACT2FN mapping {}".format(
26
+ activation_string, list(ACT2FN.keys())
27
+ )
28
+ )
29
+
30
+
31
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
32
+ # adding residual consideration
33
+ num_tokens = all_layer_matrices[0].shape[1]
34
+ batch_size = all_layer_matrices[0].shape[0]
35
+ eye = (
36
+ torch.eye(num_tokens)
37
+ .expand(batch_size, num_tokens, num_tokens)
38
+ .to(all_layer_matrices[0].device)
39
+ )
40
+ all_layer_matrices = [
41
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
42
+ ]
43
+ all_layer_matrices = [
44
+ all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ for i in range(len(all_layer_matrices))
46
+ ]
47
+ joint_attention = all_layer_matrices[start_layer]
48
+ for i in range(start_layer + 1, len(all_layer_matrices)):
49
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
50
+ return joint_attention
51
+
52
+
53
+ class BertEmbeddings(nn.Module):
54
+ """Construct the embeddings from word, position and token_type embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.word_embeddings = nn.Embedding(
59
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
60
+ )
61
+ self.position_embeddings = nn.Embedding(
62
+ config.max_position_embeddings, config.hidden_size
63
+ )
64
+ self.token_type_embeddings = nn.Embedding(
65
+ config.type_vocab_size, config.hidden_size
66
+ )
67
+
68
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
69
+ # any TensorFlow checkpoint file
70
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
71
+ self.dropout = Dropout(config.hidden_dropout_prob)
72
+
73
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
74
+ self.register_buffer(
75
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
76
+ )
77
+
78
+ self.add1 = Add()
79
+ self.add2 = Add()
80
+
81
+ def forward(
82
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
83
+ ):
84
+ if input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ else:
87
+ input_shape = inputs_embeds.size()[:-1]
88
+
89
+ seq_length = input_shape[1]
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[:, :seq_length]
93
+
94
+ if token_type_ids is None:
95
+ token_type_ids = torch.zeros(
96
+ input_shape, dtype=torch.long, device=self.position_ids.device
97
+ )
98
+
99
+ if inputs_embeds is None:
100
+ inputs_embeds = self.word_embeddings(input_ids)
101
+ position_embeddings = self.position_embeddings(position_ids)
102
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
103
+
104
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
105
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
106
+ embeddings = self.add2([embeddings, inputs_embeds])
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+ def relprop(self, cam, **kwargs):
112
+ cam = self.dropout.relprop(cam, **kwargs)
113
+ cam = self.LayerNorm.relprop(cam, **kwargs)
114
+
115
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
116
+ (cam) = self.add2.relprop(cam, **kwargs)
117
+
118
+ return cam
119
+
120
+
121
+ class BertEncoder(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.config = config
125
+ self.layer = nn.ModuleList(
126
+ [BertLayer(config) for _ in range(config.num_hidden_layers)]
127
+ )
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states,
132
+ attention_mask=None,
133
+ head_mask=None,
134
+ encoder_hidden_states=None,
135
+ encoder_attention_mask=None,
136
+ output_attentions=False,
137
+ output_hidden_states=False,
138
+ return_dict=False,
139
+ ):
140
+ all_hidden_states = () if output_hidden_states else None
141
+ all_attentions = () if output_attentions else None
142
+ for i, layer_module in enumerate(self.layer):
143
+ if output_hidden_states:
144
+ all_hidden_states = all_hidden_states + (hidden_states,)
145
+
146
+ layer_head_mask = head_mask[i] if head_mask is not None else None
147
+
148
+ if getattr(self.config, "gradient_checkpointing", False):
149
+
150
+ def create_custom_forward(module):
151
+ def custom_forward(*inputs):
152
+ return module(*inputs, output_attentions)
153
+
154
+ return custom_forward
155
+
156
+ layer_outputs = torch.utils.checkpoint.checkpoint(
157
+ create_custom_forward(layer_module),
158
+ hidden_states,
159
+ attention_mask,
160
+ layer_head_mask,
161
+ )
162
+ else:
163
+ layer_outputs = layer_module(
164
+ hidden_states,
165
+ attention_mask,
166
+ layer_head_mask,
167
+ output_attentions,
168
+ )
169
+ hidden_states = layer_outputs[0]
170
+ if output_attentions:
171
+ all_attentions = all_attentions + (layer_outputs[1],)
172
+
173
+ if output_hidden_states:
174
+ all_hidden_states = all_hidden_states + (hidden_states,)
175
+
176
+ if not return_dict:
177
+ return tuple(
178
+ v
179
+ for v in [hidden_states, all_hidden_states, all_attentions]
180
+ if v is not None
181
+ )
182
+ return BaseModelOutput(
183
+ last_hidden_state=hidden_states,
184
+ hidden_states=all_hidden_states,
185
+ attentions=all_attentions,
186
+ )
187
+
188
+ def relprop(self, cam, **kwargs):
189
+ # assuming output_hidden_states is False
190
+ for layer_module in reversed(self.layer):
191
+ cam = layer_module.relprop(cam, **kwargs)
192
+ return cam
193
+
194
+
195
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
196
+ class BertPooler(nn.Module):
197
+ def __init__(self, config):
198
+ super().__init__()
199
+ self.dense = Linear(config.hidden_size, config.hidden_size)
200
+ self.activation = Tanh()
201
+ self.pool = IndexSelect()
202
+
203
+ def forward(self, hidden_states):
204
+ # We "pool" the model by simply taking the hidden state corresponding
205
+ # to the first token.
206
+ self._seq_size = hidden_states.shape[1]
207
+
208
+ # first_token_tensor = hidden_states[:, 0]
209
+ first_token_tensor = self.pool(
210
+ hidden_states, 1, torch.tensor(0, device=hidden_states.device)
211
+ )
212
+ first_token_tensor = first_token_tensor.squeeze(1)
213
+ pooled_output = self.dense(first_token_tensor)
214
+ pooled_output = self.activation(pooled_output)
215
+ return pooled_output
216
+
217
+ def relprop(self, cam, **kwargs):
218
+ cam = self.activation.relprop(cam, **kwargs)
219
+ # print(cam.sum())
220
+ cam = self.dense.relprop(cam, **kwargs)
221
+ # print(cam.sum())
222
+ cam = cam.unsqueeze(1)
223
+ cam = self.pool.relprop(cam, **kwargs)
224
+ # print(cam.sum())
225
+
226
+ return cam
227
+
228
+
229
+ class BertAttention(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.self = BertSelfAttention(config)
233
+ self.output = BertSelfOutput(config)
234
+ self.pruned_heads = set()
235
+ self.clone = Clone()
236
+
237
+ def prune_heads(self, heads):
238
+ if len(heads) == 0:
239
+ return
240
+ heads, index = find_pruneable_heads_and_indices(
241
+ heads,
242
+ self.self.num_attention_heads,
243
+ self.self.attention_head_size,
244
+ self.pruned_heads,
245
+ )
246
+
247
+ # Prune linear layers
248
+ self.self.query = prune_linear_layer(self.self.query, index)
249
+ self.self.key = prune_linear_layer(self.self.key, index)
250
+ self.self.value = prune_linear_layer(self.self.value, index)
251
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
252
+
253
+ # Update hyper params and store pruned heads
254
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
255
+ self.self.all_head_size = (
256
+ self.self.attention_head_size * self.self.num_attention_heads
257
+ )
258
+ self.pruned_heads = self.pruned_heads.union(heads)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states,
263
+ attention_mask=None,
264
+ head_mask=None,
265
+ encoder_hidden_states=None,
266
+ encoder_attention_mask=None,
267
+ output_attentions=False,
268
+ ):
269
+ h1, h2 = self.clone(hidden_states, 2)
270
+ self_outputs = self.self(
271
+ h1,
272
+ attention_mask,
273
+ head_mask,
274
+ encoder_hidden_states,
275
+ encoder_attention_mask,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], h2)
279
+ outputs = (attention_output,) + self_outputs[
280
+ 1:
281
+ ] # add attentions if we output them
282
+ return outputs
283
+
284
+ def relprop(self, cam, **kwargs):
285
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
286
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
287
+ # print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
288
+ cam1 = self.self.relprop(cam1, **kwargs)
289
+ # print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
290
+
291
+ return self.clone.relprop((cam1, cam2), **kwargs)
292
+
293
+
294
+ class BertSelfAttention(nn.Module):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
298
+ config, "embedding_size"
299
+ ):
300
+ raise ValueError(
301
+ "The hidden size (%d) is not a multiple of the number of attention "
302
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
303
+ )
304
+
305
+ self.num_attention_heads = config.num_attention_heads
306
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
307
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
308
+
309
+ self.query = Linear(config.hidden_size, self.all_head_size)
310
+ self.key = Linear(config.hidden_size, self.all_head_size)
311
+ self.value = Linear(config.hidden_size, self.all_head_size)
312
+
313
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
314
+
315
+ self.matmul1 = MatMul()
316
+ self.matmul2 = MatMul()
317
+ self.softmax = Softmax(dim=-1)
318
+ self.add = Add()
319
+ self.mul = Mul()
320
+ self.head_mask = None
321
+ self.attention_mask = None
322
+ self.clone = Clone()
323
+
324
+ self.attn_cam = None
325
+ self.attn = None
326
+ self.attn_gradients = None
327
+
328
+ def get_attn(self):
329
+ return self.attn
330
+
331
+ def save_attn(self, attn):
332
+ self.attn = attn
333
+
334
+ def save_attn_cam(self, cam):
335
+ self.attn_cam = cam
336
+
337
+ def get_attn_cam(self):
338
+ return self.attn_cam
339
+
340
+ def save_attn_gradients(self, attn_gradients):
341
+ self.attn_gradients = attn_gradients
342
+
343
+ def get_attn_gradients(self):
344
+ return self.attn_gradients
345
+
346
+ def transpose_for_scores(self, x):
347
+ new_x_shape = x.size()[:-1] + (
348
+ self.num_attention_heads,
349
+ self.attention_head_size,
350
+ )
351
+ x = x.view(*new_x_shape)
352
+ return x.permute(0, 2, 1, 3)
353
+
354
+ def transpose_for_scores_relprop(self, x):
355
+ return x.permute(0, 2, 1, 3).flatten(2)
356
+
357
+ def forward(
358
+ self,
359
+ hidden_states,
360
+ attention_mask=None,
361
+ head_mask=None,
362
+ encoder_hidden_states=None,
363
+ encoder_attention_mask=None,
364
+ output_attentions=False,
365
+ ):
366
+ self.head_mask = head_mask
367
+ self.attention_mask = attention_mask
368
+
369
+ h1, h2, h3 = self.clone(hidden_states, 3)
370
+ mixed_query_layer = self.query(h1)
371
+
372
+ # If this is instantiated as a cross-attention module, the keys
373
+ # and values come from an encoder; the attention mask needs to be
374
+ # such that the encoder's padding tokens are not attended to.
375
+ if encoder_hidden_states is not None:
376
+ mixed_key_layer = self.key(encoder_hidden_states)
377
+ mixed_value_layer = self.value(encoder_hidden_states)
378
+ attention_mask = encoder_attention_mask
379
+ else:
380
+ mixed_key_layer = self.key(h2)
381
+ mixed_value_layer = self.value(h3)
382
+
383
+ query_layer = self.transpose_for_scores(mixed_query_layer)
384
+ key_layer = self.transpose_for_scores(mixed_key_layer)
385
+ value_layer = self.transpose_for_scores(mixed_value_layer)
386
+
387
+ # Take the dot product between "query" and "key" to get the raw attention scores.
388
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
389
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
390
+ if attention_mask is not None:
391
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
392
+ attention_scores = self.add([attention_scores, attention_mask])
393
+
394
+ # Normalize the attention scores to probabilities.
395
+ attention_probs = self.softmax(attention_scores)
396
+
397
+ self.save_attn(attention_probs)
398
+ attention_probs.register_hook(self.save_attn_gradients)
399
+
400
+ # This is actually dropping out entire tokens to attend to, which might
401
+ # seem a bit unusual, but is taken from the original Transformer paper.
402
+ attention_probs = self.dropout(attention_probs)
403
+
404
+ # Mask heads if we want to
405
+ if head_mask is not None:
406
+ attention_probs = attention_probs * head_mask
407
+
408
+ context_layer = self.matmul2([attention_probs, value_layer])
409
+
410
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
411
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
412
+ context_layer = context_layer.view(*new_context_layer_shape)
413
+
414
+ outputs = (
415
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
416
+ )
417
+ return outputs
418
+
419
+ def relprop(self, cam, **kwargs):
420
+ # Assume output_attentions == False
421
+ cam = self.transpose_for_scores(cam)
422
+
423
+ # [attention_probs, value_layer]
424
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
425
+ cam1 /= 2
426
+ cam2 /= 2
427
+ if self.head_mask is not None:
428
+ # [attention_probs, head_mask]
429
+ (cam1, _) = self.mul.relprop(cam1, **kwargs)
430
+
431
+ self.save_attn_cam(cam1)
432
+
433
+ cam1 = self.dropout.relprop(cam1, **kwargs)
434
+
435
+ cam1 = self.softmax.relprop(cam1, **kwargs)
436
+
437
+ if self.attention_mask is not None:
438
+ # [attention_scores, attention_mask]
439
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
440
+
441
+ # [query_layer, key_layer.transpose(-1, -2)]
442
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
443
+ cam1_1 /= 2
444
+ cam1_2 /= 2
445
+
446
+ # query
447
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
448
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
449
+
450
+ # key
451
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
452
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
453
+
454
+ # value
455
+ cam2 = self.transpose_for_scores_relprop(cam2)
456
+ cam2 = self.value.relprop(cam2, **kwargs)
457
+
458
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
459
+
460
+ return cam
461
+
462
+
463
+ class BertSelfOutput(nn.Module):
464
+ def __init__(self, config):
465
+ super().__init__()
466
+ self.dense = Linear(config.hidden_size, config.hidden_size)
467
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
468
+ self.dropout = Dropout(config.hidden_dropout_prob)
469
+ self.add = Add()
470
+
471
+ def forward(self, hidden_states, input_tensor):
472
+ hidden_states = self.dense(hidden_states)
473
+ hidden_states = self.dropout(hidden_states)
474
+ add = self.add([hidden_states, input_tensor])
475
+ hidden_states = self.LayerNorm(add)
476
+ return hidden_states
477
+
478
+ def relprop(self, cam, **kwargs):
479
+ cam = self.LayerNorm.relprop(cam, **kwargs)
480
+ # [hidden_states, input_tensor]
481
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
482
+ cam1 = self.dropout.relprop(cam1, **kwargs)
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+
485
+ return (cam1, cam2)
486
+
487
+
488
+ class BertIntermediate(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
492
+ if isinstance(config.hidden_act, str):
493
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
494
+ else:
495
+ self.intermediate_act_fn = config.hidden_act
496
+
497
+ def forward(self, hidden_states):
498
+ hidden_states = self.dense(hidden_states)
499
+ hidden_states = self.intermediate_act_fn(hidden_states)
500
+ return hidden_states
501
+
502
+ def relprop(self, cam, **kwargs):
503
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
504
+ # print(cam.sum())
505
+ cam = self.dense.relprop(cam, **kwargs)
506
+ # print(cam.sum())
507
+ return cam
508
+
509
+
510
+ class BertOutput(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
514
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
515
+ self.dropout = Dropout(config.hidden_dropout_prob)
516
+ self.add = Add()
517
+
518
+ def forward(self, hidden_states, input_tensor):
519
+ hidden_states = self.dense(hidden_states)
520
+ hidden_states = self.dropout(hidden_states)
521
+ add = self.add([hidden_states, input_tensor])
522
+ hidden_states = self.LayerNorm(add)
523
+ return hidden_states
524
+
525
+ def relprop(self, cam, **kwargs):
526
+ # print("in", cam.sum())
527
+ cam = self.LayerNorm.relprop(cam, **kwargs)
528
+ # print(cam.sum())
529
+ # [hidden_states, input_tensor]
530
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
531
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
532
+ cam1 = self.dropout.relprop(cam1, **kwargs)
533
+ # print(cam1.sum())
534
+ cam1 = self.dense.relprop(cam1, **kwargs)
535
+ # print("dense", cam1.sum())
536
+
537
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
538
+ return (cam1, cam2)
539
+
540
+
541
+ class BertLayer(nn.Module):
542
+ def __init__(self, config):
543
+ super().__init__()
544
+ self.attention = BertAttention(config)
545
+ self.intermediate = BertIntermediate(config)
546
+ self.output = BertOutput(config)
547
+ self.clone = Clone()
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states,
552
+ attention_mask=None,
553
+ head_mask=None,
554
+ output_attentions=False,
555
+ ):
556
+ self_attention_outputs = self.attention(
557
+ hidden_states,
558
+ attention_mask,
559
+ head_mask,
560
+ output_attentions=output_attentions,
561
+ )
562
+ attention_output = self_attention_outputs[0]
563
+ outputs = self_attention_outputs[
564
+ 1:
565
+ ] # add self attentions if we output attention weights
566
+
567
+ ao1, ao2 = self.clone(attention_output, 2)
568
+ intermediate_output = self.intermediate(ao1)
569
+ layer_output = self.output(intermediate_output, ao2)
570
+
571
+ outputs = (layer_output,) + outputs
572
+ return outputs
573
+
574
+ def relprop(self, cam, **kwargs):
575
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
576
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
577
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
578
+ # print("intermediate", cam1.sum())
579
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
580
+ # print("clone", cam.sum())
581
+ cam = self.attention.relprop(cam, **kwargs)
582
+ # print("attention", cam.sum())
583
+ return cam
584
+
585
+
586
+ class BertModel(BertPreTrainedModel):
587
+ def __init__(self, config):
588
+ super().__init__(config)
589
+ self.config = config
590
+
591
+ self.embeddings = BertEmbeddings(config)
592
+ self.encoder = BertEncoder(config)
593
+ self.pooler = BertPooler(config)
594
+
595
+ self.init_weights()
596
+
597
+ def get_input_embeddings(self):
598
+ return self.embeddings.word_embeddings
599
+
600
+ def set_input_embeddings(self, value):
601
+ self.embeddings.word_embeddings = value
602
+
603
+ def forward(
604
+ self,
605
+ input_ids=None,
606
+ attention_mask=None,
607
+ token_type_ids=None,
608
+ position_ids=None,
609
+ head_mask=None,
610
+ inputs_embeds=None,
611
+ encoder_hidden_states=None,
612
+ encoder_attention_mask=None,
613
+ output_attentions=None,
614
+ output_hidden_states=None,
615
+ return_dict=None,
616
+ ):
617
+ r"""
618
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
619
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
620
+ if the model is configured as a decoder.
621
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
622
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
623
+ is used in the cross-attention if the model is configured as a decoder.
624
+ Mask values selected in ``[0, 1]``:
625
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
626
+ """
627
+ output_attentions = (
628
+ output_attentions
629
+ if output_attentions is not None
630
+ else self.config.output_attentions
631
+ )
632
+ output_hidden_states = (
633
+ output_hidden_states
634
+ if output_hidden_states is not None
635
+ else self.config.output_hidden_states
636
+ )
637
+ return_dict = (
638
+ return_dict if return_dict is not None else self.config.use_return_dict
639
+ )
640
+
641
+ if input_ids is not None and inputs_embeds is not None:
642
+ raise ValueError(
643
+ "You cannot specify both input_ids and inputs_embeds at the same time"
644
+ )
645
+ elif input_ids is not None:
646
+ input_shape = input_ids.size()
647
+ elif inputs_embeds is not None:
648
+ input_shape = inputs_embeds.size()[:-1]
649
+ else:
650
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
651
+
652
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
653
+
654
+ if attention_mask is None:
655
+ attention_mask = torch.ones(input_shape, device=device)
656
+ if token_type_ids is None:
657
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
658
+
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
662
+ attention_mask, input_shape, device
663
+ )
664
+
665
+ # If a 2D or 3D attention mask is provided for the cross-attention
666
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if self.config.is_decoder and encoder_hidden_states is not None:
668
+ (
669
+ encoder_batch_size,
670
+ encoder_sequence_length,
671
+ _,
672
+ ) = encoder_hidden_states.size()
673
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
674
+ if encoder_attention_mask is None:
675
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
676
+ encoder_extended_attention_mask = self.invert_attention_mask(
677
+ encoder_attention_mask
678
+ )
679
+ else:
680
+ encoder_extended_attention_mask = None
681
+
682
+ # Prepare head mask if needed
683
+ # 1.0 in head_mask indicate we keep the head
684
+ # attention_probs has shape bsz x n_heads x N x N
685
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
686
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
687
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
688
+
689
+ embedding_output = self.embeddings(
690
+ input_ids=input_ids,
691
+ position_ids=position_ids,
692
+ token_type_ids=token_type_ids,
693
+ inputs_embeds=inputs_embeds,
694
+ )
695
+
696
+ encoder_outputs = self.encoder(
697
+ embedding_output,
698
+ attention_mask=extended_attention_mask,
699
+ head_mask=head_mask,
700
+ encoder_hidden_states=encoder_hidden_states,
701
+ encoder_attention_mask=encoder_extended_attention_mask,
702
+ output_attentions=output_attentions,
703
+ output_hidden_states=output_hidden_states,
704
+ return_dict=return_dict,
705
+ )
706
+ sequence_output = encoder_outputs[0]
707
+ pooled_output = self.pooler(sequence_output)
708
+
709
+ if not return_dict:
710
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
711
+
712
+ return BaseModelOutputWithPooling(
713
+ last_hidden_state=sequence_output,
714
+ pooler_output=pooled_output,
715
+ hidden_states=encoder_outputs.hidden_states,
716
+ attentions=encoder_outputs.attentions,
717
+ )
718
+
719
+ def relprop(self, cam, **kwargs):
720
+ cam = self.pooler.relprop(cam, **kwargs)
721
+ # print("111111111111",cam.sum())
722
+ cam = self.encoder.relprop(cam, **kwargs)
723
+ # print("222222222222222", cam.sum())
724
+ # print("conservation: ", cam.sum())
725
+ return cam
726
+
727
+
728
+ if __name__ == "__main__":
729
+
730
+ class Config:
731
+ def __init__(
732
+ self, hidden_size, num_attention_heads, attention_probs_dropout_prob
733
+ ):
734
+ self.hidden_size = hidden_size
735
+ self.num_attention_heads = num_attention_heads
736
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
737
+
738
+ model = BertSelfAttention(Config(1024, 4, 0.1))
739
+ x = torch.rand(2, 20, 1024)
740
+ x.requires_grad_()
741
+
742
+ model.eval()
743
+
744
+ y = model.forward(x)
745
+
746
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
747
+
748
+ print(relprop[1][0].shape)
Transformer-Explainability/BERT_explainability/modules/BERT/BERT_cls_lrp.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel
6
+ from BERT_explainability.modules.layers_lrp import *
7
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
8
+ from torch.nn import CrossEntropyLoss, MSELoss
9
+ from transformers import BertPreTrainedModel
10
+ from transformers.utils import logging
11
+
12
+
13
+ class BertForSequenceClassification(BertPreTrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+
18
+ self.bert = BertModel(config)
19
+ self.dropout = Dropout(config.hidden_dropout_prob)
20
+ self.classifier = Linear(config.hidden_size, config.num_labels)
21
+
22
+ self.init_weights()
23
+
24
+ def forward(
25
+ self,
26
+ input_ids=None,
27
+ attention_mask=None,
28
+ token_type_ids=None,
29
+ position_ids=None,
30
+ head_mask=None,
31
+ inputs_embeds=None,
32
+ labels=None,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+ r"""
38
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
39
+ Labels for computing the sequence classification/regression loss.
40
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
41
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
42
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
43
+ """
44
+ return_dict = (
45
+ return_dict if return_dict is not None else self.config.use_return_dict
46
+ )
47
+
48
+ outputs = self.bert(
49
+ input_ids,
50
+ attention_mask=attention_mask,
51
+ token_type_ids=token_type_ids,
52
+ position_ids=position_ids,
53
+ head_mask=head_mask,
54
+ inputs_embeds=inputs_embeds,
55
+ output_attentions=output_attentions,
56
+ output_hidden_states=output_hidden_states,
57
+ return_dict=return_dict,
58
+ )
59
+
60
+ pooled_output = outputs[1]
61
+
62
+ pooled_output = self.dropout(pooled_output)
63
+ logits = self.classifier(pooled_output)
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ if self.num_labels == 1:
68
+ # We are doing regression
69
+ loss_fct = MSELoss()
70
+ loss = loss_fct(logits.view(-1), labels.view(-1))
71
+ else:
72
+ loss_fct = CrossEntropyLoss()
73
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
74
+
75
+ if not return_dict:
76
+ output = (logits,) + outputs[2:]
77
+ return ((loss,) + output) if loss is not None else output
78
+
79
+ return SequenceClassifierOutput(
80
+ loss=loss,
81
+ logits=logits,
82
+ hidden_states=outputs.hidden_states,
83
+ attentions=outputs.attentions,
84
+ )
85
+
86
+ def relprop(self, cam=None, **kwargs):
87
+ cam = self.classifier.relprop(cam, **kwargs)
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.bert.relprop(cam, **kwargs)
90
+ return cam
91
+
92
+
93
+ # this is the actual classifier we will be using
94
+ class BertClassifier(nn.Module):
95
+ """Thin wrapper around BertForSequenceClassification"""
96
+
97
+ def __init__(
98
+ self,
99
+ bert_dir: str,
100
+ pad_token_id: int,
101
+ cls_token_id: int,
102
+ sep_token_id: int,
103
+ num_labels: int,
104
+ max_length: int = 512,
105
+ use_half_precision=True,
106
+ ):
107
+ super(BertClassifier, self).__init__()
108
+ bert = BertForSequenceClassification.from_pretrained(
109
+ bert_dir, num_labels=num_labels
110
+ )
111
+ if use_half_precision:
112
+ import apex
113
+
114
+ bert = bert.half()
115
+ self.bert = bert
116
+ self.pad_token_id = pad_token_id
117
+ self.cls_token_id = cls_token_id
118
+ self.sep_token_id = sep_token_id
119
+ self.max_length = max_length
120
+
121
+ def forward(
122
+ self,
123
+ query: List[torch.tensor],
124
+ docids: List[Any],
125
+ document_batch: List[torch.tensor],
126
+ ):
127
+ assert len(query) == len(document_batch)
128
+ print(query)
129
+ # note about device management:
130
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
131
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
132
+ target_device = next(self.parameters()).device
133
+ cls_token = torch.tensor([self.cls_token_id]).to(
134
+ device=document_batch[0].device
135
+ )
136
+ sep_token = torch.tensor([self.sep_token_id]).to(
137
+ device=document_batch[0].device
138
+ )
139
+ input_tensors = []
140
+ position_ids = []
141
+ for q, d in zip(query, document_batch):
142
+ if len(q) + len(d) + 2 > self.max_length:
143
+ d = d[: (self.max_length - len(q) - 2)]
144
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
145
+ position_ids.append(
146
+ torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))
147
+ )
148
+ bert_input = PaddedSequence.autopad(
149
+ input_tensors,
150
+ batch_first=True,
151
+ padding_value=self.pad_token_id,
152
+ device=target_device,
153
+ )
154
+ positions = PaddedSequence.autopad(
155
+ position_ids, batch_first=True, padding_value=0, device=target_device
156
+ )
157
+ (classes,) = self.bert(
158
+ bert_input.data,
159
+ attention_mask=bert_input.mask(
160
+ on=0.0, off=float("-inf"), device=target_device
161
+ ),
162
+ position_ids=positions.data,
163
+ )
164
+ assert torch.all(classes == classes) # for nans
165
+
166
+ print(input_tensors[0])
167
+ print(self.relprop()[0])
168
+
169
+ return classes
170
+
171
+ def relprop(self, cam=None, **kwargs):
172
+ return self.bert.relprop(cam, **kwargs)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ import os
177
+
178
+ from transformers import BertTokenizer
179
+
180
+ class Config:
181
+ def __init__(
182
+ self,
183
+ hidden_size,
184
+ num_attention_heads,
185
+ attention_probs_dropout_prob,
186
+ num_labels,
187
+ hidden_dropout_prob,
188
+ ):
189
+ self.hidden_size = hidden_size
190
+ self.num_attention_heads = num_attention_heads
191
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
192
+ self.num_labels = num_labels
193
+ self.hidden_dropout_prob = hidden_dropout_prob
194
+
195
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
196
+ x = tokenizer.encode_plus(
197
+ "In this movie the acting is great. The movie is perfect! [sep]",
198
+ add_special_tokens=True,
199
+ max_length=512,
200
+ return_token_type_ids=False,
201
+ return_attention_mask=True,
202
+ pad_to_max_length=True,
203
+ return_tensors="pt",
204
+ truncation=True,
205
+ )
206
+
207
+ print(x["input_ids"])
208
+
209
+ model = BertForSequenceClassification.from_pretrained(
210
+ "bert-base-uncased", num_labels=2
211
+ )
212
+ model_save_file = os.path.join(
213
+ "./BERT_explainability/output_bert/movies/classifier/", "classifier.pt"
214
+ )
215
+ model.load_state_dict(torch.load(model_save_file))
216
+
217
+ # x = torch.randint(100, (2, 20))
218
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
219
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
220
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
221
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
222
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
223
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
224
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
225
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
226
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
227
+ # 102, 101, 1012, 102]])
228
+ # x.requires_grad_()
229
+
230
+ model.eval()
231
+
232
+ y = model(x["input_ids"], x["attention_mask"])
233
+ print(y)
234
+
235
+ cam, _ = model.relprop()
236
+
237
+ # print(cam.shape)
238
+
239
+ cam = cam.sum(-1)
240
+ # print(cam)
Transformer-Explainability/BERT_explainability/modules/BERT/BERT_orig_lrp.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from BERT_explainability.modules.layers_lrp import *
8
+ from torch import nn
9
+ from transformers import BertConfig, BertPreTrainedModel, PreTrainedModel
10
+ from transformers.modeling_outputs import (BaseModelOutput,
11
+ BaseModelOutputWithPooling)
12
+
13
+ ACT2FN = {
14
+ "relu": ReLU,
15
+ "tanh": Tanh,
16
+ "gelu": GELU,
17
+ }
18
+
19
+
20
+ def get_activation(activation_string):
21
+ if activation_string in ACT2FN:
22
+ return ACT2FN[activation_string]
23
+ else:
24
+ raise KeyError(
25
+ "function {} not found in ACT2FN mapping {}".format(
26
+ activation_string, list(ACT2FN.keys())
27
+ )
28
+ )
29
+
30
+
31
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
32
+ # adding residual consideration
33
+ num_tokens = all_layer_matrices[0].shape[1]
34
+ batch_size = all_layer_matrices[0].shape[0]
35
+ eye = (
36
+ torch.eye(num_tokens)
37
+ .expand(batch_size, num_tokens, num_tokens)
38
+ .to(all_layer_matrices[0].device)
39
+ )
40
+ all_layer_matrices = [
41
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
42
+ ]
43
+ all_layer_matrices = [
44
+ all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ for i in range(len(all_layer_matrices))
46
+ ]
47
+ joint_attention = all_layer_matrices[start_layer]
48
+ for i in range(start_layer + 1, len(all_layer_matrices)):
49
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
50
+ return joint_attention
51
+
52
+
53
+ class BertEmbeddings(nn.Module):
54
+ """Construct the embeddings from word, position and token_type embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.word_embeddings = nn.Embedding(
59
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
60
+ )
61
+ self.position_embeddings = nn.Embedding(
62
+ config.max_position_embeddings, config.hidden_size
63
+ )
64
+ self.token_type_embeddings = nn.Embedding(
65
+ config.type_vocab_size, config.hidden_size
66
+ )
67
+
68
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
69
+ # any TensorFlow checkpoint file
70
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
71
+ self.dropout = Dropout(config.hidden_dropout_prob)
72
+
73
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
74
+ self.register_buffer(
75
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
76
+ )
77
+
78
+ self.add1 = Add()
79
+ self.add2 = Add()
80
+
81
+ def forward(
82
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
83
+ ):
84
+ if input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ else:
87
+ input_shape = inputs_embeds.size()[:-1]
88
+
89
+ seq_length = input_shape[1]
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[:, :seq_length]
93
+
94
+ if token_type_ids is None:
95
+ token_type_ids = torch.zeros(
96
+ input_shape, dtype=torch.long, device=self.position_ids.device
97
+ )
98
+
99
+ if inputs_embeds is None:
100
+ inputs_embeds = self.word_embeddings(input_ids)
101
+ position_embeddings = self.position_embeddings(position_ids)
102
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
103
+
104
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
105
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
106
+ embeddings = self.add2([embeddings, inputs_embeds])
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+ def relprop(self, cam, **kwargs):
112
+ cam = self.dropout.relprop(cam, **kwargs)
113
+ cam = self.LayerNorm.relprop(cam, **kwargs)
114
+
115
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
116
+ (cam) = self.add2.relprop(cam, **kwargs)
117
+
118
+ return cam
119
+
120
+
121
+ class BertEncoder(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.config = config
125
+ self.layer = nn.ModuleList(
126
+ [BertLayer(config) for _ in range(config.num_hidden_layers)]
127
+ )
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states,
132
+ attention_mask=None,
133
+ head_mask=None,
134
+ encoder_hidden_states=None,
135
+ encoder_attention_mask=None,
136
+ output_attentions=False,
137
+ output_hidden_states=False,
138
+ return_dict=False,
139
+ ):
140
+ all_hidden_states = () if output_hidden_states else None
141
+ all_attentions = () if output_attentions else None
142
+ for i, layer_module in enumerate(self.layer):
143
+ if output_hidden_states:
144
+ all_hidden_states = all_hidden_states + (hidden_states,)
145
+
146
+ layer_head_mask = head_mask[i] if head_mask is not None else None
147
+
148
+ if getattr(self.config, "gradient_checkpointing", False):
149
+
150
+ def create_custom_forward(module):
151
+ def custom_forward(*inputs):
152
+ return module(*inputs, output_attentions)
153
+
154
+ return custom_forward
155
+
156
+ layer_outputs = torch.utils.checkpoint.checkpoint(
157
+ create_custom_forward(layer_module),
158
+ hidden_states,
159
+ attention_mask,
160
+ layer_head_mask,
161
+ )
162
+ else:
163
+ layer_outputs = layer_module(
164
+ hidden_states,
165
+ attention_mask,
166
+ layer_head_mask,
167
+ output_attentions,
168
+ )
169
+ hidden_states = layer_outputs[0]
170
+ if output_attentions:
171
+ all_attentions = all_attentions + (layer_outputs[1],)
172
+
173
+ if output_hidden_states:
174
+ all_hidden_states = all_hidden_states + (hidden_states,)
175
+
176
+ if not return_dict:
177
+ return tuple(
178
+ v
179
+ for v in [hidden_states, all_hidden_states, all_attentions]
180
+ if v is not None
181
+ )
182
+ return BaseModelOutput(
183
+ last_hidden_state=hidden_states,
184
+ hidden_states=all_hidden_states,
185
+ attentions=all_attentions,
186
+ )
187
+
188
+ def relprop(self, cam, **kwargs):
189
+ # assuming output_hidden_states is False
190
+ for layer_module in reversed(self.layer):
191
+ cam = layer_module.relprop(cam, **kwargs)
192
+ return cam
193
+
194
+
195
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
196
+ class BertPooler(nn.Module):
197
+ def __init__(self, config):
198
+ super().__init__()
199
+ self.dense = Linear(config.hidden_size, config.hidden_size)
200
+ self.activation = Tanh()
201
+ self.pool = IndexSelect()
202
+
203
+ def forward(self, hidden_states):
204
+ # We "pool" the model by simply taking the hidden state corresponding
205
+ # to the first token.
206
+ self._seq_size = hidden_states.shape[1]
207
+
208
+ # first_token_tensor = hidden_states[:, 0]
209
+ first_token_tensor = self.pool(
210
+ hidden_states, 1, torch.tensor(0, device=hidden_states.device)
211
+ )
212
+ first_token_tensor = first_token_tensor.squeeze(1)
213
+ pooled_output = self.dense(first_token_tensor)
214
+ pooled_output = self.activation(pooled_output)
215
+ return pooled_output
216
+
217
+ def relprop(self, cam, **kwargs):
218
+ cam = self.activation.relprop(cam, **kwargs)
219
+ # print(cam.sum())
220
+ cam = self.dense.relprop(cam, **kwargs)
221
+ # print(cam.sum())
222
+ cam = cam.unsqueeze(1)
223
+ cam = self.pool.relprop(cam, **kwargs)
224
+ # print(cam.sum())
225
+
226
+ return cam
227
+
228
+
229
+ class BertAttention(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.self = BertSelfAttention(config)
233
+ self.output = BertSelfOutput(config)
234
+ self.pruned_heads = set()
235
+ self.clone = Clone()
236
+
237
+ def prune_heads(self, heads):
238
+ if len(heads) == 0:
239
+ return
240
+ heads, index = find_pruneable_heads_and_indices(
241
+ heads,
242
+ self.self.num_attention_heads,
243
+ self.self.attention_head_size,
244
+ self.pruned_heads,
245
+ )
246
+
247
+ # Prune linear layers
248
+ self.self.query = prune_linear_layer(self.self.query, index)
249
+ self.self.key = prune_linear_layer(self.self.key, index)
250
+ self.self.value = prune_linear_layer(self.self.value, index)
251
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
252
+
253
+ # Update hyper params and store pruned heads
254
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
255
+ self.self.all_head_size = (
256
+ self.self.attention_head_size * self.self.num_attention_heads
257
+ )
258
+ self.pruned_heads = self.pruned_heads.union(heads)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states,
263
+ attention_mask=None,
264
+ head_mask=None,
265
+ encoder_hidden_states=None,
266
+ encoder_attention_mask=None,
267
+ output_attentions=False,
268
+ ):
269
+ h1, h2 = self.clone(hidden_states, 2)
270
+ self_outputs = self.self(
271
+ h1,
272
+ attention_mask,
273
+ head_mask,
274
+ encoder_hidden_states,
275
+ encoder_attention_mask,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], h2)
279
+ outputs = (attention_output,) + self_outputs[
280
+ 1:
281
+ ] # add attentions if we output them
282
+ return outputs
283
+
284
+ def relprop(self, cam, **kwargs):
285
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
286
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
287
+ # print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
288
+ cam1 = self.self.relprop(cam1, **kwargs)
289
+ # print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
290
+
291
+ return self.clone.relprop((cam1, cam2), **kwargs)
292
+
293
+
294
+ class BertSelfAttention(nn.Module):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
298
+ config, "embedding_size"
299
+ ):
300
+ raise ValueError(
301
+ "The hidden size (%d) is not a multiple of the number of attention "
302
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
303
+ )
304
+
305
+ self.num_attention_heads = config.num_attention_heads
306
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
307
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
308
+
309
+ self.query = Linear(config.hidden_size, self.all_head_size)
310
+ self.key = Linear(config.hidden_size, self.all_head_size)
311
+ self.value = Linear(config.hidden_size, self.all_head_size)
312
+
313
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
314
+
315
+ self.matmul1 = MatMul()
316
+ self.matmul2 = MatMul()
317
+ self.softmax = Softmax(dim=-1)
318
+ self.add = Add()
319
+ self.mul = Mul()
320
+ self.head_mask = None
321
+ self.attention_mask = None
322
+ self.clone = Clone()
323
+
324
+ self.attn_cam = None
325
+ self.attn = None
326
+ self.attn_gradients = None
327
+
328
+ def get_attn(self):
329
+ return self.attn
330
+
331
+ def save_attn(self, attn):
332
+ self.attn = attn
333
+
334
+ def save_attn_cam(self, cam):
335
+ self.attn_cam = cam
336
+
337
+ def get_attn_cam(self):
338
+ return self.attn_cam
339
+
340
+ def save_attn_gradients(self, attn_gradients):
341
+ self.attn_gradients = attn_gradients
342
+
343
+ def get_attn_gradients(self):
344
+ return self.attn_gradients
345
+
346
+ def transpose_for_scores(self, x):
347
+ new_x_shape = x.size()[:-1] + (
348
+ self.num_attention_heads,
349
+ self.attention_head_size,
350
+ )
351
+ x = x.view(*new_x_shape)
352
+ return x.permute(0, 2, 1, 3)
353
+
354
+ def transpose_for_scores_relprop(self, x):
355
+ return x.permute(0, 2, 1, 3).flatten(2)
356
+
357
+ def forward(
358
+ self,
359
+ hidden_states,
360
+ attention_mask=None,
361
+ head_mask=None,
362
+ encoder_hidden_states=None,
363
+ encoder_attention_mask=None,
364
+ output_attentions=False,
365
+ ):
366
+ self.head_mask = head_mask
367
+ self.attention_mask = attention_mask
368
+
369
+ h1, h2, h3 = self.clone(hidden_states, 3)
370
+ mixed_query_layer = self.query(h1)
371
+
372
+ # If this is instantiated as a cross-attention module, the keys
373
+ # and values come from an encoder; the attention mask needs to be
374
+ # such that the encoder's padding tokens are not attended to.
375
+ if encoder_hidden_states is not None:
376
+ mixed_key_layer = self.key(encoder_hidden_states)
377
+ mixed_value_layer = self.value(encoder_hidden_states)
378
+ attention_mask = encoder_attention_mask
379
+ else:
380
+ mixed_key_layer = self.key(h2)
381
+ mixed_value_layer = self.value(h3)
382
+
383
+ query_layer = self.transpose_for_scores(mixed_query_layer)
384
+ key_layer = self.transpose_for_scores(mixed_key_layer)
385
+ value_layer = self.transpose_for_scores(mixed_value_layer)
386
+
387
+ # Take the dot product between "query" and "key" to get the raw attention scores.
388
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
389
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
390
+ if attention_mask is not None:
391
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
392
+ attention_scores = self.add([attention_scores, attention_mask])
393
+
394
+ # Normalize the attention scores to probabilities.
395
+ attention_probs = self.softmax(attention_scores)
396
+
397
+ self.save_attn(attention_probs)
398
+ attention_probs.register_hook(self.save_attn_gradients)
399
+
400
+ # This is actually dropping out entire tokens to attend to, which might
401
+ # seem a bit unusual, but is taken from the original Transformer paper.
402
+ attention_probs = self.dropout(attention_probs)
403
+
404
+ # Mask heads if we want to
405
+ if head_mask is not None:
406
+ attention_probs = attention_probs * head_mask
407
+
408
+ context_layer = self.matmul2([attention_probs, value_layer])
409
+
410
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
411
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
412
+ context_layer = context_layer.view(*new_context_layer_shape)
413
+
414
+ outputs = (
415
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
416
+ )
417
+ return outputs
418
+
419
+ def relprop(self, cam, **kwargs):
420
+ # Assume output_attentions == False
421
+ cam = self.transpose_for_scores(cam)
422
+
423
+ # [attention_probs, value_layer]
424
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
425
+ cam1 /= 2
426
+ cam2 /= 2
427
+ if self.head_mask is not None:
428
+ # [attention_probs, head_mask]
429
+ (cam1, _) = self.mul.relprop(cam1, **kwargs)
430
+
431
+ self.save_attn_cam(cam1)
432
+
433
+ cam1 = self.dropout.relprop(cam1, **kwargs)
434
+
435
+ cam1 = self.softmax.relprop(cam1, **kwargs)
436
+
437
+ if self.attention_mask is not None:
438
+ # [attention_scores, attention_mask]
439
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
440
+
441
+ # [query_layer, key_layer.transpose(-1, -2)]
442
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
443
+ cam1_1 /= 2
444
+ cam1_2 /= 2
445
+
446
+ # query
447
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
448
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
449
+
450
+ # key
451
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
452
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
453
+
454
+ # value
455
+ cam2 = self.transpose_for_scores_relprop(cam2)
456
+ cam2 = self.value.relprop(cam2, **kwargs)
457
+
458
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
459
+
460
+ return cam
461
+
462
+
463
+ class BertSelfOutput(nn.Module):
464
+ def __init__(self, config):
465
+ super().__init__()
466
+ self.dense = Linear(config.hidden_size, config.hidden_size)
467
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
468
+ self.dropout = Dropout(config.hidden_dropout_prob)
469
+ self.add = Add()
470
+
471
+ def forward(self, hidden_states, input_tensor):
472
+ hidden_states = self.dense(hidden_states)
473
+ hidden_states = self.dropout(hidden_states)
474
+ add = self.add([hidden_states, input_tensor])
475
+ hidden_states = self.LayerNorm(add)
476
+ return hidden_states
477
+
478
+ def relprop(self, cam, **kwargs):
479
+ cam = self.LayerNorm.relprop(cam, **kwargs)
480
+ # [hidden_states, input_tensor]
481
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
482
+ cam1 = self.dropout.relprop(cam1, **kwargs)
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+
485
+ return (cam1, cam2)
486
+
487
+
488
+ class BertIntermediate(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
492
+ if isinstance(config.hidden_act, str):
493
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
494
+ else:
495
+ self.intermediate_act_fn = config.hidden_act
496
+
497
+ def forward(self, hidden_states):
498
+ hidden_states = self.dense(hidden_states)
499
+ hidden_states = self.intermediate_act_fn(hidden_states)
500
+ return hidden_states
501
+
502
+ def relprop(self, cam, **kwargs):
503
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
504
+ # print(cam.sum())
505
+ cam = self.dense.relprop(cam, **kwargs)
506
+ # print(cam.sum())
507
+ return cam
508
+
509
+
510
+ class BertOutput(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
514
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
515
+ self.dropout = Dropout(config.hidden_dropout_prob)
516
+ self.add = Add()
517
+
518
+ def forward(self, hidden_states, input_tensor):
519
+ hidden_states = self.dense(hidden_states)
520
+ hidden_states = self.dropout(hidden_states)
521
+ add = self.add([hidden_states, input_tensor])
522
+ hidden_states = self.LayerNorm(add)
523
+ return hidden_states
524
+
525
+ def relprop(self, cam, **kwargs):
526
+ # print("in", cam.sum())
527
+ cam = self.LayerNorm.relprop(cam, **kwargs)
528
+ # print(cam.sum())
529
+ # [hidden_states, input_tensor]
530
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
531
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
532
+ cam1 = self.dropout.relprop(cam1, **kwargs)
533
+ # print(cam1.sum())
534
+ cam1 = self.dense.relprop(cam1, **kwargs)
535
+ # print("dense", cam1.sum())
536
+
537
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
538
+ return (cam1, cam2)
539
+
540
+
541
+ class BertLayer(nn.Module):
542
+ def __init__(self, config):
543
+ super().__init__()
544
+ self.attention = BertAttention(config)
545
+ self.intermediate = BertIntermediate(config)
546
+ self.output = BertOutput(config)
547
+ self.clone = Clone()
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states,
552
+ attention_mask=None,
553
+ head_mask=None,
554
+ output_attentions=False,
555
+ ):
556
+ self_attention_outputs = self.attention(
557
+ hidden_states,
558
+ attention_mask,
559
+ head_mask,
560
+ output_attentions=output_attentions,
561
+ )
562
+ attention_output = self_attention_outputs[0]
563
+ outputs = self_attention_outputs[
564
+ 1:
565
+ ] # add self attentions if we output attention weights
566
+
567
+ ao1, ao2 = self.clone(attention_output, 2)
568
+ intermediate_output = self.intermediate(ao1)
569
+ layer_output = self.output(intermediate_output, ao2)
570
+
571
+ outputs = (layer_output,) + outputs
572
+ return outputs
573
+
574
+ def relprop(self, cam, **kwargs):
575
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
576
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
577
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
578
+ # print("intermediate", cam1.sum())
579
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
580
+ # print("clone", cam.sum())
581
+ cam = self.attention.relprop(cam, **kwargs)
582
+ # print("attention", cam.sum())
583
+ return cam
584
+
585
+
586
+ class BertModel(BertPreTrainedModel):
587
+ def __init__(self, config):
588
+ super().__init__(config)
589
+ self.config = config
590
+
591
+ self.embeddings = BertEmbeddings(config)
592
+ self.encoder = BertEncoder(config)
593
+ self.pooler = BertPooler(config)
594
+
595
+ self.init_weights()
596
+
597
+ def get_input_embeddings(self):
598
+ return self.embeddings.word_embeddings
599
+
600
+ def set_input_embeddings(self, value):
601
+ self.embeddings.word_embeddings = value
602
+
603
+ def forward(
604
+ self,
605
+ input_ids=None,
606
+ attention_mask=None,
607
+ token_type_ids=None,
608
+ position_ids=None,
609
+ head_mask=None,
610
+ inputs_embeds=None,
611
+ encoder_hidden_states=None,
612
+ encoder_attention_mask=None,
613
+ output_attentions=None,
614
+ output_hidden_states=None,
615
+ return_dict=None,
616
+ ):
617
+ r"""
618
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
619
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
620
+ if the model is configured as a decoder.
621
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
622
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
623
+ is used in the cross-attention if the model is configured as a decoder.
624
+ Mask values selected in ``[0, 1]``:
625
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
626
+ """
627
+ output_attentions = (
628
+ output_attentions
629
+ if output_attentions is not None
630
+ else self.config.output_attentions
631
+ )
632
+ output_hidden_states = (
633
+ output_hidden_states
634
+ if output_hidden_states is not None
635
+ else self.config.output_hidden_states
636
+ )
637
+ return_dict = (
638
+ return_dict if return_dict is not None else self.config.use_return_dict
639
+ )
640
+
641
+ if input_ids is not None and inputs_embeds is not None:
642
+ raise ValueError(
643
+ "You cannot specify both input_ids and inputs_embeds at the same time"
644
+ )
645
+ elif input_ids is not None:
646
+ input_shape = input_ids.size()
647
+ elif inputs_embeds is not None:
648
+ input_shape = inputs_embeds.size()[:-1]
649
+ else:
650
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
651
+
652
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
653
+
654
+ if attention_mask is None:
655
+ attention_mask = torch.ones(input_shape, device=device)
656
+ if token_type_ids is None:
657
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
658
+
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
662
+ attention_mask, input_shape, device
663
+ )
664
+
665
+ # If a 2D or 3D attention mask is provided for the cross-attention
666
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if self.config.is_decoder and encoder_hidden_states is not None:
668
+ (
669
+ encoder_batch_size,
670
+ encoder_sequence_length,
671
+ _,
672
+ ) = encoder_hidden_states.size()
673
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
674
+ if encoder_attention_mask is None:
675
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
676
+ encoder_extended_attention_mask = self.invert_attention_mask(
677
+ encoder_attention_mask
678
+ )
679
+ else:
680
+ encoder_extended_attention_mask = None
681
+
682
+ # Prepare head mask if needed
683
+ # 1.0 in head_mask indicate we keep the head
684
+ # attention_probs has shape bsz x n_heads x N x N
685
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
686
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
687
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
688
+
689
+ embedding_output = self.embeddings(
690
+ input_ids=input_ids,
691
+ position_ids=position_ids,
692
+ token_type_ids=token_type_ids,
693
+ inputs_embeds=inputs_embeds,
694
+ )
695
+
696
+ encoder_outputs = self.encoder(
697
+ embedding_output,
698
+ attention_mask=extended_attention_mask,
699
+ head_mask=head_mask,
700
+ encoder_hidden_states=encoder_hidden_states,
701
+ encoder_attention_mask=encoder_extended_attention_mask,
702
+ output_attentions=output_attentions,
703
+ output_hidden_states=output_hidden_states,
704
+ return_dict=return_dict,
705
+ )
706
+ sequence_output = encoder_outputs[0]
707
+ pooled_output = self.pooler(sequence_output)
708
+
709
+ if not return_dict:
710
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
711
+
712
+ return BaseModelOutputWithPooling(
713
+ last_hidden_state=sequence_output,
714
+ pooler_output=pooled_output,
715
+ hidden_states=encoder_outputs.hidden_states,
716
+ attentions=encoder_outputs.attentions,
717
+ )
718
+
719
+ def relprop(self, cam, **kwargs):
720
+ cam = self.pooler.relprop(cam, **kwargs)
721
+ # print("111111111111",cam.sum())
722
+ cam = self.encoder.relprop(cam, **kwargs)
723
+ # print("222222222222222", cam.sum())
724
+ # print("conservation: ", cam.sum())
725
+ return cam
726
+
727
+
728
+ if __name__ == "__main__":
729
+
730
+ class Config:
731
+ def __init__(
732
+ self, hidden_size, num_attention_heads, attention_probs_dropout_prob
733
+ ):
734
+ self.hidden_size = hidden_size
735
+ self.num_attention_heads = num_attention_heads
736
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
737
+
738
+ model = BertSelfAttention(Config(1024, 4, 0.1))
739
+ x = torch.rand(2, 20, 1024)
740
+ x.requires_grad_()
741
+
742
+ model.eval()
743
+
744
+ y = model.forward(x)
745
+
746
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
747
+
748
+ print(relprop[1][0].shape)
Transformer-Explainability/BERT_explainability/modules/BERT/BertForSequenceClassification.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from BERT_explainability.modules.BERT.BERT import BertModel
6
+ from BERT_explainability.modules.layers_ours import *
7
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
8
+ from torch.nn import CrossEntropyLoss, MSELoss
9
+ from transformers import BertPreTrainedModel
10
+ from transformers.utils import logging
11
+
12
+
13
+ class BertForSequenceClassification(BertPreTrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+
18
+ self.bert = BertModel(config)
19
+ self.dropout = Dropout(config.hidden_dropout_prob)
20
+ self.classifier = Linear(config.hidden_size, config.num_labels)
21
+
22
+ self.init_weights()
23
+
24
+ def forward(
25
+ self,
26
+ input_ids=None,
27
+ attention_mask=None,
28
+ token_type_ids=None,
29
+ position_ids=None,
30
+ head_mask=None,
31
+ inputs_embeds=None,
32
+ labels=None,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+ r"""
38
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
39
+ Labels for computing the sequence classification/regression loss.
40
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
41
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
42
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
43
+ """
44
+ return_dict = (
45
+ return_dict if return_dict is not None else self.config.use_return_dict
46
+ )
47
+
48
+ outputs = self.bert(
49
+ input_ids,
50
+ attention_mask=attention_mask,
51
+ token_type_ids=token_type_ids,
52
+ position_ids=position_ids,
53
+ head_mask=head_mask,
54
+ inputs_embeds=inputs_embeds,
55
+ output_attentions=output_attentions,
56
+ output_hidden_states=output_hidden_states,
57
+ return_dict=return_dict,
58
+ )
59
+
60
+ pooled_output = outputs[1]
61
+
62
+ pooled_output = self.dropout(pooled_output)
63
+ logits = self.classifier(pooled_output)
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ if self.num_labels == 1:
68
+ # We are doing regression
69
+ loss_fct = MSELoss()
70
+ loss = loss_fct(logits.view(-1), labels.view(-1))
71
+ else:
72
+ loss_fct = CrossEntropyLoss()
73
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
74
+
75
+ if not return_dict:
76
+ output = (logits,) + outputs[2:]
77
+ return ((loss,) + output) if loss is not None else output
78
+
79
+ return SequenceClassifierOutput(
80
+ loss=loss,
81
+ logits=logits,
82
+ hidden_states=outputs.hidden_states,
83
+ attentions=outputs.attentions,
84
+ )
85
+
86
+ def relprop(self, cam=None, **kwargs):
87
+ cam = self.classifier.relprop(cam, **kwargs)
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.bert.relprop(cam, **kwargs)
90
+ # print("conservation: ", cam.sum())
91
+ return cam
92
+
93
+
94
+ # this is the actual classifier we will be using
95
+ class BertClassifier(nn.Module):
96
+ """Thin wrapper around BertForSequenceClassification"""
97
+
98
+ def __init__(
99
+ self,
100
+ bert_dir: str,
101
+ pad_token_id: int,
102
+ cls_token_id: int,
103
+ sep_token_id: int,
104
+ num_labels: int,
105
+ max_length: int = 512,
106
+ use_half_precision=True,
107
+ ):
108
+ super(BertClassifier, self).__init__()
109
+ bert = BertForSequenceClassification.from_pretrained(
110
+ bert_dir, num_labels=num_labels
111
+ )
112
+ if use_half_precision:
113
+ import apex
114
+
115
+ bert = bert.half()
116
+ self.bert = bert
117
+ self.pad_token_id = pad_token_id
118
+ self.cls_token_id = cls_token_id
119
+ self.sep_token_id = sep_token_id
120
+ self.max_length = max_length
121
+
122
+ def forward(
123
+ self,
124
+ query: List[torch.tensor],
125
+ docids: List[Any],
126
+ document_batch: List[torch.tensor],
127
+ ):
128
+ assert len(query) == len(document_batch)
129
+ print(query)
130
+ # note about device management:
131
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
132
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
133
+ target_device = next(self.parameters()).device
134
+ cls_token = torch.tensor([self.cls_token_id]).to(
135
+ device=document_batch[0].device
136
+ )
137
+ sep_token = torch.tensor([self.sep_token_id]).to(
138
+ device=document_batch[0].device
139
+ )
140
+ input_tensors = []
141
+ position_ids = []
142
+ for q, d in zip(query, document_batch):
143
+ if len(q) + len(d) + 2 > self.max_length:
144
+ d = d[: (self.max_length - len(q) - 2)]
145
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
146
+ position_ids.append(
147
+ torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))
148
+ )
149
+ bert_input = PaddedSequence.autopad(
150
+ input_tensors,
151
+ batch_first=True,
152
+ padding_value=self.pad_token_id,
153
+ device=target_device,
154
+ )
155
+ positions = PaddedSequence.autopad(
156
+ position_ids, batch_first=True, padding_value=0, device=target_device
157
+ )
158
+ (classes,) = self.bert(
159
+ bert_input.data,
160
+ attention_mask=bert_input.mask(
161
+ on=0.0, off=float("-inf"), device=target_device
162
+ ),
163
+ position_ids=positions.data,
164
+ )
165
+ assert torch.all(classes == classes) # for nans
166
+
167
+ print(input_tensors[0])
168
+ print(self.relprop()[0])
169
+
170
+ return classes
171
+
172
+ def relprop(self, cam=None, **kwargs):
173
+ return self.bert.relprop(cam, **kwargs)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ import os
178
+
179
+ from transformers import BertTokenizer
180
+
181
+ class Config:
182
+ def __init__(
183
+ self,
184
+ hidden_size,
185
+ num_attention_heads,
186
+ attention_probs_dropout_prob,
187
+ num_labels,
188
+ hidden_dropout_prob,
189
+ ):
190
+ self.hidden_size = hidden_size
191
+ self.num_attention_heads = num_attention_heads
192
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
193
+ self.num_labels = num_labels
194
+ self.hidden_dropout_prob = hidden_dropout_prob
195
+
196
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
197
+ x = tokenizer.encode_plus(
198
+ "In this movie the acting is great. The movie is perfect! [sep]",
199
+ add_special_tokens=True,
200
+ max_length=512,
201
+ return_token_type_ids=False,
202
+ return_attention_mask=True,
203
+ pad_to_max_length=True,
204
+ return_tensors="pt",
205
+ truncation=True,
206
+ )
207
+
208
+ print(x["input_ids"])
209
+
210
+ model = BertForSequenceClassification.from_pretrained(
211
+ "bert-base-uncased", num_labels=2
212
+ )
213
+ model_save_file = os.path.join(
214
+ "./BERT_explainability/output_bert/movies/classifier/", "classifier.pt"
215
+ )
216
+ model.load_state_dict(torch.load(model_save_file))
217
+
218
+ # x = torch.randint(100, (2, 20))
219
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
220
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
221
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
222
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
223
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
224
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
225
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
226
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
227
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
228
+ # 102, 101, 1012, 102]])
229
+ # x.requires_grad_()
230
+
231
+ model.eval()
232
+
233
+ y = model(x["input_ids"], x["attention_mask"])
234
+ print(y)
235
+
236
+ cam, _ = model.relprop()
237
+
238
+ # print(cam.shape)
239
+
240
+ cam = cam.sum(-1)
241
+ # print(cam)
Transformer-Explainability/BERT_explainability/modules/BERT/ExplanationGenerator.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ # compute rollout between attention layers
9
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
10
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
11
+ num_tokens = all_layer_matrices[0].shape[1]
12
+ batch_size = all_layer_matrices[0].shape[0]
13
+ eye = (
14
+ torch.eye(num_tokens)
15
+ .expand(batch_size, num_tokens, num_tokens)
16
+ .to(all_layer_matrices[0].device)
17
+ )
18
+ all_layer_matrices = [
19
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
20
+ ]
21
+ matrices_aug = [
22
+ all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
23
+ for i in range(len(all_layer_matrices))
24
+ ]
25
+ joint_attention = matrices_aug[start_layer]
26
+ for i in range(start_layer + 1, len(matrices_aug)):
27
+ joint_attention = matrices_aug[i].bmm(joint_attention)
28
+ return joint_attention
29
+
30
+
31
+ class Generator:
32
+ def __init__(self, model):
33
+ self.model = model
34
+ self.model.eval()
35
+
36
+ def forward(self, input_ids, attention_mask):
37
+ return self.model(input_ids, attention_mask)
38
+
39
+ def generate_LRP(self, input_ids, attention_mask, index=None, start_layer=11):
40
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
41
+ kwargs = {"alpha": 1}
42
+
43
+ if index == None:
44
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
45
+
46
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
47
+ one_hot[0, index] = 1
48
+ one_hot_vector = one_hot
49
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
50
+ one_hot = torch.sum(one_hot.cuda() * output)
51
+
52
+ self.model.zero_grad()
53
+ one_hot.backward(retain_graph=True)
54
+
55
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
56
+
57
+ cams = []
58
+ blocks = self.model.bert.encoder.layer
59
+ for blk in blocks:
60
+ grad = blk.attention.self.get_attn_gradients()
61
+ cam = blk.attention.self.get_attn_cam()
62
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
63
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
64
+ cam = grad * cam
65
+ cam = cam.clamp(min=0).mean(dim=0)
66
+ cams.append(cam.unsqueeze(0))
67
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
68
+ rollout[:, 0, 0] = rollout[:, 0].min()
69
+ return rollout[:, 0]
70
+
71
+ def generate_LRP_last_layer(self, input_ids, attention_mask, index=None):
72
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
73
+ kwargs = {"alpha": 1}
74
+ if index == None:
75
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
76
+
77
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
78
+ one_hot[0, index] = 1
79
+ one_hot_vector = one_hot
80
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
81
+ one_hot = torch.sum(one_hot.cuda() * output)
82
+
83
+ self.model.zero_grad()
84
+ one_hot.backward(retain_graph=True)
85
+
86
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
87
+
88
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0]
89
+ cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
90
+ cam[:, 0, 0] = 0
91
+ return cam[:, 0]
92
+
93
+ def generate_full_lrp(self, input_ids, attention_mask, index=None):
94
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
95
+ kwargs = {"alpha": 1}
96
+
97
+ if index == None:
98
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
99
+
100
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
101
+ one_hot[0, index] = 1
102
+ one_hot_vector = one_hot
103
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
104
+ one_hot = torch.sum(one_hot.cuda() * output)
105
+
106
+ self.model.zero_grad()
107
+ one_hot.backward(retain_graph=True)
108
+
109
+ cam = self.model.relprop(
110
+ torch.tensor(one_hot_vector).to(input_ids.device), **kwargs
111
+ )
112
+ cam = cam.sum(dim=2)
113
+ cam[:, 0] = 0
114
+ return cam
115
+
116
+ def generate_attn_last_layer(self, input_ids, attention_mask, index=None):
117
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
118
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0]
119
+ cam = cam.mean(dim=0).unsqueeze(0)
120
+ cam[:, 0, 0] = 0
121
+ return cam[:, 0]
122
+
123
+ def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
124
+ self.model.zero_grad()
125
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
126
+ blocks = self.model.bert.encoder.layer
127
+ all_layer_attentions = []
128
+ for blk in blocks:
129
+ attn_heads = blk.attention.self.get_attn()
130
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
131
+ all_layer_attentions.append(avg_heads)
132
+ rollout = compute_rollout_attention(
133
+ all_layer_attentions, start_layer=start_layer
134
+ )
135
+ rollout[:, 0, 0] = 0
136
+ return rollout[:, 0]
137
+
138
+ def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
139
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
140
+ kwargs = {"alpha": 1}
141
+
142
+ if index == None:
143
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
144
+
145
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
146
+ one_hot[0, index] = 1
147
+ one_hot_vector = one_hot
148
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
149
+ one_hot = torch.sum(one_hot.cuda() * output)
150
+
151
+ self.model.zero_grad()
152
+ one_hot.backward(retain_graph=True)
153
+
154
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
155
+
156
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()
157
+ grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients()
158
+
159
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
160
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
161
+ grad = grad.mean(dim=[1, 2], keepdim=True)
162
+ cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
163
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
164
+ cam[:, 0, 0] = 0
165
+ return cam[:, 0]
Transformer-Explainability/BERT_explainability/modules/__init__.py ADDED
File without changes
Transformer-Explainability/BERT_explainability/modules/layers_lrp.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = [
6
+ "forward_hook",
7
+ "Clone",
8
+ "Add",
9
+ "Cat",
10
+ "ReLU",
11
+ "GELU",
12
+ "Dropout",
13
+ "BatchNorm2d",
14
+ "Linear",
15
+ "MaxPool2d",
16
+ "AdaptiveAvgPool2d",
17
+ "AvgPool2d",
18
+ "Conv2d",
19
+ "Sequential",
20
+ "safe_divide",
21
+ "einsum",
22
+ "Softmax",
23
+ "IndexSelect",
24
+ "LayerNorm",
25
+ "AddEye",
26
+ "Tanh",
27
+ "MatMul",
28
+ "Mul",
29
+ ]
30
+
31
+
32
+ def safe_divide(a, b):
33
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
34
+ den = den + den.eq(0).type(den.type()) * 1e-9
35
+ return a / den * b.ne(0).type(b.type())
36
+
37
+
38
+ def forward_hook(self, input, output):
39
+ if type(input[0]) in (list, tuple):
40
+ self.X = []
41
+ for i in input[0]:
42
+ x = i.detach()
43
+ x.requires_grad = True
44
+ self.X.append(x)
45
+ else:
46
+ self.X = input[0].detach()
47
+ self.X.requires_grad = True
48
+
49
+ self.Y = output
50
+
51
+
52
+ def backward_hook(self, grad_input, grad_output):
53
+ self.grad_input = grad_input
54
+ self.grad_output = grad_output
55
+
56
+
57
+ class RelProp(nn.Module):
58
+ def __init__(self):
59
+ super(RelProp, self).__init__()
60
+ # if not self.training:
61
+ self.register_forward_hook(forward_hook)
62
+
63
+ def gradprop(self, Z, X, S):
64
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
65
+ return C
66
+
67
+ def relprop(self, R, alpha):
68
+ return R
69
+
70
+
71
+ class RelPropSimple(RelProp):
72
+ def relprop(self, R, alpha):
73
+ Z = self.forward(self.X)
74
+ S = safe_divide(R, Z)
75
+ C = self.gradprop(Z, self.X, S)
76
+
77
+ if torch.is_tensor(self.X) == False:
78
+ outputs = []
79
+ outputs.append(self.X[0] * C[0])
80
+ outputs.append(self.X[1] * C[1])
81
+ else:
82
+ outputs = self.X * (C[0])
83
+ return outputs
84
+
85
+
86
+ class AddEye(RelPropSimple):
87
+ # input of shape B, C, seq_len, seq_len
88
+ def forward(self, input):
89
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
90
+
91
+
92
+ class ReLU(nn.ReLU, RelProp):
93
+ pass
94
+
95
+
96
+ class Tanh(nn.Tanh, RelProp):
97
+ pass
98
+
99
+
100
+ class GELU(nn.GELU, RelProp):
101
+ pass
102
+
103
+
104
+ class Softmax(nn.Softmax, RelProp):
105
+ pass
106
+
107
+
108
+ class LayerNorm(nn.LayerNorm, RelProp):
109
+ pass
110
+
111
+
112
+ class Dropout(nn.Dropout, RelProp):
113
+ pass
114
+
115
+
116
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
117
+ pass
118
+
119
+
120
+ class LayerNorm(nn.LayerNorm, RelProp):
121
+ pass
122
+
123
+
124
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
125
+ pass
126
+
127
+
128
+ class MatMul(RelPropSimple):
129
+ def forward(self, inputs):
130
+ return torch.matmul(*inputs)
131
+
132
+
133
+ class Mul(RelPropSimple):
134
+ def forward(self, inputs):
135
+ return torch.mul(*inputs)
136
+
137
+
138
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
139
+ pass
140
+
141
+
142
+ class Add(RelPropSimple):
143
+ def forward(self, inputs):
144
+ return torch.add(*inputs)
145
+
146
+
147
+ class einsum(RelPropSimple):
148
+ def __init__(self, equation):
149
+ super().__init__()
150
+ self.equation = equation
151
+
152
+ def forward(self, *operands):
153
+ return torch.einsum(self.equation, *operands)
154
+
155
+
156
+ class IndexSelect(RelProp):
157
+ def forward(self, inputs, dim, indices):
158
+ self.__setattr__("dim", dim)
159
+ self.__setattr__("indices", indices)
160
+
161
+ return torch.index_select(inputs, dim, indices)
162
+
163
+ def relprop(self, R, alpha):
164
+ Z = self.forward(self.X, self.dim, self.indices)
165
+ S = safe_divide(R, Z)
166
+ C = self.gradprop(Z, self.X, S)
167
+
168
+ if torch.is_tensor(self.X) == False:
169
+ outputs = []
170
+ outputs.append(self.X[0] * C[0])
171
+ outputs.append(self.X[1] * C[1])
172
+ else:
173
+ outputs = self.X * (C[0])
174
+ return outputs
175
+
176
+
177
+ class Clone(RelProp):
178
+ def forward(self, input, num):
179
+ self.__setattr__("num", num)
180
+ outputs = []
181
+ for _ in range(num):
182
+ outputs.append(input)
183
+
184
+ return outputs
185
+
186
+ def relprop(self, R, alpha):
187
+ Z = []
188
+ for _ in range(self.num):
189
+ Z.append(self.X)
190
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
191
+ C = self.gradprop(Z, self.X, S)[0]
192
+
193
+ R = self.X * C
194
+
195
+ return R
196
+
197
+
198
+ class Cat(RelProp):
199
+ def forward(self, inputs, dim):
200
+ self.__setattr__("dim", dim)
201
+ return torch.cat(inputs, dim)
202
+
203
+ def relprop(self, R, alpha):
204
+ Z = self.forward(self.X, self.dim)
205
+ S = safe_divide(R, Z)
206
+ C = self.gradprop(Z, self.X, S)
207
+
208
+ outputs = []
209
+ for x, c in zip(self.X, C):
210
+ outputs.append(x * c)
211
+
212
+ return outputs
213
+
214
+
215
+ class Sequential(nn.Sequential):
216
+ def relprop(self, R, alpha):
217
+ for m in reversed(self._modules.values()):
218
+ R = m.relprop(R, alpha)
219
+ return R
220
+
221
+
222
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
223
+ def relprop(self, R, alpha):
224
+ X = self.X
225
+ beta = 1 - alpha
226
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
227
+ (
228
+ self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2)
229
+ + self.eps
230
+ ).pow(0.5)
231
+ )
232
+ Z = X * weight + 1e-9
233
+ S = R / Z
234
+ Ca = S * weight
235
+ R = self.X * (Ca)
236
+ return R
237
+
238
+
239
+ class Linear(nn.Linear, RelProp):
240
+ def relprop(self, R, alpha):
241
+ beta = alpha - 1
242
+ pw = torch.clamp(self.weight, min=0)
243
+ nw = torch.clamp(self.weight, max=0)
244
+ px = torch.clamp(self.X, min=0)
245
+ nx = torch.clamp(self.X, max=0)
246
+
247
+ def f(w1, w2, x1, x2):
248
+ Z1 = F.linear(x1, w1)
249
+ Z2 = F.linear(x2, w2)
250
+ S1 = safe_divide(R, Z1)
251
+ S2 = safe_divide(R, Z2)
252
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
253
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
254
+
255
+ return C1 + C2
256
+
257
+ activator_relevances = f(pw, nw, px, nx)
258
+ inhibitor_relevances = f(nw, pw, px, nx)
259
+
260
+ R = alpha * activator_relevances - beta * inhibitor_relevances
261
+
262
+ return R
263
+
264
+
265
+ class Conv2d(nn.Conv2d, RelProp):
266
+ def gradprop2(self, DY, weight):
267
+ Z = self.forward(self.X)
268
+
269
+ output_padding = self.X.size()[2] - (
270
+ (Z.size()[2] - 1) * self.stride[0]
271
+ - 2 * self.padding[0]
272
+ + self.kernel_size[0]
273
+ )
274
+
275
+ return F.conv_transpose2d(
276
+ DY,
277
+ weight,
278
+ stride=self.stride,
279
+ padding=self.padding,
280
+ output_padding=output_padding,
281
+ )
282
+
283
+ def relprop(self, R, alpha):
284
+ if self.X.shape[1] == 3:
285
+ pw = torch.clamp(self.weight, min=0)
286
+ nw = torch.clamp(self.weight, max=0)
287
+ X = self.X
288
+ L = (
289
+ self.X * 0
290
+ + torch.min(
291
+ torch.min(
292
+ torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
293
+ )[0],
294
+ dim=3,
295
+ keepdim=True,
296
+ )[0]
297
+ )
298
+ H = (
299
+ self.X * 0
300
+ + torch.max(
301
+ torch.max(
302
+ torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
303
+ )[0],
304
+ dim=3,
305
+ keepdim=True,
306
+ )[0]
307
+ )
308
+ Za = (
309
+ torch.conv2d(
310
+ X, self.weight, bias=None, stride=self.stride, padding=self.padding
311
+ )
312
+ - torch.conv2d(
313
+ L, pw, bias=None, stride=self.stride, padding=self.padding
314
+ )
315
+ - torch.conv2d(
316
+ H, nw, bias=None, stride=self.stride, padding=self.padding
317
+ )
318
+ + 1e-9
319
+ )
320
+
321
+ S = R / Za
322
+ C = (
323
+ X * self.gradprop2(S, self.weight)
324
+ - L * self.gradprop2(S, pw)
325
+ - H * self.gradprop2(S, nw)
326
+ )
327
+ R = C
328
+ else:
329
+ beta = alpha - 1
330
+ pw = torch.clamp(self.weight, min=0)
331
+ nw = torch.clamp(self.weight, max=0)
332
+ px = torch.clamp(self.X, min=0)
333
+ nx = torch.clamp(self.X, max=0)
334
+
335
+ def f(w1, w2, x1, x2):
336
+ Z1 = F.conv2d(
337
+ x1, w1, bias=None, stride=self.stride, padding=self.padding
338
+ )
339
+ Z2 = F.conv2d(
340
+ x2, w2, bias=None, stride=self.stride, padding=self.padding
341
+ )
342
+ S1 = safe_divide(R, Z1)
343
+ S2 = safe_divide(R, Z2)
344
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
345
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
346
+ return C1 + C2
347
+
348
+ activator_relevances = f(pw, nw, px, nx)
349
+ inhibitor_relevances = f(nw, pw, px, nx)
350
+
351
+ R = alpha * activator_relevances - beta * inhibitor_relevances
352
+ return R
Transformer-Explainability/BERT_explainability/modules/layers_ours.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = [
6
+ "forward_hook",
7
+ "Clone",
8
+ "Add",
9
+ "Cat",
10
+ "ReLU",
11
+ "GELU",
12
+ "Dropout",
13
+ "BatchNorm2d",
14
+ "Linear",
15
+ "MaxPool2d",
16
+ "AdaptiveAvgPool2d",
17
+ "AvgPool2d",
18
+ "Conv2d",
19
+ "Sequential",
20
+ "safe_divide",
21
+ "einsum",
22
+ "Softmax",
23
+ "IndexSelect",
24
+ "LayerNorm",
25
+ "AddEye",
26
+ "Tanh",
27
+ "MatMul",
28
+ "Mul",
29
+ ]
30
+
31
+
32
+ def safe_divide(a, b):
33
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
34
+ den = den + den.eq(0).type(den.type()) * 1e-9
35
+ return a / den * b.ne(0).type(b.type())
36
+
37
+
38
+ def forward_hook(self, input, output):
39
+ if type(input[0]) in (list, tuple):
40
+ self.X = []
41
+ for i in input[0]:
42
+ x = i.detach()
43
+ x.requires_grad = True
44
+ self.X.append(x)
45
+ else:
46
+ self.X = input[0].detach()
47
+ self.X.requires_grad = True
48
+
49
+ self.Y = output
50
+
51
+
52
+ def backward_hook(self, grad_input, grad_output):
53
+ self.grad_input = grad_input
54
+ self.grad_output = grad_output
55
+
56
+
57
+ class RelProp(nn.Module):
58
+ def __init__(self):
59
+ super(RelProp, self).__init__()
60
+ # if not self.training:
61
+ self.register_forward_hook(forward_hook)
62
+
63
+ def gradprop(self, Z, X, S):
64
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
65
+ return C
66
+
67
+ def relprop(self, R, alpha):
68
+ return R
69
+
70
+
71
+ class RelPropSimple(RelProp):
72
+ def relprop(self, R, alpha):
73
+ Z = self.forward(self.X)
74
+ S = safe_divide(R, Z)
75
+ C = self.gradprop(Z, self.X, S)
76
+
77
+ if torch.is_tensor(self.X) == False:
78
+ outputs = []
79
+ outputs.append(self.X[0] * C[0])
80
+ outputs.append(self.X[1] * C[1])
81
+ else:
82
+ outputs = self.X * (C[0])
83
+ return outputs
84
+
85
+
86
+ class AddEye(RelPropSimple):
87
+ # input of shape B, C, seq_len, seq_len
88
+ def forward(self, input):
89
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
90
+
91
+
92
+ class ReLU(nn.ReLU, RelProp):
93
+ pass
94
+
95
+
96
+ class GELU(nn.GELU, RelProp):
97
+ pass
98
+
99
+
100
+ class Softmax(nn.Softmax, RelProp):
101
+ pass
102
+
103
+
104
+ class Mul(RelPropSimple):
105
+ def forward(self, inputs):
106
+ return torch.mul(*inputs)
107
+
108
+
109
+ class Tanh(nn.Tanh, RelProp):
110
+ pass
111
+
112
+
113
+ class LayerNorm(nn.LayerNorm, RelProp):
114
+ pass
115
+
116
+
117
+ class Dropout(nn.Dropout, RelProp):
118
+ pass
119
+
120
+
121
+ class MatMul(RelPropSimple):
122
+ def forward(self, inputs):
123
+ return torch.matmul(*inputs)
124
+
125
+
126
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
127
+ pass
128
+
129
+
130
+ class LayerNorm(nn.LayerNorm, RelProp):
131
+ pass
132
+
133
+
134
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
135
+ pass
136
+
137
+
138
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
139
+ pass
140
+
141
+
142
+ class Add(RelPropSimple):
143
+ def forward(self, inputs):
144
+ return torch.add(*inputs)
145
+
146
+ def relprop(self, R, alpha):
147
+ Z = self.forward(self.X)
148
+ S = safe_divide(R, Z)
149
+ C = self.gradprop(Z, self.X, S)
150
+
151
+ a = self.X[0] * C[0]
152
+ b = self.X[1] * C[1]
153
+
154
+ a_sum = a.sum()
155
+ b_sum = b.sum()
156
+
157
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
158
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
159
+
160
+ a = a * safe_divide(a_fact, a.sum())
161
+ b = b * safe_divide(b_fact, b.sum())
162
+
163
+ outputs = [a, b]
164
+
165
+ return outputs
166
+
167
+
168
+ class einsum(RelPropSimple):
169
+ def __init__(self, equation):
170
+ super().__init__()
171
+ self.equation = equation
172
+
173
+ def forward(self, *operands):
174
+ return torch.einsum(self.equation, *operands)
175
+
176
+
177
+ class IndexSelect(RelProp):
178
+ def forward(self, inputs, dim, indices):
179
+ self.__setattr__("dim", dim)
180
+ self.__setattr__("indices", indices)
181
+
182
+ return torch.index_select(inputs, dim, indices)
183
+
184
+ def relprop(self, R, alpha):
185
+ Z = self.forward(self.X, self.dim, self.indices)
186
+ S = safe_divide(R, Z)
187
+ C = self.gradprop(Z, self.X, S)
188
+
189
+ if torch.is_tensor(self.X) == False:
190
+ outputs = []
191
+ outputs.append(self.X[0] * C[0])
192
+ outputs.append(self.X[1] * C[1])
193
+ else:
194
+ outputs = self.X * (C[0])
195
+ return outputs
196
+
197
+
198
+ class Clone(RelProp):
199
+ def forward(self, input, num):
200
+ self.__setattr__("num", num)
201
+ outputs = []
202
+ for _ in range(num):
203
+ outputs.append(input)
204
+
205
+ return outputs
206
+
207
+ def relprop(self, R, alpha):
208
+ Z = []
209
+ for _ in range(self.num):
210
+ Z.append(self.X)
211
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
212
+ C = self.gradprop(Z, self.X, S)[0]
213
+
214
+ R = self.X * C
215
+
216
+ return R
217
+
218
+
219
+ class Cat(RelProp):
220
+ def forward(self, inputs, dim):
221
+ self.__setattr__("dim", dim)
222
+ return torch.cat(inputs, dim)
223
+
224
+ def relprop(self, R, alpha):
225
+ Z = self.forward(self.X, self.dim)
226
+ S = safe_divide(R, Z)
227
+ C = self.gradprop(Z, self.X, S)
228
+
229
+ outputs = []
230
+ for x, c in zip(self.X, C):
231
+ outputs.append(x * c)
232
+
233
+ return outputs
234
+
235
+
236
+ class Sequential(nn.Sequential):
237
+ def relprop(self, R, alpha):
238
+ for m in reversed(self._modules.values()):
239
+ R = m.relprop(R, alpha)
240
+ return R
241
+
242
+
243
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
244
+ def relprop(self, R, alpha):
245
+ X = self.X
246
+ beta = 1 - alpha
247
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
248
+ (
249
+ self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2)
250
+ + self.eps
251
+ ).pow(0.5)
252
+ )
253
+ Z = X * weight + 1e-9
254
+ S = R / Z
255
+ Ca = S * weight
256
+ R = self.X * (Ca)
257
+ return R
258
+
259
+
260
+ class Linear(nn.Linear, RelProp):
261
+ def relprop(self, R, alpha):
262
+ beta = alpha - 1
263
+ pw = torch.clamp(self.weight, min=0)
264
+ nw = torch.clamp(self.weight, max=0)
265
+ px = torch.clamp(self.X, min=0)
266
+ nx = torch.clamp(self.X, max=0)
267
+
268
+ def f(w1, w2, x1, x2):
269
+ Z1 = F.linear(x1, w1)
270
+ Z2 = F.linear(x2, w2)
271
+ S1 = safe_divide(R, Z1 + Z2)
272
+ S2 = safe_divide(R, Z1 + Z2)
273
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
274
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
275
+
276
+ return C1 + C2
277
+
278
+ activator_relevances = f(pw, nw, px, nx)
279
+ inhibitor_relevances = f(nw, pw, px, nx)
280
+
281
+ R = alpha * activator_relevances - beta * inhibitor_relevances
282
+
283
+ return R
284
+
285
+
286
+ class Conv2d(nn.Conv2d, RelProp):
287
+ def gradprop2(self, DY, weight):
288
+ Z = self.forward(self.X)
289
+
290
+ output_padding = self.X.size()[2] - (
291
+ (Z.size()[2] - 1) * self.stride[0]
292
+ - 2 * self.padding[0]
293
+ + self.kernel_size[0]
294
+ )
295
+
296
+ return F.conv_transpose2d(
297
+ DY,
298
+ weight,
299
+ stride=self.stride,
300
+ padding=self.padding,
301
+ output_padding=output_padding,
302
+ )
303
+
304
+ def relprop(self, R, alpha):
305
+ if self.X.shape[1] == 3:
306
+ pw = torch.clamp(self.weight, min=0)
307
+ nw = torch.clamp(self.weight, max=0)
308
+ X = self.X
309
+ L = (
310
+ self.X * 0
311
+ + torch.min(
312
+ torch.min(
313
+ torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
314
+ )[0],
315
+ dim=3,
316
+ keepdim=True,
317
+ )[0]
318
+ )
319
+ H = (
320
+ self.X * 0
321
+ + torch.max(
322
+ torch.max(
323
+ torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True
324
+ )[0],
325
+ dim=3,
326
+ keepdim=True,
327
+ )[0]
328
+ )
329
+ Za = (
330
+ torch.conv2d(
331
+ X, self.weight, bias=None, stride=self.stride, padding=self.padding
332
+ )
333
+ - torch.conv2d(
334
+ L, pw, bias=None, stride=self.stride, padding=self.padding
335
+ )
336
+ - torch.conv2d(
337
+ H, nw, bias=None, stride=self.stride, padding=self.padding
338
+ )
339
+ + 1e-9
340
+ )
341
+
342
+ S = R / Za
343
+ C = (
344
+ X * self.gradprop2(S, self.weight)
345
+ - L * self.gradprop2(S, pw)
346
+ - H * self.gradprop2(S, nw)
347
+ )
348
+ R = C
349
+ else:
350
+ beta = alpha - 1
351
+ pw = torch.clamp(self.weight, min=0)
352
+ nw = torch.clamp(self.weight, max=0)
353
+ px = torch.clamp(self.X, min=0)
354
+ nx = torch.clamp(self.X, max=0)
355
+
356
+ def f(w1, w2, x1, x2):
357
+ Z1 = F.conv2d(
358
+ x1, w1, bias=None, stride=self.stride, padding=self.padding
359
+ )
360
+ Z2 = F.conv2d(
361
+ x2, w2, bias=None, stride=self.stride, padding=self.padding
362
+ )
363
+ S1 = safe_divide(R, Z1)
364
+ S2 = safe_divide(R, Z2)
365
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
366
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
367
+ return C1 + C2
368
+
369
+ activator_relevances = f(pw, nw, px, nx)
370
+ inhibitor_relevances = f(nw, pw, px, nx)
371
+
372
+ R = alpha * activator_relevances - beta * inhibitor_relevances
373
+ return R
Transformer-Explainability/BERT_params/boolq.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.2,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "False", "True" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.2,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
Transformer-Explainability/BERT_params/boolq_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "False", "True" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
Transformer-Explainability/BERT_params/boolq_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 10,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "False",
21
+ "True"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
Transformer-Explainability/BERT_params/boolq_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "False", "True" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
Transformer-Explainability/BERT_params/cose_bert.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 0,
6
+ "use_evidence_token_identifier": 1,
7
+ "evidence_token_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 0.5,
14
+ "sampling_method": "everything",
15
+ "use_half_precision": 0,
16
+ "cose_data_hack": 1
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [ "false", "true"],
20
+ "batch_size": 32,
21
+ "warmup_steps": 10,
22
+ "epochs": 10,
23
+ "patience": 10,
24
+ "lr": 1e-05,
25
+ "max_grad_norm": 0.5,
26
+ "sampling_method": "everything",
27
+ "use_half_precision": 0,
28
+ "cose_data_hack": 1
29
+ }
30
+ }
Transformer-Explainability/BERT_params/cose_multiclass.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "A",
21
+ "B",
22
+ "C",
23
+ "D",
24
+ "E"
25
+ ],
26
+ "batch_size": 10,
27
+ "warmup_steps": 50,
28
+ "epochs": 10,
29
+ "patience": 10,
30
+ "lr": 1e-05,
31
+ "max_grad_norm": 1,
32
+ "sampling_method": "everything",
33
+ "use_half_precision": 0
34
+ }
35
+ }
Transformer-Explainability/BERT_params/esnli_bert.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 0,
6
+ "use_evidence_token_identifier": 1,
7
+ "evidence_token_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "everything",
15
+ "use_half_precision": 0
16
+ },
17
+ "evidence_classifier": {
18
+ "classes": [ "contradiction", "neutral", "entailment" ],
19
+ "batch_size": 32,
20
+ "warmup_steps": 10,
21
+ "epochs": 10,
22
+ "patience": 10,
23
+ "lr": 1e-05,
24
+ "max_grad_norm": 1,
25
+ "sampling_method": "everything",
26
+ "use_half_precision": 0
27
+ }
28
+ }
Transformer-Explainability/BERT_params/evidence_inference.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/PubMed-w2v.bin",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
Transformer-Explainability/BERT_params/evidence_inference_bert.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "allenai/scibert_scivocab_uncased",
4
+ "bert_dir": "allenai/scibert_scivocab_uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 10,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "use_half_precision": 0,
16
+ "sampling_ratio": 1
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "significantly decreased",
21
+ "no significant difference",
22
+ "significantly increased"
23
+ ],
24
+ "batch_size": 10,
25
+ "warmup_steps": 10,
26
+ "epochs": 10,
27
+ "patience": 10,
28
+ "lr": 1e-05,
29
+ "max_grad_norm": 1,
30
+ "sampling_method": "everything",
31
+ "use_half_precision": 0
32
+ }
33
+ }
Transformer-Explainability/BERT_params/evidence_inference_soft.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/PubMed-w2v.bin",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
8
+ "use_token_selection": 1,
9
+ "has_query": 1,
10
+ "hidden_size": 32,
11
+ "mlp_size": 128,
12
+ "dropout": 0.2,
13
+ "batch_size": 16,
14
+ "epochs": 50,
15
+ "attention_epochs": 0,
16
+ "patience": 10,
17
+ "lr": 1e-3,
18
+ "dropout": 0.2,
19
+ "k_fraction": 0.013,
20
+ "threshold": 0.1
21
+ }
22
+ }
Transformer-Explainability/BERT_params/fever.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "SUPPORTS", "REFUTES" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-5,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
Transformer-Explainability/BERT_params/fever_baas.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "SUPPORTS", "REFUTES" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
Transformer-Explainability/BERT_params/fever_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 16,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1.0,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1.0,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "SUPPORTS",
21
+ "REFUTES"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 10,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1.0,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
Transformer-Explainability/BERT_params/fever_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "SUPPORTS", "REFUTES" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 128,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
Transformer-Explainability/BERT_params/movies.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-4,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "NEG", "POS" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
Transformer-Explainability/BERT_params/movies_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "NEG", "POS" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
Transformer-Explainability/BERT_params/movies_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 16,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "NEG",
21
+ "POS"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
Transformer-Explainability/BERT_params/movies_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "NEG", "POS" ],
8
+ "has_query": 0,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
Transformer-Explainability/BERT_params/multirc.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "False", "True" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
Transformer-Explainability/BERT_params/multirc_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "False", "True" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
Transformer-Explainability/BERT_params/multirc_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "False",
21
+ "True"
22
+ ],
23
+ "batch_size": 32,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
Transformer-Explainability/BERT_params/multirc_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "False", "True" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
Transformer-Explainability/BERT_rationale_benchmark/__init__.py ADDED
File without changes
Transformer-Explainability/BERT_rationale_benchmark/metrics.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import pprint
6
+ from collections import Counter, defaultdict, namedtuple
7
+ from dataclasses import dataclass
8
+ from itertools import chain
9
+ from typing import Any, Callable, Dict, List, Set, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ from BERT_rationale_benchmark.utils import (Annotation, Evidence,
14
+ annotations_from_jsonl,
15
+ load_documents,
16
+ load_flattened_documents,
17
+ load_jsonl)
18
+ from scipy.stats import entropy
19
+ from sklearn.metrics import (accuracy_score, auc, average_precision_score,
20
+ classification_report, precision_recall_curve,
21
+ roc_auc_score)
22
+
23
+ logging.basicConfig(
24
+ level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
25
+ )
26
+
27
+
28
+ # start_token is inclusive, end_token is exclusive
29
+ @dataclass(eq=True, frozen=True)
30
+ class Rationale:
31
+ ann_id: str
32
+ docid: str
33
+ start_token: int
34
+ end_token: int
35
+
36
+ def to_token_level(self) -> List["Rationale"]:
37
+ ret = []
38
+ for t in range(self.start_token, self.end_token):
39
+ ret.append(Rationale(self.ann_id, self.docid, t, t + 1))
40
+ return ret
41
+
42
+ @classmethod
43
+ def from_annotation(cls, ann: Annotation) -> List["Rationale"]:
44
+ ret = []
45
+ for ev_group in ann.evidences:
46
+ for ev in ev_group:
47
+ ret.append(
48
+ Rationale(ann.annotation_id, ev.docid, ev.start_token, ev.end_token)
49
+ )
50
+ return ret
51
+
52
+ @classmethod
53
+ def from_instance(cls, inst: dict) -> List["Rationale"]:
54
+ ret = []
55
+ for rat in inst["rationales"]:
56
+ for pred in rat.get("hard_rationale_predictions", []):
57
+ ret.append(
58
+ Rationale(
59
+ inst["annotation_id"],
60
+ rat["docid"],
61
+ pred["start_token"],
62
+ pred["end_token"],
63
+ )
64
+ )
65
+ return ret
66
+
67
+
68
+ @dataclass(eq=True, frozen=True)
69
+ class PositionScoredDocument:
70
+ ann_id: str
71
+ docid: str
72
+ scores: Tuple[float]
73
+ truths: Tuple[bool]
74
+
75
+ @classmethod
76
+ def from_results(
77
+ cls,
78
+ instances: List[dict],
79
+ annotations: List[Annotation],
80
+ docs: Dict[str, List[Any]],
81
+ use_tokens: bool = True,
82
+ ) -> List["PositionScoredDocument"]:
83
+ """Creates a paired list of annotation ids/docids/predictions/truth values"""
84
+ key_to_annotation = dict()
85
+ for ann in annotations:
86
+ for ev in chain.from_iterable(ann.evidences):
87
+ key = (ann.annotation_id, ev.docid)
88
+ if key not in key_to_annotation:
89
+ key_to_annotation[key] = [False for _ in docs[ev.docid]]
90
+ if use_tokens:
91
+ start, end = ev.start_token, ev.end_token
92
+ else:
93
+ start, end = ev.start_sentence, ev.end_sentence
94
+ for t in range(start, end):
95
+ key_to_annotation[key][t] = True
96
+ ret = []
97
+ if use_tokens:
98
+ field = "soft_rationale_predictions"
99
+ else:
100
+ field = "soft_sentence_predictions"
101
+ for inst in instances:
102
+ for rat in inst["rationales"]:
103
+ docid = rat["docid"]
104
+ scores = rat[field]
105
+ key = (inst["annotation_id"], docid)
106
+ assert len(scores) == len(docs[docid])
107
+ if key in key_to_annotation:
108
+ assert len(scores) == len(key_to_annotation[key])
109
+ else:
110
+ # In case model makes a prediction on docuemnt(s) for which ground truth evidence is not present
111
+ key_to_annotation[key] = [False for _ in docs[docid]]
112
+ ret.append(
113
+ PositionScoredDocument(
114
+ inst["annotation_id"],
115
+ docid,
116
+ tuple(scores),
117
+ tuple(key_to_annotation[key]),
118
+ )
119
+ )
120
+ return ret
121
+
122
+
123
+ def _f1(_p, _r):
124
+ if _p == 0 or _r == 0:
125
+ return 0
126
+ return 2 * _p * _r / (_p + _r)
127
+
128
+
129
+ def _keyed_rationale_from_list(
130
+ rats: List[Rationale],
131
+ ) -> Dict[Tuple[str, str], Rationale]:
132
+ ret = defaultdict(set)
133
+ for r in rats:
134
+ ret[(r.ann_id, r.docid)].add(r)
135
+ return ret
136
+
137
+
138
+ def partial_match_score(
139
+ truth: List[Rationale], pred: List[Rationale], thresholds: List[float]
140
+ ) -> List[Dict[str, Any]]:
141
+ """Computes a partial match F1
142
+
143
+ Computes an instance-level (annotation) micro- and macro-averaged F1 score.
144
+ True Positives are computed by using intersection-over-union and
145
+ thresholding the resulting intersection-over-union fraction.
146
+
147
+ Micro-average results are computed by ignoring instance level distinctions
148
+ in the TP calculation (and recall, and precision, and finally the F1 of
149
+ those numbers). Macro-average results are computed first by measuring
150
+ instance (annotation + document) precisions and recalls, averaging those,
151
+ and finally computing an F1 of the resulting average.
152
+ """
153
+
154
+ ann_to_rat = _keyed_rationale_from_list(truth)
155
+ pred_to_rat = _keyed_rationale_from_list(pred)
156
+
157
+ num_classifications = {k: len(v) for k, v in pred_to_rat.items()}
158
+ num_truth = {k: len(v) for k, v in ann_to_rat.items()}
159
+ ious = defaultdict(dict)
160
+ for k in set(ann_to_rat.keys()) | set(pred_to_rat.keys()):
161
+ for p in pred_to_rat.get(k, []):
162
+ best_iou = 0.0
163
+ for t in ann_to_rat.get(k, []):
164
+ num = len(
165
+ set(range(p.start_token, p.end_token))
166
+ & set(range(t.start_token, t.end_token))
167
+ )
168
+ denom = len(
169
+ set(range(p.start_token, p.end_token))
170
+ | set(range(t.start_token, t.end_token))
171
+ )
172
+ iou = 0 if denom == 0 else num / denom
173
+ if iou > best_iou:
174
+ best_iou = iou
175
+ ious[k][p] = best_iou
176
+ scores = []
177
+ for threshold in thresholds:
178
+ threshold_tps = dict()
179
+ for k, vs in ious.items():
180
+ threshold_tps[k] = sum(int(x >= threshold) for x in vs.values())
181
+ micro_r = (
182
+ sum(threshold_tps.values()) / sum(num_truth.values())
183
+ if sum(num_truth.values()) > 0
184
+ else 0
185
+ )
186
+ micro_p = (
187
+ sum(threshold_tps.values()) / sum(num_classifications.values())
188
+ if sum(num_classifications.values()) > 0
189
+ else 0
190
+ )
191
+ micro_f1 = _f1(micro_r, micro_p)
192
+ macro_rs = list(
193
+ threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_truth.items()
194
+ )
195
+ macro_ps = list(
196
+ threshold_tps.get(k, 0.0) / n if n > 0 else 0
197
+ for k, n in num_classifications.items()
198
+ )
199
+ macro_r = sum(macro_rs) / len(macro_rs) if len(macro_rs) > 0 else 0
200
+ macro_p = sum(macro_ps) / len(macro_ps) if len(macro_ps) > 0 else 0
201
+ macro_f1 = _f1(macro_r, macro_p)
202
+ scores.append(
203
+ {
204
+ "threshold": threshold,
205
+ "micro": {"p": micro_p, "r": micro_r, "f1": micro_f1},
206
+ "macro": {"p": macro_p, "r": macro_r, "f1": macro_f1},
207
+ }
208
+ )
209
+ return scores
210
+
211
+
212
+ def score_hard_rationale_predictions(
213
+ truth: List[Rationale], pred: List[Rationale]
214
+ ) -> Dict[str, Dict[str, float]]:
215
+ """Computes instance (annotation)-level micro/macro averaged F1s"""
216
+ scores = dict()
217
+ truth = set(truth)
218
+ pred = set(pred)
219
+ micro_prec = len(truth & pred) / len(pred)
220
+ micro_rec = len(truth & pred) / len(truth)
221
+ micro_f1 = _f1(micro_prec, micro_rec)
222
+ scores["instance_micro"] = {
223
+ "p": micro_prec,
224
+ "r": micro_rec,
225
+ "f1": micro_f1,
226
+ }
227
+
228
+ ann_to_rat = _keyed_rationale_from_list(truth)
229
+ pred_to_rat = _keyed_rationale_from_list(pred)
230
+ instances_to_scores = dict()
231
+ for k in set(ann_to_rat.keys()) | (pred_to_rat.keys()):
232
+ if len(pred_to_rat.get(k, set())) > 0:
233
+ instance_prec = len(
234
+ ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())
235
+ ) / len(pred_to_rat[k])
236
+ else:
237
+ instance_prec = 0
238
+ if len(ann_to_rat.get(k, set())) > 0:
239
+ instance_rec = len(
240
+ ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())
241
+ ) / len(ann_to_rat[k])
242
+ else:
243
+ instance_rec = 0
244
+ instance_f1 = _f1(instance_prec, instance_rec)
245
+ instances_to_scores[k] = {
246
+ "p": instance_prec,
247
+ "r": instance_rec,
248
+ "f1": instance_f1,
249
+ }
250
+ # these are calculated as sklearn would
251
+ macro_prec = sum(instance["p"] for instance in instances_to_scores.values()) / len(
252
+ instances_to_scores
253
+ )
254
+ macro_rec = sum(instance["r"] for instance in instances_to_scores.values()) / len(
255
+ instances_to_scores
256
+ )
257
+ macro_f1 = sum(instance["f1"] for instance in instances_to_scores.values()) / len(
258
+ instances_to_scores
259
+ )
260
+
261
+ f1_scores = [instance["f1"] for instance in instances_to_scores.values()]
262
+ print(macro_f1, np.argsort(f1_scores)[::-1])
263
+
264
+ scores["instance_macro"] = {
265
+ "p": macro_prec,
266
+ "r": macro_rec,
267
+ "f1": macro_f1,
268
+ }
269
+ return scores
270
+
271
+
272
+ def _auprc(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]]) -> float:
273
+ if len(preds) == 0:
274
+ return 0.0
275
+ assert len(truth.keys() and preds.keys()) == len(truth.keys())
276
+ aucs = []
277
+ for k, true in truth.items():
278
+ pred = preds[k]
279
+ true = [int(t) for t in true]
280
+ precision, recall, _ = precision_recall_curve(true, pred)
281
+ aucs.append(auc(recall, precision))
282
+ return np.average(aucs)
283
+
284
+
285
+ def _score_aggregator(
286
+ truth: Dict[Any, List[bool]],
287
+ preds: Dict[Any, List[float]],
288
+ score_function: Callable[[List[float], List[float]], float],
289
+ discard_single_class_answers: bool,
290
+ ) -> float:
291
+ if len(preds) == 0:
292
+ return 0.0
293
+ assert len(truth.keys() and preds.keys()) == len(truth.keys())
294
+ scores = []
295
+ for k, true in truth.items():
296
+ pred = preds[k]
297
+ if (all(true) or all(not x for x in true)) and discard_single_class_answers:
298
+ continue
299
+ true = [int(t) for t in true]
300
+ scores.append(score_function(true, pred))
301
+ return np.average(scores)
302
+
303
+
304
+ def score_soft_tokens(paired_scores: List[PositionScoredDocument]) -> Dict[str, float]:
305
+ truth = {(ps.ann_id, ps.docid): ps.truths for ps in paired_scores}
306
+ pred = {(ps.ann_id, ps.docid): ps.scores for ps in paired_scores}
307
+ auprc_score = _auprc(truth, pred)
308
+ ap = _score_aggregator(truth, pred, average_precision_score, True)
309
+ roc_auc = _score_aggregator(truth, pred, roc_auc_score, True)
310
+
311
+ return {
312
+ "auprc": auprc_score,
313
+ "average_precision": ap,
314
+ "roc_auc_score": roc_auc,
315
+ }
316
+
317
+
318
+ def _instances_aopc(
319
+ instances: List[dict], thresholds: List[float], key: str
320
+ ) -> Tuple[float, List[float]]:
321
+ dataset_scores = []
322
+ for inst in instances:
323
+ kls = inst["classification"]
324
+ beta_0 = inst["classification_scores"][kls]
325
+ instance_scores = []
326
+ for score in filter(
327
+ lambda x: x["threshold"] in thresholds,
328
+ sorted(inst["thresholded_scores"], key=lambda x: x["threshold"]),
329
+ ):
330
+ beta_k = score[key][kls]
331
+ delta = beta_0 - beta_k
332
+ instance_scores.append(delta)
333
+ assert len(instance_scores) == len(thresholds)
334
+ dataset_scores.append(instance_scores)
335
+ dataset_scores = np.array(dataset_scores)
336
+ # a careful reading of Samek, et al. "Evaluating the Visualization of What a Deep Neural Network Has Learned"
337
+ # and some algebra will show the reader that we can average in any of several ways and get the same result:
338
+ # over a flattened array, within an instance and then between instances, or over instances (by position) an
339
+ # then across them.
340
+ final_score = np.average(dataset_scores)
341
+ position_scores = np.average(dataset_scores, axis=0).tolist()
342
+
343
+ return final_score, position_scores
344
+
345
+
346
+ def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]):
347
+ if aopc_thresholds is None:
348
+ aopc_thresholds = sorted(
349
+ set(
350
+ chain.from_iterable(
351
+ [x["threshold"] for x in y["thresholded_scores"]] for y in instances
352
+ )
353
+ )
354
+ )
355
+ aopc_comprehensiveness_score, aopc_comprehensiveness_points = _instances_aopc(
356
+ instances, aopc_thresholds, "comprehensiveness_classification_scores"
357
+ )
358
+ aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc(
359
+ instances, aopc_thresholds, "sufficiency_classification_scores"
360
+ )
361
+ return (
362
+ aopc_thresholds,
363
+ aopc_comprehensiveness_score,
364
+ aopc_comprehensiveness_points,
365
+ aopc_sufficiency_score,
366
+ aopc_sufficiency_points,
367
+ )
368
+
369
+
370
+ def score_classifications(
371
+ instances: List[dict],
372
+ annotations: List[Annotation],
373
+ docs: Dict[str, List[str]],
374
+ aopc_thresholds: List[float],
375
+ ) -> Dict[str, float]:
376
+ def compute_kl(cls_scores_, faith_scores_):
377
+ keys = list(cls_scores_.keys())
378
+ cls_scores_ = [cls_scores_[k] for k in keys]
379
+ faith_scores_ = [faith_scores_[k] for k in keys]
380
+ return entropy(faith_scores_, cls_scores_)
381
+
382
+ labels = list(set(x.classification for x in annotations))
383
+ label_to_int = {l: i for i, l in enumerate(labels)}
384
+ key_to_instances = {inst["annotation_id"]: inst for inst in instances}
385
+ truth = []
386
+ predicted = []
387
+ for ann in annotations:
388
+ truth.append(label_to_int[ann.classification])
389
+ inst = key_to_instances[ann.annotation_id]
390
+ predicted.append(label_to_int[inst["classification"]])
391
+ classification_scores = classification_report(
392
+ truth, predicted, output_dict=True, target_names=labels, digits=3
393
+ )
394
+ accuracy = accuracy_score(truth, predicted)
395
+ if "comprehensiveness_classification_scores" in instances[0]:
396
+ comprehensiveness_scores = [
397
+ x["classification_scores"][x["classification"]]
398
+ - x["comprehensiveness_classification_scores"][x["classification"]]
399
+ for x in instances
400
+ ]
401
+ comprehensiveness_score = np.average(comprehensiveness_scores)
402
+ else:
403
+ comprehensiveness_score = None
404
+ comprehensiveness_scores = None
405
+
406
+ if "sufficiency_classification_scores" in instances[0]:
407
+ sufficiency_scores = [
408
+ x["classification_scores"][x["classification"]]
409
+ - x["sufficiency_classification_scores"][x["classification"]]
410
+ for x in instances
411
+ ]
412
+ sufficiency_score = np.average(sufficiency_scores)
413
+ else:
414
+ sufficiency_score = None
415
+ sufficiency_scores = None
416
+
417
+ if "comprehensiveness_classification_scores" in instances[0]:
418
+ comprehensiveness_entropies = [
419
+ entropy(list(x["classification_scores"].values()))
420
+ - entropy(list(x["comprehensiveness_classification_scores"].values()))
421
+ for x in instances
422
+ ]
423
+ comprehensiveness_entropy = np.average(comprehensiveness_entropies)
424
+ comprehensiveness_kl = np.average(
425
+ list(
426
+ compute_kl(
427
+ x["classification_scores"],
428
+ x["comprehensiveness_classification_scores"],
429
+ )
430
+ for x in instances
431
+ )
432
+ )
433
+ else:
434
+ comprehensiveness_entropies = None
435
+ comprehensiveness_kl = None
436
+ comprehensiveness_entropy = None
437
+
438
+ if "sufficiency_classification_scores" in instances[0]:
439
+ sufficiency_entropies = [
440
+ entropy(list(x["classification_scores"].values()))
441
+ - entropy(list(x["sufficiency_classification_scores"].values()))
442
+ for x in instances
443
+ ]
444
+ sufficiency_entropy = np.average(sufficiency_entropies)
445
+ sufficiency_kl = np.average(
446
+ list(
447
+ compute_kl(
448
+ x["classification_scores"], x["sufficiency_classification_scores"]
449
+ )
450
+ for x in instances
451
+ )
452
+ )
453
+ else:
454
+ sufficiency_entropies = None
455
+ sufficiency_kl = None
456
+ sufficiency_entropy = None
457
+
458
+ if "thresholded_scores" in instances[0]:
459
+ (
460
+ aopc_thresholds,
461
+ aopc_comprehensiveness_score,
462
+ aopc_comprehensiveness_points,
463
+ aopc_sufficiency_score,
464
+ aopc_sufficiency_points,
465
+ ) = compute_aopc_scores(instances, aopc_thresholds)
466
+ else:
467
+ (
468
+ aopc_thresholds,
469
+ aopc_comprehensiveness_score,
470
+ aopc_comprehensiveness_points,
471
+ aopc_sufficiency_score,
472
+ aopc_sufficiency_points,
473
+ ) = (None, None, None, None, None)
474
+ if "tokens_to_flip" in instances[0]:
475
+ token_percentages = []
476
+ for ann in annotations:
477
+ # in practice, this is of size 1 for everything except e-snli
478
+ docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
479
+ inst = key_to_instances[ann.annotation_id]
480
+ tokens = inst["tokens_to_flip"]
481
+ doc_lengths = sum(len(docs[d]) for d in docids)
482
+ token_percentages.append(tokens / doc_lengths)
483
+ token_percentages = np.average(token_percentages)
484
+ else:
485
+ token_percentages = None
486
+
487
+ return {
488
+ "accuracy": accuracy,
489
+ "prf": classification_scores,
490
+ "comprehensiveness": comprehensiveness_score,
491
+ "sufficiency": sufficiency_score,
492
+ "comprehensiveness_entropy": comprehensiveness_entropy,
493
+ "comprehensiveness_kl": comprehensiveness_kl,
494
+ "sufficiency_entropy": sufficiency_entropy,
495
+ "sufficiency_kl": sufficiency_kl,
496
+ "aopc_thresholds": aopc_thresholds,
497
+ "comprehensiveness_aopc": aopc_comprehensiveness_score,
498
+ "comprehensiveness_aopc_points": aopc_comprehensiveness_points,
499
+ "sufficiency_aopc": aopc_sufficiency_score,
500
+ "sufficiency_aopc_points": aopc_sufficiency_points,
501
+ }
502
+
503
+
504
+ def verify_instance(instance: dict, docs: Dict[str, list], thresholds: Set[float]):
505
+ error = False
506
+ docids = []
507
+ # verify the internal structure of these instances is correct:
508
+ # * hard predictions are present
509
+ # * start and end tokens are valid
510
+ # * soft rationale predictions, if present, must have the same document length
511
+
512
+ for rat in instance["rationales"]:
513
+ docid = rat["docid"]
514
+ if docid not in docid:
515
+ error = True
516
+ logging.info(
517
+ f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} could not be found as a preprocessed document! Gave up on additional processing.'
518
+ )
519
+ continue
520
+ doc_length = len(docs[docid])
521
+ for h1 in rat.get("hard_rationale_predictions", []):
522
+ # verify that each token is valid
523
+ # verify that no annotations overlap
524
+ for h2 in rat.get("hard_rationale_predictions", []):
525
+ if h1 == h2:
526
+ continue
527
+ if (
528
+ len(
529
+ set(range(h1["start_token"], h1["end_token"]))
530
+ & set(range(h2["start_token"], h2["end_token"]))
531
+ )
532
+ > 0
533
+ ):
534
+ logging.info(
535
+ f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} {h1} and {h2} overlap!'
536
+ )
537
+ error = True
538
+ if h1["start_token"] > doc_length:
539
+ logging.info(
540
+ f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}'
541
+ )
542
+ error = True
543
+ if h1["end_token"] > doc_length:
544
+ logging.info(
545
+ f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}'
546
+ )
547
+ error = True
548
+ # length check for soft rationale
549
+ # note that either flattened_documents or sentence-broken documents must be passed in depending on result
550
+ soft_rationale_predictions = rat.get("soft_rationale_predictions", [])
551
+ if (
552
+ len(soft_rationale_predictions) > 0
553
+ and len(soft_rationale_predictions) != doc_length
554
+ ):
555
+ logging.info(
556
+ f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} expected classifications for {doc_length} tokens but have them for {len(soft_rationale_predictions)} tokens instead!'
557
+ )
558
+ error = True
559
+
560
+ # count that one appears per-document
561
+ docids = Counter(docids)
562
+ for docid, count in docids.items():
563
+ if count > 1:
564
+ error = True
565
+ logging.info(
566
+ 'Error! For instance annotation={instance["annotation_id"]}, docid={docid} appear {count} times, may only appear once!'
567
+ )
568
+
569
+ classification = instance.get("classification", "")
570
+ if not isinstance(classification, str):
571
+ logging.info(
572
+ f'Error! For instance annotation={instance["annotation_id"]}, classification field {classification} is not a string!'
573
+ )
574
+ error = True
575
+ classification_scores = instance.get("classification_scores", dict())
576
+ if not isinstance(classification_scores, dict):
577
+ logging.info(
578
+ f'Error! For instance annotation={instance["annotation_id"]}, classification_scores field {classification_scores} is not a dict!'
579
+ )
580
+ error = True
581
+ comprehensiveness_classification_scores = instance.get(
582
+ "comprehensiveness_classification_scores", dict()
583
+ )
584
+ if not isinstance(comprehensiveness_classification_scores, dict):
585
+ logging.info(
586
+ f'Error! For instance annotation={instance["annotation_id"]}, comprehensiveness_classification_scores field {comprehensiveness_classification_scores} is not a dict!'
587
+ )
588
+ error = True
589
+ sufficiency_classification_scores = instance.get(
590
+ "sufficiency_classification_scores", dict()
591
+ )
592
+ if not isinstance(sufficiency_classification_scores, dict):
593
+ logging.info(
594
+ f'Error! For instance annotation={instance["annotation_id"]}, sufficiency_classification_scores field {sufficiency_classification_scores} is not a dict!'
595
+ )
596
+ error = True
597
+ if ("classification" in instance) != ("classification_scores" in instance):
598
+ logging.info(
599
+ f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide classification scores!'
600
+ )
601
+ error = True
602
+ if ("comprehensiveness_classification_scores" in instance) and not (
603
+ "classification" in instance
604
+ ):
605
+ logging.info(
606
+ f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide a comprehensiveness_classification_score'
607
+ )
608
+ error = True
609
+ if ("sufficiency_classification_scores" in instance) and not (
610
+ "classification_scores" in instance
611
+ ):
612
+ logging.info(
613
+ f'Error! For instance annotation={instance["annotation_id"]}, when providing a sufficiency_classification_score, you must also provide a classification score!'
614
+ )
615
+ error = True
616
+ if "thresholded_scores" in instance:
617
+ instance_thresholds = set(
618
+ x["threshold"] for x in instance["thresholded_scores"]
619
+ )
620
+ if instance_thresholds != thresholds:
621
+ error = True
622
+ logging.info(
623
+ 'Error: {instance["thresholded_scores"]} has thresholds that differ from previous thresholds: {thresholds}'
624
+ )
625
+ if (
626
+ "comprehensiveness_classification_scores" not in instance
627
+ or "sufficiency_classification_scores" not in instance
628
+ or "classification" not in instance
629
+ or "classification_scores" not in instance
630
+ ):
631
+ error = True
632
+ logging.info(
633
+ "Error: {instance} must have comprehensiveness_classification_scores, sufficiency_classification_scores, classification, and classification_scores defined when including thresholded scores"
634
+ )
635
+ if not all(
636
+ "sufficiency_classification_scores" in x
637
+ for x in instance["thresholded_scores"]
638
+ ):
639
+ error = True
640
+ logging.info(
641
+ "Error: {instance} must have sufficiency_classification_scores for every threshold"
642
+ )
643
+ if not all(
644
+ "comprehensiveness_classification_scores" in x
645
+ for x in instance["thresholded_scores"]
646
+ ):
647
+ error = True
648
+ logging.info(
649
+ "Error: {instance} must have comprehensiveness_classification_scores for every threshold"
650
+ )
651
+ return error
652
+
653
+
654
+ def verify_instances(instances: List[dict], docs: Dict[str, list]):
655
+ annotation_ids = list(x["annotation_id"] for x in instances)
656
+ key_counter = Counter(annotation_ids)
657
+ multi_occurrence_annotation_ids = list(
658
+ filter(lambda kv: kv[1] > 1, key_counter.items())
659
+ )
660
+ error = False
661
+ if len(multi_occurrence_annotation_ids) > 0:
662
+ error = True
663
+ logging.info(
664
+ f"Error in instances: {len(multi_occurrence_annotation_ids)} appear multiple times in the annotations file: {multi_occurrence_annotation_ids}"
665
+ )
666
+ failed_validation = set()
667
+ instances_with_classification = list()
668
+ instances_with_soft_rationale_predictions = list()
669
+ instances_with_soft_sentence_predictions = list()
670
+ instances_with_comprehensiveness_classifications = list()
671
+ instances_with_sufficiency_classifications = list()
672
+ instances_with_thresholded_scores = list()
673
+ if "thresholded_scores" in instances[0]:
674
+ thresholds = set(x["threshold"] for x in instances[0]["thresholded_scores"])
675
+ else:
676
+ thresholds = None
677
+ for instance in instances:
678
+ instance_error = verify_instance(instance, docs, thresholds)
679
+ if instance_error:
680
+ error = True
681
+ failed_validation.add(instance["annotation_id"])
682
+ if instance.get("classification", None) != None:
683
+ instances_with_classification.append(instance)
684
+ if instance.get("comprehensiveness_classification_scores", None) != None:
685
+ instances_with_comprehensiveness_classifications.append(instance)
686
+ if instance.get("sufficiency_classification_scores", None) != None:
687
+ instances_with_sufficiency_classifications.append(instance)
688
+ has_soft_rationales = []
689
+ has_soft_sentences = []
690
+ for rat in instance["rationales"]:
691
+ if rat.get("soft_rationale_predictions", None) != None:
692
+ has_soft_rationales.append(rat)
693
+ if rat.get("soft_sentence_predictions", None) != None:
694
+ has_soft_sentences.append(rat)
695
+ if len(has_soft_rationales) > 0:
696
+ instances_with_soft_rationale_predictions.append(instance)
697
+ if len(has_soft_rationales) != len(instance["rationales"]):
698
+ error = True
699
+ logging.info(
700
+ f'Error: instance {instance["annotation"]} has soft rationales for some but not all reported documents!'
701
+ )
702
+ if len(has_soft_sentences) > 0:
703
+ instances_with_soft_sentence_predictions.append(instance)
704
+ if len(has_soft_sentences) != len(instance["rationales"]):
705
+ error = True
706
+ logging.info(
707
+ f'Error: instance {instance["annotation"]} has soft sentences for some but not all reported documents!'
708
+ )
709
+ if "thresholded_scores" in instance:
710
+ instances_with_thresholded_scores.append(instance)
711
+ logging.info(
712
+ f"Error in instances: {len(failed_validation)} instances fail validation: {failed_validation}"
713
+ )
714
+ if len(instances_with_classification) != 0 and len(
715
+ instances_with_classification
716
+ ) != len(instances):
717
+ logging.info(
718
+ f"Either all {len(instances)} must have a classification or none may, instead {len(instances_with_classification)} do!"
719
+ )
720
+ error = True
721
+ if len(instances_with_soft_sentence_predictions) != 0 and len(
722
+ instances_with_soft_sentence_predictions
723
+ ) != len(instances):
724
+ logging.info(
725
+ f"Either all {len(instances)} must have a sentence prediction or none may, instead {len(instances_with_soft_sentence_predictions)} do!"
726
+ )
727
+ error = True
728
+ if len(instances_with_soft_rationale_predictions) != 0 and len(
729
+ instances_with_soft_rationale_predictions
730
+ ) != len(instances):
731
+ logging.info(
732
+ f"Either all {len(instances)} must have a soft rationale prediction or none may, instead {len(instances_with_soft_rationale_predictions)} do!"
733
+ )
734
+ error = True
735
+ if len(instances_with_comprehensiveness_classifications) != 0 and len(
736
+ instances_with_comprehensiveness_classifications
737
+ ) != len(instances):
738
+ error = True
739
+ logging.info(
740
+ f"Either all {len(instances)} must have a comprehensiveness classification or none may, instead {len(instances_with_comprehensiveness_classifications)} do!"
741
+ )
742
+ if len(instances_with_sufficiency_classifications) != 0 and len(
743
+ instances_with_sufficiency_classifications
744
+ ) != len(instances):
745
+ error = True
746
+ logging.info(
747
+ f"Either all {len(instances)} must have a sufficiency classification or none may, instead {len(instances_with_sufficiency_classifications)} do!"
748
+ )
749
+ if len(instances_with_thresholded_scores) != 0 and len(
750
+ instances_with_thresholded_scores
751
+ ) != len(instances):
752
+ error = True
753
+ logging.info(
754
+ f"Either all {len(instances)} must have thresholded scores or none may, instead {len(instances_with_thresholded_scores)} do!"
755
+ )
756
+ if error:
757
+ raise ValueError(
758
+ "Some instances are invalid, please fix your formatting and try again"
759
+ )
760
+
761
+
762
+ def _has_hard_predictions(results: List[dict]) -> bool:
763
+ # assumes that we have run "verification" over the inputs
764
+ return (
765
+ "rationales" in results[0]
766
+ and len(results[0]["rationales"]) > 0
767
+ and "hard_rationale_predictions" in results[0]["rationales"][0]
768
+ and results[0]["rationales"][0]["hard_rationale_predictions"] is not None
769
+ and len(results[0]["rationales"][0]["hard_rationale_predictions"]) > 0
770
+ )
771
+
772
+
773
+ def _has_soft_predictions(results: List[dict]) -> bool:
774
+ # assumes that we have run "verification" over the inputs
775
+ return (
776
+ "rationales" in results[0]
777
+ and len(results[0]["rationales"]) > 0
778
+ and "soft_rationale_predictions" in results[0]["rationales"][0]
779
+ and results[0]["rationales"][0]["soft_rationale_predictions"] is not None
780
+ )
781
+
782
+
783
+ def _has_soft_sentence_predictions(results: List[dict]) -> bool:
784
+ # assumes that we have run "verification" over the inputs
785
+ return (
786
+ "rationales" in results[0]
787
+ and len(results[0]["rationales"]) > 0
788
+ and "soft_sentence_predictions" in results[0]["rationales"][0]
789
+ and results[0]["rationales"][0]["soft_sentence_predictions"] is not None
790
+ )
791
+
792
+
793
+ def _has_classifications(results: List[dict]) -> bool:
794
+ # assumes that we have run "verification" over the inputs
795
+ return "classification" in results[0] and results[0]["classification"] is not None
796
+
797
+
798
+ def main():
799
+ parser = argparse.ArgumentParser(
800
+ description="""Computes rationale and final class classification scores""",
801
+ formatter_class=argparse.RawTextHelpFormatter,
802
+ )
803
+ parser.add_argument(
804
+ "--data_dir",
805
+ dest="data_dir",
806
+ required=True,
807
+ help="Which directory contains a {train,val,test}.jsonl file?",
808
+ )
809
+ parser.add_argument(
810
+ "--split",
811
+ dest="split",
812
+ required=True,
813
+ help="Which of {train,val,test} are we scoring on?",
814
+ )
815
+ parser.add_argument(
816
+ "--strict",
817
+ dest="strict",
818
+ required=False,
819
+ action="store_true",
820
+ default=False,
821
+ help="Do we perform strict scoring?",
822
+ )
823
+ parser.add_argument(
824
+ "--results",
825
+ dest="results",
826
+ required=True,
827
+ help="""Results File
828
+ Contents are expected to be jsonl of:
829
+ {
830
+ "annotation_id": str, required
831
+ # these classifications *must not* overlap
832
+ "rationales": List[
833
+ {
834
+ "docid": str, required
835
+ "hard_rationale_predictions": List[{
836
+ "start_token": int, inclusive, required
837
+ "end_token": int, exclusive, required
838
+ }], optional,
839
+ # token level classifications, a value must be provided per-token
840
+ # in an ideal world, these correspond to the hard-decoding above.
841
+ "soft_rationale_predictions": List[float], optional.
842
+ # sentence level classifications, a value must be provided for every
843
+ # sentence in each document, or not at all
844
+ "soft_sentence_predictions": List[float], optional.
845
+ }
846
+ ],
847
+ # the classification the model made for the overall classification task
848
+ "classification": str, optional
849
+ # A probability distribution output by the model. We require this to be normalized.
850
+ "classification_scores": Dict[str, float], optional
851
+ # The next two fields are measures for how faithful your model is (the
852
+ # rationales it predicts are in some sense causal of the prediction), and
853
+ # how sufficient they are. We approximate a measure for comprehensiveness by
854
+ # asking that you remove the top k%% of tokens from your documents,
855
+ # running your models again, and reporting the score distribution in the
856
+ # "comprehensiveness_classification_scores" field.
857
+ # We approximate a measure of sufficiency by asking exactly the converse
858
+ # - that you provide model distributions on the removed k%% tokens.
859
+ # 'k' is determined by human rationales, and is documented in our paper.
860
+ # You should determine which of these tokens to remove based on some kind
861
+ # of information about your model: gradient based, attention based, other
862
+ # interpretability measures, etc.
863
+ # scores per class having removed k%% of the data, where k is determined by human comprehensive rationales
864
+ "comprehensiveness_classification_scores": Dict[str, float], optional
865
+ # scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales
866
+ "sufficiency_classification_scores": Dict[str, float], optional
867
+ # the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith.
868
+ "tokens_to_flip": int, optional
869
+ "thresholded_scores": List[{
870
+ "threshold": float, required,
871
+ "comprehensiveness_classification_scores": like "classification_scores"
872
+ "sufficiency_classification_scores": like "classification_scores"
873
+ }], optional. if present, then "classification" and "classification_scores" must be present
874
+ }
875
+ When providing one of the optional fields, it must be provided for *every* instance.
876
+ The classification, classification_score, and comprehensiveness_classification_scores
877
+ must together be present for every instance or absent for every instance.
878
+ """,
879
+ )
880
+ parser.add_argument(
881
+ "--iou_thresholds",
882
+ dest="iou_thresholds",
883
+ required=False,
884
+ nargs="+",
885
+ type=float,
886
+ default=[0.5],
887
+ help="""Thresholds for IOU scoring.
888
+
889
+ These are used for "soft" or partial match scoring of rationale spans.
890
+ A span is considered a match if the size of the intersection of the prediction
891
+ and the annotation, divided by the union of the two spans, is larger than
892
+ the IOU threshold. This score can be computed for arbitrary thresholds.
893
+ """,
894
+ )
895
+ parser.add_argument(
896
+ "--score_file",
897
+ dest="score_file",
898
+ required=False,
899
+ default=None,
900
+ help="Where to write results?",
901
+ )
902
+ parser.add_argument(
903
+ "--aopc_thresholds",
904
+ nargs="+",
905
+ required=False,
906
+ type=float,
907
+ default=[0.01, 0.05, 0.1, 0.2, 0.5],
908
+ help="Thresholds for AOPC Thresholds",
909
+ )
910
+ args = parser.parse_args()
911
+ results = load_jsonl(args.results)
912
+ docids = set(
913
+ chain.from_iterable(
914
+ [rat["docid"] for rat in res["rationales"]] for res in results
915
+ )
916
+ )
917
+ docs = load_flattened_documents(args.data_dir, docids)
918
+ verify_instances(results, docs)
919
+ # load truth
920
+ annotations = annotations_from_jsonl(
921
+ os.path.join(args.data_dir, args.split + ".jsonl")
922
+ )
923
+ docids |= set(
924
+ chain.from_iterable(
925
+ (ev.docid for ev in chain.from_iterable(ann.evidences))
926
+ for ann in annotations
927
+ )
928
+ )
929
+
930
+ has_final_predictions = _has_classifications(results)
931
+ scores = dict()
932
+ if args.strict:
933
+ if not args.iou_thresholds:
934
+ raise ValueError(
935
+ "--iou_thresholds must be provided when running strict scoring"
936
+ )
937
+ if not has_final_predictions:
938
+ raise ValueError(
939
+ "We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!"
940
+ )
941
+ # TODO think about offering a sentence level version of these scores.
942
+ if _has_hard_predictions(results):
943
+ truth = list(
944
+ chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations)
945
+ )
946
+ pred = list(
947
+ chain.from_iterable(Rationale.from_instance(inst) for inst in results)
948
+ )
949
+ if args.iou_thresholds is not None:
950
+ iou_scores = partial_match_score(truth, pred, args.iou_thresholds)
951
+ scores["iou_scores"] = iou_scores
952
+ # NER style scoring
953
+ rationale_level_prf = score_hard_rationale_predictions(truth, pred)
954
+ scores["rationale_prf"] = rationale_level_prf
955
+ token_level_truth = list(
956
+ chain.from_iterable(rat.to_token_level() for rat in truth)
957
+ )
958
+ token_level_pred = list(
959
+ chain.from_iterable(rat.to_token_level() for rat in pred)
960
+ )
961
+ token_level_prf = score_hard_rationale_predictions(
962
+ token_level_truth, token_level_pred
963
+ )
964
+ scores["token_prf"] = token_level_prf
965
+ else:
966
+ logging.info("No hard predictions detected, skipping rationale scoring")
967
+
968
+ if _has_soft_predictions(results):
969
+ flattened_documents = load_flattened_documents(args.data_dir, docids)
970
+ paired_scoring = PositionScoredDocument.from_results(
971
+ results, annotations, flattened_documents, use_tokens=True
972
+ )
973
+ token_scores = score_soft_tokens(paired_scoring)
974
+ scores["token_soft_metrics"] = token_scores
975
+ else:
976
+ logging.info("No soft predictions detected, skipping rationale scoring")
977
+
978
+ if _has_soft_sentence_predictions(results):
979
+ documents = load_documents(args.data_dir, docids)
980
+ paired_scoring = PositionScoredDocument.from_results(
981
+ results, annotations, documents, use_tokens=False
982
+ )
983
+ sentence_scores = score_soft_tokens(paired_scoring)
984
+ scores["sentence_soft_metrics"] = sentence_scores
985
+ else:
986
+ logging.info(
987
+ "No sentence level predictions detected, skipping sentence-level diagnostic"
988
+ )
989
+
990
+ if has_final_predictions:
991
+ flattened_documents = load_flattened_documents(args.data_dir, docids)
992
+ class_results = score_classifications(
993
+ results, annotations, flattened_documents, args.aopc_thresholds
994
+ )
995
+ scores["classification_scores"] = class_results
996
+ else:
997
+ logging.info("No classification scores detected, skipping classification")
998
+
999
+ pprint.pprint(scores)
1000
+
1001
+ if args.score_file:
1002
+ with open(args.score_file, "w") as of:
1003
+ json.dump(scores, of, indent=4, sort_keys=True)
1004
+
1005
+
1006
+ if __name__ == "__main__":
1007
+ main()
Transformer-Explainability/BERT_rationale_benchmark/models/model_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Set
3
+
4
+ import numpy as np
5
+ import torch
6
+ from gensim.models import KeyedVectors
7
+ from torch import nn
8
+ from torch.nn.utils.rnn import (PackedSequence, pack_padded_sequence,
9
+ pad_packed_sequence, pad_sequence)
10
+
11
+
12
+ @dataclass(eq=True, frozen=True)
13
+ class PaddedSequence:
14
+ """A utility class for padding variable length sequences mean for RNN input
15
+ This class is in the style of PackedSequence from the PyTorch RNN Utils,
16
+ but is somewhat more manual in approach. It provides the ability to generate masks
17
+ for outputs of the same input dimensions.
18
+ The constructor should never be called directly and should only be called via
19
+ the autopad classmethod.
20
+
21
+ We'd love to delete this, but we pad_sequence, pack_padded_sequence, and
22
+ pad_packed_sequence all require shuffling around tuples of information, and some
23
+ convenience methods using these are nice to have.
24
+ """
25
+
26
+ data: torch.Tensor
27
+ batch_sizes: torch.Tensor
28
+ batch_first: bool = False
29
+
30
+ @classmethod
31
+ def autopad(
32
+ cls, data, batch_first: bool = False, padding_value=0, device=None
33
+ ) -> "PaddedSequence":
34
+ # handle tensors of size 0 (single item)
35
+ data_ = []
36
+ for d in data:
37
+ if len(d.size()) == 0:
38
+ d = d.unsqueeze(0)
39
+ data_.append(d)
40
+ padded = pad_sequence(
41
+ data_, batch_first=batch_first, padding_value=padding_value
42
+ )
43
+ if batch_first:
44
+ batch_lengths = torch.LongTensor([len(x) for x in data_])
45
+ if any([x == 0 for x in batch_lengths]):
46
+ raise ValueError(
47
+ "Found a 0 length batch element, this can't possibly be right: {}".format(
48
+ batch_lengths
49
+ )
50
+ )
51
+ else:
52
+ # TODO actually test this codepath
53
+ batch_lengths = torch.LongTensor([len(x) for x in data])
54
+ return PaddedSequence(padded, batch_lengths, batch_first).to(device=device)
55
+
56
+ def pack_other(self, data: torch.Tensor):
57
+ return pack_padded_sequence(
58
+ data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False
59
+ )
60
+
61
+ @classmethod
62
+ def from_packed_sequence(
63
+ cls, ps: PackedSequence, batch_first: bool, padding_value=0
64
+ ) -> "PaddedSequence":
65
+ padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value)
66
+ return PaddedSequence(padded, batch_sizes, batch_first)
67
+
68
+ def cuda(self) -> "PaddedSequence":
69
+ return PaddedSequence(
70
+ self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first
71
+ )
72
+
73
+ def to(
74
+ self, dtype=None, device=None, copy=False, non_blocking=False
75
+ ) -> "PaddedSequence":
76
+ # TODO make to() support all of the torch.Tensor to() variants
77
+ return PaddedSequence(
78
+ self.data.to(
79
+ dtype=dtype, device=device, copy=copy, non_blocking=non_blocking
80
+ ),
81
+ self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking),
82
+ batch_first=self.batch_first,
83
+ )
84
+
85
+ def mask(
86
+ self, on=int(0), off=int(0), device="cpu", size=None, dtype=None
87
+ ) -> torch.Tensor:
88
+ if size is None:
89
+ size = self.data.size()
90
+ out_tensor = torch.zeros(*size, dtype=dtype)
91
+ # TODO this can be done more efficiently
92
+ out_tensor.fill_(off)
93
+ # note to self: these are probably less efficient than explicilty populating the off values instead of the on values.
94
+ if self.batch_first:
95
+ for i, bl in enumerate(self.batch_sizes):
96
+ out_tensor[i, :bl] = on
97
+ else:
98
+ for i, bl in enumerate(self.batch_sizes):
99
+ out_tensor[:bl, i] = on
100
+ return out_tensor.to(device)
101
+
102
+ def unpad(self, other: torch.Tensor) -> List[torch.Tensor]:
103
+ out = []
104
+ for o, bl in zip(other, self.batch_sizes):
105
+ out.append(o[:bl])
106
+ return out
107
+
108
+ def flip(self) -> "PaddedSequence":
109
+ return PaddedSequence(
110
+ self.data.transpose(0, 1), not self.batch_first, self.padding_value
111
+ )
112
+
113
+
114
+ def extract_embeddings(
115
+ vocab: Set[str], embedding_file: str, unk_token: str = "UNK", pad_token: str = "PAD"
116
+ ) -> (nn.Embedding, Dict[str, int], List[str]):
117
+ vocab = vocab | set([unk_token, pad_token])
118
+ if embedding_file.endswith(".bin"):
119
+ WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True)
120
+
121
+ word_to_vector = dict()
122
+ WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()])
123
+
124
+ if unk_token not in WVs:
125
+ mean_vector = np.mean(WV_matrix, axis=0)
126
+ word_to_vector[unk_token] = mean_vector
127
+ if pad_token not in WVs:
128
+ word_to_vector[pad_token] = np.zeros(WVs.vector_size)
129
+
130
+ for v in vocab:
131
+ if v in WVs:
132
+ word_to_vector[v] = WVs[v]
133
+
134
+ interner = dict()
135
+ deinterner = list()
136
+ vectors = []
137
+ count = 0
138
+ for word in [pad_token, unk_token] + sorted(
139
+ list(word_to_vector.keys() - {unk_token, pad_token})
140
+ ):
141
+ vector = word_to_vector[word]
142
+ vectors.append(np.array(vector))
143
+ interner[word] = count
144
+ deinterner.append(word)
145
+ count += 1
146
+ vectors = torch.FloatTensor(np.array(vectors))
147
+ embedding = nn.Embedding.from_pretrained(
148
+ vectors, padding_idx=interner[pad_token]
149
+ )
150
+ embedding.weight.requires_grad = False
151
+ return embedding, interner, deinterner
152
+ elif embedding_file.endswith(".txt"):
153
+ word_to_vector = dict()
154
+ vector = []
155
+ with open(embedding_file, "r") as inf:
156
+ for line in inf:
157
+ contents = line.strip().split()
158
+ word = contents[0]
159
+ vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0)
160
+ word_to_vector[word] = vector
161
+ embed_size = vector.size()
162
+ if unk_token not in word_to_vector:
163
+ mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0)
164
+ word_to_vector[unk_token] = mean_vector.unsqueeze(0)
165
+ if pad_token not in word_to_vector:
166
+ word_to_vector[pad_token] = torch.zeros(embed_size)
167
+ interner = dict()
168
+ deinterner = list()
169
+ vectors = []
170
+ count = 0
171
+ for word in [pad_token, unk_token] + sorted(
172
+ list(word_to_vector.keys() - {unk_token, pad_token})
173
+ ):
174
+ vector = word_to_vector[word]
175
+ vectors.append(vector)
176
+ interner[word] = count
177
+ deinterner.append(word)
178
+ count += 1
179
+ vectors = torch.cat(vectors, dim=0)
180
+ embedding = nn.Embedding.from_pretrained(
181
+ vectors, padding_idx=interner[pad_token]
182
+ )
183
+ embedding.weight.requires_grad = False
184
+ return embedding, interner, deinterner
185
+ else:
186
+ raise ValueError("Unable to open embeddings file {}".format(embedding_file))
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/__init__.py ADDED
File without changes
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/bert_pipeline.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO consider if this can be collapsed back down into the pipeline_train.py
2
+ import argparse
3
+ import json
4
+ import logging
5
+ import os
6
+ import random
7
+ from collections import OrderedDict
8
+ from itertools import chain
9
+ from typing import List, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from BERT_explainability.modules.BERT.BERT_cls_lrp import \
15
+ BertForSequenceClassification as BertForClsOrigLrp
16
+ from BERT_explainability.modules.BERT.BertForSequenceClassification import \
17
+ BertForSequenceClassification as BertForSequenceClassificationTest
18
+ from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
19
+ from BERT_rationale_benchmark.utils import (Annotation, Evidence,
20
+ load_datasets, load_documents,
21
+ write_jsonl)
22
+ from sklearn.metrics import accuracy_score
23
+ from transformers import BertForSequenceClassification, BertTokenizer
24
+
25
+ logging.basicConfig(
26
+ level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+ # let's make this more or less deterministic (not resistent to restarts)
30
+ random.seed(12345)
31
+ np.random.seed(67890)
32
+ torch.manual_seed(10111213)
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = False
35
+
36
+
37
+ import numpy as np
38
+
39
+ latex_special_token = ["!@#$%^&*()"]
40
+
41
+
42
+ def generate(text_list, attention_list, latex_file, color="red", rescale_value=False):
43
+ attention_list = attention_list[: len(text_list)]
44
+ if attention_list.max() == attention_list.min():
45
+ attention_list = torch.zeros_like(attention_list)
46
+ else:
47
+ attention_list = (
48
+ 100
49
+ * (attention_list - attention_list.min())
50
+ / (attention_list.max() - attention_list.min())
51
+ )
52
+ attention_list[attention_list < 1] = 0
53
+ attention_list = attention_list.tolist()
54
+ text_list = [text_list[i].replace("$", "") for i in range(len(text_list))]
55
+ if rescale_value:
56
+ attention_list = rescale(attention_list)
57
+ word_num = len(text_list)
58
+ text_list = clean_word(text_list)
59
+ with open(latex_file, "w") as f:
60
+ f.write(
61
+ r"""\documentclass[varwidth=150mm]{standalone}
62
+ \special{papersize=210mm,297mm}
63
+ \usepackage{color}
64
+ \usepackage{tcolorbox}
65
+ \usepackage{CJK}
66
+ \usepackage{adjustbox}
67
+ \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
68
+ \begin{document}
69
+ \begin{CJK*}{UTF8}{gbsn}"""
70
+ + "\n"
71
+ )
72
+ string = (
73
+ r"""{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{"""
74
+ + "\n"
75
+ )
76
+ for idx in range(word_num):
77
+ # string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
78
+ # print(text_list[idx])
79
+ if "\#\#" in text_list[idx]:
80
+ token = text_list[idx].replace("\#\#", "")
81
+ string += (
82
+ "\\colorbox{%s!%s}{" % (color, attention_list[idx])
83
+ + "\\strut "
84
+ + token
85
+ + "}"
86
+ )
87
+ else:
88
+ string += (
89
+ " "
90
+ + "\\colorbox{%s!%s}{" % (color, attention_list[idx])
91
+ + "\\strut "
92
+ + text_list[idx]
93
+ + "}"
94
+ )
95
+ string += "\n}}}"
96
+ f.write(string + "\n")
97
+ f.write(
98
+ r"""\end{CJK*}
99
+ \end{document}"""
100
+ )
101
+
102
+
103
+ def clean_word(word_list):
104
+ new_word_list = []
105
+ for word in word_list:
106
+ for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
107
+ if latex_sensitive in word:
108
+ word = word.replace(latex_sensitive, "\\" + latex_sensitive)
109
+ new_word_list.append(word)
110
+ return new_word_list
111
+
112
+
113
+ def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id):
114
+ words = tokenizer.convert_ids_to_tokens(input_ids)
115
+ words = [word.replace("##", "") for word in words]
116
+ score_per_char = []
117
+
118
+ # TODO: DELETE
119
+ input_ids_chars = []
120
+ for word in words:
121
+ if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
122
+ continue
123
+ input_ids_chars += list(word)
124
+ # TODO: DELETE
125
+
126
+ for i in range(len(scores_per_id)):
127
+ if words[i] in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
128
+ continue
129
+ score_per_char += [scores_per_id[i]] * len(words[i])
130
+
131
+ score_per_word = []
132
+ start_idx = 0
133
+ end_idx = 0
134
+ # TODO: DELETE
135
+ words_from_chars = []
136
+ for inp in input:
137
+ if start_idx >= len(score_per_char):
138
+ break
139
+ end_idx = end_idx + len(inp)
140
+ score_per_word.append(np.max(score_per_char[start_idx:end_idx]))
141
+
142
+ # TODO: DELETE
143
+ words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
144
+
145
+ start_idx = end_idx
146
+
147
+ if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
148
+ print(words_from_chars)
149
+ print(input[: len(words_from_chars)])
150
+ print(words)
151
+ print(tokenizer.convert_ids_to_tokens(input_ids))
152
+ assert False
153
+
154
+ return torch.tensor(score_per_word)
155
+
156
+
157
+ def get_input_words(input, tokenizer, input_ids):
158
+ words = tokenizer.convert_ids_to_tokens(input_ids)
159
+ words = [word.replace("##", "") for word in words]
160
+
161
+ input_ids_chars = []
162
+ for word in words:
163
+ if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
164
+ continue
165
+ input_ids_chars += list(word)
166
+
167
+ start_idx = 0
168
+ end_idx = 0
169
+ words_from_chars = []
170
+ for inp in input:
171
+ if start_idx >= len(input_ids_chars):
172
+ break
173
+ end_idx = end_idx + len(inp)
174
+ words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
175
+ start_idx = end_idx
176
+
177
+ if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
178
+ print(words_from_chars)
179
+ print(input[: len(words_from_chars)])
180
+ print(words)
181
+ print(tokenizer.convert_ids_to_tokens(input_ids))
182
+ assert False
183
+ return words_from_chars
184
+
185
+
186
+ def bert_tokenize_doc(
187
+ doc: List[List[str]], tokenizer, special_token_map
188
+ ) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
189
+ """Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words"""
190
+ sents = []
191
+ sent_token_spans = []
192
+ for sent in doc:
193
+ tokens = []
194
+ spans = []
195
+ start = 0
196
+ for w in sent:
197
+ if w in special_token_map:
198
+ tokens.append(w)
199
+ else:
200
+ tokens.extend(tokenizer.tokenize(w))
201
+ end = len(tokens)
202
+ spans.append((start, end))
203
+ start = end
204
+ sents.append(tokens)
205
+ sent_token_spans.append(spans)
206
+ return sents, sent_token_spans
207
+
208
+
209
+ def initialize_models(params: dict, batch_first: bool, use_half_precision=False):
210
+ assert batch_first
211
+ max_length = params["max_length"]
212
+ tokenizer = BertTokenizer.from_pretrained(params["bert_vocab"])
213
+ pad_token_id = tokenizer.pad_token_id
214
+ cls_token_id = tokenizer.cls_token_id
215
+ sep_token_id = tokenizer.sep_token_id
216
+ bert_dir = params["bert_dir"]
217
+ evidence_classes = dict(
218
+ (y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"])
219
+ )
220
+ evidence_classifier = BertForSequenceClassification.from_pretrained(
221
+ bert_dir, num_labels=len(evidence_classes)
222
+ )
223
+ word_interner = tokenizer.vocab
224
+ de_interner = tokenizer.ids_to_tokens
225
+ return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer
226
+
227
+
228
+ BATCH_FIRST = True
229
+
230
+
231
+ def extract_docid_from_dataset_element(element):
232
+ return next(iter(element.evidences))[0].docid
233
+
234
+
235
+ def extract_evidence_from_dataset_element(element):
236
+ return next(iter(element.evidences))
237
+
238
+
239
+ def main():
240
+ parser = argparse.ArgumentParser(
241
+ description="""Trains a pipeline model.
242
+
243
+ Step 1 is evidence identification, that is identify if a given sentence is evidence or not
244
+ Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task
245
+ (e.g. sentiment or significance).
246
+
247
+ These models should be separated into two separate steps, but at the moment:
248
+ * prep data (load, intern documents, load json)
249
+ * convert data for evidence identification - in the case of training data we take all the positives and sample some
250
+ negatives
251
+ * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a
252
+ broader sampling of negative values.
253
+ * train evidence identification
254
+ * convert data for evidence classification - take all rationales + decisions and use this as input
255
+ * train evidence classification
256
+ * decode first the evidence, then run classification for each split
257
+
258
+ """,
259
+ formatter_class=argparse.RawTextHelpFormatter,
260
+ )
261
+ parser.add_argument(
262
+ "--data_dir",
263
+ dest="data_dir",
264
+ required=True,
265
+ help="Which directory contains a {train,val,test}.jsonl file?",
266
+ )
267
+ parser.add_argument(
268
+ "--output_dir",
269
+ dest="output_dir",
270
+ required=True,
271
+ help="Where shall we write intermediate models + final data to?",
272
+ )
273
+ parser.add_argument(
274
+ "--model_params",
275
+ dest="model_params",
276
+ required=True,
277
+ help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.",
278
+ )
279
+ args = parser.parse_args()
280
+ assert BATCH_FIRST
281
+ os.makedirs(args.output_dir, exist_ok=True)
282
+
283
+ with open(args.model_params, "r") as fp:
284
+ logger.info(f"Loading model parameters from {args.model_params}")
285
+ model_params = json.load(fp)
286
+ logger.info(f"Params: {json.dumps(model_params, indent=2, sort_keys=True)}")
287
+ train, val, test = load_datasets(args.data_dir)
288
+ docids = set(
289
+ e.docid
290
+ for e in chain.from_iterable(
291
+ chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))
292
+ )
293
+ )
294
+ documents = load_documents(args.data_dir, docids)
295
+ logger.info(f"Loaded {len(documents)} documents")
296
+ (
297
+ evidence_classifier,
298
+ word_interner,
299
+ de_interner,
300
+ evidence_classes,
301
+ tokenizer,
302
+ ) = initialize_models(model_params, batch_first=BATCH_FIRST)
303
+ logger.info(f"We have {len(word_interner)} wordpieces")
304
+ cache = os.path.join(args.output_dir, "preprocessed.pkl")
305
+ if os.path.exists(cache):
306
+ logger.info(f"Loading interned documents from {cache}")
307
+ (interned_documents) = torch.load(cache)
308
+ else:
309
+ logger.info(f"Interning documents")
310
+ interned_documents = {}
311
+ for d, doc in documents.items():
312
+ encoding = tokenizer.encode_plus(
313
+ doc,
314
+ add_special_tokens=True,
315
+ max_length=model_params["max_length"],
316
+ return_token_type_ids=False,
317
+ pad_to_max_length=False,
318
+ return_attention_mask=True,
319
+ return_tensors="pt",
320
+ truncation=True,
321
+ )
322
+ interned_documents[d] = encoding
323
+ torch.save((interned_documents), cache)
324
+
325
+ evidence_classifier = evidence_classifier.cuda()
326
+ optimizer = None
327
+ scheduler = None
328
+
329
+ save_dir = args.output_dir
330
+
331
+ logging.info(f"Beginning training classifier")
332
+ evidence_classifier_output_dir = os.path.join(save_dir, "classifier")
333
+ os.makedirs(save_dir, exist_ok=True)
334
+ os.makedirs(evidence_classifier_output_dir, exist_ok=True)
335
+ model_save_file = os.path.join(evidence_classifier_output_dir, "classifier.pt")
336
+ epoch_save_file = os.path.join(
337
+ evidence_classifier_output_dir, "classifier_epoch_data.pt"
338
+ )
339
+
340
+ device = next(evidence_classifier.parameters()).device
341
+ if optimizer is None:
342
+ optimizer = torch.optim.Adam(
343
+ evidence_classifier.parameters(),
344
+ lr=model_params["evidence_classifier"]["lr"],
345
+ )
346
+ criterion = nn.CrossEntropyLoss(reduction="none")
347
+ batch_size = model_params["evidence_classifier"]["batch_size"]
348
+ epochs = model_params["evidence_classifier"]["epochs"]
349
+ patience = model_params["evidence_classifier"]["patience"]
350
+ max_grad_norm = model_params["evidence_classifier"].get("max_grad_norm", None)
351
+
352
+ class_labels = [k for k, v in sorted(evidence_classes.items())]
353
+
354
+ results = {
355
+ "train_loss": [],
356
+ "train_f1": [],
357
+ "train_acc": [],
358
+ "val_loss": [],
359
+ "val_f1": [],
360
+ "val_acc": [],
361
+ }
362
+ best_epoch = -1
363
+ best_val_acc = 0
364
+ best_val_loss = float("inf")
365
+ best_model_state_dict = None
366
+ start_epoch = 0
367
+ epoch_data = {}
368
+ if os.path.exists(epoch_save_file):
369
+ logging.info(f"Restoring model from {model_save_file}")
370
+ evidence_classifier.load_state_dict(torch.load(model_save_file))
371
+ epoch_data = torch.load(epoch_save_file)
372
+ start_epoch = epoch_data["epoch"] + 1
373
+ # handle finishing because patience was exceeded or we didn't get the best final epoch
374
+ if bool(epoch_data.get("done", 0)):
375
+ start_epoch = epochs
376
+ results = epoch_data["results"]
377
+ best_epoch = start_epoch
378
+ best_model_state_dict = OrderedDict(
379
+ {k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
380
+ )
381
+ logging.info(f"Restoring training from epoch {start_epoch}")
382
+ logging.info(
383
+ f"Training evidence classifier from epoch {start_epoch} until epoch {epochs}"
384
+ )
385
+ optimizer.zero_grad()
386
+ for epoch in range(start_epoch, epochs):
387
+ epoch_train_data = random.sample(train, k=len(train))
388
+ epoch_train_loss = 0
389
+ epoch_training_acc = 0
390
+ evidence_classifier.train()
391
+ logging.info(
392
+ f"Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples"
393
+ )
394
+ for batch_start in range(0, len(epoch_train_data), batch_size):
395
+ batch_elements = epoch_train_data[
396
+ batch_start : min(batch_start + batch_size, len(epoch_train_data))
397
+ ]
398
+ targets = [evidence_classes[s.classification] for s in batch_elements]
399
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
400
+ samples_encoding = [
401
+ interned_documents[extract_docid_from_dataset_element(s)]
402
+ for s in batch_elements
403
+ ]
404
+ input_ids = (
405
+ torch.stack(
406
+ [
407
+ samples_encoding[i]["input_ids"]
408
+ for i in range(len(samples_encoding))
409
+ ]
410
+ )
411
+ .squeeze(1)
412
+ .to(device)
413
+ )
414
+ attention_masks = (
415
+ torch.stack(
416
+ [
417
+ samples_encoding[i]["attention_mask"]
418
+ for i in range(len(samples_encoding))
419
+ ]
420
+ )
421
+ .squeeze(1)
422
+ .to(device)
423
+ )
424
+ preds = evidence_classifier(
425
+ input_ids=input_ids, attention_mask=attention_masks
426
+ )[0]
427
+ epoch_training_acc += accuracy_score(
428
+ preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
429
+ )
430
+ loss = criterion(preds, targets.to(device=preds.device)).sum()
431
+ epoch_train_loss += loss.item()
432
+ loss.backward()
433
+ assert loss == loss # for nans
434
+ if max_grad_norm:
435
+ torch.nn.utils.clip_grad_norm_(
436
+ evidence_classifier.parameters(), max_grad_norm
437
+ )
438
+ optimizer.step()
439
+ if scheduler:
440
+ scheduler.step()
441
+ optimizer.zero_grad()
442
+ epoch_train_loss /= len(epoch_train_data)
443
+ epoch_training_acc /= len(epoch_train_data)
444
+ assert epoch_train_loss == epoch_train_loss # for nans
445
+ results["train_loss"].append(epoch_train_loss)
446
+ logging.info(f"Epoch {epoch} training loss {epoch_train_loss}")
447
+ logging.info(f"Epoch {epoch} training accuracy {epoch_training_acc}")
448
+
449
+ with torch.no_grad():
450
+ epoch_val_loss = 0
451
+ epoch_val_acc = 0
452
+ epoch_val_data = random.sample(val, k=len(val))
453
+ evidence_classifier.eval()
454
+ val_batch_size = 32
455
+ logging.info(
456
+ f"Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples"
457
+ )
458
+ for batch_start in range(0, len(epoch_val_data), val_batch_size):
459
+ batch_elements = epoch_val_data[
460
+ batch_start : min(batch_start + val_batch_size, len(epoch_val_data))
461
+ ]
462
+ targets = [evidence_classes[s.classification] for s in batch_elements]
463
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
464
+ samples_encoding = [
465
+ interned_documents[extract_docid_from_dataset_element(s)]
466
+ for s in batch_elements
467
+ ]
468
+ input_ids = (
469
+ torch.stack(
470
+ [
471
+ samples_encoding[i]["input_ids"]
472
+ for i in range(len(samples_encoding))
473
+ ]
474
+ )
475
+ .squeeze(1)
476
+ .to(device)
477
+ )
478
+ attention_masks = (
479
+ torch.stack(
480
+ [
481
+ samples_encoding[i]["attention_mask"]
482
+ for i in range(len(samples_encoding))
483
+ ]
484
+ )
485
+ .squeeze(1)
486
+ .to(device)
487
+ )
488
+ preds = evidence_classifier(
489
+ input_ids=input_ids, attention_mask=attention_masks
490
+ )[0]
491
+ epoch_val_acc += accuracy_score(
492
+ preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
493
+ )
494
+ loss = criterion(preds, targets.to(device=preds.device)).sum()
495
+ epoch_val_loss += loss.item()
496
+
497
+ epoch_val_loss /= len(val)
498
+ epoch_val_acc /= len(val)
499
+ results["val_acc"].append(epoch_val_acc)
500
+ results["val_loss"] = epoch_val_loss
501
+
502
+ logging.info(f"Epoch {epoch} val loss {epoch_val_loss}")
503
+ logging.info(f"Epoch {epoch} val acc {epoch_val_acc}")
504
+
505
+ if epoch_val_acc > best_val_acc or (
506
+ epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss
507
+ ):
508
+ best_model_state_dict = OrderedDict(
509
+ {k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
510
+ )
511
+ best_epoch = epoch
512
+ best_val_acc = epoch_val_acc
513
+ best_val_loss = epoch_val_loss
514
+ epoch_data = {
515
+ "epoch": epoch,
516
+ "results": results,
517
+ "best_val_acc": best_val_acc,
518
+ "done": 0,
519
+ }
520
+ torch.save(evidence_classifier.state_dict(), model_save_file)
521
+ torch.save(epoch_data, epoch_save_file)
522
+ logging.debug(
523
+ f"Epoch {epoch} new best model with val accuracy {epoch_val_acc}"
524
+ )
525
+ if epoch - best_epoch > patience:
526
+ logging.info(f"Exiting after epoch {epoch} due to no improvement")
527
+ epoch_data["done"] = 1
528
+ torch.save(epoch_data, epoch_save_file)
529
+ break
530
+
531
+ epoch_data["done"] = 1
532
+ epoch_data["results"] = results
533
+ torch.save(epoch_data, epoch_save_file)
534
+ evidence_classifier.load_state_dict(best_model_state_dict)
535
+ evidence_classifier = evidence_classifier.to(device=device)
536
+ evidence_classifier.eval()
537
+
538
+ # test
539
+
540
+ test_classifier = BertForSequenceClassificationTest.from_pretrained(
541
+ model_params["bert_dir"], num_labels=len(evidence_classes)
542
+ ).to(device)
543
+ orig_lrp_classifier = BertForClsOrigLrp.from_pretrained(
544
+ model_params["bert_dir"], num_labels=len(evidence_classes)
545
+ ).to(device)
546
+ if os.path.exists(epoch_save_file):
547
+ logging.info(f"Restoring model from {model_save_file}")
548
+ test_classifier.load_state_dict(torch.load(model_save_file))
549
+ orig_lrp_classifier.load_state_dict(torch.load(model_save_file))
550
+ test_classifier.eval()
551
+ orig_lrp_classifier.eval()
552
+ test_batch_size = 1
553
+ logging.info(
554
+ f"Testing with {len(test) // test_batch_size} batches with {len(test)} examples"
555
+ )
556
+
557
+ # explainability
558
+ explanations = Generator(test_classifier)
559
+ explanations_orig_lrp = Generator(orig_lrp_classifier)
560
+ method = "transformer_attribution"
561
+ method_folder = {
562
+ "transformer_attribution": "ours",
563
+ "partial_lrp": "partial_lrp",
564
+ "last_attn": "last_attn",
565
+ "attn_gradcam": "attn_gradcam",
566
+ "lrp": "lrp",
567
+ "rollout": "rollout",
568
+ "ground_truth": "ground_truth",
569
+ "generate_all": "generate_all",
570
+ }
571
+ method_expl = {
572
+ "transformer_attribution": explanations.generate_LRP,
573
+ "partial_lrp": explanations_orig_lrp.generate_LRP_last_layer,
574
+ "last_attn": explanations_orig_lrp.generate_attn_last_layer,
575
+ "attn_gradcam": explanations_orig_lrp.generate_attn_gradcam,
576
+ "lrp": explanations_orig_lrp.generate_full_lrp,
577
+ "rollout": explanations_orig_lrp.generate_rollout,
578
+ }
579
+
580
+ os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True)
581
+
582
+ result_files = []
583
+ for i in range(5, 85, 5):
584
+ result_files.append(
585
+ open(
586
+ os.path.join(
587
+ args.output_dir, "{0}/identifier_results_{1}.json"
588
+ ).format(method_folder[method], i),
589
+ "w",
590
+ )
591
+ )
592
+
593
+ j = 0
594
+ for batch_start in range(0, len(test), test_batch_size):
595
+ batch_elements = test[
596
+ batch_start : min(batch_start + test_batch_size, len(test))
597
+ ]
598
+ targets = [evidence_classes[s.classification] for s in batch_elements]
599
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
600
+ samples_encoding = [
601
+ interned_documents[extract_docid_from_dataset_element(s)]
602
+ for s in batch_elements
603
+ ]
604
+ input_ids = (
605
+ torch.stack(
606
+ [
607
+ samples_encoding[i]["input_ids"]
608
+ for i in range(len(samples_encoding))
609
+ ]
610
+ )
611
+ .squeeze(1)
612
+ .to(device)
613
+ )
614
+ attention_masks = (
615
+ torch.stack(
616
+ [
617
+ samples_encoding[i]["attention_mask"]
618
+ for i in range(len(samples_encoding))
619
+ ]
620
+ )
621
+ .squeeze(1)
622
+ .to(device)
623
+ )
624
+ preds = test_classifier(
625
+ input_ids=input_ids, attention_mask=attention_masks
626
+ )[0]
627
+
628
+ for s in batch_elements:
629
+ doc_name = extract_docid_from_dataset_element(s)
630
+ inp = documents[doc_name].split()
631
+ classification = "neg" if targets.item() == 0 else "pos"
632
+ is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
633
+ if method == "generate_all":
634
+ file_name = "{0}_{1}_{2}.tex".format(
635
+ j, classification, is_classification_correct
636
+ )
637
+ GT_global = os.path.join(
638
+ args.output_dir, "{0}/visual_results_{1}.pdf"
639
+ ).format(method_folder["ground_truth"], j)
640
+ GT_ours = os.path.join(
641
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
642
+ ).format(
643
+ method_folder["transformer_attribution"],
644
+ j,
645
+ classification,
646
+ is_classification_correct,
647
+ )
648
+ CF_ours = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
649
+ method_folder["transformer_attribution"], j
650
+ )
651
+ GT_partial = os.path.join(
652
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
653
+ ).format(
654
+ method_folder["partial_lrp"],
655
+ j,
656
+ classification,
657
+ is_classification_correct,
658
+ )
659
+ CF_partial = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
660
+ method_folder["partial_lrp"], j
661
+ )
662
+ GT_gradcam = os.path.join(
663
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
664
+ ).format(
665
+ method_folder["attn_gradcam"],
666
+ j,
667
+ classification,
668
+ is_classification_correct,
669
+ )
670
+ CF_gradcam = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
671
+ method_folder["attn_gradcam"], j
672
+ )
673
+ GT_lrp = os.path.join(
674
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
675
+ ).format(
676
+ method_folder["lrp"],
677
+ j,
678
+ classification,
679
+ is_classification_correct,
680
+ )
681
+ CF_lrp = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
682
+ method_folder["lrp"], j
683
+ )
684
+ GT_lastattn = os.path.join(
685
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
686
+ ).format(
687
+ method_folder["last_attn"],
688
+ j,
689
+ classification,
690
+ is_classification_correct,
691
+ )
692
+ GT_rollout = os.path.join(
693
+ args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
694
+ ).format(
695
+ method_folder["rollout"],
696
+ j,
697
+ classification,
698
+ is_classification_correct,
699
+ )
700
+ with open(file_name, "w") as f:
701
+ f.write(
702
+ r"""\documentclass[varwidth]{standalone}
703
+ \usepackage{color}
704
+ \usepackage{tcolorbox}
705
+ \usepackage{CJK}
706
+ \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
707
+ \begin{document}
708
+ \begin{CJK*}{UTF8}{gbsn}
709
+ {\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{
710
+ \setlength{\tabcolsep}{2pt} % Default value: 6pt
711
+ \begin{tabular}{ccc}
712
+ \includegraphics[width=0.32\linewidth]{"""
713
+ + GT_global
714
+ + """}&
715
+ \includegraphics[width=0.32\linewidth]{"""
716
+ + GT_ours
717
+ + """}&
718
+ \includegraphics[width=0.32\linewidth]{"""
719
+ + CF_ours
720
+ + """}\\\\
721
+ (a) & (b) & (c)\\\\
722
+ \includegraphics[width=0.32\linewidth]{"""
723
+ + GT_partial
724
+ + """}&
725
+ \includegraphics[width=0.32\linewidth]{"""
726
+ + CF_partial
727
+ + """}&
728
+ \includegraphics[width=0.32\linewidth]{"""
729
+ + GT_gradcam
730
+ + """}\\\\
731
+ (d) & (e) & (f)\\\\
732
+ \includegraphics[width=0.32\linewidth]{"""
733
+ + CF_gradcam
734
+ + """}&
735
+ \includegraphics[width=0.32\linewidth]{"""
736
+ + GT_lrp
737
+ + """}&
738
+ \includegraphics[width=0.32\linewidth]{"""
739
+ + CF_lrp
740
+ + """}\\\\
741
+ (g) & (h) & (i)\\\\
742
+ \includegraphics[width=0.32\linewidth]{"""
743
+ + GT_lastattn
744
+ + """}&
745
+ \includegraphics[width=0.32\linewidth]{"""
746
+ + GT_rollout
747
+ + """}&\\\\
748
+ (j) & (k)&\\\\
749
+ \end{tabular}
750
+ }}}
751
+ \end{CJK*}
752
+ \end{document}
753
+ )"""
754
+ )
755
+ j += 1
756
+ break
757
+
758
+ if method == "ground_truth":
759
+ inp_cropped = get_input_words(inp, tokenizer, input_ids[0])
760
+ cam = torch.zeros(len(inp_cropped))
761
+ for evidence in extract_evidence_from_dataset_element(s):
762
+ start_idx = evidence.start_token
763
+ if start_idx >= len(cam):
764
+ break
765
+ end_idx = evidence.end_token
766
+ cam[start_idx:end_idx] = 1
767
+ generate(
768
+ inp_cropped,
769
+ cam,
770
+ (
771
+ os.path.join(
772
+ args.output_dir, "{0}/visual_results_{1}.tex"
773
+ ).format(method_folder[method], j)
774
+ ),
775
+ color="green",
776
+ )
777
+ j = j + 1
778
+ break
779
+ text = tokenizer.convert_ids_to_tokens(input_ids[0])
780
+ classification = "neg" if targets.item() == 0 else "pos"
781
+ is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
782
+ target_idx = targets.item()
783
+ cam_target = method_expl[method](
784
+ input_ids=input_ids,
785
+ attention_mask=attention_masks,
786
+ index=target_idx,
787
+ )[0]
788
+ cam_target = cam_target.clamp(min=0)
789
+ generate(
790
+ text,
791
+ cam_target,
792
+ (
793
+ os.path.join(args.output_dir, "{0}/{1}_GT_{2}_{3}.tex").format(
794
+ method_folder[method],
795
+ j,
796
+ classification,
797
+ is_classification_correct,
798
+ )
799
+ ),
800
+ )
801
+ if method in [
802
+ "transformer_attribution",
803
+ "partial_lrp",
804
+ "attn_gradcam",
805
+ "lrp",
806
+ ]:
807
+ cam_false_class = method_expl[method](
808
+ input_ids=input_ids,
809
+ attention_mask=attention_masks,
810
+ index=1 - target_idx,
811
+ )[0]
812
+ cam_false_class = cam_false_class.clamp(min=0)
813
+ generate(
814
+ text,
815
+ cam_false_class,
816
+ (
817
+ os.path.join(args.output_dir, "{0}/{1}_CF.tex").format(
818
+ method_folder[method], j
819
+ )
820
+ ),
821
+ )
822
+ cam = cam_target
823
+ cam = scores_per_word_from_scores_per_token(
824
+ inp, tokenizer, input_ids[0], cam
825
+ )
826
+ j = j + 1
827
+ doc_name = extract_docid_from_dataset_element(s)
828
+ hard_rationales = []
829
+ for res, i in enumerate(range(5, 85, 5)):
830
+ print("calculating top ", i)
831
+ _, indices = cam.topk(k=i)
832
+ for index in indices.tolist():
833
+ hard_rationales.append(
834
+ {"start_token": index, "end_token": index + 1}
835
+ )
836
+ result_dict = {
837
+ "annotation_id": doc_name,
838
+ "rationales": [
839
+ {
840
+ "docid": doc_name,
841
+ "hard_rationale_predictions": hard_rationales,
842
+ }
843
+ ],
844
+ }
845
+ result_files[res].write(json.dumps(result_dict) + "\n")
846
+
847
+ for i in range(len(result_files)):
848
+ result_files[i].close()
849
+
850
+
851
+ if __name__ == "__main__":
852
+ main()
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_train.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ from itertools import chain
7
+ from typing import Set
8
+
9
+ import numpy as np
10
+ import torch
11
+ from rationale_benchmark.models.mlp import (AttentiveClassifier,
12
+ BahadanauAttention, RNNEncoder,
13
+ WordEmbedder)
14
+ from rationale_benchmark.models.model_utils import extract_embeddings
15
+ from rationale_benchmark.models.pipeline.evidence_classifier import \
16
+ train_evidence_classifier
17
+ from rationale_benchmark.models.pipeline.evidence_identifier import \
18
+ train_evidence_identifier
19
+ from rationale_benchmark.models.pipeline.pipeline_utils import decode
20
+ from rationale_benchmark.utils import (intern_annotations, intern_documents,
21
+ load_datasets, load_documents,
22
+ write_jsonl)
23
+
24
+ logging.basicConfig(
25
+ level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
26
+ )
27
+ # let's make this more or less deterministic (not resistant to restarts)
28
+ random.seed(12345)
29
+ np.random.seed(67890)
30
+ torch.manual_seed(10111213)
31
+ torch.backends.cudnn.deterministic = True
32
+ torch.backends.cudnn.benchmark = False
33
+
34
+
35
+ def initialize_models(
36
+ params: dict, vocab: Set[str], batch_first: bool, unk_token="UNK"
37
+ ):
38
+ # TODO this is obviously asking for some sort of dependency injection. implement if it saves me time.
39
+ if "embedding_file" in params["embeddings"]:
40
+ embeddings, word_interner, de_interner = extract_embeddings(
41
+ vocab, params["embeddings"]["embedding_file"], unk_token=unk_token
42
+ )
43
+ if torch.cuda.is_available():
44
+ embeddings = embeddings.cuda()
45
+ else:
46
+ raise ValueError("No 'embedding_file' found in params!")
47
+ word_embedder = WordEmbedder(embeddings, params["embeddings"]["dropout"])
48
+ query_encoder = RNNEncoder(
49
+ word_embedder,
50
+ batch_first=batch_first,
51
+ condition=False,
52
+ attention_mechanism=BahadanauAttention(word_embedder.output_dimension),
53
+ )
54
+ document_encoder = RNNEncoder(
55
+ word_embedder,
56
+ batch_first=batch_first,
57
+ condition=True,
58
+ attention_mechanism=BahadanauAttention(
59
+ word_embedder.output_dimension, query_size=query_encoder.output_dimension
60
+ ),
61
+ )
62
+ evidence_identifier = AttentiveClassifier(
63
+ document_encoder,
64
+ query_encoder,
65
+ 2,
66
+ params["evidence_identifier"]["mlp_size"],
67
+ params["evidence_identifier"]["dropout"],
68
+ )
69
+ query_encoder = RNNEncoder(
70
+ word_embedder,
71
+ batch_first=batch_first,
72
+ condition=False,
73
+ attention_mechanism=BahadanauAttention(word_embedder.output_dimension),
74
+ )
75
+ document_encoder = RNNEncoder(
76
+ word_embedder,
77
+ batch_first=batch_first,
78
+ condition=True,
79
+ attention_mechanism=BahadanauAttention(
80
+ word_embedder.output_dimension, query_size=query_encoder.output_dimension
81
+ ),
82
+ )
83
+ evidence_classes = dict(
84
+ (y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"])
85
+ )
86
+ evidence_classifier = AttentiveClassifier(
87
+ document_encoder,
88
+ query_encoder,
89
+ len(evidence_classes),
90
+ params["evidence_classifier"]["mlp_size"],
91
+ params["evidence_classifier"]["dropout"],
92
+ )
93
+ return (
94
+ evidence_identifier,
95
+ evidence_classifier,
96
+ word_interner,
97
+ de_interner,
98
+ evidence_classes,
99
+ )
100
+
101
+
102
+ def main():
103
+ parser = argparse.ArgumentParser(
104
+ description="""Trains a pipeline model.
105
+
106
+ Step 1 is evidence identification, that is identify if a given sentence is evidence or not
107
+ Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance).
108
+
109
+ These models should be separated into two separate steps, but at the moment:
110
+ * prep data (load, intern documents, load json)
111
+ * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives
112
+ * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values.
113
+ * train evidence identification
114
+ * convert data for evidence classification - take all rationales + decisions and use this as input
115
+ * train evidence classification
116
+ * decode first the evidence, then run classification for each split
117
+
118
+ """,
119
+ formatter_class=argparse.RawTextHelpFormatter,
120
+ )
121
+ parser.add_argument(
122
+ "--data_dir",
123
+ dest="data_dir",
124
+ required=True,
125
+ help="Which directory contains a {train,val,test}.jsonl file?",
126
+ )
127
+ parser.add_argument(
128
+ "--output_dir",
129
+ dest="output_dir",
130
+ required=True,
131
+ help="Where shall we write intermediate models + final data to?",
132
+ )
133
+ parser.add_argument(
134
+ "--model_params",
135
+ dest="model_params",
136
+ required=True,
137
+ help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.",
138
+ )
139
+ args = parser.parse_args()
140
+ BATCH_FIRST = True
141
+
142
+ with open(args.model_params, "r") as fp:
143
+ logging.debug(f"Loading model parameters from {args.model_params}")
144
+ model_params = json.load(fp)
145
+ train, val, test = load_datasets(args.data_dir)
146
+ docids = set(
147
+ e.docid
148
+ for e in chain.from_iterable(
149
+ chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))
150
+ )
151
+ )
152
+ documents = load_documents(args.data_dir, docids)
153
+ document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values())))
154
+ annotation_vocab = set(
155
+ chain.from_iterable(e.query.split() for e in chain(train, val, test))
156
+ )
157
+ logging.debug(
158
+ f"Loaded {len(documents)} documents with {len(document_vocab)} unique words"
159
+ )
160
+ # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important
161
+ vocab = document_vocab | annotation_vocab
162
+ unk_token = "UNK"
163
+ (
164
+ evidence_identifier,
165
+ evidence_classifier,
166
+ word_interner,
167
+ de_interner,
168
+ evidence_classes,
169
+ ) = initialize_models(
170
+ model_params, vocab, batch_first=BATCH_FIRST, unk_token=unk_token
171
+ )
172
+ logging.debug(
173
+ f"Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}"
174
+ )
175
+ interned_documents = intern_documents(documents, word_interner, unk_token)
176
+ interned_train = intern_annotations(train, word_interner, unk_token)
177
+ interned_val = intern_annotations(val, word_interner, unk_token)
178
+ interned_test = intern_annotations(test, word_interner, unk_token)
179
+ assert BATCH_FIRST # for correctness of the split dimension for DataParallel
180
+ evidence_identifier, evidence_ident_results = train_evidence_identifier(
181
+ evidence_identifier.cuda(),
182
+ args.output_dir,
183
+ interned_train,
184
+ interned_val,
185
+ interned_documents,
186
+ model_params,
187
+ tensorize_model_inputs=True,
188
+ )
189
+ evidence_classifier, evidence_class_results = train_evidence_classifier(
190
+ evidence_classifier.cuda(),
191
+ args.output_dir,
192
+ interned_train,
193
+ interned_val,
194
+ interned_documents,
195
+ model_params,
196
+ class_interner=evidence_classes,
197
+ tensorize_model_inputs=True,
198
+ )
199
+ pipeline_batch_size = min(
200
+ [
201
+ model_params["evidence_classifier"]["batch_size"],
202
+ model_params["evidence_identifier"]["batch_size"],
203
+ ]
204
+ )
205
+ pipeline_results, train_decoded, val_decoded, test_decoded = decode(
206
+ evidence_identifier,
207
+ evidence_classifier,
208
+ interned_train,
209
+ interned_val,
210
+ interned_test,
211
+ interned_documents,
212
+ evidence_classes,
213
+ pipeline_batch_size,
214
+ tensorize_model_inputs=True,
215
+ )
216
+ write_jsonl(train_decoded, os.path.join(args.output_dir, "train_decoded.jsonl"))
217
+ write_jsonl(val_decoded, os.path.join(args.output_dir, "val_decoded.jsonl"))
218
+ write_jsonl(test_decoded, os.path.join(args.output_dir, "test_decoded.jsonl"))
219
+ with open(
220
+ os.path.join(args.output_dir, "identifier_results.json"), "w"
221
+ ) as ident_output, open(
222
+ os.path.join(args.output_dir, "classifier_results.json"), "w"
223
+ ) as class_output:
224
+ ident_output.write(json.dumps(evidence_ident_results))
225
+ class_output.write(json.dumps(evidence_class_results))
226
+ for k, v in pipeline_results.items():
227
+ if type(v) is dict:
228
+ for k1, v1 in v.items():
229
+ logging.info(f"Pipeline results for {k}, {k1}={v1}")
230
+ else:
231
+ logging.info(f"Pipeline results {k}\t={v}")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()
Transformer-Explainability/BERT_rationale_benchmark/models/pipeline/pipeline_utils.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ from collections import defaultdict, namedtuple
4
+ from itertools import chain
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from rationale_benchmark.metrics import (PositionScoredDocument, Rationale,
11
+ partial_match_score,
12
+ score_hard_rationale_predictions,
13
+ score_soft_tokens)
14
+ from rationale_benchmark.models.model_utils import PaddedSequence
15
+ from rationale_benchmark.utils import Annotation
16
+ from sklearn.metrics import accuracy_score, classification_report
17
+
18
+ SentenceEvidence = namedtuple(
19
+ "SentenceEvidence", "kls ann_id query docid index sentence"
20
+ )
21
+
22
+
23
+ def token_annotations_to_evidence_classification(
24
+ annotations: List[Annotation],
25
+ documents: Dict[str, List[List[Any]]],
26
+ class_interner: Dict[str, int],
27
+ ) -> List[SentenceEvidence]:
28
+ ret = []
29
+ for ann in annotations:
30
+ docid_to_ev = defaultdict(list)
31
+ for evidence in ann.all_evidences():
32
+ docid_to_ev[evidence.docid].append(evidence)
33
+ for docid, evidences in docid_to_ev.items():
34
+ evidences = sorted(evidences, key=lambda ev: ev.start_token)
35
+ text = []
36
+ covered_tokens = set()
37
+ doc = list(chain.from_iterable(documents[docid]))
38
+ for evidence in evidences:
39
+ assert (
40
+ evidence.start_token >= 0
41
+ and evidence.end_token > evidence.start_token
42
+ )
43
+ assert evidence.start_token < len(doc) and evidence.end_token <= len(
44
+ doc
45
+ )
46
+ text.extend(evidence.text)
47
+ new_tokens = set(range(evidence.start_token, evidence.end_token))
48
+ if len(new_tokens & covered_tokens) > 0:
49
+ raise ValueError(
50
+ "Have overlapping token ranges covered in the evidence spans and the implementer was lazy; deal with it"
51
+ )
52
+ covered_tokens |= new_tokens
53
+ assert len(text) > 0
54
+ ret.append(
55
+ SentenceEvidence(
56
+ kls=class_interner[ann.classification],
57
+ query=ann.query,
58
+ ann_id=ann.annotation_id,
59
+ docid=docid,
60
+ index=-1,
61
+ sentence=tuple(text),
62
+ )
63
+ )
64
+ return ret
65
+
66
+
67
+ def annotations_to_evidence_classification(
68
+ annotations: List[Annotation],
69
+ documents: Dict[str, List[List[Any]]],
70
+ class_interner: Dict[str, int],
71
+ include_all: bool,
72
+ ) -> List[SentenceEvidence]:
73
+ """Converts Corpus-Level annotations to Sentence Level relevance judgments.
74
+
75
+ As this module is about a pipelined approach for evidence identification,
76
+ inputs to both an evidence identifier and evidence classifier need to be to
77
+ be on a sentence level, this module converts data to be that form.
78
+
79
+ The return type is of the form
80
+ annotation id -> docid -> [sentence level annotations]
81
+ """
82
+ ret = []
83
+ for ann in annotations:
84
+ ann_id = ann.annotation_id
85
+ docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
86
+ annotations_for_doc = defaultdict(list)
87
+ for d in docids:
88
+ for index, sent in enumerate(documents[d]):
89
+ annotations_for_doc[d].append(
90
+ SentenceEvidence(
91
+ kls=class_interner[ann.classification],
92
+ query=ann.query,
93
+ ann_id=ann.annotation_id,
94
+ docid=d,
95
+ index=index,
96
+ sentence=tuple(sent),
97
+ )
98
+ )
99
+ if include_all:
100
+ ret.extend(chain.from_iterable(annotations_for_doc.values()))
101
+ else:
102
+ contributes = set()
103
+ for ev in chain.from_iterable(ann.evidences):
104
+ for index in range(ev.start_sentence, ev.end_sentence):
105
+ contributes.add(annotations_for_doc[ev.docid][index])
106
+ ret.extend(contributes)
107
+ assert len(ret) > 0
108
+ return ret
109
+
110
+
111
+ def annotations_to_evidence_identification(
112
+ annotations: List[Annotation], documents: Dict[str, List[List[Any]]]
113
+ ) -> Dict[str, Dict[str, List[SentenceEvidence]]]:
114
+ """Converts Corpus-Level annotations to Sentence Level relevance judgments.
115
+
116
+ As this module is about a pipelined approach for evidence identification,
117
+ inputs to both an evidence identifier and evidence classifier need to be to
118
+ be on a sentence level, this module converts data to be that form.
119
+
120
+ The return type is of the form
121
+ annotation id -> docid -> [sentence level annotations]
122
+ """
123
+ ret = defaultdict(dict) # annotation id -> docid -> sentences
124
+ for ann in annotations:
125
+ ann_id = ann.annotation_id
126
+ for ev_group in ann.evidences:
127
+ for ev in ev_group:
128
+ if len(ev.text) == 0:
129
+ continue
130
+ if ev.docid not in ret[ann_id]:
131
+ ret[ann.annotation_id][ev.docid] = []
132
+ # populate the document with "not evidence"; to be filled in later
133
+ for index, sent in enumerate(documents[ev.docid]):
134
+ ret[ann.annotation_id][ev.docid].append(
135
+ SentenceEvidence(
136
+ kls=0,
137
+ query=ann.query,
138
+ ann_id=ann.annotation_id,
139
+ docid=ev.docid,
140
+ index=index,
141
+ sentence=sent,
142
+ )
143
+ )
144
+ # define the evidence sections of the document
145
+ for s in range(ev.start_sentence, ev.end_sentence):
146
+ ret[ann.annotation_id][ev.docid][s] = SentenceEvidence(
147
+ kls=1,
148
+ ann_id=ann.annotation_id,
149
+ query=ann.query,
150
+ docid=ev.docid,
151
+ index=ret[ann.annotation_id][ev.docid][s].index,
152
+ sentence=ret[ann.annotation_id][ev.docid][s].sentence,
153
+ )
154
+ return ret
155
+
156
+
157
+ def annotations_to_evidence_token_identification(
158
+ annotations: List[Annotation],
159
+ source_documents: Dict[str, List[List[str]]],
160
+ interned_documents: Dict[str, List[List[int]]],
161
+ token_mapping: Dict[str, List[List[Tuple[int, int]]]],
162
+ ) -> Dict[str, Dict[str, List[SentenceEvidence]]]:
163
+ # TODO document
164
+ # TODO should we simplify to use only source text?
165
+ ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences
166
+ positive_tokens = 0
167
+ negative_tokens = 0
168
+ for ann in annotations:
169
+ annid = ann.annotation_id
170
+ docids = set(ev.docid for ev in chain.from_iterable(ann.evidences))
171
+ sentence_offsets = defaultdict(list) # docid -> [(start, end)]
172
+ classes = defaultdict(list) # docid -> [token is yea or nay]
173
+ for docid in docids:
174
+ start = 0
175
+ assert len(source_documents[docid]) == len(interned_documents[docid])
176
+ for whole_token_sent, wordpiece_sent in zip(
177
+ source_documents[docid], interned_documents[docid]
178
+ ):
179
+ classes[docid].extend([0 for _ in wordpiece_sent])
180
+ end = start + len(wordpiece_sent)
181
+ sentence_offsets[docid].append((start, end))
182
+ start = end
183
+ for ev in chain.from_iterable(ann.evidences):
184
+ if len(ev.text) == 0:
185
+ continue
186
+ flat_token_map = list(chain.from_iterable(token_mapping[ev.docid]))
187
+ if ev.start_token != -1:
188
+ # start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1]
189
+ start, end = (
190
+ flat_token_map[ev.start_token][0],
191
+ flat_token_map[ev.end_token - 1][1],
192
+ )
193
+ else:
194
+ start = flat_token_map[sentence_offsets[ev.start_sentence][0]][0]
195
+ end = flat_token_map[sentence_offsets[ev.end_sentence - 1][1]][1]
196
+ for i in range(start, end):
197
+ classes[ev.docid][i] = 1
198
+ for docid, offsets in sentence_offsets.items():
199
+ token_assignments = classes[docid]
200
+ positive_tokens += sum(token_assignments)
201
+ negative_tokens += len(token_assignments) - sum(token_assignments)
202
+ for s, (start, end) in enumerate(offsets):
203
+ sent = interned_documents[docid][s]
204
+ ret[annid][docid].append(
205
+ SentenceEvidence(
206
+ kls=tuple(token_assignments[start:end]),
207
+ query=ann.query,
208
+ ann_id=ann.annotation_id,
209
+ docid=docid,
210
+ index=s,
211
+ sentence=sent,
212
+ )
213
+ )
214
+ logging.info(
215
+ f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens"
216
+ )
217
+ return ret
218
+
219
+
220
+ def make_preds_batch(
221
+ classifier: nn.Module,
222
+ batch_elements: List[SentenceEvidence],
223
+ device=None,
224
+ criterion: nn.Module = None,
225
+ tensorize_model_inputs: bool = True,
226
+ ) -> Tuple[float, List[float], List[int], List[int]]:
227
+ """Batch predictions
228
+
229
+ Args:
230
+ classifier: a module that looks like an AttentiveClassifier
231
+ batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects.
232
+ device: Optional; what compute device this should run on
233
+ criterion: Optional; a loss function
234
+ tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
235
+ """
236
+ # delete any "None" padding, if any (imposed by the use of the "grouper")
237
+ batch_elements = filter(lambda x: x is not None, batch_elements)
238
+ targets, queries, sentences = zip(
239
+ *[(s.kls, s.query, s.sentence) for s in batch_elements]
240
+ )
241
+ ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
242
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
243
+ if tensorize_model_inputs:
244
+ queries = [torch.tensor(q, dtype=torch.long) for q in queries]
245
+ sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
246
+ preds = classifier(queries, ids, sentences)
247
+ targets = targets.to(device=preds.device)
248
+ if criterion:
249
+ loss = criterion(preds, targets)
250
+ else:
251
+ loss = None
252
+ # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16
253
+ hard_preds = torch.argmax(preds.float(), dim=-1)
254
+ return loss, preds, hard_preds, targets
255
+
256
+
257
+ def make_preds_epoch(
258
+ classifier: nn.Module,
259
+ data: List[SentenceEvidence],
260
+ batch_size: int,
261
+ device=None,
262
+ criterion: nn.Module = None,
263
+ tensorize_model_inputs: bool = True,
264
+ ):
265
+ """Predictions for more than one batch.
266
+
267
+ Args:
268
+ classifier: a module that looks like an AttentiveClassifier
269
+ data: a list of elements to make predictions over. These must be SentenceEvidence objects.
270
+ batch_size: the biggest chunk we can fit in one batch.
271
+ device: Optional; what compute device this should run on
272
+ criterion: Optional; a loss function
273
+ tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
274
+ """
275
+ epoch_loss = 0
276
+ epoch_soft_pred = []
277
+ epoch_hard_pred = []
278
+ epoch_truth = []
279
+ batches = _grouper(data, batch_size)
280
+ classifier.eval()
281
+ for batch in batches:
282
+ loss, soft_preds, hard_preds, targets = make_preds_batch(
283
+ classifier,
284
+ batch,
285
+ device,
286
+ criterion=criterion,
287
+ tensorize_model_inputs=tensorize_model_inputs,
288
+ )
289
+ if loss is not None:
290
+ epoch_loss += loss.sum().item()
291
+ epoch_hard_pred.extend(hard_preds)
292
+ epoch_soft_pred.extend(soft_preds.cpu())
293
+ epoch_truth.extend(targets)
294
+ epoch_loss /= len(data)
295
+ epoch_hard_pred = [x.item() for x in epoch_hard_pred]
296
+ epoch_truth = [x.item() for x in epoch_truth]
297
+ return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth
298
+
299
+
300
+ def make_token_preds_batch(
301
+ classifier: nn.Module,
302
+ batch_elements: List[SentenceEvidence],
303
+ token_mapping: Dict[str, List[List[Tuple[int, int]]]],
304
+ device=None,
305
+ criterion: nn.Module = None,
306
+ tensorize_model_inputs: bool = True,
307
+ ) -> Tuple[float, List[float], List[int], List[int]]:
308
+ """Batch predictions
309
+
310
+ Args:
311
+ classifier: a module that looks like an AttentiveClassifier
312
+ batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects.
313
+ device: Optional; what compute device this should run on
314
+ criterion: Optional; a loss function
315
+ tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
316
+ """
317
+ # delete any "None" padding, if any (imposed by the use of the "grouper")
318
+ batch_elements = filter(lambda x: x is not None, batch_elements)
319
+ targets, queries, sentences = zip(
320
+ *[(s.kls, s.query, s.sentence) for s in batch_elements]
321
+ )
322
+ ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
323
+ targets = PaddedSequence.autopad(
324
+ [torch.tensor(t, dtype=torch.long, device=device) for t in targets],
325
+ batch_first=True,
326
+ device=device,
327
+ )
328
+ aggregate_spans = [token_mapping[s.docid][s.index] for s in batch_elements]
329
+ if tensorize_model_inputs:
330
+ queries = [torch.tensor(q, dtype=torch.long) for q in queries]
331
+ sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
332
+ preds = classifier(queries, ids, sentences, aggregate_spans)
333
+ targets = targets.to(device=preds.device)
334
+ mask = targets.mask(on=1, off=0, device=preds.device, dtype=torch.float)
335
+ if criterion:
336
+ loss = criterion(
337
+ preds, (targets.data.to(device=preds.device) * mask).squeeze()
338
+ ).sum()
339
+ else:
340
+ loss = None
341
+ hard_preds = [
342
+ torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(preds)
343
+ ]
344
+ targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())]
345
+ return loss, preds, hard_preds, targets # targets.unpad(targets.data.cpu())
346
+
347
+
348
+ # TODO fix the arguments
349
+ def make_token_preds_epoch(
350
+ classifier: nn.Module,
351
+ data: List[SentenceEvidence],
352
+ token_mapping: Dict[str, List[List[Tuple[int, int]]]],
353
+ batch_size: int,
354
+ device=None,
355
+ criterion: nn.Module = None,
356
+ tensorize_model_inputs: bool = True,
357
+ ):
358
+ """Predictions for more than one batch.
359
+
360
+ Args:
361
+ classifier: a module that looks like an AttentiveClassifier
362
+ data: a list of elements to make predictions over. These must be SentenceEvidence objects.
363
+ batch_size: the biggest chunk we can fit in one batch.
364
+ device: Optional; what compute device this should run on
365
+ criterion: Optional; a loss function
366
+ tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
367
+ """
368
+ epoch_loss = 0
369
+ epoch_soft_pred = []
370
+ epoch_hard_pred = []
371
+ epoch_truth = []
372
+ batches = _grouper(data, batch_size)
373
+ classifier.eval()
374
+ for batch in batches:
375
+ loss, soft_preds, hard_preds, targets = make_token_preds_batch(
376
+ classifier,
377
+ batch,
378
+ token_mapping,
379
+ device,
380
+ criterion=criterion,
381
+ tensorize_model_inputs=tensorize_model_inputs,
382
+ )
383
+ if loss is not None:
384
+ epoch_loss += loss.sum().item()
385
+ epoch_hard_pred.extend(hard_preds)
386
+ epoch_soft_pred.extend(soft_preds.cpu().tolist())
387
+ epoch_truth.extend(targets)
388
+ epoch_loss /= len(data)
389
+ return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth
390
+
391
+
392
+ # copied from https://docs.python.org/3/library/itertools.html#itertools-recipes
393
+ def _grouper(iterable, n, fillvalue=None):
394
+ "Collect data into fixed-length chunks or blocks"
395
+ # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
396
+ args = [iter(iterable)] * n
397
+ return itertools.zip_longest(*args, fillvalue=fillvalue)
398
+
399
+
400
+ def score_rationales(
401
+ truth: List[Annotation],
402
+ documents: Dict[str, List[List[int]]],
403
+ input_data: List[SentenceEvidence],
404
+ scores: List[float],
405
+ ) -> dict:
406
+ results = {}
407
+ doc_to_sent_scores = dict() # (annid, docid) -> [sentence scores]
408
+ for sent, score in zip(input_data, scores):
409
+ k = (sent.ann_id, sent.docid)
410
+ if k not in doc_to_sent_scores:
411
+ doc_to_sent_scores[k] = [0.0 for _ in range(len(documents[sent.docid]))]
412
+ if not isinstance(score[1], float):
413
+ score[1] = score[1].item()
414
+ doc_to_sent_scores[(sent.ann_id, sent.docid)][sent.index] = score[1]
415
+ # hard rationale scoring
416
+ best_sentence = {k: np.argmax(np.array(v)) for k, v in doc_to_sent_scores.items()}
417
+ predicted_rationales = []
418
+ for (ann_id, docid), sent_idx in best_sentence.items():
419
+ start_token = sum(len(s) for s in documents[docid][:sent_idx])
420
+ end_token = start_token + len(documents[docid][sent_idx])
421
+ predicted_rationales.append(Rationale(ann_id, docid, start_token, end_token))
422
+ true_rationales = list(
423
+ chain.from_iterable(Rationale.from_annotation(rat) for rat in truth)
424
+ )
425
+
426
+ results["hard_rationale_scores"] = score_hard_rationale_predictions(
427
+ true_rationales, predicted_rationales
428
+ )
429
+ results["hard_rationale_partial_match_scores"] = partial_match_score(
430
+ true_rationales, predicted_rationales, [0.5]
431
+ )
432
+
433
+ # soft rationale scoring
434
+ instance_format = []
435
+ for (ann_id, docid), sentences in doc_to_sent_scores.items():
436
+ soft_token_predictions = []
437
+ for sent_score, sent_text in zip(sentences, documents[docid]):
438
+ soft_token_predictions.extend(sent_score for _ in range(len(sent_text)))
439
+ instance_format.append(
440
+ {
441
+ "annotation_id": ann_id,
442
+ "rationales": [
443
+ {
444
+ "docid": docid,
445
+ "soft_rationale_predictions": soft_token_predictions,
446
+ "soft_sentence_predictions": sentences,
447
+ }
448
+ ],
449
+ }
450
+ )
451
+ flattened_documents = {
452
+ k: list(chain.from_iterable(v)) for k, v in documents.items()
453
+ }
454
+ token_scoring_format = PositionScoredDocument.from_results(
455
+ instance_format, truth, flattened_documents, use_tokens=True
456
+ )
457
+ results["soft_token_scores"] = score_soft_tokens(token_scoring_format)
458
+ sentence_scoring_format = PositionScoredDocument.from_results(
459
+ instance_format, truth, documents, use_tokens=False
460
+ )
461
+ results["soft_sentence_scores"] = score_soft_tokens(sentence_scoring_format)
462
+ return results
463
+
464
+
465
+ def decode(
466
+ evidence_identifier: nn.Module,
467
+ evidence_classifier: nn.Module,
468
+ train: List[Annotation],
469
+ val: List[Annotation],
470
+ test: List[Annotation],
471
+ docs: Dict[str, List[List[int]]],
472
+ class_interner: Dict[str, int],
473
+ batch_size: int,
474
+ tensorize_model_inputs: bool,
475
+ decoding_docs: Dict[str, List[Any]] = None,
476
+ ) -> dict:
477
+ """Identifies and then classifies evidence
478
+
479
+ Args:
480
+ evidence_identifier: a module for identifying evidence statements
481
+ evidence_classifier: a module for making a classification based on evidence statements
482
+ train: A List of interned Annotations
483
+ val: A List of interned Annotations
484
+ test: A List of interned Annotations
485
+ docs: A Dict of Documents, which are interned sentences.
486
+ class_interner: Converts an Annotation's final class into ints
487
+ batch_size: how big should our batches be?
488
+ tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization
489
+ """
490
+ device = None
491
+ class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])]
492
+ if decoding_docs is None:
493
+ decoding_docs = docs
494
+
495
+ def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]:
496
+ """Prepares data for evidence identification and classification.
497
+
498
+ Creates paired evaluation data, wherein each (annotation, docid, sentence, kls)
499
+ tuplet appears first as the kls determining if the sentence is evidence, and
500
+ secondarily what the overall classification for the (annotation/docid) pair is.
501
+ This allows selection based on model scores of the evidence_identifier for
502
+ input to the evidence_classifier.
503
+ """
504
+ identification_data = annotations_to_evidence_identification(data, docs)
505
+ classification_data = annotations_to_evidence_classification(
506
+ data, docs, class_interner, include_all=True
507
+ )
508
+ ann_doc_sents = defaultdict(
509
+ lambda: defaultdict(dict)
510
+ ) # ann id -> docid -> sent idx -> sent data
511
+ ret = []
512
+ for sent_ev in classification_data:
513
+ id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index]
514
+ ret.append((id_data, sent_ev))
515
+ assert id_data.ann_id == sent_ev.ann_id
516
+ assert id_data.docid == sent_ev.docid
517
+ assert id_data.index == sent_ev.index
518
+ assert len(ret) == len(classification_data)
519
+ return ret
520
+
521
+ def decode_batch(
522
+ data: List[Tuple[SentenceEvidence, SentenceEvidence]],
523
+ name: str,
524
+ score: bool = False,
525
+ annotations: List[Annotation] = None,
526
+ ) -> dict:
527
+ """Identifies evidence statements and then makes classifications based on it.
528
+
529
+ Args:
530
+ data: a paired list of SentenceEvidences, differing only in the kls field.
531
+ The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class
532
+ name: a name for a results dict
533
+ """
534
+
535
+ num_uniques = len(set((x.ann_id, x.docid) for x, _ in data))
536
+ logging.info(
537
+ f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations"
538
+ )
539
+ identifier_data, classifier_data = zip(*data)
540
+ results = dict()
541
+ IdentificationClassificationResult = namedtuple(
542
+ "IdentificationClassificationResult",
543
+ "identification_data classification_data soft_identification hard_identification soft_classification hard_classification",
544
+ )
545
+ with torch.no_grad():
546
+ # make predictions for the evidence_identifier
547
+ evidence_identifier.eval()
548
+ evidence_classifier.eval()
549
+
550
+ (
551
+ _,
552
+ soft_identification_preds,
553
+ hard_identification_preds,
554
+ _,
555
+ ) = make_preds_epoch(
556
+ evidence_identifier,
557
+ identifier_data,
558
+ batch_size,
559
+ device,
560
+ tensorize_model_inputs=tensorize_model_inputs,
561
+ )
562
+ assert len(soft_identification_preds) == len(data)
563
+ identification_results = defaultdict(list)
564
+ for id_data, cls_data, soft_id_pred, hard_id_pred in zip(
565
+ identifier_data,
566
+ classifier_data,
567
+ soft_identification_preds,
568
+ hard_identification_preds,
569
+ ):
570
+ res = IdentificationClassificationResult(
571
+ identification_data=id_data,
572
+ classification_data=cls_data,
573
+ # 1 is p(evidence|sent,query)
574
+ soft_identification=soft_id_pred[1].float().item(),
575
+ hard_identification=hard_id_pred,
576
+ soft_classification=None,
577
+ hard_classification=False,
578
+ )
579
+ identification_results[(id_data.ann_id, id_data.docid)].append(res)
580
+
581
+ best_identification_results = {
582
+ key: max(value, key=lambda x: x.soft_identification)
583
+ for key, value in identification_results.items()
584
+ }
585
+ logging.info(
586
+ f"Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences"
587
+ )
588
+ ids, classification_data = zip(
589
+ *[
590
+ (k, v.classification_data)
591
+ for k, v in best_identification_results.items()
592
+ ]
593
+ )
594
+ (
595
+ _,
596
+ soft_classification_preds,
597
+ hard_classification_preds,
598
+ classification_truth,
599
+ ) = make_preds_epoch(
600
+ evidence_classifier,
601
+ classification_data,
602
+ batch_size,
603
+ device,
604
+ tensorize_model_inputs=tensorize_model_inputs,
605
+ )
606
+ classification_results = dict()
607
+ for eyeD, soft_class, hard_class in zip(
608
+ ids, soft_classification_preds, hard_classification_preds
609
+ ):
610
+ input_id_result = best_identification_results[eyeD]
611
+ res = IdentificationClassificationResult(
612
+ identification_data=input_id_result.identification_data,
613
+ classification_data=input_id_result.classification_data,
614
+ soft_identification=input_id_result.soft_identification,
615
+ hard_identification=input_id_result.hard_identification,
616
+ soft_classification=soft_class,
617
+ hard_classification=hard_class,
618
+ )
619
+ classification_results[eyeD] = res
620
+
621
+ if score:
622
+ truth = []
623
+ pred = []
624
+ for res in classification_results.values():
625
+ truth.append(res.classification_data.kls)
626
+ pred.append(res.hard_classification)
627
+ # results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True)
628
+ results[f"{name}_f1"] = classification_report(
629
+ classification_truth,
630
+ hard_classification_preds,
631
+ target_names=class_labels,
632
+ output_dict=True,
633
+ )
634
+ results[f"{name}_acc"] = accuracy_score(
635
+ classification_truth, hard_classification_preds
636
+ )
637
+ results[f"{name}_rationale"] = score_rationales(
638
+ annotations,
639
+ decoding_docs,
640
+ identifier_data,
641
+ soft_identification_preds,
642
+ )
643
+
644
+ # turn the above results into a format suitable for scoring via the rationale scorer
645
+ # n.b. the sentence-level evidence predictions (hard and soft) are
646
+ # broadcast to the token level for scoring. The comprehensiveness class
647
+ # score is also a lie since the pipeline model above is faithful by
648
+ # design.
649
+ decoded = dict()
650
+ decoded_scores = defaultdict(list)
651
+ for (ann_id, docid), pred in classification_results.items():
652
+ sentence_prediction_scores = [
653
+ x.soft_identification
654
+ for x in identification_results[(ann_id, docid)]
655
+ ]
656
+ sentence_start_token = sum(
657
+ len(s)
658
+ for s in decoding_docs[docid][: pred.identification_data.index]
659
+ )
660
+ sentence_end_token = sentence_start_token + len(
661
+ decoding_docs[docid][pred.classification_data.index]
662
+ )
663
+ hard_rationale_predictions = [
664
+ {
665
+ "start_token": sentence_start_token,
666
+ "end_token": sentence_end_token,
667
+ }
668
+ ]
669
+ soft_rationale_predictions = []
670
+ for sent_result in sorted(
671
+ identification_results[(ann_id, docid)],
672
+ key=lambda x: x.identification_data.index,
673
+ ):
674
+ soft_rationale_predictions.extend(
675
+ sent_result.soft_identification
676
+ for _ in range(
677
+ len(
678
+ decoding_docs[sent_result.identification_data.docid][
679
+ sent_result.identification_data.index
680
+ ]
681
+ )
682
+ )
683
+ )
684
+ if ann_id not in decoded:
685
+ decoded[ann_id] = {
686
+ "annotation_id": ann_id,
687
+ "rationales": [],
688
+ "classification": class_labels[pred.hard_classification],
689
+ "classification_scores": {
690
+ class_labels[i]: s.item()
691
+ for i, s in enumerate(pred.soft_classification)
692
+ },
693
+ # TODO this should turn into the data distribution for the predicted class
694
+ # "comprehensiveness_classification_scores": 0.0,
695
+ "truth": pred.classification_data.kls,
696
+ }
697
+ decoded[ann_id]["rationales"].append(
698
+ {
699
+ "docid": docid,
700
+ "hard_rationale_predictions": hard_rationale_predictions,
701
+ "soft_rationale_predictions": soft_rationale_predictions,
702
+ "soft_sentence_predictions": sentence_prediction_scores,
703
+ }
704
+ )
705
+ decoded_scores[ann_id].append(pred.soft_classification)
706
+
707
+ # in practice, this is always a single element operation:
708
+ # in evidence inference (prompt is really a prompt + document), fever (we split documents into two classifications), movies (you only have one opinion about a movie), or boolQ (single document prompts)
709
+ # this exists to support weird models we *might* implement for cose/esnli
710
+ for ann_id, scores_list in decoded_scores.items():
711
+ scores = torch.stack(scores_list)
712
+ score_avg = torch.mean(scores, dim=0)
713
+ # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16
714
+ hard_pred = torch.argmax(score_avg.float()).item()
715
+ decoded[ann_id]["classification"] = class_labels[hard_pred]
716
+ decoded[ann_id]["classification_scores"] = {
717
+ class_labels[i]: s.item() for i, s in enumerate(score_avg)
718
+ }
719
+ return results, list(decoded.values())
720
+
721
+ test_results, test_decoded = decode_batch(prep(test), "test", score=False)
722
+ val_results, val_decoded = dict(), []
723
+ train_results, train_decoded = dict(), []
724
+ # val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val)
725
+ # train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train)
726
+ return (
727
+ dict(**train_results, **val_results, **test_results),
728
+ train_decoded,
729
+ val_decoded,
730
+ test_decoded,
731
+ )
732
+
733
+
734
+ def decode_evidence_tokens_and_classify(
735
+ evidence_token_identifier: nn.Module,
736
+ evidence_classifier: nn.Module,
737
+ train: List[Annotation],
738
+ val: List[Annotation],
739
+ test: List[Annotation],
740
+ docs: Dict[str, List[List[int]]],
741
+ source_documents: Dict[str, List[List[str]]],
742
+ token_mapping: Dict[str, List[List[Tuple[int, int]]]],
743
+ class_interner: Dict[str, int],
744
+ batch_size: int,
745
+ decoding_docs: Dict[str, List[Any]],
746
+ use_cose_hack: bool = False,
747
+ ) -> dict:
748
+ """Identifies and then classifies evidence
749
+
750
+ Args:
751
+ evidence_token_identifier: a module for identifying evidence statements
752
+ evidence_classifier: a module for making a classification based on evidence statements
753
+ train: A List of interned Annotations
754
+ val: A List of interned Annotations
755
+ test: A List of interned Annotations
756
+ docs: A Dict of Documents, which are interned sentences.
757
+ class_interner: Converts an Annotation's final class into ints
758
+ batch_size: how big should our batches be?
759
+ """
760
+ device = None
761
+ class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])]
762
+ if decoding_docs is None:
763
+ decoding_docs = docs
764
+
765
+ def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]:
766
+ """Prepares data for evidence identification and classification.
767
+
768
+ Creates paired evaluation data, wherein each (annotation, docid, sentence, kls)
769
+ tuplet appears first as the kls determining if the sentence is evidence, and
770
+ secondarily what the overall classification for the (annotation/docid) pair is.
771
+ This allows selection based on model scores of the evidence_token_identifier for
772
+ input to the evidence_classifier.
773
+ """
774
+ # identification_data = annotations_to_evidence_identification(data, docs)
775
+ classification_data = token_annotations_to_evidence_classification(
776
+ data, docs, class_interner
777
+ )
778
+ # annotation id -> docid -> [SentenceEvidence])
779
+ identification_data = annotations_to_evidence_token_identification(
780
+ data,
781
+ source_documents=decoding_docs,
782
+ interned_documents=docs,
783
+ token_mapping=token_mapping,
784
+ )
785
+ ann_doc_sents = defaultdict(
786
+ lambda: defaultdict(dict)
787
+ ) # ann id -> docid -> sent idx -> sent data
788
+ ret = []
789
+ for sent_ev in classification_data:
790
+ id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index]
791
+ ret.append((id_data, sent_ev))
792
+ assert id_data.ann_id == sent_ev.ann_id
793
+ assert id_data.docid == sent_ev.docid
794
+ # assert id_data.index == sent_ev.index
795
+ assert len(ret) == len(classification_data)
796
+ return ret
797
+
798
+ def decode_batch(
799
+ data: List[Tuple[SentenceEvidence, SentenceEvidence]],
800
+ name: str,
801
+ score: bool = False,
802
+ annotations: List[Annotation] = None,
803
+ class_labels: dict = class_labels,
804
+ ) -> dict:
805
+ """Identifies evidence statements and then makes classifications based on it.
806
+
807
+ Args:
808
+ data: a paired list of SentenceEvidences, differing only in the kls field.
809
+ The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class
810
+ name: a name for a results dict
811
+ """
812
+
813
+ num_uniques = len(set((x.ann_id, x.docid) for x, _ in data))
814
+ logging.info(
815
+ f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations"
816
+ )
817
+ identifier_data, classifier_data = zip(*data)
818
+ results = dict()
819
+ with torch.no_grad():
820
+ # make predictions for the evidence_token_identifier
821
+ evidence_token_identifier.eval()
822
+ evidence_classifier.eval()
823
+
824
+ (
825
+ _,
826
+ soft_identification_preds,
827
+ hard_identification_preds,
828
+ id_preds_truth,
829
+ ) = make_token_preds_epoch(
830
+ evidence_token_identifier,
831
+ identifier_data,
832
+ token_mapping,
833
+ batch_size,
834
+ device,
835
+ tensorize_model_inputs=True,
836
+ )
837
+ assert len(soft_identification_preds) == len(data)
838
+ evidence_only_cls = []
839
+ for id_data, cls_data, soft_id_pred, hard_id_pred in zip(
840
+ identifier_data,
841
+ classifier_data,
842
+ soft_identification_preds,
843
+ hard_identification_preds,
844
+ ):
845
+ assert cls_data.ann_id == id_data.ann_id
846
+ sent = []
847
+ for start, end in token_mapping[cls_data.docid][0]:
848
+ if bool(hard_id_pred[start]):
849
+ sent.extend(id_data.sentence[start:end])
850
+ # assert len(sent) > 0
851
+ new_cls_data = SentenceEvidence(
852
+ cls_data.kls,
853
+ cls_data.ann_id,
854
+ cls_data.query,
855
+ cls_data.docid,
856
+ cls_data.index,
857
+ tuple(sent),
858
+ )
859
+ evidence_only_cls.append(new_cls_data)
860
+ (
861
+ _,
862
+ soft_classification_preds,
863
+ hard_classification_preds,
864
+ classification_truth,
865
+ ) = make_preds_epoch(
866
+ evidence_classifier,
867
+ evidence_only_cls,
868
+ batch_size,
869
+ device,
870
+ tensorize_model_inputs=True,
871
+ )
872
+
873
+ if use_cose_hack:
874
+ logging.info(
875
+ "Reformatting identification and classification results to fit COS-E"
876
+ )
877
+ grouping = 5
878
+ new_soft_identification_preds = []
879
+ new_hard_identification_preds = []
880
+ new_id_preds_truth = []
881
+ new_soft_classification_preds = []
882
+ new_hard_classification_preds = []
883
+ new_classification_truth = []
884
+ new_identifier_data = []
885
+ class_labels = []
886
+
887
+ # TODO fix the labels for COS-E
888
+ for i in range(0, len(soft_identification_preds), grouping):
889
+ cls_scores = torch.stack(
890
+ soft_classification_preds[i : i + grouping]
891
+ )
892
+ cls_scores = nn.functional.softmax(cls_scores, dim=-1)
893
+ cls_scores = cls_scores[:, 1]
894
+ choice = torch.argmax(cls_scores)
895
+ cls_labels = [
896
+ x.ann_id.split("_")[-1]
897
+ for x in evidence_only_cls[i : i + grouping]
898
+ ]
899
+ class_labels = cls_labels # we need to update the class labels because of the terrible hackery used to train this
900
+ cls_truths = [x.kls for x in evidence_only_cls[i : i + grouping]]
901
+ # cls_choice = evidence_only_cls[i + choice].ann_id.split('_')[-1]
902
+ cls_truth = np.argmax(cls_truths)
903
+ new_soft_identification_preds.append(
904
+ soft_identification_preds[i + choice]
905
+ )
906
+ new_hard_identification_preds.append(
907
+ hard_identification_preds[i + choice]
908
+ )
909
+ new_id_preds_truth.append(id_preds_truth[i + choice])
910
+ new_soft_classification_preds.append(
911
+ soft_classification_preds[i + choice]
912
+ )
913
+ new_hard_classification_preds.append(choice)
914
+ new_identifier_data.append(identifier_data[i + choice])
915
+ # new_hard_classification_preds.append(hard_classification_preds[i + choice])
916
+ # new_classification_truth.append(classification_truth[i + choice])
917
+ new_classification_truth.append(cls_truth)
918
+
919
+ soft_identification_preds = new_soft_identification_preds
920
+ hard_identification_preds = new_hard_identification_preds
921
+ id_preds_truth = new_id_preds_truth
922
+ soft_classification_preds = new_soft_classification_preds
923
+ hard_classification_preds = new_hard_classification_preds
924
+ classification_truth = new_classification_truth
925
+ identifier_data = new_identifier_data
926
+ if score:
927
+ results[f"{name}_f1"] = classification_report(
928
+ classification_truth,
929
+ hard_classification_preds,
930
+ target_names=class_labels,
931
+ output_dict=True,
932
+ )
933
+ results[f"{name}_acc"] = accuracy_score(
934
+ classification_truth, hard_classification_preds
935
+ )
936
+ results[f"{name}_token_pred_acc"] = accuracy_score(
937
+ list(chain.from_iterable(id_preds_truth)),
938
+ list(chain.from_iterable(hard_identification_preds)),
939
+ )
940
+ results[f"{name}_token_pred_f1"] = classification_report(
941
+ list(chain.from_iterable(id_preds_truth)),
942
+ list(chain.from_iterable(hard_identification_preds)),
943
+ output_dict=True,
944
+ )
945
+ # TODO for token level stuff!
946
+ soft_id_scores = [
947
+ [1 - x, x] for x in chain.from_iterable(soft_identification_preds)
948
+ ]
949
+ results[f"{name}_rationale"] = score_rationales(
950
+ annotations, decoding_docs, identifier_data, soft_id_scores
951
+ )
952
+ logging.info(f"Results: {results}")
953
+
954
+ # turn the above results into a format suitable for scoring via the rationale scorer
955
+ # n.b. the sentence-level evidence predictions (hard and soft) are
956
+ # broadcast to the token level for scoring. The comprehensiveness class
957
+ # score is also a lie since the pipeline model above is faithful by
958
+ # design.
959
+ decoded = dict()
960
+ scores = []
961
+ assert len(identifier_data) == len(soft_identification_preds)
962
+ for (
963
+ id_data,
964
+ soft_id_pred,
965
+ hard_id_pred,
966
+ soft_cls_preds,
967
+ hard_cls_pred,
968
+ ) in zip(
969
+ identifier_data,
970
+ soft_identification_preds,
971
+ hard_identification_preds,
972
+ soft_classification_preds,
973
+ hard_classification_preds,
974
+ ):
975
+ docid = id_data.docid
976
+ if use_cose_hack:
977
+ docid = "_".join(docid.split("_")[0:-1])
978
+ assert len(docid) > 0
979
+ rationales = {
980
+ "docid": docid,
981
+ "hard_rationale_predictions": [],
982
+ # token level classifications, a value must be provided per-token
983
+ # in an ideal world, these correspond to the hard-decoding above.
984
+ "soft_rationale_predictions": [],
985
+ # sentence level classifications, a value must be provided for every
986
+ # sentence in each document, or not at all
987
+ "soft_sentence_predictions": [1.0],
988
+ }
989
+ last = -1
990
+ start_span = -1
991
+ for pos, (start, _) in enumerate(token_mapping[id_data.docid][0]):
992
+ rationales["soft_rationale_predictions"].append(soft_id_pred[start])
993
+ if bool(hard_id_pred[start]):
994
+ if start_span == -1:
995
+ start_span = pos
996
+ last = pos
997
+ else:
998
+ if start_span != -1:
999
+ rationales["hard_rationale_predictions"].append(
1000
+ {
1001
+ "start_token": start_span,
1002
+ "end_token": last + 1,
1003
+ }
1004
+ )
1005
+ last = -1
1006
+ start_span = -1
1007
+ if start_span != -1:
1008
+ rationales["hard_rationale_predictions"].append(
1009
+ {
1010
+ "start_token": start_span,
1011
+ "end_token": last + 1,
1012
+ }
1013
+ )
1014
+
1015
+ ann_id = id_data.ann_id
1016
+ if use_cose_hack:
1017
+ ann_id = "_".join(ann_id.split("_")[0:-1])
1018
+ soft_cls_preds = nn.functional.softmax(soft_cls_preds)
1019
+ decoded[id_data.ann_id] = {
1020
+ "annotation_id": ann_id,
1021
+ "rationales": [rationales],
1022
+ "classification": class_labels[hard_cls_pred],
1023
+ "classification_scores": {
1024
+ class_labels[i]: score.item()
1025
+ for i, score in enumerate(soft_cls_preds)
1026
+ },
1027
+ }
1028
+ return results, list(decoded.values())
1029
+
1030
+ # test_results, test_decoded = dict(), []
1031
+ # val_results, val_decoded = dict(), []
1032
+ train_results, train_decoded = dict(), []
1033
+ val_results, val_decoded = decode_batch(
1034
+ prep(val), "val", score=True, annotations=val, class_labels=class_labels
1035
+ )
1036
+ test_results, test_decoded = decode_batch(
1037
+ prep(test), "test", score=False, class_labels=class_labels
1038
+ )
1039
+ # train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train, class_labels=class_labels)
1040
+ return (
1041
+ dict(**train_results, **val_results, **test_results),
1042
+ train_decoded,
1043
+ val_decoded,
1044
+ test_decoded,
1045
+ )
Transformer-Explainability/BERT_rationale_benchmark/models/sequence_taggers.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from rationale_benchmark.models.model_utils import PaddedSequence
6
+ from transformers import BertModel
7
+
8
+
9
+ class BertTagger(nn.Module):
10
+ def __init__(
11
+ self,
12
+ bert_dir: str,
13
+ pad_token_id: int,
14
+ cls_token_id: int,
15
+ sep_token_id: int,
16
+ max_length: int = 512,
17
+ use_half_precision=True,
18
+ ):
19
+ super(BertTagger, self).__init__()
20
+ self.sep_token_id = sep_token_id
21
+ self.cls_token_id = cls_token_id
22
+ self.pad_token_id = pad_token_id
23
+ self.max_length = max_length
24
+ bert = BertModel.from_pretrained(bert_dir)
25
+ if use_half_precision:
26
+ import apex
27
+
28
+ bert = bert.half()
29
+ self.bert = bert
30
+ self.relevance_tagger = nn.Sequential(
31
+ nn.Linear(self.bert.config.hidden_size, 1), nn.Sigmoid()
32
+ )
33
+
34
+ def forward(
35
+ self,
36
+ query: List[torch.tensor],
37
+ docids: List[Any],
38
+ document_batch: List[torch.tensor],
39
+ aggregate_spans: List[Tuple[int, int]],
40
+ ):
41
+ assert len(query) == len(document_batch)
42
+ # note about device management: since distributed training is enabled, the inputs to this module can be on
43
+ # *any* device (preferably cpu, since we wrap and unwrap the module) we want to keep these params on the
44
+ # input device (assuming CPU) for as long as possible for cheap memory access
45
+ target_device = next(self.parameters()).device
46
+ # cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
47
+ sep_token = torch.tensor([self.sep_token_id]).to(
48
+ device=document_batch[0].device
49
+ )
50
+ input_tensors = []
51
+ query_lengths = []
52
+ for q, d in zip(query, document_batch):
53
+ if len(q) + len(d) + 1 > self.max_length:
54
+ d = d[: (self.max_length - len(q) - 1)]
55
+ input_tensors.append(torch.cat([q, sep_token, d]))
56
+ query_lengths.append(q.size()[0])
57
+ bert_input = PaddedSequence.autopad(
58
+ input_tensors,
59
+ batch_first=True,
60
+ padding_value=self.pad_token_id,
61
+ device=target_device,
62
+ )
63
+ outputs = self.bert(
64
+ bert_input.data,
65
+ attention_mask=bert_input.mask(
66
+ on=0.0, off=float("-inf"), dtype=torch.float, device=target_device
67
+ ),
68
+ )
69
+ hidden = outputs[0]
70
+ classes = self.relevance_tagger(hidden)
71
+ ret = []
72
+ for ql, cls, doc in zip(query_lengths, classes, document_batch):
73
+ start = ql + 1
74
+ end = start + len(doc)
75
+ ret.append(cls[ql + 1 : end])
76
+ return PaddedSequence.autopad(
77
+ ret, batch_first=True, padding_value=0, device=target_device
78
+ ).data.squeeze(dim=-1)
Transformer-Explainability/BERT_rationale_benchmark/utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import asdict, dataclass, is_dataclass
4
+ from itertools import chain
5
+ from typing import Dict, FrozenSet, List, Set, Tuple, Union
6
+
7
+
8
+ @dataclass(eq=True, frozen=True)
9
+ class Evidence:
10
+ """
11
+ (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience.
12
+ Args:
13
+ text: Some representation of the evidence text
14
+ docid: Some identifier for the document
15
+ start_token: The canonical start token, inclusive
16
+ end_token: The canonical end token, exclusive
17
+ start_sentence: Best guess start sentence, inclusive
18
+ end_sentence: Best guess end sentence, exclusive
19
+ """
20
+
21
+ text: Union[str, Tuple[int], Tuple[str]]
22
+ docid: str
23
+ start_token: int = -1
24
+ end_token: int = -1
25
+ start_sentence: int = -1
26
+ end_sentence: int = -1
27
+
28
+
29
+ @dataclass(eq=True, frozen=True)
30
+ class Annotation:
31
+ """
32
+ Args:
33
+ annotation_id: unique ID for this annotation element
34
+ query: some representation of a query string
35
+ evidences: a set of "evidence groups".
36
+ Each evidence group is:
37
+ * sufficient to respond to the query (or justify an answer)
38
+ * composed of one or more Evidences
39
+ * may have multiple documents in it (depending on the dataset)
40
+ - e-snli has multiple documents
41
+ - other datasets do not
42
+ classification: str
43
+ query_type: Optional str, additional information about the query
44
+ docids: a set of docids in which one may find evidence.
45
+ """
46
+
47
+ annotation_id: str
48
+ query: Union[str, Tuple[int]]
49
+ evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]]
50
+ classification: str
51
+ query_type: str = None
52
+ docids: Set[str] = None
53
+
54
+ def all_evidences(self) -> Tuple[Evidence]:
55
+ return tuple(list(chain.from_iterable(self.evidences)))
56
+
57
+
58
+ def annotations_to_jsonl(annotations, output_file):
59
+ with open(output_file, "w") as of:
60
+ for ann in sorted(annotations, key=lambda x: x.annotation_id):
61
+ as_json = _annotation_to_dict(ann)
62
+ as_str = json.dumps(as_json, sort_keys=True)
63
+ of.write(as_str)
64
+ of.write("\n")
65
+
66
+
67
+ def _annotation_to_dict(dc):
68
+ # convenience method
69
+ if is_dataclass(dc):
70
+ d = asdict(dc)
71
+ ret = dict()
72
+ for k, v in d.items():
73
+ ret[k] = _annotation_to_dict(v)
74
+ return ret
75
+ elif isinstance(dc, dict):
76
+ ret = dict()
77
+ for k, v in dc.items():
78
+ k = _annotation_to_dict(k)
79
+ v = _annotation_to_dict(v)
80
+ ret[k] = v
81
+ return ret
82
+ elif isinstance(dc, str):
83
+ return dc
84
+ elif isinstance(dc, (set, frozenset, list, tuple)):
85
+ ret = []
86
+ for x in dc:
87
+ ret.append(_annotation_to_dict(x))
88
+ return tuple(ret)
89
+ else:
90
+ return dc
91
+
92
+
93
+ def load_jsonl(fp: str) -> List[dict]:
94
+ ret = []
95
+ with open(fp, "r") as inf:
96
+ for line in inf:
97
+ content = json.loads(line)
98
+ ret.append(content)
99
+ return ret
100
+
101
+
102
+ def write_jsonl(jsonl, output_file):
103
+ with open(output_file, "w") as of:
104
+ for js in jsonl:
105
+ as_str = json.dumps(js, sort_keys=True)
106
+ of.write(as_str)
107
+ of.write("\n")
108
+
109
+
110
+ def annotations_from_jsonl(fp: str) -> List[Annotation]:
111
+ ret = []
112
+ with open(fp, "r") as inf:
113
+ for line in inf:
114
+ content = json.loads(line)
115
+ ev_groups = []
116
+ for ev_group in content["evidences"]:
117
+ ev_group = tuple([Evidence(**ev) for ev in ev_group])
118
+ ev_groups.append(ev_group)
119
+ content["evidences"] = frozenset(ev_groups)
120
+ ret.append(Annotation(**content))
121
+ return ret
122
+
123
+
124
+ def load_datasets(
125
+ data_dir: str,
126
+ ) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]:
127
+ """Loads a training, validation, and test dataset
128
+
129
+ Each dataset is assumed to have been serialized by annotations_to_jsonl,
130
+ that is it is a list of json-serialized Annotation instances.
131
+ """
132
+ train_data = annotations_from_jsonl(os.path.join(data_dir, "train.jsonl"))
133
+ val_data = annotations_from_jsonl(os.path.join(data_dir, "val.jsonl"))
134
+ test_data = annotations_from_jsonl(os.path.join(data_dir, "test.jsonl"))
135
+ return train_data, val_data, test_data
136
+
137
+
138
+ def load_documents(
139
+ data_dir: str, docids: Set[str] = None
140
+ ) -> Dict[str, List[List[str]]]:
141
+ """Loads a subset of available documents from disk.
142
+
143
+ Each document is assumed to be serialized as newline ('\n') separated sentences.
144
+ Each sentence is assumed to be space (' ') joined tokens.
145
+ """
146
+ if os.path.exists(os.path.join(data_dir, "docs.jsonl")):
147
+ assert not os.path.exists(os.path.join(data_dir, "docs"))
148
+ return load_documents_from_file(data_dir, docids)
149
+
150
+ docs_dir = os.path.join(data_dir, "docs")
151
+ res = dict()
152
+ if docids is None:
153
+ docids = sorted(os.listdir(docs_dir))
154
+ else:
155
+ docids = sorted(set(str(d) for d in docids))
156
+ for d in docids:
157
+ with open(os.path.join(docs_dir, d), "r") as inf:
158
+ res[d] = inf.read()
159
+ return res
160
+
161
+
162
+ def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]:
163
+ """Loads a subset of available documents from disk.
164
+
165
+ Returns a tokenized version of the document.
166
+ """
167
+ unflattened_docs = load_documents(data_dir, docids)
168
+ flattened_docs = dict()
169
+ for doc, unflattened in unflattened_docs.items():
170
+ flattened_docs[doc] = list(chain.from_iterable(unflattened))
171
+ return flattened_docs
172
+
173
+
174
+ def intern_documents(
175
+ documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str
176
+ ):
177
+ """
178
+ Replaces every word with its index in an embeddings file.
179
+
180
+ If a word is not found, uses the unk_token instead
181
+ """
182
+ ret = dict()
183
+ unk = word_interner[unk_token]
184
+ for docid, sentences in documents.items():
185
+ ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences]
186
+ return ret
187
+
188
+
189
+ def intern_annotations(
190
+ annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str
191
+ ):
192
+ ret = []
193
+ for ann in annotations:
194
+ ev_groups = []
195
+ for ev_group in ann.evidences:
196
+ evs = []
197
+ for ev in ev_group:
198
+ evs.append(
199
+ Evidence(
200
+ text=tuple(
201
+ [
202
+ word_interner.get(t, word_interner[unk_token])
203
+ for t in ev.text.split()
204
+ ]
205
+ ),
206
+ docid=ev.docid,
207
+ start_token=ev.start_token,
208
+ end_token=ev.end_token,
209
+ start_sentence=ev.start_sentence,
210
+ end_sentence=ev.end_sentence,
211
+ )
212
+ )
213
+ ev_groups.append(tuple(evs))
214
+ ret.append(
215
+ Annotation(
216
+ annotation_id=ann.annotation_id,
217
+ query=tuple(
218
+ [
219
+ word_interner.get(t, word_interner[unk_token])
220
+ for t in ann.query.split()
221
+ ]
222
+ ),
223
+ evidences=frozenset(ev_groups),
224
+ classification=ann.classification,
225
+ query_type=ann.query_type,
226
+ )
227
+ )
228
+ return ret
229
+
230
+
231
+ def load_documents_from_file(
232
+ data_dir: str, docids: Set[str] = None
233
+ ) -> Dict[str, List[List[str]]]:
234
+ """Loads a subset of available documents from 'docs.jsonl' file on disk.
235
+
236
+ Each document is assumed to be serialized as newline ('\n') separated sentences.
237
+ Each sentence is assumed to be space (' ') joined tokens.
238
+ """
239
+ docs_file = os.path.join(data_dir, "docs.jsonl")
240
+ documents = load_jsonl(docs_file)
241
+ documents = {doc["docid"]: doc["document"] for doc in documents}
242
+ # res = dict()
243
+ # if docids is None:
244
+ # docids = sorted(list(documents.keys()))
245
+ # else:
246
+ # docids = sorted(set(str(d) for d in docids))
247
+ # for d in docids:
248
+ # lines = documents[d].split('\n')
249
+ # tokenized = [line.strip().split(' ') for line in lines]
250
+ # res[d] = tokenized
251
+ return documents
Transformer-Explainability/DeiT.PNG ADDED
Transformer-Explainability/DeiT_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Transformer-Explainability/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Hila Chefer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Transformer-Explainability/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838) [CVPR 2021]
2
+
3
+ #### Check out our new advancements- [Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers](https://github.com/hila-chefer/Transformer-MM-Explainability)!
4
+ Faster, more general, and can be applied to *any* type of attention!
5
+ Among the features:
6
+ * We remove LRP for a simple and quick solution, and prove that the great results from our first paper still hold!
7
+ * We expand our work to *any* type of Transformer- not just self-attention based encoders, but also co-attention encoders and encoder-decoders!
8
+ * We show that VQA models can actually understand both image and text and make connections!
9
+ * We use a DETR object detector and create segmentation masks from our explanations!
10
+ * We provide a colab notebook with all the examples. You can very easily add images and questions of your own!
11
+
12
+ <p align="center">
13
+ <img width="400" height="450" src="new_work.jpg">
14
+ </p>
15
+
16
+ ---
17
+ ## ViT explainability notebook:
18
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
19
+
20
+ ## BERT explainability notebook:
21
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
22
+ ---
23
+
24
+ ## Updates
25
+ April 5 2021: Check out this new [post](https://analyticsindiamag.com/compute-relevancy-of-transformer-networks-via-novel-interpretable-transformer/) about our paper! A great resource for understanding the main concepts behind our work.
26
+
27
+ March 15 2021: [A Colab notebook for BERT for sentiment analysis added!](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
28
+
29
+ Feb 28 2021: Our paper was accepted to CVPR 2021!
30
+
31
+ Feb 17 2021: [A Colab notebook with all examples added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
32
+
33
+ Jan 5 2021: [A Jupyter notebook for DeiT added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT_example.ipynb)
34
+
35
+
36
+ <p align="center">
37
+ <img width="300" height="460" src="https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT.PNG">
38
+ </p>
39
+
40
+
41
+ ## Introduction
42
+ Official implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838).
43
+
44
+ We introduce a novel method which allows to visualize classifications made by a Transformer based model for both vision and NLP tasks.
45
+ Our method also allows to visualize explanations per class.
46
+
47
+ <p align="center">
48
+ <img width="600" height="200" src="https://github.com/hila-chefer/Transformer-Explainability/blob/main/method-page-001.jpg">
49
+ </p>
50
+ Method consists of 3 phases:
51
+
52
+ 1. Calculating relevance for each attention matrix using our novel formulation of LRP.
53
+
54
+ 2. Backpropagation of gradients for each attention matrix w.r.t. the visualized class. Gradients are used to average attention heads.
55
+
56
+ 3. Layer aggregation with rollout.
57
+
58
+ Please notice our [Jupyter notebook](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.ipynb) where you can run the two class specific examples from the paper.
59
+
60
+
61
+ ![alt text](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.PNG)
62
+
63
+ To add another input image, simply add the image to the [samples folder](https://github.com/hila-chefer/Transformer-Explainability/tree/main/samples), and use the `generate_visualization` function for your selected class of interest (using the `class_index={class_idx}`), not specifying the index will visualize the top class.
64
+
65
+ ## Credits
66
+ ViT implementation is based on:
67
+ - https://github.com/rwightman/pytorch-image-models
68
+ - https://github.com/lucidrains/vit-pytorch
69
+ - pretrained weights from: https://github.com/google-research/vision_transformer
70
+
71
+ BERT implementation is taken from the huggingface Transformers library:
72
+ https://huggingface.co/transformers/
73
+
74
+ ERASER benchmark code adapted from the ERASER GitHub implementation: https://github.com/jayded/eraserbenchmark
75
+
76
+ Text visualizations in supplementary were created using TAHV heatmap generator for text: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization
77
+
78
+ ## Reproducing results on ViT
79
+
80
+ ### Section A. Segmentation Results
81
+
82
+ Example:
83
+ ```
84
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/imagenet_seg_eval.py --method transformer_attribution --imagenet-seg-path /path/to/gtsegs_ijcv.mat
85
+
86
+ ```
87
+ [Link to download dataset](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat).
88
+
89
+ In the exmaple above we run a segmentation test with our method. Notice you can choose which method you wish to run using the `--method` argument.
90
+ You must provide a path to imagenet segmentation data in `--imagenet-seg-path`.
91
+
92
+ ### Section B. Perturbation Results
93
+
94
+ Example:
95
+ ```
96
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/generate_visualizations.py --method transformer_attribution --imagenet-validation-path /path/to/imagenet_validation_directory
97
+ ```
98
+
99
+ Notice that you can choose to visualize by target or top class by using the `--vis-cls` argument.
100
+
101
+ Now to run the perturbation test run the following command:
102
+ ```
103
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/pertubation_eval_from_hdf5.py --method transformer_attribution
104
+ ```
105
+
106
+ Notice that you can use the `--neg` argument to run either positive or negative perturbation.
107
+
108
+ ## Reproducing results on BERT
109
+
110
+ 1. Download the pretrained weights:
111
+
112
+ - Download `classifier.zip` from https://drive.google.com/file/d/1kGMTr69UWWe70i-o2_JfjmWDQjT66xwQ/view?usp=sharing
113
+ - mkdir -p `./bert_models/movies`
114
+ - unzip classifier.zip -d ./bert_models/movies/
115
+
116
+ 2. Download the dataset pkl file:
117
+
118
+ - Download `preprocessed.pkl` from https://drive.google.com/file/d/1-gfbTj6D87KIm_u1QMHGLKSL3e93hxBH/view?usp=sharing
119
+ - mv preprocessed.pkl ./bert_models/movies
120
+
121
+ 3. Download the dataset:
122
+
123
+ - Download `movies.zip` from https://drive.google.com/file/d/11faFLGkc0hkw3wrGTYJBr1nIvkRb189F/view?usp=sharing
124
+ - unzip movies.zip -d ./data/
125
+
126
+ 4. Now you can run the model.
127
+
128
+ Example:
129
+ ```
130
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies/ --model_params BERT_params/movies_bert.json
131
+ ```
132
+ To control which algorithm to use for explanations change the `method` variable in `BERT_rationale_benchmark/models/pipeline/bert_pipeline.py` (Defaults to 'transformer_attribution' which is our method).
133
+ Running this command will create a directory for the method in `bert_models/movies/<method_name>`.
134
+
135
+ In order to run f1 test with k, run the following command:
136
+ ```
137
+ PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/metrics.py --data_dir data/movies/ --split test --results bert_models/movies/<method_name>/identifier_results_k.json
138
+ ```
139
+
140
+ Also, in the method directory there will be created `.tex` files containing the explanations extracted for each example. This corresponds to our visualizations in the supplementary.
141
+
142
+ ## Citing our paper
143
+ If you make use of our work, please cite our paper:
144
+ ```
145
+ @InProceedings{Chefer_2021_CVPR,
146
+ author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
147
+ title = {Transformer Interpretability Beyond Attention Visualization},
148
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
149
+ month = {June},
150
+ year = {2021},
151
+ pages = {782-791}
152
+ }
153
+ ```
Transformer-Explainability/Transformer_explainability.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Transformer-Explainability/baselines/ViT/ViT_LRP.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from baselines.ViT.helpers import load_pretrained
7
+ from baselines.ViT.layer_helpers import to_2tuple
8
+ from baselines.ViT.weight_init import trunc_normal_
9
+ from einops import rearrange
10
+ from modules.layers_ours import *
11
+
12
+
13
+ def _cfg(url="", **kwargs):
14
+ return {
15
+ "url": url,
16
+ "num_classes": 1000,
17
+ "input_size": (3, 224, 224),
18
+ "pool_size": None,
19
+ "crop_pct": 0.9,
20
+ "interpolation": "bicubic",
21
+ "first_conv": "patch_embed.proj",
22
+ "classifier": "head",
23
+ **kwargs,
24
+ }
25
+
26
+
27
+ default_cfgs = {
28
+ # patch models
29
+ "vit_small_patch16_224": _cfg(
30
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
31
+ ),
32
+ "vit_base_patch16_224": _cfg(
33
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
34
+ mean=(0.5, 0.5, 0.5),
35
+ std=(0.5, 0.5, 0.5),
36
+ ),
37
+ "vit_large_patch16_224": _cfg(
38
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
39
+ mean=(0.5, 0.5, 0.5),
40
+ std=(0.5, 0.5, 0.5),
41
+ ),
42
+ }
43
+
44
+
45
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
46
+ # adding residual consideration
47
+ num_tokens = all_layer_matrices[0].shape[1]
48
+ batch_size = all_layer_matrices[0].shape[0]
49
+ eye = (
50
+ torch.eye(num_tokens)
51
+ .expand(batch_size, num_tokens, num_tokens)
52
+ .to(all_layer_matrices[0].device)
53
+ )
54
+ all_layer_matrices = [
55
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
56
+ ]
57
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
58
+ # for i in range(len(all_layer_matrices))]
59
+ joint_attention = all_layer_matrices[start_layer]
60
+ for i in range(start_layer + 1, len(all_layer_matrices)):
61
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
62
+ return joint_attention
63
+
64
+
65
+ class Mlp(nn.Module):
66
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
67
+ super().__init__()
68
+ out_features = out_features or in_features
69
+ hidden_features = hidden_features or in_features
70
+ self.fc1 = Linear(in_features, hidden_features)
71
+ self.act = GELU()
72
+ self.fc2 = Linear(hidden_features, out_features)
73
+ self.drop = Dropout(drop)
74
+
75
+ def forward(self, x):
76
+ x = self.fc1(x)
77
+ x = self.act(x)
78
+ x = self.drop(x)
79
+ x = self.fc2(x)
80
+ x = self.drop(x)
81
+ return x
82
+
83
+ def relprop(self, cam, **kwargs):
84
+ cam = self.drop.relprop(cam, **kwargs)
85
+ cam = self.fc2.relprop(cam, **kwargs)
86
+ cam = self.act.relprop(cam, **kwargs)
87
+ cam = self.fc1.relprop(cam, **kwargs)
88
+ return cam
89
+
90
+
91
+ class Attention(nn.Module):
92
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
97
+ self.scale = head_dim**-0.5
98
+
99
+ # A = Q*K^T
100
+ self.matmul1 = einsum("bhid,bhjd->bhij")
101
+ # attn = A*V
102
+ self.matmul2 = einsum("bhij,bhjd->bhid")
103
+
104
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
105
+ self.attn_drop = Dropout(attn_drop)
106
+ self.proj = Linear(dim, dim)
107
+ self.proj_drop = Dropout(proj_drop)
108
+ self.softmax = Softmax(dim=-1)
109
+
110
+ self.attn_cam = None
111
+ self.attn = None
112
+ self.v = None
113
+ self.v_cam = None
114
+ self.attn_gradients = None
115
+
116
+ def get_attn(self):
117
+ return self.attn
118
+
119
+ def save_attn(self, attn):
120
+ self.attn = attn
121
+
122
+ def save_attn_cam(self, cam):
123
+ self.attn_cam = cam
124
+
125
+ def get_attn_cam(self):
126
+ return self.attn_cam
127
+
128
+ def get_v(self):
129
+ return self.v
130
+
131
+ def save_v(self, v):
132
+ self.v = v
133
+
134
+ def save_v_cam(self, cam):
135
+ self.v_cam = cam
136
+
137
+ def get_v_cam(self):
138
+ return self.v_cam
139
+
140
+ def save_attn_gradients(self, attn_gradients):
141
+ self.attn_gradients = attn_gradients
142
+
143
+ def get_attn_gradients(self):
144
+ return self.attn_gradients
145
+
146
+ def forward(self, x):
147
+ b, n, _, h = *x.shape, self.num_heads
148
+ qkv = self.qkv(x)
149
+ q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
150
+
151
+ self.save_v(v)
152
+
153
+ dots = self.matmul1([q, k]) * self.scale
154
+
155
+ attn = self.softmax(dots)
156
+ attn = self.attn_drop(attn)
157
+
158
+ self.save_attn(attn)
159
+ attn.register_hook(self.save_attn_gradients)
160
+
161
+ out = self.matmul2([attn, v])
162
+ out = rearrange(out, "b h n d -> b n (h d)")
163
+
164
+ out = self.proj(out)
165
+ out = self.proj_drop(out)
166
+ return out
167
+
168
+ def relprop(self, cam, **kwargs):
169
+ cam = self.proj_drop.relprop(cam, **kwargs)
170
+ cam = self.proj.relprop(cam, **kwargs)
171
+ cam = rearrange(cam, "b n (h d) -> b h n d", h=self.num_heads)
172
+
173
+ # attn = A*V
174
+ (cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
175
+ cam1 /= 2
176
+ cam_v /= 2
177
+
178
+ self.save_v_cam(cam_v)
179
+ self.save_attn_cam(cam1)
180
+
181
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
182
+ cam1 = self.softmax.relprop(cam1, **kwargs)
183
+
184
+ # A = Q*K^T
185
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
186
+ cam_q /= 2
187
+ cam_k /= 2
188
+
189
+ cam_qkv = rearrange(
190
+ [cam_q, cam_k, cam_v],
191
+ "qkv b h n d -> b n (qkv h d)",
192
+ qkv=3,
193
+ h=self.num_heads,
194
+ )
195
+
196
+ return self.qkv.relprop(cam_qkv, **kwargs)
197
+
198
+
199
+ class Block(nn.Module):
200
+ def __init__(
201
+ self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0
202
+ ):
203
+ super().__init__()
204
+ self.norm1 = LayerNorm(dim, eps=1e-6)
205
+ self.attn = Attention(
206
+ dim,
207
+ num_heads=num_heads,
208
+ qkv_bias=qkv_bias,
209
+ attn_drop=attn_drop,
210
+ proj_drop=drop,
211
+ )
212
+ self.norm2 = LayerNorm(dim, eps=1e-6)
213
+ mlp_hidden_dim = int(dim * mlp_ratio)
214
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
215
+
216
+ self.add1 = Add()
217
+ self.add2 = Add()
218
+ self.clone1 = Clone()
219
+ self.clone2 = Clone()
220
+
221
+ def forward(self, x):
222
+ x1, x2 = self.clone1(x, 2)
223
+ x = self.add1([x1, self.attn(self.norm1(x2))])
224
+ x1, x2 = self.clone2(x, 2)
225
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
226
+ return x
227
+
228
+ def relprop(self, cam, **kwargs):
229
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
230
+ cam2 = self.mlp.relprop(cam2, **kwargs)
231
+ cam2 = self.norm2.relprop(cam2, **kwargs)
232
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
233
+
234
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
235
+ cam2 = self.attn.relprop(cam2, **kwargs)
236
+ cam2 = self.norm1.relprop(cam2, **kwargs)
237
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
238
+ return cam
239
+
240
+
241
+ class PatchEmbed(nn.Module):
242
+ """Image to Patch Embedding"""
243
+
244
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
245
+ super().__init__()
246
+ img_size = to_2tuple(img_size)
247
+ patch_size = to_2tuple(patch_size)
248
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
249
+ self.img_size = img_size
250
+ self.patch_size = patch_size
251
+ self.num_patches = num_patches
252
+
253
+ self.proj = Conv2d(
254
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
255
+ )
256
+
257
+ def forward(self, x):
258
+ B, C, H, W = x.shape
259
+ # FIXME look at relaxing size constraints
260
+ assert (
261
+ H == self.img_size[0] and W == self.img_size[1]
262
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
263
+ x = self.proj(x).flatten(2).transpose(1, 2)
264
+ return x
265
+
266
+ def relprop(self, cam, **kwargs):
267
+ cam = cam.transpose(1, 2)
268
+ cam = cam.reshape(
269
+ cam.shape[0],
270
+ cam.shape[1],
271
+ (self.img_size[0] // self.patch_size[0]),
272
+ (self.img_size[1] // self.patch_size[1]),
273
+ )
274
+ return self.proj.relprop(cam, **kwargs)
275
+
276
+
277
+ class VisionTransformer(nn.Module):
278
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
279
+
280
+ def __init__(
281
+ self,
282
+ img_size=224,
283
+ patch_size=16,
284
+ in_chans=3,
285
+ num_classes=1000,
286
+ embed_dim=768,
287
+ depth=12,
288
+ num_heads=12,
289
+ mlp_ratio=4.0,
290
+ qkv_bias=False,
291
+ mlp_head=False,
292
+ drop_rate=0.0,
293
+ attn_drop_rate=0.0,
294
+ ):
295
+ super().__init__()
296
+ self.num_classes = num_classes
297
+ self.num_features = (
298
+ self.embed_dim
299
+ ) = embed_dim # num_features for consistency with other models
300
+ self.patch_embed = PatchEmbed(
301
+ img_size=img_size,
302
+ patch_size=patch_size,
303
+ in_chans=in_chans,
304
+ embed_dim=embed_dim,
305
+ )
306
+ num_patches = self.patch_embed.num_patches
307
+
308
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
309
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
310
+
311
+ self.blocks = nn.ModuleList(
312
+ [
313
+ Block(
314
+ dim=embed_dim,
315
+ num_heads=num_heads,
316
+ mlp_ratio=mlp_ratio,
317
+ qkv_bias=qkv_bias,
318
+ drop=drop_rate,
319
+ attn_drop=attn_drop_rate,
320
+ )
321
+ for i in range(depth)
322
+ ]
323
+ )
324
+
325
+ self.norm = LayerNorm(embed_dim)
326
+ if mlp_head:
327
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
328
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
329
+ else:
330
+ # with a single Linear layer as head, the param count within rounding of paper
331
+ self.head = Linear(embed_dim, num_classes)
332
+
333
+ # FIXME not quite sure what the proper weight init is supposed to be,
334
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
335
+ trunc_normal_(self.pos_embed, std=0.02) # embeddings same as weights?
336
+ trunc_normal_(self.cls_token, std=0.02)
337
+ self.apply(self._init_weights)
338
+
339
+ self.pool = IndexSelect()
340
+ self.add = Add()
341
+
342
+ self.inp_grad = None
343
+
344
+ def save_inp_grad(self, grad):
345
+ self.inp_grad = grad
346
+
347
+ def get_inp_grad(self):
348
+ return self.inp_grad
349
+
350
+ def _init_weights(self, m):
351
+ if isinstance(m, nn.Linear):
352
+ trunc_normal_(m.weight, std=0.02)
353
+ if isinstance(m, nn.Linear) and m.bias is not None:
354
+ nn.init.constant_(m.bias, 0)
355
+ elif isinstance(m, nn.LayerNorm):
356
+ nn.init.constant_(m.bias, 0)
357
+ nn.init.constant_(m.weight, 1.0)
358
+
359
+ @property
360
+ def no_weight_decay(self):
361
+ return {"pos_embed", "cls_token"}
362
+
363
+ def forward(self, x):
364
+ B = x.shape[0]
365
+ x = self.patch_embed(x)
366
+
367
+ cls_tokens = self.cls_token.expand(
368
+ B, -1, -1
369
+ ) # stole cls_tokens impl from Phil Wang, thanks
370
+ x = torch.cat((cls_tokens, x), dim=1)
371
+ x = self.add([x, self.pos_embed])
372
+
373
+ x.register_hook(self.save_inp_grad)
374
+
375
+ for blk in self.blocks:
376
+ x = blk(x)
377
+
378
+ x = self.norm(x)
379
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
380
+ x = x.squeeze(1)
381
+ x = self.head(x)
382
+ return x
383
+
384
+ def relprop(
385
+ self,
386
+ cam=None,
387
+ method="transformer_attribution",
388
+ is_ablation=False,
389
+ start_layer=0,
390
+ **kwargs,
391
+ ):
392
+ # print(kwargs)
393
+ # print("conservation 1", cam.sum())
394
+ cam = self.head.relprop(cam, **kwargs)
395
+ cam = cam.unsqueeze(1)
396
+ cam = self.pool.relprop(cam, **kwargs)
397
+ cam = self.norm.relprop(cam, **kwargs)
398
+ for blk in reversed(self.blocks):
399
+ cam = blk.relprop(cam, **kwargs)
400
+
401
+ # print("conservation 2", cam.sum())
402
+ # print("min", cam.min())
403
+
404
+ if method == "full":
405
+ (cam, _) = self.add.relprop(cam, **kwargs)
406
+ cam = cam[:, 1:]
407
+ cam = self.patch_embed.relprop(cam, **kwargs)
408
+ # sum on channels
409
+ cam = cam.sum(dim=1)
410
+ return cam
411
+
412
+ elif method == "rollout":
413
+ # cam rollout
414
+ attn_cams = []
415
+ for blk in self.blocks:
416
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
417
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
418
+ attn_cams.append(avg_heads)
419
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
420
+ cam = cam[:, 0, 1:]
421
+ return cam
422
+
423
+ # our method, method name grad is legacy
424
+ elif method == "transformer_attribution" or method == "grad":
425
+ cams = []
426
+ for blk in self.blocks:
427
+ grad = blk.attn.get_attn_gradients()
428
+ cam = blk.attn.get_attn_cam()
429
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
430
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
431
+ cam = grad * cam
432
+ cam = cam.clamp(min=0).mean(dim=0)
433
+ cams.append(cam.unsqueeze(0))
434
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
435
+ cam = rollout[:, 0, 1:]
436
+ return cam
437
+
438
+ elif method == "last_layer":
439
+ cam = self.blocks[-1].attn.get_attn_cam()
440
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
441
+ if is_ablation:
442
+ grad = self.blocks[-1].attn.get_attn_gradients()
443
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
444
+ cam = grad * cam
445
+ cam = cam.clamp(min=0).mean(dim=0)
446
+ cam = cam[0, 1:]
447
+ return cam
448
+
449
+ elif method == "last_layer_attn":
450
+ cam = self.blocks[-1].attn.get_attn()
451
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
452
+ cam = cam.clamp(min=0).mean(dim=0)
453
+ cam = cam[0, 1:]
454
+ return cam
455
+
456
+ elif method == "second_layer":
457
+ cam = self.blocks[1].attn.get_attn_cam()
458
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
459
+ if is_ablation:
460
+ grad = self.blocks[1].attn.get_attn_gradients()
461
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
462
+ cam = grad * cam
463
+ cam = cam.clamp(min=0).mean(dim=0)
464
+ cam = cam[0, 1:]
465
+ return cam
466
+
467
+
468
+ def _conv_filter(state_dict, patch_size=16):
469
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
470
+ out_dict = {}
471
+ for k, v in state_dict.items():
472
+ if "patch_embed.proj.weight" in k:
473
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
474
+ out_dict[k] = v
475
+ return out_dict
476
+
477
+
478
+ def vit_base_patch16_224(pretrained=False, **kwargs):
479
+ model = VisionTransformer(
480
+ patch_size=16,
481
+ embed_dim=768,
482
+ depth=12,
483
+ num_heads=12,
484
+ mlp_ratio=4,
485
+ qkv_bias=True,
486
+ **kwargs,
487
+ )
488
+ model.default_cfg = default_cfgs["vit_base_patch16_224"]
489
+ if pretrained:
490
+ load_pretrained(
491
+ model,
492
+ num_classes=model.num_classes,
493
+ in_chans=kwargs.get("in_chans", 3),
494
+ filter_fn=_conv_filter,
495
+ )
496
+ return model
497
+
498
+
499
+ def vit_large_patch16_224(pretrained=False, **kwargs):
500
+ model = VisionTransformer(
501
+ patch_size=16,
502
+ embed_dim=1024,
503
+ depth=24,
504
+ num_heads=16,
505
+ mlp_ratio=4,
506
+ qkv_bias=True,
507
+ **kwargs,
508
+ )
509
+ model.default_cfg = default_cfgs["vit_large_patch16_224"]
510
+ if pretrained:
511
+ load_pretrained(
512
+ model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
513
+ )
514
+ return model
515
+
516
+
517
+ def deit_base_patch16_224(pretrained=False, **kwargs):
518
+ model = VisionTransformer(
519
+ patch_size=16,
520
+ embed_dim=768,
521
+ depth=12,
522
+ num_heads=12,
523
+ mlp_ratio=4,
524
+ qkv_bias=True,
525
+ **kwargs,
526
+ )
527
+ model.default_cfg = _cfg()
528
+ if pretrained:
529
+ checkpoint = torch.hub.load_state_dict_from_url(
530
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
531
+ map_location="cpu",
532
+ check_hash=True,
533
+ )
534
+ model.load_state_dict(checkpoint["model"])
535
+ return model
Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from numpy import *
6
+
7
+
8
+ # compute rollout between attention layers
9
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
10
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
11
+ num_tokens = all_layer_matrices[0].shape[1]
12
+ batch_size = all_layer_matrices[0].shape[0]
13
+ eye = (
14
+ torch.eye(num_tokens)
15
+ .expand(batch_size, num_tokens, num_tokens)
16
+ .to(all_layer_matrices[0].device)
17
+ )
18
+ all_layer_matrices = [
19
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
20
+ ]
21
+ matrices_aug = [
22
+ all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
23
+ for i in range(len(all_layer_matrices))
24
+ ]
25
+ joint_attention = matrices_aug[start_layer]
26
+ for i in range(start_layer + 1, len(matrices_aug)):
27
+ joint_attention = matrices_aug[i].bmm(joint_attention)
28
+ return joint_attention
29
+
30
+
31
+ class LRP:
32
+ def __init__(self, model):
33
+ self.model = model
34
+ self.model.eval()
35
+
36
+ def generate_LRP(
37
+ self,
38
+ input,
39
+ index=None,
40
+ method="transformer_attribution",
41
+ is_ablation=False,
42
+ start_layer=0,
43
+ ):
44
+ output = self.model(input)
45
+ kwargs = {"alpha": 1}
46
+ if index == None:
47
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
48
+
49
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
50
+ one_hot[0, index] = 1
51
+ one_hot_vector = one_hot
52
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
53
+ one_hot = torch.sum(one_hot.cuda() * output)
54
+
55
+ self.model.zero_grad()
56
+ one_hot.backward(retain_graph=True)
57
+
58
+ return self.model.relprop(
59
+ torch.tensor(one_hot_vector).to(input.device),
60
+ method=method,
61
+ is_ablation=is_ablation,
62
+ start_layer=start_layer,
63
+ **kwargs
64
+ )
65
+
66
+
67
+ class Baselines:
68
+ def __init__(self, model):
69
+ self.model = model
70
+ self.model.eval()
71
+
72
+ def generate_cam_attn(self, input, index=None):
73
+ output = self.model(input.cuda(), register_hook=True)
74
+ if index == None:
75
+ index = np.argmax(output.cpu().data.numpy())
76
+
77
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
78
+ one_hot[0][index] = 1
79
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
80
+ one_hot = torch.sum(one_hot.cuda() * output)
81
+
82
+ self.model.zero_grad()
83
+ one_hot.backward(retain_graph=True)
84
+ #################### attn
85
+ grad = self.model.blocks[-1].attn.get_attn_gradients()
86
+ cam = self.model.blocks[-1].attn.get_attention_map()
87
+ cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
88
+ grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
89
+ grad = grad.mean(dim=[1, 2], keepdim=True)
90
+ cam = (cam * grad).mean(0).clamp(min=0)
91
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
92
+
93
+ return cam
94
+ #################### attn
95
+
96
+ def generate_rollout(self, input, start_layer=0):
97
+ self.model(input)
98
+ blocks = self.model.blocks
99
+ all_layer_attentions = []
100
+ for blk in blocks:
101
+ attn_heads = blk.attn.get_attention_map()
102
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
103
+ all_layer_attentions.append(avg_heads)
104
+ rollout = compute_rollout_attention(
105
+ all_layer_attentions, start_layer=start_layer
106
+ )
107
+ return rollout[:, 0, 1:]
Transformer-Explainability/baselines/ViT/ViT_new.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from baselines.ViT.helpers import load_pretrained
9
+ from baselines.ViT.layer_helpers import to_2tuple
10
+ from baselines.ViT.weight_init import trunc_normal_
11
+ from einops import rearrange
12
+
13
+
14
+ def _cfg(url="", **kwargs):
15
+ return {
16
+ "url": url,
17
+ "num_classes": 1000,
18
+ "input_size": (3, 224, 224),
19
+ "pool_size": None,
20
+ "crop_pct": 0.9,
21
+ "interpolation": "bicubic",
22
+ "first_conv": "patch_embed.proj",
23
+ "classifier": "head",
24
+ **kwargs,
25
+ }
26
+
27
+
28
+ default_cfgs = {
29
+ # patch models
30
+ "vit_small_patch16_224": _cfg(
31
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
32
+ ),
33
+ "vit_base_patch16_224": _cfg(
34
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
35
+ mean=(0.5, 0.5, 0.5),
36
+ std=(0.5, 0.5, 0.5),
37
+ ),
38
+ "vit_large_patch16_224": _cfg(
39
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
40
+ mean=(0.5, 0.5, 0.5),
41
+ std=(0.5, 0.5, 0.5),
42
+ ),
43
+ }
44
+
45
+
46
+ class Mlp(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_features,
50
+ hidden_features=None,
51
+ out_features=None,
52
+ act_layer=nn.GELU,
53
+ drop=0.0,
54
+ ):
55
+ super().__init__()
56
+ out_features = out_features or in_features
57
+ hidden_features = hidden_features or in_features
58
+ self.fc1 = nn.Linear(in_features, hidden_features)
59
+ self.act = act_layer()
60
+ self.fc2 = nn.Linear(hidden_features, out_features)
61
+ self.drop = nn.Dropout(drop)
62
+
63
+ def forward(self, x):
64
+ x = self.fc1(x)
65
+ x = self.act(x)
66
+ x = self.drop(x)
67
+ x = self.fc2(x)
68
+ x = self.drop(x)
69
+ return x
70
+
71
+
72
+ class Attention(nn.Module):
73
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
74
+ super().__init__()
75
+ self.num_heads = num_heads
76
+ head_dim = dim // num_heads
77
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
78
+ self.scale = head_dim**-0.5
79
+
80
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
81
+ self.attn_drop = nn.Dropout(attn_drop)
82
+ self.proj = nn.Linear(dim, dim)
83
+ self.proj_drop = nn.Dropout(proj_drop)
84
+
85
+ self.attn_gradients = None
86
+ self.attention_map = None
87
+
88
+ def save_attn_gradients(self, attn_gradients):
89
+ self.attn_gradients = attn_gradients
90
+
91
+ def get_attn_gradients(self):
92
+ return self.attn_gradients
93
+
94
+ def save_attention_map(self, attention_map):
95
+ self.attention_map = attention_map
96
+
97
+ def get_attention_map(self):
98
+ return self.attention_map
99
+
100
+ def forward(self, x, register_hook=False):
101
+ b, n, _, h = *x.shape, self.num_heads
102
+
103
+ # self.save_output(x)
104
+ # x.register_hook(self.save_output_grad)
105
+
106
+ qkv = self.qkv(x)
107
+ q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
108
+
109
+ dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
110
+
111
+ attn = dots.softmax(dim=-1)
112
+ attn = self.attn_drop(attn)
113
+
114
+ out = torch.einsum("bhij,bhjd->bhid", attn, v)
115
+
116
+ self.save_attention_map(attn)
117
+ if register_hook:
118
+ attn.register_hook(self.save_attn_gradients)
119
+
120
+ out = rearrange(out, "b h n d -> b n (h d)")
121
+ out = self.proj(out)
122
+ out = self.proj_drop(out)
123
+ return out
124
+
125
+
126
+ class Block(nn.Module):
127
+ def __init__(
128
+ self,
129
+ dim,
130
+ num_heads,
131
+ mlp_ratio=4.0,
132
+ qkv_bias=False,
133
+ drop=0.0,
134
+ attn_drop=0.0,
135
+ act_layer=nn.GELU,
136
+ norm_layer=nn.LayerNorm,
137
+ ):
138
+ super().__init__()
139
+ self.norm1 = norm_layer(dim)
140
+ self.attn = Attention(
141
+ dim,
142
+ num_heads=num_heads,
143
+ qkv_bias=qkv_bias,
144
+ attn_drop=attn_drop,
145
+ proj_drop=drop,
146
+ )
147
+ self.norm2 = norm_layer(dim)
148
+ mlp_hidden_dim = int(dim * mlp_ratio)
149
+ self.mlp = Mlp(
150
+ in_features=dim,
151
+ hidden_features=mlp_hidden_dim,
152
+ act_layer=act_layer,
153
+ drop=drop,
154
+ )
155
+
156
+ def forward(self, x, register_hook=False):
157
+ x = x + self.attn(self.norm1(x), register_hook=register_hook)
158
+ x = x + self.mlp(self.norm2(x))
159
+ return x
160
+
161
+
162
+ class PatchEmbed(nn.Module):
163
+ """Image to Patch Embedding"""
164
+
165
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
166
+ super().__init__()
167
+ img_size = to_2tuple(img_size)
168
+ patch_size = to_2tuple(patch_size)
169
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
170
+ self.img_size = img_size
171
+ self.patch_size = patch_size
172
+ self.num_patches = num_patches
173
+
174
+ self.proj = nn.Conv2d(
175
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
176
+ )
177
+
178
+ def forward(self, x):
179
+ B, C, H, W = x.shape
180
+ # FIXME look at relaxing size constraints
181
+ assert (
182
+ H == self.img_size[0] and W == self.img_size[1]
183
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
184
+ x = self.proj(x).flatten(2).transpose(1, 2)
185
+ return x
186
+
187
+
188
+ class VisionTransformer(nn.Module):
189
+ """Vision Transformer"""
190
+
191
+ def __init__(
192
+ self,
193
+ img_size=224,
194
+ patch_size=16,
195
+ in_chans=3,
196
+ num_classes=1000,
197
+ embed_dim=768,
198
+ depth=12,
199
+ num_heads=12,
200
+ mlp_ratio=4.0,
201
+ qkv_bias=False,
202
+ drop_rate=0.0,
203
+ attn_drop_rate=0.0,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.num_classes = num_classes
208
+ self.num_features = (
209
+ self.embed_dim
210
+ ) = embed_dim # num_features for consistency with other models
211
+ self.patch_embed = PatchEmbed(
212
+ img_size=img_size,
213
+ patch_size=patch_size,
214
+ in_chans=in_chans,
215
+ embed_dim=embed_dim,
216
+ )
217
+ num_patches = self.patch_embed.num_patches
218
+
219
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
221
+ self.pos_drop = nn.Dropout(p=drop_rate)
222
+
223
+ self.blocks = nn.ModuleList(
224
+ [
225
+ Block(
226
+ dim=embed_dim,
227
+ num_heads=num_heads,
228
+ mlp_ratio=mlp_ratio,
229
+ qkv_bias=qkv_bias,
230
+ drop=drop_rate,
231
+ attn_drop=attn_drop_rate,
232
+ norm_layer=norm_layer,
233
+ )
234
+ for i in range(depth)
235
+ ]
236
+ )
237
+ self.norm = norm_layer(embed_dim)
238
+
239
+ # Classifier head
240
+ self.head = (
241
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
242
+ )
243
+
244
+ trunc_normal_(self.pos_embed, std=0.02)
245
+ trunc_normal_(self.cls_token, std=0.02)
246
+ self.apply(self._init_weights)
247
+
248
+ def _init_weights(self, m):
249
+ if isinstance(m, nn.Linear):
250
+ trunc_normal_(m.weight, std=0.02)
251
+ if isinstance(m, nn.Linear) and m.bias is not None:
252
+ nn.init.constant_(m.bias, 0)
253
+ elif isinstance(m, nn.LayerNorm):
254
+ nn.init.constant_(m.bias, 0)
255
+ nn.init.constant_(m.weight, 1.0)
256
+
257
+ @torch.jit.ignore
258
+ def no_weight_decay(self):
259
+ return {"pos_embed", "cls_token"}
260
+
261
+ def forward(self, x, register_hook=False):
262
+ B = x.shape[0]
263
+ x = self.patch_embed(x)
264
+
265
+ cls_tokens = self.cls_token.expand(
266
+ B, -1, -1
267
+ ) # stole cls_tokens impl from Phil Wang, thanks
268
+ x = torch.cat((cls_tokens, x), dim=1)
269
+ x = x + self.pos_embed
270
+ x = self.pos_drop(x)
271
+
272
+ for blk in self.blocks:
273
+ x = blk(x, register_hook=register_hook)
274
+
275
+ x = self.norm(x)
276
+ x = x[:, 0]
277
+ x = self.head(x)
278
+ return x
279
+
280
+
281
+ def _conv_filter(state_dict, patch_size=16):
282
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
283
+ out_dict = {}
284
+ for k, v in state_dict.items():
285
+ if "patch_embed.proj.weight" in k:
286
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
287
+ out_dict[k] = v
288
+ return out_dict
289
+
290
+
291
+ def vit_base_patch16_224(pretrained=False, **kwargs):
292
+ model = VisionTransformer(
293
+ patch_size=16,
294
+ embed_dim=768,
295
+ depth=12,
296
+ num_heads=12,
297
+ mlp_ratio=4,
298
+ qkv_bias=True,
299
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
300
+ **kwargs,
301
+ )
302
+ model.default_cfg = default_cfgs["vit_base_patch16_224"]
303
+ if pretrained:
304
+ load_pretrained(
305
+ model,
306
+ num_classes=model.num_classes,
307
+ in_chans=kwargs.get("in_chans", 3),
308
+ filter_fn=_conv_filter,
309
+ )
310
+ return model
311
+
312
+
313
+ def vit_large_patch16_224(pretrained=False, **kwargs):
314
+ model = VisionTransformer(
315
+ patch_size=16,
316
+ embed_dim=1024,
317
+ depth=24,
318
+ num_heads=16,
319
+ mlp_ratio=4,
320
+ qkv_bias=True,
321
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
322
+ **kwargs,
323
+ )
324
+ model.default_cfg = default_cfgs["vit_large_patch16_224"]
325
+ if pretrained:
326
+ load_pretrained(
327
+ model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
328
+ )
329
+ return model
Transformer-Explainability/baselines/ViT/ViT_orig_LRP.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from baselines.ViT.helpers import load_pretrained
7
+ from baselines.ViT.layer_helpers import to_2tuple
8
+ from baselines.ViT.weight_init import trunc_normal_
9
+ from einops import rearrange
10
+ from modules.layers_lrp import *
11
+
12
+
13
+ def _cfg(url="", **kwargs):
14
+ return {
15
+ "url": url,
16
+ "num_classes": 1000,
17
+ "input_size": (3, 224, 224),
18
+ "pool_size": None,
19
+ "crop_pct": 0.9,
20
+ "interpolation": "bicubic",
21
+ "first_conv": "patch_embed.proj",
22
+ "classifier": "head",
23
+ **kwargs,
24
+ }
25
+
26
+
27
+ default_cfgs = {
28
+ # patch models
29
+ "vit_small_patch16_224": _cfg(
30
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth",
31
+ ),
32
+ "vit_base_patch16_224": _cfg(
33
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
34
+ mean=(0.5, 0.5, 0.5),
35
+ std=(0.5, 0.5, 0.5),
36
+ ),
37
+ "vit_large_patch16_224": _cfg(
38
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
39
+ mean=(0.5, 0.5, 0.5),
40
+ std=(0.5, 0.5, 0.5),
41
+ ),
42
+ }
43
+
44
+
45
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
46
+ # adding residual consideration
47
+ num_tokens = all_layer_matrices[0].shape[1]
48
+ batch_size = all_layer_matrices[0].shape[0]
49
+ eye = (
50
+ torch.eye(num_tokens)
51
+ .expand(batch_size, num_tokens, num_tokens)
52
+ .to(all_layer_matrices[0].device)
53
+ )
54
+ all_layer_matrices = [
55
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
56
+ ]
57
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
58
+ # for i in range(len(all_layer_matrices))]
59
+ joint_attention = all_layer_matrices[start_layer]
60
+ for i in range(start_layer + 1, len(all_layer_matrices)):
61
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
62
+ return joint_attention
63
+
64
+
65
+ class Mlp(nn.Module):
66
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
67
+ super().__init__()
68
+ out_features = out_features or in_features
69
+ hidden_features = hidden_features or in_features
70
+ self.fc1 = Linear(in_features, hidden_features)
71
+ self.act = GELU()
72
+ self.fc2 = Linear(hidden_features, out_features)
73
+ self.drop = Dropout(drop)
74
+
75
+ def forward(self, x):
76
+ x = self.fc1(x)
77
+ x = self.act(x)
78
+ x = self.drop(x)
79
+ x = self.fc2(x)
80
+ x = self.drop(x)
81
+ return x
82
+
83
+ def relprop(self, cam, **kwargs):
84
+ cam = self.drop.relprop(cam, **kwargs)
85
+ cam = self.fc2.relprop(cam, **kwargs)
86
+ cam = self.act.relprop(cam, **kwargs)
87
+ cam = self.fc1.relprop(cam, **kwargs)
88
+ return cam
89
+
90
+
91
+ class Attention(nn.Module):
92
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
97
+ self.scale = head_dim**-0.5
98
+
99
+ # A = Q*K^T
100
+ self.matmul1 = einsum("bhid,bhjd->bhij")
101
+ # attn = A*V
102
+ self.matmul2 = einsum("bhij,bhjd->bhid")
103
+
104
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
105
+ self.attn_drop = Dropout(attn_drop)
106
+ self.proj = Linear(dim, dim)
107
+ self.proj_drop = Dropout(proj_drop)
108
+ self.softmax = Softmax(dim=-1)
109
+
110
+ self.attn_cam = None
111
+ self.attn = None
112
+ self.v = None
113
+ self.v_cam = None
114
+ self.attn_gradients = None
115
+
116
+ def get_attn(self):
117
+ return self.attn
118
+
119
+ def save_attn(self, attn):
120
+ self.attn = attn
121
+
122
+ def save_attn_cam(self, cam):
123
+ self.attn_cam = cam
124
+
125
+ def get_attn_cam(self):
126
+ return self.attn_cam
127
+
128
+ def get_v(self):
129
+ return self.v
130
+
131
+ def save_v(self, v):
132
+ self.v = v
133
+
134
+ def save_v_cam(self, cam):
135
+ self.v_cam = cam
136
+
137
+ def get_v_cam(self):
138
+ return self.v_cam
139
+
140
+ def save_attn_gradients(self, attn_gradients):
141
+ self.attn_gradients = attn_gradients
142
+
143
+ def get_attn_gradients(self):
144
+ return self.attn_gradients
145
+
146
+ def forward(self, x):
147
+ b, n, _, h = *x.shape, self.num_heads
148
+ qkv = self.qkv(x)
149
+ q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
150
+
151
+ self.save_v(v)
152
+
153
+ dots = self.matmul1([q, k]) * self.scale
154
+
155
+ attn = self.softmax(dots)
156
+ attn = self.attn_drop(attn)
157
+
158
+ self.save_attn(attn)
159
+ attn.register_hook(self.save_attn_gradients)
160
+
161
+ out = self.matmul2([attn, v])
162
+ out = rearrange(out, "b h n d -> b n (h d)")
163
+
164
+ out = self.proj(out)
165
+ out = self.proj_drop(out)
166
+ return out
167
+
168
+ def relprop(self, cam, **kwargs):
169
+ cam = self.proj_drop.relprop(cam, **kwargs)
170
+ cam = self.proj.relprop(cam, **kwargs)
171
+ cam = rearrange(cam, "b n (h d) -> b h n d", h=self.num_heads)
172
+
173
+ # attn = A*V
174
+ (cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
175
+ cam1 /= 2
176
+ cam_v /= 2
177
+
178
+ self.save_v_cam(cam_v)
179
+ self.save_attn_cam(cam1)
180
+
181
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
182
+ cam1 = self.softmax.relprop(cam1, **kwargs)
183
+
184
+ # A = Q*K^T
185
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
186
+ cam_q /= 2
187
+ cam_k /= 2
188
+
189
+ cam_qkv = rearrange(
190
+ [cam_q, cam_k, cam_v],
191
+ "qkv b h n d -> b n (qkv h d)",
192
+ qkv=3,
193
+ h=self.num_heads,
194
+ )
195
+
196
+ return self.qkv.relprop(cam_qkv, **kwargs)
197
+
198
+
199
+ class Block(nn.Module):
200
+ def __init__(
201
+ self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0
202
+ ):
203
+ super().__init__()
204
+ self.norm1 = LayerNorm(dim, eps=1e-6)
205
+ self.attn = Attention(
206
+ dim,
207
+ num_heads=num_heads,
208
+ qkv_bias=qkv_bias,
209
+ attn_drop=attn_drop,
210
+ proj_drop=drop,
211
+ )
212
+ self.norm2 = LayerNorm(dim, eps=1e-6)
213
+ mlp_hidden_dim = int(dim * mlp_ratio)
214
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
215
+
216
+ self.add1 = Add()
217
+ self.add2 = Add()
218
+ self.clone1 = Clone()
219
+ self.clone2 = Clone()
220
+
221
+ def forward(self, x):
222
+ x1, x2 = self.clone1(x, 2)
223
+ x = self.add1([x1, self.attn(self.norm1(x2))])
224
+ x1, x2 = self.clone2(x, 2)
225
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
226
+ return x
227
+
228
+ def relprop(self, cam, **kwargs):
229
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
230
+ cam2 = self.mlp.relprop(cam2, **kwargs)
231
+ cam2 = self.norm2.relprop(cam2, **kwargs)
232
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
233
+
234
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
235
+ cam2 = self.attn.relprop(cam2, **kwargs)
236
+ cam2 = self.norm1.relprop(cam2, **kwargs)
237
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
238
+ return cam
239
+
240
+
241
+ class PatchEmbed(nn.Module):
242
+ """Image to Patch Embedding"""
243
+
244
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
245
+ super().__init__()
246
+ img_size = to_2tuple(img_size)
247
+ patch_size = to_2tuple(patch_size)
248
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
249
+ self.img_size = img_size
250
+ self.patch_size = patch_size
251
+ self.num_patches = num_patches
252
+
253
+ self.proj = Conv2d(
254
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
255
+ )
256
+
257
+ def forward(self, x):
258
+ B, C, H, W = x.shape
259
+ # FIXME look at relaxing size constraints
260
+ assert (
261
+ H == self.img_size[0] and W == self.img_size[1]
262
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
263
+ x = self.proj(x).flatten(2).transpose(1, 2)
264
+ return x
265
+
266
+ def relprop(self, cam, **kwargs):
267
+ cam = cam.transpose(1, 2)
268
+ cam = cam.reshape(
269
+ cam.shape[0],
270
+ cam.shape[1],
271
+ (self.img_size[0] // self.patch_size[0]),
272
+ (self.img_size[1] // self.patch_size[1]),
273
+ )
274
+ return self.proj.relprop(cam, **kwargs)
275
+
276
+
277
+ class VisionTransformer(nn.Module):
278
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
279
+
280
+ def __init__(
281
+ self,
282
+ img_size=224,
283
+ patch_size=16,
284
+ in_chans=3,
285
+ num_classes=1000,
286
+ embed_dim=768,
287
+ depth=12,
288
+ num_heads=12,
289
+ mlp_ratio=4.0,
290
+ qkv_bias=False,
291
+ mlp_head=False,
292
+ drop_rate=0.0,
293
+ attn_drop_rate=0.0,
294
+ ):
295
+ super().__init__()
296
+ self.num_classes = num_classes
297
+ self.num_features = (
298
+ self.embed_dim
299
+ ) = embed_dim # num_features for consistency with other models
300
+ self.patch_embed = PatchEmbed(
301
+ img_size=img_size,
302
+ patch_size=patch_size,
303
+ in_chans=in_chans,
304
+ embed_dim=embed_dim,
305
+ )
306
+ num_patches = self.patch_embed.num_patches
307
+
308
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
309
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
310
+
311
+ self.blocks = nn.ModuleList(
312
+ [
313
+ Block(
314
+ dim=embed_dim,
315
+ num_heads=num_heads,
316
+ mlp_ratio=mlp_ratio,
317
+ qkv_bias=qkv_bias,
318
+ drop=drop_rate,
319
+ attn_drop=attn_drop_rate,
320
+ )
321
+ for i in range(depth)
322
+ ]
323
+ )
324
+
325
+ self.norm = LayerNorm(embed_dim)
326
+ if mlp_head:
327
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
328
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
329
+ else:
330
+ # with a single Linear layer as head, the param count within rounding of paper
331
+ self.head = Linear(embed_dim, num_classes)
332
+
333
+ # FIXME not quite sure what the proper weight init is supposed to be,
334
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
335
+ trunc_normal_(self.pos_embed, std=0.02) # embeddings same as weights?
336
+ trunc_normal_(self.cls_token, std=0.02)
337
+ self.apply(self._init_weights)
338
+
339
+ self.pool = IndexSelect()
340
+ self.add = Add()
341
+
342
+ self.inp_grad = None
343
+
344
+ def save_inp_grad(self, grad):
345
+ self.inp_grad = grad
346
+
347
+ def get_inp_grad(self):
348
+ return self.inp_grad
349
+
350
+ def _init_weights(self, m):
351
+ if isinstance(m, nn.Linear):
352
+ trunc_normal_(m.weight, std=0.02)
353
+ if isinstance(m, nn.Linear) and m.bias is not None:
354
+ nn.init.constant_(m.bias, 0)
355
+ elif isinstance(m, nn.LayerNorm):
356
+ nn.init.constant_(m.bias, 0)
357
+ nn.init.constant_(m.weight, 1.0)
358
+
359
+ @property
360
+ def no_weight_decay(self):
361
+ return {"pos_embed", "cls_token"}
362
+
363
+ def forward(self, x):
364
+ B = x.shape[0]
365
+ x = self.patch_embed(x)
366
+
367
+ cls_tokens = self.cls_token.expand(
368
+ B, -1, -1
369
+ ) # stole cls_tokens impl from Phil Wang, thanks
370
+ x = torch.cat((cls_tokens, x), dim=1)
371
+ x = self.add([x, self.pos_embed])
372
+
373
+ x.register_hook(self.save_inp_grad)
374
+
375
+ for blk in self.blocks:
376
+ x = blk(x)
377
+
378
+ x = self.norm(x)
379
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
380
+ x = x.squeeze(1)
381
+ x = self.head(x)
382
+ return x
383
+
384
+ def relprop(
385
+ self, cam=None, method="grad", is_ablation=False, start_layer=0, **kwargs
386
+ ):
387
+ # print(kwargs)
388
+ # print("conservation 1", cam.sum())
389
+ cam = self.head.relprop(cam, **kwargs)
390
+ cam = cam.unsqueeze(1)
391
+ cam = self.pool.relprop(cam, **kwargs)
392
+ cam = self.norm.relprop(cam, **kwargs)
393
+ for blk in reversed(self.blocks):
394
+ cam = blk.relprop(cam, **kwargs)
395
+
396
+ # print("conservation 2", cam.sum())
397
+ # print("min", cam.min())
398
+
399
+ if method == "full":
400
+ (cam, _) = self.add.relprop(cam, **kwargs)
401
+ cam = cam[:, 1:]
402
+ cam = self.patch_embed.relprop(cam, **kwargs)
403
+ # sum on channels
404
+ cam = cam.sum(dim=1)
405
+ return cam
406
+
407
+ elif method == "rollout":
408
+ # cam rollout
409
+ attn_cams = []
410
+ for blk in self.blocks:
411
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
412
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
413
+ attn_cams.append(avg_heads)
414
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
415
+ cam = cam[:, 0, 1:]
416
+ return cam
417
+
418
+ elif method == "grad":
419
+ cams = []
420
+ for blk in self.blocks:
421
+ grad = blk.attn.get_attn_gradients()
422
+ cam = blk.attn.get_attn_cam()
423
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
424
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
425
+ cam = grad * cam
426
+ cam = cam.clamp(min=0).mean(dim=0)
427
+ cams.append(cam.unsqueeze(0))
428
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
429
+ cam = rollout[:, 0, 1:]
430
+ return cam
431
+
432
+ elif method == "last_layer":
433
+ cam = self.blocks[-1].attn.get_attn_cam()
434
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
435
+ if is_ablation:
436
+ grad = self.blocks[-1].attn.get_attn_gradients()
437
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
438
+ cam = grad * cam
439
+ cam = cam.clamp(min=0).mean(dim=0)
440
+ cam = cam[0, 1:]
441
+ return cam
442
+
443
+ elif method == "last_layer_attn":
444
+ cam = self.blocks[-1].attn.get_attn()
445
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
446
+ cam = cam.clamp(min=0).mean(dim=0)
447
+ cam = cam[0, 1:]
448
+ return cam
449
+
450
+ elif method == "second_layer":
451
+ cam = self.blocks[1].attn.get_attn_cam()
452
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
453
+ if is_ablation:
454
+ grad = self.blocks[1].attn.get_attn_gradients()
455
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
456
+ cam = grad * cam
457
+ cam = cam.clamp(min=0).mean(dim=0)
458
+ cam = cam[0, 1:]
459
+ return cam
460
+
461
+
462
+ def _conv_filter(state_dict, patch_size=16):
463
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
464
+ out_dict = {}
465
+ for k, v in state_dict.items():
466
+ if "patch_embed.proj.weight" in k:
467
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
468
+ out_dict[k] = v
469
+ return out_dict
470
+
471
+
472
+ def vit_base_patch16_224(pretrained=False, **kwargs):
473
+ model = VisionTransformer(
474
+ patch_size=16,
475
+ embed_dim=768,
476
+ depth=12,
477
+ num_heads=12,
478
+ mlp_ratio=4,
479
+ qkv_bias=True,
480
+ **kwargs,
481
+ )
482
+ model.default_cfg = default_cfgs["vit_base_patch16_224"]
483
+ if pretrained:
484
+ load_pretrained(
485
+ model,
486
+ num_classes=model.num_classes,
487
+ in_chans=kwargs.get("in_chans", 3),
488
+ filter_fn=_conv_filter,
489
+ )
490
+ return model
491
+
492
+
493
+ def vit_large_patch16_224(pretrained=False, **kwargs):
494
+ model = VisionTransformer(
495
+ patch_size=16,
496
+ embed_dim=1024,
497
+ depth=24,
498
+ num_heads=16,
499
+ mlp_ratio=4,
500
+ qkv_bias=True,
501
+ **kwargs,
502
+ )
503
+ model.default_cfg = default_cfgs["vit_large_patch16_224"]
504
+ if pretrained:
505
+ load_pretrained(
506
+ model, num_classes=model.num_classes, in_chans=kwargs.get("in_chans", 3)
507
+ )
508
+ return model