whisper.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import warnings
  2. import faster_whisper
  3. from tqdm import tqdm
  4. # pylint: disable=R0903
  5. class WhisperAI:
  6. """
  7. Wrapper class for the Whisper speech recognition model with additional functionality.
  8. This class provides a high-level interface for transcribing audio files using the Whisper
  9. speech recognition model. It encapsulates the model instantiation and transcription process,
  10. allowing users to easily transcribe audio files and iterate over the resulting segments.
  11. Usage:
  12. ```python
  13. whisper = WhisperAI(model_args, transcribe_args)
  14. # Transcribe an audio file and iterate over the segments
  15. for segment in whisper.transcribe(audio_path):
  16. # Process each transcription segment
  17. print(segment)
  18. ```
  19. Args:
  20. - model_args: Arguments to pass to WhisperModel initialize method
  21. - model_size_or_path (str): The name of the Whisper model to use.
  22. - device (str): The device to use for computation ("cpu", "cuda", "auto").
  23. - compute_type (str): The type to use for computation.
  24. See https://opennmt.net/CTranslate2/quantization.html.
  25. - transcribe_args (dict): Additional arguments to pass to the transcribe method.
  26. Attributes:
  27. - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
  28. - transcribe_args (dict): Additional arguments used for transcribe method.
  29. Methods:
  30. - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
  31. """
  32. def __init__(self, model_args: dict, transcribe_args: dict):
  33. self.model = faster_whisper.WhisperModel(**model_args)
  34. self.transcribe_args = transcribe_args
  35. def transcribe(self, audio_path: str):
  36. """
  37. Transcribes the specified audio file and yields the resulting segments.
  38. Args:
  39. - audio_path (str): The path to the audio file for transcription.
  40. Yields:
  41. - faster_whisper.TranscriptionSegment: An individual transcription segment.
  42. """
  43. warnings.filterwarnings("ignore")
  44. segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
  45. warnings.filterwarnings("default")
  46. # Same precision as the Whisper timestamps.
  47. total_duration = round(info.duration, 2)
  48. with tqdm(total=total_duration, unit=" seconds") as pbar:
  49. for segment in segments:
  50. yield segment
  51. pbar.update(segment.end - segment.start)
  52. pbar.update(0)