Xianbao QIAN patrickvonplaten commited on
Commit
9be2463
0 Parent(s):

Duplicate from diffusers/sd-to-diffusers

Browse files

Co-authored-by: Patrick von Platen <patrickvonplaten@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +198 -0
  4. hf_utils.py +50 -0
  5. requirements.txt +9 -0
  6. utils.py +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD To Diffusers
3
+ emoji: 🎨➡️🧨
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: diffusers/sd-to-diffusers
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from huggingface_hub import HfApi, upload_folder
4
+ import gradio as gr
5
+ import hf_utils
6
+ import utils
7
+ from safetensors import safe_open
8
+ import torch
9
+
10
+ subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers", "diffs"])
11
+
12
+ def error_str(error, title="Error"):
13
+ return f"""#### {title}
14
+ {error}""" if error else ""
15
+
16
+ def on_token_change(token):
17
+ model_names, error = hf_utils.get_my_model_names(token)
18
+ if model_names:
19
+ model_names.append("Other")
20
+
21
+ return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(visible=bool(model_names)), gr.update(value=error_str(error))
22
+
23
+ def url_to_model_id(model_id_str):
24
+ return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
25
+
26
+ def get_ckpt_names(token, radio_model_names, input_model):
27
+
28
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
29
+
30
+ if token == "" or model_id == "":
31
+ return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
32
+
33
+ try:
34
+ api = HfApi(token=token)
35
+ ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt") or f.endswith(".safetensors")]
36
+
37
+ if not ckpt_files:
38
+ return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
39
+
40
+ return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
41
+
42
+ except Exception as e:
43
+ return error_str(e), gr.update(choices=[]), None
44
+
45
+ def convert_and_push(radio_model_names, input_model, ckpt_name, sd_version, token, path_in_repo, ema, safetensors):
46
+ extract_ema = ema == "ema"
47
+
48
+ if sd_version == None:
49
+ return error_str("You must select a stable diffusion version.", title="Invalid input")
50
+
51
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
52
+
53
+ try:
54
+ model_id = url_to_model_id(model_id)
55
+
56
+ # 1. Download the checkpoint file
57
+ ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
58
+
59
+ if safetensors == "yes":
60
+ tensors = {}
61
+ with safe_open(ckpt_path, framework="pt", device="cpu") as f:
62
+ for key in f.keys():
63
+ tensors[key] = f.get_tensor(key)
64
+
65
+ new_checkpoint_path = "/".join(ckpt_path.split("/")[:-1] + ["model_safe.ckpt"])
66
+ torch.save(tensors, new_checkpoint_path)
67
+ ckpt_path = new_checkpoint_path
68
+ print("Converting ckpt_path", ckpt_path)
69
+
70
+ print(ckpt_path)
71
+
72
+ # 2. Run the conversion script
73
+ os.makedirs(model_id, exist_ok=True)
74
+ run_command = [
75
+ "python3",
76
+ "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
77
+ "--checkpoint_path",
78
+ ckpt_path,
79
+ "--dump_path" ,
80
+ model_id,
81
+ ]
82
+ if extract_ema:
83
+ run_command.append("--extract_ema")
84
+ subprocess.run(run_command)
85
+
86
+ # 3. Push to the model repo
87
+ commit_message="Add Diffusers weights"
88
+ upload_folder(
89
+ folder_path=model_id,
90
+ repo_id=model_id,
91
+ path_in_repo=path_in_repo,
92
+ token=token,
93
+ create_pr=True,
94
+ commit_message=commit_message,
95
+ commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
96
+ )
97
+
98
+ # # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
99
+ hf_utils.delete_file(revision)
100
+ subprocess.run(["rm", "-rf", model_id.split('/')[0]])
101
+ import glob
102
+ for f in glob.glob("*.yaml*"):
103
+ subprocess.run(["rm", "-rf", f])
104
+
105
+ return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
106
+ You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
107
+
108
+ return "Done"
109
+
110
+ except Exception as e:
111
+ return error_str(e)
112
+
113
+
114
+ DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
115
+ With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
116
+ You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub.
117
+ You can skip the queue by running the app in the colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/qunash/f0f3152c5851c0c477b68b7b98d547fe/convert-sd-to-diffusers.ipynb)"""
118
+
119
+ with gr.Blocks() as demo:
120
+
121
+ gr.Markdown(DESCRIPTION)
122
+ with gr.Row():
123
+
124
+ with gr.Column(scale=11):
125
+ with gr.Column():
126
+ gr.Markdown("## 1. Load model info")
127
+ input_token = gr.Textbox(
128
+ max_lines=1,
129
+ type="password",
130
+ label="Enter your Hugging Face token",
131
+ placeholder="READ permission is sufficient"
132
+ )
133
+ gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
134
+ with gr.Group(visible=False) as group_model:
135
+ radio_model_names = gr.Radio(label="Choose a model")
136
+ input_model = gr.Textbox(
137
+ max_lines=1,
138
+ label="Model name or URL",
139
+ placeholder="username/model_name",
140
+ visible=False,
141
+ )
142
+
143
+ btn_get_ckpts = gr.Button("Load", visible=False)
144
+
145
+ with gr.Column(scale=10):
146
+ with gr.Column(visible=False) as group_convert:
147
+ gr.Markdown("## 2. Convert to Diffusers🧨")
148
+ radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
149
+ path_in_repo = gr.Textbox(label="Path where the weights will be saved", placeholder="Leave empty for root folder")
150
+ ema = gr.Radio(label="Extract EMA or non-EMA?", choices=["ema", "non-ema"])
151
+ safetensors = gr.Radio(label="Extract from safetensors", choices=["yes", "no"], value="no")
152
+ radio_sd_version = gr.Radio(label="Choose the model version", choices=["v1", "v2", "v2.1"])
153
+ gr.Markdown("Conversion may take a few minutes.")
154
+ btn_convert = gr.Button("Convert & Push")
155
+
156
+ error_output = gr.Markdown(label="Output")
157
+
158
+ input_token.change(
159
+ fn=on_token_change,
160
+ inputs=input_token,
161
+ outputs=[group_model, radio_model_names, btn_get_ckpts, error_output],
162
+ queue=False,
163
+ scroll_to_output=True)
164
+
165
+ radio_model_names.change(
166
+ lambda x: gr.update(visible=x == "Other"),
167
+ inputs=radio_model_names,
168
+ outputs=input_model,
169
+ queue=False,
170
+ scroll_to_output=True)
171
+
172
+ btn_get_ckpts.click(
173
+ fn=get_ckpt_names,
174
+ inputs=[input_token, radio_model_names, input_model],
175
+ outputs=[error_output, radio_ckpts, group_convert],
176
+ scroll_to_output=True,
177
+ queue=False
178
+ )
179
+
180
+ btn_convert.click(
181
+ fn=convert_and_push,
182
+ inputs=[radio_model_names, input_model, radio_ckpts, radio_sd_version, input_token, path_in_repo, ema, safetensors],
183
+ outputs=error_output,
184
+ scroll_to_output=True
185
+ )
186
+
187
+ # gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
188
+ gr.HTML("""
189
+ <div style="border-top: 1px solid #303030;">
190
+ <br>
191
+ <p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
192
+ <a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
193
+ <p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
194
+ </div>
195
+ """)
196
+
197
+ demo.queue()
198
+ demo.launch(debug=True, share=utils.is_google_colab())
hf_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import get_hf_file_metadata, hf_hub_url, hf_hub_download, scan_cache_dir, whoami, list_models
2
+
3
+
4
+ def get_my_model_names(token):
5
+
6
+ try:
7
+ author = whoami(token=token)
8
+ model_infos = list_models(author=author["name"], use_auth_token=token)
9
+ return [model.modelId for model in model_infos], None
10
+
11
+ except Exception as e:
12
+ return [], e
13
+
14
+ def download_file(repo_id: str, filename: str, token: str):
15
+ """Download a file from a repo on the Hugging Face Hub.
16
+
17
+ Returns:
18
+ file_path (:obj:`str`): The path to the downloaded file.
19
+ revision (:obj:`str`): The commit hash of the file.
20
+ """
21
+
22
+ md = get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename=filename), token=token)
23
+ revision = md.commit_hash
24
+
25
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
26
+
27
+ return file_path, revision
28
+
29
+ def delete_file(revision: str):
30
+ """Delete a file from local cache.
31
+
32
+ Args:
33
+ revision (:obj:`str`): The commit hash of the file.
34
+ Returns:
35
+ None
36
+ """
37
+ scan_cache_dir().delete_revisions(revision).execute()
38
+
39
+ def get_pr_url(api, repo_id, title):
40
+ try:
41
+ discussions = api.get_repo_discussions(repo_id=repo_id)
42
+ except Exception:
43
+ return None
44
+ for discussion in discussions:
45
+ if (
46
+ discussion.status == "open"
47
+ and discussion.is_pull_request
48
+ and discussion.title == title
49
+ ):
50
+ return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/huggingface_hub@main
2
+ git+https://github.com/huggingface/diffusers.git
3
+ torch
4
+ #transformers
5
+ git+https://github.com/huggingface/transformers
6
+ pytorch_lightning
7
+ OmegaConf
8
+ ftfy
9
+ safetensors
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False