mirror of
https://github.com/karl0ss/bazarr-ai-sub-generator.git
synced 2025-04-26 06:49:22 +01:00
86 lines
3.3 KiB
Python
86 lines
3.3 KiB
Python
import warnings
|
|
import torch
|
|
import whisper
|
|
from tqdm import tqdm
|
|
|
|
|
|
class WhisperAI:
|
|
"""
|
|
Wrapper class for the Whisper speech recognition model with additional functionality.
|
|
|
|
This class provides a high-level interface for transcribing audio files using the Whisper
|
|
speech recognition model. It encapsulates the model instantiation and transcription process,
|
|
allowing users to easily transcribe audio files and iterate over the resulting segments.
|
|
|
|
Usage:
|
|
```python
|
|
whisper = WhisperAI(model_args, transcribe_args)
|
|
|
|
# Transcribe an audio file and iterate over the segments
|
|
for segment in whisper.transcribe(audio_path):
|
|
# Process each transcription segment
|
|
print(segment)
|
|
```
|
|
|
|
Args:
|
|
- model_args (dict): Arguments to pass to Whisper model initialization
|
|
- model_size (str): The name of the Whisper model to use.
|
|
- device (str): The device to use for computation ("cpu" or "cuda").
|
|
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
|
|
|
|
Attributes:
|
|
- model (whisper.Whisper): The underlying Whisper speech recognition model.
|
|
- device (torch.device): The device to use for computation.
|
|
- transcribe_args (dict): Additional arguments used for transcribe method.
|
|
|
|
Methods:
|
|
- transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
|
|
"""
|
|
|
|
def __init__(self, model_args: dict, transcribe_args: dict):
|
|
"""
|
|
Initializes the WhisperAI instance.
|
|
|
|
Args:
|
|
- model_args (dict): Arguments to initialize the Whisper model.
|
|
- transcribe_args (dict): Additional arguments for the transcribe method.
|
|
"""
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(device)
|
|
# Set device for computation
|
|
self.device = torch.device(device)
|
|
# Load the Whisper model with the specified size
|
|
self.model = whisper.load_model("base.en").to(self.device)
|
|
# Store the additional transcription arguments
|
|
self.transcribe_args = transcribe_args
|
|
|
|
def transcribe(self, audio_path: str):
|
|
"""
|
|
Transcribes the specified audio file and yields the resulting segments.
|
|
|
|
Args:
|
|
- audio_path (str): The path to the audio file for transcription.
|
|
|
|
Yields:
|
|
- dict: An individual transcription segment.
|
|
"""
|
|
# Suppress warnings during transcription
|
|
warnings.filterwarnings("ignore")
|
|
# Load and transcribe the audio file
|
|
result = self.model.transcribe(audio_path, **self.transcribe_args)
|
|
# Restore default warning behavior
|
|
warnings.filterwarnings("default")
|
|
|
|
# Calculate the total duration from the segments
|
|
total_duration = max(segment["end"] for segment in result["segments"])
|
|
|
|
# Create a progress bar with the total duration of the audio file
|
|
with tqdm(total=total_duration, unit=" seconds") as pbar:
|
|
for segment in result["segments"]:
|
|
# Yield each transcription segment
|
|
yield segment
|
|
# Update the progress bar with the duration of the current segment
|
|
pbar.update(segment["end"] - segment["start"])
|
|
# Ensure the progress bar reaches 100% upon completion
|
|
pbar.update(0)
|