diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..1c93443 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,24 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + pip install -r requirements.txt + - name: Analysing the code with pylint + run: | + pylint --disable=C0114 --disable=C0115 --disable=C0116 $(git ls-files '*.py') diff --git a/auto_subtitle/cli.py b/auto_subtitle/cli.py index 25c5d97..6e030f5 100644 --- a/auto_subtitle/cli.py +++ b/auto_subtitle/cli.py @@ -1,9 +1,17 @@ import argparse from faster_whisper import available_models +from .utils.constants import LANGUAGE_CODES from .main import process from .utils.convert import str2bool, str2timeinterval + def main(): + """ + Main entry point for the script. + + Parses command line arguments, processes the inputs using the specified options, + and performs transcription or translation based on the specified task. + """ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("video", nargs="+", type=str, @@ -11,15 +19,18 @@ def main(): parser.add_argument("--audio_channel", default="0", type=int, help="audio channel index to use") parser.add_argument("--sample_interval", type=str2timeinterval, default=None, - help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)") + help="generate subtitles for a specific \ + fragment of the video (e.g. 01:02:05-01:03:45)") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") - parser.add_argument("--device", type=str, default="auto", choices=[ - "cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")") + parser.add_argument("--device", type=str, default="auto", + choices=["cpu", "cuda", "auto"], + help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")") parser.add_argument("--compute_type", type=str, default="default", choices=[ - "int8", "int8_float32", "int8_float16", - "int8_bfloat16", "int16", "float16", - "bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.") + "int8", "int8_float32", "int8_float16", "int8_bfloat16", + "int16", "float16", "bfloat16", "float32"], + help="Type to use for computation. \ + See https://opennmt.net/CTranslate2/quantization.html.") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_srt", type=str2bool, default=False, @@ -32,10 +43,14 @@ def main(): help="model parameter, tweak to increase accuracy") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="model parameter, tweak to increase accuracy") - parser.add_argument("--task", type=str, default="transcribe", choices=[ - "transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") - parser.add_argument("--language", type=str, default="auto", choices=["auto","af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca","cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it","ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi","mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no","oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk","sl","sn","so","sq","sr","su","sv","sw","ta","te","tg","th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo","zh"], - help="What is the origin language of the video? If unset, it is detected automatically.") + parser.add_argument("--task", type=str, default="transcribe", + choices=["transcribe", "translate"], + help="whether to perform X->X speech recognition ('transcribe') \ + or X->English translation ('translate')") + parser.add_argument("--language", type=str, default="auto", + choices=LANGUAGE_CODES, + help="What is the origin language of the video? \ + If unset, it is detected automatically.") args = parser.parse_args().__dict__ diff --git a/auto_subtitle/main.py b/auto_subtitle/main.py index 26650f9..cad112f 100644 --- a/auto_subtitle/main.py +++ b/auto_subtitle/main.py @@ -5,6 +5,7 @@ from .utils.files import filename, write_srt from .utils.ffmpeg import get_audio, overlay_subtitles from .utils.whisper import WhisperAI + def process(args: dict): model_name: str = args.pop("model") output_dir: str = args.pop("output_dir") @@ -12,9 +13,7 @@ def process(args: dict): srt_only: bool = args.pop("srt_only") language: str = args.pop("language") sample_interval: str = args.pop("sample_interval") - device: str = args.pop("device") - compute_type: str = args.pop("compute_type") - + os.makedirs(output_dir, exist_ok=True) if model_name.endswith(".en"): @@ -25,18 +24,26 @@ def process(args: dict): elif language != "auto": args["language"] = language - audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval) - subtitles = get_subtitles( - audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args - ) + audios = get_audio(args.pop("video"), args.pop( + 'audio_channel'), sample_interval) + + model_args = {} + model_args["model_size_or_path"] = model_name + model_args["device"] = args.pop("device") + model_args["compute_type"] = args.pop("compute_type") + + srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir() + subtitles = get_subtitles(audios, srt_output_dir, model_args, args) if srt_only: return overlay_subtitles(subtitles, output_dir, sample_interval) -def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict): - model = WhisperAI(model_name, device, compute_type, model_args) + +def get_subtitles(audio_paths: list, output_dir: str, + model_args: dict, transcribe_args: dict): + model = WhisperAI(model_args, transcribe_args) subtitles_path = {} @@ -44,9 +51,8 @@ def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_na print( f"Generating subtitles for {filename(path)}... This might take a while." ) - srt_path = output_dir if output_srt else tempfile.gettempdir() - srt_path = os.path.join(srt_path, f"{filename(path)}.srt") - + srt_path = os.path.join(output_dir, f"{filename(path)}.srt") + segments = model.transcribe(audio_path) with open(srt_path, "w", encoding="utf-8") as srt: @@ -54,4 +60,4 @@ def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_na subtitles_path[path] = srt_path - return subtitles_path \ No newline at end of file + return subtitles_path diff --git a/auto_subtitle/utils/constants.py b/auto_subtitle/utils/constants.py new file mode 100644 index 0000000..993556f --- /dev/null +++ b/auto_subtitle/utils/constants.py @@ -0,0 +1,105 @@ +""" +List of available language codes +""" +LANGUAGE_CODES = [ + "af", + "am", + "ar", + "as", + "az", + "ba", + "be", + "bg", + "bn", + "bo", + "br", + "bs", + "ca", + "cs", + "cy", + "da", + "de", + "el", + "en", + "es", + "et", + "eu", + "fa", + "fi", + "fo", + "fr", + "gl", + "gu", + "ha", + "haw", + "he", + "hi", + "hr", + "ht", + "hu", + "hy", + "id", + "is", + "it", + "ja", + "jw", + "ka", + "kk", + "km", + "kn", + "ko", + "la", + "lb", + "ln", + "lo", + "lt", + "lv", + "mg", + "mi", + "mk", + "ml", + "mn", + "mr", + "ms", + "mt", + "my", + "ne", + "nl", + "nn", + "no", + "oc", + "pa", + "pl", + "ps", + "pt", + "ro", + "ru", + "sa", + "sd", + "si", + "sk", + "sl", + "sn", + "so", + "sq", + "sr", + "su", + "sv", + "sw", + "ta", + "te", + "tg", + "th", + "tk", + "tl", + "tr", + "tt", + "uk", + "ur", + "uz", + "vi", + "yi", + "yo", + "zh", + "yue", +] diff --git a/auto_subtitle/utils/convert.py b/auto_subtitle/utils/convert.py index 7d14df5..df05529 100644 --- a/auto_subtitle/utils/convert.py +++ b/auto_subtitle/utils/convert.py @@ -1,23 +1,25 @@ from datetime import datetime, timedelta -def str2bool(string): + +def str2bool(string: str): string = string.lower() str2val = {"true": True, "false": False} if string in str2val: return str2val[string] - else: - raise ValueError( - f"Expected one of {set(str2val.keys())}, got {string}") -def str2timeinterval(string): + raise ValueError( + f"Expected one of {set(str2val.keys())}, got {string}") + + +def str2timeinterval(string: str): if string is None: return None - + if '-' not in string: raise ValueError( f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}") - + intervals = string.split('-') if len(intervals) != 2: raise ValueError( @@ -28,42 +30,47 @@ def str2timeinterval(string): if start >= end: raise ValueError( f"Expected time interval end to be higher than start, got {start} >= {end}") - + return [start, end] -def time_to_timestamp(string): + +def time_to_timestamp(string: str): split_time = string.split(':') - if len(split_time) == 0 or len(split_time) > 3 or not all([ x.isdigit() for x in split_time ]): + if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time): raise ValueError( f"Expected HH:mm:ss or HH:mm or ss, got {string}") - + if len(split_time) == 1: return int(split_time[0]) - + if len(split_time) == 2: return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 - + return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2]) -def try_parse_timestamp(string): + +def try_parse_timestamp(string: str): timestamp = parse_timestamp(string, '%H:%M:%S') if timestamp is not None: return timestamp - + timestamp = parse_timestamp(string, '%H:%M') if timestamp is not None: return timestamp - + return parse_timestamp(string, '%S') -def parse_timestamp(string, pattern): + +def parse_timestamp(string: str, pattern: str): try: date = datetime.strptime(string, pattern) - delta = timedelta(hours=date.hour, minutes=date.minute, seconds=date.second) + delta = timedelta( + hours=date.hour, minutes=date.minute, seconds=date.second) return int(delta.total_seconds()) - except: + except: # pylint: disable=bare-except return None + def format_timestamp(seconds: float, always_include_hours: bool = False): assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) @@ -79,4 +86,3 @@ def format_timestamp(seconds: float, always_include_hours: bool = False): hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return f"{hours_marker}{minutes:02d}:{seconds:02d},{milliseconds:03d}" - diff --git a/auto_subtitle/utils/ffmpeg.py b/auto_subtitle/utils/ffmpeg.py index 0ea7f43..9f6fdd4 100644 --- a/auto_subtitle/utils/ffmpeg.py +++ b/auto_subtitle/utils/ffmpeg.py @@ -1,9 +1,10 @@ import os -import ffmpeg import tempfile +import ffmpeg from .mytempfile import MyTempFile from .files import filename + def get_audio(paths: list, audio_channel_index: int, sample_interval: list): temp_dir = tempfile.gettempdir() @@ -13,18 +14,19 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list): print(f"Extracting audio from {filename(path)}...") output_path = os.path.join(temp_dir, f"{filename(path)}.wav") - ffmpeg_input_args = dict() + ffmpeg_input_args = {} if sample_interval is not None: ffmpeg_input_args['ss'] = str(sample_interval[0]) - ffmpeg_output_args = dict() + ffmpeg_output_args = {} ffmpeg_output_args['acodec'] = "pcm_s16le" ffmpeg_output_args['ac'] = "1" ffmpeg_output_args['ar'] = "16k" ffmpeg_output_args['map'] = "0:a:" + str(audio_channel_index) if sample_interval is not None: - ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0]) - + ffmpeg_output_args['t'] = str( + sample_interval[1] - sample_interval[0]) + ffmpeg.input(path, **ffmpeg_input_args).output( output_path, **ffmpeg_output_args @@ -34,9 +36,6 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list): return audio_paths -def escape_windows_path(path: str): - return path.replace("\\", "/").replace(":", ":").replace(" ", "\\ ").replace("(", "\\(").replace(")", "\\)").replace("[", "\\[").replace("]", "\\]").replace("'", "'\\''") - def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list): for path, srt_path in subtitles.items(): @@ -44,22 +43,26 @@ def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list): print(f"Adding subtitles to {filename(path)}...") - ffmpeg_input_args = dict() + ffmpeg_input_args = {} if sample_interval is not None: ffmpeg_input_args['ss'] = str(sample_interval[0]) - ffmpeg_output_args = dict() + ffmpeg_output_args = {} if sample_interval is not None: - ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0]) + ffmpeg_output_args['t'] = str( + sample_interval[1] - sample_interval[0]) - # HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg, so we use temp copy instead + # HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg + # so we use temp copy instead # see: https://github.com/kkroening/ffmpeg-python/issues/745 with MyTempFile(srt_path) as srt_temp: video = ffmpeg.input(path, **ffmpeg_input_args) audio = video.audio ffmpeg.concat( - video.filter('subtitles', srt_temp.tmp_file_path, force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1 + video.filter( + 'subtitles', srt_temp.tmp_file_path, + force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1 ).output(out_path, **ffmpeg_output_args).run(quiet=True, overwrite_output=True) - print(f"Saved subtitled video to {os.path.abspath(out_path)}.") \ No newline at end of file + print(f"Saved subtitled video to {os.path.abspath(out_path)}.") diff --git a/auto_subtitle/utils/files.py b/auto_subtitle/utils/files.py index 5caaead..8a9476b 100644 --- a/auto_subtitle/utils/files.py +++ b/auto_subtitle/utils/files.py @@ -13,5 +13,5 @@ def write_srt(transcript: Iterator[dict], file: TextIO): flush=True, ) -def filename(path): +def filename(path: str): return os.path.splitext(os.path.basename(path))[0] diff --git a/auto_subtitle/utils/mytempfile.py b/auto_subtitle/utils/mytempfile.py index e1dc0cf..372c74d 100644 --- a/auto_subtitle/utils/mytempfile.py +++ b/auto_subtitle/utils/mytempfile.py @@ -3,8 +3,25 @@ import os import shutil class MyTempFile: + """ + A context manager for creating a temporary file in current directory, copying the content from + a specified file, and handling cleanup operations upon exiting the context. + + Usage: + ```python + with MyTempFile(file_path) as temp_file_manager: + # Access the temporary file using temp_file_manager.tmp_file + # ... + # The temporary file is automatically closed and removed upon exiting the context. + ``` + + Args: + - file_path (str): The path to the file whose content will be copied to the temporary file. + """ def __init__(self, file_path): self.file_path = file_path + self.tmp_file = None + self.tmp_file_path = None def __enter__(self): self.tmp_file = tempfile.NamedTemporaryFile('w', dir='.', delete=False) diff --git a/auto_subtitle/utils/whisper.py b/auto_subtitle/utils/whisper.py index 3bad24c..9d21972 100644 --- a/auto_subtitle/utils/whisper.py +++ b/auto_subtitle/utils/whisper.py @@ -2,20 +2,64 @@ import warnings import faster_whisper from tqdm import tqdm +# pylint: disable=R0903 class WhisperAI: - def __init__(self, model_name, device, compute_type, model_args): - self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type) - self.model_args = model_args + """ + Wrapper class for the Whisper speech recognition model with additional functionality. - def transcribe(self, audio_path): + This class provides a high-level interface for transcribing audio files using the Whisper + speech recognition model. It encapsulates the model instantiation and transcription process, + allowing users to easily transcribe audio files and iterate over the resulting segments. + + Usage: + ```python + whisper = WhisperAI(model_args, transcribe_args) + + # Transcribe an audio file and iterate over the segments + for segment in whisper.transcribe(audio_path): + # Process each transcription segment + print(segment) + ``` + + Args: + - model_args: Arguments to pass to WhisperModel initialize method + - model_size_or_path (str): The name of the Whisper model to use. + - device (str): The device to use for computation ("cpu", "cuda", "auto"). + - compute_type (str): The type to use for computation. + See https://opennmt.net/CTranslate2/quantization.html. + - transcribe_args (dict): Additional arguments to pass to the transcribe method. + + Attributes: + - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model. + - transcribe_args (dict): Additional arguments used for transcribe method. + + Methods: + - transcribe(audio_path): Transcribes an audio file and yields the resulting segments. + """ + + def __init__(self, model_args: dict, transcribe_args: dict): + self.model = faster_whisper.WhisperModel(**model_args) + self.transcribe_args = transcribe_args + + def transcribe(self, audio_path: str): + """ + Transcribes the specified audio file and yields the resulting segments. + + Args: + - audio_path (str): The path to the audio file for transcription. + + Yields: + - faster_whisper.TranscriptionSegment: An individual transcription segment. + """ warnings.filterwarnings("ignore") - segments, info = self.model.transcribe(audio_path, **self.model_args) + segments, info = self.model.transcribe(audio_path, **self.transcribe_args) warnings.filterwarnings("default") - total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps. + # Same precision as the Whisper timestamps. + total_duration = round(info.duration, 2) with tqdm(total=total_duration, unit=" seconds") as pbar: for segment in segments: yield segment pbar.update(segment.end - segment.start) - pbar.update(0) \ No newline at end of file + pbar.update(0)