stanfordnlp/snli
Viewer β’ Updated β’ 570k β’ 31.8k β’ 91
How to use jmccardle/modernbert-nli-heads with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text-classification", model="jmccardle/modernbert-nli-heads") # Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("jmccardle/modernbert-nli-heads", dtype="auto")Lightweight NLI classification heads for ModernBERT-large that preserve base encoder compatibility. Only 2.3MB of weights - the base model is pulled from HuggingFace automatically.
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
# Download task heads (2.3MB)
weights_path = hf_hub_download("YOUR_USERNAME/modernbert-nli-heads", "task_heads.pt")
# Load base model from HuggingFace
encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-large")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
# Build task heads
nli_hidden = nn.Sequential(
nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.1)
)
nli_output = nn.Linear(512, 3)
abstention_head = nn.Sequential(
nn.Linear(515, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(0.1), nn.Linear(128, 2)
)
# Load weights
task_heads = torch.load(weights_path, map_location="cpu")
nli_hidden.load_state_dict({k.replace("nli_hidden.", ""): v for k, v in task_heads.items() if k.startswith("nli_hidden.")})
nli_output.load_state_dict({k.replace("nli_output.", ""): v for k, v in task_heads.items() if k.startswith("nli_output.")})
abstention_head.load_state_dict({k.replace("abstention_head.", ""): v for k, v in task_heads.items() if k.startswith("abstention_head.")})
Or use the provided load_model.py for a cleaner interface:
from load_model import load_modernbert_nli, predict_with_abstention
model, tokenizer = load_modernbert_nli("task_heads.pt")
result = predict_with_abstention(
model, tokenizer,
premise="A man is playing guitar on stage.",
hypothesis="A person is making music."
)
# {'label': 'entailment', 'confidence': 0.788, 'abstain': False, 'uncertainty': 0.32}
# 1. Bi-encoder embeddings (semantic search)
embeddings = model(input_ids, attention_mask, mode="embed") # (batch, 1024)
# 2. Late interaction (ColBERT-style reranking)
token_reps = model(input_ids, attention_mask, mode="late_interaction") # (batch, seq_len, 1024)
# 3. NLI classification
logits = model(input_ids, attention_mask, mode="nli") # (batch, 3)
# Labels: 0=entailment, 1=neutral, 2=contradiction
# 4. NLI with abstention
nli_logits, abstention_logits = model(input_ids, attention_mask, mode="abstention")
should_abstain = abstention_logits.argmax(dim=-1) == 1
| Metric | Value |
|---|---|
| Training Accuracy | 70.8% |
| Validation Accuracy | ~75-80% |
| Parameters | 527K |
Note: Frozen encoder limits ceiling vs full fine-tuning (~90%), but preserves embedding compatibility.
| Metric | Value |
|---|---|
| Accuracy | 65.5% |
| Precision | 44.6% |
| Recall | 76.6% |
| F1 | 56.3 |
What this means in practice:
The abstention head outperforms simple confidence thresholding because it uses semantic features from the hidden state, not just logit entropy. In testing, it caught 5 errors that a 50% confidence threshold would have missed.
categories = {
"code": "This is a programming-related request",
"factual": "This is a request for factual information",
"creative": "This is a request for creative content",
}
def route_query(query):
results = []
for name, hypothesis in categories.items():
result = predict_with_abstention(model, tokenizer, query, hypothesis)
results.append((name, result))
# Pick highest entailment score, respecting abstention
confident_results = [(n, r) for n, r in results if not r["abstain"]]
if confident_results:
return max(confident_results, key=lambda x: x[1]["probs"]["entailment"])
else:
return None, "uncertain" # All categories abstained
def validate_fact(source: str, claim: str) -> dict:
result = predict_with_abstention(model, tokenizer, source, claim)
return {
"supported": result["label"] == "entailment",
"contradicted": result["label"] == "contradiction",
"uncertain": result["abstain"],
"confidence": result["confidence"]
}
ModernBERT-large (394.8M params, frozen)
β
[CLS] token (1024 dim)
β
βββββββββββββββββββββββββββββββββββ
β NLI Hidden (525K params) β
β Linear(1024β512) + LN + GELU β
βββββββββββββββββββββββββββββββββββ
β
βββ NLI Output (1.5K params)
β Linear(512β3) β [ent, neu, con]
β
βββ Abstention Head (67K params)
Concat([hidden, logits]) β 515 dim
Linear(515β128) + LN + GELU
Linear(128β2) β [confident, uncertain]
task_heads.pt (2.3MB) - PyTorch state dict with all task head weightsconfig.json - Model configuration and training metadataload_model.py - Standalone loader script (copy into your project)@misc{modernbert-nli-abstention,
title={ModernBERT-NLI with Learned Abstention},
author={[Your Name]},
year={2024},
url={https://huggingface.co/YOUR_USERNAME/modernbert-nli-heads}
}
Base model
answerdotai/ModernBERT-large