main.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import warnings
  3. import tempfile
  4. from .utils.files import filename, write_srt
  5. from .utils.ffmpeg import get_audio, overlay_subtitles
  6. from .utils.whisper import WhisperAI
  7. def process(args: dict):
  8. model_name: str = args.pop("model")
  9. output_dir: str = args.pop("output_dir")
  10. output_srt: bool = args.pop("output_srt")
  11. srt_only: bool = args.pop("srt_only")
  12. language: str = args.pop("language")
  13. sample_interval: str = args.pop("sample_interval")
  14. os.makedirs(output_dir, exist_ok=True)
  15. if model_name.endswith(".en"):
  16. warnings.warn(
  17. f"{model_name} is an English-only model, forcing English detection.")
  18. args["language"] = "en"
  19. # if translate task used and language argument is set, then use it
  20. elif language != "auto":
  21. args["language"] = language
  22. audios = get_audio(args.pop("video"), args.pop(
  23. 'audio_channel'), sample_interval)
  24. model_args = {}
  25. model_args["model_size_or_path"] = model_name
  26. model_args["device"] = args.pop("device")
  27. model_args["compute_type"] = args.pop("compute_type")
  28. srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir()
  29. subtitles = get_subtitles(audios, srt_output_dir, model_args, args)
  30. if srt_only:
  31. return
  32. overlay_subtitles(subtitles, output_dir, sample_interval)
  33. def get_subtitles(audio_paths: list, output_dir: str,
  34. model_args: dict, transcribe_args: dict):
  35. model = WhisperAI(model_args, transcribe_args)
  36. subtitles_path = {}
  37. for path, audio_path in audio_paths.items():
  38. print(
  39. f"Generating subtitles for {filename(path)}... This might take a while."
  40. )
  41. srt_path = os.path.join(output_dir, f"{filename(path)}.srt")
  42. segments = model.transcribe(audio_path)
  43. with open(srt_path, "w", encoding="utf-8") as srt:
  44. write_srt(segments, file=srt)
  45. subtitles_path[path] = srt_path
  46. return subtitles_path