Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lilfilter/local_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def compute(self,
# Gaussian smoothing (to correct for end effects)..
amplitudes = torch.empty(
(minibatch_size * num_channels + 1), signal_length,
dtype=self.dtype)
dtype=self.dtype).type_as(input)

# set the last row to all ones.
amplitudes[minibatch_size*num_channels:,:] = 1
Expand Down Expand Up @@ -194,7 +194,7 @@ def _block_sum(self, amplitudes):
(n, s) = amplitudes.shape
t = (s + 2 * b - 1) // b

ans = torch.zeros((n, t), dtype=self.dtype)
ans = torch.zeros((n, t), dtype=self.dtype).type_as(amplitudes)

# make sure `amplitudes` is contiguous.

Expand Down
6 changes: 3 additions & 3 deletions lilfilter/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ def resample(self, in_data):
# will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1
in_data = in_data.unsqueeze(1)
out = torch.nn.functional.conv1d(in_data,
self.weights,
self.weights.type_as(in_data),
stride=self.input_sr,
padding=self.padding)
# shape will be (minibatch_size, out_channels = 1, seq_len);
# return as (minibatch_size, seq_len)
return out.squeeze(1)
elif self.resample_type == 'integer_upsample':
out = torch.nn.functional.conv_transpose1d(in_data.unsqueeze(1),
self.weights,
self.weights.type_as(in_data),
stride=self.output_sr,
padding=self.padding)
return out.squeeze(1)
Expand All @@ -221,7 +221,7 @@ def resample(self, in_data):
# in_channels, width) so we need to reshape (note: time is width).
in_data = in_data.transpose(1, 2)

out = torch.nn.functional.conv1d(in_data, self.weights,
out = torch.nn.functional.conv1d(in_data, self.weights.type_as(in_data),
padding=self.padding)
assert out.shape == (minibatch_size, self.output_sr, num_blocks)
return out.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr)
Expand Down
2 changes: 1 addition & 1 deletion lilfilter/torch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def apply(self, input):
# input.unsqueeze(1) changes dim from (minibatch_size, sequence_length) to
# (minibatch_size, num_channels=1, sequence_length)
# the final squeeze(1) removes the num_channels=1 axis
return torch.nn.functional.conv1d(input.unsqueeze(1), self.filt,
return torch.nn.functional.conv1d(input.unsqueeze(1), self.filt.type_as(input),
padding=self.padding).squeeze(1)


Expand Down