comfy_fm24_newgens/comfy_fm_newgen.py

416 lines
16 KiB
Python
Raw Permalink Normal View History

2025-09-23 13:42:21 +01:00
import argparse
import os
import random
import json
import configparser
import pycountry
import inflect
import logging
import sys
import logging
import logging.config
from tqdm import tqdm
from lib.rtf_parser import RTF_Parser
2025-09-23 14:10:16 +01:00
try:
from lib.remove_bg import remove_bg_from_file_list
REMBG_AVAILABLE = True
except ImportError:
REMBG_AVAILABLE = False
print("Warning: Background removal not available")
2025-09-23 13:42:21 +01:00
from lib.generate_xml import create_config_xml, append_to_config_xml
from lib.resize_images import resize_images
from lib.xml_reader import extract_from_values
2025-09-23 15:30:50 +01:00
from lib.text_chunker import chunk_prompt_for_clip
2025-09-23 14:32:14 +01:00
# Profile functions are now handled entirely by GUI
2025-09-23 13:42:21 +01:00
from lib.logging import LOGGING_CONFIG
2025-09-23 15:33:01 +01:00
def save_prompt_mapping(uid, prompt):
"""Save the prompt used for a specific image UID"""
try:
prompt_file = f"{output_folder}/prompts.json"
# Load existing prompts
if os.path.exists(prompt_file):
with open(prompt_file, 'r') as f:
prompts_data = json.load(f)
else:
prompts_data = {}
# Add or update the prompt for this UID
prompts_data[uid] = prompt
# Save back to file
with open(prompt_file, 'w') as f:
json.dump(prompts_data, f, indent=2)
logging.debug(f"Saved prompt for UID {uid}")
except Exception as e:
logging.warning(f"Failed to save prompt for UID {uid}: {e}")
def get_prompt_for_image(uid):
"""Get the prompt used for a specific image UID"""
try:
prompt_file = f"{output_folder}/prompts.json"
if os.path.exists(prompt_file):
with open(prompt_file, 'r') as f:
prompts_data = json.load(f)
return prompts_data.get(uid, "Prompt not found")
return "No prompts file found"
except Exception as e:
logging.warning(f"Failed to load prompt for UID {uid}: {e}")
return "Error loading prompt"
2025-09-23 13:42:21 +01:00
# from simple_term_menu import TerminalMenu
2025-09-23 13:59:06 +01:00
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
from PIL import Image
2025-09-23 13:42:21 +01:00
logging.config.dictConfig(LOGGING_CONFIG)
cut = 100
update = False
use_gpu = False
process_player = False
2025-09-23 14:47:03 +01:00
force_cpu = False
2025-09-23 13:42:21 +01:00
# Load user configurations
user_config = configparser.ConfigParser()
try:
user_config.read("./user_config.cfg")
2025-09-23 14:25:59 +01:00
2025-09-23 14:32:14 +01:00
# GUI mode: Use default profile or first available profile
profiles = [section.split(':', 1)[1] for section in user_config.sections() if section.startswith('profile:')]
if profiles:
selected_profile = profiles[0] # Use first available profile
logging.debug(f"Using profile '{selected_profile}'")
2025-09-23 14:25:59 +01:00
else:
2025-09-23 14:32:14 +01:00
selected_profile = "NewGens" # Default fallback
logging.debug(f"No profiles found, using default profile '{selected_profile}'")
2025-09-23 14:25:59 +01:00
2025-09-23 13:42:21 +01:00
selected_profile = f"profile:{selected_profile}"
output_folder = user_config[selected_profile]["output_dir"]
logging.debug("Configuration loaded successfully.")
except KeyError as e:
logging.error(f"Missing configuration key: {e}")
sys.exit(1)
rtf = RTF_Parser()
p = inflect.engine()
def generate_image(uid, comfy_prompt):
2025-09-23 13:59:06 +01:00
"""Generate an image using local Stable Diffusion."""
2025-09-23 13:42:21 +01:00
try:
2025-09-23 13:59:06 +01:00
# Initialize the pipeline (do this once and reuse)
if not hasattr(generate_image, 'pipeline'):
logging.info("Loading Stable Diffusion model...")
# Get model configuration
try:
model_id = user_config["models"]["model_name"]
model_dir = user_config["models"].get("model_dir", None)
logging.info(f"Using model: {model_id}")
except KeyError:
model_id = "SG161222/Realistic_Vision_V6.0_B1"
model_dir = None
logging.warning(f"Model configuration not found, using default: {model_id}")
2025-09-23 14:28:56 +01:00
# Check if CUDA is available and get detailed GPU info
2025-09-23 14:32:14 +01:00
if torch.cuda.is_available() and not force_cpu:
2025-09-23 14:28:56 +01:00
device = "cuda"
gpu_count = torch.cuda.device_count()
current_device = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(current_device)
gpu_memory = torch.cuda.get_device_properties(current_device).total_memory / 1024**3 # GB
logging.info(f"GPU detected: {gpu_name}")
logging.info(f"GPU memory: {gpu_memory:.1f} GB")
logging.info(f"Available GPU devices: {gpu_count}")
logging.info(f"Using device: {device} (GPU {current_device})")
2025-09-23 14:32:14 +01:00
else:
2025-09-23 14:28:56 +01:00
if force_cpu:
device = "cpu"
logging.info("Forcing CPU usage as requested")
2025-09-23 14:32:14 +01:00
else:
device = "cpu"
logging.warning("CUDA not available, using CPU")
logging.info("To use GPU: Install CUDA toolkit and ensure PyTorch with CUDA support is installed")
logging.info("GPU requirements: https://pytorch.org/get-started/locally/")
2025-09-23 13:59:06 +01:00
# Load the pipeline
if model_dir:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
cache_dir=model_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
# Use DPMSolverMultistepScheduler for better quality/speed balance
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if device == "cuda":
pipe = pipe.to("cuda")
pipe.enable_attention_slicing() # Reduce memory usage
2025-09-23 14:28:56 +01:00
# Enable memory efficient attention if available
try:
pipe.enable_xformers_memory_efficient_attention()
logging.info("Enabled memory efficient attention")
except:
pass
2025-09-23 13:59:06 +01:00
generate_image.pipeline = pipe
2025-09-23 14:28:56 +01:00
generate_image.device = device # Store device for later use
logging.info(f"Model loaded successfully on {device}")
2025-09-23 13:59:06 +01:00
# Generate the image
2025-09-23 13:42:21 +01:00
logging.debug(f"Generating image for UID: {uid}")
2025-09-23 13:59:06 +01:00
# Set random seed for reproducibility
generator = torch.Generator(device=generate_image.pipeline.device)
generator.manual_seed(random.getrandbits(32))
# Generate image with parameters similar to ComfyUI workflow
image = generate_image.pipeline(
comfy_prompt,
num_inference_steps=6,
guidance_scale=1.5,
generator=generator,
width=512,
height=512
).images[0]
# Save the image
output_path = f"{user_config[selected_profile]['output_dir']}{uid}.png"
image.save(output_path)
2025-09-23 13:42:21 +01:00
logging.debug(f"Image generated successfully for UID: {uid}")
2025-09-23 13:59:06 +01:00
2025-09-23 13:42:21 +01:00
except Exception as e:
logging.error(f"Failed to generate image for UID: {uid}. Error: {e}")
2025-09-23 13:59:06 +01:00
raise
2025-09-23 13:42:21 +01:00
def get_country_name(app_config, country_code):
# First check if it's a custom mapping
if country_code in app_config["facial_characteristics"]:
return app_config["facial_characteristics"][country_code]
# Use pycountry for standard country codes
country = pycountry.countries.get(alpha_3=country_code)
if country:
return country.name
return "Unknown Country"
def generate_prompts_for_players(players, app_config):
"""Generate images for a specific player and configuration."""
prompts = []
for player in players:
try:
logging.debug(f"Generating prompt for {player[0]} - {player[8]}")
os.makedirs(output_folder, exist_ok=True)
country = get_country_name(app_config, player[1])
facial_characteristics = random.choice(app_config["facial_characteristics"])
hair_length = app_config["hair_length"][player[5]]
hair_colour = app_config["hair_color"][player[6]]
skin_tone = app_config["skin_tone_map"][player[7]]
player_age = p.number_to_words(player[3])
if int(player[5]) > 1:
hair_extra = random.choice(app_config["hair"])
else:
hair_extra = ""
# Format the prompt
prompt = app_config["prompt"].format(
skin_tone=skin_tone,
age=player_age,
country=country,
facial_characteristics=facial_characteristics or "no facial hair",
hair=f"{hair_length} {hair_colour} {hair_extra}",
)
logging.debug(f"Generated prompt: {prompt}")
2025-09-23 15:30:50 +01:00
# Chunk the prompt if it's too long for CLIP
prompt_chunks = chunk_prompt_for_clip(prompt)
2025-09-23 15:37:07 +01:00
# Use the first chunk (most important part) for image generation
# This ensures each player generates exactly one image
final_prompt = prompt_chunks[0] if prompt_chunks else prompt
prompts.append(f"{player[0]}:{final_prompt}")
2025-09-23 15:30:50 +01:00
2025-09-23 13:42:21 +01:00
except KeyError as e:
logging.warning(f"Key error while generating prompt for player: {e}")
return prompts
def post_process_images(
output_folder, update, processed_players, football_manager_version
):
"""
Handles post-processing tasks for generated images.
Args:
output_folder (str): Path to the folder where images are stored.
update (bool): Flag to determine if XML config should be updated.
processed_players (list): List of processed player IDs.
"""
try:
# # Resize images to desired dimensions
# resize_images(output_folder, processed_players)
# logging.debug("Images resized successfully.")
2025-09-23 14:10:16 +01:00
# Remove background from images if available
if REMBG_AVAILABLE:
try:
remove_bg_from_file_list(output_folder, processed_players, use_gpu=use_gpu)
logging.debug("Background removed from images.")
except Exception as e:
logging.warning(f"Background removal failed: {e}")
else:
logging.info("Background removal not available (rembg not installed). Images will have original backgrounds.")
2025-09-23 13:42:21 +01:00
# Update or create configuration XML
if update:
append_to_config_xml(
output_folder, processed_players, football_manager_version
)
logging.debug("Configuration XML updated.")
else:
create_config_xml(
output_folder, processed_players, football_manager_version
)
logging.debug("Configuration XML created.")
except Exception as e:
logging.error(f"Post-processing failed: {e}")
raise # Re-raise the exception to ensure the script stops if post-processing fails.
def main():
"""Main function for generating images."""
# parser = argparse.ArgumentParser(description="Generate images for country groups")
# parser.add_argument(
# "--rtf_file",
# type=str,
# default=None,
# help="Path to the RTF file to be processed",
# )
# parser.add_argument(
# "--player_uuid",
# type=int,
# default=None,
# help="Player UUID to generate",
# )
# parser.add_argument(
# "--num_inference_steps",
# type=int,
# default=6,
# help="Number of inference steps. Defaults to 6",
# )
# args = parser.parse_args()
# if not args.rtf_file:
# logging.error("Please pass in a RTF file as --rtf_file")
# sys.exit(1)
# Load configurations
try:
with open("app_config.json", "r") as f:
app_config = json.load(f)
logging.debug("Application configuration loaded successfully.")
except FileNotFoundError:
logging.error("app_config.json file not found.")
sys.exit(1)
# Parse the RTF file
try:
# rtf_file = random.sample(rtf.parse_rtf(args.rtf_file), cut)
rtf_location = user_config[selected_profile]["rtf_file"]
rtf_file = rtf.parse_rtf(rtf_location)[:cut]
logging.info(f"Parsed RTF file successfully. Found {len(rtf_file)} players.")
except FileNotFoundError:
logging.error(f"RTF file not found: {rtf_location}")
sys.exit(1)
2025-09-23 14:25:59 +01:00
# Get parameters from environment variables (set by GUI)
update_mode = os.environ.get('FM_NEWGEN_UPDATE_MODE', 'false').lower() == 'true'
process_specific_player = os.environ.get('FM_NEWGEN_PROCESS_PLAYER', 'false').lower() == 'true'
specific_player_uid = os.environ.get('FM_NEWGEN_PLAYER_UID', '')
2025-09-23 14:32:14 +01:00
# Check if user wants to force CPU usage
2025-09-23 14:47:03 +01:00
global force_cpu
2025-09-23 14:32:14 +01:00
force_cpu = os.environ.get('FM_NEWGEN_FORCE_CPU', 'false').lower() == 'true'
2025-09-23 14:25:59 +01:00
# Check for processed players
2025-09-23 13:42:21 +01:00
try:
2025-09-23 14:25:59 +01:00
if update_mode:
2025-09-23 13:42:21 +01:00
values_from_config = extract_from_values(
f"{user_config[selected_profile]['output_dir']}config.xml"
)
# Extract the IDs from list_a
ids_in_b = [item for item in values_from_config]
# Filter list_a to remove inner lists whose first item matches an ID in list_b
players_to_process = [item for item in rtf_file if item[0] not in ids_in_b]
2025-09-23 14:25:59 +01:00
if process_specific_player and specific_player_uid:
2025-09-23 13:42:21 +01:00
players_to_process = [
inner_list
for inner_list in players_to_process
2025-09-23 14:25:59 +01:00
if int(inner_list[0]) == int(specific_player_uid)
2025-09-23 13:42:21 +01:00
]
2025-09-23 14:25:59 +01:00
elif process_specific_player and specific_player_uid:
2025-09-23 13:42:21 +01:00
players_to_process = [
inner_list
for inner_list in rtf_file
2025-09-23 14:25:59 +01:00
if int(inner_list[0]) == int(specific_player_uid)
2025-09-23 13:42:21 +01:00
]
else:
players_to_process = rtf_file
except FileNotFoundError:
2025-09-23 14:25:59 +01:00
logging.error("config.xml file not found.")
2025-09-23 13:42:21 +01:00
sys.exit(1)
if len(players_to_process) > 0:
print(f"Processing {len(players_to_process)} players")
logging.info(f"Processing {len(players_to_process)} players")
prompts = generate_prompts_for_players(players_to_process, app_config)
for prompt in tqdm(prompts, desc="Generating Images"):
uid = prompt.split(":")[0]
comfy_prompt = prompt.split(":")[1]
generate_image(uid, comfy_prompt)
2025-09-23 15:33:01 +01:00
# Save the prompt for this image
save_prompt_mapping(uid, comfy_prompt)
2025-09-23 13:42:21 +01:00
try:
post_process_images(
output_folder,
update,
[item[0] for item in players_to_process],
user_config[selected_profile]["football_manager_version"],
)
except Exception as e:
logging.error(f"Post-processing failed: {e}")
else:
print(f"{len(players_to_process)} players processed")
logging.info(f"{len(players_to_process)} players processed")
logging.info("Image generation complete for players in RTF file.")
if __name__ == "__main__":
main()