1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.popgym_envs import POPGymAMAGO, MultiDomainPOPGymAMAGO
6from amago.cli_utils import *
7
8
9def add_cli(parser):
10 parser.add_argument("--env", type=str, default="AutoencodeEasy")
11 parser.add_argument("--max_seq_len", type=int, default=2000)
12 parser.add_argument(
13 "--multidomain",
14 action="store_true",
15 help="Activate 'MultiDomain' POPGym, where agents play 27 POPGym games at the same time in 1-shot format (2 episodes, second one counts).",
16 )
17 return parser
18
19
20if __name__ == "__main__":
21 parser = ArgumentParser()
22 add_common_cli(parser)
23 add_cli(parser)
24 args = parser.parse_args()
25
26 # whenever we need a "max rollout length" value, we use this arbitrarily large number
27 artificial_horizon = max(args.max_seq_len, 2000)
28
29 # fmt: off
30 config = {
31 "amago.nets.actor_critic.NCriticsTwoHot.min_return": None,
32 "amago.nets.actor_critic.NCriticsTwoHot.max_return": None,
33 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
34 "binary_filter.threshold": 1e-3, # not important
35 # learnable position embedding
36 "amago.nets.transformer.LearnablePosEmb.max_time_idx": artificial_horizon,
37 "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "learnable",
38 "amago.nets.traj_encoders.TformerTrajEncoder.attention_type": amago.nets.transformer.FlashAttention,
39 "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0, # not important
40 "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6, # not important
41 # paper version defaulted to large set of gamma values
42 "amago.agent.Multigammas.discrete": [0.1, 0.7, 0.9, 0.93, 0.95, 0.98, 0.99, 0.992, 0.994, 0.995, 0.997, 0.998, 0.999, 0.9991, 0.9992, 0.9993, 0.9994, 0.9995],
43 }
44 # fmt: on
45
46 traj_encoder_type = switch_traj_encoder(
47 config,
48 arch=args.traj_encoder,
49 memory_size=args.memory_size, # paper: 256
50 layers=args.memory_layers, # paper: 3
51 )
52 tstep_encoder_type = switch_tstep_encoder(
53 config,
54 arch="ff",
55 n_layers=2,
56 d_hidden=512,
57 d_output=200,
58 )
59 agent_type = switch_agent(
60 config,
61 args.agent_type,
62 reward_multiplier=200.0 if args.multidomain else 100.0,
63 tau=0.0025,
64 )
65 # steps_anneal can safely be set much lower (<500k) in most tasks. More sweeps needed.
66 exploration_type = switch_exploration(
67 config,
68 "egreedy",
69 steps_anneal=1_000_000,
70 )
71 use_config(config, args.configs)
72
73 group_name = f"{args.run_name}_{args.env}"
74 for trial in range(args.trials):
75 run_name = group_name + f"_trial_{trial}"
76 if args.multidomain:
77 make_train_env = lambda: MultiDomainPOPGymAMAGO()
78 else:
79 # in order to match the pre-gymnasium version of popgym (done instead of terminated/truncated),
80 # we need to set terminated = terminated or truncated
81 make_train_env = lambda: POPGymAMAGO(
82 f"popgym-{args.env}-v0", truncated_is_done=True
83 )
84 experiment = create_experiment_from_cli(
85 args,
86 make_train_env=make_train_env,
87 make_val_env=make_train_env,
88 max_seq_len=args.max_seq_len,
89 traj_save_len=artificial_horizon,
90 group_name=group_name,
91 run_name=run_name,
92 tstep_encoder_type=tstep_encoder_type,
93 traj_encoder_type=traj_encoder_type,
94 exploration_wrapper_type=exploration_type,
95 agent_type=agent_type,
96 val_timesteps_per_epoch=artificial_horizon,
97 learning_rate=1e-4,
98 grad_clip=1.0,
99 lr_warmup_steps=2000,
100 )
101 experiment = switch_async_mode(experiment, args.mode)
102 experiment.start()
103 if args.ckpt is not None:
104 experiment.load_checkpoint(args.ckpt)
105 experiment.learn()
106 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
107 experiment.delete_buffer_from_disk()
108 wandb.finish()