|
@@ -1,9 +1,9 @@
|
|
|
import warnings
|
|
|
-import faster_whisper
|
|
|
+import torch
|
|
|
+import whisper
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
-# pylint: disable=R0903
|
|
|
class WhisperAI:
|
|
|
"""
|
|
|
Wrapper class for the Whisper speech recognition model with additional functionality.
|
|
@@ -23,23 +23,35 @@ class WhisperAI:
|
|
|
```
|
|
|
|
|
|
Args:
|
|
|
- - model_args: Arguments to pass to WhisperModel initialize method
|
|
|
- - model_size_or_path (str): The name of the Whisper model to use.
|
|
|
- - device (str): The device to use for computation ("cpu", "cuda", "auto").
|
|
|
- - compute_type (str): The type to use for computation.
|
|
|
- See https://opennmt.net/CTranslate2/quantization.html.
|
|
|
+ - model_args (dict): Arguments to pass to Whisper model initialization
|
|
|
+ - model_size (str): The name of the Whisper model to use.
|
|
|
+ - device (str): The device to use for computation ("cpu" or "cuda").
|
|
|
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
|
|
|
|
|
|
Attributes:
|
|
|
- - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
|
|
|
+ - model (whisper.Whisper): The underlying Whisper speech recognition model.
|
|
|
+ - device (torch.device): The device to use for computation.
|
|
|
- transcribe_args (dict): Additional arguments used for transcribe method.
|
|
|
|
|
|
Methods:
|
|
|
- - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
|
|
|
+ - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, model_args: dict, transcribe_args: dict):
|
|
|
- self.model = faster_whisper.WhisperModel(**model_args)
|
|
|
+ """
|
|
|
+ Initializes the WhisperAI instance.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ - model_args (dict): Arguments to initialize the Whisper model.
|
|
|
+ - transcribe_args (dict): Additional arguments for the transcribe method.
|
|
|
+ """
|
|
|
+ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ print(device)
|
|
|
+ # Set device for computation
|
|
|
+ self.device = torch.device(device)
|
|
|
+ # Load the Whisper model with the specified size
|
|
|
+ self.model = whisper.load_model("base").to(self.device)
|
|
|
+ # Store the additional transcription arguments
|
|
|
self.transcribe_args = transcribe_args
|
|
|
|
|
|
def transcribe(self, audio_path: str):
|
|
@@ -50,17 +62,24 @@ class WhisperAI:
|
|
|
- audio_path (str): The path to the audio file for transcription.
|
|
|
|
|
|
Yields:
|
|
|
- - faster_whisper.TranscriptionSegment: An individual transcription segment.
|
|
|
+ - dict: An individual transcription segment.
|
|
|
"""
|
|
|
+ # Suppress warnings during transcription
|
|
|
warnings.filterwarnings("ignore")
|
|
|
- segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
|
|
|
+ # Load and transcribe the audio file
|
|
|
+ result = self.model.transcribe(audio_path, **self.transcribe_args)
|
|
|
+ # Restore default warning behavior
|
|
|
warnings.filterwarnings("default")
|
|
|
|
|
|
- # Same precision as the Whisper timestamps.
|
|
|
- total_duration = round(info.duration, 2)
|
|
|
+ # Calculate the total duration from the segments
|
|
|
+ total_duration = max(segment["end"] for segment in result["segments"])
|
|
|
|
|
|
+ # Create a progress bar with the total duration of the audio file
|
|
|
with tqdm(total=total_duration, unit=" seconds") as pbar:
|
|
|
- for segment in segments:
|
|
|
+ for segment in result["segments"]:
|
|
|
+ # Yield each transcription segment
|
|
|
yield segment
|
|
|
- pbar.update(segment.end - segment.start)
|
|
|
+ # Update the progress bar with the duration of the current segment
|
|
|
+ pbar.update(segment["end"] - segment["start"])
|
|
|
+ # Ensure the progress bar reaches 100% upon completion
|
|
|
pbar.update(0)
|