whisper.py 751 B

1234567891011121314151617181920
  1. import warnings
  2. import faster_whisper
  3. from tqdm import tqdm
  4. class WhisperAI:
  5. def __init__(self, model_name, model_args):
  6. self.model = faster_whisper.WhisperModel(model_name, device="cuda", compute_type="float16")
  7. self.model_args = model_args
  8. def transcribe(self, audio_path):
  9. warnings.filterwarnings("ignore")
  10. segments, info = self.model.transcribe(audio_path, **self.model_args)
  11. warnings.filterwarnings("default")
  12. total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
  13. with tqdm(total=total_duration, unit=" seconds") as pbar:
  14. for segment in segments:
  15. yield segment
  16. pbar.update(segment.end - segment.start)