xiaoyao9184 commited on
Commit
87c1b67
·
verified ·
1 Parent(s): f226f0e

Synced repo using 'sync_with_huggingface' Github Action

Browse files

original:
- remote: "https://github.com/xiaoyao9184/docker-dall-e"
- commit: "817977c7edeb8eef1a5c2245179439389c278051"
sync_with_huggingface:
- repository: ""
- ref: ""

Files changed (5) hide show
  1. README.md +9 -5
  2. app.py +50 -0
  3. gradio_app.py +198 -0
  4. gradio_run.py +7 -0
  5. requirements.txt +12 -0
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
- title: Dall E
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.9.0
 
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: DALL-E
3
+ emoji: 🖼️
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ python_version: '3.8.20'
9
  app_file: app.py
10
  pinned: false
11
+ license: apache-2.0
12
+ models:
13
+ - xiaoyao9184/dall-e
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import git
4
+ import subprocess
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ REPO_URL = "https://github.com/openai/DALL-E.git"
8
+ REPO_BRANCH = '5be4b236bc3ade6943662354117a0e83752cc322'
9
+ LOCAL_PATH = "./DALL-E"
10
+ MODEL_ID = "xiaoyao9184/dall-e"
11
+
12
+ def install_src():
13
+ if not os.path.exists(LOCAL_PATH):
14
+ print(f"Cloning repository from {REPO_URL}@{REPO_BRANCH} to {LOCAL_PATH}...")
15
+ repo = git.Repo.clone_from(REPO_URL, LOCAL_PATH)
16
+ repo.git.checkout(REPO_BRANCH)
17
+ else:
18
+ print(f"Repository already exists at {LOCAL_PATH}")
19
+
20
+ requirements_path = os.path.join(LOCAL_PATH, "requirements.txt")
21
+ if os.path.exists(requirements_path):
22
+ print("Installing requirements...")
23
+ subprocess.check_call(["pip", "install", "-r", requirements_path])
24
+ else:
25
+ print("No requirements.txt found.")
26
+
27
+ def install_model():
28
+ checkpoint_path = os.path.join(LOCAL_PATH)
29
+ print(f"Downloading model from {MODEL_ID}...")
30
+ hf_hub_download(repo_id=MODEL_ID, revision='master', filename='encoder.pkl', local_dir=checkpoint_path)
31
+ hf_hub_download(repo_id=MODEL_ID, revision='master', filename='decoder.pkl', local_dir=checkpoint_path)
32
+
33
+ # clone repo and download model
34
+ install_src()
35
+ install_model()
36
+
37
+ # fix sys.path for import
38
+ print(f"LOCAL_PATH: {os.path.abspath(LOCAL_PATH)}")
39
+ os.environ["APP_PATH"] = os.path.abspath(LOCAL_PATH)
40
+
41
+ # run gradio in subprocess in reloaded mode
42
+ # huggingface space issue: https://github.com/gradio-app/gradio/issues/10048
43
+ # need disable reload for huggingface space
44
+ import re
45
+ import sys
46
+ from gradio.cli import cli
47
+ if __name__ == '__main__':
48
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
49
+ sys.argv.append(re.sub(r'app\.py$', 'gradio_app.py', sys.argv[0]))
50
+ sys.exit(cli())
gradio_app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import git
4
+
5
+ if "APP_PATH" in os.environ:
6
+ # fix sys.path for import
7
+ os.chdir(os.environ["APP_PATH"])
8
+ if os.getcwd() not in sys.path:
9
+ sys.path.append(os.getcwd())
10
+
11
+ # remove duplicate gradio_app path from sys.path
12
+ sys.path = list(dict.fromkeys(sys.path))
13
+
14
+ # remove gradio reload env if in huggingface space
15
+ if "SPACE_ID" in os.environ:
16
+ for key in ["GRADIO_WATCH_DIRS", "GRADIO_WATCH_MODULE_NAME", "GRADIO_WATCH_DEMO_NAME", "GRADIO_WATCH_DEMO_PATH"]:
17
+ if key in os.environ:
18
+ del os.environ[key]
19
+
20
+ def get_app_git_commit():
21
+ app_path = os.environ.get("APP_PATH")
22
+ if not app_path:
23
+ return None
24
+ try:
25
+ repo = git.Repo(app_path, search_parent_directories=False)
26
+ hexsha = repo.head.commit.hexsha
27
+ return hexsha
28
+ except (git.exc.InvalidGitRepositoryError, ValueError, git.exc.GitError):
29
+ return None
30
+
31
+ # here the subprocess stops loading, because __name__ is NOT '__main__'
32
+ # gradio will reload
33
+ if '__main__' == __name__:
34
+
35
+ import gradio as gr
36
+ from contextlib import suppress
37
+
38
+ import os
39
+ import torch
40
+ import torchvision.transforms as T
41
+ import torchvision.transforms.functional as TF
42
+ import torch.nn.functional as F
43
+ import json
44
+ from PIL import Image
45
+
46
+ from dall_e import map_pixels, unmap_pixels, load_model
47
+
48
+ CHECKPOINT_MODEL_PATH = os.environ.get("CHECKPOINT_MODEL_PATH", './')
49
+
50
+ # Device configuration
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ enc = load_model(f"{CHECKPOINT_MODEL_PATH}/encoder.pkl", device)
53
+ dec = load_model(f"{CHECKPOINT_MODEL_PATH}/decoder.pkl", device)
54
+
55
+ target_image_size = 256
56
+ transform = T.Compose([
57
+ T.Resize(target_image_size, interpolation=Image.LANCZOS),
58
+ T.CenterCrop(target_image_size),
59
+ T.ToTensor(),
60
+ ])
61
+
62
+ def encode_image(image):
63
+ """
64
+ Encodes the given image.
65
+
66
+ Args:
67
+ image (Union[PIL.Image.Image, str]): Input image, can be a PIL Image object, image array, or URL.
68
+
69
+ Returns:
70
+ dict: A dictionary with the following fields:
71
+ - ``shape`` (List[int]): Spatial size of the code grid, fixed as ``[32, 32]``.
72
+ - ``vocab_size`` (int): Vocabulary size of the model, fixed as `8192`.
73
+ - ``tokens`` (List[List[int]]): Quantized image tokens on a ``32 x 32`` grid,
74
+ converted from the internal one-hot representation to nested Python lists
75
+ on CPU for serialization and front-end display.
76
+ """
77
+ if image is None:
78
+ raise gr.Error("Please upload an image before clicking the \"Encoding\" button.")
79
+
80
+ if isinstance(image, Image.Image) is False:
81
+ image = Image.fromarray(image)
82
+
83
+ x = transform(image).to(device)
84
+ x = map_pixels(x.unsqueeze(0))
85
+
86
+ z_logits = enc(x)
87
+ z = torch.argmax(z_logits, axis=1)
88
+
89
+ z_cpu = z.squeeze(0).cpu().numpy().tolist()
90
+ return {
91
+ "shape": [32, 32],
92
+ "vocab_size": enc.vocab_size,
93
+ "tokens": z_cpu
94
+ }
95
+
96
+ def decode_code(code):
97
+ """
98
+ Embeds one or more watermarks into the input image.
99
+
100
+ Args:
101
+ code (Union[str, dict]): Encoded image code. It can be a JSON string or a
102
+ Python dictionary with the same structure as the output of ``encode_image``,
103
+ i.e. containing at least a ``"tokens"`` field that holds the quantized
104
+ image tokens.
105
+
106
+ Returns:
107
+ image (Union[PIL.Image.Image, str]): Output image, either a PIL Image object or a URL pointing to the image.
108
+ """
109
+ if isinstance(code, str):
110
+ code = json.loads(code)
111
+
112
+ # code["tokens"] is expected to be a 32x32 grid of token ids
113
+ # match the shape used in process_image: [1, 32, 32] -> one-hot -> [1, vocab, 32, 32]
114
+ z = torch.tensor(code["tokens"], dtype=torch.long).unsqueeze(0).to(device) # [1,32,32]
115
+ z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float() # [1,vocab,32,32]
116
+
117
+ x_stats = dec(z).float()
118
+ x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
119
+ x_rec = T.ToPILImage(mode='RGB')(x_rec[0].cpu())
120
+
121
+ return x_rec
122
+
123
+ def process_image(image):
124
+ if isinstance(image, Image.Image) is False:
125
+ image = Image.fromarray(image)
126
+
127
+ x = transform(image).to(device)
128
+ x = map_pixels(x.unsqueeze(0))
129
+
130
+ z_logits = enc(x)
131
+ z = torch.argmax(z_logits, axis=1)
132
+
133
+
134
+ z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()
135
+
136
+ x_stats = dec(z).float()
137
+ x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
138
+ x_rec = T.ToPILImage(mode='RGB')(x_rec[0].to('cpu'))
139
+
140
+ return x_rec
141
+
142
+ DALLE_VERSION = get_app_git_commit() or "unknown"
143
+
144
+ with gr.Blocks(title="DALL- Demo", css="""
145
+ .align-bottom {display:flex; flex-direction:column; justify-content:flex-end;}
146
+ """) as demo:
147
+ gr.Markdown(f"""
148
+ # DALL-E Demo
149
+
150
+ > DALL-E: [`{DALLE_VERSION}`](https://github.com/openai/DALL-E/tree/{DALLE_VERSION})
151
+
152
+ Find the original project [here](https://github.com/openai/DALL-E).
153
+ Or this project [here](https://github.com/xiaoyao9184/docker-dall-e).
154
+ See the [README](./blob/main/README.md) for Spaces's metadata.
155
+ """)
156
+
157
+ with gr.Tabs():
158
+ with gr.TabItem("Encoding -> Decoding"):
159
+ with gr.Row():
160
+ with gr.Column():
161
+ original_img = gr.Image(label="Original Image", type="numpy", height=512)
162
+ process_btn = gr.Button("process")
163
+ with gr.Column():
164
+ reconstructed_image = gr.Image(label="Reconstructed Image")
165
+ with gr.TabItem("Encoding"):
166
+ with gr.Row():
167
+ with gr.Column():
168
+ encoding_img = gr.Image(label="Input Image", type="numpy", height=512)
169
+ encoding_btn = gr.Button("Encoding")
170
+ with gr.Column():
171
+ encoding_messages = gr.JSON(label="Encoding Messages")
172
+ with gr.TabItem("Decoding"):
173
+ with gr.Row():
174
+ with gr.Column():
175
+ encoding_code = gr.Code(label="Encoding Code", language="json")
176
+ decoding_btn = gr.Button("Decoding")
177
+ with gr.Column(elem_classes="align-bottom"):
178
+ decoding_image = gr.Image(label="Decoding Image")
179
+
180
+ encoding_btn.click(
181
+ fn=encode_image,
182
+ inputs=encoding_img,
183
+ outputs=encoding_messages
184
+ )
185
+ decoding_btn.click(
186
+ fn=decode_code,
187
+ inputs=encoding_code,
188
+ outputs=decoding_image
189
+ )
190
+ process_btn.click(
191
+ fn=process_image,
192
+ inputs=[original_img],
193
+ outputs=[reconstructed_image],
194
+ api_name=False
195
+ )
196
+
197
+ if __name__ == '__main__':
198
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
gradio_run.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # NOTE: copy from gradio bin
2
+ import re
3
+ import sys
4
+ from gradio.cli import cli
5
+ if __name__ == '__main__':
6
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
7
+ sys.exit(cli())
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.7.1
2
+ torchvision==0.8.2
3
+ GitPython==3.1.43
4
+ gradio==4.44.1
5
+ huggingface-hub==0.28.1
6
+
7
+ Pillow==9.3.0
8
+ blobfile==3.1.0
9
+ mypy==1.14.1
10
+ numpy==1.24.3
11
+ pytest==8.3.5
12
+ requests==2.32.4