[BUG] Fix RandomState handling in ROCKET and Hydra transformers#3214
[BUG] Fix RandomState handling in ROCKET and Hydra transformers#3214satwiksps wants to merge 18 commits intoaeon-toolkit:mainfrom
Conversation
Thank you for contributing to
|
…into rocket-randomstate
748357b to
34a0eb9
Compare
…into rocket-randomstate
34a0eb9 to
9dfd0a9
Compare
Reference Issues/PRs
What does this implement/fix? Explain your changes.
This PR fixes a bug where passing a
RandomStateobject to therandom_stateparameter in ROCKET-based and Hydra transformers resulted in the seed being ignored (falling back toNone) or causing crashes in backend computations.Changes implemented:
_fitmethods inRocket,MiniRocket,MultiRocket,HydraTransformer, andROCKETGPU. Now usessklearn.utils.check_random_stateto process the inputrandom_state.RandomStateto pass to the underlying computation backends (Numba, PyTorch, and TensorFlow), ensuring reproducibility regardless of whether anint,None, orRandomStateobject is provided.ROCKETGPUcaused by data type mismatches between float32 inputs and float64 kernels/biases generated by numpy. Explicitly casts generated parameters and inputs tofloat32to ensure compatibility with TensorFlow'sconv1d.test_all_rockets.pyand_rockad.pywere updated. Since the random seeding logic was corrected, the deterministic stream of random numbers changed, necessitating updates to the hardcoded expected output values.Does your contribution introduce a new dependency? If yes, which one?
No.
Any other comments?
I have verified these changes with a local test script that confirms:
RandomStateobjects.ROCKETGPUno longer crashes on standard inputs.PR checklist
For all contributions
For new estimators and functions
__maintainer__at the top of relevant files and want to be contacted regarding its maintenance. Unmaintained files may be removed. This is for the full file, and you should not add yourself if you are just making minor changes or do not want to help maintain its contents.For developers with write access