amago.cli_utils#

Convenience functions that create a generic CLI for Experiment and handle common gin configurations

These mostly exist to make the examples/ easier to maintain with less boilerplate, and to break up configuration into several smaller steps.

Functions

add_common_cli(parser)

Adds a common CLI for examples and basic training scripts.

create_experiment_from_cli(...[, ...])

A convenience function that assigns Experiment kwargs from add_common_cli() options.

make_experiment_collect_only(experiment)

Modify the experiment to run in collect-only mode.

make_experiment_learn_only(experiment)

Modify the experiment to run in learn-only mode.

switch_agent(config, agent, **kwargs)

Set default kwargs for a built-in Agent without gin syntax or config files.

switch_async_mode(experiment, mode)

Switch the experiment mode between collect, learn, or both.

switch_exploration(config, strategy, **kwargs)

Set default kwargs for a built-in ExplorationWrapper without gin syntax or config files.

switch_traj_encoder(config, arch, ...)

Set default kwargs for a built-in TrajEncoder without gin syntax or config files.

switch_tstep_encoder(config, arch, **kwargs)

Set default kwargs for a TstepEncoder without gin syntax or config files.

use_config(custom_params[, gin_configs, ...])

Bind gin parameters to edit kwarg defaults across the codebase before training begins.

add_common_cli(parser)[source]#

Adds a common CLI for examples and basic training scripts.

Parameters:

parser (ArgumentParser) – The argument parser containing problem-specific application-specific arguments.

Return type:

ArgumentParser

Returns:

The argument parser with common CLI arguments added.

create_experiment_from_cli(command_line_args, make_train_env, make_val_env, max_seq_len, group_name, run_name, agent_type, tstep_encoder_type, traj_encoder_type, traj_save_len=None, exploration_wrapper_type=<class 'amago.envs.exploration.EpsilonGreedy'>, experiment_type=<class 'amago.experiment.Experiment'>, dataset=None, **extra_experiment_kwargs)[source]#

A convenience function that assigns Experiment kwargs from add_common_cli() options.

Parameters:
  • command_line_args – The parsed command line arguments created by cli_utils.add_common_cli().

  • make_train_env (callable) – A callable that makes the training environment.

  • make_val_env (callable) – A callable that makes the validation environment.

  • max_seq_len (int) – The maximum sequence length of the policy during training.

  • group_name (str) – The name of the wandb group to use for logging.

  • run_name (str) – The name of the run for logging & checkpoints.

  • agent_type (type[Agent]) – The type of agent to use. Can be the output of cli_utils.switch_agent().

  • tstep_encoder_type (type[TstepEncoder]) – The type of tstep encoder to use. Can be the output of cli_utils.switch_tstep_encoder().

  • traj_encoder_type (type[TrajEncoder]) – The type of traj encoder to use. Can be the output of cli_utils.switch_traj_encoder().

  • traj_save_len (int | None) – The length of the trajectory to save. Defaults to a very large number (which saves entire trajectories on terminated or truncated).

  • exploration_wrapper_type (type[ExplorationWrapper]) – The type of exploration wrapper to use. Can be the output of cli_utils.switch_exploration(), but defaults to EpsilonGreedy.

  • experiment_type (type[Experiment]) – The type of experiment to use. Defaults to amago.Experiment.

  • dataset (RLDataset | None) – An optional dataset to use. If not provided, we create a DiskTrajDataset (an online RL replay buffer on disk) in the same directory where the CLI tells us it will save checkpoints ({args.buffer_dir}/{args.run_name}).

  • **extra_experiment_kwargs – Additional keyword arguments to pass to the Experiment constructor.

Return type:

Experiment

Returns:

An Experiment instance.

make_experiment_collect_only(experiment)[source]#

Modify the experiment to run in collect-only mode.

Parameters:

experiment (Experiment) – The experiment to modify.

Return type:

Experiment

Returns:

The modified experiment.

make_experiment_learn_only(experiment)[source]#

Modify the experiment to run in learn-only mode.

Parameters:

experiment (Experiment) – The experiment to modify.

Return type:

Experiment

Returns:

The modified experiment.

switch_agent(config, agent, **kwargs)[source]#

Set default kwargs for a built-in Agent without gin syntax or config files.

Parameters:
  • config (dict) – A dictionary of gin parameters yet to be assigned.

  • agent (str) – A shortcut name for built-in Agents. Options are “agent” (Agent) and “multitask” (MultiTaskAgent).

  • **kwargs – Assign any of the chosen Agent’s default kwargs.

Return type:

type[Agent]

Returns:

A reference to the Agent type that can be passed into the Experiment.

switch_async_mode(experiment, mode)[source]#

Switch the experiment mode between collect, learn, or both.

Parameters:
  • experiment (Experiment) – The experiment to modify.

  • mode (str) – The mode to switch to. Options are “collect”, “learn”, or “both”.

Return type:

Experiment

Returns:

The modified experiment.

switch_exploration(config, strategy, **kwargs)[source]#

Set default kwargs for a built-in ExplorationWrapper without gin syntax or config files.

Parameters:
  • config (dict) – A dictionary of gin parameters yet to be assigned.

  • strategy (str) – A shortcut name for built-in ExplorationWrappers. Options are “egreedy” (EpsilonGreedy) and “bilevel” (BilevelEpsilonGreedy).

  • **kwargs – Assign any of the chosen ExplorationWrapper’s default kwargs.

Return type:

type[ExplorationWrapper]

Returns:

A reference to the ExplorationWrapper type that can be passed into the Experiment.

switch_traj_encoder(config, arch, memory_size, layers, **kwargs)[source]#

Set default kwargs for a built-in TrajEncoder without gin syntax or config files.

Parameters:
  • config (dict) – A dictionary of gin parameters yet to be assigned.

  • arch (str) – A shortcut name for built-in TrajEncoders. Options are “ff” (memory-free residual feed-forward blocks), “rnn” (RNN), “transformer” (Transformer), and “mamba” (Mamba).

  • memory_size (int) – Sets the same conceptual state space dimension across the various architectures. For example, the size of the hidden state in an RNN or d_model in a Transformer.

  • layers (int) – Sets the number of layers in the TrajEncoder.

  • **kwargs – Assign any of the chosen TrajEncoder’s default kwargs.

Return type:

type[TrajEncoder]

Returns:

A reference to the TrajEncoder type that can be passed into the Experiment.

switch_tstep_encoder(config, arch, **kwargs)[source]#

Set default kwargs for a TstepEncoder without gin syntax or config files.

Parameters:
  • config (dict) – A dictionary of gin parameters yet to be assigned.

  • arch (str) – A shortcut name for built-in TstepEncoders. Options are “ff” (generic MLP) and “cnn” (generic CNN).

  • **kwargs – Assign any of the chosen TstepEncoder’s default kwargs.

Return type:

type[TstepEncoder]

Returns:

A reference to the TstepEncoder type that can be passed into the Experiment.

Example

config = {}
# Make the input MLP smaller
tstep_encoder_type = switch_tstep_encoder(
    config, "ff", n_layers=1, d_hidden=128, d_output=128
)
cli_utils.use_config(config) # set new default parameters
experiment = Experiment(
    ...,  # rest of args
    tstep_encoder_type=tstep_encoder_type,
)
use_config(custom_params, gin_configs=None, finalize=True)[source]#

Bind gin parameters to edit kwarg defaults across the codebase before training begins.

Parameters:
  • custom_params (dict) – A dictionary of gin parameters to bind ({param: new_default_value}). This was probably created within the training script or from a few command line args.

  • gin_configs (list[str] | None) – An optional list of .gin configuration files to use. Gin files are the correct way to handle configs for real projects… unlike the example scripts.

  • finalize (bool) – If True, finalize/freeze the gin config to prevent later changes. Defaults to True.

Return type:

None