mirror of
https://github.com/karl0ss/bazarr-ai-sub-generator.git
synced 2025-04-26 06:49:22 +01:00
add cuda deps
This commit is contained in:
parent
77b28df03d
commit
fde1b4d89e
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@ -5,8 +5,8 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
|
@ -8,8 +8,20 @@ from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
|
||||
from utils.sonarr import update_show_in_sonarr
|
||||
from utils.whisper import WhisperAI
|
||||
|
||||
def measure_time(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(f"Function '{func.__name__}' executed in: {duration:.6f} seconds")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
def process(args: dict):
|
||||
|
||||
model_name: str = args.pop("model")
|
||||
language: str = args.pop("language")
|
||||
sample_interval: str = args.pop("sample_interval")
|
||||
@ -44,7 +56,7 @@ def process(args: dict):
|
||||
time.sleep(5)
|
||||
sync_series()
|
||||
|
||||
|
||||
@measure_time
|
||||
def get_subtitles(
|
||||
audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
|
||||
):
|
||||
|
@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment.start, always_include_hours=True)} --> "
|
||||
f"{format_timestamp(segment.end, always_include_hours=True)}\n"
|
||||
f"{segment.text.strip().replace('-->', '->')}\n",
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import warnings
|
||||
import faster_whisper
|
||||
import torch
|
||||
import whisper
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# pylint: disable=R0903
|
||||
class WhisperAI:
|
||||
"""
|
||||
Wrapper class for the Whisper speech recognition model with additional functionality.
|
||||
@ -23,23 +23,35 @@ class WhisperAI:
|
||||
```
|
||||
|
||||
Args:
|
||||
- model_args: Arguments to pass to WhisperModel initialize method
|
||||
- model_size_or_path (str): The name of the Whisper model to use.
|
||||
- device (str): The device to use for computation ("cpu", "cuda", "auto").
|
||||
- compute_type (str): The type to use for computation.
|
||||
See https://opennmt.net/CTranslate2/quantization.html.
|
||||
- model_args (dict): Arguments to pass to Whisper model initialization
|
||||
- model_size (str): The name of the Whisper model to use.
|
||||
- device (str): The device to use for computation ("cpu" or "cuda").
|
||||
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
|
||||
|
||||
Attributes:
|
||||
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
|
||||
- model (whisper.Whisper): The underlying Whisper speech recognition model.
|
||||
- device (torch.device): The device to use for computation.
|
||||
- transcribe_args (dict): Additional arguments used for transcribe method.
|
||||
|
||||
Methods:
|
||||
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
|
||||
- transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: dict, transcribe_args: dict):
|
||||
self.model = faster_whisper.WhisperModel(**model_args)
|
||||
"""
|
||||
Initializes the WhisperAI instance.
|
||||
|
||||
Args:
|
||||
- model_args (dict): Arguments to initialize the Whisper model.
|
||||
- transcribe_args (dict): Additional arguments for the transcribe method.
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(device)
|
||||
# Set device for computation
|
||||
self.device = torch.device(device)
|
||||
# Load the Whisper model with the specified size
|
||||
self.model = whisper.load_model("base").to(self.device)
|
||||
# Store the additional transcription arguments
|
||||
self.transcribe_args = transcribe_args
|
||||
|
||||
def transcribe(self, audio_path: str):
|
||||
@ -50,17 +62,24 @@ class WhisperAI:
|
||||
- audio_path (str): The path to the audio file for transcription.
|
||||
|
||||
Yields:
|
||||
- faster_whisper.TranscriptionSegment: An individual transcription segment.
|
||||
- dict: An individual transcription segment.
|
||||
"""
|
||||
# Suppress warnings during transcription
|
||||
warnings.filterwarnings("ignore")
|
||||
segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
|
||||
# Load and transcribe the audio file
|
||||
result = self.model.transcribe(audio_path, **self.transcribe_args)
|
||||
# Restore default warning behavior
|
||||
warnings.filterwarnings("default")
|
||||
|
||||
# Same precision as the Whisper timestamps.
|
||||
total_duration = round(info.duration, 2)
|
||||
# Calculate the total duration from the segments
|
||||
total_duration = max(segment["end"] for segment in result["segments"])
|
||||
|
||||
# Create a progress bar with the total duration of the audio file
|
||||
with tqdm(total=total_duration, unit=" seconds") as pbar:
|
||||
for segment in segments:
|
||||
for segment in result["segments"]:
|
||||
# Yield each transcription segment
|
||||
yield segment
|
||||
pbar.update(segment.end - segment.start)
|
||||
# Update the progress bar with the duration of the current segment
|
||||
pbar.update(segment["end"] - segment["start"])
|
||||
# Ensure the progress bar reaches 100% upon completion
|
||||
pbar.update(0)
|
||||
|
@ -1,3 +1,4 @@
|
||||
faster-whisper==0.10.0
|
||||
tqdm==4.56.0
|
||||
ffmpeg-python==0.2.0
|
||||
ffmpeg-python==0.2.0
|
||||
git+https://github.com/openai/whisper.git
|
||||
|
Loading…
x
Reference in New Issue
Block a user