Merge pull request #6 from karl0ss/reworked

Reworked
This commit is contained in:
Karl0ss 2024-07-16 08:31:56 +01:00 committed by GitHub
commit bf069d1fb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 201 additions and 92 deletions

25
.vscode/launch.json vendored
View File

@ -5,16 +5,33 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python: Current File", "name": "Python Debugger: Current File",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "${file}", "program": "${file}",
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": false, "justMyCode": false,
"env": {
"CUDA_VISIBLE_DEVICES": "1",
"LD_LIBRARY_PATH": "/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cublas/lib:/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cudnn/lib"
},
"args": [ "args": [
"--model", "--model",
"base", "base"
], ]
},
{
"name": "Current (withenv)",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/run_with_env.sh",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"${file}",
"--model",
"base"
]
} }
] ]
} }

View File

@ -15,16 +15,16 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( # parser.add_argument(
"--audio_channel", default="0", type=int, help="audio channel index to use" # "--audio_channel", default="0", type=int, help="audio channel index to use"
) # )
parser.add_argument( # parser.add_argument(
"--sample_interval", # "--sample_interval",
type=str2timeinterval, # type=str2timeinterval,
default=None, # default=None,
help="generate subtitles for a specific \ # help="generate subtitles for a specific \
fragment of the video (e.g. 01:02:05-01:03:45)", # fragment of the video (e.g. 01:02:05-01:03:45)",
) # )
parser.add_argument( parser.add_argument(
"--model", "--model",
default="small", default="small",
@ -38,46 +38,27 @@ def main():
choices=["cpu", "cuda", "auto"], choices=["cpu", "cuda", "auto"],
help='Device to use for computation ("cpu", "cuda", "auto")', help='Device to use for computation ("cpu", "cuda", "auto")',
) )
# parser.add_argument(
# "--compute_type",
# type=str,
# default="default",
# choices=[
# "int8",
# "int8_float32",
# "int8_float16",
# "int8_bfloat16",
# "int16",
# "float16",
# "bfloat16",
# "float32",
# ],
# help="Type to use for computation. \
# See https://opennmt.net/CTranslate2/quantization.html.",
# )
parser.add_argument( parser.add_argument(
"--compute_type", "--show",
type=str, type=str,
default="default", default=None,
choices=[
"int8",
"int8_float32",
"int8_float16",
"int8_bfloat16",
"int16",
"float16",
"bfloat16",
"float32",
],
help="Type to use for computation. \
See https://opennmt.net/CTranslate2/quantization.html.",
)
parser.add_argument(
"--beam_size",
type=int,
default=5,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--no_speech_threshold",
type=float,
default=0.6,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--condition_on_previous_text",
type=str2bool,
default=True,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="whether to perform X->X speech recognition ('transcribe') \ help="whether to perform X->X speech recognition ('transcribe') \
or X->English translation ('translate')", or X->English translation ('translate')",
) )

View File

@ -6,15 +6,16 @@ from utils.files import filename, write_srt
from utils.ffmpeg import get_audio, add_subtitles_to_mp4 from utils.ffmpeg import get_audio, add_subtitles_to_mp4
from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
from utils.sonarr import update_show_in_sonarr from utils.sonarr import update_show_in_sonarr
# from utils.faster_whisper import WhisperAI
from utils.whisper import WhisperAI from utils.whisper import WhisperAI
from utils.decorator import measure_time
def process(args: dict): def process(args: dict):
model_name: str = args.pop("model") model_name: str = args.pop("model")
language: str = args.pop("language") language: str = args.pop("language")
sample_interval: str = args.pop("sample_interval") show: str = args.pop("show")
audio_channel: str = args.pop("audio_channel")
if model_name.endswith(".en"): if model_name.endswith(".en"):
warnings.warn( warnings.warn(
f"{model_name} is an English-only model, forcing English detection." f"{model_name} is an English-only model, forcing English detection."
@ -25,26 +26,27 @@ def process(args: dict):
args["language"] = language args["language"] = language
model_args = {} model_args = {}
model_args["model_size_or_path"] = model_name
model_args["device"] = args.pop("device") model_args["device"] = args.pop("device")
model_args["compute_type"] = args.pop("compute_type")
list_of_episodes_needing_subtitles = get_wanted_episodes(show)
list_of_episodes_needing_subtitles = get_wanted_episodes()
print( print(
f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles." f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
) )
for episode in list_of_episodes_needing_subtitles["data"]: for episode in list_of_episodes_needing_subtitles["data"]:
print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}") print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
episode_data = get_episode_details(episode["sonarrEpisodeId"]) episode_data = get_episode_details(episode["sonarrEpisodeId"])
audios = get_audio([episode_data["path"]], audio_channel, sample_interval) try:
subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args) audios = get_audio([episode_data["path"]], 0, None)
subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
add_subtitles_to_mp4(subtitles)
update_show_in_sonarr(episode["sonarrSeriesId"])
time.sleep(5)
sync_series()
add_subtitles_to_mp4(subtitles)
update_show_in_sonarr(episode["sonarrSeriesId"])
time.sleep(5)
sync_series()
except Exception as ex:
print(f"skipping file due to - {ex}")
@measure_time
def get_subtitles( def get_subtitles(
audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
): ):

View File

@ -8,15 +8,19 @@ token = config._sections["bazarr"]["token"]
base_url = config._sections["bazarr"]["url"] base_url = config._sections["bazarr"]["url"]
def get_wanted_episodes(): def get_wanted_episodes(show: str=None):
url = f"{base_url}/api/episodes/wanted" url = f"{base_url}/api/episodes/wanted"
payload = {} payload = {}
headers = {"accept": "application/json", "X-API-KEY": token} headers = {"accept": "application/json", "X-API-KEY": token}
response = requests.request("GET", url, headers=headers, data=payload) response = requests.request("GET", url, headers=headers, data=payload)
return response.json() data = response.json()
if show != None:
data['data'] = [item for item in data['data'] if item['seriesTitle'] == show]
data['total'] = len(data['data'])
return data
def get_episode_details(episode_id: str): def get_episode_details(episode_id: str):

View File

@ -0,0 +1,13 @@
import time
from datetime import timedelta
def measure_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
human_readable_duration = str(timedelta(seconds=duration))
print(f"Function '{func.__name__}' executed in: {human_readable_duration}")
return result
return wrapper

View File

@ -0,0 +1,68 @@
import warnings
import faster_whisper
from tqdm import tqdm
# pylint: disable=R0903
class WhisperAI:
"""
Wrapper class for the Whisper speech recognition model with additional functionality.
This class provides a high-level interface for transcribing audio files using the Whisper
speech recognition model. It encapsulates the model instantiation and transcription process,
allowing users to easily transcribe audio files and iterate over the resulting segments.
Usage:
```python
whisper = WhisperAI(model_args, transcribe_args)
# Transcribe an audio file and iterate over the segments
for segment in whisper.transcribe(audio_path):
# Process each transcription segment
print(segment)
```
Args:
- model_args: Arguments to pass to WhisperModel initialize method
- model_size_or_path (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu", "cuda", "auto").
- compute_type (str): The type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
Attributes:
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
- transcribe_args (dict): Additional arguments used for transcribe method.
Methods:
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
"""
def __init__(self, model_args: dict, transcribe_args: dict):
# self.model = faster_whisper.WhisperModel(**model_args)
model_size = "base"
self.model = faster_whisper.WhisperModel(model_size, device="cuda")
self.transcribe_args = transcribe_args
def transcribe(self, audio_path: str):
"""
Transcribes the specified audio file and yields the resulting segments.
Args:
- audio_path (str): The path to the audio file for transcription.
Yields:
- faster_whisper.TranscriptionSegment: An individual transcription segment.
"""
warnings.filterwarnings("ignore")
segments, info = self.model.transcribe(audio_path, beam_size=5)
warnings.filterwarnings("default")
# Same precision as the Whisper timestamps.
total_duration = round(info.duration, 2)
with tqdm(total=total_duration, unit=" seconds") as pbar:
for segment in segments:
yield segment
pbar.update(segment.end - segment.start)
pbar.update(0)

View File

@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
for i, segment in enumerate(transcript, start=1): for i, segment in enumerate(transcript, start=1):
print( print(
f"{i}\n" f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True)} --> " f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
f"{format_timestamp(segment.end, always_include_hours=True)}\n" f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
f"{segment.text.strip().replace('-->', '->')}\n", f"{segment['text'].strip().replace('-->', '->')}\n",
file=file, file=file,
flush=True, flush=True,
) )

View File

@ -1,9 +1,9 @@
import warnings import warnings
import faster_whisper import torch
import whisper
from tqdm import tqdm from tqdm import tqdm
# pylint: disable=R0903
class WhisperAI: class WhisperAI:
""" """
Wrapper class for the Whisper speech recognition model with additional functionality. Wrapper class for the Whisper speech recognition model with additional functionality.
@ -23,23 +23,35 @@ class WhisperAI:
``` ```
Args: Args:
- model_args: Arguments to pass to WhisperModel initialize method - model_args (dict): Arguments to pass to Whisper model initialization
- model_size_or_path (str): The name of the Whisper model to use. - model_size (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu", "cuda", "auto"). - device (str): The device to use for computation ("cpu" or "cuda").
- compute_type (str): The type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
- transcribe_args (dict): Additional arguments to pass to the transcribe method. - transcribe_args (dict): Additional arguments to pass to the transcribe method.
Attributes: Attributes:
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model. - model (whisper.Whisper): The underlying Whisper speech recognition model.
- device (torch.device): The device to use for computation.
- transcribe_args (dict): Additional arguments used for transcribe method. - transcribe_args (dict): Additional arguments used for transcribe method.
Methods: Methods:
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments. - transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
""" """
def __init__(self, model_args: dict, transcribe_args: dict): def __init__(self, model_args: dict, transcribe_args: dict):
self.model = faster_whisper.WhisperModel(**model_args) """
Initializes the WhisperAI instance.
Args:
- model_args (dict): Arguments to initialize the Whisper model.
- transcribe_args (dict): Additional arguments for the transcribe method.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Set device for computation
self.device = torch.device(device)
# Load the Whisper model with the specified size
self.model = whisper.load_model("base.en").to(self.device)
# Store the additional transcription arguments
self.transcribe_args = transcribe_args self.transcribe_args = transcribe_args
def transcribe(self, audio_path: str): def transcribe(self, audio_path: str):
@ -50,17 +62,24 @@ class WhisperAI:
- audio_path (str): The path to the audio file for transcription. - audio_path (str): The path to the audio file for transcription.
Yields: Yields:
- faster_whisper.TranscriptionSegment: An individual transcription segment. - dict: An individual transcription segment.
""" """
# Suppress warnings during transcription
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
segments, info = self.model.transcribe(audio_path, **self.transcribe_args) # Load and transcribe the audio file
result = self.model.transcribe(audio_path, **self.transcribe_args)
# Restore default warning behavior
warnings.filterwarnings("default") warnings.filterwarnings("default")
# Same precision as the Whisper timestamps. # Calculate the total duration from the segments
total_duration = round(info.duration, 2) total_duration = max(segment["end"] for segment in result["segments"])
# Create a progress bar with the total duration of the audio file
with tqdm(total=total_duration, unit=" seconds") as pbar: with tqdm(total=total_duration, unit=" seconds") as pbar:
for segment in segments: for segment in result["segments"]:
# Yield each transcription segment
yield segment yield segment
pbar.update(segment.end - segment.start) # Update the progress bar with the duration of the current segment
pbar.update(segment["end"] - segment["start"])
# Ensure the progress bar reaches 100% upon completion
pbar.update(0) pbar.update(0)

View File

@ -1,3 +1,9 @@
faster-whisper==0.10.0
tqdm==4.56.0 tqdm==4.56.0
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
git+https://github.com/openai/whisper.git
faster-whisper
nvidia-cublas-cu12
nvidia-cudnn-cu12
nvidia-cublas-cu11
nvidia-cudnn-cu11
ctranslate2==3.24.0

View File

@ -7,7 +7,6 @@ setup(
py_modules=["bazarr-ai-sub-generator"], py_modules=["bazarr-ai-sub-generator"],
author="Karl Hudgell", author="Karl Hudgell",
install_requires=[ install_requires=[
'faster-whisper',
'tqdm', 'tqdm',
'ffmpeg-python' 'ffmpeg-python'
], ],