POPGym

POPGym#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5from amago.envs.builtin.popgym_envs import POPGymAMAGO, MultiDomainPOPGymAMAGO
  6from amago.cli_utils import *
  7
  8
  9def add_cli(parser):
 10    parser.add_argument("--env", type=str, default="AutoencodeEasy")
 11    parser.add_argument("--max_seq_len", type=int, default=2000)
 12    parser.add_argument(
 13        "--multidomain",
 14        action="store_true",
 15        help="Activate 'MultiDomain' POPGym, where agents play 27 POPGym games at the same time in 1-shot format (2 episodes, second one counts).",
 16    )
 17    return parser
 18
 19
 20if __name__ == "__main__":
 21    parser = ArgumentParser()
 22    add_common_cli(parser)
 23    add_cli(parser)
 24    args = parser.parse_args()
 25
 26    # whenever we need a "max rollout length" value, we use this arbitrarily large number
 27    artificial_horizon = max(args.max_seq_len, 2000)
 28
 29    # fmt: off
 30    config = {
 31        "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
 32        "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
 33        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
 34        "binary_filter.threshold": 1e-3, # not important
 35        # learnable position embedding
 36        "amago.nets.transformer.LearnablePosEmb.max_time_idx": artificial_horizon,
 37        "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "learnable",
 38        "amago.nets.traj_encoders.TformerTrajEncoder.attention_type": amago.nets.transformer.FlashAttention,
 39        "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0, # not important
 40        "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6, # not important
 41        # paper version defaulted to large set of gamma values
 42        "amago.agent.Multigammas.discrete": [0.1, 0.7, 0.9, 0.93, 0.95, 0.98, 0.99, 0.992, 0.994, 0.995, 0.997, 0.998, 0.999, 0.9991, 0.9992, 0.9993, 0.9994, 0.9995],
 43    }
 44    # fmt: on
 45
 46    traj_encoder_type = switch_traj_encoder(
 47        config,
 48        arch=args.traj_encoder,
 49        memory_size=args.memory_size,  # paper: 256
 50        layers=args.memory_layers,  # paper: 3
 51    )
 52    tstep_encoder_type = switch_tstep_encoder(
 53        config,
 54        arch="ff",
 55        n_layers=2,
 56        d_hidden=512,
 57        d_output=200,
 58    )
 59    agent_type = switch_agent(
 60        config,
 61        args.agent_type,
 62        reward_multiplier=200.0 if args.multidomain else 100.0,
 63        tau=0.0025,
 64    )
 65    # steps_anneal can safely be set much lower (<500k) in most tasks. More sweeps needed.
 66    exploration_type = switch_exploration(
 67        config,
 68        "egreedy",
 69        steps_anneal=1_000_000,
 70    )
 71    use_config(config, args.configs)
 72
 73    group_name = f"{args.run_name}_{args.env}"
 74    for trial in range(args.trials):
 75        run_name = group_name + f"_trial_{trial}"
 76        if args.multidomain:
 77            make_train_env = lambda: MultiDomainPOPGymAMAGO()
 78        else:
 79            # in order to match the pre-gymnasium version of popgym (done instead of terminated/truncated),
 80            # we need to set terminated = terminated or truncated
 81            make_train_env = lambda: POPGymAMAGO(
 82                f"popgym-{args.env}-v0", truncated_is_done=True
 83            )
 84        experiment = create_experiment_from_cli(
 85            args,
 86            make_train_env=make_train_env,
 87            make_val_env=make_train_env,
 88            max_seq_len=args.max_seq_len,
 89            traj_save_len=artificial_horizon,
 90            group_name=group_name,
 91            run_name=run_name,
 92            tstep_encoder_type=tstep_encoder_type,
 93            traj_encoder_type=traj_encoder_type,
 94            exploration_wrapper_type=exploration_type,
 95            agent_type=agent_type,
 96            val_timesteps_per_epoch=artificial_horizon,
 97            learning_rate=1e-4,
 98            grad_clip=1.0,
 99            lr_warmup_steps=2000,
100        )
101        experiment = switch_async_mode(experiment, args.mode)
102        experiment.start()
103        if args.ckpt is not None:
104            experiment.load_checkpoint(args.ckpt)
105        experiment.learn()
106        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
107        experiment.delete_buffer_from_disk()
108        wandb.finish()