POPGym

POPGym#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5import amago
  6from amago.envs.builtin.popgym_envs import POPGymAMAGO, MultiDomainPOPGymAMAGO
  7from amago import cli_utils
  8
  9
 10def add_cli(parser):
 11    parser.add_argument("--env", type=str, default="AutoencodeEasy")
 12    parser.add_argument("--max_seq_len", type=int, default=2000)
 13    parser.add_argument(
 14        "--multidomain",
 15        action="store_true",
 16        help="Activate 'MultiDomain' POPGym, where agents play 27 POPGym games at the same time in 1-shot format (2 episodes, second one counts).",
 17    )
 18    return parser
 19
 20
 21if __name__ == "__main__":
 22    parser = ArgumentParser()
 23    cli_utils.add_common_cli(parser)
 24    add_cli(parser)
 25    args = parser.parse_args()
 26
 27    # whenever we need a "max rollout length" value, we use this arbitrarily large number
 28    artificial_horizon = max(args.max_seq_len, 2000)
 29
 30    # fmt: off
 31    config = {
 32        "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
 33        "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
 34        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
 35        "binary_filter.threshold": 1e-3, # not important
 36        # learnable position embedding
 37        "amago.nets.transformer.LearnablePosEmb.max_time_idx": artificial_horizon,
 38        "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "learnable",
 39        "amago.nets.traj_encoders.TformerTrajEncoder.attention_type": amago.nets.transformer.FlashAttention,
 40        "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0, # not important
 41        "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6, # not important
 42        # paper version defaulted to large set of gamma values
 43        "amago.agent.Multigammas.discrete": [0.1, 0.7, 0.9, 0.93, 0.95, 0.98, 0.99, 0.992, 0.994, 0.995, 0.997, 0.998, 0.999, 0.9991, 0.9992, 0.9993, 0.9994, 0.9995],
 44    }
 45    # fmt: on
 46
 47    traj_encoder_type = cli_utils.switch_traj_encoder(
 48        config,
 49        arch=args.traj_encoder,
 50        memory_size=args.memory_size,  # paper: 256
 51        layers=args.memory_layers,  # paper: 3
 52    )
 53    tstep_encoder_type = cli_utils.switch_tstep_encoder(
 54        config,
 55        arch="ff",
 56        n_layers=2,
 57        d_hidden=512,
 58        d_output=200,
 59    )
 60    agent_type = cli_utils.switch_agent(
 61        config,
 62        args.agent_type,
 63        reward_multiplier=200.0 if args.multidomain else 100.0,
 64        tau=0.0025,
 65    )
 66    # steps_anneal can safely be set much lower (<500k) in most tasks. More sweeps needed.
 67    exploration_type = cli_utils.switch_exploration(
 68        config,
 69        "egreedy",
 70        steps_anneal=1_000_000,
 71    )
 72    cli_utils.use_config(config, args.configs)
 73
 74    group_name = f"{args.run_name}_{args.env}"
 75    for trial in range(args.trials):
 76        run_name = group_name + f"_trial_{trial}"
 77        if args.multidomain:
 78            make_train_env = lambda: MultiDomainPOPGymAMAGO()
 79        else:
 80            # in order to match the pre-gymnasium version of popgym (done instead of terminated/truncated),
 81            # we need to set terminated = terminated or truncated
 82            make_train_env = lambda: POPGymAMAGO(
 83                f"popgym-{args.env}-v0", truncated_is_done=True
 84            )
 85        experiment = cli_utils.create_experiment_from_cli(
 86            args,
 87            make_train_env=make_train_env,
 88            make_val_env=make_train_env,
 89            max_seq_len=args.max_seq_len,
 90            traj_save_len=artificial_horizon,
 91            group_name=group_name,
 92            run_name=run_name,
 93            tstep_encoder_type=tstep_encoder_type,
 94            traj_encoder_type=traj_encoder_type,
 95            exploration_wrapper_type=exploration_type,
 96            agent_type=agent_type,
 97            val_timesteps_per_epoch=artificial_horizon,
 98            learning_rate=1e-4,
 99            grad_clip=1.0,
100            lr_warmup_steps=2000,
101        )
102        experiment = cli_utils.switch_async_mode(experiment, args.mode)
103        experiment.start()
104        if args.ckpt is not None:
105            experiment.load_checkpoint(args.ckpt)
106        experiment.learn()
107        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
108        experiment.delete_buffer_from_disk()
109        wandb.finish()