1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.procgen_envs import (
6 TwoShotMTProcgen,
7 ProcgenAMAGO,
8 ALL_PROCGEN_GAMES,
9)
10from amago.nets.cnn import IMPALAishCNN
11from amago.cli_utils import *
12
13
14def add_cli(parser):
15 parser.add_argument("--max_seq_len", type=int, default=256)
16 parser.add_argument(
17 "--distribution",
18 type=str,
19 default="easy",
20 choices=["easy", "easy-rescaled", "memory-hard"],
21 )
22 parser.add_argument("--train_seeds", type=int, default=10_000)
23 return parser
24
25
26PROCGEN_SETTINGS = {
27 "easy": {
28 "games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
29 "reward_scales": {},
30 "distribution_mode": "easy",
31 },
32 "easy-rescaled": {
33 "games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
34 "reward_scales": {"coinrun": 100.0, "climber": 0.1},
35 "distribution_mode": "easy",
36 },
37 "memory-hard": {
38 "games": ALL_PROCGEN_GAMES,
39 "reward_scales": {},
40 "distribution_mode": "memory-hard",
41 },
42}
43
44if __name__ == "__main__":
45 parser = ArgumentParser()
46 add_cli(parser)
47 add_common_cli(parser)
48 args = parser.parse_args()
49
50 config = {}
51 traj_encoder_type = switch_traj_encoder(
52 config,
53 arch=args.traj_encoder,
54 memory_size=args.memory_size,
55 layers=args.memory_layers,
56 )
57 tstep_encoder_type = switch_tstep_encoder(
58 config,
59 arch="cnn",
60 cnn_type=IMPALAishCNN,
61 channels_first=False,
62 drqv2_aug=True,
63 )
64 agent_type = switch_agent(config, args.agent_type)
65 use_config(config, args.configs)
66
67 procgen_kwargs = PROCGEN_SETTINGS[args.distribution]
68 horizon = 2000 if "easy" in args.distribution else 5000
69 make_train_env = lambda: ProcgenAMAGO(
70 TwoShotMTProcgen(**procgen_kwargs, seed_range=(0, args.train_seeds)),
71 )
72 make_test_env = lambda: ProcgenAMAGO(
73 TwoShotMTProcgen(
74 **procgen_kwargs, seed_range=(args.train_seeds + 1, 10_000_000)
75 ),
76 )
77
78 group_name = f"{args.run_name}_{args.distribution}_procgen_l_{args.max_seq_len}"
79 for trial in range(args.trials):
80 run_name = group_name + f"_trial_{trial}"
81 experiment = create_experiment_from_cli(
82 args,
83 make_train_env=make_train_env,
84 make_val_env=make_test_env,
85 max_seq_len=args.max_seq_len,
86 traj_save_len=args.max_seq_len * 4,
87 run_name=run_name,
88 tstep_encoder_type=tstep_encoder_type,
89 traj_encoder_type=traj_encoder_type,
90 agent_type=agent_type,
91 group_name=group_name,
92 val_timesteps_per_epoch=5 * horizon + 1,
93 )
94 switch_async_mode(experiment, args.mode)
95 experiment.start()
96 if args.ckpt is not None:
97 experiment.load_checkpoint(args.ckpt)
98 experiment.learn()
99 experiment.evaluate_test(make_test_env, timesteps=horizon * 20, render=False)
100 experiment.delete_buffer_from_disk()
101 wandb.finish()