cli.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. import ffmpeg
  3. import whisper
  4. import argparse
  5. import warnings
  6. import tempfile
  7. from .utils import filename, str2bool, write_srt
  8. def main():
  9. parser = argparse.ArgumentParser(
  10. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  11. parser.add_argument("video", nargs="+", type=str,
  12. help="paths to video files to transcribe")
  13. parser.add_argument("--model", default="small",
  14. choices=whisper.available_models(), help="name of the Whisper model to use")
  15. parser.add_argument("--output_dir", "-o", type=str,
  16. default=".", help="directory to save the outputs")
  17. parser.add_argument("--output_srt", type=str2bool, default=False,
  18. help="whether to output the .srt file along with the video files")
  19. parser.add_argument("--srt_only", type=str2bool, default=False,
  20. help="only generate the .srt file and not create overlayed video")
  21. parser.add_argument("--verbose", type=str2bool, default=False,
  22. help="whether to print out the progress and debug messages")
  23. parser.add_argument("--task", type=str, default="transcribe", choices=[
  24. "transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
  25. args = parser.parse_args().__dict__
  26. model_name: str = args.pop("model")
  27. output_dir: str = args.pop("output_dir")
  28. output_srt: bool = args.pop("output_srt")
  29. srt_only: bool = args.pop("srt_only")
  30. os.makedirs(output_dir, exist_ok=True)
  31. if model_name.endswith(".en"):
  32. warnings.warn(
  33. f"{model_name} is an English-only model, forcing English detection.")
  34. args["language"] = "en"
  35. model = whisper.load_model(model_name)
  36. audios = get_audio(args.pop("video"))
  37. subtitles = get_subtitles(
  38. audios, output_srt or srt_only, output_dir, lambda audio_path: model.transcribe(audio_path, **args)
  39. )
  40. if srt_only:
  41. return
  42. # bash command to download a youtube video with `youtube-dl` and save it as `video.mp4`:
  43. # youtube-dl -f 22 -o video.mp4 https://www.youtube.com/watch?v=QH2-TGUlwu4
  44. for path, srt_path in subtitles.items():
  45. out_path = os.path.join(output_dir, f"{filename(path)}.mp4")
  46. print(f"Adding subtitles to {filename(path)}...")
  47. video = ffmpeg.input(path)
  48. audio = video.audio
  49. stderr = ffmpeg.concat(
  50. video.filter('subtitles', srt_path, force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1
  51. ).output(out_path).run(quiet=True, overwrite_output=True)
  52. print(f"Saved subtitled video to {os.path.abspath(out_path)}.")
  53. def get_audio(paths):
  54. temp_dir = tempfile.gettempdir()
  55. audio_paths = {}
  56. for path in paths:
  57. print(f"Extracting audio from {filename(path)}...")
  58. output_path = os.path.join(temp_dir, f"{filename(path)}.wav")
  59. ffmpeg.input(path).output(
  60. output_path,
  61. acodec="pcm_s16le", ac=1, ar="16k"
  62. ).run(quiet=True, overwrite_output=True)
  63. audio_paths[path] = output_path
  64. return audio_paths
  65. def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, transcribe: callable):
  66. srt_path = output_dir if output_srt else tempfile.gettempdir()
  67. subtitles_path = {}
  68. for path, audio_path in audio_paths.items():
  69. srt_path = os.path.join(srt_path, f"{filename(path)}.srt")
  70. print(
  71. f"Generating subtitles for {filename(path)}... This might take a while."
  72. )
  73. warnings.filterwarnings("ignore")
  74. result = transcribe(audio_path)
  75. warnings.filterwarnings("default")
  76. with open(srt_path, "w", encoding="utf-8") as srt:
  77. write_srt(result["segments"], file=srt)
  78. subtitles_path[path] = srt_path
  79. return subtitles_path
  80. if __name__ == '__main__':
  81. main()