add cuda deps

This commit is contained in:
Karl 2024-07-13 09:22:02 +00:00
parent 77b28df03d
commit fde1b4d89e
5 changed files with 55 additions and 23 deletions

4
.vscode/launch.json vendored
View File

@ -5,8 +5,8 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python: Current File", "name": "Python Debugger: Current File",
"type": "python", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "${file}", "program": "${file}",
"console": "integratedTerminal", "console": "integratedTerminal",

View File

@ -8,8 +8,20 @@ from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
from utils.sonarr import update_show_in_sonarr from utils.sonarr import update_show_in_sonarr
from utils.whisper import WhisperAI from utils.whisper import WhisperAI
def measure_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
print(f"Function '{func.__name__}' executed in: {duration:.6f} seconds")
return result
return wrapper
def process(args: dict): def process(args: dict):
model_name: str = args.pop("model") model_name: str = args.pop("model")
language: str = args.pop("language") language: str = args.pop("language")
sample_interval: str = args.pop("sample_interval") sample_interval: str = args.pop("sample_interval")
@ -44,7 +56,7 @@ def process(args: dict):
time.sleep(5) time.sleep(5)
sync_series() sync_series()
@measure_time
def get_subtitles( def get_subtitles(
audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
): ):

View File

@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
for i, segment in enumerate(transcript, start=1): for i, segment in enumerate(transcript, start=1):
print( print(
f"{i}\n" f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True)} --> " f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
f"{format_timestamp(segment.end, always_include_hours=True)}\n" f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
f"{segment.text.strip().replace('-->', '->')}\n", f"{segment['text'].strip().replace('-->', '->')}\n",
file=file, file=file,
flush=True, flush=True,
) )

View File

@ -1,9 +1,9 @@
import warnings import warnings
import faster_whisper import torch
import whisper
from tqdm import tqdm from tqdm import tqdm
# pylint: disable=R0903
class WhisperAI: class WhisperAI:
""" """
Wrapper class for the Whisper speech recognition model with additional functionality. Wrapper class for the Whisper speech recognition model with additional functionality.
@ -23,23 +23,35 @@ class WhisperAI:
``` ```
Args: Args:
- model_args: Arguments to pass to WhisperModel initialize method - model_args (dict): Arguments to pass to Whisper model initialization
- model_size_or_path (str): The name of the Whisper model to use. - model_size (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu", "cuda", "auto"). - device (str): The device to use for computation ("cpu" or "cuda").
- compute_type (str): The type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
- transcribe_args (dict): Additional arguments to pass to the transcribe method. - transcribe_args (dict): Additional arguments to pass to the transcribe method.
Attributes: Attributes:
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model. - model (whisper.Whisper): The underlying Whisper speech recognition model.
- device (torch.device): The device to use for computation.
- transcribe_args (dict): Additional arguments used for transcribe method. - transcribe_args (dict): Additional arguments used for transcribe method.
Methods: Methods:
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments. - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
""" """
def __init__(self, model_args: dict, transcribe_args: dict): def __init__(self, model_args: dict, transcribe_args: dict):
self.model = faster_whisper.WhisperModel(**model_args) """
Initializes the WhisperAI instance.
Args:
- model_args (dict): Arguments to initialize the Whisper model.
- transcribe_args (dict): Additional arguments for the transcribe method.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Set device for computation
self.device = torch.device(device)
# Load the Whisper model with the specified size
self.model = whisper.load_model("base").to(self.device)
# Store the additional transcription arguments
self.transcribe_args = transcribe_args self.transcribe_args = transcribe_args
def transcribe(self, audio_path: str): def transcribe(self, audio_path: str):
@ -50,17 +62,24 @@ class WhisperAI:
- audio_path (str): The path to the audio file for transcription. - audio_path (str): The path to the audio file for transcription.
Yields: Yields:
- faster_whisper.TranscriptionSegment: An individual transcription segment. - dict: An individual transcription segment.
""" """
# Suppress warnings during transcription
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
segments, info = self.model.transcribe(audio_path, **self.transcribe_args) # Load and transcribe the audio file
result = self.model.transcribe(audio_path, **self.transcribe_args)
# Restore default warning behavior
warnings.filterwarnings("default") warnings.filterwarnings("default")
# Same precision as the Whisper timestamps. # Calculate the total duration from the segments
total_duration = round(info.duration, 2) total_duration = max(segment["end"] for segment in result["segments"])
# Create a progress bar with the total duration of the audio file
with tqdm(total=total_duration, unit=" seconds") as pbar: with tqdm(total=total_duration, unit=" seconds") as pbar:
for segment in segments: for segment in result["segments"]:
# Yield each transcription segment
yield segment yield segment
pbar.update(segment.end - segment.start) # Update the progress bar with the duration of the current segment
pbar.update(segment["end"] - segment["start"])
# Ensure the progress bar reaches 100% upon completion
pbar.update(0) pbar.update(0)

View File

@ -1,3 +1,4 @@
faster-whisper==0.10.0 faster-whisper==0.10.0
tqdm==4.56.0 tqdm==4.56.0
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
git+https://github.com/openai/whisper.git