ryanyen22's picture
Add utils/model_utils.py
fddaf33 verified
"""
Model loading and tokenization utilities.
Supports:
- Local loading with optional quantization (4-bit, 8-bit)
- Multiple model sizes (8B for prototyping, 70B for production)
- Consistent tokenization across scenarios
"""
import torch
from typing import Optional, Dict, Any, Tuple
def load_model_and_tokenizer(
model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
quantize: Optional[str] = None, # '4bit', '8bit', None
device_map: str = "auto",
attn_implementation: Optional[str] = None,
) -> Tuple[Any, Any]:
"""
Load a HuggingFace model and tokenizer.
Args:
model_name: HF model ID
quantize: '4bit', '8bit', or None for full precision
device_map: device placement strategy
attn_implementation: 'flash_attention_2', 'sdpa', or None
Returns:
(model, tokenizer) tuple
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
kwargs: Dict[str, Any] = {
"device_map": device_map,
"torch_dtype": torch.bfloat16,
}
if quantize == "4bit":
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
elif quantize == "8bit":
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if attn_implementation:
kwargs["attn_implementation"] = attn_implementation
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
model.eval()
return model, tokenizer
def get_token_ids(tokenizer, tokens: list) -> Dict[str, int]:
"""
Get token IDs for a list of target tokens (aggregation functions).
Handles multi-token cases by returning the first token.
"""
token_ids = {}
for token in tokens:
# Try with and without leading space
for variant in [token, f" {token}", f" {token.upper()}", token.upper()]:
ids = tokenizer.encode(variant, add_special_tokens=False)
if len(ids) >= 1:
token_ids[token] = ids[0]
break
return token_ids
def get_logit_probs(model, tokenizer, prompt: str, target_tokens: list) -> Dict[str, float]:
"""
Get probability distribution over target tokens at the next-token position.
Args:
model: loaded HF model
tokenizer: corresponding tokenizer
prompt: input prompt text
target_tokens: list of target completions (e.g., ['MAX', 'AVG', 'MEDIAN'])
Returns:
Dict mapping token -> probability
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :] # last token position
token_ids = get_token_ids(tokenizer, target_tokens)
# Extract logits for target tokens
target_logits = torch.tensor([logits[tid].item() for tid in token_ids.values()])
probs = torch.softmax(target_logits, dim=0)
result = {}
for (token, _), prob in zip(token_ids.items(), probs):
result[token] = prob.item()
return result
def get_logit_difference(
model, tokenizer, prompt: str,
positive_token: str, negative_token: str
) -> float:
"""
Compute logit difference: logit(positive) - logit(negative).
This is the primary metric for circuit analysis:
- Positive values → model prefers positive_token
- Negative values → model prefers negative_token
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :]
pos_ids = get_token_ids(tokenizer, [positive_token])
neg_ids = get_token_ids(tokenizer, [negative_token])
pos_logit = logits[list(pos_ids.values())[0]]
neg_logit = logits[list(neg_ids.values())[0]]
return (pos_logit - neg_logit).item()