Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ dependencies:
- pandas
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<9
- sphinx_rtd_theme>=1.3.0
- sphinx>=5,<10
- sphinx_rtd_theme>=3.0
- sphinx_copybutton
97 changes: 97 additions & 0 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,100 @@ in their respective folders.

In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy,
unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``.


SBX (SB3 + Jax) Export
----------------------

As an example of manual export, :ref:`Stable Baselines Jax (SBX) <sbx>` policies can be exported to ONNX
by using an intermediate PyTorch representation, as shown in the following example:

.. code-block:: python

import numpy as np
import sbx
import torch as th


class TorchPolicy(th.nn.Module):
def __init__(self, obs_dim: int, hidden_dim: int, act_dim: int):
super().__init__()
self.net = th.nn.Sequential(
th.nn.Linear(obs_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, act_dim),
)

def forward(self, x: th.Tensor) -> th.Tensor:
return self.net(x)


model = sbx.PPO("MlpPolicy", "Pendulum-v1")
# Also possible: load a trained model
# model = sbx.PPO.load("PathToTrainedModel.zip")

params = model.policy.actor_state.params["params"]
# For debug:
print("=== SBX params ===")
for key, value in params.items():
if isinstance(value, dict):
for name, val in value.items():
print(f"{key}.{name}: {val.shape}", end=" ")
else:
print(f"{key}: {value.shape}", end=" ")
print("\n" + "=" * 20 + "\n")

obs_dim = model.observation_space.shape
act_dim = model.action_space.shape

# Number of units in the hidden layers (assume a network architecture like [64, 64])
hidden_dim = params["Dense_0"]["kernel"].shape[1]

# map params to torch state_dict keys
num_layers = len([k for k in params.keys() if k.startswith("Dense_")])
state_dict = {}
for i in range(num_layers):
layer_name = f"Dense_{i}"
state_dict[f"net.{i * 2}.bias"] = th.from_numpy(np.array(params[layer_name]["bias"]))
state_dict[f"net.{i * 2}.weight"] = th.from_numpy(np.array(params[layer_name]["kernel"].T))

torch_policy = TorchPolicy(obs_dim[0], hidden_dim, act_dim[0])
print("=== Torch params ===")
print(" ".join(f"{key}:{tuple(value.shape)}" for key, value in torch_policy.named_parameters()))
print("=" * 20 + "\n")

torch_policy.load_state_dict(state_dict)
torch_policy.eval()

dummy_input = th.zeros((1, *obs_dim))
# Use normal Torch export
th.onnx.export(
torch_policy,
(dummy_input,),
"my_ppo_actor.onnx",
opset_version=18,
input_names=["input"],
output_names=["action"],
)


##### Load and test with onnx

import onnxruntime as ort

onnx_path = "my_ppo_actor.onnx"
ort_sess = ort.InferenceSession(onnx_path)

observation = np.random.random((1, *obs_dim)).astype(np.float32)
action = ort_sess.run(None, {"input": observation})[0]

print(action)
sbx_action, _ = model.predict(observation, deterministic=True)
with th.no_grad():
torch_action = torch_policy(th.as_tensor(observation))

# Check that the predictions are the same
assert np.allclose(sbx_action, action)
assert np.allclose(sbx_action, torch_action.numpy())
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added a note on MultiDiscrete spaces with multi-dimensional arrays and a wrapper to fix the issue (@unexploredtest)

- Added an example of manual export of SBX (SB3 + Jax) model to ONNX (@m-abr)

Release 2.7.1 (2025-12-05)
--------------------------
Expand Down Expand Up @@ -1949,3 +1949,4 @@ And all the contributors:
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest
@m-abr
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@
"black>=26.1.0,<27",
],
"docs": [
"sphinx>=5,<9",
"sphinx>=5,<10",
"sphinx-autobuild",
"sphinx-rtd-theme>=1.3.0",
"sphinx-rtd-theme>=3.0.0",
# For spelling
"sphinxcontrib.spelling",
# Copy button for code snippets
Expand Down