comfy_fm24_newgens/lib/text_chunker.py

180 lines
6.2 KiB
Python
Raw Normal View History

2025-09-23 15:30:50 +01:00
import re
from typing import List, Optional
class CLIPTextChunker:
"""
Utility class for chunking text to fit within CLIP's token limits.
CLIP models typically have a maximum sequence length of 77 tokens.
2025-09-23 15:44:13 +01:00
Using a conservative limit of 70 tokens to account for special tokens.
2025-09-23 15:30:50 +01:00
"""
2025-09-23 15:44:13 +01:00
def __init__(self, max_tokens: int = 70):
2025-09-23 15:30:50 +01:00
"""
Initialize the text chunker.
Args:
2025-09-23 15:44:13 +01:00
max_tokens (int): Maximum number of tokens per chunk (default: 70 for CLIP, being conservative)
2025-09-23 15:30:50 +01:00
"""
self.max_tokens = max_tokens
2025-09-23 15:44:13 +01:00
self._tokenizer = None
@property
def tokenizer(self):
"""Lazy load CLIP tokenizer"""
if self._tokenizer is None:
try:
from transformers import CLIPTokenizer
self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
except ImportError:
# Fallback to character-based estimation if transformers not available
self._tokenizer = None
return self._tokenizer
def get_token_count(self, text: str) -> int:
2025-09-23 15:30:50 +01:00
"""
2025-09-23 15:44:13 +01:00
Get the actual token count for a text string using CLIP tokenizer.
2025-09-23 15:30:50 +01:00
Args:
text (str): Input text
Returns:
2025-09-23 15:44:13 +01:00
int: Actual token count
2025-09-23 15:30:50 +01:00
"""
2025-09-23 15:44:13 +01:00
if self.tokenizer is None:
# Fallback to character count if tokenizer not available
return len(text)
tokens = self.tokenizer(
text,
padding=False,
truncation=False,
return_tensors=None
)
return len(tokens["input_ids"])
2025-09-23 15:30:50 +01:00
def chunk_text(self, text: str, preserve_sentences: bool = True) -> List[str]:
"""
Chunk text into smaller pieces that fit within the token limit.
2025-09-23 15:44:13 +01:00
Uses actual CLIP tokenization for accuracy.
2025-09-23 15:30:50 +01:00
Args:
text (str): Input text to chunk
preserve_sentences (bool): Whether to try to preserve sentence boundaries
Returns:
List[str]: List of text chunks
"""
if not text.strip():
return []
2025-09-23 15:44:13 +01:00
if self.get_token_count(text) <= self.max_tokens:
2025-09-23 15:30:50 +01:00
return [text]
chunks = []
words = text.split()
current_chunk = []
2025-09-23 15:44:13 +01:00
current_tokens = 0
2025-09-23 15:30:50 +01:00
for word in words:
word_with_space = word + " "
2025-09-23 15:44:13 +01:00
# Check if adding this word would exceed the limit
test_chunk = " ".join(current_chunk + [word])
test_tokens = self.get_token_count(test_chunk)
if test_tokens > self.max_tokens and current_chunk:
# Current chunk is complete, add it
2025-09-23 15:30:50 +01:00
chunks.append(" ".join(current_chunk))
current_chunk = [word]
2025-09-23 15:44:13 +01:00
current_tokens = self.get_token_count(word)
2025-09-23 15:30:50 +01:00
else:
current_chunk.append(word)
2025-09-23 15:44:13 +01:00
current_tokens = test_tokens
2025-09-23 15:30:50 +01:00
# Add the last chunk if it exists
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def create_priority_chunks(self, text: str, essential_info: List[str]) -> List[str]:
"""
Create chunks with priority given to essential information.
Args:
text (str): Full text to chunk
essential_info (List[str]): List of essential phrases that should be preserved
Returns:
List[str]: List of prioritized chunks
"""
2025-09-23 15:44:13 +01:00
# If text fits within limits, return as-is
if self.get_token_count(text) <= self.max_tokens:
return [text]
2025-09-23 15:30:50 +01:00
2025-09-23 15:44:13 +01:00
# Find the most important essential information at the beginning
# Look for key phrases that should be preserved
first_chunk = ""
remaining_text = text
# Try to find essential info near the beginning
2025-09-23 15:30:50 +01:00
for info in essential_info:
if info in text:
info_index = text.find(info)
2025-09-23 15:44:13 +01:00
# If the essential info is near the beginning, include it
if info_index < 100: # Within first 100 characters
# Take from start up to and including the essential info
end_pos = min(len(text), info_index + len(info) + 30) # Include some context after
candidate_chunk = text[:end_pos]
# Ensure the candidate chunk ends at a word boundary
last_space = candidate_chunk.rfind(" ")
if last_space > 0:
candidate_chunk = candidate_chunk[:last_space]
# Use the basic chunking to ensure proper word boundaries
if self.get_token_count(candidate_chunk) <= self.max_tokens:
# Use chunk_text to get a properly bounded chunk
temp_chunks = self.chunk_text(candidate_chunk)
if temp_chunks:
first_chunk = temp_chunks[0]
remaining_text = text[len(first_chunk):]
break
# If we found a good first chunk, use it
if first_chunk and self.get_token_count(first_chunk) <= self.max_tokens:
chunks = [first_chunk]
# Add remaining text as additional chunks if needed
if remaining_text.strip():
chunks.extend(self.chunk_text(remaining_text))
return chunks
# Fallback to regular chunking
return self.chunk_text(text)
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 70) -> List[str]:
2025-09-23 15:30:50 +01:00
"""
Convenience function to chunk a prompt for CLIP processing.
2025-09-23 15:44:13 +01:00
Uses a conservative 70 token limit to be safe.
2025-09-23 15:30:50 +01:00
Args:
prompt (str): The prompt to chunk
2025-09-23 15:44:13 +01:00
max_tokens (int): Maximum tokens per chunk (default: 70 for safety)
2025-09-23 15:30:50 +01:00
Returns:
List[str]: List of prompt chunks
"""
chunker = CLIPTextChunker(max_tokens=max_tokens)
2025-09-23 15:44:13 +01:00
# Define essential information that should be preserved (matching actual prompt format)
2025-09-23 15:30:50 +01:00
essential_info = [
2025-09-23 15:44:13 +01:00
"Ultra realistic headshot",
2025-09-23 15:30:50 +01:00
"male soccer player",
"looking at the camera",
"facing the camera",
2025-09-23 15:44:13 +01:00
"Olive skinned",
"transparent background"
2025-09-23 15:30:50 +01:00
]
return chunker.create_priority_chunks(prompt, essential_info)