mirror of
				https://github.com/karl0ss/comfy_fm24_newgens.git
				synced 2025-10-25 04:33:59 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			360 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			360 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import os
 | |
| import random
 | |
| import json
 | |
| import configparser
 | |
| import pycountry
 | |
| import inflect
 | |
| import logging
 | |
| import sys
 | |
| import logging
 | |
| import logging.config
 | |
| 
 | |
| from tqdm import tqdm
 | |
| from lib.rtf_parser import RTF_Parser
 | |
| try:
 | |
|     from lib.remove_bg import remove_bg_from_file_list
 | |
|     REMBG_AVAILABLE = True
 | |
| except ImportError:
 | |
|     REMBG_AVAILABLE = False
 | |
|     print("Warning: Background removal not available")
 | |
| from lib.generate_xml import create_config_xml, append_to_config_xml
 | |
| from lib.resize_images import resize_images
 | |
| from lib.xml_reader import extract_from_values
 | |
| # Profile functions are now handled entirely by GUI
 | |
| from lib.logging import LOGGING_CONFIG
 | |
| 
 | |
| # from simple_term_menu import TerminalMenu
 | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 | |
| import torch
 | |
| from PIL import Image
 | |
| 
 | |
| logging.config.dictConfig(LOGGING_CONFIG)
 | |
| 
 | |
| cut = 100
 | |
| update = False
 | |
| use_gpu = False
 | |
| process_player = False
 | |
| 
 | |
| # Load user configurations
 | |
| user_config = configparser.ConfigParser()
 | |
| try:
 | |
|     user_config.read("./user_config.cfg")
 | |
| 
 | |
|     # GUI mode: Use default profile or first available profile
 | |
|     profiles = [section.split(':', 1)[1] for section in user_config.sections() if section.startswith('profile:')]
 | |
|     if profiles:
 | |
|         selected_profile = profiles[0]  # Use first available profile
 | |
|         logging.debug(f"Using profile '{selected_profile}'")
 | |
|     else:
 | |
|         selected_profile = "NewGens"  # Default fallback
 | |
|         logging.debug(f"No profiles found, using default profile '{selected_profile}'")
 | |
| 
 | |
|     selected_profile = f"profile:{selected_profile}"
 | |
|     output_folder = user_config[selected_profile]["output_dir"]
 | |
|     logging.debug("Configuration loaded successfully.")
 | |
| except KeyError as e:
 | |
|     logging.error(f"Missing configuration key: {e}")
 | |
|     sys.exit(1)
 | |
| 
 | |
| rtf = RTF_Parser()
 | |
| p = inflect.engine()
 | |
| 
 | |
| 
 | |
| def generate_image(uid, comfy_prompt):
 | |
|     """Generate an image using local Stable Diffusion."""
 | |
|     try:
 | |
|         # Initialize the pipeline (do this once and reuse)
 | |
|         if not hasattr(generate_image, 'pipeline'):
 | |
|             logging.info("Loading Stable Diffusion model...")
 | |
| 
 | |
|             # Get model configuration
 | |
|             try:
 | |
|                 model_id = user_config["models"]["model_name"]
 | |
|                 model_dir = user_config["models"].get("model_dir", None)
 | |
|                 logging.info(f"Using model: {model_id}")
 | |
|             except KeyError:
 | |
|                 model_id = "SG161222/Realistic_Vision_V6.0_B1"
 | |
|                 model_dir = None
 | |
|                 logging.warning(f"Model configuration not found, using default: {model_id}")
 | |
| 
 | |
|             # Check if CUDA is available and get detailed GPU info
 | |
|             if torch.cuda.is_available() and not force_cpu:
 | |
|                 device = "cuda"
 | |
|                 gpu_count = torch.cuda.device_count()
 | |
|                 current_device = torch.cuda.current_device()
 | |
|                 gpu_name = torch.cuda.get_device_name(current_device)
 | |
|                 gpu_memory = torch.cuda.get_device_properties(current_device).total_memory / 1024**3  # GB
 | |
| 
 | |
|                 logging.info(f"GPU detected: {gpu_name}")
 | |
|                 logging.info(f"GPU memory: {gpu_memory:.1f} GB")
 | |
|                 logging.info(f"Available GPU devices: {gpu_count}")
 | |
|                 logging.info(f"Using device: {device} (GPU {current_device})")
 | |
|             else:
 | |
|                 if force_cpu:
 | |
|                     device = "cpu"
 | |
|                     logging.info("Forcing CPU usage as requested")
 | |
|                 else:
 | |
|                     device = "cpu"
 | |
|                     logging.warning("CUDA not available, using CPU")
 | |
|                     logging.info("To use GPU: Install CUDA toolkit and ensure PyTorch with CUDA support is installed")
 | |
|                     logging.info("GPU requirements: https://pytorch.org/get-started/locally/")
 | |
| 
 | |
|             # Load the pipeline
 | |
|             if model_dir:
 | |
|                 pipe = StableDiffusionPipeline.from_pretrained(
 | |
|                     model_id,
 | |
|                     cache_dir=model_dir,
 | |
|                     torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 | |
|                     safety_checker=None,
 | |
|                     requires_safety_checker=False
 | |
|                 )
 | |
|             else:
 | |
|                 pipe = StableDiffusionPipeline.from_pretrained(
 | |
|                     model_id,
 | |
|                     torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 | |
|                     safety_checker=None,
 | |
|                     requires_safety_checker=False
 | |
|                 )
 | |
| 
 | |
|             # Use DPMSolverMultistepScheduler for better quality/speed balance
 | |
|             pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 | |
| 
 | |
|             if device == "cuda":
 | |
|                 pipe = pipe.to("cuda")
 | |
|                 pipe.enable_attention_slicing()  # Reduce memory usage
 | |
|                 # Enable memory efficient attention if available
 | |
|                 try:
 | |
|                     pipe.enable_xformers_memory_efficient_attention()
 | |
|                     logging.info("Enabled memory efficient attention")
 | |
|                 except:
 | |
|                     pass
 | |
| 
 | |
|             generate_image.pipeline = pipe
 | |
|             generate_image.device = device  # Store device for later use
 | |
|             logging.info(f"Model loaded successfully on {device}")
 | |
| 
 | |
|         # Generate the image
 | |
|         logging.debug(f"Generating image for UID: {uid}")
 | |
| 
 | |
|         # Set random seed for reproducibility
 | |
|         generator = torch.Generator(device=generate_image.pipeline.device)
 | |
|         generator.manual_seed(random.getrandbits(32))
 | |
| 
 | |
|         # Generate image with parameters similar to ComfyUI workflow
 | |
|         image = generate_image.pipeline(
 | |
|             comfy_prompt,
 | |
|             num_inference_steps=6,
 | |
|             guidance_scale=1.5,
 | |
|             generator=generator,
 | |
|             width=512,
 | |
|             height=512
 | |
|         ).images[0]
 | |
| 
 | |
|         # Save the image
 | |
|         output_path = f"{user_config[selected_profile]['output_dir']}{uid}.png"
 | |
|         image.save(output_path)
 | |
|         logging.debug(f"Image generated successfully for UID: {uid}")
 | |
| 
 | |
|     except Exception as e:
 | |
|         logging.error(f"Failed to generate image for UID: {uid}. Error: {e}")
 | |
|         raise
 | |
| 
 | |
| 
 | |
| def get_country_name(app_config, country_code):
 | |
|     # First check if it's a custom mapping
 | |
|     if country_code in app_config["facial_characteristics"]:
 | |
|         return app_config["facial_characteristics"][country_code]
 | |
| 
 | |
|     # Use pycountry for standard country codes
 | |
|     country = pycountry.countries.get(alpha_3=country_code)
 | |
|     if country:
 | |
|         return country.name
 | |
|     return "Unknown Country"
 | |
| 
 | |
| 
 | |
| def generate_prompts_for_players(players, app_config):
 | |
|     """Generate images for a specific player and configuration."""
 | |
|     prompts = []
 | |
|     for player in players:
 | |
|         try:
 | |
|             logging.debug(f"Generating prompt for {player[0]} - {player[8]}")
 | |
|             os.makedirs(output_folder, exist_ok=True)
 | |
| 
 | |
|             country = get_country_name(app_config, player[1])
 | |
|             facial_characteristics = random.choice(app_config["facial_characteristics"])
 | |
|             hair_length = app_config["hair_length"][player[5]]
 | |
|             hair_colour = app_config["hair_color"][player[6]]
 | |
|             skin_tone = app_config["skin_tone_map"][player[7]]
 | |
|             player_age = p.number_to_words(player[3])
 | |
|             if int(player[5]) > 1:
 | |
|                 hair_extra = random.choice(app_config["hair"])
 | |
|             else:
 | |
|                 hair_extra = ""
 | |
| 
 | |
|             # Format the prompt
 | |
|             prompt = app_config["prompt"].format(
 | |
|                 skin_tone=skin_tone,
 | |
|                 age=player_age,
 | |
|                 country=country,
 | |
|                 facial_characteristics=facial_characteristics or "no facial hair",
 | |
|                 hair=f"{hair_length} {hair_colour} {hair_extra}",
 | |
|             )
 | |
|             logging.debug(f"Generated prompt: {prompt}")
 | |
|             prompt = f"{player[0]}:{prompt}"
 | |
|             prompts.append(prompt)
 | |
|         except KeyError as e:
 | |
|             logging.warning(f"Key error while generating prompt for player: {e}")
 | |
|     return prompts
 | |
| 
 | |
| 
 | |
| def post_process_images(
 | |
|     output_folder, update, processed_players, football_manager_version
 | |
| ):
 | |
|     """
 | |
|     Handles post-processing tasks for generated images.
 | |
| 
 | |
|     Args:
 | |
|         output_folder (str): Path to the folder where images are stored.
 | |
|         update (bool): Flag to determine if XML config should be updated.
 | |
|         processed_players (list): List of processed player IDs.
 | |
|     """
 | |
|     try:
 | |
|         # # Resize images to desired dimensions
 | |
|         # resize_images(output_folder, processed_players)
 | |
|         # logging.debug("Images resized successfully.")
 | |
| 
 | |
|         # Remove background from images if available
 | |
|         if REMBG_AVAILABLE:
 | |
|             try:
 | |
|                 remove_bg_from_file_list(output_folder, processed_players, use_gpu=use_gpu)
 | |
|                 logging.debug("Background removed from images.")
 | |
|             except Exception as e:
 | |
|                 logging.warning(f"Background removal failed: {e}")
 | |
|         else:
 | |
|             logging.info("Background removal not available (rembg not installed). Images will have original backgrounds.")
 | |
| 
 | |
|         # Update or create configuration XML
 | |
|         if update:
 | |
|             append_to_config_xml(
 | |
|                 output_folder, processed_players, football_manager_version
 | |
|             )
 | |
|             logging.debug("Configuration XML updated.")
 | |
|         else:
 | |
|             create_config_xml(
 | |
|                 output_folder, processed_players, football_manager_version
 | |
|             )
 | |
|             logging.debug("Configuration XML created.")
 | |
|     except Exception as e:
 | |
|         logging.error(f"Post-processing failed: {e}")
 | |
|         raise  # Re-raise the exception to ensure the script stops if post-processing fails.
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """Main function for generating images."""
 | |
|     # parser = argparse.ArgumentParser(description="Generate images for country groups")
 | |
|     # parser.add_argument(
 | |
|     #     "--rtf_file",
 | |
|     #     type=str,
 | |
|     #     default=None,
 | |
|     #     help="Path to the RTF file to be processed",
 | |
|     # )
 | |
|     # parser.add_argument(
 | |
|     #     "--player_uuid",
 | |
|     #     type=int,
 | |
|     #     default=None,
 | |
|     #     help="Player UUID to generate",
 | |
|     # )
 | |
|     # parser.add_argument(
 | |
|     #     "--num_inference_steps",
 | |
|     #     type=int,
 | |
|     #     default=6,
 | |
|     #     help="Number of inference steps. Defaults to 6",
 | |
|     # )
 | |
|     # args = parser.parse_args()
 | |
| 
 | |
|     # if not args.rtf_file:
 | |
|     #     logging.error("Please pass in a RTF file as --rtf_file")
 | |
|     #     sys.exit(1)
 | |
| 
 | |
|     # Load configurations
 | |
|     try:
 | |
|         with open("app_config.json", "r") as f:
 | |
|             app_config = json.load(f)
 | |
|         logging.debug("Application configuration loaded successfully.")
 | |
|     except FileNotFoundError:
 | |
|         logging.error("app_config.json file not found.")
 | |
|         sys.exit(1)
 | |
| 
 | |
|     # Parse the RTF file
 | |
|     try:
 | |
|         # rtf_file = random.sample(rtf.parse_rtf(args.rtf_file), cut)
 | |
|         rtf_location = user_config[selected_profile]["rtf_file"]
 | |
|         rtf_file = rtf.parse_rtf(rtf_location)[:cut]
 | |
|         logging.info(f"Parsed RTF file successfully. Found {len(rtf_file)} players.")
 | |
|     except FileNotFoundError:
 | |
|         logging.error(f"RTF file not found: {rtf_location}")
 | |
|         sys.exit(1)
 | |
| 
 | |
|     # Get parameters from environment variables (set by GUI)
 | |
|     update_mode = os.environ.get('FM_NEWGEN_UPDATE_MODE', 'false').lower() == 'true'
 | |
|     process_specific_player = os.environ.get('FM_NEWGEN_PROCESS_PLAYER', 'false').lower() == 'true'
 | |
|     specific_player_uid = os.environ.get('FM_NEWGEN_PLAYER_UID', '')
 | |
| 
 | |
|     # Check if user wants to force CPU usage
 | |
|     force_cpu = os.environ.get('FM_NEWGEN_FORCE_CPU', 'false').lower() == 'true'
 | |
| 
 | |
|     # Check for processed players
 | |
|     try:
 | |
|         if update_mode:
 | |
|             values_from_config = extract_from_values(
 | |
|                 f"{user_config[selected_profile]['output_dir']}config.xml"
 | |
|             )
 | |
|             # Extract the IDs from list_a
 | |
|             ids_in_b = [item for item in values_from_config]
 | |
| 
 | |
|             # Filter list_a to remove inner lists whose first item matches an ID in list_b
 | |
|             players_to_process = [item for item in rtf_file if item[0] not in ids_in_b]
 | |
|             if process_specific_player and specific_player_uid:
 | |
|                 players_to_process = [
 | |
|                     inner_list
 | |
|                     for inner_list in players_to_process
 | |
|                     if int(inner_list[0]) == int(specific_player_uid)
 | |
|                 ]
 | |
|         elif process_specific_player and specific_player_uid:
 | |
|             players_to_process = [
 | |
|                 inner_list
 | |
|                 for inner_list in rtf_file
 | |
|                 if int(inner_list[0]) == int(specific_player_uid)
 | |
|             ]
 | |
|         else:
 | |
|             players_to_process = rtf_file
 | |
|     except FileNotFoundError:
 | |
|         logging.error("config.xml file not found.")
 | |
|         sys.exit(1)
 | |
|     if len(players_to_process) > 0:
 | |
|         print(f"Processing {len(players_to_process)} players")
 | |
|         logging.info(f"Processing {len(players_to_process)} players")
 | |
|         prompts = generate_prompts_for_players(players_to_process, app_config)
 | |
|         for prompt in tqdm(prompts, desc="Generating Images"):
 | |
|             uid = prompt.split(":")[0]
 | |
|             comfy_prompt = prompt.split(":")[1]
 | |
|             generate_image(uid, comfy_prompt)
 | |
|         try:
 | |
|             post_process_images(
 | |
|                 output_folder,
 | |
|                 update,
 | |
|                 [item[0] for item in players_to_process],
 | |
|                 user_config[selected_profile]["football_manager_version"],
 | |
|             )
 | |
|         except Exception as e:
 | |
|             logging.error(f"Post-processing failed: {e}")
 | |
|     else:
 | |
|         print(f"{len(players_to_process)} players processed")
 | |
|         logging.info(f"{len(players_to_process)} players processed")
 | |
|     logging.info("Image generation complete for players in RTF file.")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |