KevinGeng commited on
Commit
11dde70
1 Parent(s): 6b41fa2

support google drive cloud service for future

Browse files
app.py CHANGED
@@ -21,6 +21,23 @@ import librosa.display
21
  import matplotlib.pyplot as plt
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # local import
25
  import sys
26
 
@@ -30,7 +47,6 @@ import lightning_module
30
  # Load automos
31
  # config_yaml = sys.argv[1]
32
  config_yaml = "config/Arthur.yaml"
33
-
34
  with open(config_yaml, "r") as f:
35
  # pdb.set_trace()
36
  try:
@@ -40,9 +56,9 @@ with open(config_yaml, "r") as f:
40
  exit()
41
 
42
  # Auto load examples
43
-
44
- with open(config["ref_txt"], "r") as f:
45
  refs = f.readlines()
 
46
  refs_ids = [x.split()[0] for x in refs]
47
  refs_txt = [" ".join(x.split()[1:]) for x in refs]
48
  ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str")
@@ -112,7 +128,7 @@ class ChangeSampleRate(nn.Module):
112
 
113
  # MOS model
114
  model = lightning_module.BaselineLightningModule.load_from_checkpoint(
115
- "./src/epoch=3-step=7459.ckpt"
116
  ).eval()
117
 
118
  # Get Speech Interval
@@ -138,11 +154,15 @@ def plot_UV(signal, audio_interv, sr):
138
  ax[1].set_ylim([-0.1, 1.1])
139
  return fig
140
 
 
 
 
 
141
  def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
142
  if audio_path == None:
143
  audio_path = _
144
  print("using ref audio as eval audio since it's empty")
145
-
146
  wav, sr = torchaudio.load(audio_path)
147
  if wav.shape[0] != 1:
148
  wav = wav[0, :]
@@ -214,6 +234,9 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
214
  "GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
215
  )
216
 
 
 
 
217
  return (
218
  fig_h,
219
  predic_mos,
@@ -297,10 +320,11 @@ info = gr.Interface(
297
  if config["exp_id"] == None:
298
  config["exp_id"] = Path(config_yaml).stem
299
 
300
- ## This is the theme for the interface
301
  css = """
302
  .ref_text textarea {font-size: 40px !important}
303
  .message textarea {font-size: 40px !important}
 
304
  """
305
 
306
  my_theme = gr.themes.Default().set(
@@ -313,6 +337,50 @@ my_theme = gr.themes.Default().set(
313
  # Callback for saving the recording
314
  callback = gr.CSVLogger()
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  with gr.Blocks(css=css, theme=my_theme) as demo:
317
  with gr.Column():
318
  with gr.Row():
@@ -450,6 +518,7 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
450
  preprocess=False,
451
  api_name="flagging",
452
  )
 
453
  with gr.Row():
454
  b3 = gr.ClearButton(
455
  [
 
21
  import matplotlib.pyplot as plt
22
 
23
 
24
+ # Google cloud service
25
+ from googleapiclient.discovery import build
26
+ from google.oauth2 import service_account
27
+ from googleapiclient.http import MediaFileUpload
28
+ import datetime
29
+
30
+ # 来自Google Cloud控制台的JSON凭据文件
31
+ credentials_file = "./src/peerless-window-254907-b386b71c0d99.json"
32
+ # "./client_secret_576367903492-diuopf97kn9eh1gte3vh65errtca1o64.apps.googleusercontent.com.json"
33
+ # Google Drive API版本
34
+ api_version = 'v3'
35
+
36
+ # 创建服务对象
37
+ credentials = service_account.Credentials.from_service_account_file(
38
+ credentials_file, scopes=['https://www.googleapis.com/auth/drive'])
39
+ service = build('drive', api_version, credentials=credentials)
40
+
41
  # local import
42
  import sys
43
 
 
47
  # Load automos
48
  # config_yaml = sys.argv[1]
49
  config_yaml = "config/Arthur.yaml"
 
50
  with open(config_yaml, "r") as f:
51
  # pdb.set_trace()
52
  try:
 
56
  exit()
57
 
58
  # Auto load examples
59
+ with open(config['ref_txt'], "r") as f:
 
60
  refs = f.readlines()
61
+ # refs = np.loadtxt(config["ref_txt"], delimiter="\n", dtype="str")
62
  refs_ids = [x.split()[0] for x in refs]
63
  refs_txt = [" ".join(x.split()[1:]) for x in refs]
64
  ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str")
 
128
 
129
  # MOS model
130
  model = lightning_module.BaselineLightningModule.load_from_checkpoint(
131
+ "src/epoch=3-step=7459.ckpt"
132
  ).eval()
133
 
134
  # Get Speech Interval
 
154
  ax[1].set_ylim([-0.1, 1.1])
155
  return fig
156
 
157
+
158
+ # Evaluation model
159
+
160
+
161
  def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
162
  if audio_path == None:
163
  audio_path = _
164
  print("using ref audio as eval audio since it's empty")
165
+
166
  wav, sr = torchaudio.load(audio_path)
167
  if wav.shape[0] != 1:
168
  wav = wav[0, :]
 
234
  "GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
235
  )
236
 
237
+ # Google Drive saving # TODO
238
+ click_google_saving(audio_path)
239
+
240
  return (
241
  fig_h,
242
  predic_mos,
 
320
  if config["exp_id"] == None:
321
  config["exp_id"] = Path(config_yaml).stem
322
 
323
+ ## Theme
324
  css = """
325
  .ref_text textarea {font-size: 40px !important}
326
  .message textarea {font-size: 40px !important}
327
+
328
  """
329
 
330
  my_theme = gr.themes.Default().set(
 
337
  # Callback for saving the recording
338
  callback = gr.CSVLogger()
339
 
340
+ def generate_now_time_wav():
341
+ # Get the current date and time
342
+ current_time = datetime.datetime.now()
343
+
344
+ # Format the date and time as a string
345
+ time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S")
346
+
347
+ # Create the WAV file name with the formatted time
348
+ wavfile_name = f"audio_{time_string}.wav"
349
+ return wavfile_name
350
+
351
+ # Add google drive cloud saving
352
+ def click_google_saving(audio_file,
353
+
354
+ ):
355
+ # reference_id,
356
+ # reference_textbox,
357
+ # reference_PPM,
358
+ # predict_mos,
359
+ # hyp,
360
+ # wer,
361
+ # ppm,
362
+ # msg,
363
+ name = generate_now_time_wav()
364
+ # 上传文件
365
+ media = MediaFileUpload(audio_file, mimetype='audio/wav')
366
+
367
+ request = service.files().create(
368
+ media_body=media,
369
+ body={'name': name,
370
+ }
371
+ )
372
+ # 'reference_id': reference_id,
373
+ # "reference_textbox": reference_textbox,
374
+ # "reference_PPM": reference_PPM,
375
+ # "predict_mos": predict_mos,
376
+ # "hyp": hyp,
377
+ # "wer": wer,
378
+ # "ppm": ppm,
379
+ # "msg": msg
380
+ response = request.execute()
381
+ # return response.get('id')
382
+
383
+
384
  with gr.Blocks(css=css, theme=my_theme) as demo:
385
  with gr.Column():
386
  with gr.Row():
 
518
  preprocess=False,
519
  api_name="flagging",
520
  )
521
+
522
  with gr.Row():
523
  b3 = gr.ClearButton(
524
  [
local/check_data.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from googleapiclient.discovery import build
2
+ from google.oauth2 import service_account
3
+ from googleapiclient.http import MediaFileUpload
4
+ import pdb
5
+ pdb.set_trace()
6
+
7
+ import gradio as gr
8
+
9
+ # 来自Google Cloud控制台的JSON凭据文件
10
+ credentials_file = "./src/peerless-window-254907-b386b71c0d99.json"
11
+ api_version = 'v3'
12
+
13
+ # 创建服务对象
14
+ credentials = service_account.Credentials.from_service_account_file(
15
+ credentials_file, scopes=['https://www.googleapis.com/auth/drive'])
16
+ service = build('drive', api_version, credentials=credentials)
17
+
18
+ # 列出文件
19
+ results = service.files().list().execute()
20
+ files = results.get('files', [])
21
+
22
+ print(files)
23
+ from googleapiclient.http import MediaIoBaseDownload
24
+ import io
25
+
26
+ file_id = "1EqHciegNxZSyWJ9Nizo1QmRQEgTkgWCo"
27
+ # Get the file's metadata
28
+ file = service.files().get(fileId=file_id).execute()
29
+
30
+ pdb.set_trace()
31
+ request = service.files().get_media(fileId="1EqHciegNxZSyWJ9Nizo1QmRQEgTkgWCo")
32
+ with open(file['name'], 'wb') as file_obj:
33
+ downloader = MediaIoBaseDownload(file_obj, request)
34
+ done = False
35
+ while not done:
36
+ status, done = downloader.next_chunk()
37
+ print(f"Download {int(status.progress() * 100)}%.")
38
+
39
+ print(f"Downloaded: {file['name']}")
40
+
41
+ pdb.set_trace()
42
+
43
+ # print('文件ID:%s' % response.get('id'))
local/test_google_drive.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from googleapiclient.discovery import build
2
+ from google.oauth2 import service_account
3
+ from googleapiclient.http import MediaFileUpload
4
+ import pdb
5
+
6
+ import gradio as gr
7
+
8
+ # 来自Google Cloud控制台的JSON凭据文件
9
+ credentials_file = "./src/peerless-window-254907-b386b71c0d99.json"
10
+ # "./client_secret_576367903492-diuopf97kn9eh1gte3vh65errtca1o64.apps.googleusercontent.com.json"
11
+ # Google Drive API版本
12
+ api_version = 'v3'
13
+
14
+ # 创建服务对象
15
+ credentials = service_account.Credentials.from_service_account_file(
16
+ credentials_file, scopes=['https://www.googleapis.com/auth/drive'])
17
+ service = build('drive', api_version, credentials=credentials)
18
+
19
+
20
+ import gradio as gr
21
+ from transformers import pipeline
22
+ import numpy as np
23
+ import librosa
24
+ import torchaudio
25
+
26
+ import datetime
27
+
28
+ def generate_now_time_wav():
29
+ # Get the current date and time
30
+ current_time = datetime.datetime.now()
31
+
32
+ # Format the date and time as a string
33
+ time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S")
34
+
35
+ # Create the WAV file name with the formatted time
36
+ wavfile_name = f"audio_{time_string}.wav"
37
+ return wavfile_name
38
+
39
+ # transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
40
+
41
+ def transcribe(audio_path):
42
+ if audio_path == None:
43
+ print("using ref audio as eval audio since it's empty")
44
+
45
+ wav, sr = torchaudio.load(audio_path)
46
+ if wav.shape[0] != 1:
47
+ wav = wav[0, :]
48
+ print(wav.shape)
49
+
50
+ name = generate_now_time_wav()
51
+ # 上传文件
52
+ media = MediaFileUpload(audio_path, mimetype='audio/wav')
53
+ request = service.files().create(
54
+ media_body=media,
55
+ body={'name': name}
56
+ )
57
+ response = request.execute()
58
+
59
+ return response.get('id')
60
+
61
+ demo = gr.Interface(
62
+ fn = transcribe,
63
+ inputs = gr.Audio(source="microphone", type='filepath'),
64
+ outputs = "text",
65
+ )
66
+ # file_path = 'data/3_michael_20230619_100/1st_session_ZOOM0015_002.wav'
67
+
68
+ # x = gr.Audio(source="upload", type='filepath'),
69
+ # pdb.set_trace()
70
+ # x = transcribe(file_path)
71
+ # pdb.set_trace()
72
+
73
+ demo.launch()
74
+
75
+ # # 要上传的文件
76
+ # file_name = '1st_session_ZOOM0015_001.wav'
77
+
78
+ # # 上传文件
79
+ # media = MediaFileUpload(file_path, mimetype='audio/wav')
80
+ # request = service.files().create(
81
+ # media_body=media,
82
+ # body={'name': file_name}
83
+ # )
84
+
85
+
86
+ # response = request.execute()
87
+
88
+ # # 列出文件
89
+ # results = service.files().list().execute()
90
+ # files = results.get('files', [])
91
+ # pdb.set_trace()
92
+
93
+ # print('文件ID:%s' % response.get('id'))
src/peerless-window-254907-b386b71c0d99.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "service_account",
3
+ "project_id": "peerless-window-254907",
4
+ "private_key_id": "b386b71c0d998879b5e47d776fba764d549a0696",
5
+ "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDSXL/Qf5fLbDyE\nDQxlJC/nJyIdcayfuYn2agTKm+9h1jitfctwlkIHtP7nvz+l692InGzFV+wxbXg0\nrwrgvL1blHE/CP6I1l7nQRcorgAOFiR6/BNb+nBVXoriHWD6kHxjfLfVMTzzqrK8\nWPGUWtLjykZpbuvscO4+Sdu+7Rgaw46+H1vKSWtoaMsAYBgpsh6uQZU7xB51zR4D\nHKDR5uihj1qfaf2k3FslGu9r0U/OHZ6c9je9yx0ttTTByJVB8JAmpSG8sAy+2BZI\nZ9bHDMiOg/CCdkLZds29cewS0RrqHwNuv1sKL7Ap7aCz98Q3jjlESWATST4x00yg\njho3wF5hAgMBAAECggEAAvD5vrJgydoW7IaEy+M8mtib9hTlAVrhM0zfMPioqAMM\nXZjzVelSFlcdfcYeczVE84NQaAddV5VGc/XR1MV0+M5pu2krg8bUe0JJsUNEB1Da\n8VdHMFNkOsfPNY0CdMcMe7xl4cf3RfDFzO5O01fwENxNwVlo9hK4d5q5Tvd0F1P9\n8X8AllWAYHfD33scX3OxEoyF99Ow9jgaH7Lapb6Z77GISBjYZZxIFhoEhFsjx4It\n4Cnci3upw1QBD9Wh4+8DzNMoGUBj/ZaDMRpFLwkDPXRD5dvx0bCgkLSM80E3q/AB\nq/Ca6/Bx8z42k/c1BPEr/qJ+kPFYPGVOnX/9AyH0uQKBgQDuJ/yiIxwJQANqZp8d\nMwEIpQh1fGTA+LrOeoanX/6iYjU2nrNiQKYW09snfORzSYuwj/Pb9fR/KiNJXS2g\nQ6QZUE7eG8dVEDnlTL55beGk4OB20jc2xGz0u5jDCXJ/rU9OC9VOe9E/Pu/x5Ipe\nIimpHfU2RysPBH+BpM1iyBjoDQKBgQDiH6eg5jcJvCHusDynQNB3KFBfGVmJM+xR\nM3LRFKK0IS2ZR90TajYofPlK80lyFEvUXEX+cGma0zqPnzVEkDBSelmo1EtzRksx\no3oisSBGQ9d4BT2JPBnRlhNdl1QuzGwln09TyH5ielDo907zm4MuvFfNrSd4xxkU\nPjxKpCyGpQKBgA+eW7keaFZK9m5h8IlvsN+qQxXBZLIrHcUwz+fmKcLogejlG4qU\nBtB0cGj0jd7psdmQd0Ozq6czUkEbdUSPaxGl7KYwWDBB8ioRkGRSSnwPq2jffHOB\nCkw6iVgxJGsvKIZLzF9rS1vEeuP4QwLNZsIKjuxSWoaPmvUbo8SYrtl5AoGBANGM\nXi6ISUbXNmbYoUypjsZt8JVAi63PFVdmoydIxULCYFxksWX1jnzU27zuWgjC8Ea6\nwA57pBHbX7CK7LU+HdnBEmeXXNhVswcsJNoTZQJYikvqJ02PCaolNosL2vKHdE0l\nJkFRUnX2Pha2YE72tYnQ9lle9m5Bq2cMCZluLOkVAoGAYfbHFACSa+ejAIcyYdW6\ncmQllLAxz8f35NJLg53+tWZvfAIyMBTY/eLJFb5X4gUA/1/PtBgTXAQOUHnK4HKw\nOkECQMes/HYWWD/mw4DYrPeeOcqBxP3b0eEOw1mbFwmigC4tRTLnD1cDc8zS2zdM\nIBCSOPWoWHqArBPZjzFDpoA=\n-----END PRIVATE KEY-----\n",
6
+ "client_email": "gradiotest@peerless-window-254907.iam.gserviceaccount.com",
7
+ "client_id": "100559289389957446034",
8
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9
+ "token_uri": "https://oauth2.googleapis.com/token",
10
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gradiotest%40peerless-window-254907.iam.gserviceaccount.com",
12
+ "universe_domain": "googleapis.com"
13
+ }