Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pcntoolkit/math_functions/thrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age")
"""

df = data.to_dataframe()[["X", "Z", "batch_effects", "subject_ids"]].droplevel(level=0, axis=1)

if not df["subject_ids"].duplicated().any():
raise ValueError(
"Cannot compute correlation matrix: The dataset is cross-sectional. "
"Computing a correlation matrix (e.g., for thrivelines) requires longitudinal data "
"(multiple observations per subject_id at different ages)."
)

# create dictionary of (age:indices)
grps = df.groupby(covariate_name).indices | defaultdict(list)
# get the max age in the dataset
Expand All @@ -120,6 +128,7 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age")
n_responsevars = len(data.response_vars.to_numpy())
# create empty correlation matrix
cors = np.tile(np.eye(max_age + 1), (n_responsevars, 1, 1))

for age1, age2 in offset_indices(max_age, bandwidth):
# merge two ages on subjects
merged = pd.merge(df.iloc[grps[age1]], df.iloc[grps[age2]], how="inner", on="subject_ids")
Expand All @@ -129,7 +138,8 @@ def get_correlation_matrix(data: NormData, bandwidth: int, covariate_name="age")
cors[i, age2, age1] = cors[i, age1, age2] = merged[f"{rv}_x"].corr(merged[f"{rv}_y"])
elif age1 != age2:
# Otherwise, set all response variables to NaN for these ages
cors[:, age2, age1] = cors[:, age1, age2] = np.NaN
cors[:, age2, age1] = cors[:, age1, age2] = np.nan

# Fill in missing correlation values
newcors = fill_missing(bandwidth, cors)
newcors = xr.DataArray(
Expand Down
Loading