You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Multi-scale minimal sufficient representation learning for domain generalization in sleep staging
Motivation
Comparison between (a) sufficient representation and (b) minimal sufficient representation. In conventional contrastive learning, $\boldsymbol{z}_i$ denotes the normalized feature of the $i$-th sample $\boldsymbol{v}_i$, while $\boldsymbol{z}_p$ represents the feature of a positive sample $\boldsymbol{v}_p$ that shares the same label as $\boldsymbol{v}_i$. The domain factor $D$ denotes the set of attributes that contribute to the domain gap. (a) Sufficient Representation Learning: This approach seeks to maximize the shared information between feature and positive samples $I(\boldsymbol{z}_i ; \boldsymbol{v}_p)$, while simultaneously introducing the superfluous information $I(\boldsymbol{z}_i; \boldsymbol{v}_i | \boldsymbol{v}_p)$, which corresponds to the information present in $\boldsymbol{v}_i$ but absent in $\boldsymbol{v}_p$. Among these, \textit{excess domain-relevant information} $I(\boldsymbol{z}_i ; d_i | \boldsymbol{v}_p)$ caused by domain attributes hinders the learning of domain-invariant features, where $d_i$ refers to the domain label of $\boldsymbol{v}_i$. (b) Minimal Sufficient Representation Learning: This approach aims to reduce the superfluous information $I(\boldsymbol{z}_i; \boldsymbol{v}_i | \boldsymbol{v}_p)$, thereby diminishing the excess domain-relevant information and enabling the learning of more domain-invariant features.
Overall framework
Overall framework
Effectiveness
Comparison
Feature visualization
Environment Setup
Python 3.9
Cuda 12.1
Pytorch 2.31
Required libraries are listed in requirements.txt.
pip install -r requirements.txt
Data Preprosessing
Download the SleepEDF20, and MASS3 and put them the data dir.
Convert the data to .npz format.
python Preprocessing.py
Run
Our model consist of pretrain and fintuing part.
Pretrain
First, model's feature extractor learn the domain-invarint feature via multi-scale minimal sufficient learning.
python Pretrain.py
Pretrain
Second, To demonstrate the performance of the feature extractor, we train a transformer-based classifier while keeping the parameters of the feature extractor fixed. The transformer-based classifier follows the model proposed in prior work SleePyCo for sleep scoring. You can eddit the config .json file batch size = 1024, seq_len = 1, mode = pretrain