amago.experiment

Contents

amago.experiment#

Start and launch training runs (main Experiment).

Classes

Experiment(run_name, ckpt_base_dir, ...[, ...])

Build, train, and evaluate an Agent.

class Experiment(run_name, ckpt_base_dir, max_seq_len, dataset, tstep_encoder_type, traj_encoder_type, agent_type, val_timesteps_per_epoch, make_train_env, make_val_env, parallel_actors=12, env_mode='async', async_env_mp_context=None, exploration_wrapper_type=<class 'amago.envs.exploration.EpsilonGreedy'>, sample_actions=True, force_reset_train_envs_every=None, log_to_wandb=False, wandb_project="os.environ['AMAGO_WANDB_PROJECT']", wandb_entity="os.environ['AMAGO_WANDB_ENTITY']", wandb_group_name=None, verbose=True, log_interval=300, traj_save_len=10000000000.0, has_dset_edit_rights=True, stagger_traj_file_lengths=True, save_trajs_as='npz', padded_sampling='none', dloader_workers=6, epochs=1000, start_learning_at_epoch=0, start_collecting_at_epoch=0, train_timesteps_per_epoch=1000, train_batches_per_epoch=1000, val_interval=20, ckpt_interval=50, always_save_latest=True, always_load_latest=False, batch_size=24, batches_per_update=1, learning_rate=0.0001, critic_loss_weight=10.0, lr_warmup_steps=500, grad_clip=1.0, l2_coeff=0.001, mixed_precision='no')[source]#

Bases: object

Build, train, and evaluate an Agent.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Required

Parameters:
  • run_name (str) – Name of the experiment. Used to create checkpoint and log directories.

  • ckpt_base_dir (str) – Base directory to store checkpoints and logs. Checkpoints are saved to ckpt_base_dir/run_name.

  • max_seq_len (int) – Maximum sequence length for training. Determines effective batch size (Batch Size × Sequence Length).

  • dataset (RLDataset) – RLDataset for loading training sequences.

  • tstep_encoder_type (type[TstepEncoder]) – a type of TstepEncoder (will be created with default kwargs — edit via gin).

  • traj_encoder_type (type[TrajEncoder]) – a type of TrajEncoder (will be created with default kwargs — edit via gin).

  • agent_type (type[Agent]) – a type of Agent (will be created with default kwargs — edit via gin).

  • make_train_env (callable | Iterable[callable]) – Callable returning an AMAGOEnv. If not a list, repeated parallel_actors times. List gives manual assignment across actors.

  • make_val_env (callable | Iterable[callable]) – Like make_train_env, but only used for evaluation (trajectories never saved).

  • val_timesteps_per_epoch (int) – Number of steps per parallel environment for evaluation. Determines metric sample size. Should be enough time for at least one episode to finish per actor.

Environment

Parameters:
  • parallel_actors (int) – Number of parallel envs for batched inference. Default: 12.

  • env_mode (str) – "async" (default), wraps envs in async pool. "already_vectorized" for jax/gpu batch envs. "sync" for debug. Default: “async”.

  • exploration_wrapper_type (type[ExplorationWrapper] | None) – Exploration wrapper for training envs. Default: EpsilonGreedy.

  • sample_actions (bool) – Whether to sample from stochastic actor during eval, or take argmax/mean. Default: True.

  • force_reset_train_envs_every (int | None) – If set, forces call to reset every N epochs for already_vectorized envs. Default: None.

  • async_env_mp_context (str | None) – Multiprocessing spawn method for AsyncVectorEnv (e.g., "forkserver"). Only relevant for env_mode="async". Set to None for default method. Default: None.

Logging

Parameters:
  • log_to_wandb (bool) – Enable or disable wandb logging. Default: False.

  • wandb_project (str) – wandb project. Default: AMAGO_WANDB_PROJECT env var.

  • wandb_entity (str) – wandb entity (username/team). Default: AMAGO_WANDB_ENTITY env var.

  • wandb_group_name (str | None) – Group runs on wandb dashboard. Default: None.

  • verbose (bool) – Print tqdm bars and info to console. Default: True.

  • log_interval (int) – Log extra metrics every N batches. Default: 300.

  • padded_sampling (str) – Padding for sampling training subsequences. “none”, “left”, “right”, “both”. Default: “none”.

  • dloader_workers (int) – Number of DataLoader workers for disk loading. Increase for compressed/large trajs.

Note

The parameters below are only relevant when doing online data collection. They determine how parallel environments write finished trajectories to disk. The DiskTrajDataset reads these files for training.

Parameters:
  • traj_save_len (int) – Save trajectory on episode end or after this many steps (whichever comes first). Larger values save whole trajectories. Default: large value.

  • has_dset_edit_rights (bool) – Turn off for collect-only runs where another process manages the replay buffer. Default: True.

  • stagger_traj_file_lengths (bool) – Randomizes file lengths when traj_save_len is short snippets. Default: False.

  • save_trajs_as (str) – Format for saved trajectories. “npz”, “npz-compressed”, or “traj”. Default: “npz”.

Learning Schedule

Parameters:
  • epochs (int) – Epochs (each = one data collection + one training round). Default: 500.

  • start_learning_at_epoch (int) – Number of epochs to skip before gradient updates (for replay buffer warmup). Default: 0.

  • start_collecting_at_epoch (int) – Number of epochs to skip data collection (for offline→online finetune or full offline). Default: 0.

  • train_timesteps_per_epoch (int) – Number of steps in each parallel env per epoch. Default: 1000.

  • train_batches_per_epoch (int) – Number of training batches per epoch. Default: 1000.

  • val_interval (int | None) – How many epochs between evaluation rollouts. Default: 20.

  • ckpt_interval (int | None) – How many epochs between saving checkpoints. Default: 50.

  • always_save_latest (bool) – Whether to always save the latest weights (for distributed usage). Default: True.

  • always_load_latest (bool) – Whether to always load the latest weights (for distributed usage). Default: False.

Optimization

Parameters:
  • batch_size (int) – Batch size per GPU (in sequences). Default: 24.

  • batches_per_update (int) – Number of batches to accumulate gradients over before optimizer update. Default: 1.

  • learning_rate (float) – Optimizer learning rate. Default: 1e-4 (defaults to AdamW).

  • critic_loss_weight (float) – Weight for critic loss vs actor loss in encoders. Default: 10.

  • lr_warmup_steps (int) – Number of warmup steps for learning rate scheduler. Default: 500.

  • grad_clip (float) – Gradient norm clipping value. Default: 1.0.

  • l2_coeff (float) – L2 regularization coefficient (AdamW). Default: 1e-3.

  • mixed_precision (str) – Mixed precision mode for accelerate (“no”, “fp16”, “bf16”). Default: “no”.

property DEVICE#

Return the device (cpu/gpu) that the experiment is running on.

agent_type: type[Agent]#
always_load_latest: bool = False#
always_save_latest: bool = True#
async_env_mp_context: str | None = None#
batch_size: int = 24#
batches_per_update: int = 1#
caster()[source]#

Get the context manager for mixed precision training.

ckpt_base_dir: str#
ckpt_interval: int | None = 50#
collect_new_training_data()[source]#

Generate train_timesteps_per_epoch * parallel_actors timesteps of new environment interaction that will be saved to the replay buffer when the rollouts finishes.

Return type:

None

compute_loss(batch, log_step)[source]#

Core computation of the actor and critic RL loss terms from a Batch of data.

Parameters:
  • batch (Batch) – The batch of data.

  • log_step (bool) – Whether to compute extra metrics for wandb logging.

Returns:

loss terms and any logging metrics. “Actor Loss”, “Critic Loss”, “Sequence

Length”, “Batch Size (in Timesteps)”, “Unmasked Batch Size (in Timesteps)” are always provided. Additional keys are determined by what is logged in the Agent.forward method.

Return type:

dict

critic_loss_weight: float = 10.0#
dataset: RLDataset#
delete_buffer_from_disk()[source]#

Clear the replay buffer from disk (mainly for examples/).

Calls self.dataset.delete() if the current process is the main process.

Return type:

None

dloader_workers: int = 6#
edit_actor_mask(batch, actor_loss, pad_mask)[source]#

Customize the actor loss mask.

Parameters:
  • batch (Batch) – The batch of data.

  • actor_loss (FloatTensor) – The unmasked actor loss. Shape: (Batch, Length, Num Gammas, 1)

  • pad_mask (BoolTensor) – The default mask. True where the sequence was not padded out of the dataloader.

Return type:

BoolTensor

Returns:

The mask. True where the actor loss should count, False where it should be ignored.

edit_critic_mask(batch, critic_loss, pad_mask)[source]#

Customize the critic loss mask.

Parameters:
  • batch (Batch) – The batch of data.

  • critic_loss (FloatTensor) – The unmasked critic loss. Shape: (Batch, Length, Num Critics, Num Gammas, 1)

  • pad_mask (BoolTensor) – The default mask. True where the sequence was not padded out of the dataloader.

Return type:

BoolTensor

Returns:

The mask. True where the critic loss should count, False where it should be ignored.

env_mode: str = 'async'#
epochs: int = 1000#
evaluate_test(make_test_env, timesteps, render=False, save_trajs_to=None, episodes=None)[source]#

One-off evaluation of a new environment callable for testing.

Parameters:
  • make_test_env (callable | Iterable[callable]) – A callable or iterable of callables that make and return a test environment. If an iterable, it must be of length Experiment.parallel_actors.

  • timesteps (int) – The number of timesteps to interact with each environment.

  • render (bool) – Whether to render the environment. Defaults to False.

  • save_trajs_to (str | None) – The directory to save trajectories. Useful when using evaluate_test to gather demonstration data for another run. If None, no data is saved. Defaults to None.

  • episodes (int | None) – The number of episodes to interact with the environment. If provided, the loop will terminate after this many episodes have been completed OR we hit the timesteps limit, whichever comes first. Defaults to None.

Returns:

A dictionary of evaluation metrics.

Return type:

dict[str, float]

evaluate_val()[source]#

Evaluate the current policy without exploration noise on the validation environments, and averages the performance metrics across accelerate processes.

Return type:

None

exploration_wrapper_type#

alias of EpsilonGreedy

force_reset_train_envs_every: int | None = None#
grad_clip: float = 1.0#
has_dset_edit_rights: bool = True#
init_checkpoints()[source]#

Create ckpts/training_states, ckpts/policy_weights, and ckpts/latest dirs

Return type:

None

init_dloaders()[source]#

Create pytorch dataloaders to batch trajectories in parallel.

Return type:

DataLoader

init_dsets()[source]#

Modifies the provided RLDataset (in place) to use important info configured by the experiment.

Return type:

RLDataset

init_envs()[source]#

Construct parallel training and validation environments.

Returns:

Description of the environment setup printed to the console when

Experiment.verbose is True.

Return type:

str

init_logger()[source]#

Configure log dir and wandb compatibility.

Return type:

None

init_model()[source]#

Build an initial policy based on observation shapes

Return type:

None

init_optimizer(policy)[source]#

Defines the optimizer.

Override to switch from AdamW.

Return type:

Optimizer

Returns:

torch.optim.Optimizer in charge of updating the Agent’s parameters

(Agent.trainable_params)

interact(envs, timesteps, hidden_state=None, render=False, save_on_done=False, episodes=None)[source]#

Main policy loop for interacting with the environment.

Parameters:
  • envs – The (parallel) environments to interact with.

  • timesteps (int) – The number of timesteps to interact with each environment.

  • hidden_state – The hidden state of the policy. If None, a fresh hidden state is initialized. Defaults to None.

  • render (bool) – Whether to render the environment. Defaults to False.

  • save_on_done (bool) – If True, save completed trajectory sequences to disk as soon as they are finished. If False, wait until all rollouts are completed. Only applicable if the provided envs are configured to save rollouts to disk. Defaults to False.

  • episodes (int | None) – The number of episodes to interact with the environment. If provided, the loop will terminate after this many episodes have been completed OR we hit the timesteps limit, whichever comes first. Defaults to None.

Returns:

Objects that keep track of standard

eval stats (average returns) and any additional eval metrics the envs have been configured to record.

Return type:

tuple[ReturnHistory, SpecialMetricHistory]

l2_coeff: float = 0.001#
learn()[source]#

Main training loop for the experiment.

Return type:

None

For every epoch, we:
  1. Load the latest policy checkpoint if always_load_latest is True.

  2. Evaluate the policy on the validation set if val_interval is not None and the

    current epoch is divisible by val_interval.

  3. Collect new training data if train_timesteps_per_epoch is not None and the

    current epoch >= to start_collecting_at_epoch.

  4. Train the policy on the training data for train_batches_per_epoch batches if

    self.dataset.ready_for_training is True.

  5. Save the policy checkpoint if ckpt_interval is not None and the current epoch

    is divisible by ckpt_interval.

  6. Write the latest policy checkpoint if always_save_latest is True.

Experiment be configured so that processes do some or all of the above. For example, an actor process might only do steps 1, 2, and 3, while a learner process might only do steps 4, 5, and 6.

learning_rate: float = 0.0001#
load_checkpoint(epoch, resume_training_state=True)[source]#

Load a historical checkpoint from the ckpts directory of this experiment.

Parameters:
  • epoch (int) – The epoch number of the checkpoint to load.

  • resume_training_state (bool) – Whether to resume the entire training process (True) or only the policy weights (False). Defaults to True.

Return type:

None

load_checkpoint_from_path(path, is_accelerate_state=True)[source]#

Load a checkpoint from a given path.

Parameters:
  • path (str) – Full path to the checkpoint fle to load.

  • is_accelerate_state (bool) – Whether the checkpoint is a full accelerate state (True) or pytorch weights only (False). Defaults to True.

Return type:

None

log(metrics_dict, key)[source]#

Log a dict of metrics to the key panel of the wandb console alongisde current x-axis metrics.

Return type:

None

log_interval: int = 300#
log_to_wandb: bool = False#
lr_warmup_steps: int = 500#
make_train_env: callable | Iterable[callable]#
make_val_env: callable | Iterable[callable]#
max_seq_len: int#
mixed_precision: str = 'no'#
padded_sampling: str = 'none'#
parallel_actors: int = 12#
property policy: Agent#

Returns the current Agent policy free from the accelerator wrapper.

policy_metrics(returns, specials)[source]#

Gather policy performance metrics across parallel environments.

Parameters:
  • returns (Iterable[ReturnHistory]) – The return history logger froms the environments.

  • specials (Iterable[SpecialMetricHistory]) – The special metrics history loggers from the environments.

Returns:

A dictionary of policy performance metrics.

Return type:

dict

read_latest_policy()[source]#

Read the latest policy – used to communicate weight updates between learning/collecting processes

Return type:

None

run_name: str#
sample_actions: bool = True#
save_checkpoint()[source]#

Save both the training state and the policy weights to the ckpt_dir.

Return type:

None

save_trajs_as: str = 'npz'#
stagger_traj_file_lengths: bool = True#
start()[source]#

Manual initialization after __init__ to give time for gin configuration.

Call before Experiment.learn()

start_collecting_at_epoch: int = 0#
start_learning_at_epoch: int = 0#
summary(env_summary)[source]#

Print key hparams to the console for reference.

Return type:

None

train_batches_per_epoch: int = 1000#
train_step(batch, log_step)[source]#

Take a single training step on a batch of data.

Parameters:
  • batch (Batch) – The batch of data.

  • log_step (bool) – Whether to compute extra metrics for wandb logging.

Returns:

metrics from the compute_loss method.

Return type:

dict

train_timesteps_per_epoch: int = 1000#
traj_encoder_type: type[TrajEncoder]#
traj_save_len: int = 10000000000.0#
tstep_encoder_type: type[TstepEncoder]#
val_interval: int | None = 20#
val_timesteps_per_epoch: int#
verbose: bool = True#
wandb_entity: str = "os.environ['AMAGO_WANDB_ENTITY']"#
wandb_group_name: str = None#
wandb_project: str = "os.environ['AMAGO_WANDB_PROJECT']"#
write_latest_policy()[source]#

Write absolute latest policy to a hardcoded location used by read_latest_policy

Return type:

None

x_axis_metrics()[source]#

Get current x-axis metrics for wandb.

Return type:

dict[str, int | float]