text chunking

This commit is contained in:
Karl 2025-09-23 17:08:52 +01:00
parent 6aeeb74e8f
commit e458e748a3
2 changed files with 12 additions and 12 deletions

View File

@ -8,12 +8,12 @@ class CLIPTextChunker:
Using a conservative limit of 70 tokens to account for special tokens. Using a conservative limit of 70 tokens to account for special tokens.
""" """
def __init__(self, max_tokens: int = 70): def __init__(self, max_tokens: int = 60):
""" """
Initialize the text chunker. Initialize the text chunker.
Args: Args:
max_tokens (int): Maximum number of tokens per chunk (default: 70 for CLIP, leaving buffer for special tokens) max_tokens (int): Maximum number of tokens per chunk (default: 60 for CLIP, being extra conservative)
""" """
self.max_tokens = max_tokens self.max_tokens = max_tokens
self._tokenizer = None self._tokenizer = None
@ -44,8 +44,8 @@ class CLIPTextChunker:
if self.tokenizer is None: if self.tokenizer is None:
# Fallback to character count if tokenizer not available # Fallback to character count if tokenizer not available
# CLIP tokenization is roughly 0.25-0.3 characters per token on average # CLIP tokenization is roughly 0.25-0.3 characters per token on average
# Use 0.25 for a more conservative estimate to avoid exceeding limits # Use 0.2 for an ultra-conservative estimate to ensure we never exceed limits
return int(len(text) * 0.25) return int(len(text) * 0.2)
tokens = self.tokenizer( tokens = self.tokenizer(
text, text,
@ -165,14 +165,14 @@ class CLIPTextChunker:
# Fallback to regular chunking # Fallback to regular chunking
return self.chunk_text(text) return self.chunk_text(text)
def chunk_prompt_for_clip(prompt: str, max_tokens: int = 70) -> List[str]: def chunk_prompt_for_clip(prompt: str, max_tokens: int = 60) -> List[str]:
""" """
Convenience function to chunk a prompt for CLIP processing. Convenience function to chunk a prompt for CLIP processing.
Uses a 70 token limit to be safe while allowing meaningful prompts. Uses a 60 token limit to be extra safe for any CLIP model.
Args: Args:
prompt (str): The prompt to chunk prompt (str): The prompt to chunk
max_tokens (int): Maximum tokens per chunk (default: 70 for CLIP compatibility) max_tokens (int): Maximum tokens per chunk (default: 60 for maximum CLIP compatibility)
Returns: Returns:
List[str]: List of prompt chunks List[str]: List of prompt chunks

View File

@ -20,7 +20,7 @@ def test_long_prompt_chunking():
print("-" * 80) print("-" * 80)
# Test the chunking # Test the chunking
chunker = CLIPTextChunker(max_tokens=70) chunker = CLIPTextChunker(max_tokens=60)
chunks = chunk_prompt_for_clip(test_prompt) chunks = chunk_prompt_for_clip(test_prompt)
print(f"Number of chunks: {len(chunks)}") print(f"Number of chunks: {len(chunks)}")
@ -35,8 +35,8 @@ def test_long_prompt_chunking():
if token_count > 77: if token_count > 77:
print(f" ❌ ERROR: Chunk {i+1} exceeds CLIP's 77 token limit!") print(f" ❌ ERROR: Chunk {i+1} exceeds CLIP's 77 token limit!")
return False return False
elif token_count > 70: elif token_count > 60:
print(f" ⚠️ WARNING: Chunk {i+1} is close to the 77 token limit") print(f" ⚠️ WARNING: Chunk {i+1} is close to the 60 token limit")
else: else:
print(f" ✅ Chunk {i+1} is within safe limits") print(f" ✅ Chunk {i+1} is within safe limits")
@ -47,7 +47,7 @@ def test_long_prompt_chunking():
def test_edge_cases(): def test_edge_cases():
"""Test edge cases for the chunking functionality.""" """Test edge cases for the chunking functionality."""
chunker = CLIPTextChunker(max_tokens=70) chunker = CLIPTextChunker(max_tokens=60)
# Test empty string # Test empty string
chunks = chunker.chunk_text("") chunks = chunker.chunk_text("")
@ -63,7 +63,7 @@ def test_edge_cases():
chunks = chunker.chunk_text(long_word) chunks = chunker.chunk_text(long_word)
# Should handle this gracefully # Should handle this gracefully
for chunk in chunks: for chunk in chunks:
assert chunker.get_token_count(chunk) <= 70, "Long word chunks should respect token limit" assert chunker.get_token_count(chunk) <= 60, "Long word chunks should respect token limit"
print("✅ Edge case tests passed!") print("✅ Edge case tests passed!")
return True return True