diff --git a/examples/LennardJones/LJ.json b/examples/LennardJones/LJ.json index 022a1398b..27450da55 100644 --- a/examples/LennardJones/LJ.json +++ b/examples/LennardJones/LJ.json @@ -21,7 +21,7 @@ "periodic_boundary_conditions": true, "global_attn_engine": "", "global_attn_type": "", - "mpnn_type": "EGNN", + "mpnn_type": "DimeNet", "radius": 5.0, "max_neighbours": 5, "int_emb_size": 32, @@ -61,7 +61,7 @@ "output_names": ["graph_energy"] }, "Training": { - "num_epoch": 15, + "num_epoch": 25, "batch_size": 64, "perc_train": 0.7, "patience": 20, diff --git a/examples/LennardJones/LJ_data.py b/examples/LennardJones/LJ_data.py index 83a868eac..ce1a607fa 100644 --- a/examples/LennardJones/LJ_data.py +++ b/examples/LennardJones/LJ_data.py @@ -21,6 +21,7 @@ import torch from torch_geometric.data import Data from torch_geometric.transforms import AddLaplacianEigenvectorPE +from torch_scatter import scatter # torch.set_default_tensor_type(torch.DoubleTensor) # torch.set_default_dtype(torch.float64) @@ -36,6 +37,7 @@ from hydragnn.utils.datasets.abstractrawdataset import AbstractBaseDataset from hydragnn.utils.distributed import nsplit from hydragnn.preprocess.graph_samples_checks_and_updates import get_radius_graph_pbc +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths # Angstrom unit primitive_bravais_lattice_constant_x = 3.8 @@ -51,6 +53,7 @@ def create_dataset(path, config): radius_cutoff = config["NeuralNetwork"]["Architecture"]["radius"] + max_num_neighbors = config["NeuralNetwork"]["Architecture"]["max_neighbours"] number_configurations = ( config["Dataset"]["number_configurations"] if "number_configurations" in config["Dataset"] @@ -73,6 +76,7 @@ def create_dataset(path, config): atom_types, atomic_structure_handler=atomic_structure_handler, radius_cutoff=radius_cutoff, + max_num_neighbors=max_num_neighbors, relative_maximum_atomic_displacement=1e-1, number_configurations=number_configurations, ) @@ -167,7 +171,7 @@ def transform_input_to_data_object_base(self, filepath): forces_pre_scaled = forces * forces_pre_scaling_factor data = Data( - supercell_size=torch_supercell.to(torch.float32), + cell=torch_supercell.to(torch.float32), num_nodes=num_nodes, grad_energy_post_scaling_factor=grad_energy_post_scaling_factor, forces_pre_scaling_factor=torch.tensor(forces_pre_scaling_factor).to( @@ -182,11 +186,14 @@ def transform_input_to_data_object_base(self, filepath): .unsqueeze(0) .to(torch.float32), energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), - pbc=[ - True, - True, - True, - ], # LJ example always has periodic boundary conditions + pbc=torch.tensor( + [ + True, + True, + True, + ], + dtype=torch.bool, + ), # LJ example always has periodic boundary conditions ) # Create pbc edges and lengths @@ -345,37 +352,28 @@ def create_configuration( supercell_size_x = primitive_bravais_lattice_constant_x * uc_x supercell_size_y = primitive_bravais_lattice_constant_y * uc_y supercell_size_z = primitive_bravais_lattice_constant_z * uc_z - data.supercell_size = torch.diag( + data.cell = torch.diag( torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z]) ) - data.pbc = [True, True, True] + data.pbc = torch.tensor([True, True, True], dtype=torch.bool) + data.x = torch.cat([atom_types, positions], dim=1) create_graph_connectivity_pbc = get_radius_graph_pbc( radius_cutoff, max_num_neighbors ) data = create_graph_connectivity_pbc(data) - atomic_descriptors = torch.cat( - ( - atom_types, - positions, - ), - 1, - ) - - data.x = atomic_descriptors - data = atomic_structure_handler.compute(data) total_energy = torch.sum(data.x[:, 4]) energy_per_atom = total_energy / number_nodes - total_energy_str = numpy.array2string(total_energy.detach().numpy()) - energy_per_atom_str = numpy.array2string(energy_per_atom.detach().numpy()) + total_energy_str = numpy.array2string(total_energy.detach().cpu().numpy()) + energy_per_atom_str = numpy.array2string(energy_per_atom.detach().cpu().numpy()) filetxt = total_energy_str + "\n" + energy_per_atom_str for index in range(0, 3): - numpy_row = data.supercell_size[index, :].detach().numpy() + numpy_row = data.cell[index, :].detach().numpy() numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t") filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]") @@ -402,73 +400,47 @@ def __init__( self.radius_cutoff = radius_cutoff self.formula = formula + # Calculate the potential energy with torch gradient tracking, then simply use autograd to calculate the forces def compute(self, data): + # Instantiate assert data.pos.shape[0] == data.x.shape[0] - - interatomic_potential = torch.zeros([data.pos.shape[0], 1]) - interatomic_forces = torch.zeros([data.pos.shape[0], 3]) - - for node_id in range(data.pos.shape[0]): - neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[ - 0 - ].tolist() - neighbor_list = data.edge_index[1, neighbor_list_indices] - - for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices): - neighbor_pos = data.pos[neighbor_id, :] - distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :] - - # Adjust the neighbor position based on periodic boundary conditions (PBC) - ## If the distance between the atoms is larger than the cutoff radius, the edge is because of PBC conditions - if torch.norm(distance_vector) > self.radius_cutoff: - ## At this point, we know that the edge is due to PBC conditions, so we need to adjust the neighbor position. We also know that - ## that this connection MUST be the closest connection possible as a result of the asserted radius_cutoff < supercell_size earlier - ## in the code. Because of this, we can simply adjust the neighbor position coordinate-wise to be closer than - ## as done in the following lines of code. The logic goes that if the distance vector[index] is larger than half the supercell size, - ## then there is a closer distance at +- supercell_size[index], and we adjust to that for each coordinate - if abs(distance_vector[0]) > data.supercell_size[0, 0] / 2: - if distance_vector[0] > 0: - neighbor_pos[0] -= data.supercell_size[0, 0] - else: - neighbor_pos[0] += data.supercell_size[0, 0] - - if abs(distance_vector[1]) > data.supercell_size[1, 1] / 2: - if distance_vector[1] > 0: - neighbor_pos[1] -= data.supercell_size[1, 1] - else: - neighbor_pos[1] += data.supercell_size[1, 1] - - if abs(distance_vector[2]) > data.supercell_size[2, 2] / 2: - if distance_vector[2] > 0: - neighbor_pos[2] -= data.supercell_size[2, 2] - else: - neighbor_pos[2] += data.supercell_size[2, 2] - - # The distance vecor may need to be updated after applying PBCs - distance_vector = data.pos[node_id, :] - neighbor_pos - - # pair_distance = data.edge_attr[edge_id].item() - interatomic_potential[node_id] += self.formula.potential_energy( - distance_vector - ) - - derivative_x = self.formula.derivative_x(distance_vector) - derivative_y = self.formula.derivative_y(distance_vector) - derivative_z = self.formula.derivative_z(distance_vector) - - interatomic_forces_contribution_x = -derivative_x - interatomic_forces_contribution_y = -derivative_y - interatomic_forces_contribution_z = -derivative_z - - interatomic_forces[node_id, 0] += interatomic_forces_contribution_x - interatomic_forces[node_id, 1] += interatomic_forces_contribution_y - interatomic_forces[node_id, 2] += interatomic_forces_contribution_z - - data.x = torch.cat( - (data.x, interatomic_potential, interatomic_forces), - 1, + node_potential = torch.zeros([data.pos.shape[0], 1]) + node_forces = torch.zeros([data.pos.shape[0], 3]) + + # Calculate + data.pos.requires_grad = True + edge_vec, edge_dist = get_edge_vectors_and_lengths( + positions=data.pos, + edge_index=data.edge_index, + shifts=data.edge_shifts, + normalize=False, ) + # Sum potential by edge, node, and total + edge_potential = self.formula.potential_energy( + edge_dist + ) # Shape [num_edges, 1] + node_potential = scatter( + edge_potential, + data.edge_index[0], + dim=0, + dim_size=data.pos.shape[0], + reduce="add", + ) # Shape [num_nodes, 1] + total_potential = torch.sum(node_potential, dim=0, keepdim=True) # Shape [1] + + # Autograd to calculate forces + node_forces = -torch.autograd.grad( + total_potential, + data.pos, + grad_outputs=torch.ones_like(total_potential), + )[ + 0 + ] # Shape [num_nodes, 3] + + # Append to data + data.x = torch.cat((data.x, node_potential, node_forces), dim=1) + return data @@ -477,40 +449,13 @@ def __init__(self, epsilon, sigma): self.epsilon = epsilon self.sigma = sigma - def potential_energy(self, distance_vector): - pair_distance = torch.norm(distance_vector) + def potential_energy(self, pair_distance): return ( 4 * self.epsilon * ((self.sigma / pair_distance) ** 12 - (self.sigma / pair_distance) ** 6) ) - def radial_derivative(self, distance_vector): - pair_distance = torch.norm(distance_vector) - return ( - 4 - * self.epsilon - * ( - -12 * (self.sigma / pair_distance) ** 12 * 1 / pair_distance - + 6 * (self.sigma / pair_distance) ** 6 * 1 / pair_distance - ) - ) - - def derivative_x(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[0].item()) / pair_distance - - def derivative_y(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[1].item()) / pair_distance - - def derivative_z(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[2].item()) / pair_distance - """Etc""" diff --git a/examples/LennardJones/LJ_inference_plots.py b/examples/LennardJones/LJ_inference_plots.py index 9f5d00cac..f675781ff 100644 --- a/examples/LennardJones/LJ_inference_plots.py +++ b/examples/LennardJones/LJ_inference_plots.py @@ -23,12 +23,13 @@ import hydragnn from hydragnn.utils.profiling_and_tracing.time_utils import Timer -from hydragnn.utils.distributed import get_device +from hydragnn.utils.distributed import get_device, setup_ddp from hydragnn.utils.model import load_existing_model from hydragnn.utils.datasets.pickledataset import SimplePickleDataset from hydragnn.utils.input_config_parsing.config_utils import ( update_config, ) +from hydragnn.utils.print import setup_log from hydragnn.models.create import create_model_config from hydragnn.preprocess import create_dataloaders @@ -42,13 +43,14 @@ from LJ_data import info import matplotlib.pyplot as plt +from sklearn.metrics import r2_score plt.rcParams.update({"font.size": 16}) def get_log_name_config(config): return ( - config["NeuralNetwork"]["Architecture"]["model_type"] + config["NeuralNetwork"]["Architecture"]["mpnn_type"] + "-r-" + str(config["NeuralNetwork"]["Architecture"]["radius"]) + "-ncl-" @@ -132,10 +134,10 @@ def getcolordensity(xdata, ydata): input_filename = os.path.join(dirpwd, args.inputfile) with open(input_filename, "r") as f: config = json.load(f) - hydragnn.utils.setup_log(get_log_name_config(config)) + setup_log(get_log_name_config(config)) ################################################################################################################## # Always initialize for multi-rank training. - comm_size, rank = hydragnn.utils.setup_ddp() + comm_size, rank = setup_ddp() ################################################################################################################## comm = MPI.COMM_WORLD @@ -179,11 +181,6 @@ def getcolordensity(xdata, ydata): load_existing_model(model, modelname, path="./logs/") model.eval() - variable_index = 0 - # for output_name, output_type, output_dim in zip(config["NeuralNetwork"]["Variables_of_interest"]["output_names"], config["NeuralNetwork"]["Variables_of_interest"]["type"], config["NeuralNetwork"]["Variables_of_interest"]["output_dim"]): - - test_MAE = 0.0 - num_samples = len(testset) energy_true_list = [] energy_pred_list = [] @@ -196,9 +193,6 @@ def getcolordensity(xdata, ydata): 0 ] # Note that this is sensitive to energy and forces prediction being single-task (current requirement) energy_pred = torch.sum(node_energy_pred, dim=0).float() - test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset) - # predicted.backward(retain_graph=True) - # gradients = data.pos.grad grads_energy = torch.autograd.grad( outputs=energy_pred, inputs=data.pos, @@ -211,6 +205,14 @@ def getcolordensity(xdata, ydata): forces_pred_list.extend((-grads_energy).flatten().tolist()) forces_true_list.extend(data.forces.flatten().tolist()) + # Show R2 Metrics + print( + f"R2 energy: ", r2_score(np.array(energy_true_list), np.array(energy_pred_list)) + ) + print( + f"R2 forces: ", r2_score(np.array(forces_true_list), np.array(forces_pred_list)) + ) + hist2d_norm = getcolordensity(energy_true_list, energy_pred_list) fig, ax = plt.subplots() @@ -225,8 +227,6 @@ def getcolordensity(xdata, ydata): plt.tight_layout() plt.savefig(f"./energy_Scatterplot" + ".png", dpi=400) - print(f"Test MAE energy: ", test_MAE) - hist2d_norm = getcolordensity(forces_pred_list, forces_true_list) fig, ax = plt.subplots() plt.scatter(forces_pred_list, forces_true_list, s=8, c=hist2d_norm, vmin=0, vmax=1) diff --git a/examples/alexandria/train.py b/examples/alexandria/train.py index b9cb52c90..73886a5f0 100644 --- a/examples/alexandria/train.py +++ b/examples/alexandria/train.py @@ -84,6 +84,7 @@ def __init__(self, dirpath, var_config, energy_per_atom=True, dist=False): self.energy_per_atom = energy_per_atom self.radius_graph = RadiusGraph(5.0, loop=False, max_num_neighbors=50) + self.radius_graph_pbc = RadiusGraphPBC(5.0, loop=False, max_num_neighbors=50) list_dirs = list_directories( os.path.join(dirpath, "compressed_data", "alexandria.icams.rub.de") @@ -149,6 +150,13 @@ def get_magmoms_array_from_structure(structure): print(f"Structure {entry_id} does not have cell", flush=True) return data_object + pbc = None + try: + pbc = structure["lattice"]["pbc"] + except: + print(f"Structure {entry_id} does not have pbc", flush=True) + return data_object + atomic_numbers = None try: atomic_numbers = ( @@ -236,6 +244,7 @@ def get_magmoms_array_from_structure(structure): data_object = Data( pos=pos, cell=cell, + pbc=pbc, atomic_numbers=atomic_numbers, forces=forces, # entry_id=entry_id, @@ -260,7 +269,18 @@ def get_magmoms_array_from_structure(structure): [data_object.atomic_numbers, data_object.pos, data_object.forces], dim=1 ) - data_object = self.radius_graph(data_object) + if data_object.pbc is not None and data_object.cell is not None: + try: + data_object = self.radius_graph_pbc(data_object) + except: + print( + f"Structure {entry_id} could not successfully apply pbc radius graph", + flush=True, + ) + data_object = self.radius_graph(data_object) + else: + data_object = self.radius_graph(data_object) + data_object = transform_coordinates(data_object) return data_object diff --git a/examples/mptrj/train.py b/examples/mptrj/train.py index 8dee83db6..856bec77d 100644 --- a/examples/mptrj/train.py +++ b/examples/mptrj/train.py @@ -70,6 +70,7 @@ def __init__( self.energy_per_atom = energy_per_atom self.radius_graph = RadiusGraph(5.0, loop=False, max_num_neighbors=50) + self.radius_graph_pbc = RadiusGraphPBC(5.0, loop=False, max_num_neighbors=50) self.dist = dist if self.dist: @@ -121,9 +122,20 @@ def __init__( info["magmom"] = k["magmom"] # Convert lists to PyTorch tensors - lattice_mat = torch.tensor( - info["atoms"]["lattice_mat"], dtype=torch.float32 - ) + lattice_mat = None + try: + lattice_mat = torch.tensor( + info["atoms"]["lattice_mat"], dtype=torch.float32 + ) + except: + print(f"Structure does not have lattice_mat", flush=True) + + pbc = None + try: + pbc = info["atoms"]["pbc"] + except: + print(f"Structure does not have pbc", flush=True) + coords = torch.tensor(info["atoms"]["coords"], dtype=torch.float32) # Multiply 'lattice_mat' by the transpose of 'coords' @@ -150,7 +162,8 @@ def __init__( # Creating the Data object data = Data( - supercell_size=lattice_mat, + cell=lattice_mat, + pbc=pbc, energy=energy, force=forces, # stress=torch.tensor(stresses, dtype=torch.float32), @@ -161,7 +174,18 @@ def __init__( y=energy, ) - data = self.radius_graph(data) + if data.pbc is not None and data.cell is not None: + try: + data = self.radius_graph_pbc(data) + except: + print( + f"Structure could not successfully apply pbc radius graph", + flush=True, + ) + data = self.radius_graph(data) + else: + data = self.radius_graph(data) + data = transform_coordinates(data) if self.check_forces_values(data.force): self.dataset.append(data) diff --git a/examples/omat24/train.py b/examples/omat24/train.py index ac32973b7..47dbe0ac3 100644 --- a/examples/omat24/train.py +++ b/examples/omat24/train.py @@ -24,7 +24,10 @@ SimplePickleWriter, SimplePickleDataset, ) -from hydragnn.preprocess.graph_samples_checks_and_updates import gather_deg +from hydragnn.preprocess.graph_samples_checks_and_updates import ( + RadiusGraphPBC, + gather_deg, +) from hydragnn.preprocess.load_data import split_dataset import hydragnn.utils.profiling_and_tracing.tracer as tr @@ -50,7 +53,12 @@ def info(*args, logtype="info", sep=" "): # FIXME: this radis cutoff overwrites the radius cutoff currently written in the JSON file -create_graph_fromXYZ = RadiusGraph(r=5.0) # radius cutoff in angstrom +create_graph_fromXYZ = RadiusGraph( + r=5.0, max_num_neighbors=50 +) # radius cutoff in angstrom +create_graph_fromXYZPBC = RadiusGraphPBC( + r=5.0, max_num_neighbors=50 +) # radius cutoff in angstrom compute_edge_lengths = Distance(norm=False, cat=True) @@ -128,12 +136,51 @@ def __init__( ) chemical_formula = dataset.get_atoms(index).get_chemical_formula() + cell = None + try: + cell = torch.tensor( + dataset.get_atoms(index).get_cell(), dtype=torch.float32 + ).view(3, 3) + except: + print( + f"Atomic structure {chemical_formula} does not have cell", + flush=True, + ) + + pbc = None + try: + pbc = dataset.get_atoms(index).get_pbc() + except: + print( + f"Atomic structure {chemical_formula} does not have pbc", + flush=True, + ) + if self.energy_per_atom: energy /= natoms.item() - data = Data(pos=xyz, x=Z, force=forces, energy=energy, y=energy) + data = Data( + pos=xyz, + cell=cell, + pbc=pbc, + x=Z, + force=forces, + energy=energy, + y=energy, + ) data.x = torch.cat((data.x, xyz, forces), dim=1) - data = create_graph_fromXYZ(data) + + if data.pbc is not None and data.cell is not None: + try: + data = create_graph_fromXYZPBC(data) + except: + print( + f"Structure could not successfully apply pbc radius graph", + flush=True, + ) + data = self.create_graph_from_XYZ(data) + else: + data = create_graph_fromXYZ(data) # Add edge length as edge feature data = compute_edge_lengths(data) diff --git a/examples/open_catalyst_2020/train.py b/examples/open_catalyst_2020/train.py index 74577f1e6..f7b14a715 100644 --- a/examples/open_catalyst_2020/train.py +++ b/examples/open_catalyst_2020/train.py @@ -89,7 +89,7 @@ def __init__( print(self.rank, "WARN: No files to process. Continue ...") # Initialize feature extractor. - a2g = AtomsToGraphs(max_neigh=50, radius=6, r_pbc=False) + a2g = AtomsToGraphs(max_neigh=50, radius=6.0) list_atomistic_structures = write_images_to_adios( a2g, @@ -110,7 +110,6 @@ def __init__( random.shuffle(self.dataset) def check_forces_values(self, forces): - # Calculate the L2 norm for each row norms = torch.norm(forces, p=2, dim=1) # Check if all norms are less than the threshold diff --git a/examples/open_catalyst_2020/utils/atoms_to_graphs.py b/examples/open_catalyst_2020/utils/atoms_to_graphs.py index a0edc97ce..b3621a0db 100644 --- a/examples/open_catalyst_2020/utils/atoms_to_graphs.py +++ b/examples/open_catalyst_2020/utils/atoms_to_graphs.py @@ -51,21 +51,19 @@ class AtomsToGraphs: def __init__( self, max_neigh=200, - radius=6, - r_pbc=False, + radius=6.0, ): self.max_neigh = max_neigh self.radius = radius - self.r_pbc = r_pbc - if self.r_pbc: - self.radius_graph = RadiusGraphPBC( - self.radius, loop=False, max_num_neighbors=self.max_neigh - ) - else: - self.radius_graph = RadiusGraph( - self.radius, loop=False, max_num_neighbors=self.max_neigh - ) + # NOTE Open Catalyst 2020 dataset has PBC: + # https://pubs.acs.org/doi/10.1021/acscatal.0c04525#_i3 (Section 2: Tasks, paragraph 2) + self.radius_graph = RadiusGraph( + self.radius, loop=False, max_num_neighbors=self.max_neigh + ) + self.radius_graph_pbc = RadiusGraphPBC( + self.radius, loop=False, max_num_neighbors=self.max_neigh + ) def convert( self, @@ -86,15 +84,27 @@ def convert( # set the atomic numbers, positions, and cell atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()).unsqueeze(1) positions = torch.Tensor(atoms.get_positions()) - cell = torch.Tensor(np.array(atoms.get_cell())).view(1, 3, 3) natoms = torch.IntTensor([positions.shape[0]]) # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags tags = torch.Tensor(atoms.get_tags()) + cell = None + try: + cell = torch.Tensor(np.array(atoms.get_cell())).view(3, 3) + except: + print(f"Structure does not have cell", flush=True) + + pbc = None + try: + pbc = atoms.get_pbc() + except: + print(f"Structure does not have pbc", flush=True) + # put the minimum data in torch geometric data object data = Data( cell=cell, + pbc=pbc, pos=positions, atomic_numbers=atomic_numbers, natoms=natoms, @@ -113,7 +123,18 @@ def convert( data.x = torch.cat((atomic_numbers, positions, forces), dim=1) - data = self.radius_graph(data) + if data.pbc is not None and data.cell is not None: + try: + data = self.radius_graph_pbc(data) + except: + print( + f"Structure could not successfully apply pbc radius graph", + flush=True, + ) + data = self.radius_graph(data) + else: + data = self.radius_graph(data) + data = transform_coordinates(data) return data diff --git a/examples/open_catalyst_2022/train.py b/examples/open_catalyst_2022/train.py index 2029f263b..a6e4a253a 100644 --- a/examples/open_catalyst_2022/train.py +++ b/examples/open_catalyst_2022/train.py @@ -38,7 +38,6 @@ RadiusGraph, RadiusGraphPBC, ) - from ase.io import read try: @@ -66,7 +65,6 @@ def __init__( data_type, energy_per_atom=True, dist=False, - r_pbc=False, ): super().__init__() @@ -74,11 +72,10 @@ def __init__( self.data_path = dirpath self.energy_per_atom = energy_per_atom - self.r_pbc = r_pbc - if self.r_pbc: - self.radius_graph = RadiusGraphPBC(6.0, loop=False, max_num_neighbors=50) - else: - self.radius_graph = RadiusGraph(6.0, loop=False, max_num_neighbors=50) + # NOTE Open Catalyst 2022 dataset has PBC: + # https://pubs.acs.org/doi/10.1021/acscatal.2c05426 (Section: Tasks, paragraph 3) + self.radius_graph = RadiusGraph(6.0, loop=False, max_num_neighbors=50) + self.radius_graph_pbc = RadiusGraphPBC(6.0, loop=False, max_num_neighbors=50) # Threshold for atomic forces in eV/angstrom self.forces_norm_threshold = 100.0 @@ -126,15 +123,27 @@ def ase_to_torch_geom(self, atoms): # set the atomic numbers, positions, and cell atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()).unsqueeze(1) positions = torch.Tensor(atoms.get_positions()) - cell = torch.Tensor(np.array(atoms.get_cell())).view(1, 3, 3) natoms = torch.IntTensor([positions.shape[0]]) # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags tags = torch.Tensor(atoms.get_tags()) + cell = None + try: + cell = torch.Tensor(np.array(atoms.get_cell())).view(3, 3) + except: + print(f"Structure does not have cell", flush=True) + + pbc = None + try: + pbc = atoms.get_pbc() + except: + print(f"Structure does not have pbc", flush=True) + # put the minimum data in torch geometric data object data = Data( cell=cell, + pbc=pbc, pos=positions, atomic_numbers=atomic_numbers, natoms=natoms, @@ -153,7 +162,18 @@ def ase_to_torch_geom(self, atoms): data.x = torch.cat((atomic_numbers, positions, forces), dim=1) - data = self.radius_graph(data) + if data.pbc is not None and data.cell is not None: + try: + data = self.radius_graph_pbc(data) + except: + print( + f"Structure could not successfully apply pbc radius graph", + flush=True, + ) + data = self.radius_graph(data) + else: + data = self.radius_graph(data) + data = transform_coordinates(data) return data @@ -169,7 +189,6 @@ def traj_to_torch_geom(self, traj_file): return data_list def check_forces_values(self, forces): - # Calculate the L2 norm for each row norms = torch.norm(forces, p=2, dim=1) # Check if all norms are less than the threshold diff --git a/hydragnn/preprocess/cfg_raw_dataset_loader.py b/hydragnn/preprocess/cfg_raw_dataset_loader.py index 32f44c867..d0029976c 100644 --- a/hydragnn/preprocess/cfg_raw_dataset_loader.py +++ b/hydragnn/preprocess/cfg_raw_dataset_loader.py @@ -75,7 +75,7 @@ def __transform_ASE_object_to_data_object(self, filepath): data_object = Data() - data_object.supercell_size = tensor(ase_object.cell.array).float() + data_object.cell = tensor(ase_object.cell.array).float() data_object.pos = tensor(ase_object.arrays["positions"]).float() proton_numbers = np.expand_dims(ase_object.arrays["numbers"], axis=1) masses = np.expand_dims(ase_object.arrays["masses"], axis=1) diff --git a/hydragnn/preprocess/graph_samples_checks_and_updates.py b/hydragnn/preprocess/graph_samples_checks_and_updates.py index d8146ea5f..bf7156198 100644 --- a/hydragnn/preprocess/graph_samples_checks_and_updates.py +++ b/hydragnn/preprocess/graph_samples_checks_and_updates.py @@ -11,14 +11,16 @@ import torch from torch_geometric.transforms import RadiusGraph -from torch_geometric.utils import remove_self_loops, degree from torch_geometric.data import Data +from torch_geometric.utils import remove_self_loops, degree import ase import ase.neighborlist +import numpy as np import os from .dataset_descriptors import AtomFeatures +from hydragnn.utils.distributed import get_device ## This function can be slow if datasets is too large. Use with caution. @@ -134,29 +136,18 @@ def get_radius_graph_pbc_config(config, loop=False): class RadiusGraphPBC(RadiusGraph): r"""Creates edges based on node positions :obj:`pos` to all points within a - given distance, including periodic images. + given distance, including periodic images, and limits the number of neighbors per node. """ def __call__(self, data): - data.edge_attr = None - data.edge_shifts = None - assert ( - "batch" not in data - ), "Periodic boundary conditions not currently supported on batches." - assert hasattr( - data, "supercell_size" - ), "The data must contain the size of the supercell to apply periodic boundary conditions." - assert hasattr( - data, "pbc" - ), "The data must contain data.pbc as a bool (True) or list of bools for the dimensions ([True, False, True]) to apply periodic boundary conditions." - # NOTE Cutoff radius being less than half the smallest supercell dimension is a sufficient, but not necessary condition for no dupe connections. - # However, to prevent an issue from being unobserved until long into an experiment, we assert this condition. - assert ( - self.r < min(torch.diagonal(data.supercell_size)) / 2 - ), "Cutoff radius must be smaller than half the smallest supercell dimension." + # Checks for attributes and ensures data type and device consistency + data, device, dtype = self._check_and_standardize_data( + data + ) # dtype gives us whether to use float32 or float64 + ase_atom_object = ase.Atoms( positions=data.pos, - cell=data.supercell_size, + cell=data.cell, pbc=data.pbc, ) # 'i' : first atom index @@ -170,31 +161,126 @@ def __call__(self, data): edge_length, edge_cell_shifts, ) = ase.neighborlist.neighbor_list( - "ijdS", a=ase_atom_object, cutoff=self.r, self_interaction=self.loop - ) - data.edge_index = torch.stack( - [torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], - dim=0, # Shape: [2, n_edges] + "ijdS", + a=ase_atom_object, + cutoff=self.r, + self_interaction=True, # We want self-interactions across periodic boundaries ) - # ensure no duplicate edges - unique_edge_index, unique_indices = torch.unique( - data.edge_index, dim=1, return_inverse=False # Shape: [n_edges] + # Eliminate true self-loops if desired + if not self.loop: + ( + edge_src, + edge_dst, + edge_length, + edge_cell_shifts, + ) = self._remove_true_self_loops( + edge_src, edge_dst, edge_length, edge_cell_shifts + ) + + # Limit neighbors per node + edge_src, edge_dst, edge_length, edge_cell_shifts = self._limit_neighbors( + edge_src, edge_dst, edge_length, edge_cell_shifts, self.max_num_neighbors ) - assert unique_edge_index.unsqueeze(0).size(1) == data.edge_index.size( - 1 - ), "Adding periodic boundary conditions would result in duplicate edges. Cutoff radius must be reduced or system size increased." - data.edge_attr = torch.tensor(edge_length, dtype=torch.float).unsqueeze( + # Assign to data + data.edge_index = torch.stack( + [ + torch.tensor(edge_src, dtype=torch.long, device=device), + torch.tensor(edge_dst, dtype=torch.long, device=device), + ], + dim=0, # Shape: [2, n_edges] + ) + data.edge_attr = torch.tensor( + edge_length, dtype=dtype, device=device + ).unsqueeze( 1 ) # Shape: [n_edges, 1] - # ASE returns whether the cell was shifted or not (-1,0,1). Multiply by the cell size to get the actual shift + # ASE returns the integer number of cell shifts. Multiply by the cell size to get the shift vector. data.edge_shifts = torch.matmul( - torch.tensor(edge_cell_shifts).float(), data.supercell_size.float() + torch.tensor(edge_cell_shifts, dtype=dtype, device=device), + data.cell, ) # Shape: [n_edges, 3] return data + def _remove_true_self_loops( + self, edge_src, edge_dst, edge_length, edge_cell_shifts + ): + # Create a mask to remove true self loops (i.e. the same source and destination node in the same cell) + true_self_edges = edge_src == edge_dst + true_self_edges &= np.all(edge_cell_shifts == 0, axis=1) + mask = ~true_self_edges + + # Apply the mask and return + return ( + edge_src[mask], + edge_dst[mask], + edge_length[mask], + edge_cell_shifts[mask], + ) + + def _limit_neighbors( + self, edge_src, edge_dst, edge_length, edge_cell_shifts, max_num_neighbors + ): + # Lexsort will sort primarily by edge_src, then by edge_dst within each src node + sorted_indices = np.lexsort((edge_length, edge_src)) + edge_src, edge_dst, edge_length, edge_cell_shifts = [ + edge_arg[sorted_indices] + for edge_arg in [edge_src, edge_dst, edge_length, edge_cell_shifts] + ] + + # Create a mask to keep only `max_num_neighbors` per node + unique_src, counts = np.unique(edge_src, return_counts=True) + mask = np.zeros_like(edge_src, dtype=bool) + start_idx = 0 + for src, count in zip(unique_src, counts): + end_idx = start_idx + count + # Keep only the first max_num_neighbors for this src + mask[start_idx : start_idx + min(count, max_num_neighbors)] = True + start_idx = end_idx + + # Apply the mask and return + return ( + edge_src[mask], + edge_dst[mask], + edge_length[mask], + edge_cell_shifts[mask], + ) + + def _check_and_standardize_data(self, data): + assert ( + "batch" not in data + ), "Periodic boundary conditions not currently supported on batches." + assert hasattr( + data, "cell" + ), "The data must contain data.cell as a 3x3 matrix to apply periodic boundary conditions." + assert hasattr( + data, "pbc" + ), "The data must contain data.pbc as a bool (True) or list of bools for the dimensions ([True, False, True]) to apply periodic boundary conditions." + + # Ensure data consistency in terms of device and type + if not isinstance(data.pos, torch.Tensor): + data.pos = torch.tensor(data.pos) + if data.pos.dtype not in [torch.float32, torch.float64]: + data.pos = data.pos.to(torch.get_default_dtype()) + # Canonicalize based off data.pos, similar to PyG's default behavior + device, dtype = data.pos.device, data.pos.dtype + if not ( + isinstance(data.cell, torch.Tensor) + and data.cell.dtype == dtype + and data.cell.device == device + ): + data.cell = torch.tensor(data.cell, dtype=dtype, device=device) + if not ( + isinstance(data.pbc, torch.Tensor) + and data.pbc.dtype == torch.bool + and data.pbc.device == device + ): + data.pbc = torch.tensor(data.pbc, dtype=torch.bool, device=device) + + return data, device, dtype + def __repr__(self) -> str: return f"{self.__class__.__name__}(r={self.r})" diff --git a/tests/test_periodic_boundary_conditions.py b/tests/test_periodic_boundary_conditions.py index 570736210..a6ebc26de 100644 --- a/tests/test_periodic_boundary_conditions.py +++ b/tests/test_periodic_boundary_conditions.py @@ -81,9 +81,7 @@ def pytest_periodic_h2(): # Create # Hydrogen molecule (H2) with arbitrary node features data = Data() - data.supercell_size = torch.tensor( - [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]] - ) + data.cell = torch.tensor([[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]]) data.pbc = [True, True, True] data.atom_types = [1, 1] data.pos = torch.tensor([[1.0, 1.0, 1.0], [1.43, 1.43, 1.43]]) @@ -108,7 +106,7 @@ def pytest_periodic_bcc_large(): # Convert to PyG data2 = Data() - data2.supercell_size = torch.tensor(supercell.cell[:]) + data2.cell = torch.tensor(supercell.cell[:]) data2.pbc = [True, True, True] data2.atom_types = np.ones(len(supercell)) * 27 data2.pos = torch.tensor(supercell.positions)