Browse Source

Add GitHub Workflow with Pylint analyzer

Sergey Chernyaev 4 months ago
parent
commit
1c0cdb6eba

+ 24 - 0
.github/workflows/pylint.yml

@@ -0,0 +1,24 @@
+name: Pylint
+
+on: [push]
+
+jobs:
+  build:
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: ["3.9"]
+    steps:
+    - uses: actions/checkout@v3
+    - name: Set up Python ${{ matrix.python-version }}
+      uses: actions/setup-python@v3
+      with:
+        python-version: ${{ matrix.python-version }}
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install pylint
+        pip install -r requirements.txt
+    - name: Analysing the code with pylint
+      run: |
+        pylint --disable=C0114 --disable=C0115 --disable=C0116 $(git ls-files '*.py')

+ 25 - 10
auto_subtitle/cli.py

@@ -1,9 +1,17 @@
 import argparse
 from faster_whisper import available_models
+from .utils.constants import LANGUAGE_CODES
 from .main import process
 from .utils.convert import str2bool, str2timeinterval
 
+
 def main():
+    """
+    Main entry point for the script.
+
+    Parses command line arguments, processes the inputs using the specified options,
+    and performs transcription or translation based on the specified task.
+    """
     parser = argparse.ArgumentParser(
         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument("video", nargs="+", type=str,
@@ -11,15 +19,18 @@ def main():
     parser.add_argument("--audio_channel", default="0",
                         type=int, help="audio channel index to use")
     parser.add_argument("--sample_interval", type=str2timeinterval, default=None,
-                        help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
+                        help="generate subtitles for a specific \
+                              fragment of the video (e.g. 01:02:05-01:03:45)")
     parser.add_argument("--model", default="small",
                         choices=available_models(), help="name of the Whisper model to use")
-    parser.add_argument("--device", type=str, default="auto", choices=[
-                        "cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
+    parser.add_argument("--device", type=str, default="auto",
+                        choices=["cpu", "cuda", "auto"],
+                        help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
     parser.add_argument("--compute_type", type=str, default="default", choices=[
-                        "int8", "int8_float32", "int8_float16",
-                        "int8_bfloat16", "int16", "float16",
-                        "bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.")
+                        "int8", "int8_float32", "int8_float16", "int8_bfloat16",
+                        "int16", "float16", "bfloat16", "float32"],
+                        help="Type to use for computation. \
+                              See https://opennmt.net/CTranslate2/quantization.html.")
     parser.add_argument("--output_dir", "-o", type=str,
                         default=".", help="directory to save the outputs")
     parser.add_argument("--output_srt", type=str2bool, default=False,
@@ -32,10 +43,14 @@ def main():
                         help="model parameter, tweak to increase accuracy")
     parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
                         help="model parameter, tweak to increase accuracy")
-    parser.add_argument("--task", type=str, default="transcribe", choices=[
-                        "transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
-    parser.add_argument("--language", type=str, default="auto", choices=["auto","af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca","cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it","ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi","mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no","oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk","sl","sn","so","sq","sr","su","sv","sw","ta","te","tg","th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo","zh"], 
-    help="What is the origin language of the video? If unset, it is detected automatically.")
+    parser.add_argument("--task", type=str, default="transcribe",
+                        choices=["transcribe", "translate"],
+                        help="whether to perform X->X speech recognition ('transcribe') \
+                              or X->English translation ('translate')")
+    parser.add_argument("--language", type=str, default="auto",
+                        choices=LANGUAGE_CODES,
+                        help="What is the origin language of the video? \
+                              If unset, it is detected automatically.")
 
     args = parser.parse_args().__dict__
 

+ 19 - 13
auto_subtitle/main.py

@@ -5,6 +5,7 @@ from .utils.files import filename, write_srt
 from .utils.ffmpeg import get_audio, overlay_subtitles
 from .utils.whisper import WhisperAI
 
+
 def process(args: dict):
     model_name: str = args.pop("model")
     output_dir: str = args.pop("output_dir")
@@ -12,9 +13,7 @@ def process(args: dict):
     srt_only: bool = args.pop("srt_only")
     language: str = args.pop("language")
     sample_interval: str = args.pop("sample_interval")
-    device: str = args.pop("device")
-    compute_type: str = args.pop("compute_type")
-    
+
     os.makedirs(output_dir, exist_ok=True)
 
     if model_name.endswith(".en"):
@@ -25,18 +24,26 @@ def process(args: dict):
     elif language != "auto":
         args["language"] = language
 
-    audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
-    subtitles = get_subtitles(
-        audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
-    )
+    audios = get_audio(args.pop("video"), args.pop(
+        'audio_channel'), sample_interval)
+
+    model_args = {}
+    model_args["model_size_or_path"] = model_name
+    model_args["device"] = args.pop("device")
+    model_args["compute_type"] = args.pop("compute_type")
+
+    srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir()
+    subtitles = get_subtitles(audios, srt_output_dir, model_args, args)
 
     if srt_only:
         return
 
     overlay_subtitles(subtitles, output_dir, sample_interval)
 
-def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
-    model = WhisperAI(model_name, device, compute_type, model_args)
+
+def get_subtitles(audio_paths: list, output_dir: str,
+                  model_args: dict, transcribe_args: dict):
+    model = WhisperAI(model_args, transcribe_args)
 
     subtitles_path = {}
 
@@ -44,9 +51,8 @@ def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_na
         print(
             f"Generating subtitles for {filename(path)}... This might take a while."
         )
-        srt_path = output_dir if output_srt else tempfile.gettempdir()
-        srt_path = os.path.join(srt_path, f"{filename(path)}.srt")
-        
+        srt_path = os.path.join(output_dir, f"{filename(path)}.srt")
+
         segments = model.transcribe(audio_path)
 
         with open(srt_path, "w", encoding="utf-8") as srt:
@@ -54,4 +60,4 @@ def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_na
 
         subtitles_path[path] = srt_path
 
-    return subtitles_path
+    return subtitles_path

+ 105 - 0
auto_subtitle/utils/constants.py

@@ -0,0 +1,105 @@
+"""
+List of available language codes
+"""
+LANGUAGE_CODES = [
+    "af",
+    "am",
+    "ar",
+    "as",
+    "az",
+    "ba",
+    "be",
+    "bg",
+    "bn",
+    "bo",
+    "br",
+    "bs",
+    "ca",
+    "cs",
+    "cy",
+    "da",
+    "de",
+    "el",
+    "en",
+    "es",
+    "et",
+    "eu",
+    "fa",
+    "fi",
+    "fo",
+    "fr",
+    "gl",
+    "gu",
+    "ha",
+    "haw",
+    "he",
+    "hi",
+    "hr",
+    "ht",
+    "hu",
+    "hy",
+    "id",
+    "is",
+    "it",
+    "ja",
+    "jw",
+    "ka",
+    "kk",
+    "km",
+    "kn",
+    "ko",
+    "la",
+    "lb",
+    "ln",
+    "lo",
+    "lt",
+    "lv",
+    "mg",
+    "mi",
+    "mk",
+    "ml",
+    "mn",
+    "mr",
+    "ms",
+    "mt",
+    "my",
+    "ne",
+    "nl",
+    "nn",
+    "no",
+    "oc",
+    "pa",
+    "pl",
+    "ps",
+    "pt",
+    "ro",
+    "ru",
+    "sa",
+    "sd",
+    "si",
+    "sk",
+    "sl",
+    "sn",
+    "so",
+    "sq",
+    "sr",
+    "su",
+    "sv",
+    "sw",
+    "ta",
+    "te",
+    "tg",
+    "th",
+    "tk",
+    "tl",
+    "tr",
+    "tt",
+    "uk",
+    "ur",
+    "uz",
+    "vi",
+    "yi",
+    "yo",
+    "zh",
+    "yue",
+]

+ 26 - 20
auto_subtitle/utils/convert.py

@@ -1,23 +1,25 @@
 from datetime import datetime, timedelta
 
-def str2bool(string):
+
+def str2bool(string: str):
     string = string.lower()
     str2val = {"true": True, "false": False}
 
     if string in str2val:
         return str2val[string]
-    else:
-        raise ValueError(
-            f"Expected one of {set(str2val.keys())}, got {string}")
 
-def str2timeinterval(string):
+    raise ValueError(
+        f"Expected one of {set(str2val.keys())}, got {string}")
+
+
+def str2timeinterval(string: str):
     if string is None:
         return None
-    
+
     if '-' not in string:
         raise ValueError(
             f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}")
-    
+
     intervals = string.split('-')
     if len(intervals) != 2:
         raise ValueError(
@@ -28,42 +30,47 @@ def str2timeinterval(string):
     if start >= end:
         raise ValueError(
             f"Expected time interval end to be higher than start, got {start} >= {end}")
-    
+
     return [start, end]
 
-def time_to_timestamp(string):
+
+def time_to_timestamp(string: str):
     split_time = string.split(':')
-    if len(split_time) == 0 or len(split_time) > 3 or not all([ x.isdigit() for x in split_time ]):
+    if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time):
         raise ValueError(
             f"Expected HH:mm:ss or HH:mm or ss, got {string}")
-    
+
     if len(split_time) == 1:
         return int(split_time[0])
-    
+
     if len(split_time) == 2:
         return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60
-    
+
     return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2])
 
-def try_parse_timestamp(string):
+
+def try_parse_timestamp(string: str):
     timestamp = parse_timestamp(string, '%H:%M:%S')
     if timestamp is not None:
         return timestamp
-    
+
     timestamp = parse_timestamp(string, '%H:%M')
     if timestamp is not None:
         return timestamp
-    
+
     return parse_timestamp(string, '%S')
 
-def parse_timestamp(string, pattern):
+
+def parse_timestamp(string: str, pattern: str):
     try:
         date = datetime.strptime(string, pattern)
-        delta = timedelta(hours=date.hour, minutes=date.minute, seconds=date.second)
+        delta = timedelta(
+            hours=date.hour, minutes=date.minute, seconds=date.second)
         return int(delta.total_seconds())
-    except:
+    except:  # pylint: disable=bare-except
         return None
 
+
 def format_timestamp(seconds: float, always_include_hours: bool = False):
     assert seconds >= 0, "non-negative timestamp expected"
     milliseconds = round(seconds * 1000.0)
@@ -79,4 +86,3 @@ def format_timestamp(seconds: float, always_include_hours: bool = False):
 
     hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
     return f"{hours_marker}{minutes:02d}:{seconds:02d},{milliseconds:03d}"
-

+ 17 - 14
auto_subtitle/utils/ffmpeg.py

@@ -1,9 +1,10 @@
 import os
-import ffmpeg
 import tempfile
+import ffmpeg
 from .mytempfile import MyTempFile
 from .files import filename
 
+
 def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
     temp_dir = tempfile.gettempdir()
 
@@ -13,18 +14,19 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
         print(f"Extracting audio from {filename(path)}...")
         output_path = os.path.join(temp_dir, f"{filename(path)}.wav")
 
-        ffmpeg_input_args = dict()
+        ffmpeg_input_args = {}
         if sample_interval is not None:
             ffmpeg_input_args['ss'] = str(sample_interval[0])
 
-        ffmpeg_output_args = dict()
+        ffmpeg_output_args = {}
         ffmpeg_output_args['acodec'] = "pcm_s16le"
         ffmpeg_output_args['ac'] = "1"
         ffmpeg_output_args['ar'] = "16k"
         ffmpeg_output_args['map'] = "0:a:" + str(audio_channel_index)
         if sample_interval is not None:
-            ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0])
-            
+            ffmpeg_output_args['t'] = str(
+                sample_interval[1] - sample_interval[0])
+
         ffmpeg.input(path, **ffmpeg_input_args).output(
             output_path,
             **ffmpeg_output_args
@@ -34,9 +36,6 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
 
     return audio_paths
 
-def escape_windows_path(path: str):
-    return path.replace("\\", "/").replace(":", ":").replace(" ", "\\ ").replace("(", "\\(").replace(")", "\\)").replace("[", "\\[").replace("]", "\\]").replace("'", "'\\''")
-
 
 def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list):
     for path, srt_path in subtitles.items():
@@ -44,22 +43,26 @@ def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list):
 
         print(f"Adding subtitles to {filename(path)}...")
 
-        ffmpeg_input_args = dict()
+        ffmpeg_input_args = {}
         if sample_interval is not None:
             ffmpeg_input_args['ss'] = str(sample_interval[0])
 
-        ffmpeg_output_args = dict()
+        ffmpeg_output_args = {}
         if sample_interval is not None:
-            ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0])
+            ffmpeg_output_args['t'] = str(
+                sample_interval[1] - sample_interval[0])
 
-        # HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg, so we use temp copy instead
+        # HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg
+        # so we use temp copy instead
         # see: https://github.com/kkroening/ffmpeg-python/issues/745
         with MyTempFile(srt_path) as srt_temp:
             video = ffmpeg.input(path, **ffmpeg_input_args)
             audio = video.audio
 
             ffmpeg.concat(
-                video.filter('subtitles', srt_temp.tmp_file_path, force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1
+                video.filter(
+                    'subtitles', srt_temp.tmp_file_path,
+                    force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1
             ).output(out_path, **ffmpeg_output_args).run(quiet=True, overwrite_output=True)
 
-        print(f"Saved subtitled video to {os.path.abspath(out_path)}.")
+        print(f"Saved subtitled video to {os.path.abspath(out_path)}.")

+ 1 - 1
auto_subtitle/utils/files.py

@@ -13,5 +13,5 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
             flush=True,
         )
 
-def filename(path):
+def filename(path: str):
     return os.path.splitext(os.path.basename(path))[0]

+ 17 - 0
auto_subtitle/utils/mytempfile.py

@@ -3,8 +3,25 @@ import os
 import shutil
 
 class MyTempFile:
+    """
+    A context manager for creating a temporary file in current directory, copying the content from
+    a specified file, and handling cleanup operations upon exiting the context.
+
+    Usage:
+    ```python
+    with MyTempFile(file_path) as temp_file_manager:
+        # Access the temporary file using temp_file_manager.tmp_file
+        # ...
+    # The temporary file is automatically closed and removed upon exiting the context.
+    ```
+
+    Args:
+    - file_path (str): The path to the file whose content will be copied to the temporary file.
+    """
     def __init__(self, file_path):
         self.file_path = file_path
+        self.tmp_file = None
+        self.tmp_file_path = None
 
     def __enter__(self):
         self.tmp_file = tempfile.NamedTemporaryFile('w', dir='.', delete=False)

+ 51 - 7
auto_subtitle/utils/whisper.py

@@ -2,20 +2,64 @@ import warnings
 import faster_whisper
 from tqdm import tqdm
 
+# pylint: disable=R0903
 class WhisperAI:
-    def __init__(self, model_name, device, compute_type, model_args):
-        self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type)
-        self.model_args = model_args
+    """
+    Wrapper class for the Whisper speech recognition model with additional functionality.
 
-    def transcribe(self, audio_path):
+    This class provides a high-level interface for transcribing audio files using the Whisper
+    speech recognition model. It encapsulates the model instantiation and transcription process,
+    allowing users to easily transcribe audio files and iterate over the resulting segments.
+
+    Usage:
+    ```python
+    whisper = WhisperAI(model_args, transcribe_args)
+
+    # Transcribe an audio file and iterate over the segments
+    for segment in whisper.transcribe(audio_path):
+        # Process each transcription segment
+        print(segment)
+    ```
+
+    Args:
+    - model_args: Arguments to pass to WhisperModel initialize method
+        - model_size_or_path (str): The name of the Whisper model to use.
+        - device (str): The device to use for computation ("cpu", "cuda", "auto").
+        - compute_type (str): The type to use for computation.
+            See https://opennmt.net/CTranslate2/quantization.html.
+    - transcribe_args (dict): Additional arguments to pass to the transcribe method.
+
+    Attributes:
+    - model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
+    - transcribe_args (dict): Additional arguments used for transcribe method.
+
+    Methods:
+    - transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
+    """
+
+    def __init__(self, model_args: dict, transcribe_args: dict):
+        self.model = faster_whisper.WhisperModel(**model_args)
+        self.transcribe_args = transcribe_args
+
+    def transcribe(self, audio_path: str):
+        """
+        Transcribes the specified audio file and yields the resulting segments.
+
+        Args:
+        - audio_path (str): The path to the audio file for transcription.
+
+        Yields:
+        - faster_whisper.TranscriptionSegment: An individual transcription segment.
+        """
         warnings.filterwarnings("ignore")
-        segments, info = self.model.transcribe(audio_path, **self.model_args)
+        segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
         warnings.filterwarnings("default")
 
-        total_duration = round(info.duration, 2)  # Same precision as the Whisper timestamps.
+        # Same precision as the Whisper timestamps.
+        total_duration = round(info.duration, 2)
 
         with tqdm(total=total_duration, unit=" seconds") as pbar:
             for segment in segments:
                 yield segment
                 pbar.update(segment.end - segment.start)
-            pbar.update(0)
+            pbar.update(0)