diff --git a/pcntoolkit/normative_model.py b/pcntoolkit/normative_model.py index 36599c50..05f16559 100644 --- a/pcntoolkit/normative_model.py +++ b/pcntoolkit/normative_model.py @@ -54,6 +54,13 @@ class NormativeModel: Input (X/covariates) scaler to use. outscaler: str Output (Y/response_vars) scaler to use. + y_transform : str or None + Optional transform applied to Y before fitting and inverted + after prediction. Currently supported: + - ``"log1p"`` applies log(Y+1) + - ``"log"`` applies natural log(Y) + This is useful for phenotypes that cannot be negative. + Default is ``None`` (no transform). name: str Name of the model """ @@ -68,6 +75,7 @@ def __init__( save_dir: Optional[str] = None, inscaler: str = "standardize", outscaler: str = "standardize", + y_transform: Optional[str] = None, name: Optional[str] = None, ): self.savemodel: bool = savemodel @@ -77,6 +85,7 @@ def __init__( self._save_dir = save_dir if save_dir is not None else get_default_save_dir() self.inscaler: str = inscaler self.outscaler: str = outscaler + self.y_transform: Optional[str] = y_transform self.name: Optional[str] = name self.response_vars: list[str] = None # type: ignore self.template_regression_model: RegressionModel = template_regression_model @@ -181,6 +190,7 @@ def transfer(self, transfer_data: NormData, save_dir: str | None = None, **kwarg saveplots=True, inscaler=self.inscaler, outscaler=self.outscaler, + y_transform=self.y_transform, save_dir=self.save_dir, ) if save_dir is not None: @@ -240,6 +250,7 @@ def extend(self, data: NormData, save_dir: str | None = None, n_synth_samples: i saveplots=True, inscaler=self.inscaler, outscaler=self.outscaler, + y_transform=self.y_transform, save_dir=save_dir, ) @@ -460,6 +471,7 @@ def load(cls, path: str, into: NormativeModel | None = None) -> NormativeModel: outscaler = metadata["outscaler"] saveplots = metadata["saveplots"] evaluate_model = metadata["evaluate_model"] + y_transform = metadata.get("y_transform", None) name = metadata["name"] response_vars = [] @@ -492,6 +504,7 @@ def load(cls, path: str, into: NormativeModel | None = None) -> NormativeModel: save_dir=save_dir, inscaler=inscaler, outscaler=outscaler, + y_transform=y_transform, name=name, ) else: @@ -549,9 +562,13 @@ def preprocess(self, data: NormData) -> None: """ Applies preprocessing transformations to the input data. + First applies an optional response transform (e.g. log1p), then scales. + Args: data (NormData): Data to preprocess. """ + # Enforce positivity if necessary + self._apply_y_transform(data) self.scale_forward(data) def scale_forward(self, data: NormData, overwrite: bool = False) -> None: @@ -586,10 +603,15 @@ def scale_forward(self, data: NormData, overwrite: bool = False) -> None: def postprocess(self, data: NormData) -> None: """Apply postprocessing to the data. + First unscales, then applies the inverse response transform (e.g. expm1). + Args: data (NormData): Data to postprocess. """ self.scale_backward(data) + # Invert Y to its original space if positivity was enforced during + # preprocessing + self._invert_y_transform(data) def scale_backward(self, data: NormData) -> None: """ @@ -606,6 +628,80 @@ def scale_backward(self, data: NormData) -> None: """ data.scale_backward(self.inscalers, self.outscalers) + def _apply_y_transform(self, data: NormData) -> None: + """ + Apply the forward response transform (e.g. log1p) to Y-like variables + in the data. + + Parameters + ---------- + data : NormData + Data object containing response variable arrays (Y, Yhat, + centiles, thrive_Y) to which the transform should be applied. + + """ + if self.y_transform is None: + return + + # TODO: Check if we need to track if transform has already been + # applied to avoid double-inverting. Normally I dont expect any issues + # as every process() is followed by a postprocess(). The only issues can + # be if users call postprocess() multiple times manually or with + # compute_thrivelines() that has a preprocess() call without a postprocess(). + + if self.y_transform == "log1p": + # Apply log1p transform to the response variable Y + for var in ["Y"]: + if (data[var] < -1).any(): + raise ValueError("Cannot apply log1p transform to variable " + f"'{var}' because it contains values less " + "than -1." + ) + else: + data[var] = np.log1p(data[var]) + + elif self.y_transform == "log": + # Apply natural log transform to the response variable Y + for var in ["Y"]: + if (data[var] <= 0).any(): + raise ValueError( + f"Cannot apply log transform to variable '{var}' " + "because it contains non-positive values. " + "Consider using 'log1p' transform or ensuring " + "all values are positive." + ) + else: + data[var] = np.log(data[var]) + + def _invert_y_transform(self, data: NormData) -> None: + """ + Apply the inverse response transform (e.g. expm1) to Y-like variables + in the data. + + Parameters + ---------- + data : NormData + Data object containing response variable arrays (Y, Yhat, + centiles, thrive_Y) to which the inverse transform should be applied. + """ + if self.y_transform is None: + return + + # TODO: Check if we need to track if inverse transform has already been + # applied to avoid double-inverting. Normally I dont expect any issues + # as every process() is followed by a postprocess(). The only issues can + # be if users call postprocess() multiple times manually or with + # compute_thrivelines() that has a preprocess() call without a postprocess(). + + if self.y_transform == "log1p": + for var in ("Y", "centiles", "Yhat", "Y_harmonized", "thrive_Y"): + if var in data.data_vars: + data[var] = np.expm1(data[var]) + elif self.y_transform == "log": + for var in ("Y", "centiles", "Yhat", "Y_harmonized", "thrive_Y"): + if var in data.data_vars: + data[var] = np.exp(data[var]) + def evaluate(self, data: NormData) -> None: """ Evaluates the model performance on the data. @@ -991,6 +1087,7 @@ def to_dict(self): "is_fitted": self.is_fitted, "inscaler": self.inscaler, "outscaler": self.outscaler, + "y_transform": self.y_transform, "ptk_version": importlib.metadata.version("pcntoolkit"), } @@ -1050,6 +1147,7 @@ def from_args(cls, **kwargs) -> NormativeModel: inscaler = kwargs.get("inscaler", "none") outscaler = kwargs.get("outscaler", "none") name = kwargs.get("name", None) + y_transform = kwargs.get("y_transform", None) assert "alg" in kwargs, "Algorithm must be specified" if kwargs["alg"] == "blr": template_regression_model = BLR.from_args("template", kwargs) @@ -1068,6 +1166,7 @@ def from_args(cls, **kwargs) -> NormativeModel: save_dir=save_dir, inscaler=inscaler, outscaler=outscaler, + y_transform=y_transform, name=name, ) diff --git a/pcntoolkit/regression_model/blr.py b/pcntoolkit/regression_model/blr.py index 306fb702..d5a9a197 100644 --- a/pcntoolkit/regression_model/blr.py +++ b/pcntoolkit/regression_model/blr.py @@ -347,6 +347,7 @@ def backward(self, X: xr.DataArray, be: xr.DataArray, Z: xr.DataArray) -> xr.Dat self.ys[mask] = self.ys[mask] + residual_mean self.s2[mask] = np.square(np.sqrt(self.s2[mask]) * correction_factor) + # Compute the centiles in the original Y space: centiles = Z * std + mean centiles = np_Z * np.sqrt(self.s2) + self.ys if self.warp: centiles = self.warp.invf(centiles, self.gamma) diff --git a/pcntoolkit/util/plotter.py b/pcntoolkit/util/plotter.py index 2a675474..d9f2ccc9 100644 --- a/pcntoolkit/util/plotter.py +++ b/pcntoolkit/util/plotter.py @@ -101,9 +101,10 @@ def plot_centiles( # Batch effects are the first ones in the highlighted batch effects for be, v in batch_effects.items(): centile_df[be] = v - # Response vars are all 0, we don't need them + # Assign random values for response vars because they are not needed. + # They must be > 0 to satisfy later checks that require response_vars > 0. for rv in response_vars: - centile_df[rv] = 0 + centile_df[rv] = 1e-6 centile_data = NormData.from_dataframe( "centile", @@ -352,9 +353,10 @@ def plot_centiles_advanced( # Batch effects are the first ones in the highlighted batch effects for be, v in batch_effects.items(): centile_df[be] = v[0] - # Response vars are all 0, we don't need them + # Assign random values for response vars because they are not needed. + # They must be > 0 to satisfy later checks that require response_vars > 0. for rv in model.response_vars: - centile_df[rv] = 0 + centile_df[rv] = 1e-6 centile_data = NormData.from_dataframe( "centile", dataframe=centile_df, diff --git a/test/fixtures/blr_model_fixtures.py b/test/fixtures/blr_model_fixtures.py index 322e6bae..c6dfebdf 100644 --- a/test/fixtures/blr_model_fixtures.py +++ b/test/fixtures/blr_model_fixtures.py @@ -1,3 +1,5 @@ +from math import log + import pytest from typing import Any, Callable @@ -11,6 +13,7 @@ from pcntoolkit.regression_model.blr import BLR from test.fixtures.norm_data_fixtures import * from test.fixtures.path_fixtures import * +import os # Default keyword arguments shared by all BLR tests. BLR_BASE_CONFIG: dict[str, Any] = { @@ -84,7 +87,7 @@ def fitted_blr_model( @pytest.fixture -def new_norm_blr_model( +def norm_blr_model( blr_model_factory: Callable, save_dir_blr ) -> NormativeModel: @@ -102,12 +105,80 @@ def new_norm_blr_model( @pytest.fixture -def fitted_norm_blr_model(new_norm_blr_model: NormativeModel, +def fitted_norm_blr_model(norm_blr_model: NormativeModel, norm_data_from_arrays: NormData ) -> NormativeModel: print("removing items") - if os.path.exists(new_norm_blr_model.save_dir): - shutil.rmtree(new_norm_blr_model.save_dir) - os.makedirs(new_norm_blr_model.save_dir, exist_ok=True) - new_norm_blr_model.fit(norm_data_from_arrays) - return new_norm_blr_model + if os.path.exists(norm_blr_model.save_dir): + shutil.rmtree(norm_blr_model.save_dir) + os.makedirs(norm_blr_model.save_dir, exist_ok=True) + norm_blr_model.fit(norm_data_from_arrays) + return norm_blr_model + + +@pytest.fixture +def log1p_transform_norm_blr_model( + save_dir_test_model: str +) -> NormativeModel: + """Create a NormativeModel using BLR with log1p. + + Returns + ------- + NormativeModel + Un-fitted normative model with log1p transform. + """ + # Build a fresh save directory + log_dir = os.path.join(save_dir_test_model, "log1p") + if os.path.exists(log_dir): + shutil.rmtree(log_dir) + os.makedirs(log_dir, exist_ok=True) + # Create test regression model + blr_model = BLR("test_model_log1p") + # Return a NormativeModel with the log1p transform + return NormativeModel( + template_regression_model=blr_model, + savemodel=False, + saveresults=False, + evaluate_model=False, + saveplots=False, + save_dir=log_dir, + inscaler="standardize", + outscaler="standardize", + name="test_model_log1p", + y_transform="log1p", + ) + + +@pytest.fixture +def log_transform_norm_blr_model( + save_dir_test_model: str, +) -> NormativeModel: + """Create a NormativeModel using BLR with natural log y_transform. + + Returns + ------- + NormativeModel + Un-fitted normative model with natural-log transform. + """ + # Build a fresh save directory for this fixture + log_dir = os.path.join(save_dir_test_model, "log") + if os.path.exists(log_dir): + # Remove stale directory from previous test runs + shutil.rmtree(log_dir) + os.makedirs(log_dir, exist_ok=True) + # Create a BLR regression model for the natural-log transform test + blr_model = BLR("test_model_log") + # Return NormativeModel with natural-log transform enabled + return NormativeModel( + template_regression_model=blr_model, + savemodel=False, + saveresults=False, + evaluate_model=False, + saveplots=False, + save_dir=log_dir, + inscaler="standardize", + outscaler="standardize", + name="test_model_log", + y_transform="log", + ) + diff --git a/test/fixtures/test_model_fixtures.py b/test/fixtures/test_model_fixtures.py index 4a7f26f3..63a6bd9b 100644 --- a/test/fixtures/test_model_fixtures.py +++ b/test/fixtures/test_model_fixtures.py @@ -1,3 +1,6 @@ +import os +import shutil + import pytest from pcntoolkit.dataio.norm_data import NormData diff --git a/test/test_core/test_normative_model.py b/test/test_core/test_normative_model_main.py similarity index 100% rename from test/test_core/test_normative_model.py rename to test/test_core/test_normative_model_main.py diff --git a/test/test_norm/test_normative_model.py b/test/test_norm/test_normative_model_helper.py similarity index 84% rename from test/test_norm/test_normative_model.py rename to test/test_norm/test_normative_model_helper.py index 2e50d191..6e5a5549 100644 --- a/test/test_norm/test_normative_model.py +++ b/test/test_norm/test_normative_model_helper.py @@ -9,6 +9,7 @@ from pcntoolkit.normative_model import NormativeModel from pcntoolkit.regression_model.blr import * from pcntoolkit.regression_model.hbr import * +from test.fixtures.blr_model_fixtures import * from test.fixtures.data_fixtures import * from test.fixtures.norm_data_fixtures import * from test.fixtures.path_fixtures import * @@ -294,3 +295,70 @@ def test_blr_model_to_and_from_dict_and_args(blr_model_args: dict, norm_data_fro assert model1.l_bfgs_b_norm == "l2" assert model1.fixed_effect assert not model1.fixed_effect_var + + +def test_log1p_transform( + log1p_transform_norm_blr_model: NormativeModel, + norm_data_from_arrays: NormData, + test_norm_data_from_arrays: NormData, +) -> None: + log1p_transform_norm_blr_model.fit_predict( + norm_data_from_arrays, test_norm_data_from_arrays + ) + + # Check that Y are non-negative in the original data + assert bool( + np.all(test_norm_data_from_arrays["Y"].values >= 0) + ) + + # Both training and test centiles should be bigger than 0 due to the + # exp(Y) - 1 transform + assert bool( + np.all(norm_data_from_arrays["centiles"].values > -1) + ) + assert bool( + np.all(test_norm_data_from_arrays["centiles"].values > -1) + ) + + # We dont expect any negative yhat values in the train and test dataset + assert bool( + np.all(norm_data_from_arrays["Yhat"].values >= 0) + ) + assert bool( + np.all(test_norm_data_from_arrays["Yhat"].values >= 0) + ) + + +def test_log_transformed( + log_transform_norm_blr_model: NormativeModel, + norm_data_from_arrays: NormData, + test_norm_data_from_arrays: NormData, +) -> None: + # Force the data to be >= 1e-6 + test_norm_data_from_arrays["Y"].values.clip(min=1e-6) + + log_transform_norm_blr_model.fit_predict( + norm_data_from_arrays, test_norm_data_from_arrays + ) + + # Check that Y are positive in the original data + assert bool( + np.all(test_norm_data_from_arrays["Y"].values > 0) + ) + + # Both training and test centiles should be bigger than 0 due to the exp(Y) + # transform + assert bool( + np.all(norm_data_from_arrays["centiles"].values > 0) + ) + assert bool( + np.all(test_norm_data_from_arrays["centiles"].values > 0) + ) + + # We dont expect any negative yhat values in the train and test dataset + assert bool( + np.all(norm_data_from_arrays["Yhat"].values > 0) + ) + assert bool( + np.all(test_norm_data_from_arrays["Yhat"].values > 0) + )