celo2-base / README.md
amoudgl's picture
Update README.md
bbf68fd verified
---
license: mit
library_name: optax
tags:
- optimizer
- learned-optimizer
- meta-learning
- jax
---
# Celo2-base
<p>
<a href="https://arxiv.org/abs/2602.19142"><img alt="Paper" src="https://img.shields.io/badge/arXiv-2602.19142-b31b1b.svg"></a>
<a href="https://github.com/amoudgl/celo2"><img alt="Code" src="https://img.shields.io/badge/GitHub-black?logo=github&logoColor=white&labelColor=grey"></a>
<a href="https://opensource.org/licenses/MIT"><img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
</p>
Official pretrained weights for **Celo2-base** learned update rule: This variant uses the learned update rule for all parameters without any optimization harness. For better performance, see [celo2](https://huggingface.co/amoudgl/celo2) that uses Newton-Schulz orthogonalization and AdamW for biases/embeddings.
## Quickstart
Download checkpoint and install:
```bash
pip install git+https://github.com/amoudgl/celo2.git
hf download amoudgl/celo2-base --local-dir ./celo2-base
```
Use `load_checkpoint` method to fetch pretrained params from checkpoint path:
```python
from celo2_optax import load_checkpoint
pretrained_params = load_checkpoint('./celo2-base/theta.state')
```
Standard optax usage with `scale_by_celo2` method that takes pretrained params as input:
```python
import optax
from celo2_optax import scale_by_celo2
optimizer = optax.chain(
scale_by_celo2(pretrained_params, orthogonalize=False),
optax.add_decayed_weights(weight_decay),
optax.scale_by_learning_rate(lr_schedule),
)
```
## Loading and inspecting MLP update rule weights
```python
from celo2_optax import load_checkpoint
import jax
pretrained_params = load_checkpoint('./celo2-base/theta.state') # dictionary containing weights
print(jax.tree.map(lambda x: x.shape, pretrained_params))
```
The checkpoint contains a small MLP stored under the `ff_mod_stack` key with weight matrices (`w0__*`, `w1`, `w2`) and biases (`b0`, `b1`, `b2`). Each `w0__*` key contains weights corresponding to particular input feature such as momentum, gradient, parameter, etc.
## Meta-training config
| Key | Value |
| ----------------------- | ------------------------------------------------------------ |
| **Optimizer architecture** | MLP, 2 hidden layers, 8 units each |
| **Meta-training tasks** | 4 image classification tasks (MNIST, FMNIST, CIFAR-10, SVHN) |
| **Task architecture** | MLP (64-32-10) |
| **Meta-trainer** | Persistent Evolution Strategies (PES) |
| **Outer iterations** | 100K |
| **Truncation length** | 50 |
| **Min unroll length** | 100 |
| **Max unroll length** | 2000 |
For more details, see config JSON included in the repo [here](./config.json).
## Files
| File | Description |
| ------------- | -------------------------------- |
| `theta.state` | Pretrained MLP optimizer weights |
| `config.json` | Meta-training configuration |
## Citation
```bibtex
@misc{moudgil2026celo2,
title={Celo2: Towards Learned Optimization Free Lunch},
author={Abhinav Moudgil and Boris Knyazev and Eugene Belilovsky},
year={2026},
eprint={2602.19142},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2602.19142},
}
```