Merge pull request #6 from karl0ss/reworked

Reworked
This commit is contained in:
Karl0ss 2024-07-16 08:31:56 +01:00 committed by GitHub
commit bf069d1fb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 201 additions and 92 deletions

23
.vscode/launch.json vendored
View File

@ -5,16 +5,33 @@
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"name": "Python Debugger: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"CUDA_VISIBLE_DEVICES": "1",
"LD_LIBRARY_PATH": "/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cublas/lib:/home/karl/faster-auto-subtitle/venv/lib/python3.11/site-packages/nvidia/cudnn/lib"
},
"args": [
"--model",
"base",
],
"base"
]
},
{
"name": "Current (withenv)",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/run_with_env.sh",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"${file}",
"--model",
"base"
]
}
]
}

View File

@ -15,16 +15,16 @@ def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--audio_channel", default="0", type=int, help="audio channel index to use"
)
parser.add_argument(
"--sample_interval",
type=str2timeinterval,
default=None,
help="generate subtitles for a specific \
fragment of the video (e.g. 01:02:05-01:03:45)",
)
# parser.add_argument(
# "--audio_channel", default="0", type=int, help="audio channel index to use"
# )
# parser.add_argument(
# "--sample_interval",
# type=str2timeinterval,
# default=None,
# help="generate subtitles for a specific \
# fragment of the video (e.g. 01:02:05-01:03:45)",
# )
parser.add_argument(
"--model",
default="small",
@ -38,46 +38,27 @@ def main():
choices=["cpu", "cuda", "auto"],
help='Device to use for computation ("cpu", "cuda", "auto")',
)
# parser.add_argument(
# "--compute_type",
# type=str,
# default="default",
# choices=[
# "int8",
# "int8_float32",
# "int8_float16",
# "int8_bfloat16",
# "int16",
# "float16",
# "bfloat16",
# "float32",
# ],
# help="Type to use for computation. \
# See https://opennmt.net/CTranslate2/quantization.html.",
# )
parser.add_argument(
"--compute_type",
"--show",
type=str,
default="default",
choices=[
"int8",
"int8_float32",
"int8_float16",
"int8_bfloat16",
"int16",
"float16",
"bfloat16",
"float32",
],
help="Type to use for computation. \
See https://opennmt.net/CTranslate2/quantization.html.",
)
parser.add_argument(
"--beam_size",
type=int,
default=5,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--no_speech_threshold",
type=float,
default=0.6,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--condition_on_previous_text",
type=str2bool,
default=True,
help="model parameter, tweak to increase accuracy",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
default=None,
help="whether to perform X->X speech recognition ('transcribe') \
or X->English translation ('translate')",
)

View File

@ -6,14 +6,15 @@ from utils.files import filename, write_srt
from utils.ffmpeg import get_audio, add_subtitles_to_mp4
from utils.bazarr import get_wanted_episodes, get_episode_details, sync_series
from utils.sonarr import update_show_in_sonarr
# from utils.faster_whisper import WhisperAI
from utils.whisper import WhisperAI
from utils.decorator import measure_time
def process(args: dict):
model_name: str = args.pop("model")
language: str = args.pop("language")
sample_interval: str = args.pop("sample_interval")
audio_channel: str = args.pop("audio_channel")
show: str = args.pop("show")
if model_name.endswith(".en"):
warnings.warn(
@ -25,26 +26,27 @@ def process(args: dict):
args["language"] = language
model_args = {}
model_args["model_size_or_path"] = model_name
model_args["device"] = args.pop("device")
model_args["compute_type"] = args.pop("compute_type")
list_of_episodes_needing_subtitles = get_wanted_episodes()
list_of_episodes_needing_subtitles = get_wanted_episodes(show)
print(
f"Found {list_of_episodes_needing_subtitles['total']} episodes needing subtitles."
)
for episode in list_of_episodes_needing_subtitles["data"]:
print(f"Processing {episode['seriesTitle']} - {episode['episode_number']}")
episode_data = get_episode_details(episode["sonarrEpisodeId"])
audios = get_audio([episode_data["path"]], audio_channel, sample_interval)
try:
audios = get_audio([episode_data["path"]], 0, None)
subtitles = get_subtitles(audios, tempfile.gettempdir(), model_args, args)
add_subtitles_to_mp4(subtitles)
update_show_in_sonarr(episode["sonarrSeriesId"])
time.sleep(5)
sync_series()
except Exception as ex:
print(f"skipping file due to - {ex}")
@measure_time
def get_subtitles(
audio_paths: list, output_dir: str, model_args: dict, transcribe_args: dict
):

View File

@ -8,7 +8,7 @@ token = config._sections["bazarr"]["token"]
base_url = config._sections["bazarr"]["url"]
def get_wanted_episodes():
def get_wanted_episodes(show: str=None):
url = f"{base_url}/api/episodes/wanted"
payload = {}
@ -16,7 +16,11 @@ def get_wanted_episodes():
response = requests.request("GET", url, headers=headers, data=payload)
return response.json()
data = response.json()
if show != None:
data['data'] = [item for item in data['data'] if item['seriesTitle'] == show]
data['total'] = len(data['data'])
return data
def get_episode_details(episode_id: str):

View File

@ -0,0 +1,13 @@
import time
from datetime import timedelta
def measure_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
human_readable_duration = str(timedelta(seconds=duration))
print(f"Function '{func.__name__}' executed in: {human_readable_duration}")
return result
return wrapper

View File

@ -0,0 +1,68 @@
import warnings
import faster_whisper
from tqdm import tqdm
# pylint: disable=R0903
class WhisperAI:
"""
Wrapper class for the Whisper speech recognition model with additional functionality.
This class provides a high-level interface for transcribing audio files using the Whisper
speech recognition model. It encapsulates the model instantiation and transcription process,
allowing users to easily transcribe audio files and iterate over the resulting segments.
Usage:
```python
whisper = WhisperAI(model_args, transcribe_args)
# Transcribe an audio file and iterate over the segments
for segment in whisper.transcribe(audio_path):
# Process each transcription segment
print(segment)
```
Args:
- model_args: Arguments to pass to WhisperModel initialize method
- model_size_or_path (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu", "cuda", "auto").
- compute_type (str): The type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
Attributes:
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
- transcribe_args (dict): Additional arguments used for transcribe method.
Methods:
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
"""
def __init__(self, model_args: dict, transcribe_args: dict):
# self.model = faster_whisper.WhisperModel(**model_args)
model_size = "base"
self.model = faster_whisper.WhisperModel(model_size, device="cuda")
self.transcribe_args = transcribe_args
def transcribe(self, audio_path: str):
"""
Transcribes the specified audio file and yields the resulting segments.
Args:
- audio_path (str): The path to the audio file for transcription.
Yields:
- faster_whisper.TranscriptionSegment: An individual transcription segment.
"""
warnings.filterwarnings("ignore")
segments, info = self.model.transcribe(audio_path, beam_size=5)
warnings.filterwarnings("default")
# Same precision as the Whisper timestamps.
total_duration = round(info.duration, 2)
with tqdm(total=total_duration, unit=" seconds") as pbar:
for segment in segments:
yield segment
pbar.update(segment.end - segment.start)
pbar.update(0)

View File

@ -7,9 +7,9 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
for i, segment in enumerate(transcript, start=1):
print(
f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True)} --> "
f"{format_timestamp(segment.end, always_include_hours=True)}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)

View File

@ -1,9 +1,9 @@
import warnings
import faster_whisper
import torch
import whisper
from tqdm import tqdm
# pylint: disable=R0903
class WhisperAI:
"""
Wrapper class for the Whisper speech recognition model with additional functionality.
@ -23,23 +23,35 @@ class WhisperAI:
```
Args:
- model_args: Arguments to pass to WhisperModel initialize method
- model_size_or_path (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu", "cuda", "auto").
- compute_type (str): The type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
- model_args (dict): Arguments to pass to Whisper model initialization
- model_size (str): The name of the Whisper model to use.
- device (str): The device to use for computation ("cpu" or "cuda").
- transcribe_args (dict): Additional arguments to pass to the transcribe method.
Attributes:
- model (faster_whisper.WhisperModel): The underlying Whisper speech recognition model.
- model (whisper.Whisper): The underlying Whisper speech recognition model.
- device (torch.device): The device to use for computation.
- transcribe_args (dict): Additional arguments used for transcribe method.
Methods:
- transcribe(audio_path): Transcribes an audio file and yields the resulting segments.
- transcribe(audio_path: str): Transcribes an audio file and yields the resulting segments.
"""
def __init__(self, model_args: dict, transcribe_args: dict):
self.model = faster_whisper.WhisperModel(**model_args)
"""
Initializes the WhisperAI instance.
Args:
- model_args (dict): Arguments to initialize the Whisper model.
- transcribe_args (dict): Additional arguments for the transcribe method.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Set device for computation
self.device = torch.device(device)
# Load the Whisper model with the specified size
self.model = whisper.load_model("base.en").to(self.device)
# Store the additional transcription arguments
self.transcribe_args = transcribe_args
def transcribe(self, audio_path: str):
@ -50,17 +62,24 @@ class WhisperAI:
- audio_path (str): The path to the audio file for transcription.
Yields:
- faster_whisper.TranscriptionSegment: An individual transcription segment.
- dict: An individual transcription segment.
"""
# Suppress warnings during transcription
warnings.filterwarnings("ignore")
segments, info = self.model.transcribe(audio_path, **self.transcribe_args)
# Load and transcribe the audio file
result = self.model.transcribe(audio_path, **self.transcribe_args)
# Restore default warning behavior
warnings.filterwarnings("default")
# Same precision as the Whisper timestamps.
total_duration = round(info.duration, 2)
# Calculate the total duration from the segments
total_duration = max(segment["end"] for segment in result["segments"])
# Create a progress bar with the total duration of the audio file
with tqdm(total=total_duration, unit=" seconds") as pbar:
for segment in segments:
for segment in result["segments"]:
# Yield each transcription segment
yield segment
pbar.update(segment.end - segment.start)
# Update the progress bar with the duration of the current segment
pbar.update(segment["end"] - segment["start"])
# Ensure the progress bar reaches 100% upon completion
pbar.update(0)

View File

@ -1,3 +1,9 @@
faster-whisper==0.10.0
tqdm==4.56.0
ffmpeg-python==0.2.0
git+https://github.com/openai/whisper.git
faster-whisper
nvidia-cublas-cu12
nvidia-cudnn-cu12
nvidia-cublas-cu11
nvidia-cudnn-cu11
ctranslate2==3.24.0

View File

@ -7,7 +7,6 @@ setup(
py_modules=["bazarr-ai-sub-generator"],
author="Karl Hudgell",
install_requires=[
'faster-whisper',
'tqdm',
'ffmpeg-python'
],