From fde1b4d89ef404fd6bd0b326e6d861e9ab174303 Mon Sep 17 00:00:00 2001 From: Karl Date: Sat, 13 Jul 2024 09:22:02 +0000 Subject: [PATCH] add cuda deps --- .vscode/launch.json | 4 +- bazarr-ai-sub-generator/main.py | 14 ++++++- bazarr-ai-sub-generator/utils/files.py | 6 +-- bazarr-ai-sub-generator/utils/whisper.py | 51 ++++++++++++++++-------- requirements.txt | 3 +- 5 files changed, 55 insertions(+), 23 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 94c9cc5..a5131f3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,8 +5,8 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Current File", - "type": "python", + "name": "Python Debugger: Current File", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", diff --git a/bazarr-ai-sub-generator/main.py b/bazarr-ai-sub-generator/main.py index 64074a8..d5202fa 100644 --- a/bazarr-ai-sub-generator/main.py +++ b/bazarr-ai-sub-generator/main.py @@ -8,8 +8,20 @@ from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series from utils.sonarr import update_show_in_sonarr from utils.whisper import WhisperAI +def measure_time(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + duration = end_time - start_time + print(f"Function '{func.__name__}' executed in: {duration:.6f} seconds") + return result + return wrapper + + def process(args: dict): + model_name: str = args.pop("model") language: str = args.pop("language") sample_interval: str = args.pop("sample_interval") @@ -44,7 +56,7 @@ def process(args: dict): time.sleep(5) sync_series() - +@measure_time def get_subtitles( audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict ): diff --git a/bazarr-ai-sub-generator/utils/files.py b/bazarr-ai-sub-generator/utils/files.py index ea40253..29faa08 100644 --- a/bazarr-ai-sub-generator/utils/files.py +++ b/bazarr-ai-sub-generator/utils/files.py @@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO): for i, segment in enumerate(transcript, start=1): print( f"{i}\n" - f"{format_timestamp(segment.start, always_include_hours=True)} --> " - f"{format_timestamp(segment.end, always_include_hours=True)}\n" - f"{segment.text.strip().replace('-->', '->')}\n", + f"{format_timestamp(segment['start'], always_include_hours=True)} --> " + f"{format_timestamp(segment['end'], always_include_hours=True)}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", file=file, flush=True, ) diff --git a/bazarr-ai-sub-generator/utils/whisper.py b/bazarr-ai-sub-generator/utils/whisper.py index 5e823b1..6db019c 100644 --- a/bazarr-ai-sub-generator/utils/whisper.py +++ b/bazarr-ai-sub-generator/utils/whisper.py @@ -1,9 +1,9 @@ import warnings -import faster_whisper +import torch +import whisper from tqdm import tqdm -# pylint: disable=R0903 class WhisperAI: """ Wrapper class for the Whisper speech recognition model with additional functionality. @@ -23,23 +23,35 @@ class WhisperAI: ``` Args: - - model_args: Arguments to pass to WhisperModel initialize method - - model_size_or_path (str): The name of the Whisper model to use. - - device (str): The device to use for computation ("cpu", "cuda", "auto"). - - compute_type (str): The type to use for computation. - See https://opennmt.net/CTranslate2/quantization.html. + - model_args (dict): Arguments to pass to Whisper model initialization + - model_size (str): The name of the Whisper model to use. + - device (str): The device to use for computation ("cpu" or "cuda"). - transcribe_args (dict): Additional arguments to pass to the transcribe method. Attributes: - - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model. + - model (whisper.Whisper): The underlying Whisper speech recognition model. + - device (torch.device): The device to use for computation. - transcribe_args (dict): Additional arguments used for transcribe method. Methods: - - transcribe(audio_path): Transcribes an audio file and yields the resulting segments. + - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments. """ def __init__(self, model_args: dict, transcribe_args: dict): - self.model = faster_whisper.WhisperModel(**model_args) + """ + Initializes the WhisperAI instance. + + Args: + - model_args (dict): Arguments to initialize the Whisper model. + - transcribe_args (dict): Additional arguments for the transcribe method. + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + print(device) + # Set device for computation + self.device = torch.device(device) + # Load the Whisper model with the specified size + self.model = whisper.load_model("base").to(self.device) + # Store the additional transcription arguments self.transcribe_args = transcribe_args def transcribe(self, audio_path: str): @@ -50,17 +62,24 @@ class WhisperAI: - audio_path (str): The path to the audio file for transcription. Yields: - - faster_whisper.TranscriptionSegment: An individual transcription segment. + - dict: An individual transcription segment. """ + # Suppress warnings during transcription warnings.filterwarnings("ignore") - segments, info = self.model.transcribe(audio_path, **self.transcribe_args) + # Load and transcribe the audio file + result = self.model.transcribe(audio_path, **self.transcribe_args) + # Restore default warning behavior warnings.filterwarnings("default") - # Same precision as the Whisper timestamps. - total_duration = round(info.duration, 2) + # Calculate the total duration from the segments + total_duration = max(segment["end"] for segment in result["segments"]) + # Create a progress bar with the total duration of the audio file with tqdm(total=total_duration, unit=" seconds") as pbar: - for segment in segments: + for segment in result["segments"]: + # Yield each transcription segment yield segment - pbar.update(segment.end - segment.start) + # Update the progress bar with the duration of the current segment + pbar.update(segment["end"] - segment["start"]) + # Ensure the progress bar reaches 100% upon completion pbar.update(0) diff --git a/requirements.txt b/requirements.txt index eab95da..ec34ef1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ faster-whisper==0.10.0 tqdm==4.56.0 -ffmpeg-python==0.2.0 \ No newline at end of file +ffmpeg-python==0.2.0 +git+https://github.com/openai/whisper.git