mirror of
https://github.com/karl0ss/bazarr-ai-sub-generator.git
synced 2025-04-26 14:59:21 +01:00
Add GitHub Workflow with Pylint analyzer
This commit is contained in:
parent
fab2921954
commit
1c0cdb6eba
24
.github/workflows/pylint.yml
vendored
Normal file
24
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
name: Pylint
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.9"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pylint
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Analysing the code with pylint
|
||||||
|
run: |
|
||||||
|
pylint --disable=C0114 --disable=C0115 --disable=C0116 $(git ls-files '*.py')
|
@ -1,9 +1,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from faster_whisper import available_models
|
from faster_whisper import available_models
|
||||||
|
from .utils.constants import LANGUAGE_CODES
|
||||||
from .main import process
|
from .main import process
|
||||||
from .utils.convert import str2bool, str2timeinterval
|
from .utils.convert import str2bool, str2timeinterval
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
Main entry point for the script.
|
||||||
|
|
||||||
|
Parses command line arguments, processes the inputs using the specified options,
|
||||||
|
and performs transcription or translation based on the specified task.
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("video", nargs="+", type=str,
|
parser.add_argument("video", nargs="+", type=str,
|
||||||
@ -11,15 +19,18 @@ def main():
|
|||||||
parser.add_argument("--audio_channel", default="0",
|
parser.add_argument("--audio_channel", default="0",
|
||||||
type=int, help="audio channel index to use")
|
type=int, help="audio channel index to use")
|
||||||
parser.add_argument("--sample_interval", type=str2timeinterval, default=None,
|
parser.add_argument("--sample_interval", type=str2timeinterval, default=None,
|
||||||
help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
|
help="generate subtitles for a specific \
|
||||||
|
fragment of the video (e.g. 01:02:05-01:03:45)")
|
||||||
parser.add_argument("--model", default="small",
|
parser.add_argument("--model", default="small",
|
||||||
choices=available_models(), help="name of the Whisper model to use")
|
choices=available_models(), help="name of the Whisper model to use")
|
||||||
parser.add_argument("--device", type=str, default="auto", choices=[
|
parser.add_argument("--device", type=str, default="auto",
|
||||||
"cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
|
choices=["cpu", "cuda", "auto"],
|
||||||
|
help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
|
||||||
parser.add_argument("--compute_type", type=str, default="default", choices=[
|
parser.add_argument("--compute_type", type=str, default="default", choices=[
|
||||||
"int8", "int8_float32", "int8_float16",
|
"int8", "int8_float32", "int8_float16", "int8_bfloat16",
|
||||||
"int8_bfloat16", "int16", "float16",
|
"int16", "float16", "bfloat16", "float32"],
|
||||||
"bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.")
|
help="Type to use for computation. \
|
||||||
|
See https://opennmt.net/CTranslate2/quantization.html.")
|
||||||
parser.add_argument("--output_dir", "-o", type=str,
|
parser.add_argument("--output_dir", "-o", type=str,
|
||||||
default=".", help="directory to save the outputs")
|
default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_srt", type=str2bool, default=False,
|
parser.add_argument("--output_srt", type=str2bool, default=False,
|
||||||
@ -32,10 +43,14 @@ def main():
|
|||||||
help="model parameter, tweak to increase accuracy")
|
help="model parameter, tweak to increase accuracy")
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
|
||||||
help="model parameter, tweak to increase accuracy")
|
help="model parameter, tweak to increase accuracy")
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=[
|
parser.add_argument("--task", type=str, default="transcribe",
|
||||||
"transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
choices=["transcribe", "translate"],
|
||||||
parser.add_argument("--language", type=str, default="auto", choices=["auto","af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca","cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it","ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi","mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no","oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk","sl","sn","so","sq","sr","su","sv","sw","ta","te","tg","th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo","zh"],
|
help="whether to perform X->X speech recognition ('transcribe') \
|
||||||
help="What is the origin language of the video? If unset, it is detected automatically.")
|
or X->English translation ('translate')")
|
||||||
|
parser.add_argument("--language", type=str, default="auto",
|
||||||
|
choices=LANGUAGE_CODES,
|
||||||
|
help="What is the origin language of the video? \
|
||||||
|
If unset, it is detected automatically.")
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from .utils.files import filename, write_srt
|
|||||||
from .utils.ffmpeg import get_audio, overlay_subtitles
|
from .utils.ffmpeg import get_audio, overlay_subtitles
|
||||||
from .utils.whisper import WhisperAI
|
from .utils.whisper import WhisperAI
|
||||||
|
|
||||||
|
|
||||||
def process(args: dict):
|
def process(args: dict):
|
||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
@ -12,8 +13,6 @@ def process(args: dict):
|
|||||||
srt_only: bool = args.pop("srt_only")
|
srt_only: bool = args.pop("srt_only")
|
||||||
language: str = args.pop("language")
|
language: str = args.pop("language")
|
||||||
sample_interval: str = args.pop("sample_interval")
|
sample_interval: str = args.pop("sample_interval")
|
||||||
device: str = args.pop("device")
|
|
||||||
compute_type: str = args.pop("compute_type")
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -25,18 +24,26 @@ def process(args: dict):
|
|||||||
elif language != "auto":
|
elif language != "auto":
|
||||||
args["language"] = language
|
args["language"] = language
|
||||||
|
|
||||||
audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
|
audios = get_audio(args.pop("video"), args.pop(
|
||||||
subtitles = get_subtitles(
|
'audio_channel'), sample_interval)
|
||||||
audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
|
|
||||||
)
|
model_args = {}
|
||||||
|
model_args["model_size_or_path"] = model_name
|
||||||
|
model_args["device"] = args.pop("device")
|
||||||
|
model_args["compute_type"] = args.pop("compute_type")
|
||||||
|
|
||||||
|
srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir()
|
||||||
|
subtitles = get_subtitles(audios, srt_output_dir, model_args, args)
|
||||||
|
|
||||||
if srt_only:
|
if srt_only:
|
||||||
return
|
return
|
||||||
|
|
||||||
overlay_subtitles(subtitles, output_dir, sample_interval)
|
overlay_subtitles(subtitles, output_dir, sample_interval)
|
||||||
|
|
||||||
def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
|
|
||||||
model = WhisperAI(model_name, device, compute_type, model_args)
|
def get_subtitles(audio_paths: list, output_dir: str,
|
||||||
|
model_args: dict, transcribe_args: dict):
|
||||||
|
model = WhisperAI(model_args, transcribe_args)
|
||||||
|
|
||||||
subtitles_path = {}
|
subtitles_path = {}
|
||||||
|
|
||||||
@ -44,8 +51,7 @@ def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_na
|
|||||||
print(
|
print(
|
||||||
f"Generating subtitles for {filename(path)}... This might take a while."
|
f"Generating subtitles for {filename(path)}... This might take a while."
|
||||||
)
|
)
|
||||||
srt_path = output_dir if output_srt else tempfile.gettempdir()
|
srt_path = os.path.join(output_dir, f"{filename(path)}.srt")
|
||||||
srt_path = os.path.join(srt_path, f"{filename(path)}.srt")
|
|
||||||
|
|
||||||
segments = model.transcribe(audio_path)
|
segments = model.transcribe(audio_path)
|
||||||
|
|
||||||
|
105
auto_subtitle/utils/constants.py
Normal file
105
auto_subtitle/utils/constants.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
List of available language codes
|
||||||
|
"""
|
||||||
|
LANGUAGE_CODES = [
|
||||||
|
"af",
|
||||||
|
"am",
|
||||||
|
"ar",
|
||||||
|
"as",
|
||||||
|
"az",
|
||||||
|
"ba",
|
||||||
|
"be",
|
||||||
|
"bg",
|
||||||
|
"bn",
|
||||||
|
"bo",
|
||||||
|
"br",
|
||||||
|
"bs",
|
||||||
|
"ca",
|
||||||
|
"cs",
|
||||||
|
"cy",
|
||||||
|
"da",
|
||||||
|
"de",
|
||||||
|
"el",
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"et",
|
||||||
|
"eu",
|
||||||
|
"fa",
|
||||||
|
"fi",
|
||||||
|
"fo",
|
||||||
|
"fr",
|
||||||
|
"gl",
|
||||||
|
"gu",
|
||||||
|
"ha",
|
||||||
|
"haw",
|
||||||
|
"he",
|
||||||
|
"hi",
|
||||||
|
"hr",
|
||||||
|
"ht",
|
||||||
|
"hu",
|
||||||
|
"hy",
|
||||||
|
"id",
|
||||||
|
"is",
|
||||||
|
"it",
|
||||||
|
"ja",
|
||||||
|
"jw",
|
||||||
|
"ka",
|
||||||
|
"kk",
|
||||||
|
"km",
|
||||||
|
"kn",
|
||||||
|
"ko",
|
||||||
|
"la",
|
||||||
|
"lb",
|
||||||
|
"ln",
|
||||||
|
"lo",
|
||||||
|
"lt",
|
||||||
|
"lv",
|
||||||
|
"mg",
|
||||||
|
"mi",
|
||||||
|
"mk",
|
||||||
|
"ml",
|
||||||
|
"mn",
|
||||||
|
"mr",
|
||||||
|
"ms",
|
||||||
|
"mt",
|
||||||
|
"my",
|
||||||
|
"ne",
|
||||||
|
"nl",
|
||||||
|
"nn",
|
||||||
|
"no",
|
||||||
|
"oc",
|
||||||
|
"pa",
|
||||||
|
"pl",
|
||||||
|
"ps",
|
||||||
|
"pt",
|
||||||
|
"ro",
|
||||||
|
"ru",
|
||||||
|
"sa",
|
||||||
|
"sd",
|
||||||
|
"si",
|
||||||
|
"sk",
|
||||||
|
"sl",
|
||||||
|
"sn",
|
||||||
|
"so",
|
||||||
|
"sq",
|
||||||
|
"sr",
|
||||||
|
"su",
|
||||||
|
"sv",
|
||||||
|
"sw",
|
||||||
|
"ta",
|
||||||
|
"te",
|
||||||
|
"tg",
|
||||||
|
"th",
|
||||||
|
"tk",
|
||||||
|
"tl",
|
||||||
|
"tr",
|
||||||
|
"tt",
|
||||||
|
"uk",
|
||||||
|
"ur",
|
||||||
|
"uz",
|
||||||
|
"vi",
|
||||||
|
"yi",
|
||||||
|
"yo",
|
||||||
|
"zh",
|
||||||
|
"yue",
|
||||||
|
]
|
@ -1,16 +1,18 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
def str2bool(string):
|
|
||||||
|
def str2bool(string: str):
|
||||||
string = string.lower()
|
string = string.lower()
|
||||||
str2val = {"true": True, "false": False}
|
str2val = {"true": True, "false": False}
|
||||||
|
|
||||||
if string in str2val:
|
if string in str2val:
|
||||||
return str2val[string]
|
return str2val[string]
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected one of {set(str2val.keys())}, got {string}")
|
f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
def str2timeinterval(string):
|
|
||||||
|
def str2timeinterval(string: str):
|
||||||
if string is None:
|
if string is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -31,9 +33,10 @@ def str2timeinterval(string):
|
|||||||
|
|
||||||
return [start, end]
|
return [start, end]
|
||||||
|
|
||||||
def time_to_timestamp(string):
|
|
||||||
|
def time_to_timestamp(string: str):
|
||||||
split_time = string.split(':')
|
split_time = string.split(':')
|
||||||
if len(split_time) == 0 or len(split_time) > 3 or not all([ x.isdigit() for x in split_time ]):
|
if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected HH:mm:ss or HH:mm or ss, got {string}")
|
f"Expected HH:mm:ss or HH:mm or ss, got {string}")
|
||||||
|
|
||||||
@ -45,7 +48,8 @@ def time_to_timestamp(string):
|
|||||||
|
|
||||||
return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2])
|
return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2])
|
||||||
|
|
||||||
def try_parse_timestamp(string):
|
|
||||||
|
def try_parse_timestamp(string: str):
|
||||||
timestamp = parse_timestamp(string, '%H:%M:%S')
|
timestamp = parse_timestamp(string, '%H:%M:%S')
|
||||||
if timestamp is not None:
|
if timestamp is not None:
|
||||||
return timestamp
|
return timestamp
|
||||||
@ -56,14 +60,17 @@ def try_parse_timestamp(string):
|
|||||||
|
|
||||||
return parse_timestamp(string, '%S')
|
return parse_timestamp(string, '%S')
|
||||||
|
|
||||||
def parse_timestamp(string, pattern):
|
|
||||||
|
def parse_timestamp(string: str, pattern: str):
|
||||||
try:
|
try:
|
||||||
date = datetime.strptime(string, pattern)
|
date = datetime.strptime(string, pattern)
|
||||||
delta = timedelta(hours=date.hour, minutes=date.minute, seconds=date.second)
|
delta = timedelta(
|
||||||
|
hours=date.hour, minutes=date.minute, seconds=date.second)
|
||||||
return int(delta.total_seconds())
|
return int(delta.total_seconds())
|
||||||
except:
|
except: # pylint: disable=bare-except
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(seconds: float, always_include_hours: bool = False):
|
def format_timestamp(seconds: float, always_include_hours: bool = False):
|
||||||
assert seconds >= 0, "non-negative timestamp expected"
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
milliseconds = round(seconds * 1000.0)
|
milliseconds = round(seconds * 1000.0)
|
||||||
@ -79,4 +86,3 @@ def format_timestamp(seconds: float, always_include_hours: bool = False):
|
|||||||
|
|
||||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||||
return f"{hours_marker}{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
return f"{hours_marker}{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import ffmpeg
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import ffmpeg
|
||||||
from .mytempfile import MyTempFile
|
from .mytempfile import MyTempFile
|
||||||
from .files import filename
|
from .files import filename
|
||||||
|
|
||||||
|
|
||||||
def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
|
|
||||||
@ -13,17 +14,18 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
|||||||
print(f"Extracting audio from {filename(path)}...")
|
print(f"Extracting audio from {filename(path)}...")
|
||||||
output_path = os.path.join(temp_dir, f"{filename(path)}.wav")
|
output_path = os.path.join(temp_dir, f"{filename(path)}.wav")
|
||||||
|
|
||||||
ffmpeg_input_args = dict()
|
ffmpeg_input_args = {}
|
||||||
if sample_interval is not None:
|
if sample_interval is not None:
|
||||||
ffmpeg_input_args['ss'] = str(sample_interval[0])
|
ffmpeg_input_args['ss'] = str(sample_interval[0])
|
||||||
|
|
||||||
ffmpeg_output_args = dict()
|
ffmpeg_output_args = {}
|
||||||
ffmpeg_output_args['acodec'] = "pcm_s16le"
|
ffmpeg_output_args['acodec'] = "pcm_s16le"
|
||||||
ffmpeg_output_args['ac'] = "1"
|
ffmpeg_output_args['ac'] = "1"
|
||||||
ffmpeg_output_args['ar'] = "16k"
|
ffmpeg_output_args['ar'] = "16k"
|
||||||
ffmpeg_output_args['map'] = "0:a:" + str(audio_channel_index)
|
ffmpeg_output_args['map'] = "0:a:" + str(audio_channel_index)
|
||||||
if sample_interval is not None:
|
if sample_interval is not None:
|
||||||
ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0])
|
ffmpeg_output_args['t'] = str(
|
||||||
|
sample_interval[1] - sample_interval[0])
|
||||||
|
|
||||||
ffmpeg.input(path, **ffmpeg_input_args).output(
|
ffmpeg.input(path, **ffmpeg_input_args).output(
|
||||||
output_path,
|
output_path,
|
||||||
@ -34,9 +36,6 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
|||||||
|
|
||||||
return audio_paths
|
return audio_paths
|
||||||
|
|
||||||
def escape_windows_path(path: str):
|
|
||||||
return path.replace("\\", "/").replace(":", ":").replace(" ", "\\ ").replace("(", "\\(").replace(")", "\\)").replace("[", "\\[").replace("]", "\\]").replace("'", "'\\''")
|
|
||||||
|
|
||||||
|
|
||||||
def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list):
|
def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list):
|
||||||
for path, srt_path in subtitles.items():
|
for path, srt_path in subtitles.items():
|
||||||
@ -44,22 +43,26 @@ def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list):
|
|||||||
|
|
||||||
print(f"Adding subtitles to {filename(path)}...")
|
print(f"Adding subtitles to {filename(path)}...")
|
||||||
|
|
||||||
ffmpeg_input_args = dict()
|
ffmpeg_input_args = {}
|
||||||
if sample_interval is not None:
|
if sample_interval is not None:
|
||||||
ffmpeg_input_args['ss'] = str(sample_interval[0])
|
ffmpeg_input_args['ss'] = str(sample_interval[0])
|
||||||
|
|
||||||
ffmpeg_output_args = dict()
|
ffmpeg_output_args = {}
|
||||||
if sample_interval is not None:
|
if sample_interval is not None:
|
||||||
ffmpeg_output_args['t'] = str(sample_interval[1] - sample_interval[0])
|
ffmpeg_output_args['t'] = str(
|
||||||
|
sample_interval[1] - sample_interval[0])
|
||||||
|
|
||||||
# HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg, so we use temp copy instead
|
# HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg
|
||||||
|
# so we use temp copy instead
|
||||||
# see: https://github.com/kkroening/ffmpeg-python/issues/745
|
# see: https://github.com/kkroening/ffmpeg-python/issues/745
|
||||||
with MyTempFile(srt_path) as srt_temp:
|
with MyTempFile(srt_path) as srt_temp:
|
||||||
video = ffmpeg.input(path, **ffmpeg_input_args)
|
video = ffmpeg.input(path, **ffmpeg_input_args)
|
||||||
audio = video.audio
|
audio = video.audio
|
||||||
|
|
||||||
ffmpeg.concat(
|
ffmpeg.concat(
|
||||||
video.filter('subtitles', srt_temp.tmp_file_path, force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1
|
video.filter(
|
||||||
|
'subtitles', srt_temp.tmp_file_path,
|
||||||
|
force_style="OutlineColour=&H40000000,BorderStyle=3"), audio, v=1, a=1
|
||||||
).output(out_path, **ffmpeg_output_args).run(quiet=True, overwrite_output=True)
|
).output(out_path, **ffmpeg_output_args).run(quiet=True, overwrite_output=True)
|
||||||
|
|
||||||
print(f"Saved subtitled video to {os.path.abspath(out_path)}.")
|
print(f"Saved subtitled video to {os.path.abspath(out_path)}.")
|
@ -13,5 +13,5 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
|
|||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def filename(path):
|
def filename(path: str):
|
||||||
return os.path.splitext(os.path.basename(path))[0]
|
return os.path.splitext(os.path.basename(path))[0]
|
||||||
|
@ -3,8 +3,25 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
class MyTempFile:
|
class MyTempFile:
|
||||||
|
"""
|
||||||
|
A context manager for creating a temporary file in current directory, copying the content from
|
||||||
|
a specified file, and handling cleanup operations upon exiting the context.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
with MyTempFile(file_path) as temp_file_manager:
|
||||||
|
# Access the temporary file using temp_file_manager.tmp_file
|
||||||
|
# ...
|
||||||
|
# The temporary file is automatically closed and removed upon exiting the context.
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- file_path (str): The path to the file whose content will be copied to the temporary file.
|
||||||
|
"""
|
||||||
def __init__(self, file_path):
|
def __init__(self, file_path):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
self.tmp_file = None
|
||||||
|
self.tmp_file_path = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.tmp_file = tempfile.NamedTemporaryFile('w', dir='.', delete=False)
|
self.tmp_file = tempfile.NamedTemporaryFile('w', dir='.', delete=False)
|
||||||
|
@ -2,17 +2,61 @@ import warnings
|
|||||||
import faster_whisper
|
import faster_whisper
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# pylint: disable=R0903
|
||||||
class WhisperAI:
|
class WhisperAI:
|
||||||
def __init__(self, model_name, device, compute_type, model_args):
|
"""
|
||||||
self.model = faster_whisper.WhisperModel(model_name, device=device, compute_type=compute_type)
|
Wrapper class for the Whisper speech recognition model with additional functionality.
|
||||||
self.model_args = model_args
|
|
||||||
|
|
||||||
def transcribe(self, audio_path):
|
This class provides a high-level interface for transcribing audio files using the Whisper
|
||||||
|
speech recognition model. It encapsulates the model instantiation and transcription process,
|
||||||
|
allowing users to easily transcribe audio files and iterate over the resulting segments.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
whisper = WhisperAI(model_args, transcribe_args)
|
||||||
|
|
||||||
|
# Transcribe an audio file and iterate over the segments
|
||||||
|
for segment in whisper.transcribe(audio_path):
|
||||||
|
# Process each transcription segment
|
||||||
|
print(segment)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- model_args: Arguments to pass to WhisperModel initialize method
|
||||||
|
- model_size_or_path (str): The name of the Whisper model to use.
|
||||||
|
- device (str): The device to use for computation ("cpu", "cuda", "auto").
|
||||||
|
- compute_type (str): The type to use for computation.
|
||||||
|
See https://opennmt.net/CTranslate2/quantization.html.
|
||||||
|
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
|
||||||
|
- transcribe_args (dict): Additional arguments used for transcribe method.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_args: dict, transcribe_args: dict):
|
||||||
|
self.model = faster_whisper.WhisperModel(**model_args)
|
||||||
|
self.transcribe_args = transcribe_args
|
||||||
|
|
||||||
|
def transcribe(self, audio_path: str):
|
||||||
|
"""
|
||||||
|
Transcribes the specified audio file and yields the resulting segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- audio_path (str): The path to the audio file for transcription.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
- faster_whisper.TranscriptionSegment: An individual transcription segment.
|
||||||
|
"""
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
segments, info = self.model.transcribe(audio_path, **self.model_args)
|
segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
|
||||||
warnings.filterwarnings("default")
|
warnings.filterwarnings("default")
|
||||||
|
|
||||||
total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
|
# Same precision as the Whisper timestamps.
|
||||||
|
total_duration = round(info.duration, 2)
|
||||||
|
|
||||||
with tqdm(total=total_duration, unit=" seconds") as pbar:
|
with tqdm(total=total_duration, unit=" seconds") as pbar:
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user