diff --git a/.dockerignore b/.dockerignore index 6b8710a..aa29595 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,20 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files .git +.github +.gitignore + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude weights +diffusers-cache + +# Exclude Python virtual environment +/venv \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0bb6e87..a5d32bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -.cog/ __pycache__/ -diffusers-cache/ \ No newline at end of file +diffusers-cache/ +.cog/ \ No newline at end of file diff --git a/README.md b/README.md index a4373cf..d369282 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,12 @@ This is an implementation of the [Diffusers Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog) -First, download the pre-trained weights: - - cog run script/download-weights - -Then, you can run predictions: - - cog predict -i prompt="monkey scuba diving" +Make single prediction: +```bash +cog predict -i prompt="monkey scuba diving" +``` + +Run HTTP API for making predictions: +```bash +cog run -p 5000 +``` \ No newline at end of file diff --git a/cog.yaml b/cog.yaml index 06fad69..a4fd2c4 100644 --- a/cog.yaml +++ b/cog.yaml @@ -11,4 +11,7 @@ build: - "accelerate==0.15.0" - "huggingface-hub==0.13.2" + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.3.0/pget" && chmod +x /usr/local/bin/pget + predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index db3c20e..f2985a5 100644 --- a/predict.py +++ b/predict.py @@ -4,33 +4,48 @@ import torch from cog import BasePredictor, Input, Path from diffusers import ( - StableDiffusionPipeline, - PNDMScheduler, - LMSDiscreteScheduler, DDIMScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, ) from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) +from weights_downloader import WeightsDownloader + # MODEL_ID refers to a diffusers-compatible model on HuggingFace # e.g. prompthero/openjourney-v2, wavymulder/Analog-Diffusion, etc -MODEL_ID = "stabilityai/stable-diffusion-2-1" MODEL_CACHE = "diffusers-cache" + +SD_MODEL_CACHE = os.path.join(MODEL_CACHE, "models--stabilityai--stable-diffusion-2-1") +MODEL_ID = "stabilityai/stable-diffusion-2-1" +SD_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-2-1.tar" + +SAFETY_CACHE = os.path.join( + MODEL_CACHE, "models--CompVis--stable-diffusion-safety-checker" +) SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" +SAFETY_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-safety-checker.tar" + class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" + print("Loading pipeline...") + WeightsDownloader.download_if_not_exists(SAFETY_URL, SAFETY_CACHE) safety_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_MODEL_ID, cache_dir=MODEL_CACHE, local_files_only=True, ) + + WeightsDownloader.download_if_not_exists(SD_URL, SD_MODEL_CACHE) self.pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, safety_checker=safety_checker, diff --git a/script/download-weights b/script/download-weights deleted file mode 100755 index 6cf233d..0000000 --- a/script/download-weights +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python - -import os -import shutil -import sys - -from diffusers import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker - -# append project directory to path so predict.py can be imported -sys.path.append('.') - -from predict import MODEL_CACHE, MODEL_ID, SAFETY_MODEL_ID - -if os.path.exists(MODEL_CACHE): - shutil.rmtree(MODEL_CACHE) -os.makedirs(MODEL_CACHE, exist_ok=True) - -saftey_checker = StableDiffusionSafetyChecker.from_pretrained( - SAFETY_MODEL_ID, - cache_dir=MODEL_CACHE, -) - -pipe = StableDiffusionPipeline.from_pretrained( - MODEL_ID, - cache_dir=MODEL_CACHE, -) diff --git a/weights_downloader.py b/weights_downloader.py new file mode 100644 index 0000000..3332365 --- /dev/null +++ b/weights_downloader.py @@ -0,0 +1,18 @@ +import os +import subprocess +import time + + +class WeightsDownloader: + @staticmethod + def download_if_not_exists(url, dest): + if not os.path.exists(dest): + WeightsDownloader.download(url, dest) + + @staticmethod + def download(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start)