comfy_fm24_newgens/comfy_fm_newgen.py
2025-09-23 15:37:07 +01:00

416 lines
16 KiB
Python

import argparse
import os
import random
import json
import configparser
import pycountry
import inflect
import logging
import sys
import logging
import logging.config
from tqdm import tqdm
from lib.rtf_parser import RTF_Parser
try:
from lib.remove_bg import remove_bg_from_file_list
REMBG_AVAILABLE = True
except ImportError:
REMBG_AVAILABLE = False
print("Warning: Background removal not available")
from lib.generate_xml import create_config_xml, append_to_config_xml
from lib.resize_images import resize_images
from lib.xml_reader import extract_from_values
from lib.text_chunker import chunk_prompt_for_clip
# Profile functions are now handled entirely by GUI
from lib.logging import LOGGING_CONFIG
def save_prompt_mapping(uid, prompt):
"""Save the prompt used for a specific image UID"""
try:
prompt_file = f"{output_folder}/prompts.json"
# Load existing prompts
if os.path.exists(prompt_file):
with open(prompt_file, 'r') as f:
prompts_data = json.load(f)
else:
prompts_data = {}
# Add or update the prompt for this UID
prompts_data[uid] = prompt
# Save back to file
with open(prompt_file, 'w') as f:
json.dump(prompts_data, f, indent=2)
logging.debug(f"Saved prompt for UID {uid}")
except Exception as e:
logging.warning(f"Failed to save prompt for UID {uid}: {e}")
def get_prompt_for_image(uid):
"""Get the prompt used for a specific image UID"""
try:
prompt_file = f"{output_folder}/prompts.json"
if os.path.exists(prompt_file):
with open(prompt_file, 'r') as f:
prompts_data = json.load(f)
return prompts_data.get(uid, "Prompt not found")
return "No prompts file found"
except Exception as e:
logging.warning(f"Failed to load prompt for UID {uid}: {e}")
return "Error loading prompt"
# from simple_term_menu import TerminalMenu
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
from PIL import Image
logging.config.dictConfig(LOGGING_CONFIG)
cut = 100
update = False
use_gpu = False
process_player = False
force_cpu = False
# Load user configurations
user_config = configparser.ConfigParser()
try:
user_config.read("./user_config.cfg")
# GUI mode: Use default profile or first available profile
profiles = [section.split(':', 1)[1] for section in user_config.sections() if section.startswith('profile:')]
if profiles:
selected_profile = profiles[0] # Use first available profile
logging.debug(f"Using profile '{selected_profile}'")
else:
selected_profile = "NewGens" # Default fallback
logging.debug(f"No profiles found, using default profile '{selected_profile}'")
selected_profile = f"profile:{selected_profile}"
output_folder = user_config[selected_profile]["output_dir"]
logging.debug("Configuration loaded successfully.")
except KeyError as e:
logging.error(f"Missing configuration key: {e}")
sys.exit(1)
rtf = RTF_Parser()
p = inflect.engine()
def generate_image(uid, comfy_prompt):
"""Generate an image using local Stable Diffusion."""
try:
# Initialize the pipeline (do this once and reuse)
if not hasattr(generate_image, 'pipeline'):
logging.info("Loading Stable Diffusion model...")
# Get model configuration
try:
model_id = user_config["models"]["model_name"]
model_dir = user_config["models"].get("model_dir", None)
logging.info(f"Using model: {model_id}")
except KeyError:
model_id = "SG161222/Realistic_Vision_V6.0_B1"
model_dir = None
logging.warning(f"Model configuration not found, using default: {model_id}")
# Check if CUDA is available and get detailed GPU info
if torch.cuda.is_available() and not force_cpu:
device = "cuda"
gpu_count = torch.cuda.device_count()
current_device = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(current_device)
gpu_memory = torch.cuda.get_device_properties(current_device).total_memory / 1024**3 # GB
logging.info(f"GPU detected: {gpu_name}")
logging.info(f"GPU memory: {gpu_memory:.1f} GB")
logging.info(f"Available GPU devices: {gpu_count}")
logging.info(f"Using device: {device} (GPU {current_device})")
else:
if force_cpu:
device = "cpu"
logging.info("Forcing CPU usage as requested")
else:
device = "cpu"
logging.warning("CUDA not available, using CPU")
logging.info("To use GPU: Install CUDA toolkit and ensure PyTorch with CUDA support is installed")
logging.info("GPU requirements: https://pytorch.org/get-started/locally/")
# Load the pipeline
if model_dir:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
cache_dir=model_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
# Use DPMSolverMultistepScheduler for better quality/speed balance
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if device == "cuda":
pipe = pipe.to("cuda")
pipe.enable_attention_slicing() # Reduce memory usage
# Enable memory efficient attention if available
try:
pipe.enable_xformers_memory_efficient_attention()
logging.info("Enabled memory efficient attention")
except:
pass
generate_image.pipeline = pipe
generate_image.device = device # Store device for later use
logging.info(f"Model loaded successfully on {device}")
# Generate the image
logging.debug(f"Generating image for UID: {uid}")
# Set random seed for reproducibility
generator = torch.Generator(device=generate_image.pipeline.device)
generator.manual_seed(random.getrandbits(32))
# Generate image with parameters similar to ComfyUI workflow
image = generate_image.pipeline(
comfy_prompt,
num_inference_steps=6,
guidance_scale=1.5,
generator=generator,
width=512,
height=512
).images[0]
# Save the image
output_path = f"{user_config[selected_profile]['output_dir']}{uid}.png"
image.save(output_path)
logging.debug(f"Image generated successfully for UID: {uid}")
except Exception as e:
logging.error(f"Failed to generate image for UID: {uid}. Error: {e}")
raise
def get_country_name(app_config, country_code):
# First check if it's a custom mapping
if country_code in app_config["facial_characteristics"]:
return app_config["facial_characteristics"][country_code]
# Use pycountry for standard country codes
country = pycountry.countries.get(alpha_3=country_code)
if country:
return country.name
return "Unknown Country"
def generate_prompts_for_players(players, app_config):
"""Generate images for a specific player and configuration."""
prompts = []
for player in players:
try:
logging.debug(f"Generating prompt for {player[0]} - {player[8]}")
os.makedirs(output_folder, exist_ok=True)
country = get_country_name(app_config, player[1])
facial_characteristics = random.choice(app_config["facial_characteristics"])
hair_length = app_config["hair_length"][player[5]]
hair_colour = app_config["hair_color"][player[6]]
skin_tone = app_config["skin_tone_map"][player[7]]
player_age = p.number_to_words(player[3])
if int(player[5]) > 1:
hair_extra = random.choice(app_config["hair"])
else:
hair_extra = ""
# Format the prompt
prompt = app_config["prompt"].format(
skin_tone=skin_tone,
age=player_age,
country=country,
facial_characteristics=facial_characteristics or "no facial hair",
hair=f"{hair_length} {hair_colour} {hair_extra}",
)
logging.debug(f"Generated prompt: {prompt}")
# Chunk the prompt if it's too long for CLIP
prompt_chunks = chunk_prompt_for_clip(prompt)
# Use the first chunk (most important part) for image generation
# This ensures each player generates exactly one image
final_prompt = prompt_chunks[0] if prompt_chunks else prompt
prompts.append(f"{player[0]}:{final_prompt}")
except KeyError as e:
logging.warning(f"Key error while generating prompt for player: {e}")
return prompts
def post_process_images(
output_folder, update, processed_players, football_manager_version
):
"""
Handles post-processing tasks for generated images.
Args:
output_folder (str): Path to the folder where images are stored.
update (bool): Flag to determine if XML config should be updated.
processed_players (list): List of processed player IDs.
"""
try:
# # Resize images to desired dimensions
# resize_images(output_folder, processed_players)
# logging.debug("Images resized successfully.")
# Remove background from images if available
if REMBG_AVAILABLE:
try:
remove_bg_from_file_list(output_folder, processed_players, use_gpu=use_gpu)
logging.debug("Background removed from images.")
except Exception as e:
logging.warning(f"Background removal failed: {e}")
else:
logging.info("Background removal not available (rembg not installed). Images will have original backgrounds.")
# Update or create configuration XML
if update:
append_to_config_xml(
output_folder, processed_players, football_manager_version
)
logging.debug("Configuration XML updated.")
else:
create_config_xml(
output_folder, processed_players, football_manager_version
)
logging.debug("Configuration XML created.")
except Exception as e:
logging.error(f"Post-processing failed: {e}")
raise # Re-raise the exception to ensure the script stops if post-processing fails.
def main():
"""Main function for generating images."""
# parser = argparse.ArgumentParser(description="Generate images for country groups")
# parser.add_argument(
# "--rtf_file",
# type=str,
# default=None,
# help="Path to the RTF file to be processed",
# )
# parser.add_argument(
# "--player_uuid",
# type=int,
# default=None,
# help="Player UUID to generate",
# )
# parser.add_argument(
# "--num_inference_steps",
# type=int,
# default=6,
# help="Number of inference steps. Defaults to 6",
# )
# args = parser.parse_args()
# if not args.rtf_file:
# logging.error("Please pass in a RTF file as --rtf_file")
# sys.exit(1)
# Load configurations
try:
with open("app_config.json", "r") as f:
app_config = json.load(f)
logging.debug("Application configuration loaded successfully.")
except FileNotFoundError:
logging.error("app_config.json file not found.")
sys.exit(1)
# Parse the RTF file
try:
# rtf_file = random.sample(rtf.parse_rtf(args.rtf_file), cut)
rtf_location = user_config[selected_profile]["rtf_file"]
rtf_file = rtf.parse_rtf(rtf_location)[:cut]
logging.info(f"Parsed RTF file successfully. Found {len(rtf_file)} players.")
except FileNotFoundError:
logging.error(f"RTF file not found: {rtf_location}")
sys.exit(1)
# Get parameters from environment variables (set by GUI)
update_mode = os.environ.get('FM_NEWGEN_UPDATE_MODE', 'false').lower() == 'true'
process_specific_player = os.environ.get('FM_NEWGEN_PROCESS_PLAYER', 'false').lower() == 'true'
specific_player_uid = os.environ.get('FM_NEWGEN_PLAYER_UID', '')
# Check if user wants to force CPU usage
global force_cpu
force_cpu = os.environ.get('FM_NEWGEN_FORCE_CPU', 'false').lower() == 'true'
# Check for processed players
try:
if update_mode:
values_from_config = extract_from_values(
f"{user_config[selected_profile]['output_dir']}config.xml"
)
# Extract the IDs from list_a
ids_in_b = [item for item in values_from_config]
# Filter list_a to remove inner lists whose first item matches an ID in list_b
players_to_process = [item for item in rtf_file if item[0] not in ids_in_b]
if process_specific_player and specific_player_uid:
players_to_process = [
inner_list
for inner_list in players_to_process
if int(inner_list[0]) == int(specific_player_uid)
]
elif process_specific_player and specific_player_uid:
players_to_process = [
inner_list
for inner_list in rtf_file
if int(inner_list[0]) == int(specific_player_uid)
]
else:
players_to_process = rtf_file
except FileNotFoundError:
logging.error("config.xml file not found.")
sys.exit(1)
if len(players_to_process) > 0:
print(f"Processing {len(players_to_process)} players")
logging.info(f"Processing {len(players_to_process)} players")
prompts = generate_prompts_for_players(players_to_process, app_config)
for prompt in tqdm(prompts, desc="Generating Images"):
uid = prompt.split(":")[0]
comfy_prompt = prompt.split(":")[1]
generate_image(uid, comfy_prompt)
# Save the prompt for this image
save_prompt_mapping(uid, comfy_prompt)
try:
post_process_images(
output_folder,
update,
[item[0] for item in players_to_process],
user_config[selected_profile]["football_manager_version"],
)
except Exception as e:
logging.error(f"Post-processing failed: {e}")
else:
print(f"{len(players_to_process)} players processed")
logging.info(f"{len(players_to_process)} players processed")
logging.info("Image generation complete for players in RTF file.")
if __name__ == "__main__":
main()