mirror of
https://github.com/karl0ss/comfy_fm24_newgens.git
synced 2025-10-25 12:43:59 +01:00
138 lines
4.7 KiB
Python
138 lines
4.7 KiB
Python
|
|
import re
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
class CLIPTextChunker:
|
||
|
|
"""
|
||
|
|
Utility class for chunking text to fit within CLIP's token limits.
|
||
|
|
CLIP models typically have a maximum sequence length of 77 tokens.
|
||
|
|
Using a conservative limit of 60 tokens to account for special tokens.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, max_tokens: int = 60):
|
||
|
|
"""
|
||
|
|
Initialize the text chunker.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_tokens (int): Maximum number of tokens per chunk (default: 60 for CLIP, being conservative)
|
||
|
|
"""
|
||
|
|
self.max_tokens = max_tokens
|
||
|
|
|
||
|
|
def estimate_token_count(self, text: str) -> int:
|
||
|
|
"""
|
||
|
|
Estimate the number of tokens in a text string.
|
||
|
|
Uses character count as a simple proxy for token count.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text (str): Input text
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
int: Estimated token count (using character count as proxy)
|
||
|
|
"""
|
||
|
|
# Simple approach: use character count as a proxy for token count
|
||
|
|
# This is much more reliable than trying to estimate actual tokens
|
||
|
|
return len(text)
|
||
|
|
|
||
|
|
def chunk_text(self, text: str, preserve_sentences: bool = True) -> List[str]:
|
||
|
|
"""
|
||
|
|
Chunk text into smaller pieces that fit within the token limit.
|
||
|
|
Uses character count as a simple and reliable approach.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text (str): Input text to chunk
|
||
|
|
preserve_sentences (bool): Whether to try to preserve sentence boundaries
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[str]: List of text chunks
|
||
|
|
"""
|
||
|
|
if not text.strip():
|
||
|
|
return []
|
||
|
|
|
||
|
|
if self.estimate_token_count(text) <= self.max_tokens:
|
||
|
|
return [text]
|
||
|
|
|
||
|
|
chunks = []
|
||
|
|
words = text.split()
|
||
|
|
current_chunk = []
|
||
|
|
current_length = 0
|
||
|
|
|
||
|
|
for word in words:
|
||
|
|
word_with_space = word + " "
|
||
|
|
|
||
|
|
# If adding this word would exceed the limit, start a new chunk
|
||
|
|
if current_length + len(word_with_space) > self.max_tokens and current_chunk:
|
||
|
|
# Join the current chunk and add it
|
||
|
|
chunks.append(" ".join(current_chunk))
|
||
|
|
current_chunk = [word]
|
||
|
|
current_length = len(word_with_space)
|
||
|
|
else:
|
||
|
|
current_chunk.append(word)
|
||
|
|
current_length += len(word_with_space)
|
||
|
|
|
||
|
|
# Add the last chunk if it exists
|
||
|
|
if current_chunk:
|
||
|
|
chunks.append(" ".join(current_chunk))
|
||
|
|
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
def create_priority_chunks(self, text: str, essential_info: List[str]) -> List[str]:
|
||
|
|
"""
|
||
|
|
Create chunks with priority given to essential information.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text (str): Full text to chunk
|
||
|
|
essential_info (List[str]): List of essential phrases that should be preserved
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[str]: List of prioritized chunks
|
||
|
|
"""
|
||
|
|
# First, try to create chunks that include essential information
|
||
|
|
essential_chunks = []
|
||
|
|
|
||
|
|
for info in essential_info:
|
||
|
|
if info in text:
|
||
|
|
# Create a chunk focused on this essential info
|
||
|
|
info_index = text.find(info)
|
||
|
|
start = max(0, info_index - 50) # Include some context before
|
||
|
|
end = min(len(text), info_index + len(info) + 50) # Include some context after
|
||
|
|
context = text[start:end]
|
||
|
|
|
||
|
|
chunk = self.chunk_text(context)[0] # Take the first (most relevant) chunk
|
||
|
|
if chunk not in essential_chunks:
|
||
|
|
essential_chunks.append(chunk)
|
||
|
|
|
||
|
|
# If we have too many essential chunks, combine them
|
||
|
|
if len(essential_chunks) > 1:
|
||
|
|
combined = " ".join(essential_chunks)
|
||
|
|
if self.estimate_token_count(combined) <= self.max_tokens:
|
||
|
|
return [combined]
|
||
|
|
else:
|
||
|
|
# Need to reduce the combined chunk
|
||
|
|
return self.chunk_text(combined)
|
||
|
|
|
||
|
|
return essential_chunks if essential_chunks else self.chunk_text(text)
|
||
|
|
|
||
|
|
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 60) -> List[str]:
|
||
|
|
"""
|
||
|
|
Convenience function to chunk a prompt for CLIP processing.
|
||
|
|
Uses a conservative 60 token limit to be safe.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
prompt (str): The prompt to chunk
|
||
|
|
max_tokens (int): Maximum tokens per chunk (default: 60 for safety)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[str]: List of prompt chunks
|
||
|
|
"""
|
||
|
|
chunker = CLIPTextChunker(max_tokens=max_tokens)
|
||
|
|
|
||
|
|
# Define essential information that should be preserved
|
||
|
|
essential_info = [
|
||
|
|
"Ultra-realistic close-up headshot",
|
||
|
|
"male soccer player",
|
||
|
|
"looking at the camera",
|
||
|
|
"facing the camera",
|
||
|
|
"confident expression",
|
||
|
|
"soccer jersey"
|
||
|
|
]
|
||
|
|
|
||
|
|
return chunker.create_priority_chunks(prompt, essential_info)
|