-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_medina_example.sh
executable file
·88 lines (86 loc) · 1.71 KB
/
train_medina_example.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/bin/bash
# This script trains for very few epochs on the dataset used by Medina et al. (2022).
# You can use this script to test whether the installed libraries work, but the
# resulting trained model will not be useful because real training requires a lot
# more training epochs. For real training, see example command line parameters in
# the files in the directory "hyperparameters".
python3.9 main.py --experiment_base_path "$(pwd)/experiments" \
--data_path "$(pwd)/data/medina_2022/medina_data.csv" \
--experiment_name medina-baseline \
--number_of_solutes \
156 \
--number_of_solvents \
262 \
--data \
Medina \
--split_dataset \
--split_sizes \
10%10%80% \
--exclude_solutes_solvents_not_present_in_train \
--num_workers \
4 \
--data_ensemble_id \
1 \
--model \
Diagonal-Gaussian-PMF-VI \
--diagonal_prior \
--data_likelihood_std \
0.15 \
--number_of_samples_for_expectation \
16 \
--dimensionality_of_embedding \
16 \
--predict_from_prior \
--get_point_estimate \
\
--maximize_entropy \
--batch_size \
1000 \
--number_epochs \
10 \
--use_seed \
--seed \
1 \
--save_test_summary \
--clipping_schedule \
--clip_grad_value \
--max_grad \
1.0 \
--clipping_schedule_max_grad_factor \
10.0 \
--clipping_schedule_fraction \
0.1 \
--lr \
0.001 \
--use_lr_schedule \
--lr_schedule \
Cyclical \
--lr_scheduler_number_of_cycles \
2 \
--graph_model \
FiLM \
--graph_featurizer \
simple-atom \
--graph_activation \
ELU \
--graph_jumping_knowledge \
cat \
--graph_norm \
LayerNorm \
--graph_final_prediction_agg \
mean \
--graph_num_relations \
4 \
--graph_dropout \
0.1 \
--graph_featurizer_atom_embedding_dim \
16 \
--graph_number_of_layers \
6 \
--graph_message_agg \
add \
--graph_residual_every_layer \
-1 \
--graph_bias \
--graph_hidden_channels \
16