diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 4cfd0bb6..1af9563d 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -11,7 +11,6 @@ Implemented with PyTorch Lightning """ -from abc import abstractmethod import numpy as np import torch from torch.utils.data import DataLoader @@ -38,10 +37,8 @@ ) -class ContextualizedRegressionBase(pl.LightningModule): - """ - Abstract class for Contextualized Regression. - """ +class NaiveContextualizedRegression(pl.LightningModule): + """See NaiveMetamodel""" def __init__( self, @@ -74,78 +71,6 @@ def __init__( **kwargs, ) - @abstractmethod - def _build_metamodel( - self, - context_dim: int, - x_dim: int, - y_dim: int, - **kwargs, - ): - """ - - :param context_dim: Dimension of the context vector - :param x_dim: Dimension of the input features - :param y_dim: Dimension of the output labels - :param **kwargs: Additional keyword arguments for the metamodel - - """ - # builds the metamodel - - @abstractmethod - def dataloader(self, C, X, Y, batch_size=32): - """ - - :param C: - :param X: - :param Y: - :param batch_size: (Default value = 32) - - """ - # returns the dataloader for this class - - @abstractmethod - def _batch_loss(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ - # MSE loss by default - - @abstractmethod - def predict_step(self, batch, batch_idx, dataloader_idx=0): - """ - - :param batch: - :param batch_idx: - :param dataload_idx: - - """ - # returns predicted params on the given batch - - @abstractmethod - def _params_reshape(self, beta_preds, mu_preds, dataloader): - """ - - :param beta_preds: - :param mu_preds: - :param dataloader: - - """ - # reshapes the batch parameter predictions into beta (y_dim, x_dim) - - @abstractmethod - def _y_reshape(self, y_preds, dataloader): - """ - - :param y_preds: - :param dataloader: - - """ - # reshapes the batch y predictions into a desirable format - def forward(self, *args, **kwargs): """ @@ -240,10 +165,6 @@ def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): kwargs["batch_size"] = kwargs.get("batch_size", 32) return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) - -class NaiveContextualizedRegression(ContextualizedRegressionBase): - """See NaiveMetamodel""" - def _build_metamodel( self, context_dim: int, @@ -342,9 +263,134 @@ def dataloader(self, C, X, Y, **kwargs): return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) -class ContextualizedRegression(ContextualizedRegressionBase): +class ContextualizedRegression(pl.LightningModule): """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" + def __init__( + self, + context_dim: int, + x_dim: int, + y_dim: int, + learning_rate: float = 1e-3, + metamodel_type: str = "subtype", + fit_intercept: bool = True, + link_fn: callable = LINK_FUNCTIONS["identity"], + loss_fn: callable = MSE, + model_regularizer: callable = REGULARIZERS["none"], + base_y_predictor: callable = None, + base_param_predictor: callable = None, + **kwargs, + ): + super().__init__() + self.learning_rate = learning_rate + self.metamodel_type = metamodel_type + self.fit_intercept = fit_intercept + self.link_fn = link_fn + self.loss_fn = loss_fn + self.model_regularizer = model_regularizer + self.base_y_predictor = base_y_predictor + self.base_param_predictor = base_param_predictor + self._build_metamodel( + context_dim, + x_dim, + y_dim, + **kwargs, + ) + + def forward(self, *args, **kwargs): + """ + + :param *args: + + """ + beta, mu = self.metamodel(*args) + if not self.fit_intercept: + mu = torch.zeros_like(mu) + if self.base_param_predictor is not None: + base_beta, base_mu = self.base_param_predictor.predict_params(*args) + beta = beta + base_beta.to(beta.device) + mu = mu + base_mu.to(mu.device) + return beta, mu + + def configure_optimizers(self): + """ + Set up optimizer. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"train_loss": loss}) + return loss + + def validation_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"val_loss": loss}) + return loss + + def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"test_loss": loss}) + return loss + + def _predict_from_models(self, X, beta_hat, mu_hat): + """ + + :param X: + :param beta_hat: + :param mu_hat: + + """ + return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) + + def _predict_y(self, C, X, beta_hat, mu_hat): + """ + + :param C: + :param X: + :param beta_hat: + :param mu_hat: + + """ + Y = self._predict_from_models(X, beta_hat, mu_hat) + if self.base_y_predictor is not None: + Y_base = self.base_y_predictor.predict_y(C, X) + Y = Y + Y_base.to(Y.device) + return Y + + def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): + """ + + :param C: + :param X: + :param Y: + :param dataset_constructor: + :param **kwargs: + + """ + kwargs["num_workers"] = kwargs.get("num_workers", 0) + kwargs["batch_size"] = kwargs.get("batch_size", 32) + return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) + def _build_metamodel( self, context_dim: int, @@ -451,9 +497,134 @@ def dataloader(self, C, X, Y, **kwargs): return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) -class MultitaskContextualizedRegression(ContextualizedRegressionBase): +class MultitaskContextualizedRegression(pl.LightningModule): """See MultitaskMetamodel""" + def __init__( + self, + context_dim: int, + x_dim: int, + y_dim: int, + learning_rate: float = 1e-3, + metamodel_type: str = "subtype", + fit_intercept: bool = True, + link_fn: callable = LINK_FUNCTIONS["identity"], + loss_fn: callable = MSE, + model_regularizer: callable = REGULARIZERS["none"], + base_y_predictor: callable = None, + base_param_predictor: callable = None, + **kwargs, + ): + super().__init__() + self.learning_rate = learning_rate + self.metamodel_type = metamodel_type + self.fit_intercept = fit_intercept + self.link_fn = link_fn + self.loss_fn = loss_fn + self.model_regularizer = model_regularizer + self.base_y_predictor = base_y_predictor + self.base_param_predictor = base_param_predictor + self._build_metamodel( + context_dim, + x_dim, + y_dim, + **kwargs, + ) + + def forward(self, *args, **kwargs): + """ + + :param *args: + + """ + beta, mu = self.metamodel(*args) + if not self.fit_intercept: + mu = torch.zeros_like(mu) + if self.base_param_predictor is not None: + base_beta, base_mu = self.base_param_predictor.predict_params(*args) + beta = beta + base_beta.to(beta.device) + mu = mu + base_mu.to(mu.device) + return beta, mu + + def configure_optimizers(self): + """ + Set up optimizer. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"train_loss": loss}) + return loss + + def validation_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"val_loss": loss}) + return loss + + def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"test_loss": loss}) + return loss + + def _predict_from_models(self, X, beta_hat, mu_hat): + """ + + :param X: + :param beta_hat: + :param mu_hat: + + """ + return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) + + def _predict_y(self, C, X, beta_hat, mu_hat): + """ + + :param C: + :param X: + :param beta_hat: + :param mu_hat: + + """ + Y = self._predict_from_models(X, beta_hat, mu_hat) + if self.base_y_predictor is not None: + Y_base = self.base_y_predictor.predict_y(C, X) + Y = Y + Y_base.to(Y.device) + return Y + + def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): + """ + + :param C: + :param X: + :param Y: + :param dataset_constructor: + :param **kwargs: + + """ + kwargs["num_workers"] = kwargs.get("num_workers", 0) + kwargs["batch_size"] = kwargs.get("batch_size", 32) + return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) + def _build_metamodel( self, context_dim: int, @@ -555,9 +726,134 @@ def dataloader(self, C, X, Y, **kwargs): return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) -class TasksplitContextualizedRegression(ContextualizedRegressionBase): +class TasksplitContextualizedRegression(pl.LightningModule): """See TasksplitMetamodel""" + def __init__( + self, + context_dim: int, + x_dim: int, + y_dim: int, + learning_rate: float = 1e-3, + metamodel_type: str = "subtype", + fit_intercept: bool = True, + link_fn: callable = LINK_FUNCTIONS["identity"], + loss_fn: callable = MSE, + model_regularizer: callable = REGULARIZERS["none"], + base_y_predictor: callable = None, + base_param_predictor: callable = None, + **kwargs, + ): + super().__init__() + self.learning_rate = learning_rate + self.metamodel_type = metamodel_type + self.fit_intercept = fit_intercept + self.link_fn = link_fn + self.loss_fn = loss_fn + self.model_regularizer = model_regularizer + self.base_y_predictor = base_y_predictor + self.base_param_predictor = base_param_predictor + self._build_metamodel( + context_dim, + x_dim, + y_dim, + **kwargs, + ) + + def forward(self, *args, **kwargs): + """ + + :param *args: + + """ + beta, mu = self.metamodel(*args) + if not self.fit_intercept: + mu = torch.zeros_like(mu) + if self.base_param_predictor is not None: + base_beta, base_mu = self.base_param_predictor.predict_params(*args) + beta = beta + base_beta.to(beta.device) + mu = mu + base_mu.to(mu.device) + return beta, mu + + def configure_optimizers(self): + """ + Set up optimizer. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"train_loss": loss}) + return loss + + def validation_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"val_loss": loss}) + return loss + + def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ + loss = self._batch_loss(batch, batch_idx) + self.log_dict({"test_loss": loss}) + return loss + + def _predict_from_models(self, X, beta_hat, mu_hat): + """ + + :param X: + :param beta_hat: + :param mu_hat: + + """ + return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) + + def _predict_y(self, C, X, beta_hat, mu_hat): + """ + + :param C: + :param X: + :param beta_hat: + :param mu_hat: + + """ + Y = self._predict_from_models(X, beta_hat, mu_hat) + if self.base_y_predictor is not None: + Y_base = self.base_y_predictor.predict_y(C, X) + Y = Y + Y_base.to(Y.device) + return Y + + def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): + """ + + :param C: + :param X: + :param Y: + :param dataset_constructor: + :param **kwargs: + + """ + kwargs["num_workers"] = kwargs.get("num_workers", 0) + kwargs["batch_size"] = kwargs.get("batch_size", 32) + return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) + def _build_metamodel( self, context_dim: int, @@ -891,7 +1187,7 @@ def __init__( layers: int = 1, link_fn: callable = LINK_FUNCTIONS["identity"], num_archetypes: int = 10, - **kwargs, # Allows for additional args to be passed to class ContextualizedRegressionBase + **kwargs, # Allows for additional args to be passed to ContextualizedRegression ): super().__init__( context_dim=context_dim, @@ -942,7 +1238,7 @@ def __init__( task_width: int = 25, task_layers: int = 1, task_link_fn: callable = LINK_FUNCTIONS["identity"], - **kwargs, # Allows for additional args to be passed to class ContextualizedRegressionBase + **kwargs, # Allows for additional args to be passed to ContextualizedRegression ): super().__init__( context_dim=context_dim, @@ -995,7 +1291,7 @@ def __init__( link_fn: callable = LINK_FUNCTIONS["identity"], num_archetypes: int = 10, model_regularizer: callable = REGULARIZERS["l1"](1e-3, mu_ratio=0), - **kwargs, # Allows for additional args to be passed to class ContextualizedRegressionBase + **kwargs, # Allows for additional args to be passed to ContextualizedRegression ): super().__init__( context_dim=context_dim, @@ -1057,7 +1353,7 @@ def __init__( layers: int = 1, link_fn: callable = LINK_FUNCTIONS["identity"], num_archetypes: int = 10, - **kwargs, # Allows for additional args to be passed to class ContextualizedRegressionBase + **kwargs, # Allows for additional args to be passed to ContextualizedRegression ): super().__init__( context_dim=context_dim,