PTeterwak commited on
Commit
7c1460c
·
verified ·
1 Parent(s): a6bbf87

Fix HF inference to match native OLMo outputs exactly

Browse files

Five fixes for bitwise parity: chat template format, GPU pixel normalization, RoPE on-device recomputation, SDPA causal mask, and GQA repeat_interleave.

chat_template.jinja CHANGED
@@ -1 +1 @@
1
- {% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}{% if not (has_subtitle and loop.index == 2) and not (not has_subtitle and loop.first) %}{{ '<|im_end|>\n' }}{% endif %}{{ '<|im_start|>user\n' }}{{ text_content }}{{ '<|im_end|>\n' }}{% else %} {# assistant #}{{ '<|im_start|>assistant\n' }}{{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
 
1
+ {% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}User: {{ text_content }}{% else %} {{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %} Assistant:{% endif %}
image_processing_molmo2.py CHANGED
@@ -29,7 +29,10 @@ def normalize_image(
29
  image: np.ndarray,
30
  image_mean: list[float],
31
  image_std: list[float],
 
32
  ) -> np.ndarray:
 
 
33
  image -= np.array(image_mean, dtype=np.float32)[None, None, :]
34
  image /= np.array(image_std, dtype=np.float32)[None, None, :]
35
  return image
@@ -110,11 +113,12 @@ def build_resized_image(
110
  image_mean: list[float],
111
  image_std: list[float],
112
  image_patch_size: int,
 
113
  ) -> tuple[np.ndarray, np.ndarray]:
114
  resized = resize_image(
115
  image, base_image_input_size, resample,
116
  )
117
- resized = normalize_image(resized, image_mean, image_std)
118
  if len(resized.shape) == 3:
119
  resized = np.expand_dims(resized, 0)
120
  crop_patch_w = base_image_input_size[1] // image_patch_size
@@ -132,6 +136,7 @@ def build_overlapping_crops(
132
  image_mean: list[float],
133
  image_std: list[float],
134
  image_patch_size: int,
 
135
  ) -> tuple[np.ndarray, np.ndarray]:
136
  """Decompose an image into a set of overlapping crops
137
 
@@ -167,7 +172,7 @@ def build_overlapping_crops(
167
  [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
168
  resample,
169
  )
170
- src = normalize_image(src, image_mean, image_std)
171
 
172
  # Now we have to split the image into crops, and track what patches came from
173
  # where in `patch_idx_arr`
@@ -259,6 +264,7 @@ def image_to_patches_and_grids(
259
  image_patch_size: int,
260
  image_pooling_w: int,
261
  image_pooling_h: int,
 
262
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
263
  """
264
  :return image_grids, the shape of each (low-res, high-res) image after pooling
@@ -284,6 +290,7 @@ def image_to_patches_and_grids(
284
  image_mean,
285
  image_std,
286
  image_patch_size,
 
287
  )
288
  pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
289
  h, w = pooling_idx.shape[:2]
@@ -297,6 +304,7 @@ def image_to_patches_and_grids(
297
  image_mean,
298
  image_std,
299
  image_patch_size,
 
300
  )
301
  crop_arr = np.concatenate([resized, crop_arr], 0)
302
 
@@ -390,6 +398,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
390
  image_mean: Optional[Union[float, list[float]]] = None,
391
  image_std: Optional[Union[float, list[float]]] = None,
392
  do_convert_rgb: Optional[bool] = None,
 
393
  max_crops: Optional[int] = None,
394
  overlap_margins: Optional[list[int]] = None,
395
  patch_size: Optional[int] = None,
@@ -448,6 +457,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
448
  image_mean = image_mean or self.image_mean
449
  image_std = image_std or self.image_std
450
  do_convert_rgb = do_convert_rgb or self.do_convert_rgb
 
451
 
452
  max_crops = max_crops or self.max_crops
453
  overlap_margins = overlap_margins or self.overlap_margins
@@ -491,6 +501,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
491
  patch_size,
492
  image_pooling_w,
493
  image_pooling_h,
 
494
  )
495
  batch_grids.append(image_grid)
496
  batch_crops.append(crops)
@@ -498,6 +509,10 @@ class Molmo2ImageProcessor(BaseImageProcessor):
498
  batch_num_crops.append(crops.shape[0])
499
 
500
  pixel_values = np.concatenate(batch_crops, 0)
 
 
 
 
501
  image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
502
  image_grids = np.concatenate(batch_grids, 0)
503
  image_num_crops = np.array(batch_num_crops)
 
29
  image: np.ndarray,
30
  image_mean: list[float],
31
  image_std: list[float],
32
+ do_normalize: bool = True,
33
  ) -> np.ndarray:
34
+ if not do_normalize:
35
+ return image
36
  image -= np.array(image_mean, dtype=np.float32)[None, None, :]
37
  image /= np.array(image_std, dtype=np.float32)[None, None, :]
38
  return image
 
113
  image_mean: list[float],
114
  image_std: list[float],
115
  image_patch_size: int,
116
+ do_normalize: bool = True,
117
  ) -> tuple[np.ndarray, np.ndarray]:
118
  resized = resize_image(
119
  image, base_image_input_size, resample,
120
  )
121
+ resized = normalize_image(resized, image_mean, image_std, do_normalize=do_normalize)
122
  if len(resized.shape) == 3:
123
  resized = np.expand_dims(resized, 0)
124
  crop_patch_w = base_image_input_size[1] // image_patch_size
 
136
  image_mean: list[float],
137
  image_std: list[float],
138
  image_patch_size: int,
139
+ do_normalize: bool = True,
140
  ) -> tuple[np.ndarray, np.ndarray]:
141
  """Decompose an image into a set of overlapping crops
142
 
 
172
  [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
173
  resample,
174
  )
175
+ src = normalize_image(src, image_mean, image_std, do_normalize=do_normalize)
176
 
177
  # Now we have to split the image into crops, and track what patches came from
178
  # where in `patch_idx_arr`
 
264
  image_patch_size: int,
265
  image_pooling_w: int,
266
  image_pooling_h: int,
267
+ do_normalize: bool = True,
268
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
269
  """
270
  :return image_grids, the shape of each (low-res, high-res) image after pooling
 
290
  image_mean,
291
  image_std,
292
  image_patch_size,
293
+ do_normalize=do_normalize,
294
  )
295
  pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
296
  h, w = pooling_idx.shape[:2]
 
304
  image_mean,
305
  image_std,
306
  image_patch_size,
307
+ do_normalize=do_normalize,
308
  )
309
  crop_arr = np.concatenate([resized, crop_arr], 0)
310
 
 
398
  image_mean: Optional[Union[float, list[float]]] = None,
399
  image_std: Optional[Union[float, list[float]]] = None,
400
  do_convert_rgb: Optional[bool] = None,
401
+ do_normalize: Optional[bool] = None,
402
  max_crops: Optional[int] = None,
403
  overlap_margins: Optional[list[int]] = None,
404
  patch_size: Optional[int] = None,
 
457
  image_mean = image_mean or self.image_mean
458
  image_std = image_std or self.image_std
459
  do_convert_rgb = do_convert_rgb or self.do_convert_rgb
460
+ do_normalize = do_normalize if do_normalize is not None else False
461
 
462
  max_crops = max_crops or self.max_crops
463
  overlap_margins = overlap_margins or self.overlap_margins
 
501
  patch_size,
502
  image_pooling_w,
503
  image_pooling_h,
504
+ do_normalize=do_normalize,
505
  )
506
  batch_grids.append(image_grid)
507
  batch_crops.append(crops)
 
509
  batch_num_crops.append(crops.shape[0])
510
 
511
  pixel_values = np.concatenate(batch_crops, 0)
512
+ if not do_normalize:
513
+ # Convert to uint8 so the model can normalize on GPU with exact
514
+ # native precision (CPU and GPU float32 /255 differ for ~half of uint8 values).
515
+ pixel_values = np.clip(pixel_values * 255 + 0.5, 0, 255).astype(np.uint8)
516
  image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
517
  image_grids = np.concatenate(batch_grids, 0)
518
  image_num_crops = np.array(batch_num_crops)
modeling_molmo2.py CHANGED
@@ -440,7 +440,11 @@ class Molmo2VisionBackbone(nn.Module):
440
 
441
  # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
442
  batch_size, num_image = images.shape[:2]
443
- images = images.to(device=self.device, dtype=self.dtype)
 
 
 
 
444
  image_features = self.encode_image(images)
445
 
446
  image_features = self.image_feature_dropout(image_features)
@@ -543,7 +547,14 @@ class Molmo2RotaryEmbedding(nn.Module):
543
 
544
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
545
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
546
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 
 
 
 
 
 
 
547
  emb = torch.cat((freqs, freqs), dim=-1)
548
  cos = emb.cos() * self.attention_scaling
549
  sin = emb.sin() * self.attention_scaling
@@ -710,16 +721,36 @@ class Molmo2Attention(nn.Module):
710
  if self.config._attn_implementation != "eager":
711
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  attn_output, attn_weights = attention_interface(
714
  self,
715
  query_states,
716
  key_states,
717
  value_states,
718
- attention_mask,
719
  dropout=0.0 if not self.training else self.attention_dropout,
720
  scaling=self.scaling,
721
  **kwargs,
722
  )
 
 
723
 
724
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
725
  attn_output = self.attn_out(attn_output)
 
440
 
441
  # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
442
  batch_size, num_image = images.shape[:2]
443
+ images = images.to(device=self.device)
444
+ # Normalize pixel values on GPU: uint8 [0,255] -> float [-1,1]
445
+ # This matches native OLMo's normalize_on_gpu path exactly.
446
+ images = images.float().div_(255.0).mul_(2.0).sub_(1.0)
447
+ images = images.to(dtype=self.dtype)
448
  image_features = self.encode_image(images)
449
 
450
  image_features = self.image_feature_dropout(image_features)
 
547
 
548
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
549
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
550
+ # Recompute inv_freq directly on the target device to avoid CPU/GPU
551
+ # float32 rounding differences when inv_freq is initialized on CPU.
552
+ dim = self.inv_freq.shape[0] * 2
553
+ inv_freq = 1.0 / (self.config.rope_theta ** (
554
+ torch.arange(0, dim, 2, dtype=torch.float, device=x.device) / dim
555
+ ))
556
+ seq = position_ids[0].float()
557
+ freqs = torch.einsum("i , j -> i j", seq, inv_freq).unsqueeze(0)
558
  emb = torch.cat((freqs, freqs), dim=-1)
559
  cos = emb.cos() * self.attention_scaling
560
  sin = emb.sin() * self.attention_scaling
 
721
  if self.config._attn_implementation != "eager":
722
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
723
 
724
+ # During prefill with SDPA, drop the explicit attention mask so SDPA uses
725
+ # is_causal=True internally. This matches native OLMo's behavior and avoids
726
+ # numerical differences from explicit-mask vs is_causal code paths.
727
+ sdpa_mask = attention_mask
728
+ if self.config._attn_implementation == "sdpa" and query_states.shape[2] > 1:
729
+ sdpa_mask = None
730
+
731
+ # Expand GQA key/value heads to match query heads via repeat_interleave,
732
+ # matching native OLMo's approach. This avoids the enable_gqa=True SDPA
733
+ # path which uses a different kernel and produces different float32 results.
734
+ # Temporarily set num_key_value_groups=1 so the HF SDPA wrapper doesn't
735
+ # try to handle GQA again on already-expanded tensors.
736
+ saved_groups = self.num_key_value_groups
737
+ if saved_groups > 1:
738
+ key_states = key_states.repeat_interleave(saved_groups, dim=1)
739
+ value_states = value_states.repeat_interleave(saved_groups, dim=1)
740
+ self.num_key_value_groups = 1
741
+
742
  attn_output, attn_weights = attention_interface(
743
  self,
744
  query_states,
745
  key_states,
746
  value_states,
747
+ sdpa_mask,
748
  dropout=0.0 if not self.training else self.attention_dropout,
749
  scaling=self.scaling,
750
  **kwargs,
751
  )
752
+
753
+ self.num_key_value_groups = saved_groups
754
 
755
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
756
  attn_output = self.attn_out(attn_output)