""" Model loading and tokenization utilities. Supports: - Local loading with optional quantization (4-bit, 8-bit) - Multiple model sizes (8B for prototyping, 70B for production) - Consistent tokenization across scenarios """ import torch from typing import Optional, Dict, Any, Tuple def load_model_and_tokenizer( model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", quantize: Optional[str] = None, # '4bit', '8bit', None device_map: str = "auto", attn_implementation: Optional[str] = None, ) -> Tuple[Any, Any]: """ Load a HuggingFace model and tokenizer. Args: model_name: HF model ID quantize: '4bit', '8bit', or None for full precision device_map: device placement strategy attn_implementation: 'flash_attention_2', 'sdpa', or None Returns: (model, tokenizer) tuple """ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token kwargs: Dict[str, Any] = { "device_map": device_map, "torch_dtype": torch.bfloat16, } if quantize == "4bit": kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) elif quantize == "8bit": kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) if attn_implementation: kwargs["attn_implementation"] = attn_implementation model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) model.eval() return model, tokenizer def get_token_ids(tokenizer, tokens: list) -> Dict[str, int]: """ Get token IDs for a list of target tokens (aggregation functions). Handles multi-token cases by returning the first token. """ token_ids = {} for token in tokens: # Try with and without leading space for variant in [token, f" {token}", f" {token.upper()}", token.upper()]: ids = tokenizer.encode(variant, add_special_tokens=False) if len(ids) >= 1: token_ids[token] = ids[0] break return token_ids def get_logit_probs(model, tokenizer, prompt: str, target_tokens: list) -> Dict[str, float]: """ Get probability distribution over target tokens at the next-token position. Args: model: loaded HF model tokenizer: corresponding tokenizer prompt: input prompt text target_tokens: list of target completions (e.g., ['MAX', 'AVG', 'MEDIAN']) Returns: Dict mapping token -> probability """ inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0, -1, :] # last token position token_ids = get_token_ids(tokenizer, target_tokens) # Extract logits for target tokens target_logits = torch.tensor([logits[tid].item() for tid in token_ids.values()]) probs = torch.softmax(target_logits, dim=0) result = {} for (token, _), prob in zip(token_ids.items(), probs): result[token] = prob.item() return result def get_logit_difference( model, tokenizer, prompt: str, positive_token: str, negative_token: str ) -> float: """ Compute logit difference: logit(positive) - logit(negative). This is the primary metric for circuit analysis: - Positive values → model prefers positive_token - Negative values → model prefers negative_token """ inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0, -1, :] pos_ids = get_token_ids(tokenizer, [positive_token]) neg_ids = get_token_ids(tokenizer, [negative_token]) pos_logit = logits[list(pos_ids.values())[0]] neg_logit = logits[list(neg_ids.values())[0]] return (pos_logit - neg_logit).item()