diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9de07cd --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +uv.lock +.venv +data +logs +__pycache__ diff --git a/src/NoProp_simple.ipynb b/NoProp_simple.ipynb similarity index 100% rename from src/NoProp_simple.ipynb rename to NoProp_simple.ipynb diff --git a/README.md b/README.md index 4fa1532..7d2c246 100644 --- a/README.md +++ b/README.md @@ -41,24 +41,33 @@ current implementation is slightly different. We leverage predefined ResNet Mod #### Quick Start -copy the code +You can either install directly from the repository: + +```bash +pip install git+https://github.com/yhgon/NoProp.git ``` + +Or download the repository and install it in editable mode: + +```bash git clone https://github.com/yhgon/NoProp.git +cd NoProp +pip install --editable . ``` run with default option for mnist dataset ``` -python NoProp/src/nopropct_mnist.py +noprop-mnist ``` run with configure dataset and backbone model ``` -python NoProp/src/noprop_simple.py --dataset cifar10 --backbone resnet50 +noprop-simple --dataset cifar10 --backbone resnet50 ``` run with configure dataset and backbone model default epoch is `400` ``` -python NoProp/src/noprop_simple.py --dataset cifar100 --backbone resnet152 +noprop-simple --dataset cifar100 --backbone resnet152 ``` ## log diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d483a44 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "noprop" +dynamic = ["version"] +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["torch", "torchvision"] + +[project.scripts] +noprop-simple = "noprop.simple:main" +noprop-mnist = "noprop.mnist:main" + +[tool.hatch.version] +source = "vcs" diff --git a/src/noprop/__init__.py b/src/noprop/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/nopropct_mnist.py b/src/noprop/mnist.py similarity index 99% rename from src/nopropct_mnist.py rename to src/noprop/mnist.py index 3695021..4a10536 100644 --- a/src/nopropct_mnist.py +++ b/src/noprop/mnist.py @@ -277,6 +277,9 @@ def train_and_eval(backbone: str): del model, optim, tr, te, ds_train, ds_test torch.cuda.empty_cache(); gc.collect() -if __name__ == '__main__': +def main(): for backbone in ['resnet18', 'resnet50', 'resnet152']: train_and_eval(backbone) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/noprop_simple.py b/src/noprop/simple.py similarity index 99% rename from src/noprop_simple.py rename to src/noprop/simple.py index bb1221c..cc11643 100644 --- a/src/noprop_simple.py +++ b/src/noprop/simple.py @@ -393,7 +393,7 @@ def train_and_eval(backbone: str, time_emb_dim: int, embed_dim: int, dataset: st del model, optimizer, ds_train, ds_test, tr_loader, te_loader torch.cuda.empty_cache(); gc.collect() -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', choices=['mnist','cifar10','cifar100'], required=True) parser.add_argument('--data-root', default='./data') @@ -412,3 +412,6 @@ def train_and_eval(backbone: str, time_emb_dim: int, embed_dim: int, dataset: st data_root = args.data_root, epoches = args.epoches, ) + +if __name__ == '__main__': + main() \ No newline at end of file