mirror of
https://github.com/karl0ss/bazarr-ai-sub-generator.git
synced 2025-04-26 14:59:21 +01:00
Expose more model parameters
This commit is contained in:
parent
d8a3d96f52
commit
ec1e8e2c19
@ -14,6 +14,12 @@ def main():
|
|||||||
help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
|
help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
|
||||||
parser.add_argument("--model", default="small",
|
parser.add_argument("--model", default="small",
|
||||||
choices=available_models(), help="name of the Whisper model to use")
|
choices=available_models(), help="name of the Whisper model to use")
|
||||||
|
parser.add_argument("--device", type=str, default="auto", choices=[
|
||||||
|
"cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
|
||||||
|
parser.add_argument("--compute_type", type=str, default="default", choices=[
|
||||||
|
"int8", "int8_float32", "int8_float16",
|
||||||
|
"int8_bfloat16", "int16", "float16",
|
||||||
|
"bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.")
|
||||||
parser.add_argument("--output_dir", "-o", type=str,
|
parser.add_argument("--output_dir", "-o", type=str,
|
||||||
default=".", help="directory to save the outputs")
|
default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_srt", type=str2bool, default=False,
|
parser.add_argument("--output_srt", type=str2bool, default=False,
|
||||||
|
@ -12,6 +12,8 @@ def process(args: dict):
|
|||||||
srt_only: bool = args.pop("srt_only")
|
srt_only: bool = args.pop("srt_only")
|
||||||
language: str = args.pop("language")
|
language: str = args.pop("language")
|
||||||
sample_interval: str = args.pop("sample_interval")
|
sample_interval: str = args.pop("sample_interval")
|
||||||
|
device: str = args.pop("device")
|
||||||
|
compute_type: str = args.pop("compute_type")
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -25,7 +27,7 @@ def process(args: dict):
|
|||||||
|
|
||||||
audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
|
audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
|
||||||
subtitles = get_subtitles(
|
subtitles = get_subtitles(
|
||||||
audios, output_srt or srt_only, output_dir, model_name, args
|
audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
|
||||||
)
|
)
|
||||||
|
|
||||||
if srt_only:
|
if srt_only:
|
||||||
@ -33,8 +35,8 @@ def process(args: dict):
|
|||||||
|
|
||||||
overlay_subtitles(subtitles, output_dir, sample_interval)
|
overlay_subtitles(subtitles, output_dir, sample_interval)
|
||||||
|
|
||||||
def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, model_args: dict):
|
def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
|
||||||
model = WhisperAI(model_name, model_args)
|
model = WhisperAI(model_name, device, compute_type, model_args)
|
||||||
|
|
||||||
subtitles_path = {}
|
subtitles_path = {}
|
||||||
|
|
||||||
|
@ -3,8 +3,8 @@ import faster_whisper
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
class WhisperAI:
|
class WhisperAI:
|
||||||
def __init__(self, model_name, model_args):
|
def __init__(self, model_name, device, compute_type, model_args):
|
||||||
self.model = faster_whisper.WhisperModel(model_name, device="cuda", compute_type="float16")
|
self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type)
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
|
|
||||||
def transcribe(self, audio_path):
|
def transcribe(self, audio_path):
|
||||||
@ -18,3 +18,4 @@ class WhisperAI:
|
|||||||
for segment in segments:
|
for segment in segments:
|
||||||
yield segment
|
yield segment
|
||||||
pbar.update(segment.end - segment.start)
|
pbar.update(segment.end - segment.start)
|
||||||
|
pbar.update(0)
|
Loading…
x
Reference in New Issue
Block a user