hysts HF staff commited on
Commit
ecfdc8b
1 Parent(s): 1f333c3

Migrate from yapf to black

Browse files
Files changed (13) hide show
  1. .pre-commit-config.yaml +26 -12
  2. .style.yapf +0 -5
  3. .vscode/settings.json +21 -0
  4. app.py +27 -29
  5. app_inference.py +64 -94
  6. app_system_monitor.py +29 -30
  7. app_training.py +77 -102
  8. app_upload.py +34 -34
  9. constants.py +7 -5
  10. inference.py +12 -17
  11. trainer.py +40 -41
  12. uploader.py +23 -20
  13. utils.py +12 -15
.pre-commit-config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
@@ -9,29 +9,43 @@ repos:
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
12
- - id: double-quote-string-fixer
13
  - id: end-of-file-fixer
14
  - id: mixed-line-ending
15
- args: ['--fix=lf']
16
  - id: requirements-txt-fixer
17
  - id: trailing-whitespace
18
  - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
  hooks:
21
  - id: docformatter
22
- args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
  rev: 5.12.0
25
  hooks:
26
  - id: isort
 
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
  hooks:
30
  - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
  hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.4.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
 
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
 
12
  - id: end-of-file-fixer
13
  - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
  - id: requirements-txt-fixer
16
  - id: trailing-whitespace
17
  - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
  hooks:
20
  - id: docformatter
21
+ args: ["--in-place"]
22
  - repo: https://github.com/pycqa/isort
23
  rev: 5.12.0
24
  hooks:
25
  - id: isort
26
+ args: ["--profile", "black"]
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.5.1
29
  hooks:
30
  - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies: ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.9.0
35
  hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args: ["--extra-keys", "metadata.interpreter metadata.kernelspec cell.metadata.pycharm"]
44
+ - repo: https://github.com/nbQA-dev/nbQA
45
+ rev: 1.7.0
46
+ hooks:
47
+ - id: nbqa-black
48
+ - id: nbqa-pyupgrade
49
+ args: ["--py37-plus"]
50
+ - id: nbqa-isort
51
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
app.py CHANGED
@@ -15,37 +15,37 @@ from app_upload import create_upload_demo
15
  from inference import InferencePipeline
16
  from trainer import Trainer
17
 
18
- TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
19
 
20
- ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
21
- SPACE_ID = os.getenv('SPACE_ID')
22
- GPU_DATA = getoutput('nvidia-smi')
23
- SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
24
 
25
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
26
- '''
27
 
28
  IS_SHARED_UI = SPACE_ID == ORIGINAL_SPACE_ID
29
- if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
30
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
31
  else:
32
- SETTINGS = 'Settings'
33
 
34
- INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
35
 
36
- CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
37
  <center>
38
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
39
  You can use "T4 small/medium" to run this demo.
40
  </center>
41
- '''
42
 
43
- HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
44
 
45
  You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>. You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
46
- '''
47
 
48
- HF_TOKEN = os.getenv('HF_TOKEN')
49
 
50
 
51
  def show_warning(warning_text: str) -> gr.Blocks:
@@ -58,33 +58,31 @@ def show_warning(warning_text: str) -> gr.Blocks:
58
  pipe = InferencePipeline(HF_TOKEN)
59
  trainer = Trainer()
60
 
61
- with gr.Blocks(css='style.css') as demo:
62
  if IS_SHARED_UI:
63
  show_warning(SHARED_UI_WARNING)
64
  elif not torch.cuda.is_available():
65
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
- elif 'T4' not in GPU_DATA:
67
  show_warning(INVALID_GPU_WARNING)
68
 
69
  gr.Markdown(TITLE)
70
  with gr.Tabs():
71
- with gr.TabItem('Train'):
72
- create_training_demo(trainer,
73
- pipe,
74
- disable_run_button=IS_SHARED_UI)
75
- with gr.TabItem('Run'):
76
- create_inference_demo(pipe,
77
- HF_TOKEN,
78
- disable_run_button=IS_SHARED_UI)
79
- with gr.TabItem('Upload'):
80
- gr.Markdown('''
81
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
82
- ''')
 
83
  create_upload_demo(disable_run_button=IS_SHARED_UI)
84
 
85
  with gr.Row():
86
- if not IS_SHARED_UI and not os.getenv('DISABLE_SYSTEM_MONITOR'):
87
- with gr.Accordion(label='System info', open=False):
88
  create_monitor_demo()
89
 
90
  if not HF_TOKEN:
 
15
  from inference import InferencePipeline
16
  from trainer import Trainer
17
 
18
+ TITLE = "# [Tune-A-Video](https://tuneavideo.github.io/)"
19
 
20
+ ORIGINAL_SPACE_ID = "Tune-A-Video-library/Tune-A-Video-Training-UI"
21
+ SPACE_ID = os.getenv("SPACE_ID")
22
+ GPU_DATA = getoutput("nvidia-smi")
23
+ SHARED_UI_WARNING = f"""## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
24
 
25
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
26
+ """
27
 
28
  IS_SHARED_UI = SPACE_ID == ORIGINAL_SPACE_ID
29
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
30
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
31
  else:
32
+ SETTINGS = "Settings"
33
 
34
+ INVALID_GPU_WARNING = f"""## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task."""
35
 
36
+ CUDA_NOT_AVAILABLE_WARNING = f"""## Attention - Running on CPU.
37
  <center>
38
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
39
  You can use "T4 small/medium" to run this demo.
40
  </center>
41
+ """
42
 
43
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f"""The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
44
 
45
  You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>. You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
46
+ """
47
 
48
+ HF_TOKEN = os.getenv("HF_TOKEN")
49
 
50
 
51
  def show_warning(warning_text: str) -> gr.Blocks:
 
58
  pipe = InferencePipeline(HF_TOKEN)
59
  trainer = Trainer()
60
 
61
+ with gr.Blocks(css="style.css") as demo:
62
  if IS_SHARED_UI:
63
  show_warning(SHARED_UI_WARNING)
64
  elif not torch.cuda.is_available():
65
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif "T4" not in GPU_DATA:
67
  show_warning(INVALID_GPU_WARNING)
68
 
69
  gr.Markdown(TITLE)
70
  with gr.Tabs():
71
+ with gr.TabItem("Train"):
72
+ create_training_demo(trainer, pipe, disable_run_button=IS_SHARED_UI)
73
+ with gr.TabItem("Run"):
74
+ create_inference_demo(pipe, HF_TOKEN, disable_run_button=IS_SHARED_UI)
75
+ with gr.TabItem("Upload"):
76
+ gr.Markdown(
77
+ """
 
 
 
78
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
79
+ """
80
+ )
81
  create_upload_demo(disable_run_button=IS_SHARED_UI)
82
 
83
  with gr.Row():
84
+ if not IS_SHARED_UI and not os.getenv("DISABLE_SYSTEM_MONITOR"):
85
+ with gr.Accordion(label="System info", open=False):
86
  create_monitor_demo()
87
 
88
  if not HF_TOKEN:
app_inference.py CHANGED
@@ -14,7 +14,7 @@ from utils import find_exp_dirs
14
 
15
  class ModelSource(enum.Enum):
16
  HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
- LOCAL = 'Local'
18
 
19
 
20
  class InferenceUtil:
@@ -23,18 +23,13 @@ class InferenceUtil:
23
 
24
  def load_hub_model_list(self) -> dict:
25
  api = HfApi(token=self.hf_token)
26
- choices = [
27
- info.modelId
28
- for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
- ]
30
- return gr.update(choices=choices,
31
- value=choices[0] if choices else None)
32
 
33
  @staticmethod
34
  def load_local_model_list() -> dict:
35
  choices = find_exp_dirs()
36
- return gr.update(choices=choices,
37
- value=choices[0] if choices else None)
38
 
39
  def reload_model_list(self, model_source: str) -> dict:
40
  if model_source == ModelSource.HUB_LIB.value:
@@ -48,22 +43,21 @@ class InferenceUtil:
48
  try:
49
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
  except Exception:
51
- return '', ''
52
- base_model = getattr(card.data, 'base_model', '')
53
- training_prompt = getattr(card.data, 'training_prompt', '')
54
  return base_model, training_prompt
55
 
56
- def reload_model_list_and_update_model_info(
57
- self, model_source: str) -> tuple[dict, str, str]:
58
  model_list_update = self.reload_model_list(model_source)
59
- model_list = model_list_update['choices']
60
- model_info = self.load_model_info(model_list[0] if model_list else '')
61
  return model_list_update, *model_info
62
 
63
 
64
- def create_inference_demo(pipe: InferencePipeline,
65
- hf_token: str | None = None,
66
- disable_run_button: bool = False) -> gr.Blocks:
67
  app = InferenceUtil(hf_token)
68
 
69
  with gr.Blocks() as demo:
@@ -71,84 +65,60 @@ def create_inference_demo(pipe: InferencePipeline,
71
  with gr.Column():
72
  with gr.Box():
73
  model_source = gr.Radio(
74
- label='Model Source',
75
- choices=[_.value for _ in ModelSource],
76
- value=ModelSource.HUB_LIB.value)
77
- reload_button = gr.Button('Reload Model List')
78
- model_id = gr.Dropdown(label='Model ID',
79
- choices=None,
80
- value=None)
81
- with gr.Accordion(
82
- label=
83
- 'Model info (Base model and prompt used for training)',
84
- open=False):
85
  with gr.Row():
86
- base_model_used_for_training = gr.Text(
87
- label='Base model', interactive=False)
88
- prompt_used_for_training = gr.Text(
89
- label='Training prompt', interactive=False)
90
- prompt = gr.Textbox(
91
- label='Prompt',
92
- max_lines=1,
93
- placeholder='Example: "A panda is surfing"')
94
- video_length = gr.Slider(label='Video length',
95
- minimum=4,
96
- maximum=12,
97
- step=1,
98
- value=8)
99
- fps = gr.Slider(label='FPS',
100
- minimum=1,
101
- maximum=12,
102
- step=1,
103
- value=1)
104
- seed = gr.Slider(label='Seed',
105
- minimum=0,
106
- maximum=100000,
107
- step=1,
108
- value=0)
109
- with gr.Accordion('Advanced options', open=False):
110
- num_steps = gr.Slider(label='Number of Steps',
111
- minimum=0,
112
- maximum=100,
113
- step=1,
114
- value=50)
115
- guidance_scale = gr.Slider(label='Guidance scale',
116
- minimum=0,
117
- maximum=50,
118
- step=0.1,
119
- value=7.5)
120
-
121
- run_button = gr.Button('Generate',
122
- interactive=not disable_run_button)
123
-
124
- gr.Markdown('''
125
  - After training, you can press "Reload Model List" button to load your trained model names.
126
  - It takes a few minutes to download model first.
127
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
128
- ''')
 
129
  with gr.Column():
130
- result = gr.Video(label='Result')
131
-
132
- model_source.change(fn=app.reload_model_list_and_update_model_info,
133
- inputs=model_source,
134
- outputs=[
135
- model_id,
136
- base_model_used_for_training,
137
- prompt_used_for_training,
138
- ])
139
- reload_button.click(fn=app.reload_model_list_and_update_model_info,
140
- inputs=model_source,
141
- outputs=[
142
- model_id,
143
- base_model_used_for_training,
144
- prompt_used_for_training,
145
- ])
146
- model_id.change(fn=app.load_model_info,
147
- inputs=model_id,
148
- outputs=[
149
- base_model_used_for_training,
150
- prompt_used_for_training,
151
- ])
 
 
 
 
 
 
152
  inputs = [
153
  model_id,
154
  prompt,
@@ -163,10 +133,10 @@ def create_inference_demo(pipe: InferencePipeline,
163
  return demo
164
 
165
 
166
- if __name__ == '__main__':
167
  import os
168
 
169
- hf_token = os.getenv('HF_TOKEN')
170
  pipe = InferencePipeline(hf_token)
171
  demo = create_inference_demo(pipe, hf_token)
172
  demo.queue(api_open=False, max_size=10).launch()
 
14
 
15
  class ModelSource(enum.Enum):
16
  HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = "Local"
18
 
19
 
20
  class InferenceUtil:
 
23
 
24
  def load_hub_model_list(self) -> dict:
25
  api = HfApi(token=self.hf_token)
26
+ choices = [info.modelId for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)]
27
+ return gr.update(choices=choices, value=choices[0] if choices else None)
 
 
 
 
28
 
29
  @staticmethod
30
  def load_local_model_list() -> dict:
31
  choices = find_exp_dirs()
32
+ return gr.update(choices=choices, value=choices[0] if choices else None)
 
33
 
34
  def reload_model_list(self, model_source: str) -> dict:
35
  if model_source == ModelSource.HUB_LIB.value:
 
43
  try:
44
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
45
  except Exception:
46
+ return "", ""
47
+ base_model = getattr(card.data, "base_model", "")
48
+ training_prompt = getattr(card.data, "training_prompt", "")
49
  return base_model, training_prompt
50
 
51
+ def reload_model_list_and_update_model_info(self, model_source: str) -> tuple[dict, str, str]:
 
52
  model_list_update = self.reload_model_list(model_source)
53
+ model_list = model_list_update["choices"]
54
+ model_info = self.load_model_info(model_list[0] if model_list else "")
55
  return model_list_update, *model_info
56
 
57
 
58
+ def create_inference_demo(
59
+ pipe: InferencePipeline, hf_token: str | None = None, disable_run_button: bool = False
60
+ ) -> gr.Blocks:
61
  app = InferenceUtil(hf_token)
62
 
63
  with gr.Blocks() as demo:
 
65
  with gr.Column():
66
  with gr.Box():
67
  model_source = gr.Radio(
68
+ label="Model Source", choices=[_.value for _ in ModelSource], value=ModelSource.HUB_LIB.value
69
+ )
70
+ reload_button = gr.Button("Reload Model List")
71
+ model_id = gr.Dropdown(label="Model ID", choices=None, value=None)
72
+ with gr.Accordion(label="Model info (Base model and prompt used for training)", open=False):
 
 
 
 
 
 
73
  with gr.Row():
74
+ base_model_used_for_training = gr.Text(label="Base model", interactive=False)
75
+ prompt_used_for_training = gr.Text(label="Training prompt", interactive=False)
76
+ prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "A panda is surfing"')
77
+ video_length = gr.Slider(label="Video length", minimum=4, maximum=12, step=1, value=8)
78
+ fps = gr.Slider(label="FPS", minimum=1, maximum=12, step=1, value=1)
79
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=0)
80
+ with gr.Accordion("Advanced options", open=False):
81
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=100, step=1, value=50)
82
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=50, step=0.1, value=7.5)
83
+
84
+ run_button = gr.Button("Generate", interactive=not disable_run_button)
85
+
86
+ gr.Markdown(
87
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  - After training, you can press "Reload Model List" button to load your trained model names.
89
  - It takes a few minutes to download model first.
90
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
91
+ """
92
+ )
93
  with gr.Column():
94
+ result = gr.Video(label="Result")
95
+
96
+ model_source.change(
97
+ fn=app.reload_model_list_and_update_model_info,
98
+ inputs=model_source,
99
+ outputs=[
100
+ model_id,
101
+ base_model_used_for_training,
102
+ prompt_used_for_training,
103
+ ],
104
+ )
105
+ reload_button.click(
106
+ fn=app.reload_model_list_and_update_model_info,
107
+ inputs=model_source,
108
+ outputs=[
109
+ model_id,
110
+ base_model_used_for_training,
111
+ prompt_used_for_training,
112
+ ],
113
+ )
114
+ model_id.change(
115
+ fn=app.load_model_info,
116
+ inputs=model_id,
117
+ outputs=[
118
+ base_model_used_for_training,
119
+ prompt_used_for_training,
120
+ ],
121
+ )
122
  inputs = [
123
  model_id,
124
  prompt,
 
133
  return demo
134
 
135
 
136
+ if __name__ == "__main__":
137
  import os
138
 
139
+ hf_token = os.getenv("HF_TOKEN")
140
  pipe = InferencePipeline(hf_token)
141
  demo = create_inference_demo(pipe, hf_token)
142
  demo.queue(api_open=False, max_size=10).launch()
app_system_monitor.py CHANGED
@@ -16,15 +16,12 @@ class SystemMonitor:
16
 
17
  def __init__(self):
18
  self.devices = nvitop.Device.all()
19
- self.cpu_memory_usage = collections.deque(
20
- [0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
21
- self.cpu_memory_usage_str = ''
22
- self.gpu_memory_usage = collections.deque(
23
- [0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
24
- self.gpu_util = collections.deque([0 for _ in range(self.MAX_SIZE)],
25
- maxlen=self.MAX_SIZE)
26
- self.gpu_memory_usage_str = ''
27
- self.gpu_util_str = ''
28
 
29
  def update(self) -> None:
30
  self.update_cpu()
@@ -33,7 +30,9 @@ class SystemMonitor:
33
  def update_cpu(self) -> None:
34
  memory = psutil.virtual_memory()
35
  self.cpu_memory_usage.append(memory.percent)
36
- self.cpu_memory_usage_str = f'{memory.used / 1024**3:0.2f}GiB / {memory.total / 1024**3:0.2f}GiB ({memory.percent}%)'
 
 
37
 
38
  def update_gpu(self) -> None:
39
  if not self.devices:
@@ -41,36 +40,36 @@ class SystemMonitor:
41
  device = self.devices[0]
42
  self.gpu_memory_usage.append(device.memory_percent())
43
  self.gpu_util.append(device.gpu_utilization())
44
- self.gpu_memory_usage_str = f'{device.memory_usage()} ({device.memory_percent()}%)'
45
- self.gpu_util_str = f'{device.gpu_utilization()}%'
46
 
47
  def get_json(self) -> dict[str, str]:
48
  return {
49
- 'CPU memory usage': self.cpu_memory_usage_str,
50
- 'GPU memory usage': self.gpu_memory_usage_str,
51
- 'GPU Util': self.gpu_util_str,
52
  }
53
 
54
  def get_graph_data(self) -> dict[str, list[int | float]]:
55
  return {
56
- 'index': list(range(-self.MAX_SIZE + 1, 1)),
57
- 'CPU memory usage': self.cpu_memory_usage,
58
- 'GPU memory usage': self.gpu_memory_usage,
59
- 'GPU Util': self.gpu_util,
60
  }
61
 
62
  def get_graph(self):
63
  df = pd.DataFrame(self.get_graph_data())
64
- return px.line(df,
65
- x='index',
66
- y=[
67
- 'CPU memory usage',
68
- 'GPU memory usage',
69
- 'GPU Util',
70
- ],
71
- range_y=[-5,
72
- 105]).update_layout(xaxis_title='Time',
73
- yaxis_title='Percentage')
74
 
75
 
76
  def create_monitor_demo() -> gr.Blocks:
@@ -82,6 +81,6 @@ def create_monitor_demo() -> gr.Blocks:
82
  return demo
83
 
84
 
85
- if __name__ == '__main__':
86
  demo = create_monitor_demo()
87
  demo.queue(api_open=False).launch()
 
16
 
17
  def __init__(self):
18
  self.devices = nvitop.Device.all()
19
+ self.cpu_memory_usage = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
20
+ self.cpu_memory_usage_str = ""
21
+ self.gpu_memory_usage = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
22
+ self.gpu_util = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
23
+ self.gpu_memory_usage_str = ""
24
+ self.gpu_util_str = ""
 
 
 
25
 
26
  def update(self) -> None:
27
  self.update_cpu()
 
30
  def update_cpu(self) -> None:
31
  memory = psutil.virtual_memory()
32
  self.cpu_memory_usage.append(memory.percent)
33
+ self.cpu_memory_usage_str = (
34
+ f"{memory.used / 1024**3:0.2f}GiB / {memory.total / 1024**3:0.2f}GiB ({memory.percent}%)"
35
+ )
36
 
37
  def update_gpu(self) -> None:
38
  if not self.devices:
 
40
  device = self.devices[0]
41
  self.gpu_memory_usage.append(device.memory_percent())
42
  self.gpu_util.append(device.gpu_utilization())
43
+ self.gpu_memory_usage_str = f"{device.memory_usage()} ({device.memory_percent()}%)"
44
+ self.gpu_util_str = f"{device.gpu_utilization()}%"
45
 
46
  def get_json(self) -> dict[str, str]:
47
  return {
48
+ "CPU memory usage": self.cpu_memory_usage_str,
49
+ "GPU memory usage": self.gpu_memory_usage_str,
50
+ "GPU Util": self.gpu_util_str,
51
  }
52
 
53
  def get_graph_data(self) -> dict[str, list[int | float]]:
54
  return {
55
+ "index": list(range(-self.MAX_SIZE + 1, 1)),
56
+ "CPU memory usage": self.cpu_memory_usage,
57
+ "GPU memory usage": self.gpu_memory_usage,
58
+ "GPU Util": self.gpu_util,
59
  }
60
 
61
  def get_graph(self):
62
  df = pd.DataFrame(self.get_graph_data())
63
+ return px.line(
64
+ df,
65
+ x="index",
66
+ y=[
67
+ "CPU memory usage",
68
+ "GPU memory usage",
69
+ "GPU Util",
70
+ ],
71
+ range_y=[-5, 105],
72
+ ).update_layout(xaxis_title="Time", yaxis_title="Percentage")
73
 
74
 
75
  def create_monitor_demo() -> gr.Blocks:
 
81
  return demo
82
 
83
 
84
+ if __name__ == "__main__":
85
  demo = create_monitor_demo()
86
  demo.queue(api_open=False).launch()
app_training.py CHANGED
@@ -11,145 +11,120 @@ from inference import InferencePipeline
11
  from trainer import Trainer
12
 
13
 
14
- def create_training_demo(trainer: Trainer,
15
- pipe: InferencePipeline | None = None,
16
- disable_run_button: bool = False) -> gr.Blocks:
17
  def read_log() -> str:
18
  with open(trainer.log_file) as f:
19
  lines = f.readlines()
20
- return ''.join(lines[-10:])
21
 
22
  with gr.Blocks() as demo:
23
  with gr.Row():
24
  with gr.Column():
25
  with gr.Box():
26
- gr.Markdown('Training Data')
27
- training_video = gr.File(label='Training video')
28
- training_prompt = gr.Textbox(
29
- label='Training prompt',
30
- max_lines=1,
31
- placeholder='A man is surfing')
32
- gr.Markdown('''
33
  - Upload a video and write a `Training Prompt` that describes the video.
34
- ''')
 
35
 
36
  with gr.Column():
37
  with gr.Box():
38
- gr.Markdown('Training Parameters')
39
  with gr.Row():
40
- base_model = gr.Text(
41
- label='Base Model',
42
- value='CompVis/stable-diffusion-v1-4',
43
- max_lines=1)
44
- resolution = gr.Dropdown(choices=['512', '768'],
45
- value='512',
46
- label='Resolution',
47
- visible=False)
48
 
49
- hf_token = gr.Text(label='Hugging Face Write Token',
50
- type='password',
51
- visible=os.getenv('HF_TOKEN') is None)
52
- with gr.Accordion(label='Advanced options', open=False):
53
- num_training_steps = gr.Number(
54
- label='Number of Training Steps',
55
- value=300,
56
- precision=0)
57
- learning_rate = gr.Number(label='Learning Rate',
58
- value=0.000035)
59
  gradient_accumulation = gr.Number(
60
- label='Number of Gradient Accumulation',
61
- value=1,
62
- precision=0)
63
- seed = gr.Slider(label='Seed',
64
- minimum=0,
65
- maximum=100000,
66
- step=1,
67
- randomize=True,
68
- value=0)
69
- fp16 = gr.Checkbox(label='FP16', value=True)
70
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
71
- value=False)
72
- checkpointing_steps = gr.Number(
73
- label='Checkpointing Steps',
74
- value=1000,
75
- precision=0)
76
- validation_epochs = gr.Number(
77
- label='Validation Epochs', value=100, precision=0)
78
- gr.Markdown('''
79
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
80
  - Expected time to train a model for 300 steps: ~20 minutes with T4
81
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
82
- ''')
 
83
 
84
  with gr.Row():
85
  with gr.Column():
86
- gr.Markdown('Output Model')
87
- output_model_name = gr.Text(label='Name of your model',
88
- placeholder='The surfer man',
89
- max_lines=1)
90
  validation_prompt = gr.Text(
91
- label='Validation Prompt',
92
- placeholder=
93
- 'prompt to test the model, e.g: a dog is surfing')
94
  with gr.Column():
95
- gr.Markdown('Upload Settings')
96
  with gr.Row():
97
- upload_to_hub = gr.Checkbox(label='Upload model to Hub',
98
- value=True)
99
- use_private_repo = gr.Checkbox(label='Private', value=True)
100
- delete_existing_repo = gr.Checkbox(
101
- label='Delete existing repo of the same name',
102
- value=False)
103
  upload_to = gr.Radio(
104
- label='Upload to',
105
  choices=[_.value for _ in UploadTarget],
106
- value=UploadTarget.MODEL_LIBRARY.value)
 
107
 
108
  pause_space_after_training = gr.Checkbox(
109
- label='Pause this Space after training',
110
  value=False,
111
- interactive=bool(os.getenv('SPACE_ID')),
112
- visible=False)
113
- run_button = gr.Button('Start Training',
114
- interactive=not disable_run_button)
115
 
116
  with gr.Box():
117
- gr.Text(label='Log',
118
- value=read_log,
119
- lines=10,
120
- max_lines=10,
121
- every=1)
122
 
123
  if pipe is not None:
124
  run_button.click(fn=pipe.clear)
125
- run_button.click(fn=trainer.run,
126
- inputs=[
127
- training_video,
128
- training_prompt,
129
- output_model_name,
130
- delete_existing_repo,
131
- validation_prompt,
132
- base_model,
133
- resolution,
134
- num_training_steps,
135
- learning_rate,
136
- gradient_accumulation,
137
- seed,
138
- fp16,
139
- use_8bit_adam,
140
- checkpointing_steps,
141
- validation_epochs,
142
- upload_to_hub,
143
- use_private_repo,
144
- delete_existing_repo,
145
- upload_to,
146
- pause_space_after_training,
147
- hf_token,
148
- ])
 
 
149
  return demo
150
 
151
 
152
- if __name__ == '__main__':
153
  trainer = Trainer()
154
  demo = create_training_demo(trainer)
155
  demo.queue(api_open=False, max_size=1).launch()
 
11
  from trainer import Trainer
12
 
13
 
14
+ def create_training_demo(
15
+ trainer: Trainer, pipe: InferencePipeline | None = None, disable_run_button: bool = False
16
+ ) -> gr.Blocks:
17
  def read_log() -> str:
18
  with open(trainer.log_file) as f:
19
  lines = f.readlines()
20
+ return "".join(lines[-10:])
21
 
22
  with gr.Blocks() as demo:
23
  with gr.Row():
24
  with gr.Column():
25
  with gr.Box():
26
+ gr.Markdown("Training Data")
27
+ training_video = gr.File(label="Training video")
28
+ training_prompt = gr.Textbox(label="Training prompt", max_lines=1, placeholder="A man is surfing")
29
+ gr.Markdown(
30
+ """
 
 
31
  - Upload a video and write a `Training Prompt` that describes the video.
32
+ """
33
+ )
34
 
35
  with gr.Column():
36
  with gr.Box():
37
+ gr.Markdown("Training Parameters")
38
  with gr.Row():
39
+ base_model = gr.Text(label="Base Model", value="CompVis/stable-diffusion-v1-4", max_lines=1)
40
+ resolution = gr.Dropdown(
41
+ choices=["512", "768"], value="512", label="Resolution", visible=False
42
+ )
 
 
 
 
43
 
44
+ hf_token = gr.Text(
45
+ label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None
46
+ )
47
+ with gr.Accordion(label="Advanced options", open=False):
48
+ num_training_steps = gr.Number(label="Number of Training Steps", value=300, precision=0)
49
+ learning_rate = gr.Number(label="Learning Rate", value=0.000035)
 
 
 
 
50
  gradient_accumulation = gr.Number(
51
+ label="Number of Gradient Accumulation", value=1, precision=0
52
+ )
53
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, randomize=True, value=0)
54
+ fp16 = gr.Checkbox(label="FP16", value=True)
55
+ use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=False)
56
+ checkpointing_steps = gr.Number(label="Checkpointing Steps", value=1000, precision=0)
57
+ validation_epochs = gr.Number(label="Validation Epochs", value=100, precision=0)
58
+ gr.Markdown(
59
+ """
 
 
 
 
 
 
 
 
 
 
60
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
61
  - Expected time to train a model for 300 steps: ~20 minutes with T4
62
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
63
+ """
64
+ )
65
 
66
  with gr.Row():
67
  with gr.Column():
68
+ gr.Markdown("Output Model")
69
+ output_model_name = gr.Text(label="Name of your model", placeholder="The surfer man", max_lines=1)
 
 
70
  validation_prompt = gr.Text(
71
+ label="Validation Prompt", placeholder="prompt to test the model, e.g: a dog is surfing"
72
+ )
 
73
  with gr.Column():
74
+ gr.Markdown("Upload Settings")
75
  with gr.Row():
76
+ upload_to_hub = gr.Checkbox(label="Upload model to Hub", value=True)
77
+ use_private_repo = gr.Checkbox(label="Private", value=True)
78
+ delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False)
 
 
 
79
  upload_to = gr.Radio(
80
+ label="Upload to",
81
  choices=[_.value for _ in UploadTarget],
82
+ value=UploadTarget.MODEL_LIBRARY.value,
83
+ )
84
 
85
  pause_space_after_training = gr.Checkbox(
86
+ label="Pause this Space after training",
87
  value=False,
88
+ interactive=bool(os.getenv("SPACE_ID")),
89
+ visible=False,
90
+ )
91
+ run_button = gr.Button("Start Training", interactive=not disable_run_button)
92
 
93
  with gr.Box():
94
+ gr.Text(label="Log", value=read_log, lines=10, max_lines=10, every=1)
 
 
 
 
95
 
96
  if pipe is not None:
97
  run_button.click(fn=pipe.clear)
98
+ run_button.click(
99
+ fn=trainer.run,
100
+ inputs=[
101
+ training_video,
102
+ training_prompt,
103
+ output_model_name,
104
+ delete_existing_repo,
105
+ validation_prompt,
106
+ base_model,
107
+ resolution,
108
+ num_training_steps,
109
+ learning_rate,
110
+ gradient_accumulation,
111
+ seed,
112
+ fp16,
113
+ use_8bit_adam,
114
+ checkpointing_steps,
115
+ validation_epochs,
116
+ upload_to_hub,
117
+ use_private_repo,
118
+ delete_existing_repo,
119
+ upload_to,
120
+ pause_space_after_training,
121
+ hf_token,
122
+ ],
123
+ )
124
  return demo
125
 
126
 
127
+ if __name__ == "__main__":
128
  trainer = Trainer()
129
  demo = create_training_demo(trainer)
130
  demo.queue(api_open=False, max_size=1).launch()
app_upload.py CHANGED
@@ -21,49 +21,49 @@ def create_upload_demo(disable_run_button: bool = False) -> gr.Blocks:
21
 
22
  with gr.Blocks() as demo:
23
  with gr.Box():
24
- gr.Markdown('Local Models')
25
- reload_button = gr.Button('Reload Model List')
26
  model_dir = gr.Dropdown(
27
- label='Model names',
28
- choices=model_dirs,
29
- value=model_dirs[0] if model_dirs else None)
30
  with gr.Box():
31
- gr.Markdown('Upload Settings')
32
  with gr.Row():
33
- use_private_repo = gr.Checkbox(label='Private', value=True)
34
- delete_existing_repo = gr.Checkbox(
35
- label='Delete existing repo of the same name', value=False)
36
- upload_to = gr.Radio(label='Upload to',
37
- choices=[_.value for _ in UploadTarget],
38
- value=UploadTarget.MODEL_LIBRARY.value)
39
- model_name = gr.Textbox(label='Model Name')
40
- hf_token = gr.Text(label='Hugging Face Write Token',
41
- type='password',
42
- visible=os.getenv('HF_TOKEN') is None)
43
- upload_button = gr.Button('Upload', interactive=not disable_run_button)
44
- gr.Markdown(f'''
45
  - You can upload your trained model to your personal profile (i.e. `https://huggingface.co/{{your_username}}/{{model_name}}`) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. `https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}`).
46
- ''')
 
47
  with gr.Box():
48
- gr.Markdown('Output message')
49
  output_message = gr.Markdown()
50
 
51
- reload_button.click(fn=load_local_model_list,
52
- inputs=None,
53
- outputs=model_dir)
54
- upload_button.click(fn=upload,
55
- inputs=[
56
- model_dir,
57
- model_name,
58
- upload_to,
59
- use_private_repo,
60
- delete_existing_repo,
61
- hf_token,
62
- ],
63
- outputs=output_message)
64
  return demo
65
 
66
 
67
- if __name__ == '__main__':
68
  demo = create_upload_demo()
69
  demo.queue(api_open=False, max_size=1).launch()
 
21
 
22
  with gr.Blocks() as demo:
23
  with gr.Box():
24
+ gr.Markdown("Local Models")
25
+ reload_button = gr.Button("Reload Model List")
26
  model_dir = gr.Dropdown(
27
+ label="Model names", choices=model_dirs, value=model_dirs[0] if model_dirs else None
28
+ )
 
29
  with gr.Box():
30
+ gr.Markdown("Upload Settings")
31
  with gr.Row():
32
+ use_private_repo = gr.Checkbox(label="Private", value=True)
33
+ delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False)
34
+ upload_to = gr.Radio(
35
+ label="Upload to", choices=[_.value for _ in UploadTarget], value=UploadTarget.MODEL_LIBRARY.value
36
+ )
37
+ model_name = gr.Textbox(label="Model Name")
38
+ hf_token = gr.Text(
39
+ label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None
40
+ )
41
+ upload_button = gr.Button("Upload", interactive=not disable_run_button)
42
+ gr.Markdown(
43
+ f"""
44
  - You can upload your trained model to your personal profile (i.e. `https://huggingface.co/{{your_username}}/{{model_name}}`) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. `https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}`).
45
+ """
46
+ )
47
  with gr.Box():
48
+ gr.Markdown("Output message")
49
  output_message = gr.Markdown()
50
 
51
+ reload_button.click(fn=load_local_model_list, inputs=None, outputs=model_dir)
52
+ upload_button.click(
53
+ fn=upload,
54
+ inputs=[
55
+ model_dir,
56
+ model_name,
57
+ upload_to,
58
+ use_private_repo,
59
+ delete_existing_repo,
60
+ hf_token,
61
+ ],
62
+ outputs=output_message,
63
+ )
64
  return demo
65
 
66
 
67
+ if __name__ == "__main__":
68
  demo = create_upload_demo()
69
  demo.queue(api_open=False, max_size=1).launch()
constants.py CHANGED
@@ -2,10 +2,12 @@ import enum
2
 
3
 
4
  class UploadTarget(enum.Enum):
5
- PERSONAL_PROFILE = 'Personal Profile'
6
- MODEL_LIBRARY = 'Tune-A-Video Library'
7
 
8
 
9
- MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
- SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
11
- URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
 
 
 
2
 
3
 
4
  class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = "Personal Profile"
6
+ MODEL_LIBRARY = "Tune-A-Video Library"
7
 
8
 
9
+ MODEL_LIBRARY_ORG_NAME = "Tune-A-Video-library"
10
+ SAMPLE_MODEL_REPO = "Tune-A-Video-library/a-man-is-surfing"
11
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = (
12
+ "https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk"
13
+ )
inference.py CHANGED
@@ -13,7 +13,7 @@ from diffusers.utils.import_utils import is_xformers_available
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
- sys.path.append('Tune-A-Video')
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
@@ -23,8 +23,7 @@ class InferencePipeline:
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
- self.device = torch.device(
27
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
  self.model_id = None
29
 
30
  def clear(self) -> None:
@@ -39,10 +38,9 @@ class InferencePipeline:
39
  return pathlib.Path(model_id).exists()
40
 
41
  @staticmethod
42
- def get_model_card(model_id: str,
43
- hf_token: str | None = None) -> ModelCard:
44
  if InferencePipeline.check_if_model_is_local(model_id):
45
- card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
  else:
47
  card_path = model_id
48
  return ModelCard.load(card_path, token=hf_token)
@@ -57,14 +55,11 @@ class InferencePipeline:
57
  return
58
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
  unet = UNet3DConditionModel.from_pretrained(
60
- model_id,
61
- subfolder='unet',
62
- torch_dtype=torch.float16,
63
- use_auth_token=self.hf_token)
64
- pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
- unet=unet,
66
- torch_dtype=torch.float16,
67
- use_auth_token=self.hf_token)
68
  pipe = pipe.to(self.device)
69
  if is_xformers_available():
70
  pipe.unet.enable_xformers_memory_efficient_attention()
@@ -82,7 +77,7 @@ class InferencePipeline:
82
  guidance_scale: float,
83
  ) -> PIL.Image.Image:
84
  if not torch.cuda.is_available():
85
- raise gr.Error('CUDA is not available.')
86
 
87
  self.load_pipe(model_id)
88
 
@@ -97,10 +92,10 @@ class InferencePipeline:
97
  generator=generator,
98
  ) # type: ignore
99
 
100
- frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
  frames = (frames * 255).to(torch.uint8).numpy()
102
 
103
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
  writer = imageio.get_writer(out_file.name, fps=fps)
105
  for frame in frames:
106
  writer.append_data(frame)
 
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
+ sys.path.append("Tune-A-Video")
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
 
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
27
  self.model_id = None
28
 
29
  def clear(self) -> None:
 
38
  return pathlib.Path(model_id).exists()
39
 
40
  @staticmethod
41
+ def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard:
 
42
  if InferencePipeline.check_if_model_is_local(model_id):
43
+ card_path = (pathlib.Path(model_id) / "README.md").as_posix()
44
  else:
45
  card_path = model_id
46
  return ModelCard.load(card_path, token=hf_token)
 
55
  return
56
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
57
  unet = UNet3DConditionModel.from_pretrained(
58
+ model_id, subfolder="unet", torch_dtype=torch.float16, use_auth_token=self.hf_token
59
+ )
60
+ pipe = TuneAVideoPipeline.from_pretrained(
61
+ base_model_id, unet=unet, torch_dtype=torch.float16, use_auth_token=self.hf_token
62
+ )
 
 
 
63
  pipe = pipe.to(self.device)
64
  if is_xformers_available():
65
  pipe.unet.enable_xformers_memory_efficient_attention()
 
77
  guidance_scale: float,
78
  ) -> PIL.Image.Image:
79
  if not torch.cuda.is_available():
80
+ raise gr.Error("CUDA is not available.")
81
 
82
  self.load_pipe(model_id)
83
 
 
92
  generator=generator,
93
  ) # type: ignore
94
 
95
+ frames = rearrange(out.videos[0], "c t h w -> t h w c")
96
  frames = (frames * 255).to(torch.uint8).numpy()
97
 
98
+ out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
99
  writer = imageio.get_writer(out_file.name, fps=fps)
100
  for frame in frames:
101
  writer.append_data(frame)
trainer.py CHANGED
@@ -16,26 +16,24 @@ from omegaconf import OmegaConf
16
  from uploader import upload
17
  from utils import save_model_card
18
 
19
- sys.path.append('Tune-A-Video')
20
 
21
 
22
  class Trainer:
23
  def __init__(self):
24
- self.checkpoint_dir = pathlib.Path('checkpoints')
25
  self.checkpoint_dir.mkdir(exist_ok=True)
26
 
27
- self.log_file = pathlib.Path('log.txt')
28
  self.log_file.touch(exist_ok=True)
29
 
30
  def download_base_model(self, base_model_id: str) -> str:
31
  model_dir = self.checkpoint_dir / base_model_id
32
  if not model_dir.exists():
33
- org_name = base_model_id.split('/')[0]
34
  org_dir = self.checkpoint_dir / org_name
35
  org_dir.mkdir(exist_ok=True)
36
- subprocess.run(shlex.split(
37
- f'git clone https://huggingface.co/{base_model_id}'),
38
- cwd=org_dir)
39
  return model_dir.as_posix()
40
 
41
  def run(
@@ -63,28 +61,28 @@ class Trainer:
63
  hf_token: str,
64
  ) -> None:
65
  if not torch.cuda.is_available():
66
- raise RuntimeError('CUDA is not available.')
67
  if training_video is None:
68
- raise ValueError('You need to upload a video.')
69
  if not training_prompt:
70
- raise ValueError('The training prompt is missing.')
71
  if not validation_prompt:
72
- raise ValueError('The validation prompt is missing.')
73
 
74
  resolution = int(resolution_s)
75
 
76
  if not output_model_name:
77
- timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
78
- output_model_name = f'tune-a-video-{timestamp}'
79
  output_model_name = slugify.slugify(output_model_name)
80
 
81
  repo_dir = pathlib.Path(__file__).parent
82
- output_dir = repo_dir / 'experiments' / output_model_name
83
  if overwrite_existing_model or upload_to_hub:
84
  shutil.rmtree(output_dir, ignore_errors=True)
85
  output_dir.mkdir(parents=True)
86
 
87
- config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
88
  config.pretrained_model_path = self.download_base_model(base_model)
89
  config.output_dir = output_dir.as_posix()
90
  config.train_data.video_path = training_video.name # type: ignore
@@ -107,39 +105,40 @@ class Trainer:
107
  config.checkpointing_steps = checkpointing_steps
108
  config.validation_steps = validation_epochs
109
  config.seed = seed
110
- config.mixed_precision = 'fp16' if fp16 else ''
111
  config.use_8bit_adam = use_8bit_adam
112
 
113
- config_path = output_dir / 'config.yaml'
114
- with open(config_path, 'w') as f:
115
  OmegaConf.save(config, f)
116
 
117
- command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
118
- with open(self.log_file, 'w') as f:
119
- subprocess.run(shlex.split(command),
120
- stdout=f,
121
- stderr=subprocess.STDOUT,
122
- text=True)
123
- save_model_card(save_dir=output_dir,
124
- base_model=base_model,
125
- training_prompt=training_prompt,
126
- test_prompt=validation_prompt,
127
- test_image_dir='samples')
128
-
129
- with open(self.log_file, 'a') as f:
130
- f.write('Training completed!\n')
131
 
132
  if upload_to_hub:
133
- upload_message = upload(local_folder_path=output_dir.as_posix(),
134
- target_repo_name=output_model_name,
135
- upload_to=upload_to,
136
- private=use_private_repo,
137
- delete_existing_repo=delete_existing_repo,
138
- hf_token=hf_token)
139
- with open(self.log_file, 'a') as f:
 
 
140
  f.write(upload_message)
141
 
142
  if pause_space_after_training:
143
- if space_id := os.getenv('SPACE_ID'):
144
- api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
145
  api.pause_space(repo_id=space_id)
 
16
  from uploader import upload
17
  from utils import save_model_card
18
 
19
+ sys.path.append("Tune-A-Video")
20
 
21
 
22
  class Trainer:
23
  def __init__(self):
24
+ self.checkpoint_dir = pathlib.Path("checkpoints")
25
  self.checkpoint_dir.mkdir(exist_ok=True)
26
 
27
+ self.log_file = pathlib.Path("log.txt")
28
  self.log_file.touch(exist_ok=True)
29
 
30
  def download_base_model(self, base_model_id: str) -> str:
31
  model_dir = self.checkpoint_dir / base_model_id
32
  if not model_dir.exists():
33
+ org_name = base_model_id.split("/")[0]
34
  org_dir = self.checkpoint_dir / org_name
35
  org_dir.mkdir(exist_ok=True)
36
+ subprocess.run(shlex.split(f"git clone https://huggingface.co/{base_model_id}"), cwd=org_dir)
 
 
37
  return model_dir.as_posix()
38
 
39
  def run(
 
61
  hf_token: str,
62
  ) -> None:
63
  if not torch.cuda.is_available():
64
+ raise RuntimeError("CUDA is not available.")
65
  if training_video is None:
66
+ raise ValueError("You need to upload a video.")
67
  if not training_prompt:
68
+ raise ValueError("The training prompt is missing.")
69
  if not validation_prompt:
70
+ raise ValueError("The validation prompt is missing.")
71
 
72
  resolution = int(resolution_s)
73
 
74
  if not output_model_name:
75
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
76
+ output_model_name = f"tune-a-video-{timestamp}"
77
  output_model_name = slugify.slugify(output_model_name)
78
 
79
  repo_dir = pathlib.Path(__file__).parent
80
+ output_dir = repo_dir / "experiments" / output_model_name
81
  if overwrite_existing_model or upload_to_hub:
82
  shutil.rmtree(output_dir, ignore_errors=True)
83
  output_dir.mkdir(parents=True)
84
 
85
+ config = OmegaConf.load("Tune-A-Video/configs/man-surfing.yaml")
86
  config.pretrained_model_path = self.download_base_model(base_model)
87
  config.output_dir = output_dir.as_posix()
88
  config.train_data.video_path = training_video.name # type: ignore
 
105
  config.checkpointing_steps = checkpointing_steps
106
  config.validation_steps = validation_epochs
107
  config.seed = seed
108
+ config.mixed_precision = "fp16" if fp16 else ""
109
  config.use_8bit_adam = use_8bit_adam
110
 
111
+ config_path = output_dir / "config.yaml"
112
+ with open(config_path, "w") as f:
113
  OmegaConf.save(config, f)
114
 
115
+ command = f"accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}"
116
+ with open(self.log_file, "w") as f:
117
+ subprocess.run(shlex.split(command), stdout=f, stderr=subprocess.STDOUT, text=True)
118
+ save_model_card(
119
+ save_dir=output_dir,
120
+ base_model=base_model,
121
+ training_prompt=training_prompt,
122
+ test_prompt=validation_prompt,
123
+ test_image_dir="samples",
124
+ )
125
+
126
+ with open(self.log_file, "a") as f:
127
+ f.write("Training completed!\n")
 
128
 
129
  if upload_to_hub:
130
+ upload_message = upload(
131
+ local_folder_path=output_dir.as_posix(),
132
+ target_repo_name=output_model_name,
133
+ upload_to=upload_to,
134
+ private=use_private_repo,
135
+ delete_existing_repo=delete_existing_repo,
136
+ hf_token=hf_token,
137
+ )
138
+ with open(self.log_file, "a") as f:
139
  f.write(upload_message)
140
 
141
  if pause_space_after_training:
142
+ if space_id := os.getenv("SPACE_ID"):
143
+ api = HfApi(token=os.getenv("HF_TOKEN") or hf_token)
144
  api.pause_space(repo_id=space_id)
uploader.py CHANGED
@@ -8,24 +8,30 @@ import subprocess
8
  import slugify
9
  from huggingface_hub import HfApi
10
 
11
- from constants import (MODEL_LIBRARY_ORG_NAME, URL_TO_JOIN_MODEL_LIBRARY_ORG,
12
- UploadTarget)
 
 
 
13
 
14
 
15
  def join_model_library_org(hf_token: str) -> None:
16
  subprocess.run(
17
  shlex.split(
18
  f'curl -X POST -H "Authorization: Bearer {hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
19
- ))
 
20
 
21
 
22
- def upload(local_folder_path: str,
23
- target_repo_name: str,
24
- upload_to: str,
25
- private: bool = True,
26
- delete_existing_repo: bool = False,
27
- hf_token: str = '') -> str:
28
- hf_token = os.getenv('HF_TOKEN') or hf_token
 
 
29
  if not hf_token:
30
  raise ValueError
31
  api = HfApi(token=hf_token)
@@ -37,27 +43,24 @@ def upload(local_folder_path: str,
37
  target_repo_name = slugify.slugify(target_repo_name)
38
 
39
  if upload_to == UploadTarget.PERSONAL_PROFILE.value:
40
- organization = api.whoami()['name']
41
  elif upload_to == UploadTarget.MODEL_LIBRARY.value:
42
  organization = MODEL_LIBRARY_ORG_NAME
43
  join_model_library_org(hf_token)
44
  else:
45
  raise ValueError
46
 
47
- repo_id = f'{organization}/{target_repo_name}'
48
  if delete_existing_repo:
49
  try:
50
- api.delete_repo(repo_id, repo_type='model')
51
  except Exception:
52
  pass
53
  try:
54
- api.create_repo(repo_id, repo_type='model', private=private)
55
- api.upload_folder(repo_id=repo_id,
56
- folder_path=local_folder_path,
57
- path_in_repo='.',
58
- repo_type='model')
59
- url = f'https://huggingface.co/{repo_id}'
60
- message = f'Your model was successfully uploaded to {url}.'
61
  except Exception as e:
62
  message = str(e)
63
  return message
 
8
  import slugify
9
  from huggingface_hub import HfApi
10
 
11
+ from constants import (
12
+ MODEL_LIBRARY_ORG_NAME,
13
+ URL_TO_JOIN_MODEL_LIBRARY_ORG,
14
+ UploadTarget,
15
+ )
16
 
17
 
18
  def join_model_library_org(hf_token: str) -> None:
19
  subprocess.run(
20
  shlex.split(
21
  f'curl -X POST -H "Authorization: Bearer {hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
22
+ )
23
+ )
24
 
25
 
26
+ def upload(
27
+ local_folder_path: str,
28
+ target_repo_name: str,
29
+ upload_to: str,
30
+ private: bool = True,
31
+ delete_existing_repo: bool = False,
32
+ hf_token: str = "",
33
+ ) -> str:
34
+ hf_token = os.getenv("HF_TOKEN") or hf_token
35
  if not hf_token:
36
  raise ValueError
37
  api = HfApi(token=hf_token)
 
43
  target_repo_name = slugify.slugify(target_repo_name)
44
 
45
  if upload_to == UploadTarget.PERSONAL_PROFILE.value:
46
+ organization = api.whoami()["name"]
47
  elif upload_to == UploadTarget.MODEL_LIBRARY.value:
48
  organization = MODEL_LIBRARY_ORG_NAME
49
  join_model_library_org(hf_token)
50
  else:
51
  raise ValueError
52
 
53
+ repo_id = f"{organization}/{target_repo_name}"
54
  if delete_existing_repo:
55
  try:
56
+ api.delete_repo(repo_id, repo_type="model")
57
  except Exception:
58
  pass
59
  try:
60
+ api.create_repo(repo_id, repo_type="model", private=private)
61
+ api.upload_folder(repo_id=repo_id, folder_path=local_folder_path, path_in_repo=".", repo_type="model")
62
+ url = f"https://huggingface.co/{repo_id}"
63
+ message = f"Your model was successfully uploaded to {url}."
 
 
 
64
  except Exception as e:
65
  message = str(e)
66
  return message
utils.py CHANGED
@@ -5,14 +5,11 @@ import pathlib
5
 
6
  def find_exp_dirs() -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
- exp_root_dir = repo_dir / 'experiments'
9
  if not exp_root_dir.exists():
10
  return []
11
- exp_dirs = sorted(exp_root_dir.glob('*'))
12
- exp_dirs = [
13
- exp_dir for exp_dir in exp_dirs
14
- if (exp_dir / 'model_index.json').exists()
15
- ]
16
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
 
18
 
@@ -20,21 +17,21 @@ def save_model_card(
20
  save_dir: pathlib.Path,
21
  base_model: str,
22
  training_prompt: str,
23
- test_prompt: str = '',
24
- test_image_dir: str = '',
25
  ) -> None:
26
- image_str = ''
27
  if test_prompt and test_image_dir:
28
- image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
  if image_paths:
30
  image_path = image_paths[-1]
31
  rel_path = image_path.relative_to(save_dir)
32
- image_str = f'''## Samples
33
  Test prompt: {test_prompt}
34
 
35
- ![{image_path.stem}]({rel_path})'''
36
 
37
- model_card = f'''---
38
  license: creativeml-openrail-m
39
  base_model: {base_model}
40
  training_prompt: {training_prompt}
@@ -59,7 +56,7 @@ inference: false
59
  ## Related papers:
60
  - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
  - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
- '''
63
 
64
- with open(save_dir / 'README.md', 'w') as f:
65
  f.write(model_card)
 
5
 
6
  def find_exp_dirs() -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / "experiments"
9
  if not exp_root_dir.exists():
10
  return []
11
+ exp_dirs = sorted(exp_root_dir.glob("*"))
12
+ exp_dirs = [exp_dir for exp_dir in exp_dirs if (exp_dir / "model_index.json").exists()]
 
 
 
13
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
14
 
15
 
 
17
  save_dir: pathlib.Path,
18
  base_model: str,
19
  training_prompt: str,
20
+ test_prompt: str = "",
21
+ test_image_dir: str = "",
22
  ) -> None:
23
+ image_str = ""
24
  if test_prompt and test_image_dir:
25
+ image_paths = sorted((save_dir / test_image_dir).glob("*.gif"))
26
  if image_paths:
27
  image_path = image_paths[-1]
28
  rel_path = image_path.relative_to(save_dir)
29
+ image_str = f"""## Samples
30
  Test prompt: {test_prompt}
31
 
32
+ ![{image_path.stem}]({rel_path})"""
33
 
34
+ model_card = f"""---
35
  license: creativeml-openrail-m
36
  base_model: {base_model}
37
  training_prompt: {training_prompt}
 
56
  ## Related papers:
57
  - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
58
  - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
59
+ """
60
 
61
+ with open(save_dir / "README.md", "w") as f:
62
  f.write(model_card)