-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
79 lines (65 loc) · 2.19 KB
/
test_model.py
File metadata and controls
79 lines (65 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict
from mujoco_playground._src import mjx_env
import time
import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp
import mujoco
from mujoco import mjx
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from RoboRLEnv import RoboRLEnv
"""
Run this script in the beginning to validate your .xml file.
"""
# Load model and print some info statements
mj_model = mujoco.MjModel.from_xml_path("robot.xml")
print("Model Details:")
print("Number of bodies:", mj_model.nbody)
print("Number of joints:", mj_model.njnt)
print("Number of geoms:", mj_model.ngeom)
# Init env
env = RoboRLEnv()
print("\nDetailed Initialization Timing:")
start_total = time.time()
# Time XML loading
start = time.time()
mj_model = mujoco.MjModel.from_xml_path(env._xml_path)
print(f"XML Loading time: {time.time() - start:.4f} seconds")
# Time MjData creation
start = time.time()
mj_data = mujoco.MjData(mj_model)
print(f"MjData creation time: {time.time() - start:.4f} seconds")
# Time MJX model conversion
start = time.time()
mjx_model = mjx.put_model(mj_model)
print(f"MJX model conversion time: {time.time() - start:.4f} seconds")
# Time MJX data conversion
start = time.time()
mjx_data = mjx.put_data(mj_model, mj_data)
print(f"MJX data conversion time: {time.time() - start:.4f} seconds")
# Time mjx_env.init
start = time.time()
data = mjx_env.init(mjx_model)
print(f"mjx_env.init time: {time.time() - start:.4f} seconds")
print(f"\nTotal initialization time: {time.time() - start_total:.4f} seconds")