Ver Fonte

add cuda deps

Karl há 8 meses atrás
pai
commit
fde1b4d89e

+ 2 - 2
.vscode/launch.json

@@ -5,8 +5,8 @@
     "version": "0.2.0",
     "configurations": [
         {
-            "name": "Python: Current File",
-            "type": "python",
+            "name": "Python Debugger: Current File",
+            "type": "debugpy",
             "request": "launch",
             "program": "${file}",
             "console": "integratedTerminal",

+ 13 - 1
bazarr-ai-sub-generator/main.py

@@ -8,8 +8,20 @@ from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
 from utils.sonarr import update_show_in_sonarr
 from utils.whisper import WhisperAI
 
+def measure_time(func):
+    def wrapper(*args, **kwargs):
+        start_time = time.time()
+        result = func(*args, **kwargs)
+        end_time = time.time()
+        duration = end_time - start_time
+        print(f"Function '{func.__name__}' executed in: {duration:.6f} seconds")
+        return result
+    return wrapper
+
+
 
 def process(args: dict):
+    
     model_name: str = args.pop("model")
     language: str = args.pop("language")
     sample_interval: str = args.pop("sample_interval")
@@ -44,7 +56,7 @@ def process(args: dict):
         time.sleep(5)
         sync_series()
 
-
+@measure_time
 def get_subtitles(
     audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
 ):

+ 3 - 3
bazarr-ai-sub-generator/utils/files.py

@@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
     for i, segment in enumerate(transcript, start=1):
         print(
             f"{i}\n"
-            f"{format_timestamp(segment.start, always_include_hours=True)} --> "
-            f"{format_timestamp(segment.end, always_include_hours=True)}\n"
-            f"{segment.text.strip().replace('-->', '->')}\n",
+            f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
+            f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
+            f"{segment['text'].strip().replace('-->', '->')}\n",
             file=file,
             flush=True,
         )

+ 35 - 16
bazarr-ai-sub-generator/utils/whisper.py

@@ -1,9 +1,9 @@
 import warnings
-import faster_whisper
+import torch
+import whisper
 from tqdm import tqdm
 
 
-# pylint: disable=R0903
 class WhisperAI:
     """
     Wrapper class for the Whisper speech recognition model with additional functionality.
@@ -23,23 +23,35 @@ class WhisperAI:
     ```
 
     Args:
-    - model_args: Arguments to pass to WhisperModel initialize method
-        - model_size_or_path (str): The name of the Whisper model to use.
-        - device (str): The device to use for computation ("cpu", "cuda", "auto").
-        - compute_type (str): The type to use for computation.
-            See https://opennmt.net/CTranslate2/quantization.html.
+    - model_args (dict): Arguments to pass to Whisper model initialization
+        - model_size (str): The name of the Whisper model to use.
+        - device (str): The device to use for computation ("cpu" or "cuda").
     - transcribe_args (dict): Additional arguments to pass to the transcribe method.
 
     Attributes:
-    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
+    - model (whisper.Whisper): The underlying Whisper speech recognition model.
+    - device (torch.device): The device to use for computation.
     - transcribe_args (dict): Additional arguments used for transcribe method.
 
     Methods:
-    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
+    - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
     """
 
     def __init__(self, model_args: dict, transcribe_args: dict):
-        self.model = faster_whisper.WhisperModel(**model_args)
+        """
+        Initializes the WhisperAI instance.
+
+        Args:
+        - model_args (dict): Arguments to initialize the Whisper model.
+        - transcribe_args (dict): Additional arguments for the transcribe method.
+        """
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+        print(device)
+        # Set device for computation
+        self.device = torch.device(device)
+        # Load the Whisper model with the specified size
+        self.model = whisper.load_model("base").to(self.device)
+        # Store the additional transcription arguments
         self.transcribe_args = transcribe_args
 
     def transcribe(self, audio_path: str):
@@ -50,17 +62,24 @@ class WhisperAI:
         - audio_path (str): The path to the audio file for transcription.
 
         Yields:
-        - faster_whisper.TranscriptionSegment: An individual transcription segment.
+        - dict: An individual transcription segment.
         """
+        # Suppress warnings during transcription
         warnings.filterwarnings("ignore")
-        segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
+        # Load and transcribe the audio file
+        result = self.model.transcribe(audio_path, **self.transcribe_args)
+        # Restore default warning behavior
         warnings.filterwarnings("default")
 
-        # Same precision as the Whisper timestamps.
-        total_duration = round(info.duration, 2)
+        # Calculate the total duration from the segments
+        total_duration = max(segment["end"] for segment in result["segments"])
 
+        # Create a progress bar with the total duration of the audio file
         with tqdm(total=total_duration, unit=" seconds") as pbar:
-            for segment in segments:
+            for segment in result["segments"]:
+                # Yield each transcription segment
                 yield segment
-                pbar.update(segment.end - segment.start)
+                # Update the progress bar with the duration of the current segment
+                pbar.update(segment["end"] - segment["start"])
+            # Ensure the progress bar reaches 100% upon completion
             pbar.update(0)

+ 2 - 1
requirements.txt

@@ -1,3 +1,4 @@
 faster-whisper==0.10.0
 tqdm==4.56.0
-ffmpeg-python==0.2.0
+ffmpeg-python==0.2.0
+git+https://github.com/openai/whisper.git