diff --git a/examples/qm9/qm9.json b/examples/qm9/qm9.json index b4e7c960..4c647b77 100644 --- a/examples/qm9/qm9.json +++ b/examples/qm9/qm9.json @@ -5,11 +5,12 @@ "NeuralNetwork": { "Profile": {"enable": 1}, "Architecture": { - "model_type": "SchNet", + "model_type": "MACE", "radius": 7, - "equivariance": true, "basis_emb_size": 8, "envelope_exponent": 5, + "max_ell": 2, + "node_max_ell": 2, "int_emb_size": 8, "out_emb_size": 8, "num_after_skip": 2, diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index d61696a6..377693b6 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -70,6 +70,8 @@ class MACEStack(Base): def __init__( self, + input_args, + conv_args, r_max: float, # The cutoff radius for the radial basis functions and edge_index radial_type: str, # The type of radial basis function to use distance_transform: str, # The distance transform to use @@ -126,7 +128,11 @@ def __init__( ) # This makes the irreps string self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) + + if self.use_edge_attr: + self.input_args = "node_attributes, geom_feat, inv_feat, edge_attributes, edge_features, edge_index" + self.conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward self.spherical_harmonics = o3.SphericalHarmonics( self.sh_irreps, @@ -335,70 +341,53 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): hidden_irreps_out, output_irreps ) # Change sizing to output_irreps - input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" - conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward - if not last_layer: return PyGSequential( - input_args, + self.input_args, [ - (inter, "node_features, " + conv_args + " -> node_features, sc"), + (inter, "inv_feat, " + self.conv_args + " -> node_features, sc"), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), ( - lambda node_features, pos: [node_features, pos], - "node_features, pos -> node_features, pos", + lambda node_features, geom_feat: [node_features, geom_feat], + "node_features, geom_feat -> node_features, geom_feat", ), ], ) else: return PyGSequential( - input_args, + self.input_args, [ - (inter, "node_features, " + conv_args + " -> node_features, sc"), + (inter, "inv_feat, " + self.conv_args + " -> node_features, sc"), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), ( - lambda node_features, pos: [node_features, pos], - "node_features, pos -> node_features, pos", + lambda node_features, geom_feat: [node_features, geom_feat], + "node_features, geom_feat -> node_features, geom_feat", ), ], ) def forward(self, data): - data, conv_args = self._conv_args(data) - node_features = data.node_features - node_attributes = data.node_attributes - pos = data.pos - - ### encoder / decoder part #### - ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance - - ### There is a readout before the first convolution layer ### - outputs = [] - output = self.multihead_decoders[0]( - data, node_attributes - ) # [index][n_output, size_output] - # Create outputs first - outputs = output + inv_feat, geom_feat, outputs, conv_args = self._embedding(data) ### Do conv --> readout --> repeat for each convolution layer ### for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]): if not self.conv_checkpointing: - node_features, pos = conv( - node_features=node_features, pos=pos, **conv_args + inv_feat, geom_feat = conv( + inv_feat=inv_feat, geom_feat=geom_feat, **conv_args ) - output = readout(data, node_features) # [index][n_output, size_output] + output = readout(data, inv_feat) # [index][n_output, size_output] else: - node_features, pos = checkpoint( + inv_feat, geom_feat = checkpoint( conv, use_reentrant=False, - node_features=node_features, - pos=pos, + inv_feat=inv_feat, + geom_feat=geom_feat, **conv_args, ) output = readout( - data, node_features + data, inv_feat ) # output is a list of tensors with [index][n_output, size_output] # Sum predictions for each index, taking care of size differences for idx, prediction in enumerate(output): @@ -406,7 +395,7 @@ def forward(self, data): return outputs - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "MACE requires node positions (data.pos) to be set." @@ -452,7 +441,15 @@ def _conv_args(self, data): "edge_index": data.edge_index, } - return data, conv_args + node_attributes = data.node_attributes + outputs = [] + output = self.multihead_decoders[0]( + data, node_attributes + ) # [index][n_output, size_output] + # Create outputs first + outputs = output + + return data.node_features, data.pos, outputs, conv_args def _multihead(self): # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, diff --git a/hydragnn/models/PNAEqStack.py b/hydragnn/models/PNAEqStack.py index 8919914f..1ab82dff 100644 --- a/hydragnn/models/PNAEqStack.py +++ b/hydragnn/models/PNAEqStack.py @@ -60,7 +60,7 @@ def __init__( self.num_radial = num_radial self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) self.rbf = rbf_BasisLayer(self.num_radial, self.radius) @@ -106,29 +106,22 @@ def get_conv(self, input_dim, output_dim, last_layer=False): geom_nn.Linear(input_dim, output_dim) if not last_layer else None ) - input_args = "x, v, pos, edge_index, edge_rbf, edge_vec" - conv_args = "x, v, edge_index, edge_rbf, edge_vec" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - if not last_layer: return geom_nn.Sequential( - input_args, + self.input_args, [ - (message, conv_args + " -> x, v"), + (message, self.conv_args + " -> x, v"), (update, "x, v -> x, v"), (node_embed_out, "x -> x"), (vec_embed_out, "v -> v"), - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + (lambda x, v: [x, v], "x, v -> x, v"), ], ) else: return geom_nn.Sequential( - input_args, + self.input_args, [ - (message, conv_args + " -> x, v"), + (message, self.conv_args + " -> x, v"), ( update, "x, v -> x", @@ -137,60 +130,12 @@ def get_conv(self, input_dim, output_dim, last_layer=False): node_embed_out, "x -> x", ), # No need to embed down v because it's not used anymore - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + (lambda x, v: [x, v], "x, v -> x, v"), ], ) - def forward(self, data): - data, conv_args = self._conv_args( - data - ) # Added v to data here (necessary for PNAEq Stack) - x = data.x - v = data.v - pos = data.pos - - ### encoder part #### - for conv, feat_layer in zip(self.graph_convs, self.feature_layers): - if not self.conv_checkpointing: - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here - else: - c, v, pos = checkpoint( # Added v here - conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args - ) - x = self.activation_function(feat_layer(c)) - #### multi-head decoder part#### - # shared dense layers for graph level output - if data.batch is None: - x_graph = x.mean(dim=0, keepdim=True) - else: - x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device)) - outputs = [] - outputs_var = [] - for head_dim, headloc, type_head in zip( - self.head_dims, self.heads_NN, self.head_type - ): - if type_head == "graph": - x_graph_head = self.graph_shared(x_graph) - output_head = headloc(x_graph_head) - outputs.append(output_head[:, :head_dim]) - outputs_var.append(output_head[:, head_dim:] ** 2) - else: - if self.node_NN_type == "conv": - for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) - c = batch_norm(c) - x = self.activation_function(c) - x_node = x - else: - x_node = headloc(x=x, batch=data.batch) - outputs.append(x_node[:, :head_dim]) - outputs_var.append(x_node[:, head_dim:] ** 2) - if self.var_output: - return outputs, outputs_var - return outputs - - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "PNAEq requires node positions (data.pos) to be set." @@ -212,7 +157,7 @@ def _conv_args(self, data): "edge_vec": norm_diff, } - return data, conv_args + return data.x, data.v, conv_args class PainnMessage(MessagePassing): diff --git a/hydragnn/models/PNAPlusStack.py b/hydragnn/models/PNAPlusStack.py index 06561d6d..276ecb42 100644 --- a/hydragnn/models/PNAPlusStack.py +++ b/hydragnn/models/PNAPlusStack.py @@ -39,6 +39,8 @@ class PNAPlusStack(Base): def __init__( self, + input_args, + conv_args, deg: list, edge_dim: int, envelope_exponent: int, @@ -61,7 +63,7 @@ def __init__( self.num_radial = num_radial self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args,conv_args, *args, **kwargs) self.rbf = BesselBasisLayer( self.num_radial, self.radius, self.envelope_exponent @@ -81,22 +83,15 @@ def get_conv(self, input_dim, output_dim): divide_input=False, ) - input_args = "x, pos, edge_index, rbf" - conv_args = "x, edge_index, rbf" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - return PyGSequential( - input_args, + self.input_args, [ - (pna, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (pna, self.conv_args + " -> x"), + (lambda x, pos: [x, pos], "x, geom_feat -> x, geom_feat"), ], ) - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "PNA+ requires node positions (data.pos) to be set." @@ -113,7 +108,7 @@ def _conv_args(self, data): ), "Data must have edge attributes if use_edge_attributes is set." conv_args.update({"edge_attr": data.edge_attr}) - return conv_args + return data.x, data.pos, conv_args def __str__(self): return "PNAStack" diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index f5be2c0b..7f22e007 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -174,6 +174,8 @@ def create_model( assert num_radial is not None, "PNAPlus requires num_radial input." assert radius is not None, "PNAPlus requires radius input." model = PNAPlusStack( + "inv_feat, geom_feat, edge_index, rbf", + "inv_feat, edge_index, rbf", pna_deg, edge_dim, envelope_exponent, @@ -388,6 +390,8 @@ def create_model( elif model_type == "PNAEq": assert pna_deg is not None, "PNAEq requires degree input." model = PNAEqStack( + "inv_feat, geom_feat, edge_index, edge_rbf, edge_vec", + "inv_feat, geom_feat, edge_index, edge_rbf, edge_vec", pna_deg, edge_dim, num_radial, @@ -414,6 +418,8 @@ def create_model( assert max_ell >= 1, "MACE requires max_ell >= 1." assert node_max_ell >= 1, "MACE requires node_max_ell >= 1." model = MACEStack( + "node_attributes, geom_feat, inv_feat, edge_attributes, edge_features, edge_index", + "node_attributes, edge_attributes, edge_features, edge_index", radius, radial_type, distance_transform,