Create an Experiment#



1. Pick a Sequence Embedding (TstepEncoder)#

Each timestep provides a dict observation along with the previous action and reward. AMAGO standardizes its training process by creating a TstepEncoder to map timesteps to a fixed size representation. After this, the rest of the network can be environment-agnostic. We include customizable defaults for the two most common cases of images (CNNTstepEncoder) and state arrays (FFTstepEncoder). All we need to do is tell the Experiment which type to use:

from amago.nets.tstep_encoders import CNNTstepEncoder

experiment = amago.Experiment(
    make_train_env=make_env,
    ...,
    tstep_encoder_type=CNNTstepEncoder,
)

2. Pick a Sequence Model (TrajEncoder)#

The TrajEncoder is a seq2seq model that enables long-term memory and in-context learning by processing a sequence of TstepEncoder outputs. amago.nets.traj_encoders includes four built-in options: FFTrajEncoder, GRUTrajEncoder, MambaTrajEncoder, and TformerTrajEncoder.

We can select a TrajEncoder just like a TstepEncoder:

from amago.nets.traj_encoders import MambaTrajEncoder

experiment = amago.Experiment(
    ...,
    traj_encoder_type=MambaTrajEncoder,
)

3. Pick an Agent#

The Agent puts everything together and handles actor-critic RL training ontop of the outputs of the TrajEncoder. There are two built-in (highly configurable) options: Agent and MultiTaskAgent.

We can switch between them with:

from amago.agent import MultiTaskAgent

experiment = amago.Experiment(
    ...,
    agent_type=MultiTaskAgent,
)

4. Create the Experiment and Start Training#

Launch training with:

experiment = amago.Experiment(
    # final required args we haven't mentioned
    run_name="some_name", # a name used for checkpoints and logging
    ckpt_base_dir="some/place/", # path to checkpoint directory
    val_timesteps_per_epoch=1000, # give actors enough time to finish >= 1 episode
    max_seq_len=128, # maximum sequence length for the TrajEncoder
    ...
)
experiment.start()
experiment.learn()

Checkpoints and logs are saved in:

{Experiment.ckpt_base_dir}
    |-- {Experiment.run_name}/
        |-- config.txt # stores gin configuration details for reproducibility
        |-- wandb_logs/
        |-- ckpts/
                |-- training_states/
                |    | # full checkpoint dirs used to restore `accelerate` training runs
                |    |-- {Experiment.run_name}_epoch_0/
                |    |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/
                |    |-- ...
                |
                |-- latest/
                |    |--policy.pt # the latest model weights
                |-- policy_weights/
                    | # standard pytorch weight files
                    |-- policy_epoch_0.pt
                    |-- policy_epoch_{Experiment.ckpt_interval}.pt
                    |-- ...

Each epoch, we:

  1. Interact with the training envs for train_timesteps_per_epoch, creating a total of parallel_actors * train_timesteps_per_epoch new timesteps.

  2. Save any training sequences that have finished, if applicable.

  3. Compute the RL training objectives on train_batches_per_epoch batches sampled from the dataset. Gradient steps are taken every batches_per_update batches.