mirror of
https://github.com/karl0ss/comfy_fm24_newgens.git
synced 2025-10-03 06:40:06 +01:00
text chunker
This commit is contained in:
parent
36ff97e44a
commit
fd999ec1e6
@ -5,37 +5,58 @@ class CLIPTextChunker:
|
||||
"""
|
||||
Utility class for chunking text to fit within CLIP's token limits.
|
||||
CLIP models typically have a maximum sequence length of 77 tokens.
|
||||
Using a conservative limit of 60 tokens to account for special tokens.
|
||||
Using a conservative limit of 70 tokens to account for special tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, max_tokens: int = 60):
|
||||
def __init__(self, max_tokens: int = 70):
|
||||
"""
|
||||
Initialize the text chunker.
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum number of tokens per chunk (default: 60 for CLIP, being conservative)
|
||||
max_tokens (int): Maximum number of tokens per chunk (default: 70 for CLIP, being conservative)
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self._tokenizer = None
|
||||
|
||||
def estimate_token_count(self, text: str) -> int:
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""Lazy load CLIP tokenizer"""
|
||||
if self._tokenizer is None:
|
||||
try:
|
||||
from transformers import CLIPTokenizer
|
||||
self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
except ImportError:
|
||||
# Fallback to character-based estimation if transformers not available
|
||||
self._tokenizer = None
|
||||
return self._tokenizer
|
||||
|
||||
def get_token_count(self, text: str) -> int:
|
||||
"""
|
||||
Estimate the number of tokens in a text string.
|
||||
Uses character count as a simple proxy for token count.
|
||||
Get the actual token count for a text string using CLIP tokenizer.
|
||||
|
||||
Args:
|
||||
text (str): Input text
|
||||
|
||||
Returns:
|
||||
int: Estimated token count (using character count as proxy)
|
||||
int: Actual token count
|
||||
"""
|
||||
# Simple approach: use character count as a proxy for token count
|
||||
# This is much more reliable than trying to estimate actual tokens
|
||||
return len(text)
|
||||
if self.tokenizer is None:
|
||||
# Fallback to character count if tokenizer not available
|
||||
return len(text)
|
||||
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
return_tensors=None
|
||||
)
|
||||
|
||||
return len(tokens["input_ids"])
|
||||
|
||||
def chunk_text(self, text: str, preserve_sentences: bool = True) -> List[str]:
|
||||
"""
|
||||
Chunk text into smaller pieces that fit within the token limit.
|
||||
Uses character count as a simple and reliable approach.
|
||||
Uses actual CLIP tokenization for accuracy.
|
||||
|
||||
Args:
|
||||
text (str): Input text to chunk
|
||||
@ -47,26 +68,29 @@ class CLIPTextChunker:
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
if self.estimate_token_count(text) <= self.max_tokens:
|
||||
if self.get_token_count(text) <= self.max_tokens:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
words = text.split()
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
current_tokens = 0
|
||||
|
||||
for word in words:
|
||||
word_with_space = word + " "
|
||||
|
||||
# If adding this word would exceed the limit, start a new chunk
|
||||
if current_length + len(word_with_space) > self.max_tokens and current_chunk:
|
||||
# Join the current chunk and add it
|
||||
# Check if adding this word would exceed the limit
|
||||
test_chunk = " ".join(current_chunk + [word])
|
||||
test_tokens = self.get_token_count(test_chunk)
|
||||
|
||||
if test_tokens > self.max_tokens and current_chunk:
|
||||
# Current chunk is complete, add it
|
||||
chunks.append(" ".join(current_chunk))
|
||||
current_chunk = [word]
|
||||
current_length = len(word_with_space)
|
||||
current_tokens = self.get_token_count(word)
|
||||
else:
|
||||
current_chunk.append(word)
|
||||
current_length += len(word_with_space)
|
||||
current_tokens = test_tokens
|
||||
|
||||
# Add the last chunk if it exists
|
||||
if current_chunk:
|
||||
@ -85,54 +109,72 @@ class CLIPTextChunker:
|
||||
Returns:
|
||||
List[str]: List of prioritized chunks
|
||||
"""
|
||||
# First, try to create chunks that include essential information
|
||||
essential_chunks = []
|
||||
# If text fits within limits, return as-is
|
||||
if self.get_token_count(text) <= self.max_tokens:
|
||||
return [text]
|
||||
|
||||
# Find the most important essential information at the beginning
|
||||
# Look for key phrases that should be preserved
|
||||
first_chunk = ""
|
||||
remaining_text = text
|
||||
|
||||
# Try to find essential info near the beginning
|
||||
for info in essential_info:
|
||||
if info in text:
|
||||
# Create a chunk focused on this essential info
|
||||
info_index = text.find(info)
|
||||
start = max(0, info_index - 50) # Include some context before
|
||||
end = min(len(text), info_index + len(info) + 50) # Include some context after
|
||||
context = text[start:end]
|
||||
# If the essential info is near the beginning, include it
|
||||
if info_index < 100: # Within first 100 characters
|
||||
# Take from start up to and including the essential info
|
||||
end_pos = min(len(text), info_index + len(info) + 30) # Include some context after
|
||||
candidate_chunk = text[:end_pos]
|
||||
|
||||
chunk = self.chunk_text(context)[0] # Take the first (most relevant) chunk
|
||||
if chunk not in essential_chunks:
|
||||
essential_chunks.append(chunk)
|
||||
# Ensure the candidate chunk ends at a word boundary
|
||||
last_space = candidate_chunk.rfind(" ")
|
||||
if last_space > 0:
|
||||
candidate_chunk = candidate_chunk[:last_space]
|
||||
|
||||
# If we have too many essential chunks, combine them
|
||||
if len(essential_chunks) > 1:
|
||||
combined = " ".join(essential_chunks)
|
||||
if self.estimate_token_count(combined) <= self.max_tokens:
|
||||
return [combined]
|
||||
else:
|
||||
# Need to reduce the combined chunk
|
||||
return self.chunk_text(combined)
|
||||
# Use the basic chunking to ensure proper word boundaries
|
||||
if self.get_token_count(candidate_chunk) <= self.max_tokens:
|
||||
# Use chunk_text to get a properly bounded chunk
|
||||
temp_chunks = self.chunk_text(candidate_chunk)
|
||||
if temp_chunks:
|
||||
first_chunk = temp_chunks[0]
|
||||
remaining_text = text[len(first_chunk):]
|
||||
break
|
||||
|
||||
return essential_chunks if essential_chunks else self.chunk_text(text)
|
||||
# If we found a good first chunk, use it
|
||||
if first_chunk and self.get_token_count(first_chunk) <= self.max_tokens:
|
||||
chunks = [first_chunk]
|
||||
# Add remaining text as additional chunks if needed
|
||||
if remaining_text.strip():
|
||||
chunks.extend(self.chunk_text(remaining_text))
|
||||
return chunks
|
||||
|
||||
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 60) -> List[str]:
|
||||
# Fallback to regular chunking
|
||||
return self.chunk_text(text)
|
||||
|
||||
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 70) -> List[str]:
|
||||
"""
|
||||
Convenience function to chunk a prompt for CLIP processing.
|
||||
Uses a conservative 60 token limit to be safe.
|
||||
Uses a conservative 70 token limit to be safe.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to chunk
|
||||
max_tokens (int): Maximum tokens per chunk (default: 60 for safety)
|
||||
max_tokens (int): Maximum tokens per chunk (default: 70 for safety)
|
||||
|
||||
Returns:
|
||||
List[str]: List of prompt chunks
|
||||
"""
|
||||
chunker = CLIPTextChunker(max_tokens=max_tokens)
|
||||
|
||||
# Define essential information that should be preserved
|
||||
# Define essential information that should be preserved (matching actual prompt format)
|
||||
essential_info = [
|
||||
"Ultra-realistic close-up headshot",
|
||||
"Ultra realistic headshot",
|
||||
"male soccer player",
|
||||
"looking at the camera",
|
||||
"facing the camera",
|
||||
"confident expression",
|
||||
"soccer jersey"
|
||||
"Olive skinned",
|
||||
"transparent background"
|
||||
]
|
||||
|
||||
return chunker.create_priority_chunks(prompt, essential_info)
|
Loading…
x
Reference in New Issue
Block a user