Skip to content

Commit

Permalink
refactor pna addons - sketch MACE
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinBakerMath committed Oct 29, 2024
1 parent c107ff7 commit 0640baf
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 115 deletions.
5 changes: 3 additions & 2 deletions examples/qm9/qm9.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
"NeuralNetwork": {
"Profile": {"enable": 1},
"Architecture": {
"model_type": "SchNet",
"model_type": "MACE",
"radius": 7,
"equivariance": true,
"basis_emb_size": 8,
"envelope_exponent": 5,
"max_ell": 2,
"node_max_ell": 2,
"int_emb_size": 8,
"out_emb_size": 8,
"num_after_skip": 2,
Expand Down
69 changes: 33 additions & 36 deletions hydragnn/models/MACEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
class MACEStack(Base):
def __init__(
self,
input_args,
conv_args,
r_max: float, # The cutoff radius for the radial basis functions and edge_index
radial_type: str, # The type of radial basis function to use
distance_transform: str, # The distance transform to use
Expand Down Expand Up @@ -126,7 +128,11 @@ def __init__(
) # This makes the irreps string
self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e")

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

if self.use_edge_attr:
self.input_args = "node_attributes, geom_feat, inv_feat, edge_attributes, edge_features, edge_index"
self.conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward

self.spherical_harmonics = o3.SphericalHarmonics(
self.sh_irreps,
Expand Down Expand Up @@ -335,78 +341,61 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False):
hidden_irreps_out, output_irreps
) # Change sizing to output_irreps

input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index"
conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward

if not last_layer:
return PyGSequential(
input_args,
self.input_args,
[
(inter, "node_features, " + conv_args + " -> node_features, sc"),
(inter, "inv_feat, " + self.conv_args + " -> node_features, sc"),
(prod, "node_features, sc, node_attributes -> node_features"),
(sizing, "node_features -> node_features"),
(
lambda node_features, pos: [node_features, pos],
"node_features, pos -> node_features, pos",
lambda node_features, geom_feat: [node_features, geom_feat],
"node_features, geom_feat -> node_features, geom_feat",
),
],
)
else:
return PyGSequential(
input_args,
self.input_args,
[
(inter, "node_features, " + conv_args + " -> node_features, sc"),
(inter, "inv_feat, " + self.conv_args + " -> node_features, sc"),
(prod, "node_features, sc, node_attributes -> node_features"),
(sizing, "node_features -> node_features"),
(
lambda node_features, pos: [node_features, pos],
"node_features, pos -> node_features, pos",
lambda node_features, geom_feat: [node_features, geom_feat],
"node_features, geom_feat -> node_features, geom_feat",
),
],
)

def forward(self, data):
data, conv_args = self._conv_args(data)
node_features = data.node_features
node_attributes = data.node_attributes
pos = data.pos

### encoder / decoder part ####
## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance

### There is a readout before the first convolution layer ###
outputs = []
output = self.multihead_decoders[0](
data, node_attributes
) # [index][n_output, size_output]
# Create outputs first
outputs = output
inv_feat, geom_feat, outputs, conv_args = self._embedding(data)

### Do conv --> readout --> repeat for each convolution layer ###
for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]):
if not self.conv_checkpointing:
node_features, pos = conv(
node_features=node_features, pos=pos, **conv_args
inv_feat, geom_feat = conv(
inv_feat=inv_feat, geom_feat=geom_feat, **conv_args
)
output = readout(data, node_features) # [index][n_output, size_output]
output = readout(data, inv_feat) # [index][n_output, size_output]
else:
node_features, pos = checkpoint(
inv_feat, geom_feat = checkpoint(
conv,
use_reentrant=False,
node_features=node_features,
pos=pos,
inv_feat=inv_feat,
geom_feat=geom_feat,
**conv_args,
)
output = readout(
data, node_features
data, inv_feat
) # output is a list of tensors with [index][n_output, size_output]
# Sum predictions for each index, taking care of size differences
for idx, prediction in enumerate(output):
outputs[idx] = outputs[idx] + prediction

return outputs

def _conv_args(self, data):
def _embedding(self, data):
assert (
data.pos is not None
), "MACE requires node positions (data.pos) to be set."
Expand Down Expand Up @@ -452,7 +441,15 @@ def _conv_args(self, data):
"edge_index": data.edge_index,
}

return data, conv_args
node_attributes = data.node_attributes
outputs = []
output = self.multihead_decoders[0](
data, node_attributes
) # [index][n_output, size_output]
# Create outputs first
outputs = output

return data.node_features, data.pos, outputs, conv_args

def _multihead(self):
# NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer,
Expand Down
73 changes: 9 additions & 64 deletions hydragnn/models/PNAEqStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.num_radial = num_radial
self.radius = radius

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

self.rbf = rbf_BasisLayer(self.num_radial, self.radius)

Expand Down Expand Up @@ -106,29 +106,22 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
geom_nn.Linear(input_dim, output_dim) if not last_layer else None
)

input_args = "x, v, pos, edge_index, edge_rbf, edge_vec"
conv_args = "x, v, edge_index, edge_rbf, edge_vec"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

if not last_layer:
return geom_nn.Sequential(
input_args,
self.input_args,
[
(message, conv_args + " -> x, v"),
(message, self.conv_args + " -> x, v"),
(update, "x, v -> x, v"),
(node_embed_out, "x -> x"),
(vec_embed_out, "v -> v"),
(lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"),
(lambda x, v: [x, v], "x, v -> x, v"),
],
)
else:
return geom_nn.Sequential(
input_args,
self.input_args,
[
(message, conv_args + " -> x, v"),
(message, self.conv_args + " -> x, v"),
(
update,
"x, v -> x",
Expand All @@ -137,60 +130,12 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
node_embed_out,
"x -> x",
), # No need to embed down v because it's not used anymore
(lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"),
(lambda x, v: [x, v], "x, v -> x, v"),
],
)

def forward(self, data):
data, conv_args = self._conv_args(
data
) # Added v to data here (necessary for PNAEq Stack)
x = data.x
v = data.v
pos = data.pos

### encoder part ####
for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
if not self.conv_checkpointing:
c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here
else:
c, v, pos = checkpoint( # Added v here
conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args
)
x = self.activation_function(feat_layer(c))

#### multi-head decoder part####
# shared dense layers for graph level output
if data.batch is None:
x_graph = x.mean(dim=0, keepdim=True)
else:
x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device))
outputs = []
outputs_var = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
):
if type_head == "graph":
x_graph_head = self.graph_shared(x_graph)
output_head = headloc(x_graph_head)
outputs.append(output_head[:, :head_dim])
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, v, pos = conv(x=x, v=v, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
outputs_var.append(x_node[:, head_dim:] ** 2)
if self.var_output:
return outputs, outputs_var
return outputs

def _conv_args(self, data):
def _embedding(self, data):
assert (
data.pos is not None
), "PNAEq requires node positions (data.pos) to be set."
Expand All @@ -212,7 +157,7 @@ def _conv_args(self, data):
"edge_vec": norm_diff,
}

return data, conv_args
return data.x, data.v, conv_args


class PainnMessage(MessagePassing):
Expand Down
21 changes: 8 additions & 13 deletions hydragnn/models/PNAPlusStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
class PNAPlusStack(Base):
def __init__(
self,
input_args,
conv_args,
deg: list,
edge_dim: int,
envelope_exponent: int,
Expand All @@ -61,7 +63,7 @@ def __init__(
self.num_radial = num_radial
self.radius = radius

super().__init__(*args, **kwargs)
super().__init__(input_args,conv_args, *args, **kwargs)

self.rbf = BesselBasisLayer(
self.num_radial, self.radius, self.envelope_exponent
Expand All @@ -81,22 +83,15 @@ def get_conv(self, input_dim, output_dim):
divide_input=False,
)

input_args = "x, pos, edge_index, rbf"
conv_args = "x, edge_index, rbf"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return PyGSequential(
input_args,
self.input_args,
[
(pna, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(pna, self.conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, geom_feat -> x, geom_feat"),
],
)

def _conv_args(self, data):
def _embedding(self, data):
assert (
data.pos is not None
), "PNA+ requires node positions (data.pos) to be set."
Expand All @@ -113,7 +108,7 @@ def _conv_args(self, data):
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return conv_args
return data.x, data.pos, conv_args

def __str__(self):
return "PNAStack"
Expand Down
6 changes: 6 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def create_model(
assert num_radial is not None, "PNAPlus requires num_radial input."
assert radius is not None, "PNAPlus requires radius input."
model = PNAPlusStack(
"inv_feat, geom_feat, edge_index, rbf",
"inv_feat, edge_index, rbf",
pna_deg,
edge_dim,
envelope_exponent,
Expand Down Expand Up @@ -388,6 +390,8 @@ def create_model(
elif model_type == "PNAEq":
assert pna_deg is not None, "PNAEq requires degree input."
model = PNAEqStack(
"inv_feat, geom_feat, edge_index, edge_rbf, edge_vec",
"inv_feat, geom_feat, edge_index, edge_rbf, edge_vec",
pna_deg,
edge_dim,
num_radial,
Expand All @@ -414,6 +418,8 @@ def create_model(
assert max_ell >= 1, "MACE requires max_ell >= 1."
assert node_max_ell >= 1, "MACE requires node_max_ell >= 1."
model = MACEStack(
"node_attributes, geom_feat, inv_feat, edge_attributes, edge_features, edge_index",
"node_attributes, edge_attributes, edge_features, edge_index",
radius,
radial_type,
distance_transform,
Expand Down

0 comments on commit 0640baf

Please sign in to comment.