-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathpipeline_metaquery.py
More file actions
60 lines (50 loc) · 1.82 KB
/
pipeline_metaquery.py
File metadata and controls
60 lines (50 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union, List
from typing import Tuple
import torch
from PIL import Image
from diffusers.pipelines.pipeline_utils import BaseOutput
import numpy as np
import PIL
from dataclasses import dataclass
from models.metaquery import MetaQuery
@dataclass
class MetaQueryPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
# text: Optional[List[str]] = [""]
class MetaQueryPipeline(MetaQuery):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def __call__(
self,
caption: Optional[str] = "",
image: Optional[
Union[Image.Image, List[Image.Image], torch.Tensor, List[torch.Tensor]]
] = None,
negative_prompt: Optional[str] = "",
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
return_dict: bool = True,
**kwargs,
) -> Union[MetaQueryPipelineOutput, Tuple]:
samples = self.sample_images(
caption=caption,
input_images=image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)
if not return_dict:
return (samples,)
return MetaQueryPipelineOutput(images=samples)