1from argparse import ArgumentParser
2
3import gymnasium as gym
4import wandb
5
6import amago
7from amago.envs import AMAGOEnv
8from amago.cli_utils import *
9
10
11def add_cli(parser):
12 parser.add_argument(
13 "--env", type=str, required=True, help="Environment name for `gym.make`"
14 )
15 parser.add_argument(
16 "--max_seq_len", type=int, default=128, help="Policy sequence length."
17 )
18 parser.add_argument(
19 "--eval_timesteps",
20 type=int,
21 default=1000,
22 help="Timesteps per actor per evaluation. Tune based on the episode length of the environment (to be at least one full episode).",
23 )
24 return parser
25
26
27if __name__ == "__main__":
28 parser = ArgumentParser()
29 add_common_cli(parser)
30 add_cli(parser)
31 args = parser.parse_args()
32
33 # setup environment
34 env_name = args.env.replace("/", "_")
35 make_train_env = lambda: AMAGOEnv(
36 gym.make(args.env),
37 env_name=env_name,
38 )
39 config = {
40 # dictionary that sets default value for kwargs of classes that are marked as `gin.configurable`
41 # see `tutorial.md` for more information. For example:
42 "amago.nets.policy_dists.Discrete.clip_prob_high": 1.0,
43 "amago.nets.policy_dists.Discrete.clip_prob_low": 1e-6,
44 }
45 # switch sequence model
46 traj_encoder_type = switch_traj_encoder(
47 config,
48 arch=args.traj_encoder,
49 memory_size=args.memory_size,
50 layers=args.memory_layers,
51 )
52 # switch agent
53 agent_type = switch_agent(
54 config,
55 args.agent_type,
56 reward_multiplier=1.0,
57 )
58 use_config(config, args.configs)
59
60 group_name = f"{args.run_name}_{env_name}"
61 for trial in range(args.trials):
62 run_name = group_name + f"_trial_{trial}"
63 experiment = create_experiment_from_cli(
64 args,
65 make_train_env=make_train_env,
66 make_val_env=make_train_env,
67 max_seq_len=args.max_seq_len,
68 traj_save_len=args.max_seq_len * 8,
69 run_name=run_name,
70 tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
71 traj_encoder_type=traj_encoder_type,
72 agent_type=agent_type,
73 group_name=group_name,
74 val_timesteps_per_epoch=args.eval_timesteps,
75 )
76 experiment = switch_async_mode(experiment, args.mode)
77 experiment.start()
78 if args.ckpt is not None:
79 experiment.load_checkpoint(args.ckpt)
80 experiment.learn()
81 experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
82 experiment.delete_buffer_from_disk()
83 wandb.finish()