WIP: Fix HF inference to match native MolmoWeb outputs exactly

#2
by PTeterwak - opened

Differences and Fixes

  1. Chat template format
  1. Pixel normalization (CPU vs GPU float32 divergence)
  1. RoPE inv_freq computation device
  1. SDPA attention mask vs is_causal
  1. GQA key/value expansion method

Recommended usage

model = AutoModelForImageTextToText.from_pretrained(
"allenai/MolmoWeb-8B",
trust_remote_code=True,
torch_dtype=torch.float32,
attn_implementation="sdpa",
device_map="auto",
)

TBD: Before and after metrics on mind2web.

PTeterwak changed pull request title from Fix HF inference to match native OLMo outputs exactly to WIP: Fix HF inference to match native OLMo outputs exactly

Test comment - verifying API works

This comment has been hidden (marked as Resolved)
This comment has been hidden (marked as Resolved)
PTeterwak changed pull request status to closed
PTeterwak changed pull request status to open
This comment has been hidden (marked as Resolved)
PTeterwak changed pull request title from WIP: Fix HF inference to match native OLMo outputs exactly to WIP: Fix HF inference to match native MolmoWeb outputs exactly

Thanks for your suggestions! We apologize for the differences between native and HF inference due to using an outdated HF conversion script. We've made the following crucial changes (some based on your suggestions) and tested that the current HF checkpoint's outputs matches native outputs.

  1. updated the chat template (aligned with your point 1, see PR: https://huggingface.co/allenai/MolmoWeb-8B/discussions/3)
  2. aligned the image processing script with latest molmo2 code base to ensure the same image tokens from HF and native image processing (see PR: https://huggingface.co/allenai/MolmoWeb-8B/discussions/3)
  3. removed token_type_ids in HF inference: inputs = {k: v.to("cuda") for k, v in inputs.items() if k != "token_type_ids"} (see updated README.md)
  • HF uses it to enable bidirectional attention for image tokens; molmoweb is trained with causal attention only
  1. updated default model loading to use torch_dtype=torch.float32 and attn_implementation="sdpa"
  • note that torch_dtype="auto" should map to torch.float32, and even with attn_implementation="sdpa", HF's sdpa attention implementation is not exactly the same as olmo's sdpa implementation in molmo2. However, we agree that it's best practice to set these two flags explicitly.

Regarding your point 2, 3, and 5, we think that these don't necessarily lead to big differences in model outputs but are good optimizations. We'll consider merging them if your mind2web test goes well. Looking forward to hearing from you!

Thanks again.

@zixianma02 , thanks for the quick turn around! I'll so some careful ablations this week, so we can figure out the minimal set of differences.

@zixianma02 , I ran a small ablation and found the importance of these changes as below (assuming fixed chat template):

| Config                     | top1 | abs_max  | abs_mean   | cos_sim      |
  |----------------------------|------|----------|------------|--------------|
  | all_PR_fixes               | Y    | 0.000000 | 0.00000000 | 0.9999949932 |
  | remove_pixel_norm          | Y    | 0.000072 | 0.00001029 | 0.9999949932 |
  | remove_rope_device         | Y    | 0.000072 | 0.00001116 | 0.9999949932 |
  | remove_sdpa_mask           | Y    | 2.945995 | 0.72856593 | 0.9999384284 |
  | remove_gqa_expand          | Y    | 0.000065 | 0.00001033 | 0.9999949932 |
  | parent_only (current main) | Y    | 2.945995 | 0.72855932 | 0.9999384284 |

It does indeed look like the difference is minimal (but real) for the others, but will still run the mind2web benchmark in case you are still interested in merging.

Cannot merge
This branch has merge conflicts in the following files:
  • README.md
  • chat_template.jinja

Sign up or log in to comment