Fix HF inference to match native OLMo outputs exactly
Browse filesFive fixes for bitwise parity: chat template format, GPU pixel normalization, RoPE on-device recomputation, SDPA causal mask, and GQA repeat_interleave.
- chat_template.jinja +1 -1
- image_processing_molmo2.py +17 -2
- modeling_molmo2.py +34 -3
chat_template.jinja
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}
|
|
|
|
| 1 |
+
{% set DEMO_STYLES = ['point_count','pointing','cosyn_point','user_qa','long_caption','short_caption','video_long_caption','video_short_caption','video_point_track_per_frame','video_point_track_start_end','video_point_track_all_frames','video_single_point_track_start_end','video_transcript','video_clip_caption_start_end','video_clip_caption_start_end_in_seconds','video_clip_transcript_start_end','video_clip_transcript_start_end_in_seconds','video_frame_caption_timestamp','video_frame_caption_timestamp_in_seconds','correction_qa','text_sft','video_point','video_point_count','video_count','video_count_point','multi_image_pointing','multi_image_counting','multi_image_point_then_count','multi_image_count_then_point','demo','a_okvqa_mc','ai2_diagram_no_letter','ai2_diagram','science_qa','multi_image_mc','multi_image_mc_exp','mantis_instruct_mc','video_multiple_choice','video_multiple_choice_count_without_pointing','video_multiple_choice_multiple_correct','video_multiple_choice_w_subtitle'] %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set has_subtitle = messages and messages[0]['role'].lower() == 'subtitle' %}{% for message in messages %}{% if message['content'] is not string %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}{% set video_count.value = video_count.value + 1 %}{% endif %}{% endfor %}{% endif %}{% endfor %}{% if image_count.value == 1 %}{{ '<|image|>' }}{% elif image_count.value > 1 %}{% for i in range(image_count.value) %}{{ 'Image ' ~ (i + 1) ~ '<|image|>' }}{% endfor %}{% endif %}{% for _ in range(video_count.value) %}{{ '<|video|>' }}{% endfor %}{% if has_subtitle %}{{ messages[0]['content'] }}{% endif %}{% for message in messages %}{% set role = message['role'].lower() %}{% if role == 'subtitle' %}{% continue %}{% endif %}{% set conv_index = loop.index - (1 if has_subtitle else 0) %}{%- if (conv_index % 2 == 1 and role != 'user') or (conv_index % 2 == 0 and role != 'assistant') -%}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{%- endif -%}{% if message['content'] is string %}{% set text_content = message['content'] %}{% else %}{% set m = namespace(text='') %}{% for content in message['content'] %}{% if content['type'] == 'text' %}{% if content['style'] is defined and content['style'] not in DEMO_STYLES %}{% set seg = content['style'] ~ ': ' ~ content['text'] %}{% else %}{% set seg = content['text'] %}{% endif %}{% set m.text = m.text ~ ('' if not m.text else ' ') ~ seg %}{% endif %}{% endfor %}{% set text_content = m.text %}{% endif %}{% if role == 'user' %}User: {{ text_content }}{% else %} {{ text_content }}{% endif %}{% endfor %}{% if add_generation_prompt %} Assistant:{% endif %}
|
image_processing_molmo2.py
CHANGED
|
@@ -29,7 +29,10 @@ def normalize_image(
|
|
| 29 |
image: np.ndarray,
|
| 30 |
image_mean: list[float],
|
| 31 |
image_std: list[float],
|
|
|
|
| 32 |
) -> np.ndarray:
|
|
|
|
|
|
|
| 33 |
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 34 |
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 35 |
return image
|
|
@@ -110,11 +113,12 @@ def build_resized_image(
|
|
| 110 |
image_mean: list[float],
|
| 111 |
image_std: list[float],
|
| 112 |
image_patch_size: int,
|
|
|
|
| 113 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 114 |
resized = resize_image(
|
| 115 |
image, base_image_input_size, resample,
|
| 116 |
)
|
| 117 |
-
resized = normalize_image(resized, image_mean, image_std)
|
| 118 |
if len(resized.shape) == 3:
|
| 119 |
resized = np.expand_dims(resized, 0)
|
| 120 |
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
@@ -132,6 +136,7 @@ def build_overlapping_crops(
|
|
| 132 |
image_mean: list[float],
|
| 133 |
image_std: list[float],
|
| 134 |
image_patch_size: int,
|
|
|
|
| 135 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 136 |
"""Decompose an image into a set of overlapping crops
|
| 137 |
|
|
@@ -167,7 +172,7 @@ def build_overlapping_crops(
|
|
| 167 |
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
|
| 168 |
resample,
|
| 169 |
)
|
| 170 |
-
src = normalize_image(src, image_mean, image_std)
|
| 171 |
|
| 172 |
# Now we have to split the image into crops, and track what patches came from
|
| 173 |
# where in `patch_idx_arr`
|
|
@@ -259,6 +264,7 @@ def image_to_patches_and_grids(
|
|
| 259 |
image_patch_size: int,
|
| 260 |
image_pooling_w: int,
|
| 261 |
image_pooling_h: int,
|
|
|
|
| 262 |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 263 |
"""
|
| 264 |
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
|
@@ -284,6 +290,7 @@ def image_to_patches_and_grids(
|
|
| 284 |
image_mean,
|
| 285 |
image_std,
|
| 286 |
image_patch_size,
|
|
|
|
| 287 |
)
|
| 288 |
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
| 289 |
h, w = pooling_idx.shape[:2]
|
|
@@ -297,6 +304,7 @@ def image_to_patches_and_grids(
|
|
| 297 |
image_mean,
|
| 298 |
image_std,
|
| 299 |
image_patch_size,
|
|
|
|
| 300 |
)
|
| 301 |
crop_arr = np.concatenate([resized, crop_arr], 0)
|
| 302 |
|
|
@@ -390,6 +398,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
|
|
| 390 |
image_mean: Optional[Union[float, list[float]]] = None,
|
| 391 |
image_std: Optional[Union[float, list[float]]] = None,
|
| 392 |
do_convert_rgb: Optional[bool] = None,
|
|
|
|
| 393 |
max_crops: Optional[int] = None,
|
| 394 |
overlap_margins: Optional[list[int]] = None,
|
| 395 |
patch_size: Optional[int] = None,
|
|
@@ -448,6 +457,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
|
|
| 448 |
image_mean = image_mean or self.image_mean
|
| 449 |
image_std = image_std or self.image_std
|
| 450 |
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
|
|
|
| 451 |
|
| 452 |
max_crops = max_crops or self.max_crops
|
| 453 |
overlap_margins = overlap_margins or self.overlap_margins
|
|
@@ -491,6 +501,7 @@ class Molmo2ImageProcessor(BaseImageProcessor):
|
|
| 491 |
patch_size,
|
| 492 |
image_pooling_w,
|
| 493 |
image_pooling_h,
|
|
|
|
| 494 |
)
|
| 495 |
batch_grids.append(image_grid)
|
| 496 |
batch_crops.append(crops)
|
|
@@ -498,6 +509,10 @@ class Molmo2ImageProcessor(BaseImageProcessor):
|
|
| 498 |
batch_num_crops.append(crops.shape[0])
|
| 499 |
|
| 500 |
pixel_values = np.concatenate(batch_crops, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 502 |
image_grids = np.concatenate(batch_grids, 0)
|
| 503 |
image_num_crops = np.array(batch_num_crops)
|
|
|
|
| 29 |
image: np.ndarray,
|
| 30 |
image_mean: list[float],
|
| 31 |
image_std: list[float],
|
| 32 |
+
do_normalize: bool = True,
|
| 33 |
) -> np.ndarray:
|
| 34 |
+
if not do_normalize:
|
| 35 |
+
return image
|
| 36 |
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
| 37 |
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
| 38 |
return image
|
|
|
|
| 113 |
image_mean: list[float],
|
| 114 |
image_std: list[float],
|
| 115 |
image_patch_size: int,
|
| 116 |
+
do_normalize: bool = True,
|
| 117 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 118 |
resized = resize_image(
|
| 119 |
image, base_image_input_size, resample,
|
| 120 |
)
|
| 121 |
+
resized = normalize_image(resized, image_mean, image_std, do_normalize=do_normalize)
|
| 122 |
if len(resized.shape) == 3:
|
| 123 |
resized = np.expand_dims(resized, 0)
|
| 124 |
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
|
|
| 136 |
image_mean: list[float],
|
| 137 |
image_std: list[float],
|
| 138 |
image_patch_size: int,
|
| 139 |
+
do_normalize: bool = True,
|
| 140 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 141 |
"""Decompose an image into a set of overlapping crops
|
| 142 |
|
|
|
|
| 172 |
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels],
|
| 173 |
resample,
|
| 174 |
)
|
| 175 |
+
src = normalize_image(src, image_mean, image_std, do_normalize=do_normalize)
|
| 176 |
|
| 177 |
# Now we have to split the image into crops, and track what patches came from
|
| 178 |
# where in `patch_idx_arr`
|
|
|
|
| 264 |
image_patch_size: int,
|
| 265 |
image_pooling_w: int,
|
| 266 |
image_pooling_h: int,
|
| 267 |
+
do_normalize: bool = True,
|
| 268 |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 269 |
"""
|
| 270 |
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
|
|
|
| 290 |
image_mean,
|
| 291 |
image_std,
|
| 292 |
image_patch_size,
|
| 293 |
+
do_normalize=do_normalize,
|
| 294 |
)
|
| 295 |
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
| 296 |
h, w = pooling_idx.shape[:2]
|
|
|
|
| 304 |
image_mean,
|
| 305 |
image_std,
|
| 306 |
image_patch_size,
|
| 307 |
+
do_normalize=do_normalize,
|
| 308 |
)
|
| 309 |
crop_arr = np.concatenate([resized, crop_arr], 0)
|
| 310 |
|
|
|
|
| 398 |
image_mean: Optional[Union[float, list[float]]] = None,
|
| 399 |
image_std: Optional[Union[float, list[float]]] = None,
|
| 400 |
do_convert_rgb: Optional[bool] = None,
|
| 401 |
+
do_normalize: Optional[bool] = None,
|
| 402 |
max_crops: Optional[int] = None,
|
| 403 |
overlap_margins: Optional[list[int]] = None,
|
| 404 |
patch_size: Optional[int] = None,
|
|
|
|
| 457 |
image_mean = image_mean or self.image_mean
|
| 458 |
image_std = image_std or self.image_std
|
| 459 |
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
| 460 |
+
do_normalize = do_normalize if do_normalize is not None else False
|
| 461 |
|
| 462 |
max_crops = max_crops or self.max_crops
|
| 463 |
overlap_margins = overlap_margins or self.overlap_margins
|
|
|
|
| 501 |
patch_size,
|
| 502 |
image_pooling_w,
|
| 503 |
image_pooling_h,
|
| 504 |
+
do_normalize=do_normalize,
|
| 505 |
)
|
| 506 |
batch_grids.append(image_grid)
|
| 507 |
batch_crops.append(crops)
|
|
|
|
| 509 |
batch_num_crops.append(crops.shape[0])
|
| 510 |
|
| 511 |
pixel_values = np.concatenate(batch_crops, 0)
|
| 512 |
+
if not do_normalize:
|
| 513 |
+
# Convert to uint8 so the model can normalize on GPU with exact
|
| 514 |
+
# native precision (CPU and GPU float32 /255 differ for ~half of uint8 values).
|
| 515 |
+
pixel_values = np.clip(pixel_values * 255 + 0.5, 0, 255).astype(np.uint8)
|
| 516 |
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
| 517 |
image_grids = np.concatenate(batch_grids, 0)
|
| 518 |
image_num_crops = np.array(batch_num_crops)
|
modeling_molmo2.py
CHANGED
|
@@ -440,7 +440,11 @@ class Molmo2VisionBackbone(nn.Module):
|
|
| 440 |
|
| 441 |
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
| 442 |
batch_size, num_image = images.shape[:2]
|
| 443 |
-
images = images.to(device=self.device
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
image_features = self.encode_image(images)
|
| 445 |
|
| 446 |
image_features = self.image_feature_dropout(image_features)
|
|
@@ -543,7 +547,14 @@ class Molmo2RotaryEmbedding(nn.Module):
|
|
| 543 |
|
| 544 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 545 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 548 |
cos = emb.cos() * self.attention_scaling
|
| 549 |
sin = emb.sin() * self.attention_scaling
|
|
@@ -710,16 +721,36 @@ class Molmo2Attention(nn.Module):
|
|
| 710 |
if self.config._attn_implementation != "eager":
|
| 711 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
attn_output, attn_weights = attention_interface(
|
| 714 |
self,
|
| 715 |
query_states,
|
| 716 |
key_states,
|
| 717 |
value_states,
|
| 718 |
-
|
| 719 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 720 |
scaling=self.scaling,
|
| 721 |
**kwargs,
|
| 722 |
)
|
|
|
|
|
|
|
| 723 |
|
| 724 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 725 |
attn_output = self.attn_out(attn_output)
|
|
|
|
| 440 |
|
| 441 |
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
| 442 |
batch_size, num_image = images.shape[:2]
|
| 443 |
+
images = images.to(device=self.device)
|
| 444 |
+
# Normalize pixel values on GPU: uint8 [0,255] -> float [-1,1]
|
| 445 |
+
# This matches native OLMo's normalize_on_gpu path exactly.
|
| 446 |
+
images = images.float().div_(255.0).mul_(2.0).sub_(1.0)
|
| 447 |
+
images = images.to(dtype=self.dtype)
|
| 448 |
image_features = self.encode_image(images)
|
| 449 |
|
| 450 |
image_features = self.image_feature_dropout(image_features)
|
|
|
|
| 547 |
|
| 548 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 549 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 550 |
+
# Recompute inv_freq directly on the target device to avoid CPU/GPU
|
| 551 |
+
# float32 rounding differences when inv_freq is initialized on CPU.
|
| 552 |
+
dim = self.inv_freq.shape[0] * 2
|
| 553 |
+
inv_freq = 1.0 / (self.config.rope_theta ** (
|
| 554 |
+
torch.arange(0, dim, 2, dtype=torch.float, device=x.device) / dim
|
| 555 |
+
))
|
| 556 |
+
seq = position_ids[0].float()
|
| 557 |
+
freqs = torch.einsum("i , j -> i j", seq, inv_freq).unsqueeze(0)
|
| 558 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 559 |
cos = emb.cos() * self.attention_scaling
|
| 560 |
sin = emb.sin() * self.attention_scaling
|
|
|
|
| 721 |
if self.config._attn_implementation != "eager":
|
| 722 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 723 |
|
| 724 |
+
# During prefill with SDPA, drop the explicit attention mask so SDPA uses
|
| 725 |
+
# is_causal=True internally. This matches native OLMo's behavior and avoids
|
| 726 |
+
# numerical differences from explicit-mask vs is_causal code paths.
|
| 727 |
+
sdpa_mask = attention_mask
|
| 728 |
+
if self.config._attn_implementation == "sdpa" and query_states.shape[2] > 1:
|
| 729 |
+
sdpa_mask = None
|
| 730 |
+
|
| 731 |
+
# Expand GQA key/value heads to match query heads via repeat_interleave,
|
| 732 |
+
# matching native OLMo's approach. This avoids the enable_gqa=True SDPA
|
| 733 |
+
# path which uses a different kernel and produces different float32 results.
|
| 734 |
+
# Temporarily set num_key_value_groups=1 so the HF SDPA wrapper doesn't
|
| 735 |
+
# try to handle GQA again on already-expanded tensors.
|
| 736 |
+
saved_groups = self.num_key_value_groups
|
| 737 |
+
if saved_groups > 1:
|
| 738 |
+
key_states = key_states.repeat_interleave(saved_groups, dim=1)
|
| 739 |
+
value_states = value_states.repeat_interleave(saved_groups, dim=1)
|
| 740 |
+
self.num_key_value_groups = 1
|
| 741 |
+
|
| 742 |
attn_output, attn_weights = attention_interface(
|
| 743 |
self,
|
| 744 |
query_states,
|
| 745 |
key_states,
|
| 746 |
value_states,
|
| 747 |
+
sdpa_mask,
|
| 748 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 749 |
scaling=self.scaling,
|
| 750 |
**kwargs,
|
| 751 |
)
|
| 752 |
+
|
| 753 |
+
self.num_key_value_groups = saved_groups
|
| 754 |
|
| 755 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 756 |
attn_output = self.attn_out(attn_output)
|