Skip to content

Use torch.searchsorted instead of our ad-hoc implementation #19

@jmarshrossney

Description

@jmarshrossney

Hi!

I compared the searchsorted function implemented here, that does torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1, with the implementation in C++ here -- https://github.com/aliutkus/torchsearchsorted -- and it appears to be a lot slower on CPU at least.

I modified the benchmark.py in torchsearchsorted and just copy-pasted the function from nflows for comparison.
The output was (all on CPU)

Benchmark searchsorted:
- a [5000 x 16]
- v [5000 x 1]
- reporting fastest time of 10 runs
- each run executes searchsorted 100 times

Numpy: 	0.9516626670001642
torchsearchsorted: 	0.009861100999842165
nflows: 	50.19729063499926

i.e. sorting 5000 inputs into 5000 individual sets of 16 bins.

Am I missing something here? If not, it looks like the spline flows could be sped up quite a bit by using torchsearchsorted or something similar?

Cheers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions