Meta-World ML1/ML10/ML45

Meta-World ML1/ML10/ML45#

 1from argparse import ArgumentParser
 2
 3import wandb
 4
 5from amago.envs.builtin.metaworld_ml import Metaworld
 6from amago.nets.tstep_encoders import FFTstepEncoder
 7from amago.envs.exploration import BilevelEpsilonGreedy
 8from amago import cli_utils
 9
10
11def add_cli(parser):
12    parser.add_argument(
13        "--benchmark",
14        type=str,
15        default="reach-v2",
16        help="`name-v2` for ML1, or `ml10`/`ml45`",
17    )
18    parser.add_argument("--k", type=int, default=3, help="K-Shots")
19    parser.add_argument("--max_seq_len", type=int, default=256)
20    parser.add_argument(
21        "--hide_rl2s",
22        action="store_true",
23        help="hides the 'rl2 info' (previous actions, rewards)",
24    )
25    return parser
26
27
28if __name__ == "__main__":
29    parser = ArgumentParser()
30    cli_utils.add_common_cli(parser)
31    add_cli(parser)
32    args = parser.parse_args()
33
34    config = {
35        "amago.nets.tstep_encoders.FFTstepEncoder.hide_rl2s": args.hide_rl2s,
36        # delete the next three lines to use the paper settings, which were
37        # intentionally left wide open to avoid reward-specific tuning.
38        "amago.nets.actor_critic.NCriticsTwoHot.min_return": -100.0,
39        "amago.nets.actor_critic.NCriticsTwoHot.max_return": 5000 * args.k,
40        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 96,
41        #"amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "rope",
42        "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 300,
43    }
44    traj_encoder_type = cli_utils.switch_traj_encoder(
45        config,
46        arch=args.traj_encoder,
47        memory_size=args.memory_size,
48        layers=args.memory_layers,
49    )
50    agent_type = cli_utils.switch_agent(
51        config, args.agent_type, reward_multiplier=1.0, num_critics=4
52    )
53    exploration_type = cli_utils.switch_exploration(
54        config, "bilevel", steps_anneal=2_000_000, rollout_horizon=args.k * 300
55    )
56    cli_utils.use_config(config, args.configs)
57
58    make_train_env = lambda: Metaworld(
59        args.benchmark, "train", k_episodes=args.k, max_episode_length=300
60    )
61    make_test_env = lambda: Metaworld(
62        args.benchmark, "test", k_episodes=args.k, max_episode_length=300
63    )
64
65    group_name = (
66        f"{args.run_name}_metaworld_{args.benchmark}_K_{args.k}_L_{args.max_seq_len}"
67    )
68    for trial in range(args.trials):
69        run_name = group_name + f"_trial_{trial}"
70        experiment = cli_utils.create_experiment_from_cli(
71            args,
72            make_train_env=make_train_env,
73            make_val_env=make_train_env,
74            max_seq_len=args.max_seq_len,
75            traj_save_len=min(300 * args.k + 1, args.max_seq_len * 4),
76            group_name=group_name,
77            run_name=run_name,
78            tstep_encoder_type=FFTstepEncoder,
79            traj_encoder_type=traj_encoder_type,
80            agent_type=agent_type,
81            val_timesteps_per_epoch=15 * args.k * 300 + 1,
82            learning_rate=5e-4,
83            grad_clip=2.0,
84            exploration_wrapper_type=exploration_type,
85        )
86
87        experiment = cli_utils.switch_async_mode(experiment, args.mode)
88        experiment.start()
89        if args.ckpt is not None:
90            experiment.load_checkpoint(args.ckpt)
91        experiment.learn()
92        experiment.evaluate_test(make_test_env, timesteps=20_000, render=False)
93        experiment.delete_buffer_from_disk()
94        wandb.finish()