gpu support by flag

This commit is contained in:
Karl Hudgell 2024-12-14 15:49:36 +00:00
parent 02071420ff
commit 8205f6ec73

View File

@ -6,16 +6,13 @@ from concurrent.futures import ThreadPoolExecutor
import onnxruntime as ort import onnxruntime as ort
# Suppress ONNX Runtime logging to show only critical errors def process_images_in_batch(batch, use_gpu):
ort.set_default_logger_severity(3) # 0 = verbose, 1 = info, 2 = warning, 3 = error, 4 = fatal
def process_images_in_batch(batch):
""" """
Process a batch of images: remove their backgrounds and save the results. Process a batch of images: remove their backgrounds and save the results.
Args: Args:
batch (list): List of tuples (input_path, output_path). batch (list): List of tuples (input_path, output_path).
use_gpu (bool): Whether to enable GPU support.
Returns: Returns:
int: Number of images successfully processed in this batch. int: Number of images successfully processed in this batch.
@ -24,14 +21,20 @@ def process_images_in_batch(batch):
for input_path, output_path in batch: for input_path, output_path in batch:
try: try:
with Image.open(input_path) as img: with Image.open(input_path) as img:
output = remove(img) # This will use GPU if ONNX Runtime is GPU-enabled # Initialize ONNX session options with GPU support if required
session_options = ort.SessionOptions()
providers = ["CUDAExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
ort.set_default_logger_severity(3) # Suppress non-critical logging
# Initialize the rembg remove function with appropriate providers
output = remove(img, session_options=session_options, providers=providers)
output.save(output_path) output.save(output_path)
success_count += 1 success_count += 1
except Exception as e: except Exception as e:
print(f"Error processing {input_path}: {str(e)}") print(f"Error processing {input_path}: {str(e)}")
return success_count return success_count
def remove_bg_from_files_in_dir(directory, max_workers=2, batch_size=5): def remove_bg_from_files_in_dir(directory, max_workers=2, batch_size=3, use_gpu=False):
""" """
Process all JPG, JPEG, and PNG images in the given directory and its subfolders using parallel processing and GPU. Process all JPG, JPEG, and PNG images in the given directory and its subfolders using parallel processing and GPU.
@ -39,6 +42,7 @@ def remove_bg_from_files_in_dir(directory, max_workers=2, batch_size=5):
directory (str): Path to the directory containing images. directory (str): Path to the directory containing images.
max_workers (int): Maximum number of threads to use for parallel processing. max_workers (int): Maximum number of threads to use for parallel processing.
batch_size (int): Number of images to process per batch. batch_size (int): Number of images to process per batch.
use_gpu (bool): Whether to enable GPU support.
Returns: Returns:
int: The number of images successfully processed. int: The number of images successfully processed.
@ -60,8 +64,11 @@ def remove_bg_from_files_in_dir(directory, max_workers=2, batch_size=5):
batches = [files_to_process[i:i + batch_size] for i in range(0, len(files_to_process), batch_size)] batches = [files_to_process[i:i + batch_size] for i in range(0, len(files_to_process), batch_size)]
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
with tqdm(total=len(files_to_process), desc="Processing images", unit="image") as pbar: with tqdm(total=len(files_to_process), desc="Removing Backgrounds", unit="image") as pbar:
futures = {executor.submit(process_images_in_batch, batch): batch for batch in batches} futures = {
executor.submit(process_images_in_batch, batch, use_gpu): batch
for batch in batches
}
for future in futures: for future in futures:
processed_count += future.result() processed_count += future.result()