ARX5 MuJoCo World Model - nanoGPT h32

State-only world model for the ARX5 7-DOF robot arm in MuJoCo. Predicts next state (14D: 7 qpos + 7 qvel) from state history and action.

Architecture

Parameter Value
Type nanoGPT (causal transformer)
Layers 6
Heads 12
Embedding dim 384
Context length 32 (history_horizon=31)
Parameters ~10.7M
Ensemble size 2
State dim 14 (7 qpos + 7 qvel)
Action dim 7 (joint position targets)

Training

Metric Value
best_val_loss 0.1058 (step 50,000)
train_loss 0.129 (step 50,000)
Steps completed 75,000 (early stopped)
Batch size 64
Learning rate 3e-4 (cosine decay to 3e-5)
Runtime ~2h 42m on NVIDIA L4 (GCloud)
Dataset 1,024 episodes, 3.5M frames

Trained on pravsels/arx5-mujoco-trajectories.

Files

checkpoints/step_50000/params/params.pt        # Model weights (84MB)
checkpoints/step_50000/train_state/train_state.pt  # Optimizer state (86MB)
assets/normalization_stats.json                 # z-score mean/std
TRAINING_LOG.md                                 # Training log

Checkpoint Hashes

Verify integrity after download:

cd checkpoints/step_50000
find params -type f | sort | xargs sha256sum | sha256sum
# Expected: c27e9b11f1d5d1291a96d305fc0d846c28e91dfc66f578a66942e9a71ae74b32

find train_state -type f | sort | xargs sha256sum | sha256sum
# Expected: 05ec0450208e92822f10c6f4c14e80131a92eeef550793e01e8c50e315c9370a

Usage

import torch
from rsl_rl.modules.system_dynamics import SystemDynamicsEnsemble

ckpt = torch.load("checkpoints/step_50000/params/params.pt", map_location="cuda")

model = SystemDynamicsEnsemble(
    state_dim=14, action_dim=7, device="cuda",
    ensemble_size=2, history_horizon=31,
    architecture_config={"type": "gpt", "n_embd": 384, "n_head": 12, "n_layer": 6,
                         "block_size": 32, "bias": False, "dropout": 0.0,
                         "state_mean_shape": [384], "state_logstd_shape": [384]},
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

Source

  • Repo: github.com/pravsels/rsl_rwm (branch: feat/arx5_mujoco_wm)
  • Config: config/arx5_mujoco/arx5_mujoco_rsl_rwm_h32_gpt_gcloud.yaml
  • Commit: 1842346
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support