| """ |
| RabbitForCausalLM — AutoModel-compatible wrapper for Anvaya-Rabbit. |
| |
| pip install rtaforge transformers |
| model = AutoModelForCausalLM.from_pretrained( |
| "RtaForge/Anvaya-Rabbit-2.7B", trust_remote_code=True |
| ) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| try: |
| from configuration_rabbit import RabbitConfig |
| except ImportError: |
| from .configuration_rabbit import RabbitConfig |
|
|
| try: |
| from white_rabbit.rabbit_model import RabbitCausalLM, RabbitModelConfig |
| except ImportError as _e: |
| raise ImportError( |
| "The rtaforge package is required to load this model.\n" |
| "Install it with: pip install rtaforge" |
| ) from _e |
|
|
|
|
| class RabbitForCausalLM(PreTrainedModel): |
| config_class = RabbitConfig |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: RabbitConfig): |
| super().__init__(config) |
| self._inner = RabbitCausalLM( |
| RabbitModelConfig( |
| vocab_size=config.vocab_size, |
| d_model=config.d_model, |
| n_layers=config.n_layers, |
| durga_variant="fu-64", |
| ) |
| ) |
|
|
| def get_input_embeddings(self): |
| return self._inner.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self._inner.embed_tokens = value |
| self._inner.lm_head.weight = value.weight |
|
|
| def get_output_embeddings(self): |
| return self._inner.lm_head |
|
|
| def set_output_embeddings(self, value): |
| self._inner.lm_head = value |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor | None = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| out = self._inner(input_ids=input_ids, labels=labels) |
| return CausalLMOutputWithPast(loss=out.get("loss"), logits=out["logits"]) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|