Meta-World ML1/ML10/ML45

Meta-World ML1/ML10/ML45#

 1from argparse import ArgumentParser
 2
 3import wandb
 4
 5from amago.envs.builtin.metaworld_ml import Metaworld
 6from amago.nets.tstep_encoders import FFTstepEncoder
 7from amago.envs.exploration import BilevelEpsilonGreedy
 8from amago.cli_utils import *
 9
10
11def add_cli(parser):
12    parser.add_argument(
13        "--benchmark",
14        type=str,
15        default="reach-v2",
16        help="`name-v2` for ML1, or `ml10`/`ml45`",
17    )
18    parser.add_argument("--k", type=int, default=3, help="K-Shots")
19    parser.add_argument("--max_seq_len", type=int, default=256)
20    parser.add_argument(
21        "--hide_rl2s",
22        action="store_true",
23        help="hides the 'rl2 info' (previous actions, rewards)",
24    )
25    return parser
26
27
28if __name__ == "__main__":
29    parser = ArgumentParser()
30    add_common_cli(parser)
31    add_cli(parser)
32    args = parser.parse_args()
33
34    config = {
35        "amago.nets.tstep_encoders.FFTstepEncoder.hide_rl2s": args.hide_rl2s,
36        # delete the next three lines to use the paper settings, which were
37        # intentionally left wide open to avoid reward-specific tuning.
38        "amago.nets.actor_critic.NCriticsTwoHot.min_return": -100.0,
39        "amago.nets.actor_critic.NCriticsTwoHot.max_return": 5000 * args.k,
40        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 96,
41    }
42    traj_encoder_type = switch_traj_encoder(
43        config,
44        arch=args.traj_encoder,
45        memory_size=args.memory_size,
46        layers=args.memory_layers,
47    )
48    agent_type = switch_agent(
49        config, args.agent_type, reward_multiplier=1.0, num_critics=4
50    )
51    exploration_type = switch_exploration(
52        config, "bilevel", steps_anneal=2_000_000, rollout_horizon=args.k * 500
53    )
54    use_config(config, args.configs)
55
56    make_train_env = lambda: Metaworld(args.benchmark, "train", k_episodes=args.k)
57    make_test_env = lambda: Metaworld(args.benchmark, "test", k_episodes=args.k)
58
59    group_name = (
60        f"{args.run_name}_metaworld_{args.benchmark}_K_{args.k}_L_{args.max_seq_len}"
61    )
62    for trial in range(args.trials):
63        run_name = group_name + f"_trial_{trial}"
64        experiment = create_experiment_from_cli(
65            args,
66            make_train_env=make_train_env,
67            make_val_env=make_train_env,
68            max_seq_len=args.max_seq_len,
69            traj_save_len=min(500 * args.k + 1, args.max_seq_len * 4),
70            group_name=group_name,
71            run_name=run_name,
72            tstep_encoder_type=FFTstepEncoder,
73            traj_encoder_type=traj_encoder_type,
74            agent_type=agent_type,
75            val_timesteps_per_epoch=15 * args.k * 500 + 1,
76            learning_rate=5e-4,
77            grad_clip=2.0,
78            exploration_wrapper_type=exploration_type,
79        )
80
81        experiment = switch_async_mode(experiment, args.mode)
82        experiment.start()
83        if args.ckpt is not None:
84            experiment.load_checkpoint(args.ckpt)
85        experiment.learn()
86        experiment.evaluate_test(make_test_env, timesteps=20_000, render=False)
87        experiment.delete_buffer_from_disk()
88        wandb.finish()