Selaa lähdekoodia

Merge pull request #6 from karl0ss/reworked

Reworked
Karl0ss 8 kuukautta sitten
vanhempi
sitoutus
bf069d1fb4

+ 21 - 4
.vscode/launch.json

@@ -5,16 +5,33 @@
     "version": "0.2.0",
     "configurations": [
         {
-            "name": "Python: Current File",
+            "name": "Python Debugger: Current File",
             "type": "python",
             "request": "launch",
             "program": "${file}",
             "console": "integratedTerminal",
             "justMyCode": false,
+            "env": {
+                "CUDA_VISIBLE_DEVICES": "1",
+                "LD_LIBRARY_PATH": "/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cublas/lib:/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cudnn/lib"
+            },
             "args": [
                 "--model",
-                "base",
-            ],
+                "base"
+            ]
+        },
+        {
+            "name": "Current (withenv)",
+            "type": "debugpy",
+            "request": "launch",
+            "program": "${workspaceFolder}/run_with_env.sh",
+            "console": "integratedTerminal",
+            "justMyCode": false,
+            "args": [
+                "${file}",
+                "--model",
+                "base"
+            ]
         }
     ]
-}
+}

+ 29 - 48
bazarr-ai-sub-generator/cli.py

@@ -15,16 +15,16 @@ def main():
     parser = argparse.ArgumentParser(
         formatter_class=argparse.ArgumentDefaultsHelpFormatter
     )
-    parser.add_argument(
-        "--audio_channel", default="0", type=int, help="audio channel index to use"
-    )
-    parser.add_argument(
-        "--sample_interval",
-        type=str2timeinterval,
-        default=None,
-        help="generate subtitles for a specific \
-                              fragment of the video (e.g. 01:02:05-01:03:45)",
-    )
+    # parser.add_argument(
+    #     "--audio_channel", default="0", type=int, help="audio channel index to use"
+    # )
+    # parser.add_argument(
+    #     "--sample_interval",
+    #     type=str2timeinterval,
+    #     default=None,
+    #     help="generate subtitles for a specific \
+    #                           fragment of the video (e.g. 01:02:05-01:03:45)",
+    # )
     parser.add_argument(
         "--model",
         default="small",
@@ -38,46 +38,27 @@ def main():
         choices=["cpu", "cuda", "auto"],
         help='Device to use for computation ("cpu", "cuda", "auto")',
     )
+    # parser.add_argument(
+    #     "--compute_type",
+    #     type=str,
+    #     default="default",
+    #     choices=[
+    #         "int8",
+    #         "int8_float32",
+    #         "int8_float16",
+    #         "int8_bfloat16",
+    #         "int16",
+    #         "float16",
+    #         "bfloat16",
+    #         "float32",
+    #     ],
+    #     help="Type to use for computation. \
+    #                           See https://opennmt.net/CTranslate2/quantization.html.",
+    # )
     parser.add_argument(
-        "--compute_type",
+        "--show",
         type=str,
-        default="default",
-        choices=[
-            "int8",
-            "int8_float32",
-            "int8_float16",
-            "int8_bfloat16",
-            "int16",
-            "float16",
-            "bfloat16",
-            "float32",
-        ],
-        help="Type to use for computation. \
-                              See https://opennmt.net/CTranslate2/quantization.html.",
-    )
-    parser.add_argument(
-        "--beam_size",
-        type=int,
-        default=5,
-        help="model parameter, tweak to increase accuracy",
-    )
-    parser.add_argument(
-        "--no_speech_threshold",
-        type=float,
-        default=0.6,
-        help="model parameter, tweak to increase accuracy",
-    )
-    parser.add_argument(
-        "--condition_on_previous_text",
-        type=str2bool,
-        default=True,
-        help="model parameter, tweak to increase accuracy",
-    )
-    parser.add_argument(
-        "--task",
-        type=str,
-        default="transcribe",
-        choices=["transcribe", "translate"],
+        default=None,
         help="whether to perform X->X speech recognition ('transcribe') \
                               or X->English translation ('translate')",
     )

+ 17 - 15
bazarr-ai-sub-generator/main.py

@@ -6,15 +6,16 @@ from utils.files import filename, write_srt
 from utils.ffmpeg import get_audio, add_subtitles_to_mp4
 from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
 from utils.sonarr import update_show_in_sonarr
+# from utils.faster_whisper import WhisperAI
 from utils.whisper import WhisperAI
-
+from utils.decorator import measure_time
 
 def process(args: dict):
+    
     model_name: str = args.pop("model")
     language: str = args.pop("language")
-    sample_interval: str = args.pop("sample_interval")
-    audio_channel: str = args.pop("audio_channel")
-
+    show: str = args.pop("show")
+    
     if model_name.endswith(".en"):
         warnings.warn(
             f"{model_name} is an English-only model, forcing English detection."
@@ -25,26 +26,27 @@ def process(args: dict):
         args["language"] = language
 
     model_args = {}
-    model_args["model_size_or_path"] = model_name
     model_args["device"] = args.pop("device")
-    model_args["compute_type"] = args.pop("compute_type")
-
-    list_of_episodes_needing_subtitles = get_wanted_episodes()
+    
+    list_of_episodes_needing_subtitles = get_wanted_episodes(show)
     print(
         f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
     )
     for episode in list_of_episodes_needing_subtitles["data"]:
         print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
         episode_data = get_episode_details(episode["sonarrEpisodeId"])
-        audios = get_audio([episode_data["path"]], audio_channel, sample_interval)
-        subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
-
-        add_subtitles_to_mp4(subtitles)
-        update_show_in_sonarr(episode["sonarrSeriesId"])
-        time.sleep(5)
-        sync_series()
+        try:
+            audios = get_audio([episode_data["path"]], 0, None)
+            subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
 
+            add_subtitles_to_mp4(subtitles)
+            update_show_in_sonarr(episode["sonarrSeriesId"])
+            time.sleep(5)
+            sync_series()
+        except Exception as ex:
+            print(f"skipping file due to - {ex}")
 
+@measure_time
 def get_subtitles(
     audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
 ):

+ 7 - 3
bazarr-ai-sub-generator/utils/bazarr.py

@@ -8,15 +8,19 @@ token = config._sections["bazarr"]["token"]
 base_url = config._sections["bazarr"]["url"]
 
 
-def get_wanted_episodes():
+def get_wanted_episodes(show: str=None):
     url = f"{base_url}/api/episodes/wanted"
 
     payload = {}
     headers = {"accept": "application/json", "X-API-KEY": token}
 
     response = requests.request("GET", url, headers=headers, data=payload)
-
-    return response.json()
+    
+    data = response.json()
+    if show != None:
+        data['data'] = [item for item in data['data'] if item['seriesTitle'] == show]
+        data['total'] = len(data['data'])
+    return data
 
 
 def get_episode_details(episode_id: str):

+ 13 - 0
bazarr-ai-sub-generator/utils/decorator.py

@@ -0,0 +1,13 @@
+import time
+from datetime import timedelta
+
+def measure_time(func):
+    def wrapper(*args, **kwargs):
+        start_time = time.time()
+        result = func(*args, **kwargs)
+        end_time = time.time()
+        duration = end_time - start_time
+        human_readable_duration = str(timedelta(seconds=duration))
+        print(f"Function '{func.__name__}' executed in: {human_readable_duration}")
+        return result
+    return wrapper

+ 68 - 0
bazarr-ai-sub-generator/utils/faster_whisper.py

@@ -0,0 +1,68 @@
+import warnings
+import faster_whisper
+from tqdm import tqdm
+
+
+# pylint: disable=R0903
+class WhisperAI:
+    """
+    Wrapper class for the Whisper speech recognition model with additional functionality.
+
+    This class provides a high-level interface for transcribing audio files using the Whisper
+    speech recognition model. It encapsulates the model instantiation and transcription process,
+    allowing users to easily transcribe audio files and iterate over the resulting segments.
+
+    Usage:
+    ```python
+    whisper = WhisperAI(model_args, transcribe_args)
+
+    # Transcribe an audio file and iterate over the segments
+    for segment in whisper.transcribe(audio_path):
+        # Process each transcription segment
+        print(segment)
+    ```
+
+    Args:
+    - model_args: Arguments to pass to WhisperModel initialize method
+        - model_size_or_path (str): The name of the Whisper model to use.
+        - device (str): The device to use for computation ("cpu", "cuda", "auto").
+        - compute_type (str): The type to use for computation.
+            See https://opennmt.net/CTranslate2/quantization.html.
+    - transcribe_args (dict): Additional arguments to pass to the transcribe method.
+
+    Attributes:
+    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
+    - transcribe_args (dict): Additional arguments used for transcribe method.
+
+    Methods:
+    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
+    """
+
+    def __init__(self, model_args: dict, transcribe_args: dict):
+        # self.model = faster_whisper.WhisperModel(**model_args)
+        model_size = "base"
+        self.model = faster_whisper.WhisperModel(model_size, device="cuda")
+        self.transcribe_args = transcribe_args
+
+    def transcribe(self, audio_path: str):
+        """
+        Transcribes the specified audio file and yields the resulting segments.
+
+        Args:
+        - audio_path (str): The path to the audio file for transcription.
+
+        Yields:
+        - faster_whisper.TranscriptionSegment: An individual transcription segment.
+        """
+        warnings.filterwarnings("ignore")
+        segments, info = self.model.transcribe(audio_path, beam_size=5)
+        warnings.filterwarnings("default")
+
+        # Same precision as the Whisper timestamps.
+        total_duration = round(info.duration, 2)
+
+        with tqdm(total=total_duration, unit=" seconds") as pbar:
+            for segment in segments:
+                yield segment
+                pbar.update(segment.end - segment.start)
+            pbar.update(0)

+ 3 - 3
bazarr-ai-sub-generator/utils/files.py

@@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
     for i, segment in enumerate(transcript, start=1):
         print(
             f"{i}\n"
-            f"{format_timestamp(segment.start, always_include_hours=True)} --> "
-            f"{format_timestamp(segment.end, always_include_hours=True)}\n"
-            f"{segment.text.strip().replace('-->', '->')}\n",
+            f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
+            f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
+            f"{segment['text'].strip().replace('-->', '->')}\n",
             file=file,
             flush=True,
         )

+ 35 - 16
bazarr-ai-sub-generator/utils/whisper.py

@@ -1,9 +1,9 @@
 import warnings
-import faster_whisper
+import torch
+import whisper
 from tqdm import tqdm
 
 
-# pylint: disable=R0903
 class WhisperAI:
     """
     Wrapper class for the Whisper speech recognition model with additional functionality.
@@ -23,23 +23,35 @@ class WhisperAI:
     ```
 
     Args:
-    - model_args: Arguments to pass to WhisperModel initialize method
-        - model_size_or_path (str): The name of the Whisper model to use.
-        - device (str): The device to use for computation ("cpu", "cuda", "auto").
-        - compute_type (str): The type to use for computation.
-            See https://opennmt.net/CTranslate2/quantization.html.
+    - model_args (dict): Arguments to pass to Whisper model initialization
+        - model_size (str): The name of the Whisper model to use.
+        - device (str): The device to use for computation ("cpu" or "cuda").
     - transcribe_args (dict): Additional arguments to pass to the transcribe method.
 
     Attributes:
-    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
+    - model (whisper.Whisper): The underlying Whisper speech recognition model.
+    - device (torch.device): The device to use for computation.
     - transcribe_args (dict): Additional arguments used for transcribe method.
 
     Methods:
-    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
+    - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
     """
 
     def __init__(self, model_args: dict, transcribe_args: dict):
-        self.model = faster_whisper.WhisperModel(**model_args)
+        """
+        Initializes the WhisperAI instance.
+
+        Args:
+        - model_args (dict): Arguments to initialize the Whisper model.
+        - transcribe_args (dict): Additional arguments for the transcribe method.
+        """
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+        print(device)
+        # Set device for computation
+        self.device = torch.device(device)
+        # Load the Whisper model with the specified size
+        self.model = whisper.load_model("base.en").to(self.device)
+        # Store the additional transcription arguments
         self.transcribe_args = transcribe_args
 
     def transcribe(self, audio_path: str):
@@ -50,17 +62,24 @@ class WhisperAI:
         - audio_path (str): The path to the audio file for transcription.
 
         Yields:
-        - faster_whisper.TranscriptionSegment: An individual transcription segment.
+        - dict: An individual transcription segment.
         """
+        # Suppress warnings during transcription
         warnings.filterwarnings("ignore")
-        segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
+        # Load and transcribe the audio file
+        result = self.model.transcribe(audio_path, **self.transcribe_args)
+        # Restore default warning behavior
         warnings.filterwarnings("default")
 
-        # Same precision as the Whisper timestamps.
-        total_duration = round(info.duration, 2)
+        # Calculate the total duration from the segments
+        total_duration = max(segment["end"] for segment in result["segments"])
 
+        # Create a progress bar with the total duration of the audio file
         with tqdm(total=total_duration, unit=" seconds") as pbar:
-            for segment in segments:
+            for segment in result["segments"]:
+                # Yield each transcription segment
                 yield segment
-                pbar.update(segment.end - segment.start)
+                # Update the progress bar with the duration of the current segment
+                pbar.update(segment["end"] - segment["start"])
+            # Ensure the progress bar reaches 100% upon completion
             pbar.update(0)

+ 8 - 2
requirements.txt

@@ -1,3 +1,9 @@
-faster-whisper==0.10.0
 tqdm==4.56.0
-ffmpeg-python==0.2.0
+ffmpeg-python==0.2.0
+git+https://github.com/openai/whisper.git
+faster-whisper
+nvidia-cublas-cu12
+nvidia-cudnn-cu12
+nvidia-cublas-cu11
+nvidia-cudnn-cu11
+ctranslate2==3.24.0

+ 0 - 1
setup.py

@@ -7,7 +7,6 @@ setup(
     py_modules=["bazarr-ai-sub-generator"],
     author="Karl Hudgell",
     install_requires=[
-        'faster-whisper',
         'tqdm',
         'ffmpeg-python'
     ],