leandro commited on
Commit
562d551
1 Parent(s): 6a120e4

add app and requirements

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +1 -0
  2. app.py +40 -0
  3. requirements.txt +2 -0
.vscode/settings.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import pipeline
4
+
5
+ @st.cache(allow_output_mutation=True)
6
+ def load_tokenizer(model_ckpt):
7
+ return AutoTokenizer.from_pretrained(model_ckpt)
8
+
9
+ @st.cache(allow_output_mutation=True)
10
+ def load_model(model_ckpt):
11
+ model = AutoModelForCausalLM.from_pretrained(model_ckpt)
12
+ return model
13
+
14
+ st.set_page_config(page_icon=':parrot:', layout="wide")
15
+
16
+ default_code = '''\
17
+ def print_hello_world():\
18
+ '''
19
+
20
+ model_ckpt = "models/codeparrot-small"
21
+ tokenizer = load_tokenizer(model_ckpt)
22
+ model = load_model(model_ckpt)
23
+ gen_kwargs = {}
24
+
25
+ st.title("CodeParrot 🦜")
26
+ st.markdown('##')
27
+
28
+ pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
29
+ st.sidebar.header("Generation settings:")
30
+ gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
31
+ gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=16, min_value = 8, max_value=256)
32
+ if gen_kwargs["do_sample"]:
33
+ temperature = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
34
+ gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
35
+ gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
36
+
37
+ gen_prompt = st.text_area("Generate code with prompt:", value=default_code, height=220,).strip()
38
+ if st.button("Generate code!"):
39
+ generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
40
+ st.code(generated_text)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.12.2
2
+ torch