Переглянути джерело

Expose more model parameters

Sergey Chernyaev 1 рік тому
батько
коміт
ec1e8e2c19
3 змінених файлів з 14 додано та 5 видалено
  1. 6 0
      auto_subtitle/cli.py
  2. 5 3
      auto_subtitle/main.py
  3. 3 2
      auto_subtitle/utils/whisper.py

+ 6 - 0
auto_subtitle/cli.py

@@ -14,6 +14,12 @@ def main():
                         help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
     parser.add_argument("--model", default="small",
                         choices=available_models(), help="name of the Whisper model to use")
+    parser.add_argument("--device", type=str, default="auto", choices=[
+                        "cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
+    parser.add_argument("--compute_type", type=str, default="default", choices=[
+                        "int8", "int8_float32", "int8_float16",
+                        "int8_bfloat16", "int16", "float16",
+                        "bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.")
     parser.add_argument("--output_dir", "-o", type=str,
                         default=".", help="directory to save the outputs")
     parser.add_argument("--output_srt", type=str2bool, default=False,

+ 5 - 3
auto_subtitle/main.py

@@ -12,6 +12,8 @@ def process(args: dict):
     srt_only: bool = args.pop("srt_only")
     language: str = args.pop("language")
     sample_interval: str = args.pop("sample_interval")
+    device: str = args.pop("device")
+    compute_type: str = args.pop("compute_type")
     
     os.makedirs(output_dir, exist_ok=True)
 
@@ -25,7 +27,7 @@ def process(args: dict):
 
     audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
     subtitles = get_subtitles(
-        audios, output_srt or srt_only, output_dir, model_name, args
+        audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
     )
 
     if srt_only:
@@ -33,8 +35,8 @@ def process(args: dict):
 
     overlay_subtitles(subtitles, output_dir, sample_interval)
 
-def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, model_args: dict):
-    model = WhisperAI(model_name, model_args)
+def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
+    model = WhisperAI(model_name, device, compute_type, model_args)
 
     subtitles_path = {}
 

+ 3 - 2
auto_subtitle/utils/whisper.py

@@ -3,8 +3,8 @@ import faster_whisper
 from tqdm import tqdm
 
 class WhisperAI:
-    def __init__(self, model_name, model_args):
-        self.model = faster_whisper.WhisperModel(model_name, device="cuda", compute_type="float16")
+    def __init__(self, model_name, device, compute_type, model_args):
+        self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type)
         self.model_args = model_args
 
     def transcribe(self, audio_path):
@@ -18,3 +18,4 @@ class WhisperAI:
             for segment in segments:
                 yield segment
                 pbar.update(segment.end - segment.start)
+            pbar.update(0)