Create an Experiment#
1. Pick a Sequence Embedding (TstepEncoder
)#
Each timestep provides a dict observation along with the previous action and reward.
AMAGO standardizes its training process by creating a TstepEncoder
to map timesteps to a fixed size representation.
After this, the rest of the network can be environment-agnostic.
We include customizable defaults for the two most common cases of images (CNNTstepEncoder
)
and state arrays (FFTstepEncoder
).
All we need to do is tell the Experiment
which type to use:
from amago.nets.tstep_encoders import CNNTstepEncoder
experiment = amago.Experiment(
make_train_env=make_env,
...,
tstep_encoder_type=CNNTstepEncoder,
)
2. Pick a Sequence Model (TrajEncoder
)#
The TrajEncoder
is a seq2seq model that enables long-term memory and in-context learning by processing a sequence of TstepEncoder
outputs. amago.nets.traj_encoders
includes four built-in options:
FFTrajEncoder
, GRUTrajEncoder
, MambaTrajEncoder
, and TformerTrajEncoder
.
We can select a TrajEncoder
just like a TstepEncoder
:
from amago.nets.traj_encoders import MambaTrajEncoder
experiment = amago.Experiment(
...,
traj_encoder_type=MambaTrajEncoder,
)
3. Pick an Agent
#
The Agent
puts everything together and handles actor-critic RL training ontop of the outputs of the TrajEncoder
.
There are two built-in (highly configurable) options: Agent
and MultiTaskAgent
.
We can switch between them with:
from amago.agent import MultiTaskAgent
experiment = amago.Experiment(
...,
agent_type=MultiTaskAgent,
)
4. Create the Experiment
and Start Training#
Launch training with:
experiment = amago.Experiment(
# final required args we haven't mentioned
run_name="some_name", # a name used for checkpoints and logging
ckpt_base_dir="some/place/", # path to checkpoint directory
val_timesteps_per_epoch=1000, # give actors enough time to finish >= 1 episode
max_seq_len=128, # maximum sequence length for the TrajEncoder
...
)
experiment.start()
experiment.learn()
Checkpoints and logs are saved in:
{Experiment.ckpt_base_dir}
|-- {Experiment.run_name}/
|-- config.txt # stores gin configuration details for reproducibility
|-- wandb_logs/
|-- ckpts/
|-- training_states/
| | # full checkpoint dirs used to restore `accelerate` training runs
| |-- {Experiment.run_name}_epoch_0/
| |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/
| |-- ...
|
|-- latest/
| |--policy.pt # the latest model weights
|-- policy_weights/
| # standard pytorch weight files
|-- policy_epoch_0.pt
|-- policy_epoch_{Experiment.ckpt_interval}.pt
|-- ...
Each epoch
, we:
Interact with the training envs for
train_timesteps_per_epoch
, creating a total ofparallel_actors * train_timesteps_per_epoch
new timesteps.Save any training sequences that have finished, if applicable.
Compute the RL training objectives on
train_batches_per_epoch
batches sampled from the dataset. Gradient steps are taken everybatches_per_update
batches.