diff --git a/pcntoolkit/math_functions/thrive.py b/pcntoolkit/math_functions/thrive.py index 81e3d46c..61dd1183 100644 --- a/pcntoolkit/math_functions/thrive.py +++ b/pcntoolkit/math_functions/thrive.py @@ -112,6 +112,14 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age") """ df = data.to_dataframe()[["X", "Z", "batch_effects", "subject_ids"]].droplevel(level=0, axis=1) + + if not df["subject_ids"].duplicated().any(): + raise ValueError( + "Cannot compute correlation matrix: The dataset is cross-sectional. " + "Computing a correlation matrix (e.g., for thrivelines) requires longitudinal data " + "(multiple observations per subject_id at different ages)." + ) + # create dictionary of (age:indices) grps = df.groupby(covariate_name).indices | defaultdict(list) # get the max age in the dataset @@ -120,6 +128,7 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age") n_responsevars = len(data.response_vars.to_numpy()) # create empty correlation matrix cors = np.tile(np.eye(max_age + 1), (n_responsevars, 1, 1)) + for age1, age2 in offset_indices(max_age, bandwidth): # merge two ages on subjects merged = pd.merge(df.iloc[grps[age1]], df.iloc[grps[age2]], how="inner", on="subject_ids") @@ -129,7 +138,8 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age") cors[i, age2, age1] = cors[i, age1, age2] = merged[f"{rv}_x"].corr(merged[f"{rv}_y"]) elif age1 != age2: # Otherwise, set all response variables to NaN for these ages - cors[:, age2, age1] = cors[:, age1, age2] = np.NaN + cors[:, age2, age1] = cors[:, age1, age2] = np.nan + # Fill in missing correlation values newcors = fill_missing(bandwidth, cors) newcors = xr.DataArray(