|
@@ -12,6 +12,8 @@ def process(args: dict):
|
|
|
srt_only: bool = args.pop("srt_only")
|
|
|
language: str = args.pop("language")
|
|
|
sample_interval: str = args.pop("sample_interval")
|
|
|
+ device: str = args.pop("device")
|
|
|
+ compute_type: str = args.pop("compute_type")
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
@@ -25,7 +27,7 @@ def process(args: dict):
|
|
|
|
|
|
audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
|
|
|
subtitles = get_subtitles(
|
|
|
- audios, output_srt or srt_only, output_dir, model_name, args
|
|
|
+ audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
|
|
|
)
|
|
|
|
|
|
if srt_only:
|
|
@@ -33,8 +35,8 @@ def process(args: dict):
|
|
|
|
|
|
overlay_subtitles(subtitles, output_dir, sample_interval)
|
|
|
|
|
|
-def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, model_args: dict):
|
|
|
- model = WhisperAI(model_name, model_args)
|
|
|
+def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
|
|
|
+ model = WhisperAI(model_name, device, compute_type, model_args)
|
|
|
|
|
|
subtitles_path = {}
|
|
|
|