whisper.py 802 B

123456789101112131415161718192021
  1. import warnings
  2. import faster_whisper
  3. from tqdm import tqdm
  4. class WhisperAI:
  5. def __init__(self, model_name, device, compute_type, model_args):
  6. self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type)
  7. self.model_args = model_args
  8. def transcribe(self, audio_path):
  9. warnings.filterwarnings("ignore")
  10. segments, info = self.model.transcribe(audio_path, **self.model_args)
  11. warnings.filterwarnings("default")
  12. total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
  13. with tqdm(total=total_duration, unit=" seconds") as pbar:
  14. for segment in segments:
  15. yield segment
  16. pbar.update(segment.end - segment.start)
  17. pbar.update(0)