WIP: Fix HF inference to match native MolmoWeb outputs exactly
Differences and Fixes
- Chat template format
- Native: User: {message} Assistant: — see https://github.com/allenai/molmo2/blob/main/olmo/preprocessing/data_formatter.py (message_format == "role" branch)
- HF (before): Qwen format with im_start/im_end tags
- Fix: Rewrote chat_template.jinja to emit native format
- Pixel normalization (CPU vs GPU float32 divergence)
- Native: Preprocessor outputs uint8 when normalize_on_gpu=True; model normalizes on GPU — see
https://github.com/allenai/molmo2/blob/main/olmo/preprocessing/image_preprocessor.py and https://github.com/allenai/molmo2/blob/main/olmo/nn/vision_backbone.py - HF (before): Processor normalizes on CPU (float32 / 255)
- Fix: Processor outputs uint8 (do_normalize=False), model normalizes on GPU
- RoPE inv_freq computation device
- Native: Computed on GPU at forward time — see https://github.com/allenai/molmo2/blob/main/olmo/nn/llm.py compute_rope_parameters()
- HF (before): Computed on CPU at init, buffer moved to GPU
- Fix: Recompute inv_freq on target device in forward()
- SDPA attention mask vs is_causal
- Native: attn_mask=None, is_causal=True during prefill — see https://github.com/allenai/molmo2/blob/main/olmo/nn/llm.py attention() method
- HF (before): Explicit 4D boolean causal mask passed to SDPA
- Fix: Drop mask during prefill so SDPA uses is_causal=True
- GQA key/value expansion method
- Native: k.repeat_interleave(n, dim=1) before SDPA — see https://github.com/allenai/molmo2/blob/main/olmo/nn/llm.py _scaled_dot_product_attention()
- HF (before): Passes enable_gqa=True to SDPA kernel
- Fix: Pre-expand KV heads via repeat_interleave, bypass enable_gqa
Recommended usage
model = AutoModelForImageTextToText.from_pretrained(
"allenai/MolmoWeb-8B",
trust_remote_code=True,
torch_dtype=torch.float32,
attn_implementation="sdpa",
device_map="auto",
)
TBD: Before and after metrics on mind2web.
Test comment - verifying API works
Thanks for your suggestions! We apologize for the differences between native and HF inference due to using an outdated HF conversion script. We've made the following crucial changes (some based on your suggestions) and tested that the current HF checkpoint's outputs matches native outputs.
- updated the chat template (aligned with your point 1, see PR: https://huggingface.co/allenai/MolmoWeb-8B/discussions/3)
- aligned the image processing script with latest molmo2 code base to ensure the same image tokens from HF and native image processing (see PR: https://huggingface.co/allenai/MolmoWeb-8B/discussions/3)
- removed
token_type_idsin HF inference:inputs = {k: v.to("cuda") for k, v in inputs.items() if k != "token_type_ids"}(see updated README.md)
- HF uses it to enable bidirectional attention for image tokens; molmoweb is trained with causal attention only
- updated default model loading to use
torch_dtype=torch.float32andattn_implementation="sdpa"
- note that
torch_dtype="auto"should map totorch.float32, and even withattn_implementation="sdpa", HF's sdpa attention implementation is not exactly the same as olmo's sdpa implementation in molmo2. However, we agree that it's best practice to set these two flags explicitly.
Regarding your point 2, 3, and 5, we think that these don't necessarily lead to big differences in model outputs but are good optimizations. We'll consider merging them if your mind2web test goes well. Looking forward to hearing from you!
Thanks again.
@zixianma02 , thanks for the quick turn around! I'll so some careful ablations this week, so we can figure out the minimal set of differences.
@zixianma02 , I ran a small ablation and found the importance of these changes as below (assuming fixed chat template):
| Config | top1 | abs_max | abs_mean | cos_sim |
|----------------------------|------|----------|------------|--------------|
| all_PR_fixes | Y | 0.000000 | 0.00000000 | 0.9999949932 |
| remove_pixel_norm | Y | 0.000072 | 0.00001029 | 0.9999949932 |
| remove_rope_device | Y | 0.000072 | 0.00001116 | 0.9999949932 |
| remove_sdpa_mask | Y | 2.945995 | 0.72856593 | 0.9999384284 |
| remove_gqa_expand | Y | 0.000065 | 0.00001033 | 0.9999949932 |
| parent_only (current main) | Y | 2.945995 | 0.72855932 | 0.9999384284 |
It does indeed look like the difference is minimal (but real) for the others, but will still run the mind2web benchmark in case you are still interested in merging.