diar_speech / app.py
Aray Karjauv
.
9a5fefb
# workaround "cannot import name 'Config' from 'omegaconf'"
import pip
pip.main(['install', 'omegaconf==2.3.0'])
pip.main(['install', 'pytorch-lightning==1.8.4'])
# import os
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES']='1'
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import streamlit as st
import tempfile
import string
# file storage for streamlit < 1.11
from streamlit.in_memory_file_manager import in_memory_file_manager as file_mng
def run():
if video_file is None:
return
progress_bar.progress(1)
placeholder.write("Downloading pre-trained model...")
from backend import get_speakers, split_audio, get_subtitles, timeline_to_vtt, calc_speaker_percentage
progress_bar.progress(15)
video_file.seek(0)
# file storage for streamlit < 1.11
# id = storage.load_and_get_id(video_file.read(), video_file.type, "media")
# url = storage.get_url(id)
# file storage for streamlit > 1.10
video_url = file_mng.add(video_file.read(), video_file.type, "media").url
with tempfile.TemporaryDirectory() as tmpdirname:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as uploaded_fp:
uploaded_fp.write(video_file.getbuffer())
uploaded_fp.seek(0)
# split audio to fit in memory
duration = split_audio(tmpdirname, uploaded_fp)
placeholder.write("Removing noise...")
use_auth_token = True
if st.secrets.get("ACCESS_TOKEN"):
use_auth_token = st.secrets["ACCESS_TOKEN"]
get_speakers(tmpdirname, use_auth_token)
progress_bar.progress(50)
# https://docs.streamlit.io/knowledge-base/using-streamlit/where-file-uploader-store-when-deleted
# https://github.com/streamlit/streamlit/blob/10ae0d651b18d4258e3b7cbbc9313d395a073768/lib/streamlit/elements/media.py#L204
# https://github.com/streamlit/streamlit/blob/76d1aebccfb29e2ff6e7c2b23ef24eaa4ef5c59e/lib/streamlit/elements/media.py#L210
# https://github.com/streamlit/streamlit/blob/9201a1980301a0cd62a5937982c410df08847a2f/lib/streamlit/runtime/media_file_storage.py
# https://github.com/streamlit/streamlit/blob/10ae0d651b18d4258e3b7cbbc9313d395a073768/lib/streamlit/runtime/media_file_storage.py#L22
# https://github.com/streamlit/streamlit/blob/10ae0d651b18d4258e3b7cbbc9313d395a073768/lib/streamlit/runtime/memory_media_file_storage.py#L105
# https://github.com/streamlit/streamlit/pull/5072
placeholder.write("Diarisation...")
speaker_diarisation, cleaned_path = get_speakers(tmpdirname)
progress_bar.progress(75)
placeholder.write("Extracting subtitles...")
timeline = get_subtitles(speaker_diarisation, cleaned_path)
progress_bar.progress(0)
placeholder.empty()
vtt = timeline_to_vtt(timeline)
percentages = calc_speaker_percentage(timeline, duration)
print(vtt)
vtt_url = file_mng.add(str.encode(vtt), "text/vtt", "downloadable").url
print(vtt_url)
# st.markdown(result, unsafe_allow_html=True)text/vtt
st.markdown(render_player(video_url, vtt_url, percentages), unsafe_allow_html=True)
placeholder = st.empty()
progress_bar = st.progress(0)
with open("player.html") as f:
player_template = string.Template(f.read())
@st.cache
def render_player(video_url, vtt_url, percentages):
return player_template.safe_substitute(url=video_url, vtt=vtt_url, percentages=percentages)
video_file = st.file_uploader("Choose an MP4 file", type=["mp4"])
st.button('Run',key="run", on_click=run)