""" RabbitForCausalLM — AutoModel-compatible wrapper for Anvaya-Rabbit. pip install rtaforge transformers model = AutoModelForCausalLM.from_pretrained( "RtaForge/Anvaya-Rabbit-2.7B", trust_remote_code=True ) """ from __future__ import annotations import torch from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast try: from configuration_rabbit import RabbitConfig except ImportError: from .configuration_rabbit import RabbitConfig try: from white_rabbit.rabbit_model import RabbitCausalLM, RabbitModelConfig except ImportError as _e: raise ImportError( "The rtaforge package is required to load this model.\n" "Install it with: pip install rtaforge" ) from _e class RabbitForCausalLM(PreTrainedModel): config_class = RabbitConfig supports_gradient_checkpointing = True def __init__(self, config: RabbitConfig): super().__init__(config) self._inner = RabbitCausalLM( RabbitModelConfig( vocab_size=config.vocab_size, d_model=config.d_model, n_layers=config.n_layers, durga_variant="fu-64", ) ) def get_input_embeddings(self): return self._inner.embed_tokens def set_input_embeddings(self, value): self._inner.embed_tokens = value self._inner.lm_head.weight = value.weight def get_output_embeddings(self): return self._inner.lm_head def set_output_embeddings(self, value): self._inner.lm_head = value def forward( self, input_ids: torch.Tensor, labels: torch.Tensor | None = None, **kwargs, ) -> CausalLMOutputWithPast: out = self._inner(input_ids=input_ids, labels=labels) return CausalLMOutputWithPast(loss=out.get("loss"), logits=out["logits"]) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}