Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions src/matkit/graspa/graspa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ase.io import read as ase_read

from matkit.types import GRASPAResult
from matkit.utils.cif import sanitize_cif_stem
from matkit.utils.template import copy_template, render_template
from matkit.utils.unitcell_calculator import calculate_cell_size

Expand Down Expand Up @@ -198,11 +199,11 @@ def setup_simulation(
cifpath = Path(cif)
if not cifpath.exists():
raise FileNotFoundError(f"CIF file does not exist: {cif}")
cifname = cifpath.stem
safe_stem = sanitize_cif_stem(cifpath.stem)

template_path = Path(__file__).parent / "files" / template_dir
copy_template(template_path, outdir)
shutil.copy(cifpath, outdir / f"{cifname}.cif")
shutil.copy(cifpath, outdir / f"{safe_stem}.cif")

# Use pre-computed cell size or read CIF
if cell_size is not None:
Expand All @@ -219,7 +220,7 @@ def setup_simulation(
"TEMPERATURE": str(temperature),
"PRESSURE": str(pressure),
"CUTOFF": str(cutoff),
"CIFFILE": cifname,
"CIFFILE": safe_stem,
"UC_X": str(uc_x),
"UC_Y": str(uc_y),
"UC_Z": str(uc_z),
Expand Down Expand Up @@ -252,10 +253,11 @@ def _setup_single_cif(
"""
atoms = ase_read(cif)
cell_size = calculate_cell_size(atoms)
safe_stem = sanitize_cif_stem(cif.stem)

entries = []
for temp, pres in product(temperatures, pressures):
sim_dir = out_path / cif.stem / f"T{temp}_P{pres:g}"
sim_dir = out_path / safe_stem / f"T{temp}_P{pres:g}"
setup_simulation(
cif=str(cif),
outpath=str(sim_dir),
Expand All @@ -267,15 +269,17 @@ def _setup_single_cif(
template_dir=template_dir,
cell_size=cell_size,
)
entries.append(
{
"sim_dir": str(sim_dir),
"cif": cif.name,
"temperature": temp,
"pressure": pres,
"adsorbates": [ad["MoleculeName"] for ad in adsorbates],
}
)
entry = {
"sim_dir": str(sim_dir),
"cif": cif.name,
"temperature": temp,
"pressure": pres,
"adsorbates": [ad["MoleculeName"] for ad in adsorbates],
}
if safe_stem != cif.stem:
entry["original_cif_stem"] = cif.stem
entry["safe_cif_stem"] = safe_stem
entries.append(entry)
return entries


Expand Down Expand Up @@ -352,4 +356,26 @@ def setup_batch(
for entry in manifest:
f.write(json.dumps(entry) + "\n")

rename_map = {
entry["original_cif_stem"]: entry["safe_cif_stem"]
for entry in manifest
if "original_cif_stem" in entry
}
if rename_map:
_write_cif_mapping(out_path, rename_map)

return manifest


def _write_cif_mapping(out_path: Path, mapping: dict[str, str]) -> None:
"""Write {original_stem: safe_stem} to out_path/cif_mapping.json.

Merges with any existing file so re-running setup_batch with
additional CIFs is additive.
"""
mapping_path = out_path / "cif_mapping.json"
existing: dict[str, str] = {}
if mapping_path.exists():
existing = json.loads(mapping_path.read_text())
existing.update(mapping)
mapping_path.write_text(json.dumps(existing, indent=2, sort_keys=True))
7 changes: 4 additions & 3 deletions src/matkit/graspa_sycl/graspa_sycl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ase.io import read as ase_read

from matkit.utils.cif import sanitize_cif_stem
from matkit.utils.template import copy_template, render_template
from matkit.utils.unitcell_calculator import calculate_cell_size

Expand Down Expand Up @@ -41,9 +42,9 @@ def setup_simulation(
if not cifpath.exists():
raise FileNotFoundError(f"CIF file does not exist: {cif}")

cifname = cifpath.stem
safe_stem = sanitize_cif_stem(cifpath.stem)
copy_template(_file_dir, outdir)
shutil.copy(cif, outdir)
shutil.copy(cifpath, outdir / f"{safe_stem}.cif")

atoms = ase_read(cif)
uc_x, uc_y, uc_z = calculate_cell_size(atoms)
Expand All @@ -57,7 +58,7 @@ def setup_simulation(
"PRESSURE": str(pressure),
"UC_X UC_Y UC_Z": f"{uc_x} {uc_y} {uc_z}",
"CUTOFF": str(cutoff),
"CIFFILE": cifname,
"CIFFILE": safe_stem,
},
)

Expand Down
13 changes: 13 additions & 0 deletions src/matkit/utils/cif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def sanitize_cif_stem(stem: str) -> str:
"""Return a gRASPA-safe CIF stem.

gRASPA's input parser treats the first ``.`` in ``FrameworkName`` as
the extension separator, so stems with extra periods (e.g.
``str_m5_o11_o18_sra_sym.22``) get truncated and the framework file
cannot be located. Replace every ``.`` with ``_`` to make the stem
safe; stems without ``.`` are returned unchanged.
"""
return stem.replace(".", "_") if "." in stem else stem
88 changes: 88 additions & 0 deletions tests/test_graspa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Tests for matkit.graspa module."""

import json
import shutil

import pytest
from pathlib import Path

from matkit.graspa.graspa import (
generate_component_blocks,
setup_batch,
setup_simulation,
)
from matkit.utils.cif import sanitize_cif_stem


class TestGenerateComponentBlocks:
Expand Down Expand Up @@ -124,3 +129,86 @@ def test_setup_nonexistent_cif_raises(self, tmp_path):
outpath=str(tmp_path / "out"),
adsorbates=[{"MoleculeName": "CO2"}],
)

def test_setup_sanitizes_multi_period_cif(self, sample_cif, tmp_path):
"""CIFs with extra periods should be copied under a safe name."""
weird_cif = tmp_path / "foo.bar.cif"
shutil.copy(sample_cif, weird_cif)
outdir = tmp_path / "sim_output"

setup_simulation(
cif=str(weird_cif),
outpath=str(outdir),
adsorbates=[{"MoleculeName": "CO2"}],
)

assert (outdir / "foo_bar.cif").exists()
assert not (outdir / "foo.bar.cif").exists()
content = (outdir / "simulation.input").read_text()
assert "FrameworkName foo_bar" in content


class TestSanitizeCifStem:
"""Tests for the sanitize_cif_stem helper."""

def test_no_period_unchanged(self):
assert sanitize_cif_stem("MOF5") == "MOF5"

def test_single_internal_period_replaced(self):
assert sanitize_cif_stem("foo.bar") == "foo_bar"

def test_multiple_periods_replaced(self):
assert (
sanitize_cif_stem("str_m5_o11_o18_sra_sym.22")
== "str_m5_o11_o18_sra_sym_22"
)


class TestSetupBatch:
"""Tests for gRASPA batch setup, focusing on CIF rename mapping."""

def test_batch_writes_mapping_only_for_renamed(
self, sample_cif, tmp_path
):
"""cif_mapping.json should list only CIFs that needed renaming."""
cif_dir = tmp_path / "cifs"
cif_dir.mkdir()
shutil.copy(sample_cif, cif_dir / "clean.cif")
shutil.copy(sample_cif, cif_dir / "weird.v2.cif")

out_dir = tmp_path / "batch_out"
manifest = setup_batch(
cif_dir=str(cif_dir),
outpath=str(out_dir),
adsorbates=[{"MoleculeName": "CO2"}],
temperatures=[298.0],
pressures=[1e5],
n_cycle=10,
)

assert len(manifest) == 2
mapping_path = out_dir / "cif_mapping.json"
assert mapping_path.exists()
mapping = json.loads(mapping_path.read_text())
assert mapping == {"weird.v2": "weird_v2"}

assert (out_dir / "weird_v2" / "T298.0_P100000" / "weird_v2.cif").exists()
assert (out_dir / "clean" / "T298.0_P100000" / "clean.cif").exists()

def test_batch_no_mapping_when_all_clean(self, sample_cif, tmp_path):
"""cif_mapping.json should not be written if no rename occurred."""
cif_dir = tmp_path / "cifs"
cif_dir.mkdir()
shutil.copy(sample_cif, cif_dir / "clean.cif")

out_dir = tmp_path / "batch_out"
setup_batch(
cif_dir=str(cif_dir),
outpath=str(out_dir),
adsorbates=[{"MoleculeName": "CO2"}],
temperatures=[298.0],
pressures=[1e5],
n_cycle=10,
)

assert not (out_dir / "cif_mapping.json").exists()
Loading