Multi-Game ProcGen

Multi-Game ProcGen#

  1from argparse import ArgumentParser
  2
  3import wandb
  4
  5from amago.envs.builtin.procgen_envs import (
  6    TwoShotMTProcgen,
  7    ProcgenAMAGO,
  8    ALL_PROCGEN_GAMES,
  9)
 10from amago.nets.cnn import IMPALAishCNN
 11from amago.cli_utils import *
 12
 13
 14def add_cli(parser):
 15    parser.add_argument("--max_seq_len", type=int, default=256)
 16    parser.add_argument(
 17        "--distribution",
 18        type=str,
 19        default="easy",
 20        choices=["easy", "easy-rescaled", "memory-hard"],
 21    )
 22    parser.add_argument("--train_seeds", type=int, default=10_000)
 23    return parser
 24
 25
 26PROCGEN_SETTINGS = {
 27    "easy": {
 28        "games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
 29        "reward_scales": {},
 30        "distribution_mode": "easy",
 31    },
 32    "easy-rescaled": {
 33        "games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
 34        "reward_scales": {"coinrun": 100.0, "climber": 0.1},
 35        "distribution_mode": "easy",
 36    },
 37    "memory-hard": {
 38        "games": ALL_PROCGEN_GAMES,
 39        "reward_scales": {},
 40        "distribution_mode": "memory-hard",
 41    },
 42}
 43
 44if __name__ == "__main__":
 45    parser = ArgumentParser()
 46    add_cli(parser)
 47    add_common_cli(parser)
 48    args = parser.parse_args()
 49
 50    config = {}
 51    traj_encoder_type = switch_traj_encoder(
 52        config,
 53        arch=args.traj_encoder,
 54        memory_size=args.memory_size,
 55        layers=args.memory_layers,
 56    )
 57    tstep_encoder_type = switch_tstep_encoder(
 58        config,
 59        arch="cnn",
 60        cnn_type=IMPALAishCNN,
 61        channels_first=False,
 62        drqv2_aug=True,
 63    )
 64    agent_type = switch_agent(config, args.agent_type)
 65    use_config(config, args.configs)
 66
 67    procgen_kwargs = PROCGEN_SETTINGS[args.distribution]
 68    horizon = 2000 if "easy" in args.distribution else 5000
 69    make_train_env = lambda: ProcgenAMAGO(
 70        TwoShotMTProcgen(**procgen_kwargs, seed_range=(0, args.train_seeds)),
 71    )
 72    make_test_env = lambda: ProcgenAMAGO(
 73        TwoShotMTProcgen(
 74            **procgen_kwargs, seed_range=(args.train_seeds + 1, 10_000_000)
 75        ),
 76    )
 77
 78    group_name = f"{args.run_name}_{args.distribution}_procgen_l_{args.max_seq_len}"
 79    for trial in range(args.trials):
 80        run_name = group_name + f"_trial_{trial}"
 81        experiment = create_experiment_from_cli(
 82            args,
 83            make_train_env=make_train_env,
 84            make_val_env=make_test_env,
 85            max_seq_len=args.max_seq_len,
 86            traj_save_len=args.max_seq_len * 4,
 87            run_name=run_name,
 88            tstep_encoder_type=tstep_encoder_type,
 89            traj_encoder_type=traj_encoder_type,
 90            agent_type=agent_type,
 91            group_name=group_name,
 92            val_timesteps_per_epoch=5 * horizon + 1,
 93        )
 94        switch_async_mode(experiment, args.mode)
 95        experiment.start()
 96        if args.ckpt is not None:
 97            experiment.load_checkpoint(args.ckpt)
 98        experiment.learn()
 99        experiment.evaluate_test(make_test_env, timesteps=horizon * 20, render=False)
100        experiment.delete_buffer_from_disk()
101        wandb.finish()