1from argparse import ArgumentParser
2
3import wandb
4
5import amago
6from amago.envs.builtin.popgym_envs import POPGymAMAGO, MultiDomainPOPGymAMAGO
7from amago import cli_utils
8
9
10def add_cli(parser):
11 parser.add_argument("--env", type=str, default="AutoencodeEasy")
12 parser.add_argument("--max_seq_len", type=int, default=2000)
13 parser.add_argument(
14 "--multidomain",
15 action="store_true",
16 help="Activate 'MultiDomain' POPGym, where agents play 27 POPGym games at the same time in 1-shot format (2 episodes, second one counts).",
17 )
18 return parser
19
20
21if __name__ == "__main__":
22 parser = ArgumentParser()
23 cli_utils.add_common_cli(parser)
24 add_cli(parser)
25 args = parser.parse_args()
26
27 # whenever we need a "max rollout length" value, we use this arbitrarily large number
28 artificial_horizon = max(args.max_seq_len, 2000)
29
30 # fmt: off
31 config = {
32 "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
33 "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
34 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
35 "binary_filter.threshold": 1e-3, # not important
36 # learnable position embedding
37 "amago.nets.transformer.LearnablePosEmb.max_time_idx": artificial_horizon,
38 "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "learnable",
39 "amago.nets.traj_encoders.TformerTrajEncoder.attention_type": amago.nets.transformer.FlashAttention,
40 "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0, # not important
41 "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6, # not important
42 # paper version defaulted to large set of gamma values
43 "amago.agent.Multigammas.discrete": [0.1, 0.7, 0.9, 0.93, 0.95, 0.98, 0.99, 0.992, 0.994, 0.995, 0.997, 0.998, 0.999, 0.9991, 0.9992, 0.9993, 0.9994, 0.9995],
44 }
45 # fmt: on
46
47 traj_encoder_type = cli_utils.switch_traj_encoder(
48 config,
49 arch=args.traj_encoder,
50 memory_size=args.memory_size, # paper: 256
51 layers=args.memory_layers, # paper: 3
52 )
53 tstep_encoder_type = cli_utils.switch_tstep_encoder(
54 config,
55 arch="ff",
56 n_layers=2,
57 d_hidden=512,
58 d_output=200,
59 )
60 agent_type = cli_utils.switch_agent(
61 config,
62 args.agent_type,
63 reward_multiplier=200.0 if args.multidomain else 100.0,
64 tau=0.0025,
65 )
66 # steps_anneal can safely be set much lower (<500k) in most tasks. More sweeps needed.
67 exploration_type = cli_utils.switch_exploration(
68 config,
69 "egreedy",
70 steps_anneal=1_000_000,
71 )
72 cli_utils.use_config(config, args.configs)
73
74 group_name = f"{args.run_name}_{args.env}"
75 for trial in range(args.trials):
76 run_name = group_name + f"_trial_{trial}"
77 if args.multidomain:
78 make_train_env = lambda: MultiDomainPOPGymAMAGO()
79 else:
80 # in order to match the pre-gymnasium version of popgym (done instead of terminated/truncated),
81 # we need to set terminated = terminated or truncated
82 make_train_env = lambda: POPGymAMAGO(
83 f"popgym-{args.env}-v0", truncated_is_done=True
84 )
85 experiment = cli_utils.create_experiment_from_cli(
86 args,
87 make_train_env=make_train_env,
88 make_val_env=make_train_env,
89 max_seq_len=args.max_seq_len,
90 traj_save_len=artificial_horizon,
91 group_name=group_name,
92 run_name=run_name,
93 tstep_encoder_type=tstep_encoder_type,
94 traj_encoder_type=traj_encoder_type,
95 exploration_wrapper_type=exploration_type,
96 agent_type=agent_type,
97 val_timesteps_per_epoch=artificial_horizon,
98 learning_rate=1e-4,
99 grad_clip=1.0,
100 lr_warmup_steps=2000,
101 )
102 experiment = cli_utils.switch_async_mode(experiment, args.mode)
103 experiment.start()
104 if args.ckpt is not None:
105 experiment.load_checkpoint(args.ckpt)
106 experiment.learn()
107 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
108 experiment.delete_buffer_from_disk()
109 wandb.finish()