Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __call__(
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -144,6 +145,8 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation.
Comment thread
shirayu marked this conversation as resolved.
Outdated
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -217,9 +220,25 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
ucond_tokens: List[str]
if negative_prompt is None:
ucond_tokens = [""] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif type(prompt) is not type(negative_prompt):
raise TypeError("`negative_prompt` should be the same type to `prompt`.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif isinstance(negative_prompt, str):
ucond_tokens = [negative_prompt] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
else:
ucond_tokens = negative_prompt
Comment thread
shirayu marked this conversation as resolved.
Outdated

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
ucond_tokens,
Comment thread
shirayu marked this conversation as resolved.
Outdated
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __call__(
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -160,6 +161,8 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -258,9 +261,25 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
ucond_tokens: List[str]
if negative_prompt is None:
ucond_tokens = [""] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif type(prompt) is not type(negative_prompt):
raise TypeError("`negative_prompt` should be the same type to `prompt`.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif isinstance(negative_prompt, str):
ucond_tokens = [negative_prompt] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
Comment thread
shirayu marked this conversation as resolved.
else:
ucond_tokens = negative_prompt
Comment thread
shirayu marked this conversation as resolved.
Outdated

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
ucond_tokens,
Comment thread
shirayu marked this conversation as resolved.
Outdated
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __call__(
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -180,6 +181,8 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Expand Down Expand Up @@ -292,9 +295,25 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
ucond_tokens: List[str]
if negative_prompt is None:
ucond_tokens = [""] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif type(prompt) is not type(negative_prompt):
raise TypeError("`negative_prompt` should be the same type to `prompt`.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif isinstance(negative_prompt, str):
ucond_tokens = [negative_prompt] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
else:
ucond_tokens = negative_prompt
Comment thread
shirayu marked this conversation as resolved.
Outdated

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
ucond_tokens,
Comment thread
shirayu marked this conversation as resolved.
Outdated
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __call__(
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: Optional[float] = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -102,9 +103,25 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
ucond_tokens: List[str]
if negative_prompt is None:
ucond_tokens = [""] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif type(prompt) is not type(negative_prompt):
raise TypeError("`negative_prompt` should be the same type to `prompt`.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif isinstance(negative_prompt, str):
ucond_tokens = [negative_prompt] * batch_size
Comment thread
shirayu marked this conversation as resolved.
Outdated
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
Comment thread
shirayu marked this conversation as resolved.
Outdated
else:
ucond_tokens = negative_prompt
Comment thread
shirayu marked this conversation as resolved.
Outdated

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
ucond_tokens,
Comment thread
shirayu marked this conversation as resolved.
Outdated
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]

Expand Down