-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_e3nn_jax.py
156 lines (121 loc) · 5 KB
/
train_e3nn_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copied from https://github.com/e3nn/e3nn-jax/blob/245e17eb23deaccad9f2c9cfd40fe40515e3c074/examples/tetris_point.py#L13
import time
import flax
import jax
import jax.numpy as jnp
import numpy as np
import jraph
import optax
from tqdm.auto import tqdm
import e3nn_jax as e3nn
def tetris() -> jraph.GraphsTuple:
pos = [
[[0, 0, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]], # chiral_shape_1
[[1, 1, 1], [1, 1, 2], [2, 1, 1], [2, 0, 1]], # chiral_shape_2
[[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], # square
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]], # line
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], # corner
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0]], # L
[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 1]], # T
[[0, 0, 0], [1, 0, 0], [1, 1, 0], [2, 1, 0]], # zigzag
]
pos = jnp.array(pos, dtype=jnp.float32)
labels = jnp.arange(8)
graphs = []
for p, l in zip(pos, labels):
senders, receivers = e3nn.radius_graph(p, 1.1)
graphs += [
jraph.GraphsTuple(
nodes=p.reshape((4, 3)), # [num_nodes, 3]
edges=None,
globals=l[None], # [num_graphs]
senders=senders, # [num_edges]
receivers=receivers, # [num_edges]
n_node=jnp.array([len(p)]), # [num_graphs]
n_edge=jnp.array([len(senders)]), # [num_graphs]
)
]
return jraph.batch(graphs)
class Layer(flax.linen.Module):
target_irreps: e3nn.Irreps
denominator: float
sh_lmax: int = 3
@flax.linen.compact
def __call__(self, graphs, positions):
target_irreps = e3nn.Irreps(self.target_irreps)
def update_edge_fn(edge_features, sender_features, receiver_features, globals):
sh = e3nn.spherical_harmonics(
list(range(1, self.sh_lmax + 1)),
positions[graphs.receivers] - positions[graphs.senders],
True,
)
return e3nn.concatenate(
[sender_features, e3nn.tensor_product(sender_features, sh)]
).regroup()
def update_node_fn(node_features, sender_features, receiver_features, globals):
node_feats = receiver_features / self.denominator
node_feats = e3nn.flax.Linear(target_irreps, name="linear_pre")(node_feats)
shortcut = e3nn.flax.Linear(
node_feats.irreps, name="shortcut", force_irreps_out=True
)(node_features)
return shortcut + node_feats
return jraph.GraphNetwork(update_edge_fn, update_node_fn)(graphs)
class Model(flax.linen.Module):
@flax.linen.compact
def __call__(self, graphs):
positions = e3nn.IrrepsArray("1o", graphs.nodes)
graphs = graphs._replace(nodes=jnp.ones((len(positions), 1)))
layers = 2 * ["32x0e + 32x0o + 8x1e + 8x1o + 8x2e + 8x2o"] + ["0o + 7x0e"]
for irreps in layers:
graphs = Layer(irreps, 1.5)(graphs, positions)
# Readout logits
pred = e3nn.scatter_sum(
graphs.nodes.array, nel=graphs.n_node
) # [num_graphs, 1 + 7]
odd, even1, even2 = pred[:, :1], pred[:, 1:2], pred[:, 2:]
logits = jnp.concatenate([odd * even1, -odd * even1, even2], axis=1)
assert logits.shape == (len(graphs.n_node), 8) # [num_graphs, num_classes]
return logits
def train(steps=200):
model = Model()
# Optimizer
opt = optax.adam(learning_rate=0.01)
def loss_fn(params, graphs):
logits = model.apply(params, graphs)
labels = graphs.globals # [num_graphs]
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
loss = jnp.mean(loss)
return loss, logits
@jax.jit
def update_fn(params, opt_state, graphs):
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(params, graphs)
labels = graphs.globals
accuracy = jnp.mean(jnp.argmax(logits, axis=1) == labels)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, accuracy
# Dataset
graphs = tetris()
# Init
init = jax.jit(model.init)
params = init(jax.random.PRNGKey(3), graphs)
opt_state = opt.init(params)
# compile jit
wall = time.perf_counter()
print("compiling...", flush=True)
for _ in range(3):
_, _, accuracy = update_fn(params, opt_state, graphs)
print(f"initial accuracy = {100 * accuracy:.0f}%", flush=True)
print(f"compilation took {time.perf_counter() - wall:.1f}s")
# Train
wall = time.perf_counter()
timings = []
for _ in tqdm(range(steps)):
start = time.time()
params, opt_state, accuracy = jax.block_until_ready(update_fn(params, opt_state, graphs))
timings.append(time.time() - start)
print(f"final accuracy = {100 * accuracy:.0f}%")
print(f"Training time/step {np.mean(timings[20:])*1000:.3f} ms")
if __name__ == "__main__":
train()