From 3fbc9af16450b3df9f17be823afbab9fc718a228 Mon Sep 17 00:00:00 2001 From: Chigozie Nri Date: Sun, 3 Dec 2023 21:41:05 +0000 Subject: [PATCH 1/5] Use replicate-weights --- README.md | 16 +++++++++------- cog.yaml | 3 +++ predict.py | 14 +++++++++++++- script/download-weights | 28 ---------------------------- 4 files changed, 25 insertions(+), 36 deletions(-) delete mode 100755 script/download-weights diff --git a/README.md b/README.md index a4373cf..a5fef02 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,12 @@ This is an implementation of the [Diffusers Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog) -First, download the pre-trained weights: - - cog run script/download-weights - -Then, you can run predictions: - - cog predict -i prompt="monkey scuba diving" +Make single prediction: +```bash +cog predict -i prompt="monkey scuba diving" +``` + +Run HTTP API for making predictions: +```bash +cog run -p 5000 python -m cog.server.http +``` \ No newline at end of file diff --git a/cog.yaml b/cog.yaml index 06fad69..e2e25b3 100644 --- a/cog.yaml +++ b/cog.yaml @@ -11,4 +11,7 @@ build: - "accelerate==0.15.0" - "huggingface-hub==0.13.2" + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.6/pget" && chmod +x /usr/local/bin/pget + predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index db3c20e..feba699 100644 --- a/predict.py +++ b/predict.py @@ -15,22 +15,34 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) +from weights_downloader import WeightsDownloader # MODEL_ID refers to a diffusers-compatible model on HuggingFace # e.g. prompthero/openjourney-v2, wavymulder/Analog-Diffusion, etc -MODEL_ID = "stabilityai/stable-diffusion-2-1" MODEL_CACHE = "diffusers-cache" + +SD_MODEL_CACHE = os.path.join(MODEL_CACHE, "models--stabilityai--stable-diffusion-2-1") +MODEL_ID = "stabilityai/stable-diffusion-2-1" +SD_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-2-1.tar" + +SAFETY_CACHE = os.path.join(MODEL_CACHE, "models--CompVis--stable-diffusion-safety-checker") SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" +SAFETY_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-safety-checker.tar" class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" + + print("Loading pipeline...") + WeightsDownloader.download_if_not_exists(SAFETY_URL, SAFETY_CACHE) safety_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_MODEL_ID, cache_dir=MODEL_CACHE, local_files_only=True, ) + + WeightsDownloader.download_if_not_exists(SD_URL, SD_MODEL_CACHE) self.pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, safety_checker=safety_checker, diff --git a/script/download-weights b/script/download-weights deleted file mode 100755 index 6cf233d..0000000 --- a/script/download-weights +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python - -import os -import shutil -import sys - -from diffusers import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker - -# append project directory to path so predict.py can be imported -sys.path.append('.') - -from predict import MODEL_CACHE, MODEL_ID, SAFETY_MODEL_ID - -if os.path.exists(MODEL_CACHE): - shutil.rmtree(MODEL_CACHE) -os.makedirs(MODEL_CACHE, exist_ok=True) - -saftey_checker = StableDiffusionSafetyChecker.from_pretrained( - SAFETY_MODEL_ID, - cache_dir=MODEL_CACHE, -) - -pipe = StableDiffusionPipeline.from_pretrained( - MODEL_ID, - cache_dir=MODEL_CACHE, -) From 2076465ae32602153528e0c929f6fd75bbc3839e Mon Sep 17 00:00:00 2001 From: Chigozie Nri Date: Mon, 4 Dec 2023 13:22:20 +0000 Subject: [PATCH 2/5] Add weights_downloader.py --- weights_downloader.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 weights_downloader.py diff --git a/weights_downloader.py b/weights_downloader.py new file mode 100644 index 0000000..3332365 --- /dev/null +++ b/weights_downloader.py @@ -0,0 +1,18 @@ +import os +import subprocess +import time + + +class WeightsDownloader: + @staticmethod + def download_if_not_exists(url, dest): + if not os.path.exists(dest): + WeightsDownloader.download(url, dest) + + @staticmethod + def download(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) From 451f510766cdeee26efaf6c20f093a1043fb4d9e Mon Sep 17 00:00:00 2001 From: Chigozie Nri Date: Mon, 4 Dec 2023 13:23:07 +0000 Subject: [PATCH 3/5] Black + isort --- predict.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/predict.py b/predict.py index feba699..f2985a5 100644 --- a/predict.py +++ b/predict.py @@ -4,17 +4,18 @@ import torch from cog import BasePredictor, Input, Path from diffusers import ( - StableDiffusionPipeline, - PNDMScheduler, - LMSDiscreteScheduler, DDIMScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, ) from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) + from weights_downloader import WeightsDownloader # MODEL_ID refers to a diffusers-compatible model on HuggingFace @@ -25,15 +26,17 @@ MODEL_ID = "stabilityai/stable-diffusion-2-1" SD_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-2-1.tar" -SAFETY_CACHE = os.path.join(MODEL_CACHE, "models--CompVis--stable-diffusion-safety-checker") +SAFETY_CACHE = os.path.join( + MODEL_CACHE, "models--CompVis--stable-diffusion-safety-checker" +) SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" SAFETY_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-safety-checker.tar" + class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" - print("Loading pipeline...") WeightsDownloader.download_if_not_exists(SAFETY_URL, SAFETY_CACHE) safety_checker = StableDiffusionSafetyChecker.from_pretrained( From 775ecdeba9303da21b6a6a371733aa8cf7d122b8 Mon Sep 17 00:00:00 2001 From: Chigozie Nri Date: Mon, 4 Dec 2023 13:23:35 +0000 Subject: [PATCH 4/5] Update .dockerignore and .gitignore --- .dockerignore | 19 +++++++++++++++++++ .gitignore | 5 +++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/.dockerignore b/.dockerignore index 6b8710a..aa29595 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,20 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files .git +.github +.gitignore + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude weights +diffusers-cache + +# Exclude Python virtual environment +/venv \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0bb6e87..0fb9c33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ -.cog/ +.vscode/ __pycache__/ -diffusers-cache/ \ No newline at end of file +diffusers-cache/ +.cog/ \ No newline at end of file From c62d9137d92c894062548b5ac3abfe556df093af Mon Sep 17 00:00:00 2001 From: chigozienri Date: Mon, 4 Dec 2023 14:58:51 +0000 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Yorick --- .gitignore | 1 - README.md | 2 +- cog.yaml | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 0fb9c33..a5d32bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -.vscode/ __pycache__/ diffusers-cache/ .cog/ \ No newline at end of file diff --git a/README.md b/README.md index a5fef02..d369282 100644 --- a/README.md +++ b/README.md @@ -11,5 +11,5 @@ cog predict -i prompt="monkey scuba diving" Run HTTP API for making predictions: ```bash -cog run -p 5000 python -m cog.server.http +cog run -p 5000 ``` \ No newline at end of file diff --git a/cog.yaml b/cog.yaml index e2e25b3..a4fd2c4 100644 --- a/cog.yaml +++ b/cog.yaml @@ -12,6 +12,6 @@ build: - "huggingface-hub==0.13.2" run: - - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.6/pget" && chmod +x /usr/local/bin/pget + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.3.0/pget" && chmod +x /usr/local/bin/pget predict: "predict.py:Predictor"