Instructions to use TuKoResearch/AuriStreamParallel-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TuKoResearch/AuriStreamParallel-base with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("TuKoResearch/AuriStreamParallel-base", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| AuriStream Parallel model for HuggingFace Transformers. | |
| """ | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutput | |
| from .configuration_auristream_parallel import AuriStreamParallelConfig | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) if weight else None | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| out = self._norm(x.float()).type_as(x) | |
| return out * self.weight if self.weight is not None else out | |
| class Rotary(nn.Module): | |
| def __init__(self, dim: int, base: float = 10000): | |
| super().__init__() | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, x): | |
| seq_len = x.shape[1] | |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| freqs = torch.outer(t, self.inv_freq).to(x.device) | |
| return freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :] | |
| def apply_rotary_emb(x, cos, sin): | |
| d = x.shape[3] // 2 | |
| x1 = x[..., :d] | |
| x2 = x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat([y1, y2], dim=3) | |
| class BidirectionalSelfAttention(nn.Module): | |
| def __init__(self, config: AuriStreamParallelConfig): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = self.n_embd // self.n_head | |
| assert self.n_embd % self.n_head == 0 | |
| self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) | |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.rotary = None | |
| if getattr(config, "use_rope", True): | |
| rope_theta = getattr(config, "rope_theta", 10000.0) or 10000.0 | |
| self.rotary = Rotary(self.head_dim, base=rope_theta) | |
| def forward(self, x): | |
| bsz, tsz, channels = x.size() | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(self.n_embd, dim=2) | |
| q = q.view(bsz, tsz, self.n_head, self.head_dim) | |
| k = k.view(bsz, tsz, self.n_head, self.head_dim) | |
| v = v.view(bsz, tsz, self.n_head, self.head_dim) | |
| if self.rotary is not None: | |
| cos, sin = self.rotary(q) | |
| q = apply_rotary_emb(q, cos, sin) | |
| k = apply_rotary_emb(k, cos, sin) | |
| y = F.scaled_dot_product_attention( | |
| q.transpose(1, 2), | |
| k.transpose(1, 2), | |
| v.transpose(1, 2), | |
| is_causal=False, | |
| ) | |
| y = y.transpose(1, 2).contiguous().view(bsz, tsz, channels) | |
| return self.c_proj(y) | |
| class MLP(nn.Module): | |
| def __init__(self, config: AuriStreamParallelConfig): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.act = nn.SiLU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.act(x) | |
| x = self.c_proj(x) | |
| return self.dropout(x) | |
| class Block(nn.Module): | |
| def __init__(self, config: AuriStreamParallelConfig): | |
| super().__init__() | |
| self.attn = BidirectionalSelfAttention(config) | |
| self.mlp = MLP(config) | |
| self.norm1 = RMSNorm(config.n_embd, bias=config.bias) | |
| self.norm2 = RMSNorm(config.n_embd, bias=config.bias) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class AuriStreamPreTrainedModel(PreTrainedModel): | |
| config_class = AuriStreamParallelConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["Block"] | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| class AuriStreamModel(AuriStreamPreTrainedModel): | |
| """HF-compatible AuriStream Parallel model.""" | |
| config_class = AuriStreamParallelConfig | |
| def __init__(self, config: AuriStreamParallelConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.group_size = int(getattr(config, "group_size", 4)) | |
| grouped_seq_len = max(1, config.seq_len // self.group_size) | |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.wpe = None | |
| if not getattr(config, "use_rope", True): | |
| self.wpe = nn.Embedding(grouped_seq_len, config.n_embd) | |
| self.drop = nn.Dropout(config.dropout) | |
| self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) | |
| self.ln_f = RMSNorm(config.n_embd, bias=config.bias) | |
| self.group_in_proj = nn.Linear(self.group_size * config.n_embd, config.n_embd, bias=False) | |
| self.parallel_heads = nn.ModuleList( | |
| [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.group_size)] | |
| ) | |
| self.apply(self._init_weights) | |
| for name, param in self.named_parameters(): | |
| if name.endswith("c_proj.weight"): | |
| torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) | |
| def get_input_embeddings(self): | |
| return self.wte | |
| def set_input_embeddings(self, value): | |
| self.wte = value | |
| def _group_embed(self, input_ids: torch.LongTensor) -> torch.Tensor: | |
| bsz, tsz = input_ids.shape | |
| if tsz % self.group_size != 0: | |
| raise ValueError( | |
| f"Sequence length {tsz} must be divisible by group_size={self.group_size}" | |
| ) | |
| tok_emb = self.wte(input_ids) | |
| grouped = tok_emb.view(bsz, tsz // self.group_size, self.group_size, self.config.n_embd) | |
| grouped = grouped.reshape(bsz, tsz // self.group_size, self.group_size * self.config.n_embd) | |
| x = self.group_in_proj(grouped) | |
| if self.wpe is not None: | |
| pos = torch.arange(x.size(1), device=input_ids.device) | |
| x = x + self.wpe(pos) | |
| return self.drop(x) | |
| def _decode_parallel_logits(self, x: torch.Tensor) -> torch.Tensor: | |
| per_head = [head(x) for head in self.parallel_heads] | |
| logits = torch.stack(per_head, dim=2) # (B, T_g, G, V) | |
| bsz, tg, gsz, vsz = logits.shape | |
| return logits.reshape(bsz, tg * gsz, vsz) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_hidden_states: Optional[bool] = False, | |
| return_dict: Optional[bool] = True, | |
| seq: Optional[torch.LongTensor] = None, | |
| tgt: Optional[torch.LongTensor] = None, | |
| ): | |
| if seq is not None: | |
| input_ids = seq | |
| if tgt is not None: | |
| labels = tgt | |
| if input_ids is None: | |
| raise ValueError("input_ids (or seq) must be provided") | |
| usable_len = (input_ids.shape[1] // self.group_size) * self.group_size | |
| if usable_len <= 0: | |
| raise ValueError( | |
| f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}" | |
| ) | |
| if usable_len != input_ids.shape[1]: | |
| input_ids = input_ids[:, :usable_len] | |
| if labels is not None: | |
| labels = labels[:, :usable_len] | |
| x = self._group_embed(input_ids) | |
| all_hidden_states = () | |
| if output_hidden_states: | |
| all_hidden_states = (x,) | |
| for block in self.h: | |
| x = block(x) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (x,) | |
| x = self.ln_f(x) | |
| logits = self._decode_parallel_logits(x) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.reshape(-1, self.config.vocab_size), | |
| labels.reshape(-1), | |
| ignore_index=getattr(self.config, "ignore_index", -100), | |
| ) | |
| if not return_dict: | |
| out = (logits,) | |
| if output_hidden_states: | |
| out = out + (all_hidden_states,) | |
| return ((loss,) + out) if loss is not None else out | |
| return CausalLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=all_hidden_states if output_hidden_states else None, | |
| attentions=None, | |
| ) | |