diff --git a/nflows/nn/nde/made.py b/nflows/nn/nde/made.py index a3a53ae..2faef3b 100644 --- a/nflows/nn/nde/made.py +++ b/nflows/nn/nde/made.py @@ -227,11 +227,12 @@ def __init__( raise ValueError("Residual blocks can't be used with random masks.") super().__init__() + self.output_multiplier = output_multiplier # Initial layer. self.initial_layer = MaskedLinear( - in_degrees=_get_input_degrees(features), + in_degrees=_get_input_degrees(features+1), out_features=hidden_features, - autoregressive_features=features, + autoregressive_features=features+1, random_mask=random_mask, is_output=False, ) @@ -250,7 +251,7 @@ def __init__( blocks.append( block_constructor( in_degrees=prev_out_degrees, - autoregressive_features=features, + autoregressive_features=features+1, context_features=context_features, random_mask=random_mask, activation=activation, @@ -265,20 +266,23 @@ def __init__( # Final layer. self.final_layer = MaskedLinear( in_degrees=prev_out_degrees, - out_features=features * output_multiplier, - autoregressive_features=features, + out_features=(features+1) * output_multiplier, + autoregressive_features=(features+1), random_mask=random_mask, is_output=True, ) def forward(self, inputs, context=None): - temps = self.initial_layer(inputs) + # add dummy input to ensure all dims conditioned on context. + dummy_input = torch.zeros((inputs.shape[:-1]+(1,))) + concat_input = torch.cat((dummy_input,inputs),dim=-1) + temps = self.initial_layer(concat_input) if context is not None: temps += self.context_layer(context) for block in self.blocks: temps = block(temps, context) outputs = self.final_layer(temps) - return outputs + return outputs[...,self.output_multiplier:] # remove dummy input class MixtureOfGaussiansMADE(MADE):