diff --git a/docs/examples/te_llama/requirements.txt b/docs/examples/te_llama/requirements.txt new file mode 100644 index 0000000000..093849001b --- /dev/null +++ b/docs/examples/te_llama/requirements.txt @@ -0,0 +1,5 @@ +transformers==4.57.0 +accelerate==1.10.0 +peft==0.15.2 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index b2d4d183ab..6dfa9b67bb 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -72,10 +72,15 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs): forward pass of the `TransformerLayer`. Also, make sure the output format matches the output of the HF's `LlamaDecoderLayer`. """ - return ( - super().forward( - hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb - ), + # Handle case where hidden_states might be a tuple (from previous layer output) + # This can happen with older versions of HuggingFace transformers + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Return tensor directly for HuggingFace transformers >= 4.57 + # (older versions wrapped output in tuple and extracted with layer_outputs[0]) + return super().forward( + hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb ) @@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config): # collect all layer prefixes to update all_layer_prefixes = set() for param_key in hf_state_dict.keys(): - layer_prefix_pat = "model.layers.\d+." + layer_prefix_pat = r"model.layers.\d+." m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 00499cff5f..42eee386b9 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -29,12 +29,11 @@ " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", "2. `utils.py`\n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", - "3. `media/`\n", + "3. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "4. `media/`\n", " - This directory contains the images used in the following tutorial.\n", "\n", - "These packages are necessary to run this tutorial:\n", - "`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n", - "\n", "\n", "
\n", "\n", @@ -45,6 +44,27 @@ "
\n" ] }, + { + "cell_type": "markdown", + "id": "b56526b3", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Install the required Python packages using the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "099697e2", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ] + }, { "cell_type": "markdown", "id": "44abae4f", @@ -233,12 +253,12 @@ "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", "Average time taken per step: 248 milliseconds\n" - ] + ], + "name": "stdout" } ], "source": [ @@ -568,12 +588,12 @@ "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", "Average time taken per step: 185 milliseconds\n" - ] + ], + "name": "stdout" } ], "source": [ @@ -657,12 +677,12 @@ "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", "Average time taken per step: 160 milliseconds\n" - ] + ], + "name": "stdout" } ], "source": [ @@ -755,9 +775,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.13.3" } }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file