Skip to content

Commit 2dbfbc7

Browse files
authored
fix(examples): te_llama compatibility with transformers >= 4.57 (#2572)
* fix(examples): te_llama compatibility with HuggingFace transformers >= 4.57 The te_llama.py example was failing with HuggingFace transformers 4.57+ due to API changes in how decoder layer outputs are handled. Changes: - Handle case where hidden_states is passed as a tuple (older HF versions) - Return tensor directly instead of wrapped in tuple (HF 4.57+ expects this) - Fix regex pattern to use raw string (fixes SyntaxWarning) Error fixed: AttributeError: 'tuple' object has no attribute 'contiguous' Tested with: - transformer_engine 2.5.0 - transformers 4.57.3 - PyTorch container nvcr.io/nvidia/pytorch:25.08-py3 Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com> * docs(te_llama): add requirements.txt Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com> * fix(docs): add missing notebook output names Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com> --------- Signed-off-by: Santosh Bhavani <santosh.bhavani@live.com>
1 parent 7259276 commit 2dbfbc7

File tree

3 files changed

+793
-762
lines changed

3 files changed

+793
-762
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
transformers==4.57.0
2+
accelerate==1.10.0
3+
peft==0.15.2
4+
datasets==4.0.0
5+
sentencepiece==0.2.1

docs/examples/te_llama/te_llama.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,15 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs):
7272
forward pass of the `TransformerLayer`. Also, make sure the output
7373
format matches the output of the HF's `LlamaDecoderLayer`.
7474
"""
75-
return (
76-
super().forward(
77-
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
78-
),
75+
# Handle case where hidden_states might be a tuple (from previous layer output)
76+
# This can happen with older versions of HuggingFace transformers
77+
if isinstance(hidden_states, tuple):
78+
hidden_states = hidden_states[0]
79+
80+
# Return tensor directly for HuggingFace transformers >= 4.57
81+
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
82+
return super().forward(
83+
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
7984
)
8085

8186

@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
162167
# collect all layer prefixes to update
163168
all_layer_prefixes = set()
164169
for param_key in hf_state_dict.keys():
165-
layer_prefix_pat = "model.layers.\d+."
170+
layer_prefix_pat = r"model.layers.\d+."
166171
m = re.match(layer_prefix_pat, param_key)
167172
if m is not None:
168173
all_layer_prefixes.add(m.group())

0 commit comments

Comments
 (0)