Skip to content

Fix JaxMARL walkthrough notebook API calls and dependencies, Fixes #154 (and some more)#163

Merged
amacrutherford merged 1 commit intoFLAIROx:mainfrom
zitr0y:fix-walkthrough-notebook-api-calls
Dec 19, 2025
Merged

Fix JaxMARL walkthrough notebook API calls and dependencies, Fixes #154 (and some more)#163
amacrutherford merged 1 commit intoFLAIROx:mainfrom
zitr0y:fix-walkthrough-notebook-api-calls

Conversation

@zitr0y
Copy link
Copy Markdown
Contributor

@zitr0y zitr0y commented Dec 2, 2025

Fixes critical bugs that prevented the walkthrough notebook from running:

  1. Fixed action_space() and observation_space() calls to include agent parameter

    • Changed env.action_space().n to env.action_space("agent_0").n
    • Changed env.observation_space().shape to env.observation_space("agent_0").shape
  2. Added missing dependencies to installation cell

    • Added distrax (required for Part 3)
    • Added hydra-core (required for Part 3)
  3. Fixed info dictionary handling in training loop

    • Properly reshape returned_episode_returns and returned_episode metrics
    • Replaced jax.tree_map with explicit dictionary construction

These changes allow the notebook to run successfully from start to finish.

Fixes critical bugs that prevented the walkthrough notebook from running:

1. Fixed action_space() and observation_space() calls to include agent parameter
   - Changed env.action_space().n to env.action_space("agent_0").n
   - Changed env.observation_space().shape to env.observation_space("agent_0").shape

2. Added missing dependencies to installation cell
   - Added distrax (required for Part 3)
   - Added hydra-core (required for Part 3)

3. Fixed info dictionary handling in training loop
   - Properly reshape returned_episode_returns and returned_episode metrics
   - Replaced jax.tree_map with explicit dictionary construction

These changes allow the notebook to run successfully from start to finish.
@zitr0y
Copy link
Copy Markdown
Contributor Author

zitr0y commented Dec 2, 2025

Changes to jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb

Cell 2 (installation):

- !pip install --upgrade -qq "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- !pip install -qq matplotlib jaxmarl pettingzoo
+ !pip install -qq matplotlib jaxmarl pettingzoo jax[cuda] distrax hydra-core
+ #Note: numpy version warnings are expected and can be safely ignored

Cell 11 (training code):

- env.action_space().n
+ env.action_space("agent_0").n

- env.observation_space().shape
+ env.observation_space("agent_0").shape

- info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
+ info = {
+     "returned_episode_returns": info["returned_episode_returns"].reshape(-1),
+     "returned_episode": info["returned_episode"].reshape(-1),
+ }

Didn't run before, now it does :) There's some more possible improvements like unused imports but I thought I'd just make it work for starters

@zitr0y zitr0y changed the title Fix JaxMARL walkthrough notebook API calls and dependencies Fix JaxMARL walkthrough notebook API calls and dependencies, Fixes #154 (and some more) Dec 4, 2025
@zitr0y
Copy link
Copy Markdown
Contributor Author

zitr0y commented Dec 4, 2025

Fixes #154 (and some more).

I don't think the docker fail has anything to do with me changing a few lines in the example notebook

@amacrutherford amacrutherford merged commit 63226c5 into FLAIROx:main Dec 19, 2025
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants