Link to the paper | Official PyTorch implementation | Project page
This repository contains the Tensorflow/Keras code implementation for the paper "Prompt-to-Prompt Image Editing with Cross Attention Control".
Current state-of-the-art methods require the user to provide a spatial mask to localize the edit which ignores the original structure and content within the masked region. The paper proposes a novel technique to edit the generated content of large-scale language models such as DALLΒ·E 2, Imagen or Stable Diffusion, by only manipulating the text of the original parsed prompt.
To achieve this result, the authors present the Prompt-to-Prompt framework comprised of two functionalities:
-
Prompt Editing: where the key idea to edit the generated images is to inject cross-attention maps during the diffusion process, controlling which pixels attend to which tokens of the prompt text.
-
Attention Re-weighting: that amplifies or attenuates the effect of a word in the generated image. This is done by first attributing a weight to each token and later scaling the attention map assigned to the token. It's a nice alternative to negative prompting and multi-prompting.
Install dependencies using the requirements.txt.
pip install -r requirements.txtEssentially, you need to have installed TensorFlow and Keras-cv.
Try it yourself:
-
Prompt-to-Prompt: Prompt Editing - Stable Diffusion
Notebook with examples for the Prompt-to-Prompt prompt editing approach for Stable Diffusion. -
Prompt-to-Prompt: Attention Re-weighting - Stable Diffusion
Notebook with examples for the Prompt-to-Prompt attention re-weighting approach for Stable Diffusion.
To start using the Prompt-to-Prompt framework, you first need to set up a Tensorflow strategy for running computations across multiple devices (in case you have many).
For example, you can check the available hardware with:
gpus = tf.config.list_physical_devices("GPU")
tpus = tf.config.list_physical_devices("TPU")
print(f"Num GPUs Available: {len(gpus)} | Num TPUs Available: {len(tpus)}")And adjust accordingly to your needs:
import tensorflow as tf
# For running on multiple GPUs
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1", ...])
# To get the default strategy
strategy = tf.distribute.get_strategy()
...Once the strategy is set, you can start generating images just like in Keras-cv:
# Imports
import tensorflow as tf
from stable_diffusion import StableDiffusion
generator = StableDiffusion(
strategy=strategy,
img_height=512,
img_width=512,
jit_compile=False,
)
# Generate text-to-image
img = generator.text_to_image(
prompt="a photo of a chiwawa with sunglasses and a bandana",
num_steps=50,
unconditional_guidance_scale=8,
seed=5681067,
batch_size=1,
)
# Generate Prompt-to-Prompt
img_edit = generator.text_to_image_ptp(
prompt="a photo of a chiwawa with sunglasses and a bandana",
prompt_edit="a photo of a chiwawa with sunglasses and a pirate bandana",
num_steps=50,
unconditional_guidance_scale=8,
cross_attn2_replace_steps_start=0.0,
cross_attn2_replace_steps_end=1.0,
cross_attn1_replace_steps_start=0.8,
cross_attn1_replace_steps_end=1.0,
seed=5681067,
batch_size=1,
)This generates the original and pirate bandana images shown below. You can play around and change the <bandana> and <sunglasses> attributes and many others!
Another example of prompt editing where one can control the content of the basket just by replacing a couple of words in the prompt:
img_edit = generator.text_to_image_ptp(
prompt="a photo of basket with apples",
prompt_edit="a photo of basket with oranges",
num_steps=50,
unconditional_guidance_scale=8,
cross_attn2_replace_steps_start=0.0,
cross_attn2_replace_steps_end=1.0,
cross_attn1_replace_steps_start=0.0,
cross_attn1_replace_steps_end=1.0,
seed=1597337,
batch_size=1,
)The image below showcases examples where only the word <apples> was replaced with other fruits or animals. Try changing <basket> to other recipients (e.g. bowl or nest) and see what happens!
To manipulate the relative importance of tokens, we've added an argument to pass in both the text_to_image and text_to_image_ptp methods. You can create an array of weights using our method create_prompt_weights.
For example, you generated a pizza that doesn't have enough pineapple on it, you can edit the weights of your prompt:
prompt = "a photo of a pizza with pineapple"
prompt_weights = generator.create_prompt_weights(prompt, [('pineapple', 2)])This will create an array with 1's except on the pineapple word position where it will be a 2.
To generate a pizza with more pineapple (yak!), you just need to pass the variable prompt_weights to the text_to_image method:
img = generator.text_to_image(
prompt="a photo of a pizza with pineapple",
num_steps=50,
unconditional_guidance_scale=8,
prompt_weights=prompt_weights,
seed=1234,
batch_size=1,
)Now you want to reduce the amount of blossom in a tree:
prompt = "A photo of a blossom tree"
prompt_weights = generator.create_prompt_weights(prompt, [('blossom', -1)])
img = generator.text_to_image(
prompt="A photo of a blossom tree",
num_steps=50,
unconditional_guidance_scale=8,
prompt_weights=prompt_weights,
seed=1407923,
batch_size=1,
)Decreasing the weight associated to <blossom> will generate the following images.
For the prompt editing method, implemented in the function text_to_image_ptp, varying the parameters that indicate in which phase of the diffusion process the edited cross-attention maps should get injected (e.g. cross_attn2_replace_steps_start, cross_attn1_replace_steps_start), may output different results (image below).
The cross-attention and prompt weights hyperparameters should be tuned according to the users' necessities and desired outputs.
More info in bloc97/CrossAttentionControl and the paper.
- Add tutorials and Google Colabs.
- Add multi-batch support.
- Add examples for Stable Diffusion 2.x.
- keras-cv for the TensorFlow implementation of Stable Diffusion.
- bloc97/CrossAttentionControl unofficial implementation of the paper, where the method
get_matching_sentence_tokensand code logic were used. - google/prompt-to-prompt Official implementation of the paper in PyTorch.
Feel free to open an issue or create a Pull Request.
For PRs, after implementing the changes please run the Makefile for formatting and linting the submitted code:
make init: to create a python environment with all the developer packages (Optional).make format: to format the code.make lint: to lint the code.make type_check: to check for type hints.make all: to run all the checks.
Licensed under the Apache License 2.0. See LICENSE to read it in full.