diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 214ab12..9bf26f8 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -108,6 +108,7 @@ def _get_representation(self, reduce_output=False): # Remove last layer model_copy.layers = model_copy.layers[:-1] model_copy.reduce_output = reduce_output + model_copy.irreps_out = model_copy.layers[-1].irreps_out return model_copy @@ -169,7 +170,10 @@ def _get_concat_strategy(self): ``ConcatStrategy`` for the model """ - return ConcatStrategy(extract_key="x") + # Calculate input size as 3 * dimensionality of output of Representation + # (last layer in Representation is 2nd to last in original model) + input_size = 3 * self.layers[-2].irreps_out.dim + return ConcatStrategy(input_size=input_size, extract_key="x") @staticmethod def get_model( @@ -227,24 +231,21 @@ def get_model( if model is None: model = E3NN(model_kwargs) + # Get representation module + representation = model._get_representation(reduce_output=strategy == "concat") + # Construct strategy module based on model and # representation (if necessary) strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy() - reduce_output = False elif strategy == "concat": strategy = model._get_concat_strategy() - reduce_output = True elif strategy == "complex": strategy = model._get_complex_only_strategy() - reduce_output = False else: raise ValueError(f"Unknown strategy: {strategy}") - # Get representation module - representation = model._get_representation(reduce_output=reduce_output) - # Check on `combination` if grouped and (combination is None): raise ValueError( diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index b92944c..a26db02 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -143,6 +143,22 @@ def _get_complex_only_strategy(self): return ComplexOnlyStrategy(self._get_energy_func()) + def _get_concat_strategy(self): + """ + Build a :py:class:`ConcatStrategy ` object with + the appropriate ``input_size``. + + Returns + ------- + ConcatStrategy + ``ConcatStrategy`` for the model + """ + + # Calculate input size as 3 * dimensionality of output of Representation + # (ie lin1 layer) + input_size = 3 * self.lin1.out_features + return ConcatStrategy(input_size=input_size) + @staticmethod def get_model( model=None, @@ -203,7 +219,7 @@ def get_model( if strategy == "delta": strategy = model._get_delta_strategy() elif strategy == "concat": - strategy = ConcatStrategy() + strategy = model._get_concat_strategy() elif strategy == "complex": strategy = model._get_complex_only_strategy() else: diff --git a/mtenn/strategy.py b/mtenn/strategy.py index 4def3ef..5a4d89c 100644 --- a/mtenn/strategy.py +++ b/mtenn/strategy.py @@ -115,18 +115,20 @@ class ConcatStrategy(Strategy): initialize a one-layer linear network of the appropriate dimensionality. """ - def __init__(self, extract_key=None): + def __init__(self, input_size, extract_key=None): """ Set the key to use to access vector representations if ``dict`` s are passed to the ``forward`` call. Parameters ---------- + input_size : int + Input size of linear model extract_key : str, optional Key to use to extract representation from a dict """ super(ConcatStrategy, self).__init__() - self.reduce_nn: torch.nn.Module = None + self.reduce_nn = torch.nn.Linear(input_size, 1) self.extract_key = extract_key def forward(self, comp, *parts): @@ -158,17 +160,8 @@ def forward(self, comp, *parts): comp = comp.flatten() parts = [p.flatten() for p in parts] - parts_size = sum([len(p) for p in parts]) - if self.reduce_nn is None: - # If we haven't already, initialize a Linear module with appropriate input - # size - input_size = len(comp) + parts_size - self.reduce_nn = torch.nn.Linear(input_size, 1) - - # Move self.reduce_nn to appropriate torch device - self.reduce_nn = self.reduce_nn.to(comp.device) - # Enumerate all possible permutations of parts and add together + parts_size = sum([len(p) for p in parts]) parts_cat = torch.zeros((parts_size), device=comp.device) for idxs in permutations(range(len(parts)), len(parts)): parts_cat += torch.cat([parts[i] for i in idxs])