| """ |
| Model loading and tokenization utilities. |
| |
| Supports: |
| - Local loading with optional quantization (4-bit, 8-bit) |
| - Multiple model sizes (8B for prototyping, 70B for production) |
| - Consistent tokenization across scenarios |
| """ |
|
|
| import torch |
| from typing import Optional, Dict, Any, Tuple |
|
|
|
|
| def load_model_and_tokenizer( |
| model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| quantize: Optional[str] = None, |
| device_map: str = "auto", |
| attn_implementation: Optional[str] = None, |
| ) -> Tuple[Any, Any]: |
| """ |
| Load a HuggingFace model and tokenizer. |
| |
| Args: |
| model_name: HF model ID |
| quantize: '4bit', '8bit', or None for full precision |
| device_map: device placement strategy |
| attn_implementation: 'flash_attention_2', 'sdpa', or None |
| |
| Returns: |
| (model, tokenizer) tuple |
| """ |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| kwargs: Dict[str, Any] = { |
| "device_map": device_map, |
| "torch_dtype": torch.bfloat16, |
| } |
| |
| if quantize == "4bit": |
| kwargs["quantization_config"] = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
| elif quantize == "8bit": |
| kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) |
| |
| if attn_implementation: |
| kwargs["attn_implementation"] = attn_implementation |
| |
| model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) |
| model.eval() |
| |
| return model, tokenizer |
|
|
|
|
| def get_token_ids(tokenizer, tokens: list) -> Dict[str, int]: |
| """ |
| Get token IDs for a list of target tokens (aggregation functions). |
| Handles multi-token cases by returning the first token. |
| """ |
| token_ids = {} |
| for token in tokens: |
| |
| for variant in [token, f" {token}", f" {token.upper()}", token.upper()]: |
| ids = tokenizer.encode(variant, add_special_tokens=False) |
| if len(ids) >= 1: |
| token_ids[token] = ids[0] |
| break |
| return token_ids |
|
|
|
|
| def get_logit_probs(model, tokenizer, prompt: str, target_tokens: list) -> Dict[str, float]: |
| """ |
| Get probability distribution over target tokens at the next-token position. |
| |
| Args: |
| model: loaded HF model |
| tokenizer: corresponding tokenizer |
| prompt: input prompt text |
| target_tokens: list of target completions (e.g., ['MAX', 'AVG', 'MEDIAN']) |
| |
| Returns: |
| Dict mapping token -> probability |
| """ |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits[0, -1, :] |
| |
| token_ids = get_token_ids(tokenizer, target_tokens) |
| |
| |
| target_logits = torch.tensor([logits[tid].item() for tid in token_ids.values()]) |
| probs = torch.softmax(target_logits, dim=0) |
| |
| result = {} |
| for (token, _), prob in zip(token_ids.items(), probs): |
| result[token] = prob.item() |
| |
| return result |
|
|
|
|
| def get_logit_difference( |
| model, tokenizer, prompt: str, |
| positive_token: str, negative_token: str |
| ) -> float: |
| """ |
| Compute logit difference: logit(positive) - logit(negative). |
| |
| This is the primary metric for circuit analysis: |
| - Positive values → model prefers positive_token |
| - Negative values → model prefers negative_token |
| """ |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits[0, -1, :] |
| |
| pos_ids = get_token_ids(tokenizer, [positive_token]) |
| neg_ids = get_token_ids(tokenizer, [negative_token]) |
| |
| pos_logit = logits[list(pos_ids.values())[0]] |
| neg_logit = logits[list(neg_ids.values())[0]] |
| |
| return (pos_logit - neg_logit).item() |
|
|