-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
71 lines (53 loc) · 1.98 KB
/
main.py
File metadata and controls
71 lines (53 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import gradio as gr
import argparse
from PIL import Image
from src.api import DALS_API
from src.utils import read_config
def mirror(x):
return x
def show_image(selected_label):
image_path = image_options[selected_label]
return Image.open(image_path)
def process_input(prompt, img_path):
print(f"process_input: img_path: {img_path}")
if img_path:
return api.generate_image(prompt, img_path)
else:
# Generate image from prompt using API
return api.generate_image(prompt)
image_options = {
"3CPM_1": "/home/hm086/joono/DALS/3CPM_examples/far_sighted/npp1.png",
"3CPM_2": "/home/hm086/joono/DALS/3CPM_examples/one_vanishing_point/opp1.png",
"3CPM_3": "/home/hm086/joono/DALS/3CPM_examples/two_vanishing_points/tpp1.png",
# Add as many images as you like
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
args = parser.parse_args()
print(args.config)
cfg = read_config(args.config)
print(f"[cfg]: {cfg}")
api = DALS_API(cfg)
with gr.Blocks() as demo:
with gr.Row():
threeCPM_img = gr.Image(type="pil")
result_img = gr.Image(type="pil")
with gr.Row():
input_prompt = txt = gr.Textbox(label="Input", lines=2)
btn = gr.Button(value="Submit")
btn.click(process_input, inputs=[input_prompt, threeCPM_img], outputs=[result_img])
gr.Markdown("## Image Examples")
gr.Examples(
examples=[
os.path.join(os.path.dirname(__file__), "3CPM_examples/far_sighted/npp1.png"),
os.path.join(os.path.dirname(__file__), "3CPM_examples/one_vanishing_point/opp1.png"),
os.path.join(os.path.dirname(__file__), "3CPM_examples/two_vanishing_points/tpp1.png"),
],
inputs=threeCPM_img,
outputs=threeCPM_img,
fn=mirror,
cache_examples=True,
)
demo.launch()