From ec1e8e2c1909e7151b7d5cb18b66955851b621d5 Mon Sep 17 00:00:00 2001 From: Sergey Chernyaev Date: Fri, 5 Jan 2024 17:48:37 +0100 Subject: [PATCH] Expose more model parameters --- auto_subtitle/cli.py | 6 ++++++ auto_subtitle/main.py | 8 +++++--- auto_subtitle/utils/whisper.py | 5 +++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/auto_subtitle/cli.py b/auto_subtitle/cli.py index 2e0eac8..25c5d97 100644 --- a/auto_subtitle/cli.py +++ b/auto_subtitle/cli.py @@ -14,6 +14,12 @@ def main(): help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--device", type=str, default="auto", choices=[ + "cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")") + parser.add_argument("--compute_type", type=str, default="default", choices=[ + "int8", "int8_float32", "int8_float16", + "int8_bfloat16", "int16", "float16", + "bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_srt", type=str2bool, default=False, diff --git a/auto_subtitle/main.py b/auto_subtitle/main.py index c9a14fb..26650f9 100644 --- a/auto_subtitle/main.py +++ b/auto_subtitle/main.py @@ -12,6 +12,8 @@ def process(args: dict): srt_only: bool = args.pop("srt_only") language: str = args.pop("language") sample_interval: str = args.pop("sample_interval") + device: str = args.pop("device") + compute_type: str = args.pop("compute_type") os.makedirs(output_dir, exist_ok=True) @@ -25,7 +27,7 @@ def process(args: dict): audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval) subtitles = get_subtitles( - audios, output_srt or srt_only, output_dir, model_name, args + audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args ) if srt_only: @@ -33,8 +35,8 @@ def process(args: dict): overlay_subtitles(subtitles, output_dir, sample_interval) -def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, model_args: dict): - model = WhisperAI(model_name, model_args) +def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict): + model = WhisperAI(model_name, device, compute_type, model_args) subtitles_path = {} diff --git a/auto_subtitle/utils/whisper.py b/auto_subtitle/utils/whisper.py index a4984e1..3bad24c 100644 --- a/auto_subtitle/utils/whisper.py +++ b/auto_subtitle/utils/whisper.py @@ -3,8 +3,8 @@ import faster_whisper from tqdm import tqdm class WhisperAI: - def __init__(self, model_name, model_args): - self.model = faster_whisper.WhisperModel(model_name, device="cuda", compute_type="float16") + def __init__(self, model_name, device, compute_type, model_args): + self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type) self.model_args = model_args def transcribe(self, audio_path): @@ -18,3 +18,4 @@ class WhisperAI: for segment in segments: yield segment pbar.update(segment.end - segment.start) + pbar.update(0) \ No newline at end of file