1from argparse import ArgumentParser
2from functools import partial
3
4import wandb
5
6from amago.envs.builtin.ale_retro import AtariAMAGOWrapper, AtariGame
7from amago.nets.cnn import NatureishCNN, IMPALAishCNN
8from amago.cli_utils import *
9
10
11def add_cli(parser):
12 parser.add_argument("--games", nargs="+", default=None)
13 parser.add_argument("--max_seq_len", type=int, default=80)
14 parser.add_argument(
15 "--cnn", type=str, choices=["nature", "impala"], default="impala"
16 )
17 return parser
18
19
20DEFAULT_MULTIGAME_LIST = [
21 "Pong",
22 "Boxing",
23 "Breakout",
24 "Gopher",
25 "MsPacman",
26 "ChopperCommand",
27 "CrazyClimber",
28 "BattleZone",
29 "Qbert",
30 "Seaquest",
31]
32
33ATARI_TIME_LIMIT = (30 * 60 * 60) // 5 # (30 minutes of game time)
34
35
36def make_atari_game(game_name):
37 return AtariAMAGOWrapper(
38 AtariGame(
39 game=game_name,
40 action_space="discrete",
41 terminal_on_life_loss=False,
42 version="v5",
43 frame_skip=5,
44 grayscale=False,
45 sticky_action_prob=0.25,
46 clip_rewards=False,
47 ),
48 )
49
50
51if __name__ == "__main__":
52 parser = ArgumentParser()
53 add_cli(parser)
54 add_common_cli(parser)
55 args = parser.parse_args()
56
57 config = {
58 "amago.agent.Agent.reward_multiplier": 0.25,
59 "amago.agent.Agent.offline_coeff": (
60 1.0 if args.agent_type == "multitask" else 0.0
61 ),
62 }
63 traj_encoder_type = switch_traj_encoder(
64 config,
65 arch=args.traj_encoder,
66 memory_size=args.memory_size,
67 layers=args.memory_layers,
68 )
69
70 if args.cnn == "nature":
71 cnn_type = NatureishCNN
72 elif args.cnn == "impala":
73 cnn_type = IMPALAishCNN
74 tstep_encoder_type = switch_tstep_encoder(
75 config,
76 arch="cnn",
77 cnn_type=cnn_type,
78 channels_first=True,
79 drqv2_aug=True,
80 )
81
82 agent_type = switch_agent(config, args.agent_type)
83 use_config(config, args.configs)
84
85 # Episode lengths in Atari vary widely across games, so we manually set actors
86 # to a specific game so that all games are always played in parallel.
87 games = args.games or DEFAULT_MULTIGAME_LIST
88 assert (
89 args.parallel_actors % len(games) == 0
90 ), "Number of actors must be divisible by number of games."
91 env_funcs = []
92 for actor in range(args.parallel_actors):
93 game_name = games[actor % len(games)]
94 env_funcs.append(partial(make_atari_game, game_name))
95
96 group_name = f"{args.run_name}_atari_l_{args.max_seq_len}_cnn_{args.cnn}"
97 for trial in range(args.trials):
98 run_name = group_name + f"_trial_{trial}"
99 experiment = create_experiment_from_cli(
100 args,
101 make_train_env=env_funcs,
102 make_val_env=env_funcs,
103 max_seq_len=args.max_seq_len,
104 traj_save_len=args.max_seq_len * 3,
105 run_name=run_name,
106 tstep_encoder_type=tstep_encoder_type,
107 traj_encoder_type=traj_encoder_type,
108 agent_type=agent_type,
109 group_name=group_name,
110 val_timesteps_per_epoch=ATARI_TIME_LIMIT,
111 save_trajs_as="npz-compressed",
112 )
113 switch_async_mode(experiment, args.mode)
114 experiment.start()
115 if args.ckpt is not None:
116 experiment.load_checkpoint(args.ckpt)
117 experiment.learn()
118 experiment.evaluate_test(
119 env_funcs, timesteps=ATARI_TIME_LIMIT * 5, render=False
120 )
121 experiment.delete_buffer_from_disk()
122 wandb.finish()