main.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import warnings
  3. import tempfile
  4. from .utils.files import filename, write_srt
  5. from .utils.ffmpeg import get_audio, overlay_subtitles
  6. from .utils.whisper import WhisperAI
  7. def process(args: dict):
  8. model_name: str = args.pop("model")
  9. output_dir: str = args.pop("output_dir")
  10. output_srt: bool = args.pop("output_srt")
  11. srt_only: bool = args.pop("srt_only")
  12. language: str = args.pop("language")
  13. sample_interval: str = args.pop("sample_interval")
  14. device: str = args.pop("device")
  15. compute_type: str = args.pop("compute_type")
  16. os.makedirs(output_dir, exist_ok=True)
  17. if model_name.endswith(".en"):
  18. warnings.warn(
  19. f"{model_name} is an English-only model, forcing English detection.")
  20. args["language"] = "en"
  21. # if translate task used and language argument is set, then use it
  22. elif language != "auto":
  23. args["language"] = language
  24. audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
  25. subtitles = get_subtitles(
  26. audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
  27. )
  28. if srt_only:
  29. return
  30. overlay_subtitles(subtitles, output_dir, sample_interval)
  31. def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
  32. model = WhisperAI(model_name, device, compute_type, model_args)
  33. subtitles_path = {}
  34. for path, audio_path in audio_paths.items():
  35. print(
  36. f"Generating subtitles for {filename(path)}... This might take a while."
  37. )
  38. srt_path = output_dir if output_srt else tempfile.gettempdir()
  39. srt_path = os.path.join(srt_path, f"{filename(path)}.srt")
  40. segments = model.transcribe(audio_path)
  41. with open(srt_path, "w", encoding="utf-8") as srt:
  42. write_srt(segments, file=srt)
  43. subtitles_path[path] = srt_path
  44. return subtitles_path