Create an Experiment#
1. Pick a Sequence Embedding (TstepEncoder)#
Each timestep provides a dict observation along with the previous action and reward.
AMAGO standardizes its training process by creating a TstepEncoder to map timesteps to a fixed size representation.
After this, the rest of the network can be environment-agnostic.
We include customizable defaults for the two most common cases of images (CNNTstepEncoder)
and state arrays (FFTstepEncoder).
All we need to do is tell the Experiment which type to use:
from amago.nets.tstep_encoders import CNNTstepEncoder
experiment = amago.Experiment(
make_train_env=make_env,
...,
tstep_encoder_type=CNNTstepEncoder,
)
2. Pick a Sequence Model (TrajEncoder)#
The TrajEncoder is a seq2seq model that enables long-term memory and in-context learning by processing a sequence of TstepEncoder outputs. amago.nets.traj_encoders includes four built-in options:
FFTrajEncoder, GRUTrajEncoder, MambaTrajEncoder, and TformerTrajEncoder.
We can select a TrajEncoder just like a TstepEncoder:
from amago.nets.traj_encoders import MambaTrajEncoder
experiment = amago.Experiment(
...,
traj_encoder_type=MambaTrajEncoder,
)
3. Pick an Agent#
The Agent puts everything together and handles actor-critic RL training ontop of the outputs of the TrajEncoder.
There are two built-in (highly configurable) options: Agent and MultiTaskAgent.
We can switch between them with:
from amago.agent import MultiTaskAgent
experiment = amago.Experiment(
...,
agent_type=MultiTaskAgent,
)
4. Create the Experiment and Start Training#
Launch training with:
experiment = amago.Experiment(
# final required args we haven't mentioned
run_name="some_name", # a name used for checkpoints and logging
ckpt_base_dir="some/place/", # path to checkpoint directory
val_timesteps_per_epoch=1000, # give actors enough time to finish >= 1 episode
max_seq_len=128, # maximum sequence length for the TrajEncoder
...
)
experiment.start()
experiment.learn()
Checkpoints and logs are saved in:
{Experiment.ckpt_base_dir}
|-- {Experiment.run_name}/
|-- config.txt # stores gin configuration details for reproducibility
|-- wandb_logs/
|-- ckpts/
|-- training_states/
| | # full checkpoint dirs used to restore `accelerate` training runs
| |-- {Experiment.run_name}_epoch_0/
| |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/
| |-- ...
|
|-- latest/
| |--policy.pt # the latest model weights
|-- policy_weights/
| # standard pytorch weight files
|-- policy_epoch_0.pt
|-- policy_epoch_{Experiment.ckpt_interval}.pt
|-- ...
Each epoch, we:
Interact with the training envs for
train_timesteps_per_epoch, creating a total ofparallel_actors * train_timesteps_per_epochnew timesteps.Save any training sequences that have finished, if applicable.
Compute the RL training objectives on
train_batches_per_epochbatches sampled from the dataset. Gradient steps are taken everybatches_per_updatebatches.