Configure

Configure#


The Experiment has lots of other kwargs to control things like the update:data ratio, optimization, and logging.


For additional control over the training process, use gin. First, note that the general format of an AMAGO training script is:

from amago import Experiment
from amago.envs import AMAGOEnv

# define the env
make_env = lambda: AMAGOEnv(...)

# create a dataset
dataset = DiskTrajDataset(...)

# make the main conceptual choices
tstep_encoder_type = CNNTstepEncoder
traj_encoder_type = TformerTrajEncoder
agent_type = Agent
exploration_wrapper_type = EpsilonGreedy

experiment = Experiment(
    dataset=dataset,
    # assign lots of *callables*, not instances, to the experiment
    make_train_env=make_env,
    make_val_env=make_env,
    tstep_encoder_type=tstep_encoder_type,
    traj_encoder_type=traj_encoder_type,
    agent_type=agent_type,
    exploration_wrapper_type=exploration_wrapper_type,
    ...
)

experiment.start()
experiment.learn()
experiment.evaluate_test(make_env)

We choose the AMAGOEnv, RLDataset, and TstepEncoder because they are problem-specific. We also pick the TrajEncoder because it is the key feature of a sequence model agent. start() is going to create an Agent based on the environment and our other choices. This follows a strict rule:

Important

Anytime AMAGO needs to initialize/call a class/method, it infers the positional args (based on the environment and our other choices), but leaves every keyword argument set to its default value. gin lets us edit those values without editing the source code, and keeps track of the settings we used on wandb.

The Examples show how almost every application of AMAGO looks the same aside from some minor gin configuration. gin can be complicated, but AMAGO tries to make it hard to get wrong:

Tip

If something is @gin.configurable (there will be a note at the top of the documentation), it means that Experiment wlll only ever call/construct it with default keyword arguments, and there is no other place where those values should be set or will be overridden. The only exceptions are Experiment and RLDataset, which are explicitly constructed by the user before training begins, but are configurable for convenience.


For example, let’s say we want to switch the CNNTstepEncoder to use a larger IMPALA architecture with twice as many channels as usual. The API reference for CNNTstepEncoder looks like this:

class CNNTstepEncoder(obs_space, rl2_space, cnn_type=<class 'amago.nets.cnn.NatureishCNN'>, channels_first=False, img_features=256, rl2_features=12, d_output=256, out_norm='layer', activation='leaky_relu', hide_rl2s=False, drqv2_aug=False, aug_pct_of_batch=0.75, obs_key='observation')[source]

Bases: TstepEncoder

A simple CNN-based TstepEncoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Useful for pixel-based environments. Currently only supports the case where observations are a single image without additional state arrays.

Parameters:
  • obs_space (Space) – Environment observation space.

  • rl2_space (Space) – A gym space declaring the shape of previous action and reward features. This is created by the AMAGOEnv wrapper.

  • cnn_type (Type[CNN]) – The type of nets.cnn.CNN to use. Defaults to nets.cnn.NatureishCNN (the small DQN CNN).

  • channels_first (bool) – Whether the image is in channels-first format. Defaults to False.

  • img_features (int) – Linear map the output of the CNN to this many features. Defaults to 256.

  • rl2_features (int) – Linear map the previous action and reward to this many features. Defaults to 12.

  • d_output (int) – The output dimension of a layer that fuses the img_features and rl2_features. Defaults to 256.

  • out_norm (str) – The normalization layer to use. See nets.ff.Normalization for options. Defaults to “layer”.

  • activation (str) – The activation function to use. See nets.utils.activation_switch for options. Defaults to “leaky_relu”.

  • hide_rl2s (bool) – Whether to ignore the previous action and reward features (but otherwise keep the same parameter count and layer dimensions).

  • drqv2_aug (bool) – Quick-apply the default DrQv2 image augmentation. Applies random crops to `aug_pct_of_batch`% of every batch during training. Currently requires square images. Defaults to False.

  • aug_pct_of_batch (float) – The percentage of every batch to apply DrQv2 augmentation to, if drqv2_aug is True. Defaults to 0.75.

  • obs_key (str) – The key in the observation space that contains the image. Defaults to “observation”, which is the default created by AMAGOEnv when the original observation space is not a dict.


Following our rule, obs_space and rl2_space are going be determined for us, but nothing will try to set cnn_type, so it will default to NatureishCNN. The IMPALAishCNN looks like this:


class IMPALAishCNN(img_shape, channels_first, activation, cnn_block_depths=[16, 32, 32], post_group_norm=True)[source]

Bases: CNN

CNN architecture from IMPALA.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • img_shape (tuple[int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

  • cnn_block_depths (list[int]) – List of ints representing the number of output channels for each convolutional block. Length defines the number of residual blocks. Defaults to [16, 32, 32].

  • post_group_norm (bool) – Whether to use group normalization after each convolutional block. Defaults to True.

So we can change the cnn_block_depths and post_group_norm by editing these values, but this would not be the place to change the activation. The most orgnaized way to set gin values is with .gin config files. But we can also do this with:


from amago.nets.cnn import IMPALAishCNN
from amago.cli_utils import use_config

config = {
    "amago.nets.tstep_encoders.CNNTstepEncoder.cnn_type" : IMPALAishCNN,
    "IMPALAishCNN.cnn_block_depths" : [32, 64, 64],
}
# changes the default values
use_config(config)

experiment = Experiment(
    tstep_encoder_type=CNNTstepEncoder,
    ...
)

As a more complicated example, let’s say we want to use a TformerTrajEncoder with 6 layers of dimension 512, 16 heads, and sliding window attention with a window size of 256.


class TformerTrajEncoder(tstep_dim, max_seq_len, d_model=256, n_heads=8, d_ff=1024, n_layers=3, dropout_ff=0.05, dropout_emb=0.05, dropout_attn=0.0, dropout_qkv=0.0, activation='leaky_relu', norm='layer', pos_emb='fixed', sigma_reparam=True, normformer_norms=True, head_scaling=True, attention_type=<class 'amago.nets.transformer.FlashAttention'>)[source]

Bases: TrajEncoder

Transformer Trajectory Encoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A pre-norm Transformer decoder-only model that processes sequences of timestep embeddings.

Parameters:
  • tstep_dim (int) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len (int) – Maximum sequence length of the model. The max context length of the model during training.

  • d_model (int) – Dimension of the main residual stream and output. Defaults to 256.

  • n_heads (int) – Number of self-attention heads. Each head has dimension d_model/n_heads. Defaults to 8.

  • d_ff (int) – Dimension of feed-forward network in residual blocks. Defaults to 4*d_model.

  • n_layers (int) – Number of Transformer layers. Defaults to 3.

  • dropout_ff (float) – Dropout rate for linear layers within Transformer. Defaults to 0.05.

  • dropout_emb (float) – Dropout rate for input embedding (combined input sequence and position embeddings passed to Transformer). Defaults to 0.05.

  • dropout_attn (float) – Dropout rate for attention matrix. Defaults to 0.00.

  • dropout_qkv (float) – Dropout rate for query/key/value projections. Defaults to 0.00.

  • activation (str) – Activation function. Defaults to “leaky_relu”.

  • norm (str) – Normalization function. Defaults to “layer” (LayerNorm).

  • pos_emb (str) – Position embedding type. “fixed” (default) uses sinusoidal embeddings, “learned” uses trainable embeddings per timestep.

  • causal – Whether to use causal attention mask. Defaults to True.

  • sigma_reparam (bool) – Whether to use \(\sigma\)-reparam feed-forward layers from https://arxiv.org/abs/2303.06296. Defaults to True.

  • normformer_norms (bool) – Whether to use extra norm layers from NormFormer (https://arxiv.org/abs/2110.09456). Always uses pre-norm Transformer.

  • head_scaling (bool) – Whether to use head scaling from NormFormer. Defaults to True.

  • attention_type (type[SelfAttention]) – Attention layer type. Defaults to transformer.FlashAttention. transformer.VanillaAttention provided as backup. New types can inherit from transformer.SelfAttention.

class SlidingWindowFlexAttention(causal, dropout, window_size=gin.REQUIRED)[source]

Bases: FlexAttention

A more useful test of FlexAttention that implements a sliding window pattern for long context lengths.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.


gin.REQUIRED is reserved for settings that are not commonly used but would be so important and task-specific that there is no good default. You’ll get an error if you use one but forget to configure it.

from amago.nets.traj_encoders import TformerTrajEncoder
from amago.nets.transformer import SlidingWindowFlexAttention
from amago.cli_utils import use_config

config = {
    "TformerTrajEncoder.n_heads" : 16,
    "TformerTrajEncoder.d_model" : 512,
    "TformerTrajEncoder.d_ff" : 2048,
    "TformerTrajEncoder.attention_type": SlidingWindowFlexAttention,
    "SlidingWindowFlexAttention.window_size" : 128,
}

use_config(config)
experiment = Experiment(
    traj_encoder_type=TformerTrajEncoder,
    ...
)

Customizing the built-in TstepEncoder, TrajEncoder, Agent and ExplorationWrapper is so common that there’s easier ways to do it in use_config. For example, we could’ve made the changes for all the previous examples at the same time with:

from amago.cli_utils import switch_traj_encoder, switch_tstep_encoder, switch_agent, switch_exploration, use_config
from amago.nets.transformer import SlidingWindowFlexAttention
from amago.nets.cnn import IMPALAishCNN

config = {
    # these are niche changes customized a level below the `TstepEncoder` / `TrajEncoder`, so we still have to specify them
    "amago.nets.transformer.SlidingWindowFlexAttention.window_size" : 128,
    "amago.nets.cnn.IMPALAishCNN.cnn_block_depths" : [32, 64, 64],
}
tstep_encoder_type = switch_step_encoder(config, arch="cnn", cnn_type=IMPALAishCNN)
traj_encoder_type = switch_traj_encoder(config, arch="transformer", d_model=512, d_ff=2048, n_heads=16, attention_type=SlidingWindowFlexAttention)
exploration_wrapper_type = switch_exploration(config, strategy="egreedy", eps_start=1.0, eps_end=.01, steps_anneal=200_000)
# also customize random RL details as an example
agent_type = switch_agent(config, agent="multitask", num_critics=6, gamma=.998)
use_config(config)

experiment = Experiment(
    tstep_encoder_type=tstep_encoder_type,
    traj_encoder_type=traj_encoder_type,
    agent_type=agent_type,
    exploration_wrapper_type=exploration_wrapper_type,
    ...
)

If we want to combine hardcoded changes like these with genuine .gin files, use_config() will take the paths.

# these changes are applied in order from left to right. if we override the same param
# in multiple configs the final one will count. making gin this complicated is usually a bad idea.
use_config(config, gin_configs=["environment_config.gin", "rl_config.gin"])

Tip

You can view your active gin config (all the active hyperparameters used by an experiment) in the checkpoint directory as config.txt, or on wandb in the Config section.

A full API reference is available in the API Reference section.