mirror of
https://github.com/karl0ss/bazarr-ai-sub-generator.git
synced 2025-04-26 14:59:21 +01:00
commit
408fcd085c
@ -1,53 +0,0 @@
|
||||
import argparse
|
||||
from faster_whisper import available_models
|
||||
from utils.constants import LANGUAGE_CODES
|
||||
from main import process
|
||||
from utils.convert import str2bool, str2timeinterval
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the script.
|
||||
|
||||
Parses command line arguments, processes the inputs using the specified options,
|
||||
and performs transcription or translation based on the specified task.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--audio_channel", default="0",
|
||||
type=int, help="audio channel index to use")
|
||||
parser.add_argument("--sample_interval", type=str2timeinterval, default=None,
|
||||
help="generate subtitles for a specific \
|
||||
fragment of the video (e.g. 01:02:05-01:03:45)")
|
||||
parser.add_argument("--model", default="small",
|
||||
choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--device", type=str, default="auto",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
|
||||
parser.add_argument("--compute_type", type=str, default="default", choices=[
|
||||
"int8", "int8_float32", "int8_float16", "int8_bfloat16",
|
||||
"int16", "float16", "bfloat16", "float32"],
|
||||
help="Type to use for computation. \
|
||||
See https://opennmt.net/CTranslate2/quantization.html.")
|
||||
parser.add_argument("--beam_size", type=int, default=5,
|
||||
help="model parameter, tweak to increase accuracy")
|
||||
parser.add_argument("--no_speech_threshold", type=float, default=0.6,
|
||||
help="model parameter, tweak to increase accuracy")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
|
||||
help="model parameter, tweak to increase accuracy")
|
||||
parser.add_argument("--task", type=str, default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="whether to perform X->X speech recognition ('transcribe') \
|
||||
or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default="auto",
|
||||
choices=LANGUAGE_CODES,
|
||||
help="What is the origin language of the video? \
|
||||
If unset, it is detected automatically.")
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,25 +0,0 @@
|
||||
import requests
|
||||
import json
|
||||
import configparser
|
||||
config = configparser.RawConfigParser()
|
||||
config.read('config.cfg')
|
||||
|
||||
token = config._sections['sonarr']['token']
|
||||
base_url = config._sections['sonarr']['url']
|
||||
|
||||
def update_show_in_soarr(show_id):
|
||||
url = f"{base_url}/api/v3/command"
|
||||
|
||||
payload = json.dumps({
|
||||
"name": "RefreshSeries",
|
||||
"seriesId": show_id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Api-Key': token,
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 404:
|
||||
print("Updated show in Sonarr")
|
99
bazarr-ai-sub-generator/cli.py
Normal file
99
bazarr-ai-sub-generator/cli.py
Normal file
@ -0,0 +1,99 @@
|
||||
import argparse
|
||||
from faster_whisper import available_models
|
||||
from utils.constants import LANGUAGE_CODES
|
||||
from main import process
|
||||
from utils.convert import str2bool, str2timeinterval
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the script.
|
||||
|
||||
Parses command line arguments, processes the inputs using the specified options,
|
||||
and performs transcription or translation based on the specified task.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_channel", default="0", type=int, help="audio channel index to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_interval",
|
||||
type=str2timeinterval,
|
||||
default=None,
|
||||
help="generate subtitles for a specific \
|
||||
fragment of the video (e.g. 01:02:05-01:03:45)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="small",
|
||||
choices=available_models(),
|
||||
help="name of the Whisper model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["cpu", "cuda", "auto"],
|
||||
help='Device to use for computation ("cpu", "cuda", "auto")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_type",
|
||||
type=str,
|
||||
default="default",
|
||||
choices=[
|
||||
"int8",
|
||||
"int8_float32",
|
||||
"int8_float16",
|
||||
"int8_bfloat16",
|
||||
"int16",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"float32",
|
||||
],
|
||||
help="Type to use for computation. \
|
||||
See https://opennmt.net/CTranslate2/quantization.html.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="model parameter, tweak to increase accuracy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_speech_threshold",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="model parameter, tweak to increase accuracy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--condition_on_previous_text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="model parameter, tweak to increase accuracy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="whether to perform X->X speech recognition ('transcribe') \
|
||||
or X->English translation ('translate')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=LANGUAGE_CODES,
|
||||
help="What is the origin language of the video? \
|
||||
If unset, it is detected automatically.",
|
||||
)
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -5,7 +5,7 @@ import time
|
||||
from utils.files import filename, write_srt
|
||||
from utils.ffmpeg import get_audio, add_subtitles_to_mp4
|
||||
from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
|
||||
from utils.sonarr import update_show_in_soarr
|
||||
from utils.sonarr import update_show_in_sonarr
|
||||
from utils.whisper import WhisperAI
|
||||
|
||||
|
||||
@ -13,11 +13,12 @@ def process(args: dict):
|
||||
model_name: str = args.pop("model")
|
||||
language: str = args.pop("language")
|
||||
sample_interval: str = args.pop("sample_interval")
|
||||
audio_channel: str = args.pop('audio_channel')
|
||||
audio_channel: str = args.pop("audio_channel")
|
||||
|
||||
if model_name.endswith(".en"):
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model, forcing English detection.")
|
||||
f"{model_name} is an English-only model, forcing English detection."
|
||||
)
|
||||
args["language"] = "en"
|
||||
# if translate task used and language argument is set, then use it
|
||||
elif language != "auto":
|
||||
@ -29,29 +30,30 @@ def process(args: dict):
|
||||
model_args["compute_type"] = args.pop("compute_type")
|
||||
|
||||
list_of_episodes_needing_subtitles = get_wanted_episodes()
|
||||
print(f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles.")
|
||||
for episode in list_of_episodes_needing_subtitles['data']:
|
||||
print(
|
||||
f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
|
||||
)
|
||||
for episode in list_of_episodes_needing_subtitles["data"]:
|
||||
print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
|
||||
episode_data = get_episode_details(episode['sonarrEpisodeId'])
|
||||
audios = get_audio([episode_data['path']], audio_channel, sample_interval)
|
||||
episode_data = get_episode_details(episode["sonarrEpisodeId"])
|
||||
audios = get_audio([episode_data["path"]], audio_channel, sample_interval)
|
||||
subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
|
||||
|
||||
add_subtitles_to_mp4(subtitles)
|
||||
update_show_in_soarr(episode['sonarrSeriesId'])
|
||||
update_show_in_sonarr(episode["sonarrSeriesId"])
|
||||
time.sleep(5)
|
||||
sync_series()
|
||||
|
||||
|
||||
def get_subtitles(audio_paths: list, output_dir: str,
|
||||
model_args: dict, transcribe_args: dict):
|
||||
def get_subtitles(
|
||||
audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
|
||||
):
|
||||
model = WhisperAI(model_args, transcribe_args)
|
||||
|
||||
subtitles_path = {}
|
||||
|
||||
for path, audio_path in audio_paths.items():
|
||||
print(
|
||||
f"Generating subtitles for {filename(path)}... This might take a while."
|
||||
)
|
||||
print(f"Generating subtitles for {filename(path)}... This might take a while.")
|
||||
srt_path = os.path.join(output_dir, f"{filename(path)}.srt")
|
||||
|
||||
segments = model.transcribe(audio_path)
|
@ -1,19 +1,18 @@
|
||||
import requests
|
||||
import configparser
|
||||
config = configparser.RawConfigParser()
|
||||
config.read('config.cfg')
|
||||
|
||||
token = config._sections['bazarr']['token']
|
||||
base_url = config._sections['bazarr']['url']
|
||||
config = configparser.RawConfigParser()
|
||||
config.read("config.cfg")
|
||||
|
||||
token = config._sections["bazarr"]["token"]
|
||||
base_url = config._sections["bazarr"]["url"]
|
||||
|
||||
|
||||
def get_wanted_episodes():
|
||||
url = f"{base_url}/api/episodes/wanted"
|
||||
|
||||
payload = {}
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'X-API-KEY': token
|
||||
}
|
||||
headers = {"accept": "application/json", "X-API-KEY": token}
|
||||
|
||||
response = requests.request("GET", url, headers=headers, data=payload)
|
||||
|
||||
@ -24,24 +23,18 @@ def get_episode_details(episode_id: str):
|
||||
url = f"{base_url}/api/episodes?episodeid%5B%5D={episode_id}"
|
||||
|
||||
payload = {}
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'X-API-KEY': token
|
||||
}
|
||||
headers = {"accept": "application/json", "X-API-KEY": token}
|
||||
|
||||
response = requests.request("GET", url, headers=headers, data=payload)
|
||||
return response.json()['data'][0]
|
||||
return response.json()["data"][0]
|
||||
|
||||
|
||||
def sync_series():
|
||||
url = f"{base_url}/api/system/tasks?taskid=update_series"
|
||||
|
||||
payload = {}
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'X-API-KEY': token
|
||||
}
|
||||
headers = {"accept": "application/json", "X-API-KEY": token}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 204:
|
||||
print('Updated Bazarr')
|
||||
print("Updated Bazarr")
|
@ -8,37 +8,42 @@ def str2bool(string: str):
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
|
||||
raise ValueError(
|
||||
f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
def str2timeinterval(string: str):
|
||||
if string is None:
|
||||
return None
|
||||
|
||||
if '-' not in string:
|
||||
if "-" not in string:
|
||||
raise ValueError(
|
||||
f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}")
|
||||
f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}"
|
||||
)
|
||||
|
||||
intervals = string.split('-')
|
||||
intervals = string.split("-")
|
||||
if len(intervals) != 2:
|
||||
raise ValueError(
|
||||
f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}")
|
||||
f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}"
|
||||
)
|
||||
|
||||
start = try_parse_timestamp(intervals[0])
|
||||
end = try_parse_timestamp(intervals[1])
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
f"Expected time interval end to be higher than start, got {start} >= {end}")
|
||||
f"Expected time interval end to be higher than start, got {start} >= {end}"
|
||||
)
|
||||
|
||||
return [start, end]
|
||||
|
||||
|
||||
def time_to_timestamp(string: str):
|
||||
split_time = string.split(':')
|
||||
if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time):
|
||||
raise ValueError(
|
||||
f"Expected HH:mm:ss or HH:mm or ss, got {string}")
|
||||
split_time = string.split(":")
|
||||
if (
|
||||
len(split_time) == 0
|
||||
or len(split_time) > 3
|
||||
or not all(x.isdigit() for x in split_time)
|
||||
):
|
||||
raise ValueError(f"Expected HH:mm:ss or HH:mm or ss, got {string}")
|
||||
|
||||
if len(split_time) == 1:
|
||||
return int(split_time[0])
|
||||
@ -50,22 +55,21 @@ def time_to_timestamp(string: str):
|
||||
|
||||
|
||||
def try_parse_timestamp(string: str):
|
||||
timestamp = parse_timestamp(string, '%H:%M:%S')
|
||||
timestamp = parse_timestamp(string, "%H:%M:%S")
|
||||
if timestamp is not None:
|
||||
return timestamp
|
||||
|
||||
timestamp = parse_timestamp(string, '%H:%M')
|
||||
timestamp = parse_timestamp(string, "%H:%M")
|
||||
if timestamp is not None:
|
||||
return timestamp
|
||||
|
||||
return parse_timestamp(string, '%S')
|
||||
return parse_timestamp(string, "%S")
|
||||
|
||||
|
||||
def parse_timestamp(string: str, pattern: str):
|
||||
try:
|
||||
date = datetime.strptime(string, pattern)
|
||||
delta = timedelta(
|
||||
hours=date.hour, minutes=date.minute, seconds=date.second)
|
||||
delta = timedelta(hours=date.hour, minutes=date.minute, seconds=date.second)
|
||||
return int(delta.total_seconds())
|
||||
except: # pylint: disable=bare-except
|
||||
return None
|
@ -15,20 +15,18 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
||||
|
||||
ffmpeg_input_args = {}
|
||||
if sample_interval is not None:
|
||||
ffmpeg_input_args['ss'] = str(sample_interval[0])
|
||||
ffmpeg_input_args["ss"] = str(sample_interval[0])
|
||||
|
||||
ffmpeg_output_args = {}
|
||||
ffmpeg_output_args['acodec'] = "pcm_s16le"
|
||||
ffmpeg_output_args['ac'] = "1"
|
||||
ffmpeg_output_args['ar'] = "16k"
|
||||
ffmpeg_output_args['map'] = "0:a:" + str(audio_channel_index)
|
||||
ffmpeg_output_args["acodec"] = "pcm_s16le"
|
||||
ffmpeg_output_args["ac"] = "1"
|
||||
ffmpeg_output_args["ar"] = "16k"
|
||||
ffmpeg_output_args["map"] = "0:a:" + str(audio_channel_index)
|
||||
if sample_interval is not None:
|
||||
ffmpeg_output_args['t'] = str(
|
||||
sample_interval[1] - sample_interval[0])
|
||||
ffmpeg_output_args["t"] = str(sample_interval[1] - sample_interval[0])
|
||||
|
||||
ffmpeg.input(path, **ffmpeg_input_args).output(
|
||||
output_path,
|
||||
**ffmpeg_output_args
|
||||
output_path, **ffmpeg_output_args
|
||||
).run(quiet=True, overwrite_output=True)
|
||||
|
||||
audio_paths[path] = output_path
|
||||
@ -37,19 +35,25 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list):
|
||||
|
||||
|
||||
def add_subtitles_to_mp4(subtitles: dict):
|
||||
|
||||
input_file = list(subtitles.keys())[0]
|
||||
subtitle_file = subtitles[input_file]
|
||||
output_file = input_file
|
||||
os.rename(input_file, input_file+'_edit')
|
||||
os.rename(input_file, input_file + "_edit")
|
||||
|
||||
input_stream = ffmpeg.input(input_file+'_edit')
|
||||
input_stream = ffmpeg.input(input_file + "_edit")
|
||||
subtitle_stream = ffmpeg.input(subtitle_file)
|
||||
|
||||
# Combine input video and subtitle
|
||||
output = ffmpeg.output(input_stream, subtitle_stream, output_file.replace('.mkv','.mp4'), c='copy', **{'c:s': 'mov_text'}, **{'metadata:s:s:0': 'language=eng'})
|
||||
output = ffmpeg.output(
|
||||
input_stream,
|
||||
subtitle_stream,
|
||||
output_file.replace(".mkv", ".mp4"),
|
||||
c="copy",
|
||||
**{"c:s": "mov_text"},
|
||||
**{"metadata:s:s:0": "language=eng"},
|
||||
)
|
||||
ffmpeg.run(output, quiet=True, overwrite_output=True)
|
||||
os.remove(input_file+'_edit')
|
||||
os.remove(input_file + "_edit")
|
||||
# remove tempfiles
|
||||
os.remove(subtitle_file)
|
||||
os.remove(subtitle_file.replace(".srt", ".wav"))
|
@ -2,6 +2,7 @@ import os
|
||||
from typing import Iterator, TextIO
|
||||
from .convert import format_timestamp
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
print(
|
||||
@ -13,5 +14,6 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def filename(path: str):
|
||||
return os.path.splitext(os.path.basename(path))[0]
|
24
bazarr-ai-sub-generator/utils/sonarr.py
Normal file
24
bazarr-ai-sub-generator/utils/sonarr.py
Normal file
@ -0,0 +1,24 @@
|
||||
import requests
|
||||
import json
|
||||
import configparser
|
||||
|
||||
config = configparser.RawConfigParser()
|
||||
config.read("config.cfg")
|
||||
|
||||
token = config._sections["sonarr"]["token"]
|
||||
base_url = config._sections["sonarr"]["url"]
|
||||
|
||||
|
||||
def update_show_in_sonarr(show_id):
|
||||
url = f"{base_url}/api/v3/command"
|
||||
|
||||
payload = json.dumps({"name": "RefreshSeries", "seriesId": show_id})
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Api-Key": token,
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 404:
|
||||
print("Updated show in Sonarr")
|
@ -2,6 +2,7 @@ import warnings
|
||||
import faster_whisper
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# pylint: disable=R0903
|
||||
class WhisperAI:
|
||||
"""
|
Loading…
x
Reference in New Issue
Block a user