Skip to content

[CHORE] Ensure shape is handled consistently across package #509

@smcolby

Description

@smcolby

Should confirm any and all instances in which input shapes are checked for and coerced, i.e. to handle (N,), (N, 1), and (N, M) cases when input to downstream functions may be opinionated in different ways (i.e. accepts (N,) but not (N, 1)), are handled in a functionally similar way. Does not have to be literally the same way, but all patterns should result in the same outcome.

1. openadmet/models/eval/regression.py — Lines 99–102 & 383–386

Pattern: (N,)(N, 1) for both y_pred and y_true.

if y_pred.ndim == 1:
    y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
    y_true = y_true.reshape(-1, 1)

What it does: Appears in both RegressionMetrics.evaluate() and RegressionPlotter.evaluate(). Coerces any flat (N,) array into (N, 1) so that shape[1] can be read as the number of tasks. The downstream per-task loop over y_true[:, task_id] requires 2D arrays. Does not handle the extra case where y_true is (N, 1) but y_pred is (N,) — those are handled independently for each argument. No coercion is applied for (N, M) beyond reading shape[1].

2. openadmet/models/eval/uncertainty.py — Lines 125–130 & 325–330

Pattern: (N,)(N, 1) for y_pred, y_true, and y_std.

if y_pred.ndim == 1:
    y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
    y_true = y_true.reshape(-1, 1)
if y_std.ndim == 1:
    y_std = y_std.reshape(-1, 1)

What it does: Appears in both UncertaintyMetrics.evaluate() and UncertaintyPlotter.evaluate(). Same pattern as regression — coerces all three arrays to (N, 1) before iterating over tasks. After coercion, each task column is extracted and immediately .flatten()'d back to (N,) before passing to metric functions (lines 146–148, 346–348).

3. openadmet/models/eval/classification.py — Lines 114, 117–135, 324–329, 374–382

Pattern: Detects (N,) or (N, 1) as the binary case; (N, K) with K > 1 as the multiclass case.

if (y_true.ndim == 1) or (y_true.ndim == 2 and y_true.shape[1] == 1):
    # binary: ravel both y_true and slice y_pred[:, 1] or argmax
else:
    # multiclass: argmax or full ravel

What it does: Rather than coercing to a canonical shape, this branches on the input shape to select the correct metric computation path. For binary, y_true is always .ravel()'d to (N,). For multiclass, np.argmax or full .ravel() is applied. This same branching pattern appears in ClassificationMetrics.evaluate(), ClassificationMetrics.roc_curve_plot(), and ClassificationMetrics.pr_curve_plot().

4. openadmet/models/architecture/dummy.py — Lines 40–41 & 64–65

Pattern (train): (N, 1)(N,) via .ravel().

y_arr = np.asarray(y)
if y_arr.ndim == 2 and y_arr.shape[1] == 1:
    y_arr = y_arr.ravel()

Pattern (predict): (N,)(N, 1) via np.expand_dims.

pred = self.estimator.predict(X)
if pred.ndim == 1:
    pred = np.expand_dims(pred, axis=1)

What it does: The sklearn DummyClassifier/DummyRegressor rejects (N, 1) targets during training, so the fit() wrapper squeezes to (N,). Conversely, the predict() wrapper expands the resulting (N,) back out to (N, 1) so the output is consistent with the rest of the pipeline.

5. openadmet/models/features/chemprop.py — Line 160

Pattern: (N,)(N, 1).

y = y.reshape(-1, 1) if y.ndim == 1 else y

What it does: ChemProp's MoleculeDatapoint expects a per-sample target array, so y must be (N, 1) (or (N, M)). A flat (N,) is reshaped before constructing the dataset. No handling for (N, 1) that is already correct — it passes through unchanged.

6. openadmet/models/features/mtenn.py — Line 275

Pattern: Unconditional (N,) or (N, 1)(N, 1).

y = y.reshape(-1, 1)

What it does: Unlike ChemProp's conditional check, this is an unconditional reshape. Applied to any y that is not None. This means a (N, M) input would be corrupted into (N*M, 1) — no guard exists for the multi-task case here.

7. openadmet/models/features/combine.py — Lines 110–112

Pattern: (N,)(1, N) (single-sample case).

feats = [
    feat.reshape(1, -1) if len(feat.shape) == 1 else feat for feat in feats
]

What it does: Treats a 1D feature array as a single sample and promotes it to (1, N). This is specifically for combining multiple featurizers' outputs — if any individual featurizer returns a flat array (e.g., single molecule), it is made 2D so concatenation works. Note the unusual direction: (F,)(1, F) rather than (F,)(F, 1).

8. openadmet/models/features/molfeat_fingerprint.py & molfeat_properties.py — Lines 79 & 88

Pattern: Extra leading dimension squeezed away.

return np.squeeze(feat), indices

What it does: Both featurizers call datamol's MoleculeTransformer, which returns an array with an extra leading dimension (e.g., (1, N, F) or similar). np.squeeze removes all size-1 dimensions unconditionally. The comment in both files reads "datamol returns with an extra dimension."

9. openadmet/models/split/cluster.py — Line 101

Pattern: Extra dimensions squeezed.

fp_list = list(np.squeeze(feat))

What it does: Same issue as the molfeat featurizers — datamol's transformer output has an extra dimension, so np.squeeze is applied before calling np.stack to build the fingerprint matrix for KMeans.

10. openadmet/models/features/pairwise.py — Line 166

Pattern: DataFrame → (N,) via .values.ravel().

y = y.values.ravel()

What it does: When y is a pd.DataFrame, .values gives (N, 1) and .ravel() flattens it to (N,) for the PairwiseAugmentedDataset. The pd.Series path (line 168) does not ravel — it stays (N,) naturally.

11. openadmet/models/architecture/nepare.py — Line 175

Pattern: Tensor (N,)(N, 1) via .unsqueeze(1).

if y.dim() == 1:
    y = y.unsqueeze(1)  # Ensure y is [batch_size, 1]

What it does: Inside the Lightning _step method for NEPARE, the target batch from the dataloader may arrive as (batch_size,). torch.nn.functional.mse_loss requires matching shapes with y_hat, which is (batch_size, 1), so the unsqueeze is a guard.

12. openadmet/models/architecture/mtenn.py — Lines 103 & 129

Pattern (train step): Scalar target → (1,) via .unsqueeze(0).

loss = self.loss_fn(pred, target.unsqueeze(0).to(self.device))

Pattern (predict step): Scalar prediction → (1, 1)(N, 1) via .unsqueeze(1) + torch.cat.

preds = [self(data).unsqueeze(1) for data in data_batch]
return torch.cat(preds)

What it does: MTENN operates on individual molecular complexes, so both pred and target are scalars or (1,) tensors per sample. .unsqueeze(0) on the target aligns it with the prediction shape for loss computation. In predict_step, .unsqueeze(1) on each scalar prediction promotes it to (1, 1) before concatenating across the batch into (N, 1).

Summary Table

File Direction Method Condition
eval/regression.py (N,)(N, 1) .reshape(-1, 1) ndim == 1
eval/uncertainty.py (N,)(N, 1) .reshape(-1, 1) ndim == 1
eval/classification.py (N,) or (N, 1)(N,) .ravel() ndim == 1 or shape[1] == 1 (branch, not coerce)
architecture/dummy.py (fit) (N, 1)(N,) .ravel() ndim == 2 and shape[1] == 1
architecture/dummy.py (predict) (N,)(N, 1) np.expand_dims(..., 1) ndim == 1
features/chemprop.py (N,)(N, 1) .reshape(-1, 1) ndim == 1
features/mtenn.py (N,) or (N, 1)(N, 1) .reshape(-1, 1) unconditional
features/combine.py (F,)(1, F) .reshape(1, -1) len(shape) == 1
features/molfeat_fingerprint.py extra dim → standard 2D np.squeeze unconditional
features/molfeat_properties.py extra dim → standard 2D np.squeeze unconditional
split/cluster.py extra dim → standard 2D np.squeeze unconditional
features/pairwise.py (N, 1) DataFrame → (N,) .values.ravel() isinstance(y, pd.DataFrame)
architecture/nepare.py (N,)(N, 1) .unsqueeze(1) dim() == 1
architecture/mtenn.py (train) scalar → (1,) .unsqueeze(0) unconditional per-sample
architecture/mtenn.py (predict) scalar → (1, 1)(N, 1) .unsqueeze(1) + torch.cat unconditional per-sample

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions