Text Generation
Transformers
Safetensors
recursive-mlm
feature-extraction
conversational
custom_code
Instructions to use Fraser/LLaDA-8B-Recursive-ARC with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Fraser/LLaDA-8B-Recursive-ARC with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="Fraser/LLaDA-8B-Recursive-ARC", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Fraser/LLaDA-8B-Recursive-ARC", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use Fraser/LLaDA-8B-Recursive-ARC with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "Fraser/LLaDA-8B-Recursive-ARC" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Fraser/LLaDA-8B-Recursive-ARC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/Fraser/LLaDA-8B-Recursive-ARC
- SGLang
How to use Fraser/LLaDA-8B-Recursive-ARC with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "Fraser/LLaDA-8B-Recursive-ARC" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Fraser/LLaDA-8B-Recursive-ARC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "Fraser/LLaDA-8B-Recursive-ARC" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Fraser/LLaDA-8B-Recursive-ARC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use Fraser/LLaDA-8B-Recursive-ARC with Docker Model Runner:
docker model run hf.co/Fraser/LLaDA-8B-Recursive-ARC
| from __future__ import annotations | |
| import warnings | |
| from dataclasses import dataclass | |
| from typing import NamedTuple, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| from torch.utils.checkpoint import checkpoint as torch_checkpoint | |
| from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel | |
| from transformers.modeling_outputs import MaskedLMOutput | |
| from transformers.utils import ModelOutput | |
| from .configuration_recursive import RecursiveMLMConfig | |
| class IterationMetrics(ModelOutput): | |
| """Metrics for a single iteration of recursive refinement.""" | |
| accuracy: Optional[float] = None | |
| entropy: Optional[float] = None | |
| softmax_ce: Optional[float] = None | |
| full_sequence_accuracy: Optional[float] = None | |
| min_sequence_confidence: Optional[float] = None | |
| class RecursiveMaskedLMOutput(MaskedLMOutput): | |
| iteration_metrics: Optional[dict[int, IterationMetrics]] = None # Maps iteration index to metrics | |
| next_soft_embeds: Optional[torch.Tensor] = None # For caching between training steps | |
| all_logits: Optional[list[torch.Tensor]] = None # All T iterations' logits for trainer loss computation | |
| # Flow matching state (for distillation — compact H-dim, not V-dim) | |
| flow_noise_embed: Optional[torch.Tensor] = None # (num_masked, H) noise embedding | |
| flow_t: Optional[torch.Tensor] = None # (num_masked,) per-token time levels | |
| class SelfDistillationOutput(NamedTuple): | |
| """Output from self-distillation forward pass.""" | |
| loss: torch.Tensor # KL divergence loss (scalar, has grad) | |
| teacher_logits: torch.Tensor # For metrics/debugging (detached) | |
| student_logits: torch.Tensor # For metrics/debugging (has grad) | |
| degradation_temperature: float # Mean per-token temperature sampled | |
| teacher_entropy: float # Entropy of teacher distribution (for monitoring) | |
| student_entropy: float # Entropy of student distribution (for monitoring) | |
| agreement_rate: float # Fraction where teacher and student argmax agree | |
| class RecursiveMaskedLM(PreTrainedModel): | |
| """ | |
| Wraps any HF MLM with recursive soft-token refinement. | |
| At each step: | |
| 1. Normalize logits -> probs | |
| 2. Compute soft embeddings: probs @ embedding_weight + mask_embedding | |
| 3. Forward through MLM | |
| 4. Accumulate weighted loss | |
| """ | |
| config_class = RecursiveMLMConfig | |
| base_model_prefix = "mlm" | |
| supports_gradient_checkpointing = True | |
| def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None): | |
| super().__init__(config) | |
| if base_model is not None: | |
| # Pre-trained model provided - assign directly WITHOUT calling post_init() | |
| # to avoid reinitializing the pre-trained weights via _init_weights() | |
| self.mlm = base_model | |
| elif config.base_model_config is not None: | |
| model_type = config.base_model_config.get("model_type", "") | |
| if model_type == "llada": | |
| from .configuration_llada import LLaDAConfig | |
| from .modeling_llada import LLaDAModelLM | |
| base_config = LLaDAConfig.from_dict(config.base_model_config) | |
| self.mlm = LLaDAModelLM(base_config) | |
| else: | |
| base_config = AutoConfig.for_model(**config.base_model_config) | |
| self.mlm = AutoModelForMaskedLM.from_config(base_config) | |
| # Only call post_init() for freshly created models (needs weight init) | |
| self.post_init() | |
| else: | |
| raise ValueError("Need either base_model or config.base_model_config") | |
| def from_mlm_pretrained( | |
| cls, | |
| mlm_name_or_path: str, | |
| num_recursions: int = 8, | |
| normalization: str = "softmax", | |
| loss_weight: str = "linear", | |
| mask_token_id: Optional[int] = None, | |
| temperature: float = 1.0, | |
| gradient_steps: Optional[int] = None, | |
| # === Convergence schedule parameters === | |
| schedule: str = "linear", | |
| causal_strength: float = 1.0, | |
| # === Effect parameters === | |
| temperature_max: float = 0.0, | |
| entropy_target_max: float = 0.0, | |
| entropy_floor_max: float = 0.0, | |
| smear_sigma_max: float = 0.0, | |
| noise_std_max: float = 0.0, | |
| iteration_rope_dim_fraction: float = 0.0, | |
| use_recursion_checkpointing: bool = True, | |
| # === Soft embedding method === | |
| soft_embedding_method: str = "softmax", | |
| soft_embedding_ema_step: float = 1.0, | |
| # === Flow matching parameters === | |
| flow_matching_enabled: bool = False, | |
| flow_matching_lambda: float = 0.5, | |
| flow_matching_t_distribution: str = "logit_normal", | |
| flow_matching_t_logit_mean: float = -0.4, | |
| flow_matching_t_logit_std: float = 1.0, | |
| flow_matching_t_min: float = 0.01, | |
| flow_matching_t_max: float = 0.99, | |
| flow_matching_mask_scale: bool = False, | |
| **model_kwargs, | |
| ) -> "RecursiveMaskedLM": | |
| """Load a pretrained MLM and wrap it for recursive refinement.""" | |
| base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs) | |
| return cls.from_base_model( | |
| base_model, | |
| num_recursions=num_recursions, | |
| normalization=normalization, | |
| loss_weight=loss_weight, | |
| mask_token_id=mask_token_id, | |
| temperature=temperature, | |
| gradient_steps=gradient_steps, | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| temperature_max=temperature_max, | |
| entropy_target_max=entropy_target_max, | |
| entropy_floor_max=entropy_floor_max, | |
| smear_sigma_max=smear_sigma_max, | |
| noise_std_max=noise_std_max, | |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, | |
| use_recursion_checkpointing=use_recursion_checkpointing, | |
| soft_embedding_method=soft_embedding_method, | |
| soft_embedding_ema_step=soft_embedding_ema_step, | |
| flow_matching_enabled=flow_matching_enabled, | |
| flow_matching_lambda=flow_matching_lambda, | |
| flow_matching_t_distribution=flow_matching_t_distribution, | |
| flow_matching_t_logit_mean=flow_matching_t_logit_mean, | |
| flow_matching_t_logit_std=flow_matching_t_logit_std, | |
| flow_matching_t_min=flow_matching_t_min, | |
| flow_matching_t_max=flow_matching_t_max, | |
| flow_matching_mask_scale=flow_matching_mask_scale, | |
| ) | |
| def from_base_model( | |
| cls, | |
| base_model: PreTrainedModel, | |
| num_recursions: int = 8, | |
| normalization: str = "softmax", | |
| loss_weight: str = "linear", | |
| mask_token_id: Optional[int] = None, | |
| temperature: float = 1.0, | |
| gradient_steps: Optional[int] = None, | |
| # === Convergence schedule parameters === | |
| schedule: str = "linear", | |
| causal_strength: float = 1.0, | |
| # === Effect parameters === | |
| temperature_max: float = 0.0, | |
| entropy_target_max: float = 0.0, | |
| entropy_floor_max: float = 0.0, | |
| smear_sigma_max: float = 0.0, | |
| noise_std_max: float = 0.0, | |
| iteration_rope_dim_fraction: float = 0.0, | |
| use_recursion_checkpointing: bool = True, | |
| # === Soft embedding method === | |
| soft_embedding_method: str = "softmax", | |
| soft_embedding_ema_step: float = 1.0, | |
| # === Flow matching parameters === | |
| flow_matching_enabled: bool = False, | |
| flow_matching_lambda: float = 0.5, | |
| flow_matching_t_distribution: str = "logit_normal", | |
| flow_matching_t_logit_mean: float = -0.4, | |
| flow_matching_t_logit_std: float = 1.0, | |
| flow_matching_t_min: float = 0.01, | |
| flow_matching_t_max: float = 0.99, | |
| flow_matching_mask_scale: bool = False, | |
| ) -> "RecursiveMaskedLM": | |
| """Wrap an existing model for recursive refinement. | |
| Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA). | |
| Args: | |
| base_model: The base MLM model to wrap | |
| num_recursions: Number of recursive refinement steps | |
| normalization: Normalization method for logits (softmax, stable_softmax) | |
| loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform) | |
| mask_token_id: Token ID for [MASK] | |
| temperature: Temperature for softmax normalization | |
| gradient_steps: Number of final steps to backprop through | |
| schedule: Convergence schedule type ("linear" or "causal") | |
| causal_strength: How much faster early positions converge (causal only) | |
| temperature_max: Max temperature boost for uncertain positions | |
| entropy_target_max: Target entropy at progress=0 (two-sided, recommended) | |
| entropy_floor_max: Min entropy floor (one-sided) | |
| smear_sigma_max: Max Gaussian sigma for position smearing | |
| noise_std_max: Max std of Gaussian noise on logits | |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE | |
| use_recursion_checkpointing: Enable gradient checkpointing for iterations | |
| soft_embedding_method: How to convert logits to soft embeddings | |
| soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous) | |
| flow_matching_enabled: Enable CFM-inspired flow matching framework | |
| flow_matching_lambda: Weight of distillation KL loss relative to CE | |
| flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform") | |
| flow_matching_t_logit_mean: Mean of logit-normal distribution | |
| flow_matching_t_logit_std: Std of logit-normal distribution | |
| flow_matching_t_min: Minimum time value (clamp) | |
| flow_matching_t_max: Maximum time value (clamp) | |
| flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False | |
| """ | |
| config = RecursiveMLMConfig.from_base_model_config( | |
| base_model.config, | |
| num_recursions=num_recursions, | |
| normalization=normalization, | |
| loss_weight=loss_weight, | |
| mask_token_id=mask_token_id, | |
| temperature=temperature, | |
| gradient_steps=gradient_steps, | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| temperature_max=temperature_max, | |
| entropy_target_max=entropy_target_max, | |
| entropy_floor_max=entropy_floor_max, | |
| smear_sigma_max=smear_sigma_max, | |
| noise_std_max=noise_std_max, | |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, | |
| use_recursion_checkpointing=use_recursion_checkpointing, | |
| soft_embedding_method=soft_embedding_method, | |
| soft_embedding_ema_step=soft_embedding_ema_step, | |
| flow_matching_enabled=flow_matching_enabled, | |
| flow_matching_lambda=flow_matching_lambda, | |
| flow_matching_t_distribution=flow_matching_t_distribution, | |
| flow_matching_t_logit_mean=flow_matching_t_logit_mean, | |
| flow_matching_t_logit_std=flow_matching_t_logit_std, | |
| flow_matching_t_min=flow_matching_t_min, | |
| flow_matching_t_max=flow_matching_t_max, | |
| flow_matching_mask_scale=flow_matching_mask_scale, | |
| ) | |
| return cls(config, base_model=base_model) | |
| def embed_weight(self) -> torch.Tensor: | |
| return self.mlm.get_input_embeddings().weight | |
| def get_input_embeddings(self): | |
| return self.mlm.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.mlm.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.mlm.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.mlm.set_output_embeddings(new_embeddings) | |
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): | |
| """Enable gradient checkpointing with correct settings for recursion. | |
| Forces use_reentrant=False which is required for: | |
| - Nested checkpoint calls (base model + recursion checkpointing) | |
| - Models with frozen parameters | |
| - Complex gradient flows through soft embeddings | |
| """ | |
| if gradient_checkpointing_kwargs is None: | |
| gradient_checkpointing_kwargs = {} | |
| # Force use_reentrant=False for nested checkpointing compatibility | |
| gradient_checkpointing_kwargs.setdefault("use_reentrant", False) | |
| self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) | |
| def gradient_checkpointing_disable(self): | |
| """Disable gradient checkpointing in the underlying MLM.""" | |
| self.mlm.gradient_checkpointing_disable() | |
| def _single_iteration_checkpointable( | |
| self, | |
| soft_embeds: torch.Tensor, | |
| base_embeds: torch.Tensor, | |
| mask_pos: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| embed_weight: torch.Tensor, | |
| mask_emb: torch.Tensor, | |
| temperature: torch.Tensor, | |
| position_ids: Optional[torch.Tensor] = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Single differentiable iteration for checkpointing. | |
| This method performs one iteration of recursive refinement in a way that | |
| maintains gradient flow and is compatible with torch.utils.checkpoint. | |
| Args: | |
| soft_embeds: (B, L, H) - current soft embeddings | |
| base_embeds: (B, L, H) - original token embeddings | |
| mask_pos: (B, L) bool - which positions are masked | |
| attention_mask: (B, L) - attention mask for MLM | |
| embed_weight: (V, H) - embedding weight matrix | |
| mask_emb: (H,) - mask token embedding | |
| temperature: scalar tensor - softmax temperature | |
| Returns: | |
| logits: (B, L, V) - output logits from this iteration | |
| next_soft_embeds: (B, L, H) - soft embeddings for next iteration | |
| """ | |
| # Blend: use soft_embeds at masked positions, base_embeds elsewhere | |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) | |
| # Forward through base MLM | |
| outputs = self.mlm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| return_dict=True, | |
| ) | |
| logits = outputs.logits | |
| # Compute soft embeddings for next iteration (DIFFERENTIABLE - no detach!) | |
| next_soft_embeds = base_embeds.clone() | |
| if mask_pos.any(): | |
| masked_logits = logits[mask_pos] # (num_masked, V) | |
| # Convert logits to mixing weights based on soft_embedding_method | |
| if self.config.soft_embedding_method == "none": | |
| # No normalization - use raw logits directly | |
| weights = masked_logits # Differentiable! | |
| elif self.config.soft_embedding_method == "l2_normalize": | |
| # L2 normalize logits - removes softmax bottleneck for smoother gradients | |
| weights = F.normalize(masked_logits, p=2, dim=-1) # Differentiable! | |
| else: | |
| # Default: softmax normalization | |
| weights = F.softmax(masked_logits / temperature, dim=-1) # Differentiable! | |
| soft_emb = weights @ embed_weight + mask_emb # Differentiable! | |
| # Apply EMA blending with previous soft embeddings if enabled | |
| ema_step = self.config.soft_embedding_ema_step | |
| if ema_step < 1.0: | |
| prev_soft_emb = soft_embeds[mask_pos] # Previous iteration's soft embeddings | |
| soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb | |
| next_soft_embeds[mask_pos] = soft_emb | |
| return logits, next_soft_embeds | |
| def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: | |
| """Numerically stable softmax with temperature T > 0.""" | |
| z = logits / max(T, eps) | |
| z = z - z.max(dim=dim, keepdim=True).values # subtract max | |
| z = torch.exp(z) # safe since z <= 0 | |
| z_sum = z.sum(dim=dim, keepdim=True) | |
| return z / z_sum.clamp(min=eps) | |
| def normalize(self, logits: torch.Tensor) -> torch.Tensor: | |
| """Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)""" | |
| norm = self.config.normalization.lower() | |
| T = self.config.temperature | |
| V = logits.shape[-1] | |
| if norm == "none": | |
| return logits | |
| if norm == "softmax": | |
| return torch.softmax(logits / T, dim=-1) | |
| if norm == "stable_softmax": | |
| return self._stable_softmax(logits, T=T, dim=-1) | |
| raise ValueError(f"Unknown normalization: {norm}") | |
| def step_weight(self, t: int, T: int) -> float: | |
| """Loss weight for step t of T.""" | |
| lw = self.config.loss_weight | |
| if lw == "linear": | |
| return (t + 1) / T | |
| if lw == "uniform": | |
| return 1.0 | |
| if lw == "last_1": | |
| return 1.0 if t == T - 1 else 0.0 | |
| if lw == "last_2": | |
| return 1.0 if T - t <= 2 else 0.0 | |
| raise ValueError(f"Unknown loss_weight: {lw}") | |
| # ==================== CONVERGENCE SCHEDULE SYSTEM ==================== | |
| # | |
| # The core idea: control WHEN each position is allowed to converge. | |
| # | |
| # Schedule types: | |
| # - "linear": All positions converge at the same rate | |
| # - "causal": Early positions converge first, late positions last | |
| # | |
| # Effects (mechanisms to enforce the schedule): | |
| # - temperature: Raise temperature for positions not yet allowed to converge | |
| # - entropy_floor: Force minimum entropy | |
| # - entropy_target: Force exact entropy via bisection (ARChitects-style) | |
| # - smear: Spread probability across neighboring positions | |
| # - noise: Add Gaussian noise to logits | |
| # | |
| # Each effect uses per-position "convergence progress" (0=uncertain, 1=can converge) | |
| def _compute_convergence_progress( | |
| self, | |
| iteration: int, | |
| total_iterations: int, | |
| seq_length: int, | |
| mask_positions: torch.Tensor, | |
| schedule: str = "linear", | |
| causal_strength: float = 1.0, | |
| device: torch.device = None, | |
| dtype: torch.dtype = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute per-position convergence progress based on schedule. | |
| Args: | |
| iteration: Current iteration (0-indexed) | |
| total_iterations: Total number of iterations | |
| seq_length: Full sequence length L | |
| mask_positions: Position indices of masked tokens (num_masked,) | |
| schedule: "linear" or "causal" | |
| causal_strength: How much faster early positions converge (for causal schedule) | |
| Returns: | |
| progress: (num_masked,) tensor with values in [0, 1] | |
| 0 = position should be maximally uncertain | |
| 1 = position is allowed to fully converge | |
| """ | |
| base_progress = iteration / max(total_iterations - 1, 1) | |
| if schedule == "linear": | |
| return torch.full( | |
| (mask_positions.shape[0],), | |
| base_progress, | |
| device=device, | |
| dtype=dtype | |
| ) | |
| elif schedule == "causal": | |
| position_factor = mask_positions.float() / max(seq_length - 1, 1) | |
| effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor)) | |
| return effective_progress.clamp(0.0, 1.0) | |
| else: | |
| raise ValueError(f"Unknown schedule: {schedule}") | |
| def _apply_temperature_effect( | |
| self, | |
| logits: torch.Tensor, | |
| progress: torch.Tensor, | |
| temperature_max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply per-position temperature scaling based on convergence progress. | |
| Low progress = high temperature (uncertain), high progress = temperature 1.0. | |
| """ | |
| if temperature_max <= 0: | |
| return logits | |
| temperature = 1.0 + temperature_max * (1.0 - progress) | |
| temperature = temperature.unsqueeze(-1) | |
| return logits / temperature | |
| def _apply_entropy_floor_effect( | |
| self, | |
| probs: torch.Tensor, | |
| progress: torch.Tensor, | |
| entropy_floor_max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Ensure minimum entropy based on convergence progress. | |
| Low progress = high entropy floor, high progress = no floor. | |
| NOTE: This is a ONE-SIDED constraint (floor only). | |
| """ | |
| if entropy_floor_max <= 0: | |
| return probs | |
| entropy_floor = entropy_floor_max * (1.0 - progress) | |
| log_probs = torch.log(probs + 1e-10) | |
| current_entropy = -(probs * log_probs).sum(dim=-1) | |
| below_floor = current_entropy < entropy_floor | |
| if not below_floor.any(): | |
| return probs | |
| logits = torch.log(probs + 1e-10) | |
| target_ratio = entropy_floor / (current_entropy + 1e-10) | |
| temperature = torch.ones_like(current_entropy) | |
| temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0) | |
| scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1) | |
| result = probs.clone() | |
| result[below_floor] = scaled_probs[below_floor] | |
| return result | |
| def _find_temperature_for_target_entropy( | |
| self, | |
| logits: torch.Tensor, | |
| target_entropy: torch.Tensor, | |
| tol: float = 1e-3, | |
| max_iter: int = 32, | |
| T_low: float = 1e-6, | |
| T_high_init: float = 1.0, | |
| max_T: float = 100.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Find per-position temperatures that achieve exactly the target entropy. | |
| Uses bisection search, adapted from ARChitects' implementation. | |
| Args: | |
| logits: Raw logits (num_positions, V) | |
| target_entropy: Target entropy per position (num_positions,) or scalar | |
| tol: Entropy tolerance for convergence | |
| max_iter: Maximum bisection iterations | |
| T_low: Minimum temperature (near-greedy) | |
| T_high_init: Initial upper bound for search | |
| max_T: Maximum allowed temperature | |
| Returns: | |
| temperatures: (num_positions,) temperatures that achieve target entropy | |
| """ | |
| N, V = logits.shape | |
| device, dtype = logits.device, logits.dtype | |
| H_max = torch.log(torch.tensor(V, device=device, dtype=dtype)) | |
| if target_entropy.dim() == 0: | |
| target = target_entropy.expand(N).clone() | |
| else: | |
| target = target_entropy.clone() | |
| target = target.clamp(0.0, H_max) | |
| def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor: | |
| temps = temps.unsqueeze(-1).clamp(min=T_low) | |
| scaled = logits_ / temps | |
| scaled = scaled - scaled.max(dim=-1, keepdim=True).values | |
| probs = torch.softmax(scaled, dim=-1) | |
| log_probs = torch.log(probs + 1e-12) | |
| return -(probs * log_probs).sum(dim=-1) | |
| lo = torch.full((N,), T_low, device=device, dtype=dtype) | |
| hi = torch.full((N,), T_high_init, device=device, dtype=dtype) | |
| H_lo = compute_entropy(logits, lo) | |
| done_low = target <= (H_lo + tol) | |
| H_hi = compute_entropy(logits, hi) | |
| needs_expansion = (H_hi < target - tol) & ~done_low | |
| for _ in range(100): | |
| if not needs_expansion.any(): | |
| break | |
| hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T) | |
| H_hi[needs_expansion] = compute_entropy( | |
| logits[needs_expansion], hi[needs_expansion] | |
| ) | |
| needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6) | |
| can_bisect = ~done_low & (H_hi >= target - tol) | |
| for _ in range(max_iter): | |
| if not can_bisect.any(): | |
| break | |
| mid = (lo + hi) / 2.0 | |
| H_mid = compute_entropy(logits, mid) | |
| too_low = (H_mid < target) & can_bisect | |
| lo[too_low] = mid[too_low] | |
| hi[~too_low & can_bisect] = mid[~too_low & can_bisect] | |
| converged = (hi - lo) <= tol * mid.clamp(min=1.0) | |
| can_bisect = can_bisect & ~converged | |
| temps = torch.zeros(N, device=device, dtype=dtype) | |
| temps[done_low] = T_low | |
| temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0 | |
| return temps | |
| def _apply_target_entropy_effect( | |
| self, | |
| logits: torch.Tensor, | |
| progress: torch.Tensor, | |
| entropy_target_max: float, | |
| entropy_target_min: float = 0.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Adjust temperature to achieve EXACTLY the target entropy per position. | |
| This is a TWO-SIDED constraint: both raises and lowers entropy as needed. | |
| Args: | |
| logits: Raw logits (num_masked, V) | |
| progress: Per-position convergence progress (num_masked,) | |
| entropy_target_max: Target entropy at progress=0 | |
| entropy_target_min: Target entropy at progress=1 (usually ~0) | |
| Returns: | |
| probs: Probabilities with entropy matching targets | |
| """ | |
| if entropy_target_max <= 0: | |
| return torch.softmax(logits, dim=-1) | |
| target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress | |
| temps = self._find_temperature_for_target_entropy(logits, target_entropy) | |
| temps = temps.unsqueeze(-1).clamp(min=1e-6) | |
| return torch.softmax(logits / temps, dim=-1) | |
| def _apply_smear_effect( | |
| self, | |
| probs: torch.Tensor, | |
| mask_pos: torch.Tensor, | |
| progress_full: torch.Tensor, | |
| smear_sigma_max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply positional smearing with per-position sigma based on progress. | |
| Low progress = high smearing, high progress = no smearing. | |
| Note: This operates on full (B, L, V) tensor because smearing mixes across positions. | |
| """ | |
| if smear_sigma_max <= 0: | |
| return probs | |
| B, L, V = probs.shape | |
| sigma_per_pos = smear_sigma_max * (1.0 - progress_full) | |
| avg_sigma = sigma_per_pos[mask_pos].mean().item() | |
| if avg_sigma < 0.1: | |
| return probs | |
| positions = torch.arange(L, device=probs.device, dtype=probs.dtype) | |
| diff = positions.unsqueeze(0) - positions.unsqueeze(1) | |
| kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2) | |
| kernel = kernel / kernel.sum(dim=1, keepdim=True) | |
| smeared = torch.einsum('ij,bjv->biv', kernel, probs) | |
| smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10) | |
| blend = progress_full.unsqueeze(-1) | |
| result = blend * probs + (1 - blend) * smeared | |
| output = probs.clone() | |
| output[mask_pos] = result[mask_pos] | |
| return output | |
| def _apply_noise_effect( | |
| self, | |
| logits: torch.Tensor, | |
| progress: torch.Tensor, | |
| noise_std_max: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Add Gaussian noise to logits based on convergence progress. | |
| Low progress = high noise, high progress = no noise. | |
| """ | |
| if noise_std_max <= 0: | |
| return logits | |
| noise_std = noise_std_max * (1.0 - progress) | |
| noise_std = noise_std.unsqueeze(-1) | |
| noise = torch.randn_like(logits) * noise_std | |
| return logits + noise | |
| def _apply_iteration_rope( | |
| self, | |
| embeds: torch.Tensor, | |
| iteration: int, | |
| total_iterations: int, | |
| dim_fraction: float = 0.25, | |
| base: float = 10000.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply rotary embedding based on iteration progress. | |
| Uses a subset of dimensions to avoid interfering with position RoPE. | |
| """ | |
| if dim_fraction <= 0: | |
| return embeds | |
| H = embeds.shape[-1] | |
| rot_dim = int(H * dim_fraction) | |
| rot_dim = rot_dim - (rot_dim % 2) | |
| if rot_dim < 2: | |
| return embeds | |
| progress = iteration / max(total_iterations - 1, 1) | |
| inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim)) | |
| angles = progress * inv_freq * 3.14159 | |
| cos, sin = torch.cos(angles), torch.sin(angles) | |
| if embeds.dim() == 2: | |
| cos, sin = cos.unsqueeze(0), sin.unsqueeze(0) | |
| elif embeds.dim() == 3: | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| embeds_out = embeds.clone() | |
| x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2] | |
| embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin | |
| embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos | |
| return embeds_out | |
| # ==================== FLOW MATCHING ==================== | |
| def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor: | |
| """Sample per-token time levels for flow matching. | |
| Returns: | |
| t: (num_tokens,) tensor of time levels in [t_min, t_max] | |
| """ | |
| dist = self.config.flow_matching_t_distribution | |
| if dist == "logit_normal": | |
| z = torch.randn(num_tokens, device=device) | |
| z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean | |
| t = torch.sigmoid(z) | |
| elif dist == "uniform": | |
| t = torch.empty(num_tokens, device=device).uniform_(0, 1) | |
| else: | |
| raise ValueError(f"Unknown flow_matching_t_distribution: {dist}") | |
| return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max) | |
| def compute_flow_matching_distillation_loss( | |
| self, | |
| input_ids: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| labels: torch.Tensor, | |
| flow_noise_embed: torch.Tensor, | |
| flow_t: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| ) -> SelfDistillationOutput: | |
| """ | |
| CFM flow matching distillation: teacher sees state at time t, student sees | |
| noisier state at time s < t on the same interpolation path. | |
| Both should predict the same endpoint (target token). The student must | |
| learn to refine from noisier inputs by matching the teacher's predictions. | |
| Args: | |
| input_ids: Input with [MASK] tokens at positions to predict | |
| teacher_logits: Logits from the forward pass (will be detached) | |
| labels: Target tokens at masked positions (-100 elsewhere) | |
| flow_noise_embed: (num_masked, H) noise embeddings from forward | |
| flow_t: (num_masked,) per-token time levels from forward | |
| attention_mask: Standard attention mask | |
| position_ids: Position IDs (if needed by base model) | |
| Returns: | |
| SelfDistillationOutput with loss, logits, time gap, and diagnostics | |
| """ | |
| mask_id = self.config.mask_token_id | |
| mask_pos = (input_ids == mask_id) # (B, L) | |
| device = input_ids.device | |
| num_masked = mask_pos.sum().item() | |
| if num_masked == 0: | |
| zero = torch.tensor(0.0, device=device, requires_grad=True) | |
| dummy = torch.zeros(1, device=device) | |
| return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0) | |
| teacher_logits = teacher_logits.detach() | |
| embed_weight = self.embed_weight | |
| mask_emb = embed_weight[mask_id] # (H,) | |
| base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) | |
| # Target embeddings from labels | |
| target_ids = labels[mask_pos] # (num_masked,) | |
| target_embed = embed_weight[target_ids] # (num_masked, H) | |
| # Sample student time s ~ U(0, t) per token | |
| s_per_token = flow_t * torch.rand(num_masked, device=device) # (num_masked,) | |
| # Student state: same noise, earlier time (noisier) | |
| s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1) | |
| student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed | |
| if self.config.flow_matching_mask_scale: | |
| student_masked_embeds = student_interp + (1 - s_col) * mask_emb | |
| else: | |
| student_masked_embeds = student_interp + mask_emb | |
| # Build full student input (detached — gradient only flows through student's forward) | |
| student_embeds = base_embeds.detach().clone() | |
| student_embeds[mask_pos] = student_masked_embeds.detach() | |
| student_inputs = torch.where( | |
| mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach() | |
| ) | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) | |
| student_out = self.mlm( | |
| inputs_embeds=student_inputs, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| return_dict=True, | |
| ) | |
| student_logits = student_out.logits # (B, L, V) — has gradient | |
| # KL divergence loss on masked positions | |
| t_logits = teacher_logits[mask_pos] # (num_masked, V) | |
| s_logits = student_logits[mask_pos] # (num_masked, V) | |
| teacher_probs = F.softmax(t_logits, dim=-1) | |
| student_log_probs = F.log_softmax(s_logits, dim=-1) | |
| kl_loss = F.kl_div( | |
| student_log_probs, | |
| teacher_probs, | |
| reduction="batchmean", | |
| ) | |
| # Diagnostic metrics | |
| with torch.no_grad(): | |
| teacher_log_probs = torch.log(teacher_probs + 1e-10) | |
| teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() | |
| student_probs = F.softmax(s_logits.detach(), dim=-1) | |
| student_log_probs_det = torch.log(student_probs + 1e-10) | |
| student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() | |
| agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() | |
| mean_time_gap = (flow_t - s_per_token).mean().item() | |
| return SelfDistillationOutput( | |
| loss=kl_loss, | |
| teacher_logits=teacher_logits, | |
| student_logits=student_logits, | |
| degradation_temperature=mean_time_gap, | |
| teacher_entropy=teacher_entropy, | |
| student_entropy=student_entropy, | |
| agreement_rate=agreement, | |
| ) | |
| # ==================== SELF-DISTILLATION (legacy) ==================== | |
| def compute_self_distillation_loss( | |
| self, | |
| input_ids: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| temperature_min: Optional[float] = None, | |
| temperature_max: Optional[float] = None, | |
| temperature_distribution: Optional[str] = None, | |
| ) -> SelfDistillationOutput: | |
| """ | |
| CFM-style self-distillation: model's predictions should be consistent | |
| across different levels of input degradation. | |
| Process: | |
| 1. Take teacher logits (from standard forward pass, DETACHED) | |
| 2. Degrade: per-token random temperature → softer soft embeddings | |
| 3. Student: forward pass from degraded embeddings → logits (has grad) | |
| 4. Loss: KL(teacher || student) on masked positions | |
| Each masked token gets its own independently sampled degradation | |
| temperature, creating varied difficulty across the sequence. | |
| Args: | |
| input_ids: Input with [MASK] tokens at positions to predict | |
| teacher_logits: Pre-computed teacher logits (will be detached). | |
| Typically outputs.all_logits[0] or outputs.logits from standard forward. | |
| attention_mask: Standard attention mask | |
| position_ids: Position IDs (if needed by base model) | |
| temperature_min: Min degradation temperature (default: config value) | |
| temperature_max: Max degradation temperature (default: config value) | |
| temperature_distribution: How to sample T (default: config value) | |
| Returns: | |
| SelfDistillationOutput with loss, logits, temperature, and diagnostics | |
| """ | |
| # Resolve defaults from config | |
| temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min | |
| temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max | |
| temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution | |
| mask_id = self.config.mask_token_id | |
| mask_pos = (input_ids == mask_id) # (B, L) | |
| device = input_ids.device | |
| num_masked = mask_pos.sum().item() | |
| # Handle degenerate case: no masked positions | |
| if num_masked == 0: | |
| zero = torch.tensor(0.0, device=device, requires_grad=True) | |
| dummy = torch.zeros(1, device=device) | |
| return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0) | |
| # Ensure teacher logits are detached | |
| teacher_logits = teacher_logits.detach() | |
| embed_weight = self.embed_weight | |
| mask_emb = embed_weight[mask_id] # (H,) | |
| base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) | |
| # ===== STEP 1: Sample per-token degradation temperatures ===== | |
| # Each masked position gets its own temperature independently | |
| if temperature_distribution == "log_uniform": | |
| log_min = torch.tensor(temperature_min, device=device).log() | |
| log_max = torch.tensor(temperature_max, device=device).log() | |
| log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item()) | |
| T_per_token = log_T.exp() # (num_masked,) | |
| elif temperature_distribution == "uniform": | |
| T_per_token = torch.empty(num_masked, device=device).uniform_( | |
| temperature_min, temperature_max | |
| ) # (num_masked,) | |
| else: | |
| raise ValueError(f"Unknown temperature distribution: {temperature_distribution}") | |
| T_mean = T_per_token.mean().item() | |
| # ===== STEP 2: Create degraded soft embeddings ===== | |
| # Per-token temperature scaling: each position gets its own T | |
| masked_teacher_logits = teacher_logits[mask_pos] # (num_masked, V) | |
| degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype) | |
| degraded_soft = degraded_probs @ embed_weight + mask_emb | |
| degraded_soft_embeds = base_embeds.clone() | |
| degraded_soft_embeds[mask_pos] = degraded_soft | |
| degraded_soft_embeds = degraded_soft_embeds.detach() | |
| # ===== STEP 3: Student forward from degraded input ===== | |
| student_inputs = torch.where( | |
| mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach() | |
| ) | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) | |
| student_out = self.mlm( | |
| inputs_embeds=student_inputs, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| return_dict=True, | |
| ) | |
| student_logits = student_out.logits # (B, L, V) — has gradient! | |
| # ===== STEP 4: KL divergence loss on masked positions ===== | |
| t_logits = teacher_logits[mask_pos] # (num_masked, V) | |
| s_logits = student_logits[mask_pos] # (num_masked, V) | |
| teacher_probs = F.softmax(t_logits, dim=-1) | |
| student_log_probs = F.log_softmax(s_logits, dim=-1) | |
| # KL(teacher || student) = sum teacher * (log_teacher - log_student) | |
| kl_loss = F.kl_div( | |
| student_log_probs, | |
| teacher_probs, | |
| reduction="batchmean", | |
| ) | |
| # ===== STEP 5: Compute diagnostic metrics ===== | |
| with torch.no_grad(): | |
| teacher_log_probs = torch.log(teacher_probs + 1e-10) | |
| teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() | |
| student_probs = F.softmax(s_logits.detach(), dim=-1) | |
| student_log_probs_det = torch.log(student_probs + 1e-10) | |
| student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() | |
| agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() | |
| return SelfDistillationOutput( | |
| loss=kl_loss, | |
| teacher_logits=teacher_logits, | |
| student_logits=student_logits, | |
| degradation_temperature=T_mean, | |
| teacher_entropy=teacher_entropy, | |
| student_entropy=student_entropy, | |
| agreement_rate=agreement, | |
| ) | |
| # ==================== MAIN SOFT EMBEDDING COMPUTATION ==================== | |
| def _compute_next_soft_embeds( | |
| self, | |
| logits: torch.Tensor, | |
| mask_pos: torch.Tensor, | |
| base_embeds: torch.Tensor, | |
| prev_soft_embeds: Optional[torch.Tensor] = None, | |
| iteration: int = 0, | |
| total_iterations: int = 1, | |
| # === Schedule parameters (default to config values) === | |
| schedule: Optional[str] = None, | |
| causal_strength: Optional[float] = None, | |
| # === Effect parameters (default to config values) === | |
| temperature_max: Optional[float] = None, | |
| entropy_target_max: Optional[float] = None, | |
| entropy_floor_max: Optional[float] = None, | |
| smear_sigma_max: Optional[float] = None, | |
| noise_std_max: Optional[float] = None, | |
| iteration_rope_dim_fraction: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute soft embeddings from logits for the next iteration. | |
| This function implements a unified "convergence schedule" system that controls | |
| when each position is allowed to converge to a confident prediction. | |
| Schedule Types: | |
| "linear": All positions converge at the same rate (iteration-based only) | |
| "causal": Early positions converge first, late positions last | |
| Effects (mechanisms to enforce the schedule): | |
| temperature_max: High temperature = more uniform distribution (one-sided) | |
| entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended) | |
| entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident) | |
| smear_sigma_max: Spread probability across neighboring positions | |
| noise_std_max: Add Gaussian noise to logits | |
| All parameters default to their config values if not specified. | |
| Args: | |
| logits: Output logits from current iteration (B, L, V) | |
| mask_pos: Boolean mask indicating which positions are masked (B, L) | |
| base_embeds: Base token embeddings for non-masked positions (B, L, H) | |
| iteration: Current iteration index (0-indexed) | |
| total_iterations: Total number of iterations | |
| Returns: | |
| Soft embeddings for next iteration (B, L, H) | |
| """ | |
| # Use config values as defaults | |
| schedule = schedule if schedule is not None else self.config.schedule | |
| causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength | |
| temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max | |
| entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max | |
| entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max | |
| smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max | |
| noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max | |
| iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction | |
| soft_embeds = base_embeds.clone() | |
| if not mask_pos.any(): | |
| return soft_embeds.detach() | |
| B, L, V = logits.shape | |
| device, dtype = logits.device, logits.dtype | |
| # Check if any effects are enabled | |
| has_effects = ( | |
| temperature_max > 0 or | |
| entropy_target_max > 0 or | |
| entropy_floor_max > 0 or | |
| smear_sigma_max > 0 or | |
| noise_std_max > 0 or | |
| iteration_rope_dim_fraction > 0 | |
| ) | |
| if not has_effects: | |
| # Simple path: no convergence schedule effects | |
| masked_logits = logits[mask_pos] | |
| embed_weight = self.embed_weight | |
| # Convert logits to mixing weights based on soft_embedding_method | |
| if self.config.soft_embedding_method == "none": | |
| weights = masked_logits | |
| elif self.config.soft_embedding_method == "l2_normalize": | |
| weights = F.normalize(masked_logits, p=2, dim=-1) | |
| else: | |
| weights = self.normalize(masked_logits) | |
| masked_soft = weights @ embed_weight | |
| mask_emb = embed_weight[self.config.mask_token_id] | |
| masked_soft = masked_soft + mask_emb | |
| # Apply EMA blending with previous soft embeddings if enabled | |
| ema_step = self.config.soft_embedding_ema_step | |
| if ema_step < 1.0 and prev_soft_embeds is not None: | |
| prev_masked_soft = prev_soft_embeds[mask_pos] | |
| masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft | |
| soft_embeds[mask_pos] = masked_soft | |
| return soft_embeds.detach() | |
| # ========== STEP 1: Compute per-position convergence progress ========== | |
| batch_indices, position_indices = torch.where(mask_pos) | |
| progress = self._compute_convergence_progress( | |
| iteration=iteration, | |
| total_iterations=total_iterations, | |
| seq_length=L, | |
| mask_positions=position_indices, | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| # Compute full (B, L) progress for smearing if needed | |
| if smear_sigma_max > 0: | |
| all_positions = torch.arange(L, device=device, dtype=dtype) | |
| progress_full = self._compute_convergence_progress( | |
| iteration=iteration, | |
| total_iterations=total_iterations, | |
| seq_length=L, | |
| mask_positions=all_positions, | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| progress_full = progress_full.unsqueeze(0).expand(B, -1) | |
| # ========== STEP 2: Apply smearing (needs full tensor) ========== | |
| full_probs = self.normalize(logits) | |
| if smear_sigma_max > 0: | |
| full_probs = self._apply_smear_effect( | |
| full_probs, mask_pos, progress_full, smear_sigma_max | |
| ) | |
| # ========== STEP 3: Extract masked positions ========== | |
| masked_logits = logits[mask_pos] | |
| masked_probs = full_probs[mask_pos] | |
| # ========== STEP 4: Apply temperature effect (on logits) ========== | |
| if temperature_max > 0 and entropy_target_max <= 0: | |
| masked_logits = self._apply_temperature_effect( | |
| masked_logits, progress, temperature_max | |
| ) | |
| masked_probs = torch.softmax(masked_logits, dim=-1) | |
| # ========== STEP 5: Apply noise effect (on logits) ========== | |
| if noise_std_max > 0: | |
| masked_logits_noisy = self._apply_noise_effect( | |
| torch.log(masked_probs + 1e-10), progress, noise_std_max | |
| ) | |
| masked_probs = torch.softmax(masked_logits_noisy, dim=-1) | |
| # ========== STEP 6: Apply entropy control ========== | |
| if entropy_target_max > 0: | |
| masked_probs = self._apply_target_entropy_effect( | |
| masked_logits, progress, entropy_target_max | |
| ) | |
| elif entropy_floor_max > 0: | |
| masked_probs = self._apply_entropy_floor_effect( | |
| masked_probs, progress, entropy_floor_max | |
| ) | |
| # ========== STEP 7: Compute soft embeddings ========== | |
| embed_weight = self.embed_weight | |
| # Convert to mixing weights based on soft_embedding_method | |
| if self.config.soft_embedding_method == "none": | |
| # No normalization - use raw logits directly | |
| weights = masked_logits | |
| elif self.config.soft_embedding_method == "l2_normalize": | |
| # L2 normalize bypasses all the softmax-based effects above | |
| weights = F.normalize(masked_logits, p=2, dim=-1) | |
| else: | |
| weights = masked_probs | |
| masked_soft = weights @ embed_weight | |
| mask_emb = embed_weight[self.config.mask_token_id] | |
| masked_soft = masked_soft + mask_emb | |
| # ========== STEP 8: Apply iteration RoPE ========== | |
| if iteration_rope_dim_fraction > 0: | |
| masked_soft = self._apply_iteration_rope( | |
| masked_soft, iteration, total_iterations, iteration_rope_dim_fraction | |
| ) | |
| # ========== STEP 8.5: Apply EMA blending ========== | |
| ema_step = self.config.soft_embedding_ema_step | |
| if ema_step < 1.0 and prev_soft_embeds is not None: | |
| prev_masked_soft = prev_soft_embeds[mask_pos] | |
| masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft | |
| # ========== STEP 9: Place back and return ========== | |
| soft_embeds[mask_pos] = masked_soft | |
| return soft_embeds.detach() | |
| def _compute_iteration_metrics( | |
| self, logits: torch.Tensor, labels: torch.Tensor | |
| ) -> IterationMetrics: | |
| """ | |
| Compute token-level AND sequence-level metrics for a single iteration. | |
| Returns scalars only - no large tensor storage. | |
| Token-level metrics: | |
| - accuracy: fraction of correct token predictions | |
| - entropy: average entropy per token | |
| - softmax_ce: cross-entropy loss per token | |
| Sequence-level metrics: | |
| - full_sequence_accuracy: fraction of sequences where ALL tokens are correct | |
| - min_sequence_confidence: mean of minimum top-1 confidence per sequence | |
| """ | |
| B = logits.shape[0] | |
| # Move to CPU to avoid GPU OOM - metrics are for monitoring only | |
| logits = logits.detach().cpu().float() # float32 is sufficient for metrics | |
| target_labels = labels.detach().cpu().contiguous() | |
| mask = target_labels != -100 | |
| if mask.sum() == 0: | |
| return IterationMetrics( | |
| accuracy=0.0, | |
| entropy=0.0, | |
| softmax_ce=0.0, | |
| full_sequence_accuracy=0.0, | |
| min_sequence_confidence=0.0, | |
| ) | |
| logits = logits.contiguous() | |
| predictions = logits.argmax(dim=-1) | |
| correct = (predictions == target_labels) & mask | |
| # ===== TOKEN-LEVEL METRICS ===== | |
| # Token accuracy | |
| accuracy = (correct.sum() / mask.sum()).item() | |
| # Extract valid tokens for entropy/CE | |
| valid_logits = logits[mask] | |
| valid_labels = target_labels[mask] | |
| # Entropy (using log_softmax for numerical stability) | |
| log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1) | |
| probs = torch.exp(log_probs) | |
| entropy = -(probs * log_probs).sum(dim=-1).mean().item() | |
| # Cross-entropy | |
| softmax_ce = torch.nn.functional.cross_entropy( | |
| valid_logits, valid_labels, reduction="mean" | |
| ).item() | |
| # ===== SEQUENCE-LEVEL METRICS ===== | |
| # Check which sequences have valid tokens | |
| sequences_with_tokens = mask.any(dim=1) # (B,) | |
| num_valid_sequences = sequences_with_tokens.sum().item() | |
| if num_valid_sequences == 0: | |
| return IterationMetrics( | |
| accuracy=accuracy, | |
| entropy=entropy, | |
| softmax_ce=softmax_ce, | |
| full_sequence_accuracy=0.0, | |
| min_sequence_confidence=0.0, | |
| ) | |
| # Full sequence accuracy: all tokens in sequence must be correct | |
| num_correct_per_seq = correct.sum(dim=1) # (B,) | |
| num_tokens_per_seq = mask.sum(dim=1) # (B,) | |
| all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens | |
| full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item() | |
| # Min sequence confidence: minimum top-1 probability within each sequence | |
| probs_full = torch.softmax(logits, dim=-1) # (B, L, V) - already float32 | |
| top1_confidence = probs_full.max(dim=-1).values # (B, L) | |
| min_confidences = [] | |
| for i in range(B): | |
| if sequences_with_tokens[i]: | |
| seq_confidences = top1_confidence[i][mask[i]] # (num_tokens_in_seq,) | |
| min_confidences.append(seq_confidences.min().item()) | |
| min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0 | |
| return IterationMetrics( | |
| accuracy=accuracy, | |
| entropy=entropy, | |
| softmax_ce=softmax_ce, | |
| full_sequence_accuracy=full_seq_accuracy, | |
| min_sequence_confidence=min_seq_conf, | |
| ) | |
| def _single_iteration( | |
| self, | |
| t: int, | |
| T: int, | |
| soft_embeds: torch.Tensor, | |
| base_embeds: torch.Tensor, | |
| mask_pos: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| labels: Optional[torch.Tensor], | |
| compute_metrics: bool, | |
| position_ids: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]: | |
| """ | |
| Execute a single iteration of recursive refinement. | |
| Args: | |
| t: Current iteration index (0 to T-1) | |
| T: Total number of iterations | |
| soft_embeds: Soft embeddings for mask positions | |
| base_embeds: Base token embeddings from input_ids | |
| mask_pos: Boolean mask of [MASK] positions (B, L) | |
| attention_mask: Attention mask for MLM | |
| labels: Target labels for loss computation | |
| compute_metrics: Whether to compute iteration metrics | |
| Returns: | |
| logits: Output logits from MLM (B, L, V) | |
| weighted_loss: Loss weighted by step_weight(t, T), or None if no labels | |
| metrics: IterationMetrics, or None if not requested | |
| """ | |
| # Blend soft embeddings (at mask positions) with base embeddings (at non-mask positions) | |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) | |
| # Forward through base MLM | |
| outputs = self.mlm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| labels=labels, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| # Compute weighted loss for this iteration | |
| weighted_loss = outputs.loss | |
| if labels is not None: | |
| if weighted_loss is None: | |
| # Base model doesn't compute loss (e.g., LLaDA) - compute it ourselves | |
| # Only compute loss on MASKED positions (MDLM training) | |
| masked_logits = outputs.logits[mask_pos] # (num_masked, V) | |
| masked_labels = labels[mask_pos] # (num_masked,) | |
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |
| weighted_loss = loss_fct(masked_logits, masked_labels) | |
| weighted_loss *= self.step_weight(t, T) | |
| # Compute iteration metrics if requested | |
| metrics = None | |
| if compute_metrics and labels is not None: | |
| metrics = self._compute_iteration_metrics(outputs.logits, labels) | |
| return outputs.logits, weighted_loss, metrics | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| num_recursions: Optional[int] = None, | |
| compute_iteration_metrics: bool = False, | |
| use_recursion_checkpointing: Optional[bool] = None, | |
| # Parameters for single-iteration training mode (DEPRECATED) | |
| prev_soft_embeds: Optional[torch.Tensor] = None, | |
| run_set_iteration: Optional[int] = None, | |
| # === Convergence schedule parameters (None = use config defaults) === | |
| schedule: Optional[str] = None, | |
| causal_strength: Optional[float] = None, | |
| # === Effect parameters (None = use config defaults) === | |
| temperature_max: Optional[float] = None, | |
| entropy_target_max: Optional[float] = None, | |
| entropy_floor_max: Optional[float] = None, | |
| smear_sigma_max: Optional[float] = None, | |
| noise_std_max: Optional[float] = None, | |
| iteration_rope_dim_fraction: Optional[float] = None, | |
| **kwargs, | |
| ) -> RecursiveMaskedLMOutput: | |
| """ | |
| Forward with recursive refinement. | |
| Supports three modes: | |
| 1. Checkpointed mode (default): Run all T recursions with gradient checkpointing. | |
| Gradients flow through the entire chain; activations recomputed during backward. | |
| 2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations. | |
| Faster backward but higher memory. | |
| 3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one | |
| iteration. Use use_recursion_checkpointing=True instead. | |
| Loss Weighting (config.loss_weight): | |
| "last_1": Only final iteration loss (enables learning convergence behavior) | |
| "last_2": Last 2 iterations | |
| "linear": All iterations, linearly weighted (default) | |
| "uniform": All iterations, uniformly weighted | |
| Recursion Checkpointing: | |
| use_recursion_checkpointing: Enable gradient checkpointing for iterations. | |
| True = checkpoint each iteration, recompute during backward (default). | |
| False = store all activations (higher memory, faster backward). | |
| Convergence Schedule Parameters: | |
| All schedule/effect parameters default to their config values if not specified. | |
| Pass explicit values to override config for this forward pass. | |
| schedule: "linear" or "causal" - controls when positions can converge | |
| causal_strength: How much faster early positions converge (causal only) | |
| temperature_max: Max temperature boost for uncertain positions | |
| entropy_target_max: Target entropy at progress=0 (two-sided, recommended) | |
| entropy_floor_max: Min entropy floor (one-sided) | |
| smear_sigma_max: Max Gaussian sigma for position smearing | |
| noise_std_max: Max std of Gaussian noise on logits | |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE | |
| """ | |
| B, L = input_ids.shape | |
| V = self.embed_weight.shape[0] | |
| mask_id = self.config.mask_token_id | |
| if mask_id is None: | |
| raise ValueError("mask_token_id must be set") | |
| # Resolve config default for recursion checkpointing | |
| use_recursion_checkpointing = ( | |
| use_recursion_checkpointing | |
| if use_recursion_checkpointing is not None | |
| else self.config.use_recursion_checkpointing | |
| ) | |
| mask_pos = (input_ids == mask_id) # (B, L) | |
| base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) | |
| T = num_recursions or self.config.num_recursions | |
| weight_sum = sum(self.step_weight(i, T) for i in range(T)) | |
| # Bundle schedule kwargs to pass to _compute_next_soft_embeds | |
| schedule_kwargs = dict( | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| temperature_max=temperature_max, | |
| entropy_target_max=entropy_target_max, | |
| entropy_floor_max=entropy_floor_max, | |
| smear_sigma_max=smear_sigma_max, | |
| noise_std_max=noise_std_max, | |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, | |
| ) | |
| # ===== SINGLE ITERATION MODE (DEPRECATED) ===== | |
| if run_set_iteration is not None: | |
| warnings.warn( | |
| "run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, " | |
| "which provides proper gradient flow through all iterations.", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| t = run_set_iteration | |
| # Get soft embeddings for this iteration | |
| if t == 0: | |
| # t=0: Uniform prior = average embedding (equivalent to softmax(zeros) @ embed_weight) | |
| # We compute this efficiently via embed_weight.mean() rather than creating large zero tensors | |
| soft_embeds = base_embeds.clone() | |
| if mask_pos.any(): | |
| avg_embed = self.embed_weight.mean(dim=0) # (H,) - mean over all V tokens | |
| mask_emb = self.embed_weight[mask_id] | |
| soft_embeds[mask_pos] = avg_embed + mask_emb | |
| else: | |
| if prev_soft_embeds is None: | |
| raise ValueError(f"prev_soft_embeds must be provided for iteration {t}") | |
| soft_embeds = prev_soft_embeds | |
| logits, weighted_loss, metrics = self._single_iteration( | |
| t, T, soft_embeds, base_embeds, mask_pos, | |
| attention_mask, labels, compute_iteration_metrics, | |
| position_ids=position_ids, **kwargs | |
| ) | |
| # Normalize loss by total weight sum | |
| loss = weighted_loss / weight_sum if weighted_loss is not None else None | |
| # Compute soft embeddings for next iteration (if not last) | |
| next_soft_embeds = None | |
| if t < T - 1: | |
| next_soft_embeds = self._compute_next_soft_embeds( | |
| logits, mask_pos, base_embeds, | |
| iteration=t, | |
| total_iterations=T, | |
| **schedule_kwargs, | |
| ) | |
| return RecursiveMaskedLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| next_soft_embeds=next_soft_embeds, | |
| iteration_metrics={t: metrics} if metrics is not None else None, | |
| ) | |
| # ===== CHECKPOINTED MODE (gradient flow through all iterations) ===== | |
| embed_weight = self.embed_weight | |
| mask_emb = embed_weight[mask_id] # (H,) | |
| # Temperature must be a tensor for checkpointing (checkpoint requires tensor inputs) | |
| temperature = torch.tensor( | |
| self.config.temperature, | |
| device=input_ids.device, | |
| dtype=base_embeds.dtype, | |
| ) | |
| # Ensure attention_mask is a tensor (required for checkpointing) | |
| if attention_mask is None: | |
| attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype) | |
| # Initialize soft embeddings for masked positions | |
| soft_embeds = base_embeds.clone() | |
| flow_noise_embed = None | |
| flow_t_per_token = None | |
| if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any(): | |
| # Flow matching: interpolate between random noise and target on the simplex | |
| num_masked = mask_pos.sum().item() | |
| V = embed_weight.shape[0] | |
| device = input_ids.device | |
| # Sample per-token time levels (logit-normal by default) | |
| flow_t_per_token = self._sample_flow_matching_t(num_masked, device) | |
| # Random noise embedding: sample on simplex, project to H-dim | |
| z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) | |
| p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype) | |
| flow_noise_embed = p_noise @ embed_weight # (num_masked, H) — compact | |
| # Target embedding from labels | |
| target_ids = labels[mask_pos] # original token IDs at masked positions | |
| target_embed = embed_weight[target_ids] # (num_masked, H) | |
| # Interpolate in embedding space | |
| t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1) | |
| interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed | |
| # Add mask signal (binary or scaled) | |
| if self.config.flow_matching_mask_scale: | |
| soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb | |
| else: | |
| soft_embeds[mask_pos] = interp_embed + mask_emb | |
| elif mask_pos.any(): | |
| # Standard uniform prior (average embedding + mask signal) | |
| avg_embed = embed_weight.mean(dim=0) # (H,) | |
| soft_embeds[mask_pos] = avg_embed + mask_emb | |
| iteration_metrics = {} if compute_iteration_metrics and labels is not None else None | |
| # Main recursion loop with optional checkpointing | |
| all_logits = [] | |
| for t in range(T): | |
| if self.training and use_recursion_checkpointing: | |
| # Use checkpointing: activations recomputed during backward | |
| # This maintains gradient flow while saving memory | |
| logits, soft_embeds = torch_checkpoint( | |
| self._single_iteration_checkpointable, | |
| soft_embeds, | |
| base_embeds, | |
| mask_pos, | |
| attention_mask, | |
| embed_weight, | |
| mask_emb, | |
| temperature, | |
| position_ids, | |
| use_reentrant=False, # Critical for nested checkpointing! | |
| ) | |
| else: | |
| # No checkpointing: store all activations (inference or explicit disable) | |
| logits, soft_embeds = self._single_iteration_checkpointable( | |
| soft_embeds, | |
| base_embeds, | |
| mask_pos, | |
| attention_mask, | |
| embed_weight, | |
| mask_emb, | |
| temperature, | |
| position_ids, | |
| ) | |
| all_logits.append(logits) | |
| # Compute iteration metrics if requested (no grad needed) | |
| if iteration_metrics is not None and labels is not None: | |
| with torch.no_grad(): | |
| iteration_metrics[t] = self._compute_iteration_metrics(logits, labels) | |
| # Return all logits for trainer to compute loss with proper normalization | |
| # Trainer handles: timestep-based weighting, iteration weighting, batch/sequence/token normalization | |
| return RecursiveMaskedLMOutput( | |
| loss=None, # Let trainer compute loss | |
| logits=logits, # Final logits for inference/metrics | |
| all_logits=all_logits if self.training else None, # Only needed during training | |
| iteration_metrics=iteration_metrics or None, | |
| flow_noise_embed=flow_noise_embed, # For flow matching distillation | |
| flow_t=flow_t_per_token, # For flow matching distillation | |
| ) | |
| def _generate_flow_map( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| position_ids: Optional[torch.Tensor], | |
| num_steps: int, | |
| ) -> torch.Tensor: | |
| """Fill in mask positions using the CFM flow map update rule. | |
| Starts from a random point on the probability simplex and iteratively | |
| moves toward the model's predictions using the flow map step rule. | |
| Args: | |
| input_ids: Input with [MASK] tokens at positions to fill | |
| attention_mask: Attention mask | |
| position_ids: Position IDs | |
| num_steps: Number of flow map steps (finer = better, 1 step = greedy) | |
| Returns: | |
| Tensor with [MASK] positions filled with predicted tokens | |
| """ | |
| mask_pos = (input_ids == self.config.mask_token_id) | |
| num_masked = mask_pos.sum().item() | |
| if num_masked == 0: | |
| return input_ids.clone() | |
| device = input_ids.device | |
| V = self.embed_weight.shape[0] | |
| embed_weight = self.embed_weight | |
| mask_emb = embed_weight[self.config.mask_token_id] | |
| base_embeds = self.get_input_embeddings()(input_ids) | |
| # Start from random simplex point | |
| noise_scale = self.config.flow_matching_noise_scale | |
| p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype) | |
| times = torch.linspace(0, 1, num_steps + 1, device=device) | |
| for i in range(num_steps): | |
| t_now = times[i] | |
| t_next = times[i + 1] | |
| step_size = (t_next - t_now) / (1 - t_now) | |
| # Mask signal (binary or scaled) | |
| if self.config.flow_matching_mask_scale: | |
| mask_signal = (1 - t_now) * mask_emb | |
| else: | |
| mask_signal = mask_emb | |
| # Project current state to embedding space | |
| embed = p @ embed_weight + mask_signal | |
| soft_embeds = base_embeds.clone() | |
| soft_embeds[mask_pos] = embed | |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) | |
| outputs = self.mlm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| return_dict=True, | |
| ) | |
| pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype) | |
| # Flow map update: move toward model's prediction | |
| p = p + step_size * (pi - p) | |
| # Fix floating point drift off the simplex | |
| p = p.clamp(min=0) | |
| p = p / p.sum(dim=-1, keepdim=True) | |
| result = input_ids.clone() | |
| result[mask_pos] = p.argmax(dim=-1) | |
| return result | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| num_recursions: Optional[int] = None, | |
| # === Convergence schedule parameters (None = use config defaults) === | |
| schedule: Optional[str] = None, | |
| causal_strength: Optional[float] = None, | |
| # === Effect parameters (None = use config defaults) === | |
| temperature_max: Optional[float] = None, | |
| entropy_target_max: Optional[float] = None, | |
| entropy_floor_max: Optional[float] = None, | |
| smear_sigma_max: Optional[float] = None, | |
| noise_std_max: Optional[float] = None, | |
| iteration_rope_dim_fraction: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """Fill in mask positions via iterative refinement. | |
| When flow_matching_enabled, uses the CFM flow map update rule. | |
| Otherwise, uses standard recursive soft-token refinement. | |
| Args: | |
| input_ids: Input token IDs with [MASK] tokens at positions to fill | |
| attention_mask: Attention mask | |
| num_recursions: Override number of recursions/steps (default: config value) | |
| schedule: "linear" or "causal" convergence schedule | |
| causal_strength: How much faster early positions converge (causal only) | |
| temperature_max: Max temperature boost for uncertain positions | |
| entropy_target_max: Target entropy at progress=0 (two-sided) | |
| entropy_floor_max: Min entropy floor (one-sided) | |
| smear_sigma_max: Max Gaussian sigma for position smearing | |
| noise_std_max: Max std of Gaussian noise on logits | |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE | |
| Returns: | |
| Tensor with [MASK] positions filled with predicted tokens | |
| """ | |
| num_steps = num_recursions or self.config.num_recursions | |
| if self.config.flow_matching_enabled: | |
| return self._generate_flow_map( | |
| input_ids, attention_mask, position_ids, num_steps | |
| ) | |
| out = self.forward( | |
| input_ids, | |
| attention_mask, | |
| position_ids=position_ids, | |
| num_recursions=num_steps, | |
| schedule=schedule, | |
| causal_strength=causal_strength, | |
| temperature_max=temperature_max, | |
| entropy_target_max=entropy_target_max, | |
| entropy_floor_max=entropy_floor_max, | |
| smear_sigma_max=smear_sigma_max, | |
| noise_std_max=noise_std_max, | |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, | |
| ) | |
| result = input_ids.clone() | |
| mask_pos = (input_ids == self.config.mask_token_id) | |
| result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos] | |
| return result | |