diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..052afcc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "setuptools >= 68", + "wheel", + "torch", +] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index aad7b71..046696e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ jupyter scikit-learn torch torch_geometric +matscipy \ No newline at end of file diff --git a/setup.py b/setup.py index 69e1e04..a6187d7 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,30 @@ -from setuptools import setup - +from setuptools import setup, Extension +from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths with open("requirements.txt") as f: requirements = f.read().splitlines() +# Collecting include and library paths +include_dirs = include_paths() +library_dirs = library_paths() + +libraries = [] + +libraries.append('c10') +libraries.append('torch') +libraries.append('torch_cpu') + + +# Defining the extension module without specifying the unwanted libraries +neighbors_convert_extension = Extension( + name="pet.neighbors_convert", + sources=["src/neighbors_convert.cpp"], + include_dirs=include_dirs, + library_dirs=library_dirs, + libraries=libraries, + language='c++', +) + setup( name="pet", version="0.0.0", @@ -18,4 +39,12 @@ ], }, install_requires=requirements, + ext_modules=[neighbors_convert_extension], + cmdclass={ + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True) + }, + package_data={ + 'pet': ['neighbors_convert.so'], # Ensure the shared object file is included + }, + include_package_data=True, ) diff --git a/src/__init__.py b/src/__init__.py index 6014564..3d95399 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +1,15 @@ from .single_struct_calculator import SingleStructCalculator + +import torch +import importlib.resources as pkg_resources + +def load_neighbors_convert(): + try: + # Locate the shared object file in the package + with pkg_resources.files(__name__).joinpath('neighbors_convert.so') as lib_path: + # Load the shared object file + torch.ops.load_library(str(lib_path)) + except Exception as e: + print(f"Failed to load neighbors_convert.so: {e}") + +load_neighbors_convert() diff --git a/src/molecule.py b/src/molecule.py index 4b2136b..36dffac 100644 --- a/src/molecule.py +++ b/src/molecule.py @@ -3,7 +3,7 @@ import numpy as np from torch_geometric.data import Data from .long_range import get_reciprocal, get_all_k, get_volume - +from matscipy.neighbours import neighbour_list as neighbor_list class Molecule: def __init__( @@ -188,6 +188,73 @@ def get_graph(self, max_num, all_species, max_num_k): return result +class MoleculeCPP: + def __init__( + self, atoms, r_cut, use_additional_scalar_attributes, use_long_range, k_cut + ): + + self.use_additional_scalar_attributes = use_additional_scalar_attributes + self.atoms = atoms + self.r_cut = r_cut + self.use_long_range = use_long_range + self.k_cut = k_cut + + if self.use_long_range: + raise NotImplementedError("Long range is not implemented in cpp") + if self.use_additional_scalar_attributes: + raise NotImplementedError("Additional scalar attributes are not implemented in cpp") + + def is_3d_crystal(atoms): + pbc = atoms.get_pbc() + if isinstance(pbc, bool): + return pbc + return all(pbc) + + if is_3d_crystal(atoms): + i_list, j_list, D_list, S_list = neighbor_list('ijDS', atoms, r_cut) + else: + i_list, j_list, D_list, S_list = ase.neighborlist.neighbor_list( + "ijDS", atoms, r_cut + ) + + self.i_list = torch.tensor(i_list, dtype=torch.int64).contiguous() + self.j_list = torch.tensor(j_list, dtype=torch.int64).contiguous() + self.D_list = torch.tensor(D_list, dtype=torch.get_default_dtype()).contiguous() + self.S_list = torch.tensor(S_list, dtype=torch.int64).contiguous() + self.species = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int64).contiguous() + if len(self.i_list) == 0: + self.max_num = 0 + else: + self.max_num = torch.max(torch.bincount(self.i_list)) + + def get_num_k(self): + raise NotImplementedError("Long range is not implemented in cpp") + + def get_max_num(self): + return self.max_num + + def get_graph(self, max_num, all_species, max_num_k): + n_atoms = len(self.atoms.get_atomic_numbers()) + all_species = torch.tensor(all_species, dtype=torch.int64).contiguous() + + # torch.ops.my_extension.process(i_list, j_list, S_list, D_list, max_size, n_atoms, species, None) + neighbors_index, relative_positions, nums, mask, neighbor_species, neighbors_pos, species_mapped = torch.ops.neighbors_convert.process(self.i_list, self.j_list, self.S_list, self.D_list, max_num, n_atoms, self.species, all_species) + + kwargs = { + "central_species": species_mapped, + "x": relative_positions, + "neighbor_species": neighbor_species, + "neighbors_pos": neighbors_pos, + "neighbors_index": neighbors_index.transpose(0, 1), + "nums": nums, + "mask": mask, + "n_atoms": len(self.atoms.positions), + } + + result = Data(**kwargs) + + return result + def batch_to_dict(batch): batch_dict = { "x": batch.x, diff --git a/src/neighbors_convert.cpp b/src/neighbors_convert.cpp new file mode 100644 index 0000000..2b54c3e --- /dev/null +++ b/src/neighbors_convert.cpp @@ -0,0 +1,346 @@ +// #include +#include +#include // For std::fill +#include // For c10::optional +#include +#include + +// Template function to process the neighbors +template +std::vector process_neighbors_cpu(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) { + // Ensure the tensors are on the CPU and are contiguous + TORCH_CHECK(i_list.device().is_cpu(), "i_list must be on CPU"); + TORCH_CHECK(j_list.device().is_cpu(), "j_list must be on CPU"); + TORCH_CHECK(S_list.device().is_cpu(), "S_list must be on CPU"); + TORCH_CHECK(D_list.device().is_cpu(), "D_list must be on CPU"); + TORCH_CHECK(species.device().is_cpu(), "species must be on CPU"); + TORCH_CHECK(all_species.device().is_cpu(), "all_species must be on CPU"); + + TORCH_CHECK(i_list.is_contiguous(), "i_list must be contiguous"); + TORCH_CHECK(j_list.is_contiguous(), "j_list must be contiguous"); + TORCH_CHECK(S_list.is_contiguous(), "S_list must be contiguous"); + TORCH_CHECK(D_list.is_contiguous(), "D_list must be contiguous"); + TORCH_CHECK(species.is_contiguous(), "species must be contiguous"); + TORCH_CHECK(all_species.is_contiguous(), "all_species must be contiguous"); + + // Ensure the sizes match + TORCH_CHECK(i_list.sizes() == j_list.sizes(), "i_list and j_list must have the same size"); + TORCH_CHECK(i_list.size(0) == S_list.size(0) && S_list.size(1) == 3, "S_list must have the shape [N, 3]"); + TORCH_CHECK(i_list.size(0) == D_list.size(0) && D_list.sizes() == S_list.sizes(), "D_list must have the same shape as S_list"); + + // Initialize tensors with zeros + auto options_int = torch::TensorOptions().dtype(i_list.dtype()).device(torch::kCPU); + auto options_float = torch::TensorOptions().dtype(D_list.dtype()).device(torch::kCPU); + auto options_bool = torch::TensorOptions().dtype(at::kBool).device(torch::kCPU); + + at::Tensor neighbors_index = torch::zeros({n_atoms, max_size}, options_int); + at::Tensor neighbors_shift = torch::zeros({n_atoms, max_size, 3}, options_int); + at::Tensor relative_positions = torch::zeros({n_atoms, max_size, 3}, options_float); + at::Tensor nums = torch::zeros({n_atoms}, options_int); // Tensor to store the count of elements + at::Tensor mask = torch::ones({n_atoms, max_size}, options_bool); // Tensor to store the mask + at::Tensor neighbor_species = all_species.size(0) * torch::ones({n_atoms, max_size}, options_int); + + int64_t scalar_attr_dim = 0; + + // Temporary array to track the current population index + int_t* current_index = new int_t[n_atoms]; + std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros + + // Get raw data pointers + int_t* i_list_ptr = i_list.data_ptr(); + int_t* j_list_ptr = j_list.data_ptr(); + int_t* S_list_ptr = S_list.data_ptr(); + float_t* D_list_ptr = D_list.data_ptr(); + int_t* species_ptr = species.data_ptr(); + int_t* all_species_ptr = all_species.data_ptr(); + + int_t* neighbors_index_ptr = neighbors_index.data_ptr(); + int_t* neighbors_shift_ptr = neighbors_shift.data_ptr(); + float_t* relative_positions_ptr = relative_positions.data_ptr(); + int_t* nums_ptr = nums.data_ptr(); + bool* mask_ptr = mask.data_ptr(); + int_t* neighbor_species_ptr = neighbor_species.data_ptr(); + + int64_t all_species_size = all_species.size(0); + + int_t all_species_maximum = -1; + for (int64_t k = 0; k < all_species_size; ++k) { + if (all_species_ptr[k] > all_species_maximum) { + all_species_maximum = all_species_ptr[k]; + } + } + + int_t* mapping = new int_t[all_species_maximum + 1]; + for (int64_t k = 0; k < all_species_size; ++k) { + mapping[all_species_ptr[k]] = k; + } + + + + // Populate the neighbors_index, neighbors_shift, relative_positions, neighbor_species, and neighbor_scalar_attributes tensors + + int64_t shift_i; + int_t i, j, idx; + for (int64_t k = 0; k < i_list.size(0); ++k) { + i = i_list_ptr[k]; + j = j_list_ptr[k]; + idx = current_index[i]; + + shift_i = i * max_size; + if (idx < max_size) { + neighbors_index_ptr[shift_i + idx] = j; + neighbor_species_ptr[shift_i + idx] = mapping[species_ptr[j]]; + /*for (int64_t q = 0; q < all_species_size; ++q) { + if (all_species_ptr[q] == species_ptr[j]) { + neighbor_species_ptr[i * max_size + idx] = q; + break; + } + }*/ + + // Unroll the loop for better computational efficiency + neighbors_shift_ptr[(shift_i + idx) * 3 + 0] = S_list_ptr[k * 3 + 0]; + neighbors_shift_ptr[(shift_i + idx) * 3 + 1] = S_list_ptr[k * 3 + 1]; + neighbors_shift_ptr[(shift_i + idx) * 3 + 2] = S_list_ptr[k * 3 + 2]; + + relative_positions_ptr[(shift_i + idx) * 3 + 0] = D_list_ptr[k * 3 + 0]; + relative_positions_ptr[(shift_i + idx) * 3 + 1] = D_list_ptr[k * 3 + 1]; + relative_positions_ptr[(shift_i + idx) * 3 + 2] = D_list_ptr[k * 3 + 2]; + + mask_ptr[shift_i + idx] = false; + + current_index[i]++; + } + } + + // Copy current_index to nums + for (int64_t i = 0; i < n_atoms; ++i) { + nums_ptr[i] = current_index[i]; + } + + at::Tensor neighbors_pos = torch::zeros({n_atoms, max_size}, options_int); + int_t* neighbors_pos_ptr = neighbors_pos.data_ptr(); + + // Temporary array to track the current population index + int_t* current_index_two = new int_t[n_atoms]; + std::fill(current_index_two, current_index_two + n_atoms, 0); // Fill the array with zeros + + int64_t shift_j; + for (int64_t k = 0; k < i_list.size(0); ++k) { + i = i_list_ptr[k]; + j = j_list_ptr[k]; + shift_j = j * max_size; + for (int64_t q = 0; q < current_index[j]; ++q) { + if (neighbors_index_ptr[shift_j + q] == i && neighbors_shift_ptr[(shift_j + q) * 3 + 0] == -S_list_ptr[k * 3 + 0] && neighbors_shift_ptr[(shift_j + q) * 3 + 1] == -S_list_ptr[k * 3 + 1] && neighbors_shift_ptr[(shift_j + q) * 3 + 2] == -S_list_ptr[k * 3 + 2]) { + neighbors_pos_ptr[i * max_size + current_index_two[i]] = q; + current_index_two[i]++; + break; + } + } + } + + // Clean up temporary memory + delete[] current_index; + delete[] current_index_two; + + at::Tensor species_mapped = torch::zeros({n_atoms}, options_int); + int_t* species_mapped_ptr = species_mapped.data_ptr(); + for (int64_t k = 0; k < n_atoms; ++k) { + species_mapped_ptr[k] = mapping[species_ptr[k]]; + } + + /*for (int64_t k = 0; k < n_atoms; ++k) { + for (int64_t q = 0; q < all_species_size; ++q) { + if (all_species_ptr[q] == species_ptr[k]) { + species_mapped_ptr[k] = q; + break; + } + } + }*/ + + delete[] mapping; + + return {neighbors_index, relative_positions, nums, mask, neighbor_species, neighbors_pos, species_mapped}; +} + +// Template function for backward pass +template +at::Tensor process_neighbors_cpu_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { + // Ensure the tensors are on the CPU and are contiguous + TORCH_CHECK(grad_output.device().is_cpu(), "grad_output must be on CPU"); + TORCH_CHECK(i_list.device().is_cpu(), "i_list must be on CPU"); + + grad_output = grad_output.contiguous(); + i_list = i_list.contiguous(); + + // TORCH_CHECK(grad_output.is_contiguous(), "grad_output must be contiguous"); + // TORCH_CHECK(i_list.is_contiguous(), "i_list must be contiguous"); + + // Initialize gradient tensor for D_list with zeros + auto options_float = torch::TensorOptions().dtype(grad_output.dtype()).device(torch::kCPU); + at::Tensor grad_D_list = torch::zeros({i_list.size(0), 3}, options_float); + + int_t* current_index = new int_t[n_atoms]; + std::fill(current_index, current_index + n_atoms, 0); // Fill the array with zeros + + float_t* grad_D_list_ptr = grad_D_list.data_ptr(); + float_t* grad_output_ptr = grad_output.data_ptr(); + int_t* i_list_ptr = i_list.data_ptr(); + int_t i, idx; + + for (int64_t k = 0; k < i_list.size(0); ++k) { + i = i_list_ptr[k]; + idx = current_index[i]; + grad_D_list_ptr[k * 3 + 0] = grad_output_ptr[(i * max_size + idx) * 3 + 0]; + grad_D_list_ptr[k * 3 + 1] = grad_output_ptr[(i * max_size + idx) * 3 + 1]; + grad_D_list_ptr[k * 3 + 2] = grad_output_ptr[(i * max_size + idx) * 3 + 2]; + current_index[i]++; + } + + delete[] current_index; + return grad_D_list; +} + +template +at::Tensor process_neighbors_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { + // Ensure all tensors are on the same device + auto device = grad_output.device(); + TORCH_CHECK(i_list.device() == device, "i_list must be on the same device as grad_output"); + + // Move all tensors to CPU + auto grad_output_cpu = grad_output.cpu(); + auto i_list_cpu = i_list.cpu(); + + // Invoke the CPU version of the function + auto grad_D_list_cpu = process_neighbors_cpu_backward(grad_output_cpu, i_list_cpu, max_size, n_atoms); + + // Move the gradient tensor back to the initial device + return grad_D_list_cpu.to(device); +} + +// Dispatch function based on tensor types for backward +at::Tensor process_dispatch_backward(at::Tensor grad_output, at::Tensor i_list, int64_t max_size, int64_t n_atoms) { + if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Float) { + return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); + } else if (i_list.scalar_type() == at::ScalarType::Int && grad_output.scalar_type() == at::ScalarType::Double) { + return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); + } else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Float) { + return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); + } else if (i_list.scalar_type() == at::ScalarType::Long && grad_output.scalar_type() == at::ScalarType::Double) { + return process_neighbors_backward(grad_output, i_list, max_size, n_atoms); + } else { + throw std::runtime_error("Unsupported tensor types"); + } +} + +template +std::vector process_neighbors(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) { + // Ensure all tensors are on the same device + auto device = i_list.device(); + TORCH_CHECK(j_list.device() == device, "j_list must be on the same device as i_list"); + TORCH_CHECK(S_list.device() == device, "S_list must be on the same device as i_list"); + TORCH_CHECK(D_list.device() == device, "D_list must be on the same device as i_list"); + TORCH_CHECK(species.device() == device, "species must be on the same device as i_list"); + TORCH_CHECK(all_species.device() == device, "all_species must be on the same device as i_list"); + + // Move all tensors to CPU + auto i_list_cpu = i_list.cpu(); + auto j_list_cpu = j_list.cpu(); + auto S_list_cpu = S_list.cpu(); + auto D_list_cpu = D_list.cpu(); + auto species_cpu = species.cpu(); + auto all_species_cpu = all_species.cpu(); + + // Invoke the CPU version of the function + auto result = process_neighbors_cpu(i_list_cpu, j_list_cpu, S_list_cpu, D_list_cpu, max_size, n_atoms, species_cpu, all_species_cpu); + + // Move the output tensors back to the initial device + for (auto& tensor_opt : result) { + tensor_opt = tensor_opt.to(device); + } + + return result; +} + +// Dispatch function based on tensor types +std::vector process_dispatch(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, + at::Tensor all_species) { + if (i_list.scalar_type() == at::ScalarType::Int && j_list.scalar_type() == at::ScalarType::Int && + S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Float) { + return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); + } else if (i_list.scalar_type() == at::ScalarType::Int && j_list.scalar_type() == at::ScalarType::Int && + S_list.scalar_type() == at::ScalarType::Int && D_list.scalar_type() == at::ScalarType::Double) { + return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); + } else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && + S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Float) { + return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); + } else if (i_list.scalar_type() == at::ScalarType::Long && j_list.scalar_type() == at::ScalarType::Long && + S_list.scalar_type() == at::ScalarType::Long && D_list.scalar_type() == at::ScalarType::Double) { + return process_neighbors(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); + } else { + throw std::runtime_error("Unsupported tensor types"); + } +} + +class ProcessNeighborsFunction : public torch::autograd::Function { +public: + static std::vector forward(torch::autograd::AutogradContext *ctx, at::Tensor i_list, at::Tensor j_list, + at::Tensor S_list, at::Tensor D_list, int64_t max_size, int64_t n_atoms, + at::Tensor species, at::Tensor all_species) { + auto outputs = process_dispatch(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); + ctx->save_for_backward({i_list}); + ctx->saved_data["max_size"] = max_size; + ctx->saved_data["n_atoms"] = n_atoms; + return outputs; + } + + static std::vector backward(torch::autograd::AutogradContext *ctx, std::vector grad_outputs) { + auto i_list = ctx->get_saved_variables()[0]; + auto max_size = ctx->saved_data["max_size"].toInt(); + auto n_atoms = ctx->saved_data["n_atoms"].toInt(); + + auto grad_relative_positions = grad_outputs[1]; // Assuming this is the gradient w.r.t relative_positions tensor + auto grad_D_list = process_dispatch_backward(grad_relative_positions, i_list, max_size, n_atoms); + + return {at::Tensor(), at::Tensor(), at::Tensor(), grad_D_list, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()}; + } +}; + +// Wrapper function to call apply +std::vector process_neighbors_apply(at::Tensor i_list, at::Tensor j_list, at::Tensor S_list, at::Tensor D_list, + int64_t max_size, int64_t n_atoms, at::Tensor species, at::Tensor all_species) { + return ProcessNeighborsFunction::apply(i_list, j_list, S_list, D_list, max_size, n_atoms, species, all_species); +} + +/*TORCH_LIBRARY(neighbors_convert, m) { + m.def( + "convert_neighbors(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", + &process_neighbors_apply + ); +}*/ + +TORCH_LIBRARY(neighbors_convert, m) { + m.def( + "process", + &process_neighbors_apply + ); +} + +// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("process_neighbors(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply, "Process neighbors and return tensors, including count tensor, mask, and neighbor_species"); +// } + +/*static auto registry = torch::RegisterOperators() + .op("neighbors_convert::process(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply);*/ + +/*PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("process_neighbors", &process_neighbors_apply, "Process neighbors and return tensors, including count tensor, mask, and neighbor_species");*/ + +/*static auto registry = torch::RegisterOperators() + .op("neighbors_convert::process(Tensor i_list, Tensor j_list, Tensor S_list, Tensor D_list, int max_size, int n_atoms, Tensor species, Tensor all_species) -> Tensor[]", &process_neighbors_apply); + +*/ \ No newline at end of file diff --git a/src/single_struct_calculator.py b/src/single_struct_calculator.py index 963ea44..6292d42 100644 --- a/src/single_struct_calculator.py +++ b/src/single_struct_calculator.py @@ -3,15 +3,20 @@ from torch_geometric.nn import DataParallel from .data_preparation import get_compositional_features -from .molecule import Molecule +from .molecule import Molecule, MoleculeCPP from .hypers import load_hypers_from_file from .pet import PET, PETMLIPWrapper, PETUtilityWrapper +from .utilities import string2dtype, get_quadrature_predictions class SingleStructCalculator: def __init__( - self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", device="cpu" + self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", device="cpu", quadrature_order=None, inversions=False, use_augmentation=False, add_self_contributions=False, ): + if (quadrature_order is not None) and (use_augmentation): + raise NotImplementedError("Simultaneous use of a quadrature and augmentation is not yet implemented") + + self.use_augmentation = use_augmentation hypers_path = path_to_calc_folder + "/hypers_used.yaml" path_to_model_state_dict = ( path_to_calc_folder + "/" + checkpoint + "_state_dict" @@ -37,13 +42,18 @@ def __init__( model = PETMLIPWrapper( model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES ) - if FITTING_SCHEME.MULTI_GPU and torch.cuda.is_available(): + if torch.cuda.is_available() and (torch.cuda.device_count() > 1): model = DataParallel(model) model = model.to(torch.device("cuda:0")) model.load_state_dict( - torch.load(path_to_model_state_dict, map_location=torch.device(device)) + torch.load(path_to_model_state_dict, + map_location=torch.device(device)) ) + + if FITTING_SCHEME.MULTI_GPU and torch.cuda.is_available(): + model = model.module + model.eval() self.model = model @@ -51,31 +61,60 @@ def __init__( self.all_species = all_species self.device = device + if use_augmentation and (quadrature_order is not None or inversions): + raise NotImplementedError("Simultaneous use of a quadrature/inversions and augmentation is not yet implemented") + + if quadrature_order is not None: + self.quadrature_order = int(quadrature_order) + else: + self.quadrature_order = None + + self.inversions = inversions + self.use_augmentation = use_augmentation + self.add_self_contributions = add_self_contributions + def forward(self, structure): - molecule = Molecule( + molecule = MoleculeCPP( structure, self.architectural_hypers.R_CUT, self.architectural_hypers.USE_ADDITIONAL_SCALAR_ATTRIBUTES, self.architectural_hypers.USE_LONG_RANGE, self.architectural_hypers.K_CUT, ) + if self.architectural_hypers.USE_LONG_RANGE: + raise NotImplementedError( + "Long range interactions are not supported in the SingleStructCalculator" + ) graph = molecule.get_graph( - molecule.get_max_num(), self.all_species, molecule.get_num_k() + molecule.get_max_num(), self.all_species, None ) graph.batch = torch.zeros( graph.num_nodes, dtype=torch.long, device=graph.x.device ) graph = graph.to(self.device) - prediction_energy, prediction_forces = self.model( - graph, augmentation=False, create_graph=False - ) - compositional_features = get_compositional_features( - [structure], self.all_species - )[0] - self_contributions_energy = np.dot( - compositional_features, self.self_contributions - ) - energy_total = prediction_energy.data.cpu().numpy() + self_contributions_energy - return energy_total, prediction_forces.data.cpu().numpy() + if self.quadrature_order is None and not self.inversions: + prediction_energy, prediction_forces = self.model( + graph, augmentation=self.use_augmentation, create_graph=False + ) + prediction_energy_final = prediction_energy.data.cpu().numpy() + prediction_forces_final = prediction_forces.data.cpu().numpy() + else: + prediction_energy_final, prediction_forces_final = get_quadrature_predictions( + graph, self.model, self.quadrature_order, self.inversions, string2dtype(self.architectural_hypers.DTYPE) + ) + + energy_total = prediction_energy_final + + if self.add_self_contributions: + compositional_features = get_compositional_features( + [structure], self.all_species + )[0] + self_contributions_energy = np.dot( + compositional_features, self.self_contributions + ) + # note: this may lead to numerical problems in less than double precision + energy_total = energy_total + self_contributions_energy + + return energy_total, prediction_forces_final diff --git a/src/utilities.py b/src/utilities.py index eb43eae..eb4f07f 100644 --- a/src/utilities.py +++ b/src/utilities.py @@ -527,7 +527,7 @@ def get_quadrature(L): for v, weight in zip(all_v, weights_now): weights.append(weight) angles = [theta, v, w] - rotation = R.from_euler("xyz", angles, degrees=False) + rotation = R.from_euler("zxz", angles, degrees=False) rotation_matrix = rotation.as_matrix() matrices.append(rotation_matrix) @@ -554,3 +554,42 @@ def string2dtype(string): return torch.bfloat16 raise ValueError("unknown dtype") + +def get_quadrature_predictions(batch, model, quadrature_order, inversions, dtype): + x_initial = batch.x.clone() + + if quadrature_order is not None: + rotations, weights = get_quadrature(quadrature_order) + else: + assert inversions + + rotations = [np.eye(3)] + weights = [1.0] + + inverse_rotations = [r.T for r in rotations] + + if inversions: + rotations += [-r for r in rotations] + inverse_rotations += [-r for r in inverse_rotations] + weights += weights + + energy_mean, forces_mean, total_weight = 0.0, 0.0, 0.0 + for rotation, inverse, weight in zip(rotations, inverse_rotations, weights): + rotation = torch.tensor(rotation, device = batch.x.device, dtype = dtype) + inverse = torch.tensor(inverse, device = batch.x.device, dtype = dtype) + + batch_rotations = rotation[None, :].repeat(batch.num_nodes, 1, 1) + batch.x = torch.bmm(x_initial, batch_rotations) + energy, forces = model( + batch, augmentation=False, create_graph=False + ) + # rotate forces back into original frame of reference + forces = torch.matmul(forces, inverse[None, :, :]) + + energy_mean += energy.data.cpu().numpy() * weight + forces_mean += forces.data.cpu().numpy() * weight + total_weight += weight + + energy_mean /= total_weight + forces_mean /= total_weight + return energy_mean, forces_mean diff --git a/tests/bulk.xyz b/tests/bulk.xyz new file mode 100644 index 0000000..6a7e03a --- /dev/null +++ b/tests/bulk.xyz @@ -0,0 +1,194 @@ +192 +Lattice="11.74456431055329 0.0 0.0 -5.29177210903e-07 11.74456431055329 0.0 -5.29177210903e-07 -5.29177210903e-07 11.74456431055329" Properties=species:S:1:pos:R:3:forces:R:3 cutoff=-1.0 nneightol=1.2 energy=-30010.461001933934 pbc="T T T" +O 5.73416426 3.02709473 1.80318722 -0.80332611 -1.21376692 0.20230786 +H 5.56943139 2.10076473 1.36991809 0.50217575 1.53231646 1.34795286 +H 5.56662675 2.96533975 2.79956970 0.54751666 -0.28190205 -1.62430029 +O 1.78772996 10.85151956 1.15475463 -2.83291472 2.12456520 -0.92718649 +H 1.27361313 11.62702876 0.64046847 1.81789416 -2.16291579 1.69135469 +H 2.16766332 10.15956744 0.55739294 -0.13023518 0.45973042 -0.26124271 +O 7.52468827 9.97186828 10.15051851 -1.51298176 3.37316030 2.50257924 +H 8.17832796 9.32362620 10.34398569 0.48150767 -0.47552780 0.31697757 +H 7.43679193 10.62175081 11.03466779 1.05936697 -2.82762853 -2.99742426 +O 7.91178140 4.58707211 0.75474958 1.34499609 1.23814613 -2.36786880 +H 8.06085062 4.79654691 1.63976672 0.38014497 0.68880399 2.57828797 +H 7.13320297 4.09945648 0.81464715 -2.16483898 -1.83566563 0.07202743 +O 8.92351531 5.17808368 9.14508181 2.09709552 -3.02514590 -0.30432053 +H 8.62775816 5.97546907 9.42279401 -2.76871939 4.09653535 0.85838888 +H 9.79237137 5.33770469 9.55709918 0.94460315 -1.16752304 -0.52726150 +O 11.39392620 4.75242940 10.31763267 -1.23014486 0.11996567 -0.73341779 +H 11.04689178 4.60206370 11.18738833 -0.18260446 -0.43417416 1.22912156 +H 0.40332564 5.30748867 10.49062070 1.23808957 1.01266029 -0.11816898 +O 0.06869249 6.49660278 7.03350598 -8.74972508 3.31707423 -1.29926643 +H 11.62972756 7.37731241 6.49459191 0.80802095 -2.94173414 0.93328515 +H 0.92074189 6.41346904 7.01763066 7.99459174 -0.50270591 0.52857277 +O 2.58803640 5.05946331 5.42406641 -2.78970988 -0.38920966 -3.53642028 +H 2.09149884 4.27517506 5.74982790 0.69823994 0.18291556 0.51624689 +H 2.05534545 5.27435689 4.54993445 1.81884547 -0.35324094 2.29502427 +O 3.93554913 4.81396213 0.78823062 -4.93885484 -11.15967257 12.97271248 +H 4.29762276 5.26709128 0.27274428 7.51580070 11.72099606 -13.81315506 +H 4.79081063 4.49826559 1.22479124 -1.65517925 -0.69692353 0.36172507 +O 4.09431817 7.33063898 3.19968059 -1.06004574 -1.11637349 -0.28933974 +H 3.38943825 6.58815044 3.16480781 1.14148290 1.68970918 -0.61438594 +H 4.19910055 7.70582563 2.28520416 0.06644145 -0.21912545 1.05279008 +O 4.01414782 0.34585435 10.69001467 0.27125870 -3.90565143 1.86140239 +H 3.97934383 11.01133108 10.83765511 -0.38882553 3.49365767 0.06856364 +H 3.76050260 0.40037812 9.77321516 -0.24997767 0.45197288 -1.37139619 +O 9.40067440 2.29813725 7.12410112 -2.25397402 -0.18400931 -3.24547412 +H 8.77407566 2.12047129 7.78176256 -1.42089509 -0.29561015 1.47685774 +H 10.18867218 2.35986048 7.56548783 3.33147756 0.45038600 1.81095218 +O 10.91142242 11.36201681 9.39247215 0.01736164 0.75490194 5.37813315 +H 10.21386102 0.09946732 9.94308104 1.27966947 -0.25601514 -1.75074206 +H 10.90565439 11.26390736 8.49001334 -1.04485051 0.02427020 -3.42446926 +O 9.59271281 7.85595320 8.22637725 1.68824880 -3.33631637 -2.48706005 +H 10.24471205 7.11245922 7.84028956 -2.08559754 2.31028120 1.60193682 +H 10.10252213 8.20510433 8.96706659 0.16744728 0.53774132 0.59653734 +O 8.17668751 11.46557679 0.59883281 -0.66836885 3.34592202 -5.23775086 +H 7.40086080 0.24360726 0.81692790 -0.70252854 0.12888586 -0.06194922 +H 8.35316811 11.04541009 1.33745836 1.62912882 -3.53077414 5.68080355 +O 0.94004098 10.15162978 3.54788449 -1.08457407 -1.07225848 -2.11326262 +H 0.69010529 9.19133189 3.77189578 0.19180130 1.74153750 -0.34616217 +H 0.93777081 10.30122818 2.50966526 1.29074064 -1.02193169 1.68592452 +O 2.27424490 5.50915811 2.47661814 3.42542057 -4.02492496 3.71742602 +H 2.87039478 5.20416153 1.76153568 0.28323491 -0.31793299 0.44400092 +H 1.75867813 6.03341397 2.01906506 -5.01572057 4.50292924 -3.15397366 +O 8.94351821 9.24398503 5.88873692 -1.92119599 -4.43389509 -2.57106317 +H 9.09221700 8.40449830 6.46427006 -0.54337204 2.43464552 -0.54531580 +H 8.36787924 8.80270415 5.00977887 2.04375025 1.37647669 3.10278297 +O 2.82171577 2.69312570 11.50923391 0.69614192 -0.37589905 1.50967018 +H 3.43897452 3.14968922 0.41529298 -1.06053939 -1.03034434 -1.56750974 +H 2.93838347 1.70842217 11.51627197 0.70775303 1.01310766 -0.29180954 +O 8.16112970 2.63468866 9.79782190 -1.27590023 2.29413467 1.11429604 +H 8.54928118 3.61930224 9.73373854 -1.52802272 -2.32642774 -0.06867111 +H 7.16627654 2.73929112 10.09193859 2.77315197 0.05309382 -1.14957673 +O 1.59972917 8.97754430 9.43268962 -4.27687261 0.84039630 2.26111141 +H 1.59251119 9.12142758 8.50784661 0.80233881 0.43452331 -2.11537093 +H 0.58284107 9.31288390 9.64240255 3.61139368 -1.36863482 0.04291019 +O 5.26826606 3.30655851 9.49391542 -0.84340449 0.50978314 1.64621124 +H 4.78704289 2.59503212 9.98054679 0.22770831 1.03334739 -0.67504858 +H 4.89323288 4.20901733 9.82787916 0.44176714 -1.60587061 -1.14998811 +O 7.36117251 0.04299829 7.68322976 -0.43534967 -0.03071647 3.23220208 +H 7.04043820 0.88966860 7.30592641 -0.03364444 -0.06154709 0.77135186 +H 7.21231496 0.07133097 8.74782847 0.53084562 -0.30424906 -4.25693627 +O 4.18004487 0.62351363 4.09110077 -1.56364280 2.67147508 0.69047006 +H 4.80862273 1.37200834 4.08969845 0.44022654 -0.29248317 0.30456530 +H 3.28940788 1.15635275 4.10035608 1.67807237 -2.07241806 -0.35558115 +O 2.70842422 6.37716749 7.53125007 2.06494643 5.11524802 -0.69604422 +H 2.64632527 5.94334801 6.65921894 0.39734926 -0.66881108 -0.10087985 +H 3.09797273 7.37493112 7.16849909 -1.55386232 -3.75355518 1.65329206 +O 6.41182859 2.43882960 7.01329141 -4.92870926 2.02800480 0.61801120 +H 5.72035273 2.64918283 6.21809682 2.70379386 -1.03493633 1.32465351 +H 5.74453613 2.77554505 7.76408804 2.33632649 -1.16785214 -1.14451680 +O 10.57301359 4.37816353 1.12879320 4.89504836 -1.30601815 2.06174798 +H 10.93983923 3.52935799 1.61142928 -2.07324081 1.97823337 -1.66711946 +H 9.71802197 4.32374295 0.82312457 -3.01427013 -1.27858446 -0.62062344 +O 4.08049606 10.52136590 2.32719437 0.42891316 2.23963240 1.86107843 +H 3.15932554 10.71626186 1.93731777 1.63434302 0.18014648 1.02641569 +H 4.27975774 11.25147169 3.16363304 -1.47402439 -2.26965261 -2.92434833 +O 1.84757461 2.91355447 6.97180392 -2.90159406 -1.25619014 -2.57119687 +H 1.40897666 2.99356607 7.80562845 -0.14525659 0.44582177 1.66832274 +H 0.95316987 2.61135724 6.50289999 2.62214073 0.34147968 0.65385239 +O 5.97425196 7.07160674 6.51708194 -1.32986257 4.64477707 0.74215954 +H 6.45172856 6.31562418 6.45929579 1.62818266 -4.29121937 -0.25735983 +H 5.04161946 6.94539798 6.38552849 -0.03046759 -0.33969019 -0.12541847 +O 3.26609762 9.55064322 11.51695990 -2.28802058 2.61571811 -1.59542165 +H 2.90240471 9.12894190 10.76187694 -0.52591939 -0.96484803 -1.23090590 +H 3.99413434 9.15122026 0.17271762 2.46309222 -2.04713382 1.18471344 +O 10.53697662 4.56069792 5.76512112 1.37467177 5.21414812 -0.75573497 +H 10.35298171 3.68800003 5.75575469 -0.26204901 -4.34917206 0.71721469 +H 11.27576093 4.91570703 6.31964592 -1.20071085 -0.75218685 -0.31628954 +O 4.63525899 10.81807556 6.19613596 6.71521024 -1.10373908 4.35700365 +H 4.35573171 11.34174933 5.54381921 -2.37777269 3.55238684 -4.56307766 +H 5.67410264 11.23263298 6.27461294 -3.69088709 -2.18111407 -0.04988097 +O 0.14165280 7.33720078 1.35500587 -5.25636565 -0.09894072 6.15964602 +H 0.43427986 7.46266870 0.56089609 3.84631548 0.76780374 -7.44962047 +H 10.92671564 7.65920512 1.13098929 2.25378890 -0.78840343 1.73881728 +O 10.27603934 2.10768108 2.86506067 1.27164762 -1.00033441 0.95626568 +H 11.01111940 2.17232008 3.54679438 -1.62994129 0.20511087 -0.66830200 +H 10.25354931 1.09798979 2.87938021 -0.22944278 0.78903592 -0.44897292 +O 5.69627517 0.40138197 0.90041090 2.59900079 0.43527099 4.94188874 +H 5.24481822 0.44575031 0.13631182 -2.71563637 0.20740533 -4.72132913 +H 5.15967360 11.54675258 1.48204016 -0.21494998 -0.23595898 -0.06999420 +O 6.15025630 7.53315510 9.28182120 3.37681641 -8.71825476 -0.73325838 +H 5.88831358 8.32623299 9.17037648 -3.12216395 9.38792459 -0.76454872 +H 6.54803881 7.35413445 8.40211700 -0.57476523 -0.20697287 0.55008777 +O 11.51516070 7.99909564 4.70859766 -1.85752516 -3.33904174 -2.28857594 +H 10.55803788 8.11683757 4.90491182 0.94879405 0.68445368 -0.18751835 +H 11.33058369 7.11780391 3.97989418 0.73715617 2.85156037 2.24620414 +O 4.32177970 9.39263090 8.51430257 10.95624679 -10.19283450 -11.15278201 +H 4.15839623 9.72273165 7.58681367 0.74426271 -1.00007216 1.44278033 +H 3.83315863 9.85920645 8.89837939 -11.80157447 10.92359377 9.90425375 +O 7.32079629 7.68735734 4.30241710 11.26303096 2.57085748 -17.99639319 +H 7.49129719 6.77442082 3.96804590 0.55662866 -0.25286348 0.13183903 +H 6.92412505 7.59215836 4.94862185 -11.68489776 -2.23415081 18.08807877 +O 10.60936807 5.82587067 3.33454140 -3.62983878 8.74340016 19.17817558 +H 10.85887512 5.24453775 3.99905568 0.76538690 -2.26595537 2.00967797 +H 10.70514914 5.60689714 2.61041531 3.34908448 -6.50093440 -20.65116124 +O 8.16710940 1.89280867 4.64662571 -2.78388890 -0.72861497 -4.24536116 +H 8.53457006 2.07518431 5.44433390 2.39779131 1.05826653 3.94904653 +H 8.71888248 2.16570536 3.89810984 0.79682122 0.07677472 0.24100195 +O 7.35476947 9.37220466 2.30818103 -1.57494023 0.21925041 2.66969073 +H 7.02816129 10.06029379 3.00869524 1.12488899 -1.41697158 -1.75353428 +H 7.24983362 8.55541964 2.92457723 0.16987646 1.65735955 -1.30613127 +O 7.59850849 4.79108580 6.60704207 -0.18279832 -1.17206875 0.89977338 +H 7.35064188 3.91050317 7.03265930 -0.45973248 0.75343126 -1.50952619 +H 8.26643596 4.99101424 7.26121093 0.89534594 0.76599882 -0.01246266 +O 10.40399439 11.10430751 6.79294202 2.81111175 -0.62073657 -1.02525355 +H 9.63001980 11.47674243 7.07319427 -3.06936890 2.63521222 1.66147846 +H 9.93990598 10.35631552 6.42130087 0.40481317 -1.41940385 -0.75827008 +O 3.49465984 8.56023515 5.57334730 4.08231201 1.30895435 -3.78476839 +H 4.21974904 9.43364214 5.51672534 -2.68191376 -2.97917456 0.99539275 +H 3.65141801 8.24495137 4.57004848 -0.61115664 0.20748555 2.30058300 +O 10.64836843 8.95923477 10.32953916 -2.39696341 0.22372670 2.32516276 +H 10.35155293 9.90349858 10.16999223 0.96258032 -1.01845041 -0.24083020 +H 10.06870771 8.69586327 11.18304908 2.34525851 0.93941467 -2.29535337 +O 11.45213569 2.78041878 8.81582775 0.13917079 0.55494715 -5.11691409 +H 11.40535643 3.52377517 9.35537683 -0.02186930 2.41838071 2.07431039 +H 11.49658658 2.03804136 9.30441706 0.27983642 -3.85064522 2.72189444 +O 5.72252236 2.78391664 4.58865435 2.00202122 0.11734314 -0.87511135 +H 5.58377209 3.78623649 4.59301477 0.76167423 -0.88388396 -0.15393664 +H 6.80527185 2.57167953 4.53785334 -2.41439550 0.27077174 0.34678438 +O 3.49103498 0.47779834 8.23103401 1.10622791 -0.34709034 -2.37964446 +H 4.13096898 -0.02426013 7.57813517 -2.14205385 1.52066936 1.47370043 +H 2.91503088 1.07665866 7.70577271 0.67948116 -0.68213454 0.43600941 +O 4.20392664 5.85068908 9.77697231 -3.57610785 -0.45028264 -2.74477212 +H 3.54920743 6.04214539 8.83884695 1.55419656 -1.34864705 2.76174141 +H 4.72216576 6.61011129 9.59234238 1.78331795 1.13481345 0.09583637 +O 1.45781442 6.56317328 10.70033363 2.20226397 0.63186945 -0.26552977 +H 2.44013666 6.43590616 10.96735645 -2.54561438 0.39446191 -1.41605113 +H 1.41492990 7.19966763 9.90556238 0.60172068 -0.81249982 1.68488579 +O 7.98681873 5.23825113 3.84425018 3.13433042 -1.45225227 2.74527092 +H 9.01755010 5.30341401 3.65056603 -2.77410842 0.24112279 0.40572437 +H 8.05524134 4.91626267 4.90003810 -0.71303407 1.02736700 -3.07754501 +O 10.04494765 1.29552636 11.26724117 1.73607648 6.18659119 -7.11522264 +H 9.25266353 1.89236946 10.78695994 2.14781312 -1.66582877 1.26821777 +H 9.64420175 0.85762163 0.14036796 -3.95195702 -4.15348017 5.70846863 +O 1.65347770 0.84323330 3.65143388 1.42823302 18.52259540 -1.58570801 +H 1.30082871 1.18126641 2.78746213 0.06763907 -0.72436750 0.47733529 +H 1.56093519 11.80917156 3.71966599 -1.83660666 -17.48993707 1.07095236 +O 9.54164721 8.40767336 0.89833123 11.90981796 -1.53191537 -0.18516168 +H 9.00596112 8.92785456 1.26700370 -5.91832217 7.89755826 6.08966056 +H 8.94113691 7.98808875 0.49614438 -7.23364771 -6.68636245 -5.64295689 +O 5.09521982 8.37512896 0.81392218 4.28942474 -0.28139245 -0.78329207 +H 5.86312474 9.13169362 1.08096087 -2.32849491 -3.31964533 -0.61849971 +H 5.54545966 7.65089703 0.19278931 -0.77571760 2.85232141 2.11019787 +O 7.30703768 11.10160871 4.63375612 -0.88771490 -3.55613037 3.03358941 +H 7.63533922 10.74907085 5.55297398 -0.86701751 0.61395914 -2.54795409 +H 7.79843164 0.12321944 4.54375366 1.37331423 2.81788919 -0.43847047 +O 10.24857504 11.03916580 3.08253663 2.08557697 -0.83679675 1.42397013 +H 11.34180224 10.69340141 3.15046182 -3.53841546 0.81605308 0.79711433 +H 9.89365589 10.83162249 3.99602350 0.78547751 0.29104335 -1.77379458 +O 11.39159782 1.59286574 5.25942351 10.18846363 14.10111874 -9.35825404 +H 11.14796463 0.93367498 5.59827155 -6.44107909 -15.49337173 8.38653088 +H 0.63682773 1.17983234 4.75863136 -3.26021169 1.80275036 1.16045250 +O 7.51865565 7.18246937 11.26967539 2.85097416 -4.61661833 5.64131139 +H 7.03461725 7.13050416 10.50220968 -1.92595253 0.25077060 -3.35841759 +H 7.61030914 6.12718417 11.64062861 -0.25806174 3.92457476 -1.28720280 +O 1.75378853 9.68542466 6.78468686 -3.15344401 -0.46783763 0.39065667 +H 2.25816320 9.04199808 6.22910370 -0.18277466 0.61772838 -0.22779213 +H 0.77725020 9.49312166 6.62916167 2.16249413 0.72466575 0.07833126 +O 0.79380286 1.53270887 1.01780886 -0.44967946 0.87239112 -0.65009858 +H -0.05250973 1.65674801 0.40803849 2.46039770 0.24974010 1.38686909 +H 1.63024682 2.14220989 0.87222163 -2.73532074 -0.92775214 -1.05769575 +O 5.12024461 5.20156327 4.72964303 16.37304935 -0.02624223 3.03762604 +H 5.10836987 5.84100514 4.17073135 0.04007426 6.65016130 -5.75084043 +H 4.40036251 4.91235205 4.83038250 -16.81327383 -5.96156814 2.63335074 diff --git a/tests/bulk_small_unit_cell.xyz b/tests/bulk_small_unit_cell.xyz new file mode 100644 index 0000000..621602a --- /dev/null +++ b/tests/bulk_small_unit_cell.xyz @@ -0,0 +1,46 @@ +4 +Lattice="5 0 0 0 5 0 0 0 5" +Ga 0.028814 -0.044917 0.285450 +Ga 2.706516 2.379548 0.037928 +Ga 2.440308 0.049296 2.573117 +Ga -0.172131 2.582749 2.413305 +1 +Lattice="4 0 0 0 4 0 0 0 4" +N 1.768999 0.270888 1.705446 +2 +Lattice="4 0 0 0 4 0 0 0 4" +N 0.102312 -0.203830 0.256609 +Ga 2.233827 2.074812 2.061545 +4 +Lattice="5 0 0 0 5 0 0 0 5" +N -0.003714 -0.021372 0.191142 +Ga 2.245528 2.312564 0.005043 +Ga 2.733145 -0.138738 2.287960 +Ga 0.146076 2.257762 2.784547 +8 +Lattice="8 0 0 0 8 0 0 0 8" +N 0.295604 0.207245 0.262794 +Ga 1.857919 1.959535 1.919910 +N 3.752949 4.287860 -0.220102 +Ga 5.778130 5.751785 2.179060 +N 4.282583 0.068382 4.181446 +Ga 6.039969 1.705196 6.001205 +N 0.060032 4.273156 3.784467 +Ga 1.813436 5.838242 6.201837 +2 +Lattice="4 0 0 0 4 0 0 0 4" +Ga 0.295633 0.142376 -0.096602 +Ga 1.873007 2.083859 2.029016 +2 +Lattice="2 2 0 2 -2 0 0 0 4" +N -0.015750 -0.265133 0.017505 +N 2.280493 -0.126946 2.195588 +1 +Lattice="2 2 0 2 0 2 0 2 2" +N 0.254564 -0.039968 -0.290847 +4 +Lattice="4 0 0 0 4 0 0 0 4" +N 0.013750 0.199993 0.031294 +N 2.227114 2.069536 -0.026844 +N 2.064534 0.292672 1.906420 +N -0.131838 1.735252 2.009850 \ No newline at end of file diff --git a/tests/test_cpp_extension.py b/tests/test_cpp_extension.py new file mode 100644 index 0000000..727ab4d --- /dev/null +++ b/tests/test_cpp_extension.py @@ -0,0 +1,139 @@ +import pytest +from itertools import product + +import ase.io +from pet.hypers import load_hypers_from_file +from pet.data_preparation import get_all_species +from pet.pet import PET, PETUtilityWrapper, PETMLIPWrapper +import torch +from pet.molecule import MoleculeCPP, Molecule +from matscipy.neighbours import neighbour_list as neighbor_list + +def prepare_test(stucture_path, r_cut, n_gnn, n_trans, structure_index, hypers_path = "../default_hypers/default_hypers.yaml"): + device = 'cpu' + structure = ase.io.read(stucture_path, index=structure_index) + hypers = load_hypers_from_file(hypers_path) + + + MLIP_SETTINGS = hypers.MLIP_SETTINGS + ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS + FITTING_SCHEME = hypers.FITTING_SCHEME + + ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar + ARCHITECTURAL_HYPERS.TARGET_TYPE = "structural" # energy is structural property + ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = ( + "sum" # energy is a sum of atomic energies + ) + ARCHITECTURAL_HYPERS.R_CUT = r_cut + ARCHITECTURAL_HYPERS.N_TRANS_LAYERS = n_trans + ARCHITECTURAL_HYPERS.N_GNN_LAYERS = n_gnn + all_species = get_all_species([structure]) + + + model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)).to(device) + model = PETUtilityWrapper(model, FITTING_SCHEME.GLOBAL_AUG) + + model = PETMLIPWrapper( + model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES + ) + return model, structure, all_species, ARCHITECTURAL_HYPERS + +def get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS): + device = 'cpu' + molecule = Molecule( + structure, + ARCHITECTURAL_HYPERS.R_CUT, + ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, + ARCHITECTURAL_HYPERS.USE_LONG_RANGE, + ARCHITECTURAL_HYPERS.K_CUT, + ) + if ARCHITECTURAL_HYPERS.USE_LONG_RANGE: + raise NotImplementedError( + "Long range interactions are not supported in the SingleStructCalculator" + ) + + graph = molecule.get_graph( + molecule.get_max_num(), all_species, None + ) + graph.batch = torch.zeros( + graph.num_nodes, dtype=torch.long, device=graph.x.device + ) + graph = graph.to(device) + prediction_energy, prediction_forces = model( + graph, augmentation=False, create_graph=False + ) + + return prediction_energy, prediction_forces, graph + +def get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS): + device = 'cpu' + molecule = MoleculeCPP( + structure, + ARCHITECTURAL_HYPERS.R_CUT, + ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, + ARCHITECTURAL_HYPERS.USE_LONG_RANGE, + ARCHITECTURAL_HYPERS.K_CUT, + ) + if ARCHITECTURAL_HYPERS.USE_LONG_RANGE: + raise NotImplementedError( + "Long range interactions are not supported in the SingleStructCalculator" + ) + + graph = molecule.get_graph( + molecule.get_max_num(), all_species, None + ) + graph.batch = torch.zeros( + graph.num_nodes, dtype=torch.long, device=graph.x.device + ) + graph = graph.to(device) + prediction_energy, prediction_forces = model( + graph, augmentation=False, create_graph=False + ) + + return prediction_energy, prediction_forces, graph + +class Float64DtypeContext: + def __enter__(self): + # Save the current default dtype + self.original_dtype = torch.get_default_dtype() + # Set the default dtype to float64 + torch.set_default_dtype(torch.float64) + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Restore the original default dtype + torch.set_default_dtype(self.original_dtype) + + +def do_single_test(stucture_path, r_cut, n_gnn, n_trans, structure_index, epsilon): + model, structure, all_species, ARCHITECTURAL_HYPERS = prepare_test(stucture_path, r_cut, n_gnn, n_trans, structure_index) + python_energy, python_forces, python_graph = get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS) + cpp_energy, cpp_forces, cpp_graph = get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS) + # print(f"energy difference: {torch.abs(python_energy - cpp_energy)}") + # print(f"forces difference: {torch.abs(python_forces - cpp_forces).max()}") + assert torch.abs(python_energy - cpp_energy) < epsilon, f"Energy difference is {torch.abs(python_energy - cpp_energy)}" + assert torch.abs(python_forces - cpp_forces).max() < epsilon, f"Max force difference is {torch.abs(python_forces - cpp_forces).max()}" + + + +# Define the parameters for each case +case1_params = ("../example/methane_train.xyz", 0, [1.0, 2.0, 3.0, 4.0, 5.0, 10.0]) +case2_params = ("bulk.xyz", 0, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) +case3_params = ("bulk_small_unit_cell.xyz", list(range(9)), [2.0, 3.0, 5.0, 10.0, 15.0, 20.0]) + +# Generate the expanded lists using product +expanded_case1 = list(product([case1_params[0]], [case1_params[1]], case1_params[2])) +expanded_case2 = list(product([case2_params[0]], [case2_params[1]], case2_params[2])) +expanded_case3 = list(product([case3_params[0]], case3_params[1], case3_params[2])) + +# Combine all cases into one list +all_cases = expanded_case1 + expanded_case2 + expanded_case3 + +@pytest.mark.parametrize("structures_path, structure_index, r_cut", all_cases) +def test_do_single(structures_path, structure_index, r_cut): + n_gnn = 2 + n_trans = 2 + epsilon = 1e-10 + with Float64DtypeContext(): + do_single_test(structures_path, r_cut, n_gnn, n_trans, structure_index, epsilon) + #do_single_test(structures_path, r_cut, n_gnn, n_trans, structure_index, epsilon) \ No newline at end of file diff --git a/tests/test_cpp_float32.ipynb b/tests/test_cpp_float32.ipynb new file mode 100644 index 0000000..083e4f0 --- /dev/null +++ b/tests/test_cpp_float32.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import ase.io\n", + "from pet.hypers import load_hypers_from_file\n", + "from pet.data_preparation import get_all_species\n", + "from pet.pet import PET, PETUtilityWrapper, PETMLIPWrapper\n", + "import torch\n", + "from pet.molecule import MoleculeCPP, Molecule\n", + "from matscipy.neighbours import neighbour_list as neighbor_list\n", + "torch.set_default_dtype(torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "def prepare_test(stucture_path, r_cut, n_gnn, n_trans, hypers_path = \"../default_hypers/default_hypers.yaml\"):\n", + " structure = ase.io.read(stucture_path, index=0)\n", + " hypers = load_hypers_from_file(hypers_path)\n", + " \n", + "\n", + " MLIP_SETTINGS = hypers.MLIP_SETTINGS\n", + " ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS\n", + " FITTING_SCHEME = hypers.FITTING_SCHEME\n", + "\n", + " ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar\n", + " ARCHITECTURAL_HYPERS.TARGET_TYPE = \"structural\" # energy is structural property\n", + " ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = (\n", + " \"sum\" # energy is a sum of atomic energies\n", + " )\n", + " ARCHITECTURAL_HYPERS.R_CUT = r_cut\n", + " ARCHITECTURAL_HYPERS.N_TRANS_LAYERS = n_trans\n", + " ARCHITECTURAL_HYPERS.N_GNN_LAYERS = n_gnn\n", + " all_species = get_all_species([structure])\n", + "\n", + "\n", + " model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)).to(device)\n", + " model = PETUtilityWrapper(model, FITTING_SCHEME.GLOBAL_AUG)\n", + "\n", + " model = PETMLIPWrapper(\n", + " model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES\n", + " )\n", + " return model, structure, all_species, ARCHITECTURAL_HYPERS\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS):\n", + " \n", + " molecule = Molecule(\n", + " structure,\n", + " ARCHITECTURAL_HYPERS.R_CUT,\n", + " ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,\n", + " ARCHITECTURAL_HYPERS.USE_LONG_RANGE,\n", + " ARCHITECTURAL_HYPERS.K_CUT,\n", + " )\n", + " if ARCHITECTURAL_HYPERS.USE_LONG_RANGE:\n", + " raise NotImplementedError(\n", + " \"Long range interactions are not supported in the SingleStructCalculator\"\n", + " )\n", + "\n", + " graph = molecule.get_graph(\n", + " molecule.get_max_num(), all_species, None\n", + " )\n", + " graph.batch = torch.zeros(\n", + " graph.num_nodes, dtype=torch.long, device=graph.x.device\n", + " )\n", + " graph = graph.to(device)\n", + " prediction_energy, prediction_forces = model(\n", + " graph, augmentation=False, create_graph=False\n", + " )\n", + "\n", + " return prediction_energy, prediction_forces, graph\n", + "\n", + "def get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS):\n", + " \n", + " molecule = MoleculeCPP(\n", + " structure,\n", + " ARCHITECTURAL_HYPERS.R_CUT,\n", + " ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,\n", + " ARCHITECTURAL_HYPERS.USE_LONG_RANGE,\n", + " ARCHITECTURAL_HYPERS.K_CUT,\n", + " )\n", + " if ARCHITECTURAL_HYPERS.USE_LONG_RANGE:\n", + " raise NotImplementedError(\n", + " \"Long range interactions are not supported in the SingleStructCalculator\"\n", + " )\n", + "\n", + " graph = molecule.get_graph(\n", + " molecule.get_max_num(), all_species, None\n", + " )\n", + " graph.batch = torch.zeros(\n", + " graph.num_nodes, dtype=torch.long, device=graph.x.device\n", + " )\n", + " graph = graph.to(device)\n", + " prediction_energy, prediction_forces = model(\n", + " graph, augmentation=False, create_graph=False\n", + " )\n", + "\n", + " return prediction_energy, prediction_forces, graph" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Energy difference: tensor([0.], device='cuda:0', grad_fn=)\n", + "Forces difference: tensor(0., device='cuda:0')\n" + ] + } + ], + "source": [ + "model, structure, all_species, ARCHITECTURAL_HYPERS = prepare_test(\"../example/methane_train.xyz\", 10.0, 2, 2)\n", + "python_energy, python_forces, python_graph = get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "cpp_energy, cpp_forces, cpp_graph = get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "\n", + "print(\"Energy difference: \", torch.abs(python_energy - cpp_energy))\n", + "print(\"Forces difference: \", torch.abs(python_forces - cpp_forces).max())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Energy difference: tensor([0.0002], device='cuda:0', grad_fn=)\n", + "Forces difference: tensor(1.7285e-06, device='cuda:0')\n", + "Forces spread: tensor(2.4026, device='cuda:0')\n" + ] + } + ], + "source": [ + "model, structure, all_species, ARCHITECTURAL_HYPERS = prepare_test(\"bulk.xyz\", 4.0, 2, 2)\n", + "python_energy, python_forces, python_graph = get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "cpp_energy, cpp_forces, cpp_graph = get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "\n", + "print(\"Energy difference: \", torch.abs(python_energy - cpp_energy))\n", + "print(\"Forces difference: \", torch.abs(python_forces - cpp_forces).max())\n", + "\n", + "print(\"Forces spread: \", torch.abs(python_forces).max())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "341 ms ± 2.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "76.3 ms ± 457 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.27 ms ± 7.34 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit i_list, j_list, D_list, S_list = neighbor_list('ijDS', structure, 4.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "83.9 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit i_list, j_list, D_list, S_list = ase.neighborlist.neighbor_list(\"ijDS\", structure, 4.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + " def is_3d_crystal(atoms):\n", + " pbc = atoms.get_pbc()\n", + " if isinstance(pbc, bool):\n", + " return pbc\n", + " return all(pbc)\n", + "\n", + "print(is_3d_crystal(structure))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_cpp_float64.ipynb b/tests/test_cpp_float64.ipynb new file mode 100644 index 0000000..5fa1aa9 --- /dev/null +++ b/tests/test_cpp_float64.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import ase.io\n", + "from pet.hypers import load_hypers_from_file\n", + "from pet.data_preparation import get_all_species\n", + "from pet.pet import PET, PETUtilityWrapper, PETMLIPWrapper\n", + "import torch\n", + "from pet.molecule import MoleculeCPP, Molecule\n", + "from matscipy.neighbours import neighbour_list as neighbor_list\n", + "torch.set_default_dtype(torch.float64)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "def prepare_test(stucture_path, r_cut, n_gnn, n_trans, hypers_path = \"../default_hypers/default_hypers.yaml\"):\n", + " structure = ase.io.read(stucture_path, index=0)\n", + " hypers = load_hypers_from_file(hypers_path)\n", + " \n", + "\n", + " MLIP_SETTINGS = hypers.MLIP_SETTINGS\n", + " ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS\n", + " FITTING_SCHEME = hypers.FITTING_SCHEME\n", + "\n", + " ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar\n", + " ARCHITECTURAL_HYPERS.TARGET_TYPE = \"structural\" # energy is structural property\n", + " ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = (\n", + " \"sum\" # energy is a sum of atomic energies\n", + " )\n", + " ARCHITECTURAL_HYPERS.R_CUT = r_cut\n", + " ARCHITECTURAL_HYPERS.N_TRANS_LAYERS = n_trans\n", + " ARCHITECTURAL_HYPERS.N_GNN_LAYERS = n_gnn\n", + " all_species = get_all_species([structure])\n", + "\n", + "\n", + " model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)).to(device)\n", + " model = PETUtilityWrapper(model, FITTING_SCHEME.GLOBAL_AUG)\n", + "\n", + " model = PETMLIPWrapper(\n", + " model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES\n", + " )\n", + " return model, structure, all_species, ARCHITECTURAL_HYPERS\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS):\n", + " \n", + " molecule = Molecule(\n", + " structure,\n", + " ARCHITECTURAL_HYPERS.R_CUT,\n", + " ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,\n", + " ARCHITECTURAL_HYPERS.USE_LONG_RANGE,\n", + " ARCHITECTURAL_HYPERS.K_CUT,\n", + " )\n", + " if ARCHITECTURAL_HYPERS.USE_LONG_RANGE:\n", + " raise NotImplementedError(\n", + " \"Long range interactions are not supported in the SingleStructCalculator\"\n", + " )\n", + "\n", + " graph = molecule.get_graph(\n", + " molecule.get_max_num(), all_species, None\n", + " )\n", + " graph.batch = torch.zeros(\n", + " graph.num_nodes, dtype=torch.long, device=graph.x.device\n", + " )\n", + " graph = graph.to(device)\n", + " prediction_energy, prediction_forces = model(\n", + " graph, augmentation=False, create_graph=False\n", + " )\n", + "\n", + " return prediction_energy, prediction_forces, graph\n", + "\n", + "def get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS):\n", + " \n", + " molecule = MoleculeCPP(\n", + " structure,\n", + " ARCHITECTURAL_HYPERS.R_CUT,\n", + " ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,\n", + " ARCHITECTURAL_HYPERS.USE_LONG_RANGE,\n", + " ARCHITECTURAL_HYPERS.K_CUT,\n", + " )\n", + " if ARCHITECTURAL_HYPERS.USE_LONG_RANGE:\n", + " raise NotImplementedError(\n", + " \"Long range interactions are not supported in the SingleStructCalculator\"\n", + " )\n", + "\n", + " graph = molecule.get_graph(\n", + " molecule.get_max_num(), all_species, None\n", + " )\n", + " graph.batch = torch.zeros(\n", + " graph.num_nodes, dtype=torch.long, device=graph.x.device\n", + " )\n", + " graph = graph.to(device)\n", + " prediction_energy, prediction_forces = model(\n", + " graph, augmentation=False, create_graph=False\n", + " )\n", + "\n", + " return prediction_energy, prediction_forces, graph" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Energy difference: tensor([0.], device='cuda:0', grad_fn=)\n", + "Forces difference: tensor(0., device='cuda:0')\n" + ] + } + ], + "source": [ + "model, structure, all_species, ARCHITECTURAL_HYPERS = prepare_test(\"../example/methane_train.xyz\", 10.0, 2, 2)\n", + "python_energy, python_forces, python_graph = get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "cpp_energy, cpp_forces, cpp_graph = get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "\n", + "print(\"Energy difference: \", torch.abs(python_energy - cpp_energy))\n", + "print(\"Forces difference: \", torch.abs(python_forces - cpp_forces).max())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Energy difference: tensor([1.1369e-13], device='cuda:0', grad_fn=)\n", + "Forces difference: tensor(4.3097e-14, device='cuda:0')\n", + "Forces spread: tensor(1.2995, device='cuda:0')\n" + ] + } + ], + "source": [ + "model, structure, all_species, ARCHITECTURAL_HYPERS = prepare_test(\"bulk.xyz\", 4.0, 2, 2)\n", + "python_energy, python_forces, python_graph = get_predictions_old_python(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "cpp_energy, cpp_forces, cpp_graph = get_predictions_cpp(model, structure, all_species, ARCHITECTURAL_HYPERS)\n", + "\n", + "print(\"Energy difference: \", torch.abs(python_energy - cpp_energy))\n", + "print(\"Forces difference: \", torch.abs(python_forces - cpp_forces).max())\n", + "\n", + "print(\"Forces spread: \", torch.abs(python_forces).max())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_pet_runs_without_errors.py b/tests/test_pet_runs_without_errors.py index 861d158..1197a9a 100644 --- a/tests/test_pet_runs_without_errors.py +++ b/tests/test_pet_runs_without_errors.py @@ -123,7 +123,9 @@ def test_single_struct_calculator(prepare_model): ) structure = ase.io.read("../example/methane_test.xyz", index=0) energy, forces = single_struct_calculator.forward(structure) + assert forces.shape == (5, 3), "single_struct_calculator failed" + energy, forces = single_struct_calculator.forward(structure, quadrature_order = 2) assert forces.shape == (5, 3), "single_struct_calculator failed" diff --git a/tests/water_single.xyz b/tests/water_single.xyz new file mode 100644 index 0000000..6a7e03a --- /dev/null +++ b/tests/water_single.xyz @@ -0,0 +1,194 @@ +192 +Lattice="11.74456431055329 0.0 0.0 -5.29177210903e-07 11.74456431055329 0.0 -5.29177210903e-07 -5.29177210903e-07 11.74456431055329" Properties=species:S:1:pos:R:3:forces:R:3 cutoff=-1.0 nneightol=1.2 energy=-30010.461001933934 pbc="T T T" +O 5.73416426 3.02709473 1.80318722 -0.80332611 -1.21376692 0.20230786 +H 5.56943139 2.10076473 1.36991809 0.50217575 1.53231646 1.34795286 +H 5.56662675 2.96533975 2.79956970 0.54751666 -0.28190205 -1.62430029 +O 1.78772996 10.85151956 1.15475463 -2.83291472 2.12456520 -0.92718649 +H 1.27361313 11.62702876 0.64046847 1.81789416 -2.16291579 1.69135469 +H 2.16766332 10.15956744 0.55739294 -0.13023518 0.45973042 -0.26124271 +O 7.52468827 9.97186828 10.15051851 -1.51298176 3.37316030 2.50257924 +H 8.17832796 9.32362620 10.34398569 0.48150767 -0.47552780 0.31697757 +H 7.43679193 10.62175081 11.03466779 1.05936697 -2.82762853 -2.99742426 +O 7.91178140 4.58707211 0.75474958 1.34499609 1.23814613 -2.36786880 +H 8.06085062 4.79654691 1.63976672 0.38014497 0.68880399 2.57828797 +H 7.13320297 4.09945648 0.81464715 -2.16483898 -1.83566563 0.07202743 +O 8.92351531 5.17808368 9.14508181 2.09709552 -3.02514590 -0.30432053 +H 8.62775816 5.97546907 9.42279401 -2.76871939 4.09653535 0.85838888 +H 9.79237137 5.33770469 9.55709918 0.94460315 -1.16752304 -0.52726150 +O 11.39392620 4.75242940 10.31763267 -1.23014486 0.11996567 -0.73341779 +H 11.04689178 4.60206370 11.18738833 -0.18260446 -0.43417416 1.22912156 +H 0.40332564 5.30748867 10.49062070 1.23808957 1.01266029 -0.11816898 +O 0.06869249 6.49660278 7.03350598 -8.74972508 3.31707423 -1.29926643 +H 11.62972756 7.37731241 6.49459191 0.80802095 -2.94173414 0.93328515 +H 0.92074189 6.41346904 7.01763066 7.99459174 -0.50270591 0.52857277 +O 2.58803640 5.05946331 5.42406641 -2.78970988 -0.38920966 -3.53642028 +H 2.09149884 4.27517506 5.74982790 0.69823994 0.18291556 0.51624689 +H 2.05534545 5.27435689 4.54993445 1.81884547 -0.35324094 2.29502427 +O 3.93554913 4.81396213 0.78823062 -4.93885484 -11.15967257 12.97271248 +H 4.29762276 5.26709128 0.27274428 7.51580070 11.72099606 -13.81315506 +H 4.79081063 4.49826559 1.22479124 -1.65517925 -0.69692353 0.36172507 +O 4.09431817 7.33063898 3.19968059 -1.06004574 -1.11637349 -0.28933974 +H 3.38943825 6.58815044 3.16480781 1.14148290 1.68970918 -0.61438594 +H 4.19910055 7.70582563 2.28520416 0.06644145 -0.21912545 1.05279008 +O 4.01414782 0.34585435 10.69001467 0.27125870 -3.90565143 1.86140239 +H 3.97934383 11.01133108 10.83765511 -0.38882553 3.49365767 0.06856364 +H 3.76050260 0.40037812 9.77321516 -0.24997767 0.45197288 -1.37139619 +O 9.40067440 2.29813725 7.12410112 -2.25397402 -0.18400931 -3.24547412 +H 8.77407566 2.12047129 7.78176256 -1.42089509 -0.29561015 1.47685774 +H 10.18867218 2.35986048 7.56548783 3.33147756 0.45038600 1.81095218 +O 10.91142242 11.36201681 9.39247215 0.01736164 0.75490194 5.37813315 +H 10.21386102 0.09946732 9.94308104 1.27966947 -0.25601514 -1.75074206 +H 10.90565439 11.26390736 8.49001334 -1.04485051 0.02427020 -3.42446926 +O 9.59271281 7.85595320 8.22637725 1.68824880 -3.33631637 -2.48706005 +H 10.24471205 7.11245922 7.84028956 -2.08559754 2.31028120 1.60193682 +H 10.10252213 8.20510433 8.96706659 0.16744728 0.53774132 0.59653734 +O 8.17668751 11.46557679 0.59883281 -0.66836885 3.34592202 -5.23775086 +H 7.40086080 0.24360726 0.81692790 -0.70252854 0.12888586 -0.06194922 +H 8.35316811 11.04541009 1.33745836 1.62912882 -3.53077414 5.68080355 +O 0.94004098 10.15162978 3.54788449 -1.08457407 -1.07225848 -2.11326262 +H 0.69010529 9.19133189 3.77189578 0.19180130 1.74153750 -0.34616217 +H 0.93777081 10.30122818 2.50966526 1.29074064 -1.02193169 1.68592452 +O 2.27424490 5.50915811 2.47661814 3.42542057 -4.02492496 3.71742602 +H 2.87039478 5.20416153 1.76153568 0.28323491 -0.31793299 0.44400092 +H 1.75867813 6.03341397 2.01906506 -5.01572057 4.50292924 -3.15397366 +O 8.94351821 9.24398503 5.88873692 -1.92119599 -4.43389509 -2.57106317 +H 9.09221700 8.40449830 6.46427006 -0.54337204 2.43464552 -0.54531580 +H 8.36787924 8.80270415 5.00977887 2.04375025 1.37647669 3.10278297 +O 2.82171577 2.69312570 11.50923391 0.69614192 -0.37589905 1.50967018 +H 3.43897452 3.14968922 0.41529298 -1.06053939 -1.03034434 -1.56750974 +H 2.93838347 1.70842217 11.51627197 0.70775303 1.01310766 -0.29180954 +O 8.16112970 2.63468866 9.79782190 -1.27590023 2.29413467 1.11429604 +H 8.54928118 3.61930224 9.73373854 -1.52802272 -2.32642774 -0.06867111 +H 7.16627654 2.73929112 10.09193859 2.77315197 0.05309382 -1.14957673 +O 1.59972917 8.97754430 9.43268962 -4.27687261 0.84039630 2.26111141 +H 1.59251119 9.12142758 8.50784661 0.80233881 0.43452331 -2.11537093 +H 0.58284107 9.31288390 9.64240255 3.61139368 -1.36863482 0.04291019 +O 5.26826606 3.30655851 9.49391542 -0.84340449 0.50978314 1.64621124 +H 4.78704289 2.59503212 9.98054679 0.22770831 1.03334739 -0.67504858 +H 4.89323288 4.20901733 9.82787916 0.44176714 -1.60587061 -1.14998811 +O 7.36117251 0.04299829 7.68322976 -0.43534967 -0.03071647 3.23220208 +H 7.04043820 0.88966860 7.30592641 -0.03364444 -0.06154709 0.77135186 +H 7.21231496 0.07133097 8.74782847 0.53084562 -0.30424906 -4.25693627 +O 4.18004487 0.62351363 4.09110077 -1.56364280 2.67147508 0.69047006 +H 4.80862273 1.37200834 4.08969845 0.44022654 -0.29248317 0.30456530 +H 3.28940788 1.15635275 4.10035608 1.67807237 -2.07241806 -0.35558115 +O 2.70842422 6.37716749 7.53125007 2.06494643 5.11524802 -0.69604422 +H 2.64632527 5.94334801 6.65921894 0.39734926 -0.66881108 -0.10087985 +H 3.09797273 7.37493112 7.16849909 -1.55386232 -3.75355518 1.65329206 +O 6.41182859 2.43882960 7.01329141 -4.92870926 2.02800480 0.61801120 +H 5.72035273 2.64918283 6.21809682 2.70379386 -1.03493633 1.32465351 +H 5.74453613 2.77554505 7.76408804 2.33632649 -1.16785214 -1.14451680 +O 10.57301359 4.37816353 1.12879320 4.89504836 -1.30601815 2.06174798 +H 10.93983923 3.52935799 1.61142928 -2.07324081 1.97823337 -1.66711946 +H 9.71802197 4.32374295 0.82312457 -3.01427013 -1.27858446 -0.62062344 +O 4.08049606 10.52136590 2.32719437 0.42891316 2.23963240 1.86107843 +H 3.15932554 10.71626186 1.93731777 1.63434302 0.18014648 1.02641569 +H 4.27975774 11.25147169 3.16363304 -1.47402439 -2.26965261 -2.92434833 +O 1.84757461 2.91355447 6.97180392 -2.90159406 -1.25619014 -2.57119687 +H 1.40897666 2.99356607 7.80562845 -0.14525659 0.44582177 1.66832274 +H 0.95316987 2.61135724 6.50289999 2.62214073 0.34147968 0.65385239 +O 5.97425196 7.07160674 6.51708194 -1.32986257 4.64477707 0.74215954 +H 6.45172856 6.31562418 6.45929579 1.62818266 -4.29121937 -0.25735983 +H 5.04161946 6.94539798 6.38552849 -0.03046759 -0.33969019 -0.12541847 +O 3.26609762 9.55064322 11.51695990 -2.28802058 2.61571811 -1.59542165 +H 2.90240471 9.12894190 10.76187694 -0.52591939 -0.96484803 -1.23090590 +H 3.99413434 9.15122026 0.17271762 2.46309222 -2.04713382 1.18471344 +O 10.53697662 4.56069792 5.76512112 1.37467177 5.21414812 -0.75573497 +H 10.35298171 3.68800003 5.75575469 -0.26204901 -4.34917206 0.71721469 +H 11.27576093 4.91570703 6.31964592 -1.20071085 -0.75218685 -0.31628954 +O 4.63525899 10.81807556 6.19613596 6.71521024 -1.10373908 4.35700365 +H 4.35573171 11.34174933 5.54381921 -2.37777269 3.55238684 -4.56307766 +H 5.67410264 11.23263298 6.27461294 -3.69088709 -2.18111407 -0.04988097 +O 0.14165280 7.33720078 1.35500587 -5.25636565 -0.09894072 6.15964602 +H 0.43427986 7.46266870 0.56089609 3.84631548 0.76780374 -7.44962047 +H 10.92671564 7.65920512 1.13098929 2.25378890 -0.78840343 1.73881728 +O 10.27603934 2.10768108 2.86506067 1.27164762 -1.00033441 0.95626568 +H 11.01111940 2.17232008 3.54679438 -1.62994129 0.20511087 -0.66830200 +H 10.25354931 1.09798979 2.87938021 -0.22944278 0.78903592 -0.44897292 +O 5.69627517 0.40138197 0.90041090 2.59900079 0.43527099 4.94188874 +H 5.24481822 0.44575031 0.13631182 -2.71563637 0.20740533 -4.72132913 +H 5.15967360 11.54675258 1.48204016 -0.21494998 -0.23595898 -0.06999420 +O 6.15025630 7.53315510 9.28182120 3.37681641 -8.71825476 -0.73325838 +H 5.88831358 8.32623299 9.17037648 -3.12216395 9.38792459 -0.76454872 +H 6.54803881 7.35413445 8.40211700 -0.57476523 -0.20697287 0.55008777 +O 11.51516070 7.99909564 4.70859766 -1.85752516 -3.33904174 -2.28857594 +H 10.55803788 8.11683757 4.90491182 0.94879405 0.68445368 -0.18751835 +H 11.33058369 7.11780391 3.97989418 0.73715617 2.85156037 2.24620414 +O 4.32177970 9.39263090 8.51430257 10.95624679 -10.19283450 -11.15278201 +H 4.15839623 9.72273165 7.58681367 0.74426271 -1.00007216 1.44278033 +H 3.83315863 9.85920645 8.89837939 -11.80157447 10.92359377 9.90425375 +O 7.32079629 7.68735734 4.30241710 11.26303096 2.57085748 -17.99639319 +H 7.49129719 6.77442082 3.96804590 0.55662866 -0.25286348 0.13183903 +H 6.92412505 7.59215836 4.94862185 -11.68489776 -2.23415081 18.08807877 +O 10.60936807 5.82587067 3.33454140 -3.62983878 8.74340016 19.17817558 +H 10.85887512 5.24453775 3.99905568 0.76538690 -2.26595537 2.00967797 +H 10.70514914 5.60689714 2.61041531 3.34908448 -6.50093440 -20.65116124 +O 8.16710940 1.89280867 4.64662571 -2.78388890 -0.72861497 -4.24536116 +H 8.53457006 2.07518431 5.44433390 2.39779131 1.05826653 3.94904653 +H 8.71888248 2.16570536 3.89810984 0.79682122 0.07677472 0.24100195 +O 7.35476947 9.37220466 2.30818103 -1.57494023 0.21925041 2.66969073 +H 7.02816129 10.06029379 3.00869524 1.12488899 -1.41697158 -1.75353428 +H 7.24983362 8.55541964 2.92457723 0.16987646 1.65735955 -1.30613127 +O 7.59850849 4.79108580 6.60704207 -0.18279832 -1.17206875 0.89977338 +H 7.35064188 3.91050317 7.03265930 -0.45973248 0.75343126 -1.50952619 +H 8.26643596 4.99101424 7.26121093 0.89534594 0.76599882 -0.01246266 +O 10.40399439 11.10430751 6.79294202 2.81111175 -0.62073657 -1.02525355 +H 9.63001980 11.47674243 7.07319427 -3.06936890 2.63521222 1.66147846 +H 9.93990598 10.35631552 6.42130087 0.40481317 -1.41940385 -0.75827008 +O 3.49465984 8.56023515 5.57334730 4.08231201 1.30895435 -3.78476839 +H 4.21974904 9.43364214 5.51672534 -2.68191376 -2.97917456 0.99539275 +H 3.65141801 8.24495137 4.57004848 -0.61115664 0.20748555 2.30058300 +O 10.64836843 8.95923477 10.32953916 -2.39696341 0.22372670 2.32516276 +H 10.35155293 9.90349858 10.16999223 0.96258032 -1.01845041 -0.24083020 +H 10.06870771 8.69586327 11.18304908 2.34525851 0.93941467 -2.29535337 +O 11.45213569 2.78041878 8.81582775 0.13917079 0.55494715 -5.11691409 +H 11.40535643 3.52377517 9.35537683 -0.02186930 2.41838071 2.07431039 +H 11.49658658 2.03804136 9.30441706 0.27983642 -3.85064522 2.72189444 +O 5.72252236 2.78391664 4.58865435 2.00202122 0.11734314 -0.87511135 +H 5.58377209 3.78623649 4.59301477 0.76167423 -0.88388396 -0.15393664 +H 6.80527185 2.57167953 4.53785334 -2.41439550 0.27077174 0.34678438 +O 3.49103498 0.47779834 8.23103401 1.10622791 -0.34709034 -2.37964446 +H 4.13096898 -0.02426013 7.57813517 -2.14205385 1.52066936 1.47370043 +H 2.91503088 1.07665866 7.70577271 0.67948116 -0.68213454 0.43600941 +O 4.20392664 5.85068908 9.77697231 -3.57610785 -0.45028264 -2.74477212 +H 3.54920743 6.04214539 8.83884695 1.55419656 -1.34864705 2.76174141 +H 4.72216576 6.61011129 9.59234238 1.78331795 1.13481345 0.09583637 +O 1.45781442 6.56317328 10.70033363 2.20226397 0.63186945 -0.26552977 +H 2.44013666 6.43590616 10.96735645 -2.54561438 0.39446191 -1.41605113 +H 1.41492990 7.19966763 9.90556238 0.60172068 -0.81249982 1.68488579 +O 7.98681873 5.23825113 3.84425018 3.13433042 -1.45225227 2.74527092 +H 9.01755010 5.30341401 3.65056603 -2.77410842 0.24112279 0.40572437 +H 8.05524134 4.91626267 4.90003810 -0.71303407 1.02736700 -3.07754501 +O 10.04494765 1.29552636 11.26724117 1.73607648 6.18659119 -7.11522264 +H 9.25266353 1.89236946 10.78695994 2.14781312 -1.66582877 1.26821777 +H 9.64420175 0.85762163 0.14036796 -3.95195702 -4.15348017 5.70846863 +O 1.65347770 0.84323330 3.65143388 1.42823302 18.52259540 -1.58570801 +H 1.30082871 1.18126641 2.78746213 0.06763907 -0.72436750 0.47733529 +H 1.56093519 11.80917156 3.71966599 -1.83660666 -17.48993707 1.07095236 +O 9.54164721 8.40767336 0.89833123 11.90981796 -1.53191537 -0.18516168 +H 9.00596112 8.92785456 1.26700370 -5.91832217 7.89755826 6.08966056 +H 8.94113691 7.98808875 0.49614438 -7.23364771 -6.68636245 -5.64295689 +O 5.09521982 8.37512896 0.81392218 4.28942474 -0.28139245 -0.78329207 +H 5.86312474 9.13169362 1.08096087 -2.32849491 -3.31964533 -0.61849971 +H 5.54545966 7.65089703 0.19278931 -0.77571760 2.85232141 2.11019787 +O 7.30703768 11.10160871 4.63375612 -0.88771490 -3.55613037 3.03358941 +H 7.63533922 10.74907085 5.55297398 -0.86701751 0.61395914 -2.54795409 +H 7.79843164 0.12321944 4.54375366 1.37331423 2.81788919 -0.43847047 +O 10.24857504 11.03916580 3.08253663 2.08557697 -0.83679675 1.42397013 +H 11.34180224 10.69340141 3.15046182 -3.53841546 0.81605308 0.79711433 +H 9.89365589 10.83162249 3.99602350 0.78547751 0.29104335 -1.77379458 +O 11.39159782 1.59286574 5.25942351 10.18846363 14.10111874 -9.35825404 +H 11.14796463 0.93367498 5.59827155 -6.44107909 -15.49337173 8.38653088 +H 0.63682773 1.17983234 4.75863136 -3.26021169 1.80275036 1.16045250 +O 7.51865565 7.18246937 11.26967539 2.85097416 -4.61661833 5.64131139 +H 7.03461725 7.13050416 10.50220968 -1.92595253 0.25077060 -3.35841759 +H 7.61030914 6.12718417 11.64062861 -0.25806174 3.92457476 -1.28720280 +O 1.75378853 9.68542466 6.78468686 -3.15344401 -0.46783763 0.39065667 +H 2.25816320 9.04199808 6.22910370 -0.18277466 0.61772838 -0.22779213 +H 0.77725020 9.49312166 6.62916167 2.16249413 0.72466575 0.07833126 +O 0.79380286 1.53270887 1.01780886 -0.44967946 0.87239112 -0.65009858 +H -0.05250973 1.65674801 0.40803849 2.46039770 0.24974010 1.38686909 +H 1.63024682 2.14220989 0.87222163 -2.73532074 -0.92775214 -1.05769575 +O 5.12024461 5.20156327 4.72964303 16.37304935 -0.02624223 3.03762604 +H 5.10836987 5.84100514 4.17073135 0.04007426 6.65016130 -5.75084043 +H 4.40036251 4.91235205 4.83038250 -16.81327383 -5.96156814 2.63335074 diff --git a/utilities/extract_last_model.py b/utilities/extract_last_model.py new file mode 100644 index 0000000..e70be0c --- /dev/null +++ b/utilities/extract_last_model.py @@ -0,0 +1,34 @@ +import torch +import argparse + +def extract_model_state_dict(input_checkpoint_path, output_model_state_dict_path): + """ + Extracts the model state dictionary from a given checkpoint and saves it to the specified output path. + + Parameters: + input_checkpoint_path (str): Path to the input checkpoint file. + output_model_state_dict_path (str): Path to save the extracted model state dictionary. + """ + # Load the checkpoint + checkpoint = torch.load(input_checkpoint_path, map_location=torch.device('cpu')) + + # Extract the model state dictionary + model_state_dict = checkpoint["model_state_dict"] + + # Save the model state dictionary + torch.save(model_state_dict, output_model_state_dict_path) + print(f"Model state dictionary has been saved to {output_model_state_dict_path}") + +def main(): + parser = argparse.ArgumentParser( + description='Extracting the state of the model at the end of fitting and exposing it as all the other model state dicts, such as "best_val_rmse_both_model_state_dict" or "best_val_mae_both_model_state_dict".' + ) + parser.add_argument('path_to_calc_folder', type=str, help='Path to the calc folder.') + + args = parser.parse_args() + input_checkpoint_path = args.path_to_calc_folder + "/checkpoint" + output_model_state_dict_path = args.path_to_calc_folder + "/last_model_state_dict" + extract_model_state_dict(input_checkpoint_path, output_model_state_dict_path) + +if __name__ == '__main__': + main()