nvidia-oliver-holworthy commited on
Commit
382fc3a
·
unverified ·
1 Parent(s): 5cebd3b

Fix MBartDecoderLayer forward pass for transformers 5.x compatibility

Browse files

Detect the renamed `past_key_values` parameter (introduced in ~4.57) and
route through a separate call path that passes the Cache object and handles
both true 5.x (single-Tensor return) and intermediate versions (tuple return)
via an isinstance guard. Backward compatibility with 4.51.x is preserved
through the original singular-param branch.

Signed-off-by: Oliver Holworthy <nvidia-oliver-holworthy@users.noreply.huggingface.co>

Files changed (1) hide show
  1. hf_nemotron_parse_modeling.py +132 -34
hf_nemotron_parse_modeling.py CHANGED
@@ -23,6 +23,38 @@ from transformers.modeling_attn_mask_utils import (
23
  _prepare_4d_causal_attention_mask_for_sdpa,
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class NemotronParseDecoder(MBartPreTrainedModel):
28
  """
@@ -47,7 +79,11 @@ class NemotronParseDecoder(MBartPreTrainedModel):
47
  if embed_tokens is not None:
48
  self.embed_tokens.weight = embed_tokens.weight
49
 
50
- self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
 
 
 
 
51
  self.config = config
52
 
53
  self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -163,8 +199,8 @@ class NemotronParseDecoder(MBartPreTrainedModel):
163
  else:
164
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
165
 
166
- # past_key_values_length
167
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
168
 
169
  if inputs_embeds is None:
170
  inputs_embeds = self.embed_tokens(input_ids)
@@ -221,7 +257,22 @@ class NemotronParseDecoder(MBartPreTrainedModel):
221
  all_hidden_states = () if output_hidden_states else None
222
  all_self_attns = () if output_attentions else None
223
  all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
224
- next_decoder_cache = () if use_cache else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
227
  for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
@@ -240,45 +291,68 @@ class NemotronParseDecoder(MBartPreTrainedModel):
240
  if dropout_probability < self.layerdrop:
241
  continue
242
 
243
- past_key_value = past_key_values[idx] if past_key_values is not None else None
244
-
245
- if self.gradient_checkpointing and self.training:
246
- layer_outputs = self._gradient_checkpointing_func(
247
- decoder_layer.__call__,
248
- hidden_states,
249
- attention_mask,
250
- encoder_hidden_states,
251
- encoder_attention_mask,
252
- head_mask[idx] if head_mask is not None else None,
253
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
254
- None,
255
- output_attentions,
256
- use_cache,
257
- )
258
- else:
259
  layer_outputs = decoder_layer(
260
  hidden_states,
261
  attention_mask=attention_mask,
262
  encoder_hidden_states=encoder_hidden_states,
263
  encoder_attention_mask=encoder_attention_mask,
264
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
265
- cross_attn_layer_head_mask=(
266
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
267
- ),
268
- past_key_value=past_key_value,
269
- output_attentions=output_attentions,
270
  use_cache=use_cache,
271
  )
272
- hidden_states = layer_outputs[0]
273
-
274
- if use_cache:
275
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- if output_attentions:
278
- all_self_attns += (layer_outputs[1],)
 
 
 
279
 
280
- if encoder_hidden_states is not None:
281
- all_cross_attentions += (layer_outputs[2],)
 
 
282
 
283
  hidden_states = self.layer_norm(hidden_states)
284
 
@@ -533,6 +607,30 @@ class NemotronParseForConditionalGeneration(NemotronParsePreTrainedModel, Genera
533
  encoder_attentions=encoder_outputs.attentions,
534
  )
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
537
  return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
538
 
 
23
  _prepare_4d_causal_attention_mask_for_sdpa,
24
  )
25
 
26
+ # ---------------------------------------------------------------------------
27
+ # Cache compatibility (transformers 5.x introduced Cache objects;
28
+ # 4.x used plain tuple-of-tuples for past_key_values)
29
+ # ---------------------------------------------------------------------------
30
+ import inspect
31
+ try:
32
+ from transformers.cache_utils import Cache as _CacheBase
33
+ def _is_cache_object(obj) -> bool:
34
+ return isinstance(obj, _CacheBase)
35
+ except ImportError:
36
+ def _is_cache_object(obj) -> bool:
37
+ return False
38
+
39
+ def _past_key_values_length(past_key_values) -> int:
40
+ """Return the number of already-decoded tokens regardless of cache format."""
41
+ if past_key_values is None:
42
+ return 0
43
+ if _is_cache_object(past_key_values):
44
+ return past_key_values.get_seq_length()
45
+ return past_key_values[0][0].shape[2]
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # MBartDecoderLayer API detection
49
+ #
50
+ # transformers <~4.57: forward() takes `past_key_value` (singular), returns a
51
+ # tuple (hidden_states, [attentions], [present_key_value])
52
+ # transformers >=~4.57: forward() takes `past_key_values` (plural, Cache).
53
+ # True 5.x returns a single torch.Tensor (cache updated in-place);
54
+ # intermediate versions (e.g. 4.57.x) still return a tuple.
55
+ # ---------------------------------------------------------------------------
56
+ _layer_takes_plural_past_kv = 'past_key_values' in inspect.signature(MBartDecoderLayer.forward).parameters
57
+
58
 
59
  class NemotronParseDecoder(MBartPreTrainedModel):
60
  """
 
79
  if embed_tokens is not None:
80
  self.embed_tokens.weight = embed_tokens.weight
81
 
82
+ _layer_supports_idx = 'layer_idx' in inspect.signature(MBartDecoderLayer.__init__).parameters
83
+ self.layers = nn.ModuleList([
84
+ MBartDecoderLayer(config, layer_idx=i) if _layer_supports_idx else MBartDecoderLayer(config)
85
+ for i in range(config.decoder_layers)
86
+ ])
87
  self.config = config
88
 
89
  self.layernorm_embedding = nn.LayerNorm(config.d_model)
 
199
  else:
200
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
201
 
202
+ # past_key_values_length — works with both tuple-of-tuples (4.x) and Cache objects (5.x)
203
+ past_key_values_length = _past_key_values_length(past_key_values)
204
 
205
  if inputs_embeds is None:
206
  inputs_embeds = self.embed_tokens(input_ids)
 
257
  all_hidden_states = () if output_hidden_states else None
258
  all_self_attns = () if output_attentions else None
259
  all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
260
+ # In 5.x the Cache object is updated in-place by each layer, so we just
261
+ # carry the same object through. In 4.x we collect per-layer tuples.
262
+ _using_cache_obj = _is_cache_object(past_key_values)
263
+ next_decoder_cache = past_key_values if (_using_cache_obj and use_cache) else (() if use_cache else None)
264
+
265
+ # 5.x: on the first call (past_key_values=None), create an EncoderDecoderCache
266
+ # so each MBartAttention layer can populate cross-/self-attention KV states
267
+ # in-place. This enables proper KV caching during multi-step generation.
268
+ if _layer_takes_plural_past_kv and use_cache and past_key_values is None:
269
+ try:
270
+ from transformers.cache_utils import EncoderDecoderCache, DynamicCache
271
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
272
+ _using_cache_obj = True
273
+ next_decoder_cache = past_key_values
274
+ except (ImportError, AttributeError, TypeError):
275
+ pass # fallback: layers recompute KV each step (correct but slower)
276
 
277
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
278
  for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
 
291
  if dropout_probability < self.layerdrop:
292
  continue
293
 
294
+ if _layer_takes_plural_past_kv:
295
+ # Plural-param API: cache updated in-place, nothing to collect.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  layer_outputs = decoder_layer(
297
  hidden_states,
298
  attention_mask=attention_mask,
299
  encoder_hidden_states=encoder_hidden_states,
300
  encoder_attention_mask=encoder_attention_mask,
301
+ past_key_values=past_key_values if use_cache else None,
 
 
 
 
 
302
  use_cache=use_cache,
303
  )
304
+ # True 5.x returns a single Tensor; intermediate versions
305
+ # (e.g. 4.57.x) have the renamed parameter but still return
306
+ # a tuple — handle both.
307
+ hidden_states = layer_outputs if isinstance(layer_outputs, torch.Tensor) else layer_outputs[0]
308
+ else:
309
+ # Singular-param API: returns a tuple, collect cache per-layer.
310
+ if past_key_values is None:
311
+ past_key_value = None
312
+ elif _using_cache_obj:
313
+ past_key_value = past_key_values # full Cache object
314
+ else:
315
+ past_key_value = past_key_values[idx] # per-layer tuple
316
+
317
+ if self.gradient_checkpointing and self.training:
318
+ layer_outputs = self._gradient_checkpointing_func(
319
+ decoder_layer.__call__,
320
+ hidden_states,
321
+ attention_mask,
322
+ encoder_hidden_states,
323
+ encoder_attention_mask,
324
+ head_mask[idx] if head_mask is not None else None,
325
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
326
+ None,
327
+ output_attentions,
328
+ use_cache,
329
+ )
330
+ else:
331
+ layer_outputs = decoder_layer(
332
+ hidden_states,
333
+ attention_mask=attention_mask,
334
+ encoder_hidden_states=encoder_hidden_states,
335
+ encoder_attention_mask=encoder_attention_mask,
336
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
337
+ cross_attn_layer_head_mask=(
338
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
339
+ ),
340
+ past_key_value=past_key_value,
341
+ output_attentions=output_attentions,
342
+ use_cache=use_cache,
343
+ )
344
+ hidden_states = layer_outputs[0]
345
 
346
+ if use_cache and not _using_cache_obj:
347
+ # 4.x: cache is the last element of layer_outputs.
348
+ cache_idx = 3 if output_attentions else 1
349
+ if len(layer_outputs) > cache_idx:
350
+ next_decoder_cache += (layer_outputs[cache_idx],)
351
 
352
+ if output_attentions:
353
+ all_self_attns += (layer_outputs[1],)
354
+ if encoder_hidden_states is not None:
355
+ all_cross_attentions += (layer_outputs[2],)
356
 
357
  hidden_states = self.layer_norm(hidden_states)
358
 
 
607
  encoder_attentions=encoder_outputs.attentions,
608
  )
609
 
610
+ def prepare_inputs_for_generation(
611
+ self,
612
+ input_ids,
613
+ past_key_values=None,
614
+ attention_mask=None,
615
+ use_cache=None,
616
+ encoder_outputs=None,
617
+ **kwargs,
618
+ ):
619
+ if past_key_values is not None:
620
+ past_length = _past_key_values_length(past_key_values)
621
+ if input_ids.shape[1] > past_length:
622
+ input_ids = input_ids[:, past_length:]
623
+ else:
624
+ input_ids = input_ids[:, -1:]
625
+ return {
626
+ "pixel_values": None, # encoder_outputs carries the image features
627
+ "encoder_outputs": encoder_outputs,
628
+ "past_key_values": past_key_values,
629
+ "decoder_input_ids": input_ids,
630
+ "decoder_attention_mask": attention_mask,
631
+ "use_cache": use_cache,
632
+ }
633
+
634
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
635
  return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
636