1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.metaworld_ml import Metaworld
6from amago.nets.tstep_encoders import FFTstepEncoder
7from amago.envs.exploration import BilevelEpsilonGreedy
8from amago import cli_utils
9
10
11def add_cli(parser):
12 parser.add_argument(
13 "--benchmark",
14 type=str,
15 default="reach-v2",
16 help="`name-v2` for ML1, or `ml10`/`ml45`",
17 )
18 parser.add_argument("--k", type=int, default=3, help="K-Shots")
19 parser.add_argument("--max_seq_len", type=int, default=256)
20 parser.add_argument(
21 "--hide_rl2s",
22 action="store_true",
23 help="hides the 'rl2 info' (previous actions, rewards)",
24 )
25 return parser
26
27
28if __name__ == "__main__":
29 parser = ArgumentParser()
30 cli_utils.add_common_cli(parser)
31 add_cli(parser)
32 args = parser.parse_args()
33
34 config = {
35 "amago.nets.tstep_encoders.FFTstepEncoder.hide_rl2s": args.hide_rl2s,
36 # delete the next three lines to use the paper settings, which were
37 # intentionally left wide open to avoid reward-specific tuning.
38 "amago.nets.actor_critic.NCriticsTwoHot.min_return": -100.0,
39 "amago.nets.actor_critic.NCriticsTwoHot.max_return": 5000 * args.k,
40 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 96,
41 #"amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "rope",
42 "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 300,
43 }
44 traj_encoder_type = cli_utils.switch_traj_encoder(
45 config,
46 arch=args.traj_encoder,
47 memory_size=args.memory_size,
48 layers=args.memory_layers,
49 )
50 agent_type = cli_utils.switch_agent(
51 config, args.agent_type, reward_multiplier=1.0, num_critics=4
52 )
53 exploration_type = cli_utils.switch_exploration(
54 config, "bilevel", steps_anneal=2_000_000, rollout_horizon=args.k * 300
55 )
56 cli_utils.use_config(config, args.configs)
57
58 make_train_env = lambda: Metaworld(
59 args.benchmark, "train", k_episodes=args.k, max_episode_length=300
60 )
61 make_test_env = lambda: Metaworld(
62 args.benchmark, "test", k_episodes=args.k, max_episode_length=300
63 )
64
65 group_name = (
66 f"{args.run_name}_metaworld_{args.benchmark}_K_{args.k}_L_{args.max_seq_len}"
67 )
68 for trial in range(args.trials):
69 run_name = group_name + f"_trial_{trial}"
70 experiment = cli_utils.create_experiment_from_cli(
71 args,
72 make_train_env=make_train_env,
73 make_val_env=make_train_env,
74 max_seq_len=args.max_seq_len,
75 traj_save_len=min(300 * args.k + 1, args.max_seq_len * 4),
76 group_name=group_name,
77 run_name=run_name,
78 tstep_encoder_type=FFTstepEncoder,
79 traj_encoder_type=traj_encoder_type,
80 agent_type=agent_type,
81 val_timesteps_per_epoch=15 * args.k * 300 + 1,
82 learning_rate=5e-4,
83 grad_clip=2.0,
84 exploration_wrapper_type=exploration_type,
85 )
86
87 experiment = cli_utils.switch_async_mode(experiment, args.mode)
88 experiment.start()
89 if args.ckpt is not None:
90 experiment.load_checkpoint(args.ckpt)
91 experiment.learn()
92 experiment.evaluate_test(make_test_env, timesteps=20_000, render=False)
93 experiment.delete_buffer_from_disk()
94 wandb.finish()