Symbolic DM Alchemy

Symbolic DM Alchemy#

 1from argparse import ArgumentParser
 2
 3import wandb
 4
 5from amago.envs import AMAGOEnv
 6from amago.envs.builtin.alchemy import SymbolicAlchemy
 7from amago.cli_utils import *
 8
 9
10if __name__ == "__main__":
11    parser = ArgumentParser()
12    add_common_cli(parser)
13    args = parser.parse_args()
14
15    config = {}
16    traj_encoder_type = switch_traj_encoder(
17        config,
18        arch=args.traj_encoder,
19        memory_size=args.memory_size,
20        layers=args.memory_layers,
21    )
22    exploration_wrapper_type = switch_exploration(
23        config, "bilevel", rollout_horizon=200, steps_anneal=2_500_000
24    )
25    agent_type = switch_agent(config, args.agent_type, reward_multiplier=100.0)
26    tstep_encoder_type = switch_tstep_encoder(
27        config, arch="ff", n_layers=2, d_hidden=256, d_output=256
28    )
29
30    use_config(config, args.configs)
31    make_train_env = lambda: AMAGOEnv(
32        env=SymbolicAlchemy(),
33        env_name="dm_symbolic_alchemy",
34    )
35    group_name = f"{args.run_name}_symbolic_dm_alchemy"
36    for trial in range(args.trials):
37        run_name = group_name + f"_trial_{trial}"
38        experiment = create_experiment_from_cli(
39            args,
40            make_train_env=make_train_env,
41            make_val_env=make_train_env,
42            max_seq_len=201,
43            traj_save_len=201,
44            group_name=group_name,
45            run_name=run_name,
46            tstep_encoder_type=tstep_encoder_type,
47            traj_encoder_type=traj_encoder_type,
48            exploration_wrapper_type=exploration_wrapper_type,
49            agent_type=agent_type,
50            val_timesteps_per_epoch=2000,
51        )
52        switch_async_mode(experiment, args.mode)
53        experiment.start()
54        if args.ckpt is not None:
55            experiment.load_checkpoint(args.ckpt)
56        experiment.learn()
57        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
58        experiment.delete_buffer_from_disk()
59        wandb.finish()