text chunker

This commit is contained in:
Karl 2025-09-23 15:44:13 +01:00
parent 36ff97e44a
commit fd999ec1e6

View File

@ -5,37 +5,58 @@ class CLIPTextChunker:
""" """
Utility class for chunking text to fit within CLIP's token limits. Utility class for chunking text to fit within CLIP's token limits.
CLIP models typically have a maximum sequence length of 77 tokens. CLIP models typically have a maximum sequence length of 77 tokens.
Using a conservative limit of 60 tokens to account for special tokens. Using a conservative limit of 70 tokens to account for special tokens.
""" """
def __init__(self, max_tokens: int = 60): def __init__(self, max_tokens: int = 70):
""" """
Initialize the text chunker. Initialize the text chunker.
Args: Args:
max_tokens (int): Maximum number of tokens per chunk (default: 60 for CLIP, being conservative) max_tokens (int): Maximum number of tokens per chunk (default: 70 for CLIP, being conservative)
""" """
self.max_tokens = max_tokens self.max_tokens = max_tokens
self._tokenizer = None
def estimate_token_count(self, text: str) -> int: @property
def tokenizer(self):
"""Lazy load CLIP tokenizer"""
if self._tokenizer is None:
try:
from transformers import CLIPTokenizer
self._tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
except ImportError:
# Fallback to character-based estimation if transformers not available
self._tokenizer = None
return self._tokenizer
def get_token_count(self, text: str) -> int:
""" """
Estimate the number of tokens in a text string. Get the actual token count for a text string using CLIP tokenizer.
Uses character count as a simple proxy for token count.
Args: Args:
text (str): Input text text (str): Input text
Returns: Returns:
int: Estimated token count (using character count as proxy) int: Actual token count
""" """
# Simple approach: use character count as a proxy for token count if self.tokenizer is None:
# This is much more reliable than trying to estimate actual tokens # Fallback to character count if tokenizer not available
return len(text) return len(text)
tokens = self.tokenizer(
text,
padding=False,
truncation=False,
return_tensors=None
)
return len(tokens["input_ids"])
def chunk_text(self, text: str, preserve_sentences: bool = True) -> List[str]: def chunk_text(self, text: str, preserve_sentences: bool = True) -> List[str]:
""" """
Chunk text into smaller pieces that fit within the token limit. Chunk text into smaller pieces that fit within the token limit.
Uses character count as a simple and reliable approach. Uses actual CLIP tokenization for accuracy.
Args: Args:
text (str): Input text to chunk text (str): Input text to chunk
@ -47,26 +68,29 @@ class CLIPTextChunker:
if not text.strip(): if not text.strip():
return [] return []
if self.estimate_token_count(text) <= self.max_tokens: if self.get_token_count(text) <= self.max_tokens:
return [text] return [text]
chunks = [] chunks = []
words = text.split() words = text.split()
current_chunk = [] current_chunk = []
current_length = 0 current_tokens = 0
for word in words: for word in words:
word_with_space = word + " " word_with_space = word + " "
# If adding this word would exceed the limit, start a new chunk # Check if adding this word would exceed the limit
if current_length + len(word_with_space) > self.max_tokens and current_chunk: test_chunk = " ".join(current_chunk + [word])
# Join the current chunk and add it test_tokens = self.get_token_count(test_chunk)
if test_tokens > self.max_tokens and current_chunk:
# Current chunk is complete, add it
chunks.append(" ".join(current_chunk)) chunks.append(" ".join(current_chunk))
current_chunk = [word] current_chunk = [word]
current_length = len(word_with_space) current_tokens = self.get_token_count(word)
else: else:
current_chunk.append(word) current_chunk.append(word)
current_length += len(word_with_space) current_tokens = test_tokens
# Add the last chunk if it exists # Add the last chunk if it exists
if current_chunk: if current_chunk:
@ -85,54 +109,72 @@ class CLIPTextChunker:
Returns: Returns:
List[str]: List of prioritized chunks List[str]: List of prioritized chunks
""" """
# First, try to create chunks that include essential information # If text fits within limits, return as-is
essential_chunks = [] if self.get_token_count(text) <= self.max_tokens:
return [text]
# Find the most important essential information at the beginning
# Look for key phrases that should be preserved
first_chunk = ""
remaining_text = text
# Try to find essential info near the beginning
for info in essential_info: for info in essential_info:
if info in text: if info in text:
# Create a chunk focused on this essential info
info_index = text.find(info) info_index = text.find(info)
start = max(0, info_index - 50) # Include some context before # If the essential info is near the beginning, include it
end = min(len(text), info_index + len(info) + 50) # Include some context after if info_index < 100: # Within first 100 characters
context = text[start:end] # Take from start up to and including the essential info
end_pos = min(len(text), info_index + len(info) + 30) # Include some context after
candidate_chunk = text[:end_pos]
chunk = self.chunk_text(context)[0] # Take the first (most relevant) chunk # Ensure the candidate chunk ends at a word boundary
if chunk not in essential_chunks: last_space = candidate_chunk.rfind(" ")
essential_chunks.append(chunk) if last_space > 0:
candidate_chunk = candidate_chunk[:last_space]
# If we have too many essential chunks, combine them # Use the basic chunking to ensure proper word boundaries
if len(essential_chunks) > 1: if self.get_token_count(candidate_chunk) <= self.max_tokens:
combined = " ".join(essential_chunks) # Use chunk_text to get a properly bounded chunk
if self.estimate_token_count(combined) <= self.max_tokens: temp_chunks = self.chunk_text(candidate_chunk)
return [combined] if temp_chunks:
else: first_chunk = temp_chunks[0]
# Need to reduce the combined chunk remaining_text = text[len(first_chunk):]
return self.chunk_text(combined) break
return essential_chunks if essential_chunks else self.chunk_text(text) # If we found a good first chunk, use it
if first_chunk and self.get_token_count(first_chunk) <= self.max_tokens:
chunks = [first_chunk]
# Add remaining text as additional chunks if needed
if remaining_text.strip():
chunks.extend(self.chunk_text(remaining_text))
return chunks
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 60) -> List[str]: # Fallback to regular chunking
return self.chunk_text(text)
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 70) -> List[str]:
""" """
Convenience function to chunk a prompt for CLIP processing. Convenience function to chunk a prompt for CLIP processing.
Uses a conservative 60 token limit to be safe. Uses a conservative 70 token limit to be safe.
Args: Args:
prompt (str): The prompt to chunk prompt (str): The prompt to chunk
max_tokens (int): Maximum tokens per chunk (default: 60 for safety) max_tokens (int): Maximum tokens per chunk (default: 70 for safety)
Returns: Returns:
List[str]: List of prompt chunks List[str]: List of prompt chunks
""" """
chunker = CLIPTextChunker(max_tokens=max_tokens) chunker = CLIPTextChunker(max_tokens=max_tokens)
# Define essential information that should be preserved # Define essential information that should be preserved (matching actual prompt format)
essential_info = [ essential_info = [
"Ultra-realistic close-up headshot", "Ultra realistic headshot",
"male soccer player", "male soccer player",
"looking at the camera", "looking at the camera",
"facing the camera", "facing the camera",
"confident expression", "Olive skinned",
"soccer jersey" "transparent background"
] ]
return chunker.create_priority_chunks(prompt, essential_info) return chunker.create_priority_chunks(prompt, essential_info)