File size: 6,202 Bytes
16b19cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e876713
16b19cc
 
 
 
 
 
 
 
 
 
 
387fb9c
16b19cc
 
 
387fb9c
16b19cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570016e
16b19cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d827a2c
4230709
16b19cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import numpy as np
import pandas as pd
import matplotlib.cm as cm
import html
from torch.nn.functional import softmax
import torch
from matplotlib.colors import LinearSegmentedColormap

cdict = {'red':   [[0.0,  0.8, 0.8],
                   [1.0,  1.0, 1.0]],
         'green': [[0.0,  0.0, 0.0],
                   [1.0,  1.0, 1.0]],
         'blue':  [[0.0,  0.0, 0.0],
                   [1.0,  1.0, 1.0]],
          'alpha':[[0.0,  1.0, 1.0],
                   [1.0,  0.0, 0.0]]}

cmap = LinearSegmentedColormap('codemap', segmentdata=cdict, N=256)

def value2rgba(x, cmap=cmap, alpha_mult=1.0):
    c = cmap(x)
    rgb = (np.array(c[:-1]) * 255).astype(int)
    a = c[-1] * alpha_mult
    return tuple(rgb.tolist() + [a])

def highlight_token_scores(tokens, scores, sep=' ', **kwargs):
    html_code,spans = [''], []#['<span style="font-family: monospace;">'], []
    for t, s in zip(tokens, scores):
        t = html.escape(t)
        t = t.replace("\n", " \n")
        c = str(value2rgba(s, alpha_mult=0.8, **kwargs))
        spans.append(f'<span title="{s:.3f}" style="background-color: rgba{c};">{t}</span>')
    html_code.append(sep.join(spans))
    return '<p><code><FONT COLOR=black>' + ''.join(html_code) + '</pre></p>'

def color_dataframe(row):
    styles = []
    c = str(value2rgba(row["scores"], alpha_mult=0.8))
    for key in row.index:
        if key in {"tokens", "scores"}:
            styles.append(f"background-color: rgba{c}") 
        else:
            styles.append(f"background-color: None") 
    return styles

@st.cache_resource
def load_tokenizer(model_ckpt):
    return AutoTokenizer.from_pretrained(model_ckpt)

@st.cache_resource
def load_model(model_ckpt):
    model = AutoModelForCausalLM.from_pretrained(model_ckpt)
    return model

def calculate_scores(probs, token_ids):
    probs = probs[:-1]
    token_ids = token_ids[1:]
    sorted_ids = np.argsort(probs, axis=-1)[:, ::-1]
    sorted_probs = np.sort(probs, axis=-1)[:, ::-1]
    selected_token_mask = sorted_ids == token_ids[:, None]
    masked_probs = np.ma.array(sorted_probs, mask=~selected_token_mask)
    token_probs = masked_probs.sum(axis=1).data

    masked_indices = np.cumsum(selected_token_mask[:, ::-1], axis=-1)[:, ::-1].astype(bool)
    masked_probs = np.ma.array(sorted_probs, mask=~masked_indices)
    token_rank = masked_indices.sum(axis=-1)
    cumulative_probs = masked_probs.sum(axis=1).data/token_rank
    scores = token_probs/cumulative_probs
    return [1.] + list(scores), sorted_ids

def calculate_loss(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    norm_loss = 1 - (loss/torch.max(loss))
    return [1.] + list(norm_loss.numpy())

default_code = """\
from torch import nn
from transformers import Model

class Transformer:
    def __init__(config):
        self.model = Model(config)

    def forward(inputs):
        return self.model(inputs)"""

solution_code = """\
from torch import nn
from transformers import Model

class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.config = config
        self.model = Model(config)

    def forward(self, inputs):
        return self.model(inputs)
"""

st.set_page_config(page_icon=':parrot:', layout="wide")

np.random.seed(42)
model_ckpt = "codeparrot/codeparrot-small"
tokenizer = load_tokenizer(model_ckpt)
model = load_model(model_ckpt)
st.markdown("<h1 style='text-align: center;'>CodeParrot 🦜</h1>", unsafe_allow_html=True)
st.markdown('##')

col1, col2 = st.columns(2)

col1.subheader("Edit code")
code = col1.text_area(label="", value=default_code, height=220,).strip()
inputs = tokenizer(code, return_tensors='pt')
token_list = [tokenizer.decode(t) for t in inputs["input_ids"][0]]

with torch.no_grad():
    logits = model(input_ids=inputs["input_ids"]).logits[0]
    probs = softmax(logits, dim=-1)

loss = calculate_loss(logits, inputs["input_ids"][0])
norm_probs, sorted_token_ids = calculate_scores(probs.numpy(), inputs["input_ids"][0].numpy())

if len(inputs['input_ids'])>1024:
    st.warning("Your input is longer than the maximum 1024 tokens and will be truncated.")
st.sidebar.title("Info:")
st.sidebar.markdown("This demo uses CodeParrot to highlight the parts of code with low probability. Since CodeParrot is an autoregressive model the tokens at the beginning tend to have a lower probability. E.g. the model can't know what you want to import because it has no access to information later in the code. However, as you can see in the example on the right it still can highlight bugs or unconventional naming.\n\nAt the bottom of the page is an example of how a better solution might look like. Try to copy paste it and press **CMD + Enter** to update the highlighting.")
st.sidebar.title("Settings:")
if st.sidebar.radio("Highlight mode:", ["Probability heuristics", "Scaled loss per token"]) == "Probability heuristics":
    scores = norm_probs
else:
    scores = loss

suggestion_threshold = st.sidebar.slider("Suggestion threshold", 0.0, 1.0, 0.2)

col2.subheader("Highlighted code")
col2.markdown('##')
html_string = highlight_token_scores(token_list, scores, sep="")
col2.markdown(html_string, unsafe_allow_html=True)
col2.markdown('##')

st.subheader("Model suggestions")
top_k = {}
for i in range(5):
    top_k[f"top-{i+1}"] = ["No prediction for first token"] + [repr(tokenizer.decode(idx)) for idx in sorted_token_ids[:, i]]
df = pd.DataFrame({"tokens": [repr(t) for t in token_list], "scores": scores, **top_k})
df.index.name = "position"
df_filter = df.loc[df["scores"]<=suggestion_threshold]
df_filter.reset_index(inplace=True)
df_filter = df_filter[["tokens", "scores", "position", "top-1", "top-2", "top-3", "top-4", "top-5",]]
df_filter = df_filter.style.apply(color_dataframe, axis=1)
st.dataframe(df_filter)

st.markdown('##')

st.subheader("Possible solution")
st.code(solution_code)