mirror of
				https://github.com/karl0ss/bazarr-ai-sub-generator.git
				synced 2025-11-04 08:31:03 +00:00 
			
		
		
		
	
						commit
						bf069d1fb4
					
				
							
								
								
									
										25
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							@ -5,16 +5,33 @@
 | 
				
			|||||||
    "version": "0.2.0",
 | 
					    "version": "0.2.0",
 | 
				
			||||||
    "configurations": [
 | 
					    "configurations": [
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            "name": "Python: Current File",
 | 
					            "name": "Python Debugger: Current File",
 | 
				
			||||||
            "type": "python",
 | 
					            "type": "python",
 | 
				
			||||||
            "request": "launch",
 | 
					            "request": "launch",
 | 
				
			||||||
            "program": "${file}",
 | 
					            "program": "${file}",
 | 
				
			||||||
            "console": "integratedTerminal",
 | 
					            "console": "integratedTerminal",
 | 
				
			||||||
            "justMyCode": false,
 | 
					            "justMyCode": false,
 | 
				
			||||||
 | 
					            "env": {
 | 
				
			||||||
 | 
					                "CUDA_VISIBLE_DEVICES": "1",
 | 
				
			||||||
 | 
					                "LD_LIBRARY_PATH": "/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cublas/lib:/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cudnn/lib"
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
            "args": [
 | 
					            "args": [
 | 
				
			||||||
                "--model",
 | 
					                "--model",
 | 
				
			||||||
                "base",
 | 
					                "base"
 | 
				
			||||||
            ],
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "Current (withenv)",
 | 
				
			||||||
 | 
					            "type": "debugpy",
 | 
				
			||||||
 | 
					            "request": "launch",
 | 
				
			||||||
 | 
					            "program": "${workspaceFolder}/run_with_env.sh",
 | 
				
			||||||
 | 
					            "console": "integratedTerminal",
 | 
				
			||||||
 | 
					            "justMyCode": false,
 | 
				
			||||||
 | 
					            "args": [
 | 
				
			||||||
 | 
					                "${file}",
 | 
				
			||||||
 | 
					                "--model",
 | 
				
			||||||
 | 
					                "base"
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -15,16 +15,16 @@ def main():
 | 
				
			|||||||
    parser = argparse.ArgumentParser(
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | 
					        formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    parser.add_argument(
 | 
					    # parser.add_argument(
 | 
				
			||||||
        "--audio_channel", default="0", type=int, help="audio channel index to use"
 | 
					    #     "--audio_channel", default="0", type=int, help="audio channel index to use"
 | 
				
			||||||
    )
 | 
					    # )
 | 
				
			||||||
    parser.add_argument(
 | 
					    # parser.add_argument(
 | 
				
			||||||
        "--sample_interval",
 | 
					    #     "--sample_interval",
 | 
				
			||||||
        type=str2timeinterval,
 | 
					    #     type=str2timeinterval,
 | 
				
			||||||
        default=None,
 | 
					    #     default=None,
 | 
				
			||||||
        help="generate subtitles for a specific \
 | 
					    #     help="generate subtitles for a specific \
 | 
				
			||||||
                              fragment of the video (e.g. 01:02:05-01:03:45)",
 | 
					    #                           fragment of the video (e.g. 01:02:05-01:03:45)",
 | 
				
			||||||
    )
 | 
					    # )
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--model",
 | 
					        "--model",
 | 
				
			||||||
        default="small",
 | 
					        default="small",
 | 
				
			||||||
@ -38,46 +38,27 @@ def main():
 | 
				
			|||||||
        choices=["cpu", "cuda", "auto"],
 | 
					        choices=["cpu", "cuda", "auto"],
 | 
				
			||||||
        help='Device to use for computation ("cpu", "cuda", "auto")',
 | 
					        help='Device to use for computation ("cpu", "cuda", "auto")',
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    # parser.add_argument(
 | 
				
			||||||
 | 
					    #     "--compute_type",
 | 
				
			||||||
 | 
					    #     type=str,
 | 
				
			||||||
 | 
					    #     default="default",
 | 
				
			||||||
 | 
					    #     choices=[
 | 
				
			||||||
 | 
					    #         "int8",
 | 
				
			||||||
 | 
					    #         "int8_float32",
 | 
				
			||||||
 | 
					    #         "int8_float16",
 | 
				
			||||||
 | 
					    #         "int8_bfloat16",
 | 
				
			||||||
 | 
					    #         "int16",
 | 
				
			||||||
 | 
					    #         "float16",
 | 
				
			||||||
 | 
					    #         "bfloat16",
 | 
				
			||||||
 | 
					    #         "float32",
 | 
				
			||||||
 | 
					    #     ],
 | 
				
			||||||
 | 
					    #     help="Type to use for computation. \
 | 
				
			||||||
 | 
					    #                           See https://opennmt.net/CTranslate2/quantization.html.",
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--compute_type",
 | 
					        "--show",
 | 
				
			||||||
        type=str,
 | 
					        type=str,
 | 
				
			||||||
        default="default",
 | 
					        default=None,
 | 
				
			||||||
        choices=[
 | 
					 | 
				
			||||||
            "int8",
 | 
					 | 
				
			||||||
            "int8_float32",
 | 
					 | 
				
			||||||
            "int8_float16",
 | 
					 | 
				
			||||||
            "int8_bfloat16",
 | 
					 | 
				
			||||||
            "int16",
 | 
					 | 
				
			||||||
            "float16",
 | 
					 | 
				
			||||||
            "bfloat16",
 | 
					 | 
				
			||||||
            "float32",
 | 
					 | 
				
			||||||
        ],
 | 
					 | 
				
			||||||
        help="Type to use for computation. \
 | 
					 | 
				
			||||||
                              See https://opennmt.net/CTranslate2/quantization.html.",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--beam_size",
 | 
					 | 
				
			||||||
        type=int,
 | 
					 | 
				
			||||||
        default=5,
 | 
					 | 
				
			||||||
        help="model parameter, tweak to increase accuracy",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--no_speech_threshold",
 | 
					 | 
				
			||||||
        type=float,
 | 
					 | 
				
			||||||
        default=0.6,
 | 
					 | 
				
			||||||
        help="model parameter, tweak to increase accuracy",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--condition_on_previous_text",
 | 
					 | 
				
			||||||
        type=str2bool,
 | 
					 | 
				
			||||||
        default=True,
 | 
					 | 
				
			||||||
        help="model parameter, tweak to increase accuracy",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--task",
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        default="transcribe",
 | 
					 | 
				
			||||||
        choices=["transcribe", "translate"],
 | 
					 | 
				
			||||||
        help="whether to perform X->X speech recognition ('transcribe') \
 | 
					        help="whether to perform X->X speech recognition ('transcribe') \
 | 
				
			||||||
                              or X->English translation ('translate')",
 | 
					                              or X->English translation ('translate')",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -6,15 +6,16 @@ from utils.files import filename, write_srt
 | 
				
			|||||||
from utils.ffmpeg import get_audio, add_subtitles_to_mp4
 | 
					from utils.ffmpeg import get_audio, add_subtitles_to_mp4
 | 
				
			||||||
from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
 | 
					from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
 | 
				
			||||||
from utils.sonarr import update_show_in_sonarr
 | 
					from utils.sonarr import update_show_in_sonarr
 | 
				
			||||||
 | 
					# from utils.faster_whisper import WhisperAI
 | 
				
			||||||
from utils.whisper import WhisperAI
 | 
					from utils.whisper import WhisperAI
 | 
				
			||||||
 | 
					from utils.decorator import measure_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def process(args: dict):
 | 
					def process(args: dict):
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    model_name: str = args.pop("model")
 | 
					    model_name: str = args.pop("model")
 | 
				
			||||||
    language: str = args.pop("language")
 | 
					    language: str = args.pop("language")
 | 
				
			||||||
    sample_interval: str = args.pop("sample_interval")
 | 
					    show: str = args.pop("show")
 | 
				
			||||||
    audio_channel: str = args.pop("audio_channel")
 | 
					    
 | 
				
			||||||
 | 
					 | 
				
			||||||
    if model_name.endswith(".en"):
 | 
					    if model_name.endswith(".en"):
 | 
				
			||||||
        warnings.warn(
 | 
					        warnings.warn(
 | 
				
			||||||
            f"{model_name} is an English-only model, forcing English detection."
 | 
					            f"{model_name} is an English-only model, forcing English detection."
 | 
				
			||||||
@ -25,26 +26,27 @@ def process(args: dict):
 | 
				
			|||||||
        args["language"] = language
 | 
					        args["language"] = language
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_args = {}
 | 
					    model_args = {}
 | 
				
			||||||
    model_args["model_size_or_path"] = model_name
 | 
					 | 
				
			||||||
    model_args["device"] = args.pop("device")
 | 
					    model_args["device"] = args.pop("device")
 | 
				
			||||||
    model_args["compute_type"] = args.pop("compute_type")
 | 
					    
 | 
				
			||||||
 | 
					    list_of_episodes_needing_subtitles = get_wanted_episodes(show)
 | 
				
			||||||
    list_of_episodes_needing_subtitles = get_wanted_episodes()
 | 
					 | 
				
			||||||
    print(
 | 
					    print(
 | 
				
			||||||
        f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
 | 
					        f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    for episode in list_of_episodes_needing_subtitles["data"]:
 | 
					    for episode in list_of_episodes_needing_subtitles["data"]:
 | 
				
			||||||
        print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
 | 
					        print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
 | 
				
			||||||
        episode_data = get_episode_details(episode["sonarrEpisodeId"])
 | 
					        episode_data = get_episode_details(episode["sonarrEpisodeId"])
 | 
				
			||||||
        audios = get_audio([episode_data["path"]], audio_channel, sample_interval)
 | 
					        try:
 | 
				
			||||||
        subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
 | 
					            audios = get_audio([episode_data["path"]], 0, None)
 | 
				
			||||||
 | 
					            subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
 | 
				
			||||||
        add_subtitles_to_mp4(subtitles)
 | 
					 | 
				
			||||||
        update_show_in_sonarr(episode["sonarrSeriesId"])
 | 
					 | 
				
			||||||
        time.sleep(5)
 | 
					 | 
				
			||||||
        sync_series()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            add_subtitles_to_mp4(subtitles)
 | 
				
			||||||
 | 
					            update_show_in_sonarr(episode["sonarrSeriesId"])
 | 
				
			||||||
 | 
					            time.sleep(5)
 | 
				
			||||||
 | 
					            sync_series()
 | 
				
			||||||
 | 
					        except Exception as ex:
 | 
				
			||||||
 | 
					            print(f"skipping file due to - {ex}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@measure_time
 | 
				
			||||||
def get_subtitles(
 | 
					def get_subtitles(
 | 
				
			||||||
    audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
 | 
					    audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
				
			|||||||
@ -8,15 +8,19 @@ token = config._sections["bazarr"]["token"]
 | 
				
			|||||||
base_url = config._sections["bazarr"]["url"]
 | 
					base_url = config._sections["bazarr"]["url"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_wanted_episodes():
 | 
					def get_wanted_episodes(show: str=None):
 | 
				
			||||||
    url = f"{base_url}/api/episodes/wanted"
 | 
					    url = f"{base_url}/api/episodes/wanted"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    payload = {}
 | 
					    payload = {}
 | 
				
			||||||
    headers = {"accept": "application/json", "X-API-KEY": token}
 | 
					    headers = {"accept": "application/json", "X-API-KEY": token}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    response = requests.request("GET", url, headers=headers, data=payload)
 | 
					    response = requests.request("GET", url, headers=headers, data=payload)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    return response.json()
 | 
					    data = response.json()
 | 
				
			||||||
 | 
					    if show != None:
 | 
				
			||||||
 | 
					        data['data'] = [item for item in data['data'] if item['seriesTitle'] == show]
 | 
				
			||||||
 | 
					        data['total'] = len(data['data'])
 | 
				
			||||||
 | 
					    return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_episode_details(episode_id: str):
 | 
					def get_episode_details(episode_id: str):
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										13
									
								
								bazarr-ai-sub-generator/utils/decorator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								bazarr-ai-sub-generator/utils/decorator.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					import time
 | 
				
			||||||
 | 
					from datetime import timedelta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def measure_time(func):
 | 
				
			||||||
 | 
					    def wrapper(*args, **kwargs):
 | 
				
			||||||
 | 
					        start_time = time.time()
 | 
				
			||||||
 | 
					        result = func(*args, **kwargs)
 | 
				
			||||||
 | 
					        end_time = time.time()
 | 
				
			||||||
 | 
					        duration = end_time - start_time
 | 
				
			||||||
 | 
					        human_readable_duration = str(timedelta(seconds=duration))
 | 
				
			||||||
 | 
					        print(f"Function '{func.__name__}' executed in: {human_readable_duration}")
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					    return wrapper
 | 
				
			||||||
							
								
								
									
										68
									
								
								bazarr-ai-sub-generator/utils/faster_whisper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								bazarr-ai-sub-generator/utils/faster_whisper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
				
			|||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					import faster_whisper
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# pylint: disable=R0903
 | 
				
			||||||
 | 
					class WhisperAI:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Wrapper class for the Whisper speech recognition model with additional functionality.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This class provides a high-level interface for transcribing audio files using the Whisper
 | 
				
			||||||
 | 
					    speech recognition model. It encapsulates the model instantiation and transcription process,
 | 
				
			||||||
 | 
					    allowing users to easily transcribe audio files and iterate over the resulting segments.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Usage:
 | 
				
			||||||
 | 
					    ```python
 | 
				
			||||||
 | 
					    whisper = WhisperAI(model_args, transcribe_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Transcribe an audio file and iterate over the segments
 | 
				
			||||||
 | 
					    for segment in whisper.transcribe(audio_path):
 | 
				
			||||||
 | 
					        # Process each transcription segment
 | 
				
			||||||
 | 
					        print(segment)
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					    - model_args: Arguments to pass to WhisperModel initialize method
 | 
				
			||||||
 | 
					        - model_size_or_path (str): The name of the Whisper model to use.
 | 
				
			||||||
 | 
					        - device (str): The device to use for computation ("cpu", "cuda", "auto").
 | 
				
			||||||
 | 
					        - compute_type (str): The type to use for computation.
 | 
				
			||||||
 | 
					            See https://opennmt.net/CTranslate2/quantization.html.
 | 
				
			||||||
 | 
					    - transcribe_args (dict): Additional arguments to pass to the transcribe method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes:
 | 
				
			||||||
 | 
					    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
 | 
				
			||||||
 | 
					    - transcribe_args (dict): Additional arguments used for transcribe method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Methods:
 | 
				
			||||||
 | 
					    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, model_args: dict, transcribe_args: dict):
 | 
				
			||||||
 | 
					        # self.model = faster_whisper.WhisperModel(**model_args)
 | 
				
			||||||
 | 
					        model_size = "base"
 | 
				
			||||||
 | 
					        self.model = faster_whisper.WhisperModel(model_size, device="cuda")
 | 
				
			||||||
 | 
					        self.transcribe_args = transcribe_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def transcribe(self, audio_path: str):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Transcribes the specified audio file and yields the resulting segments.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					        - audio_path (str): The path to the audio file for transcription.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Yields:
 | 
				
			||||||
 | 
					        - faster_whisper.TranscriptionSegment: An individual transcription segment.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        warnings.filterwarnings("ignore")
 | 
				
			||||||
 | 
					        segments, info = self.model.transcribe(audio_path, beam_size=5)
 | 
				
			||||||
 | 
					        warnings.filterwarnings("default")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Same precision as the Whisper timestamps.
 | 
				
			||||||
 | 
					        total_duration = round(info.duration, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with tqdm(total=total_duration, unit=" seconds") as pbar:
 | 
				
			||||||
 | 
					            for segment in segments:
 | 
				
			||||||
 | 
					                yield segment
 | 
				
			||||||
 | 
					                pbar.update(segment.end - segment.start)
 | 
				
			||||||
 | 
					            pbar.update(0)
 | 
				
			||||||
@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
 | 
				
			|||||||
    for i, segment in enumerate(transcript, start=1):
 | 
					    for i, segment in enumerate(transcript, start=1):
 | 
				
			||||||
        print(
 | 
					        print(
 | 
				
			||||||
            f"{i}\n"
 | 
					            f"{i}\n"
 | 
				
			||||||
            f"{format_timestamp(segment.start, always_include_hours=True)} --> "
 | 
					            f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
 | 
				
			||||||
            f"{format_timestamp(segment.end, always_include_hours=True)}\n"
 | 
					            f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
 | 
				
			||||||
            f"{segment.text.strip().replace('-->', '->')}\n",
 | 
					            f"{segment['text'].strip().replace('-->', '->')}\n",
 | 
				
			||||||
            file=file,
 | 
					            file=file,
 | 
				
			||||||
            flush=True,
 | 
					            flush=True,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,9 @@
 | 
				
			|||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
import faster_whisper
 | 
					import torch
 | 
				
			||||||
 | 
					import whisper
 | 
				
			||||||
from tqdm import tqdm
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# pylint: disable=R0903
 | 
					 | 
				
			||||||
class WhisperAI:
 | 
					class WhisperAI:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Wrapper class for the Whisper speech recognition model with additional functionality.
 | 
					    Wrapper class for the Whisper speech recognition model with additional functionality.
 | 
				
			||||||
@ -23,23 +23,35 @@ class WhisperAI:
 | 
				
			|||||||
    ```
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
    - model_args: Arguments to pass to WhisperModel initialize method
 | 
					    - model_args (dict): Arguments to pass to Whisper model initialization
 | 
				
			||||||
        - model_size_or_path (str): The name of the Whisper model to use.
 | 
					        - model_size (str): The name of the Whisper model to use.
 | 
				
			||||||
        - device (str): The device to use for computation ("cpu", "cuda", "auto").
 | 
					        - device (str): The device to use for computation ("cpu" or "cuda").
 | 
				
			||||||
        - compute_type (str): The type to use for computation.
 | 
					 | 
				
			||||||
            See https://opennmt.net/CTranslate2/quantization.html.
 | 
					 | 
				
			||||||
    - transcribe_args (dict): Additional arguments to pass to the transcribe method.
 | 
					    - transcribe_args (dict): Additional arguments to pass to the transcribe method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Attributes:
 | 
					    Attributes:
 | 
				
			||||||
    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
 | 
					    - model (whisper.Whisper): The underlying Whisper speech recognition model.
 | 
				
			||||||
 | 
					    - device (torch.device): The device to use for computation.
 | 
				
			||||||
    - transcribe_args (dict): Additional arguments used for transcribe method.
 | 
					    - transcribe_args (dict): Additional arguments used for transcribe method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Methods:
 | 
					    Methods:
 | 
				
			||||||
    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
 | 
					    - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, model_args: dict, transcribe_args: dict):
 | 
					    def __init__(self, model_args: dict, transcribe_args: dict):
 | 
				
			||||||
        self.model = faster_whisper.WhisperModel(**model_args)
 | 
					        """
 | 
				
			||||||
 | 
					        Initializes the WhisperAI instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					        - model_args (dict): Arguments to initialize the Whisper model.
 | 
				
			||||||
 | 
					        - transcribe_args (dict): Additional arguments for the transcribe method.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
				
			||||||
 | 
					        print(device)
 | 
				
			||||||
 | 
					        # Set device for computation
 | 
				
			||||||
 | 
					        self.device = torch.device(device)
 | 
				
			||||||
 | 
					        # Load the Whisper model with the specified size
 | 
				
			||||||
 | 
					        self.model = whisper.load_model("base.en").to(self.device)
 | 
				
			||||||
 | 
					        # Store the additional transcription arguments
 | 
				
			||||||
        self.transcribe_args = transcribe_args
 | 
					        self.transcribe_args = transcribe_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def transcribe(self, audio_path: str):
 | 
					    def transcribe(self, audio_path: str):
 | 
				
			||||||
@ -50,17 +62,24 @@ class WhisperAI:
 | 
				
			|||||||
        - audio_path (str): The path to the audio file for transcription.
 | 
					        - audio_path (str): The path to the audio file for transcription.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
        - faster_whisper.TranscriptionSegment: An individual transcription segment.
 | 
					        - dict: An individual transcription segment.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        # Suppress warnings during transcription
 | 
				
			||||||
        warnings.filterwarnings("ignore")
 | 
					        warnings.filterwarnings("ignore")
 | 
				
			||||||
        segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
 | 
					        # Load and transcribe the audio file
 | 
				
			||||||
 | 
					        result = self.model.transcribe(audio_path, **self.transcribe_args)
 | 
				
			||||||
 | 
					        # Restore default warning behavior
 | 
				
			||||||
        warnings.filterwarnings("default")
 | 
					        warnings.filterwarnings("default")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Same precision as the Whisper timestamps.
 | 
					        # Calculate the total duration from the segments
 | 
				
			||||||
        total_duration = round(info.duration, 2)
 | 
					        total_duration = max(segment["end"] for segment in result["segments"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create a progress bar with the total duration of the audio file
 | 
				
			||||||
        with tqdm(total=total_duration, unit=" seconds") as pbar:
 | 
					        with tqdm(total=total_duration, unit=" seconds") as pbar:
 | 
				
			||||||
            for segment in segments:
 | 
					            for segment in result["segments"]:
 | 
				
			||||||
 | 
					                # Yield each transcription segment
 | 
				
			||||||
                yield segment
 | 
					                yield segment
 | 
				
			||||||
                pbar.update(segment.end - segment.start)
 | 
					                # Update the progress bar with the duration of the current segment
 | 
				
			||||||
 | 
					                pbar.update(segment["end"] - segment["start"])
 | 
				
			||||||
 | 
					            # Ensure the progress bar reaches 100% upon completion
 | 
				
			||||||
            pbar.update(0)
 | 
					            pbar.update(0)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,3 +1,9 @@
 | 
				
			|||||||
faster-whisper==0.10.0
 | 
					 | 
				
			||||||
tqdm==4.56.0
 | 
					tqdm==4.56.0
 | 
				
			||||||
ffmpeg-python==0.2.0
 | 
					ffmpeg-python==0.2.0
 | 
				
			||||||
 | 
					git+https://github.com/openai/whisper.git
 | 
				
			||||||
 | 
					faster-whisper
 | 
				
			||||||
 | 
					nvidia-cublas-cu12
 | 
				
			||||||
 | 
					nvidia-cudnn-cu12
 | 
				
			||||||
 | 
					nvidia-cublas-cu11
 | 
				
			||||||
 | 
					nvidia-cudnn-cu11
 | 
				
			||||||
 | 
					ctranslate2==3.24.0
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user