refactor
Browse files- app.py +563 -541
- encode_image.py +29 -0
- llm_service.py +258 -0
- mv_utils_zs.py +483 -0
- string_utils.py +69 -0
app.py
CHANGED
|
@@ -1,352 +1,116 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
import os
|
| 3 |
-
import subprocess
|
| 4 |
import asyncio
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
import random
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
-
import zipfile
|
| 11 |
-
import xml.etree.ElementTree as ET
|
| 12 |
import tempfile
|
| 13 |
-
from
|
| 14 |
-
from typing import Tuple, Dict, Any, Union, List
|
| 15 |
|
| 16 |
-
import gradio as gr
|
| 17 |
-
import trimesh
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
from loguru import logger
|
|
|
|
| 23 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 24 |
from torch import Tensor
|
| 25 |
-
import torchvision.transforms.functional as TF
|
| 26 |
-
from torch_scatter import scatter
|
| 27 |
|
| 28 |
-
from
|
| 29 |
-
from
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
#
|
| 39 |
GRADIO_3D_MODEL_DEFAULT_FORMAT = [".obj", ".glb", ".gltf", ".stl", ".splat", ".ply"]
|
| 40 |
USER_REQUIRE_FORMAT = [".3dxml", ".step"]
|
| 41 |
FREECAD_LOW_LEVEL_FORMAT = [".step", ".igs", ".iges"]
|
| 42 |
FREECAD_NATIVE_FORMAT = [".fcstd"]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
)
|
| 50 |
-
)
|
| 51 |
-
VALID_FILE_TYPES = VALID_FILE_TYPES + [t.upper() for t in VALID_FILE_TYPES]
|
| 52 |
-
|
| 53 |
-
# Realistic Projection Parameters (from mv_utils_zs.py)
|
| 54 |
-
TRANS = -1.5
|
| 55 |
-
params = {
|
| 56 |
-
"maxpoolz": 1,
|
| 57 |
-
"maxpoolxy": 7,
|
| 58 |
-
"maxpoolpadz": 0,
|
| 59 |
-
"maxpoolpadxy": 2,
|
| 60 |
-
"convz": 1,
|
| 61 |
-
"convxy": 3,
|
| 62 |
-
"convsigmaxy": 3,
|
| 63 |
-
"convsigmaz": 1,
|
| 64 |
-
"convpadz": 0,
|
| 65 |
-
"convpadxy": 1,
|
| 66 |
-
"imgbias": 0.0,
|
| 67 |
-
"depth_bias": 0.2,
|
| 68 |
-
"obj_ratio": 0.8,
|
| 69 |
-
"bg_clr": 0.0,
|
| 70 |
-
"resolution": 122,
|
| 71 |
-
"depth": 8,
|
| 72 |
-
"grid_height": 64,
|
| 73 |
-
"grid_width": 64,
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def get2DGaussianKernel(ksize, sigma=0):
|
| 78 |
-
center = ksize // 2
|
| 79 |
-
xs = np.arange(ksize, dtype=np.float32) - center
|
| 80 |
-
kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
|
| 81 |
-
kernel = kernel1d[..., None] @ kernel1d[None, ...]
|
| 82 |
-
kernel = torch.from_numpy(kernel)
|
| 83 |
-
kernel = kernel / kernel.sum()
|
| 84 |
-
return kernel
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
|
| 88 |
-
kernel2d = get2DGaussianKernel(ksize, sigma)
|
| 89 |
-
zs = np.arange(depth, dtype=np.float32) - depth // 2
|
| 90 |
-
zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
|
| 91 |
-
kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
|
| 92 |
-
kernel3d = kernel3d / torch.sum(kernel3d)
|
| 93 |
-
return kernel3d
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def euler2mat(angle):
|
| 97 |
-
if len(angle.size()) == 1:
|
| 98 |
-
x, y, z = angle[0], angle[1], angle[2]
|
| 99 |
-
_dim, _view = 0, [3, 3]
|
| 100 |
-
else:
|
| 101 |
-
b, _ = angle.size()
|
| 102 |
-
x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
|
| 103 |
-
_dim, _view = 1, [b, 3, 3]
|
| 104 |
-
zero, one = z.detach() * 0, z.detach() * 0 + 1
|
| 105 |
-
cosz, sinz = torch.cos(z), torch.sin(z)
|
| 106 |
-
zmat = torch.stack(
|
| 107 |
-
[cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
|
| 108 |
-
).reshape(_view)
|
| 109 |
-
cosy, siny = torch.cos(y), torch.sin(y)
|
| 110 |
-
ymat = torch.stack(
|
| 111 |
-
[cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
|
| 112 |
-
).reshape(_view)
|
| 113 |
-
cosx, sinx = torch.cos(x), torch.sin(x)
|
| 114 |
-
xmat = torch.stack(
|
| 115 |
-
[one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
|
| 116 |
-
).reshape(_view)
|
| 117 |
-
return xmat @ ymat @ zmat
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
|
| 121 |
-
batch, pnum, _ = points.shape
|
| 122 |
-
pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
|
| 123 |
-
pcent = (pmax + pmin) / 2
|
| 124 |
-
pcent = pcent[:, None, :]
|
| 125 |
-
prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
|
| 126 |
-
points = (points - pcent) / prange * 2.0
|
| 127 |
-
points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]
|
| 128 |
-
_x = (points[:, :, 0] + 1) / 2 * resolution
|
| 129 |
-
_y = (points[:, :, 1] + 1) / 2 * resolution
|
| 130 |
-
_z = (
|
| 131 |
-
((points[:, :, 2] + 1) / 2 + params["depth_bias"])
|
| 132 |
-
/ (1 + params["depth_bias"])
|
| 133 |
-
* (depth - 2)
|
| 134 |
-
)
|
| 135 |
-
_x.ceil_(), _y.ceil_()
|
| 136 |
-
z_int = _z.ceil()
|
| 137 |
-
_x, _y, _z = (
|
| 138 |
-
torch.clip(_x, 1, resolution - 2),
|
| 139 |
-
torch.clip(_y, 1, resolution - 2),
|
| 140 |
-
torch.clip(_z, 1, depth - 2),
|
| 141 |
-
)
|
| 142 |
-
coordinates = z_int * resolution * resolution + _y * resolution + _x
|
| 143 |
-
grid = (
|
| 144 |
-
torch.ones([batch, depth, resolution, resolution], device=points.device).view(
|
| 145 |
-
batch, -1
|
| 146 |
-
)
|
| 147 |
-
* params["bg_clr"]
|
| 148 |
-
)
|
| 149 |
-
grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
|
| 150 |
-
grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
|
| 151 |
-
return grid
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class Grid2Image(torch.nn.Module):
|
| 155 |
-
def __init__(self):
|
| 156 |
-
super().__init__()
|
| 157 |
-
self.maxpool = torch.nn.MaxPool3d(
|
| 158 |
-
(params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
|
| 159 |
-
stride=1,
|
| 160 |
-
padding=(
|
| 161 |
-
params["maxpoolpadz"],
|
| 162 |
-
params["maxpoolpadxy"],
|
| 163 |
-
params["maxpoolpadxy"],
|
| 164 |
-
),
|
| 165 |
-
)
|
| 166 |
-
self.conv = torch.nn.Conv3d(
|
| 167 |
-
1,
|
| 168 |
-
1,
|
| 169 |
-
kernel_size=(params["convz"], params["convxy"], params["convxy"]),
|
| 170 |
-
stride=1,
|
| 171 |
-
padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
|
| 172 |
-
bias=True,
|
| 173 |
-
)
|
| 174 |
-
kn3d = get3DGaussianKernel(
|
| 175 |
-
params["convxy"],
|
| 176 |
-
params["convz"],
|
| 177 |
-
sigma=params["convsigmaxy"],
|
| 178 |
-
zsigma=params["convsigmaz"],
|
| 179 |
-
)
|
| 180 |
-
self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
|
| 181 |
-
self.conv.bias.data.fill_(0)
|
| 182 |
-
|
| 183 |
-
def forward(self, x):
|
| 184 |
-
x = self.maxpool(x.unsqueeze(1))
|
| 185 |
-
x = self.conv(x)
|
| 186 |
-
img = torch.max(x, dim=2)[0]
|
| 187 |
-
img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
|
| 188 |
-
img = 1 - img
|
| 189 |
-
img = img.repeat(1, 3, 1, 1)
|
| 190 |
-
return img
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
class Realistic_Projection:
|
| 194 |
-
def __init__(self):
|
| 195 |
-
_views = np.asarray([
|
| 196 |
-
[[np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 197 |
-
[[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 198 |
-
[[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 199 |
-
[[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 200 |
-
[[0, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 201 |
-
[[np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 202 |
-
[[np.pi, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 203 |
-
[[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 204 |
-
[[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 205 |
-
[[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 206 |
-
])
|
| 207 |
-
_views_bias = np.asarray([
|
| 208 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 209 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 210 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 211 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 212 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 213 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 214 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 215 |
-
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 216 |
-
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
|
| 217 |
-
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
|
| 218 |
-
])
|
| 219 |
-
angle, angle2 = (
|
| 220 |
-
torch.tensor(_views[:, 0, :]).float(),
|
| 221 |
-
torch.tensor(_views_bias[:, 0, :]).float(),
|
| 222 |
-
)
|
| 223 |
-
self.rot_mat, self.rot_mat2 = (
|
| 224 |
-
euler2mat(angle).transpose(1, 2),
|
| 225 |
-
euler2mat(angle2).transpose(1, 2),
|
| 226 |
-
)
|
| 227 |
-
self.translation = torch.tensor(_views[:, 1, :]).float().unsqueeze(1)
|
| 228 |
-
self.grid2image = Grid2Image()
|
| 229 |
-
|
| 230 |
-
def get_img(self, points):
|
| 231 |
-
b, _, _ = points.shape
|
| 232 |
-
v = self.translation.shape[0]
|
| 233 |
-
_points = self.point_transform(
|
| 234 |
-
torch.repeat_interleave(points, v, dim=0),
|
| 235 |
-
self.rot_mat.repeat(b, 1, 1),
|
| 236 |
-
self.rot_mat2.repeat(b, 1, 1),
|
| 237 |
-
self.translation.repeat(b, 1, 1),
|
| 238 |
-
)
|
| 239 |
-
grid = points2grid(
|
| 240 |
-
_points, resolution=params["resolution"], depth=params["depth"]
|
| 241 |
-
).squeeze()
|
| 242 |
-
return self.grid2image(grid)
|
| 243 |
-
|
| 244 |
-
@staticmethod
|
| 245 |
-
def point_transform(points, rot_mat, rot_mat2, translation):
|
| 246 |
-
rot_mat, rot_mat2, translation = (
|
| 247 |
-
rot_mat.to(points.device),
|
| 248 |
-
rot_mat2.to(points.device),
|
| 249 |
-
translation.to(points.device),
|
| 250 |
-
)
|
| 251 |
-
points = torch.matmul(points, rot_mat)
|
| 252 |
-
points = torch.matmul(points, rot_mat2)
|
| 253 |
-
return points - translation
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
# OpenAI Service
|
| 257 |
-
class OpenAIService:
|
| 258 |
-
def __init__(self):
|
| 259 |
-
self.model_name = "gpt-4o"
|
| 260 |
-
self.temperature = 0.3
|
| 261 |
-
self.client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
| 262 |
-
|
| 263 |
-
@staticmethod
|
| 264 |
-
def encode_image(image: Union[str, np.ndarray]) -> str:
|
| 265 |
-
if isinstance(image, str):
|
| 266 |
-
with open(image, "rb") as image_file:
|
| 267 |
-
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 268 |
-
elif isinstance(image, np.ndarray):
|
| 269 |
-
_, buffer = cv2.imencode(".jpg", image)
|
| 270 |
-
return base64.b64encode(buffer).decode("utf-8")
|
| 271 |
-
raise TypeError("Input must be a file path or a NumPy array.")
|
| 272 |
-
|
| 273 |
-
async def chat_with_image(
|
| 274 |
-
self, prompt: str, image: str, retry_left: int = 3
|
| 275 |
-
) -> str:
|
| 276 |
-
base64_image = self.encode_image(image=image)
|
| 277 |
-
model_kwargs = {
|
| 278 |
-
"model": self.model_name,
|
| 279 |
-
"temperature": self.temperature,
|
| 280 |
-
"messages": [
|
| 281 |
-
{
|
| 282 |
-
"role": "user",
|
| 283 |
-
"content": [
|
| 284 |
-
{"type": "text", "text": prompt},
|
| 285 |
-
{
|
| 286 |
-
"type": "image_url",
|
| 287 |
-
"image_url": {
|
| 288 |
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
| 289 |
-
},
|
| 290 |
-
},
|
| 291 |
-
],
|
| 292 |
-
}
|
| 293 |
-
],
|
| 294 |
-
}
|
| 295 |
-
try:
|
| 296 |
-
response = await self.client.chat.completions.create(**model_kwargs)
|
| 297 |
-
return response.choices[0].message.content
|
| 298 |
-
except Exception as e:
|
| 299 |
-
if retry_left > 0:
|
| 300 |
-
logger.warning(f"OpenAI API failed: {e}. Retrying.")
|
| 301 |
-
await asyncio.sleep(1)
|
| 302 |
-
return await self.chat_with_image(prompt, image, retry_left - 1)
|
| 303 |
-
logger.error(f"OpenAI API failed: {e}. Returning empty string.")
|
| 304 |
-
return ""
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# ==============================================================================
|
| 308 |
-
# PHẦN 2: LOGIC XỬ LÝ 3D (TỪ NOTEBOOK)
|
| 309 |
-
# ==============================================================================
|
| 310 |
-
|
| 311 |
-
# Khởi tạo các model và service
|
| 312 |
-
llm_service = OpenAIService()
|
| 313 |
-
pc_views = Realistic_Projection()
|
| 314 |
-
clip_embedding_model = ClipEmbedding(embed_batch_size=1536)
|
| 315 |
-
text_embedding_model = OpenAIEmbedding(
|
| 316 |
-
mode=OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
| 317 |
-
model="text-embedding-3-small",
|
| 318 |
-
api_key=OPENAI_API_KEY,
|
| 319 |
-
dimensions=1536,
|
| 320 |
-
)
|
| 321 |
|
| 322 |
|
| 323 |
-
# Chuyển đổi file STEP/FCStd sang OBJ
|
| 324 |
def convert_step_to_obj_with_freecad(step_path, obj_path):
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
_, ext = os.path.splitext(step_path)
|
| 327 |
ext = ext.lower()
|
| 328 |
-
script_template = ""
|
| 329 |
if ext in FREECAD_LOW_LEVEL_FORMAT:
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
elif ext in FREECAD_NATIVE_FORMAT:
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
else:
|
| 334 |
-
|
|
|
|
| 335 |
|
| 336 |
-
|
| 337 |
command = [freecad_executable, "-c", python_script]
|
|
|
|
|
|
|
| 338 |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
|
|
|
|
|
| 339 |
stdout, stderr = process.communicate()
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
f"FreeCAD conversion failed for {step_path}. Stderr: {stderr.decode()}"
|
| 343 |
-
)
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
|
|
|
|
|
|
| 346 |
def convert_to_obj(file: str) -> str:
|
| 347 |
if file is None:
|
| 348 |
return None
|
| 349 |
logger.info(f"Converting {file} to .obj")
|
|
|
|
| 350 |
prefix_path, ext = os.path.splitext(file)
|
| 351 |
ext = ext.lower()
|
| 352 |
if ext in FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT:
|
|
@@ -355,61 +119,364 @@ def convert_to_obj(file: str) -> str:
|
|
| 355 |
convert_step_to_obj_with_freecad(file, response_path)
|
| 356 |
return response_path
|
| 357 |
elif ext in GRADIO_3D_MODEL_DEFAULT_FORMAT:
|
| 358 |
-
return
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
-
# Render ảnh chiều sâu
|
| 363 |
def render_depth_images_from_obj(obj_path: str, imsize: int = 512) -> List[np.ndarray]:
|
| 364 |
mesh = trimesh.load_mesh(obj_path)
|
| 365 |
-
points: Tensor = torch.tensor(mesh.vertices).float()
|
|
|
|
|
|
|
| 366 |
images: Tensor = pc_views.get_img(points)
|
| 367 |
images = torch.nn.functional.interpolate(
|
| 368 |
images, size=(imsize, imsize), mode="bilinear", align_corners=True
|
| 369 |
)
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
def aggregate_images(
|
| 374 |
np_images: list[np.ndarray], n_rows: int = 2, n_cols: int = 5
|
| 375 |
) -> np.ndarray:
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
(
|
|
|
|
| 379 |
)
|
|
|
|
| 380 |
for i, img in enumerate(np_images):
|
| 381 |
-
row
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
| 384 |
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
|
| 390 |
async def generate_description_from_aggregated_depth_map(np_image: np.ndarray) -> str:
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
|
| 396 |
async def aget_image_embedding_from_np_image(np_image: np.ndarray):
|
|
|
|
| 397 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
return image_embedding
|
| 404 |
|
| 405 |
|
| 406 |
-
# Embedding 3D Object
|
| 407 |
async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
|
|
|
|
| 408 |
depth_images = render_depth_images_from_obj(obj_path=obj_path)
|
|
|
|
| 409 |
aggregated_image = aggregate_images(depth_images)
|
|
|
|
| 410 |
description = await generate_description_from_aggregated_depth_map(
|
| 411 |
np_image=aggregated_image
|
| 412 |
)
|
|
|
|
| 413 |
image_embedding = await aget_image_embedding_from_np_image(
|
| 414 |
np_image=aggregated_image
|
| 415 |
)
|
|
@@ -421,250 +488,205 @@ async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
|
|
| 421 |
}
|
| 422 |
|
| 423 |
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
r"FILE_DESCRIPTION\s*\(\s*\((.*?)\),\s*'(.*?)'\);", content, re.DOTALL
|
| 432 |
-
)
|
| 433 |
-
if desc_match:
|
| 434 |
-
metadata["Description"] = desc_match.group(1).replace("'", "")
|
| 435 |
-
name_match = re.search(
|
| 436 |
-
r"FILE_NAME\s*\(\s*'(.*?)',.*?,'(.*?)'", content, re.DOTALL
|
| 437 |
-
)
|
| 438 |
-
if name_match:
|
| 439 |
-
metadata["FileName"], metadata["OriginatingSystem"] = (
|
| 440 |
-
name_match.group(1),
|
| 441 |
-
name_match.group(2),
|
| 442 |
-
)
|
| 443 |
-
except Exception as e:
|
| 444 |
-
logger.error(f"Failed to read STEP file: {e}")
|
| 445 |
-
return metadata
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
def dict_to_markdown(metadata: dict) -> str:
|
| 449 |
-
return "\\n".join(f"{key}: {value}" for key, value in metadata.items())
|
| 450 |
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
if original_filepath.lower().endswith((".step")):
|
| 456 |
-
meta = extract_step_metadata(original_filepath)
|
| 457 |
-
return dict_to_markdown(meta) if meta else "No metadata found in STEP file."
|
| 458 |
-
logger.warning(f"No metadata parser for file {original_filepath}")
|
| 459 |
-
return "No metadata found."
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
# ==============================================================================
|
| 463 |
-
# PHẦN 3: LOGIC CỦA GRADIO APP
|
| 464 |
-
# ==============================================================================
|
| 465 |
|
| 466 |
|
| 467 |
async def accumulate_and_embedding(input_files, file_list, embedding_dict):
|
|
|
|
| 468 |
if not isinstance(input_files, list):
|
| 469 |
input_files = [input_files]
|
| 470 |
-
new_files = [
|
| 471 |
-
f.name for f in input_files if f.name not in [fi.name for fi in file_list]
|
| 472 |
-
]
|
| 473 |
-
|
| 474 |
-
for file_path in new_files:
|
| 475 |
-
logger.info(f"Processing new upload file: {file_path}")
|
| 476 |
-
try:
|
| 477 |
-
obj_path = convert_to_obj(file_path)
|
| 478 |
-
embeddings = await embedding_3d_object(obj_path)
|
| 479 |
-
if obj_path not in embedding_dict:
|
| 480 |
-
embedding_dict[obj_path] = {}
|
| 481 |
-
embedding_dict[obj_path].update(embeddings)
|
| 482 |
-
except Exception as e:
|
| 483 |
-
logger.error(f"Failed to process {file_path}: {e}")
|
| 484 |
-
gr.Warning(f"Could not process file: {os.path.basename(file_path)}")
|
| 485 |
-
|
| 486 |
-
all_file_paths = [f.name for f in input_files]
|
| 487 |
-
return (
|
| 488 |
-
input_files,
|
| 489 |
-
gr.update(
|
| 490 |
-
choices=all_file_paths, value=all_file_paths[-1] if all_file_paths else None
|
| 491 |
-
),
|
| 492 |
-
embedding_dict,
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
def render_3D_object(filepath) -> Tuple[str, str]:
|
| 497 |
-
if not filepath:
|
| 498 |
-
return None, None
|
| 499 |
-
try:
|
| 500 |
-
obj_path = convert_to_obj(filepath)
|
| 501 |
-
return obj_path, filepath
|
| 502 |
-
except Exception as e:
|
| 503 |
-
logger.error(f"Failed to render {filepath}: {e}")
|
| 504 |
-
gr.Warning(f"Could not render file: {os.path.basename(filepath)}")
|
| 505 |
-
return None, None
|
| 506 |
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
"description"
|
| 516 |
-
|
| 517 |
-
|
| 518 |
|
|
|
|
| 519 |
|
| 520 |
-
def find_top_k_similar(query_embedding, embedding_dict, key, top_k=4):
|
| 521 |
-
valid_items = [
|
| 522 |
-
(path, data[key]) for path, data in embedding_dict.items() if key in data
|
| 523 |
-
]
|
| 524 |
-
if not valid_items:
|
| 525 |
-
gr.Warning("No embeddings available for search.")
|
| 526 |
-
return [None] * top_k + ["-"] * top_k
|
| 527 |
-
|
| 528 |
-
filepaths = [item[0] for item in valid_items]
|
| 529 |
-
feature_matrix = np.array([item[1] for item in valid_items])
|
| 530 |
-
similarities = cosine_similarity(query_embedding.reshape(1, -1), feature_matrix)[0]
|
| 531 |
-
scores = sorted(
|
| 532 |
-
list(zip(filepaths, similarities)), key=lambda x: x[1], reverse=True
|
| 533 |
-
)
|
| 534 |
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
return
|
| 542 |
|
| 543 |
|
| 544 |
-
def
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
|
|
|
|
|
|
| 549 |
):
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
query_embedding = np.array(embedding_dict[filepath]["image_embedding"])
|
| 554 |
-
# Exclude the query file itself from the search
|
| 555 |
-
search_dict = {k: v for k, v in embedding_dict.items() if k != filepath}
|
| 556 |
-
return find_top_k_similar(query_embedding, search_dict, "image_embedding", top_k)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
def query_3D_object(query: str, embedding_dict: dict, top_k: int = 4):
|
| 560 |
-
if not query.strip():
|
| 561 |
-
gr.Warning("Query cannot be empty!")
|
| 562 |
-
return [None] * top_k + ["-"] * top_k
|
| 563 |
-
if len(embedding_dict) < 1:
|
| 564 |
-
gr.Warning("Please upload and process at least one 3D file.")
|
| 565 |
-
return [None] * top_k + ["-"] * top_k
|
| 566 |
-
|
| 567 |
-
query_embedding = np.array(text_embedding_model.get_text_embedding(text=query))
|
| 568 |
-
return find_top_k_similar(query_embedding, embedding_dict, "text_embedding", top_k)
|
| 569 |
|
| 570 |
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
|
|
|
| 579 |
)
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
embedding_store = gr.State({})
|
| 584 |
-
|
| 585 |
-
with gr.Row():
|
| 586 |
-
with gr.Column(scale=1):
|
| 587 |
-
file_input = gr.File(
|
| 588 |
-
file_count="multiple",
|
| 589 |
-
label="1. Tải lên File 3D",
|
| 590 |
-
file_types=VALID_FILE_TYPES,
|
| 591 |
-
)
|
| 592 |
-
file_dropdown = gr.Dropdown(
|
| 593 |
-
label="2. Chọn File để xem và tìm kiếm", interactive=True
|
| 594 |
-
)
|
| 595 |
-
sim_button = gr.Button("🔍 Tìm kiếm Tương tự", variant="primary")
|
| 596 |
-
query_input = gr.Textbox(
|
| 597 |
-
label="Hoặc, truy vấn bằng văn bản",
|
| 598 |
-
placeholder="ví dụ: một bộ phận có hai lỗ xuyên...",
|
| 599 |
-
)
|
| 600 |
-
query_button = gr.Button("💬 Tìm kiếm theo Văn bản", variant="primary")
|
| 601 |
-
|
| 602 |
-
with gr.Column(scale=2):
|
| 603 |
-
gr.Markdown("### **Trình xem và Thông tin Chi tiết**")
|
| 604 |
-
model_render = gr.Model3D(label="Mô hình 3D", height=400, interactive=False)
|
| 605 |
-
model_hidden_filepath = gr.Textbox(visible=False)
|
| 606 |
-
original_hidden_filepath = gr.Textbox(visible=False)
|
| 607 |
-
with gr.Accordion("📝 Mô tả & Metadata", open=False):
|
| 608 |
-
description_render = gr.Textbox(label="Mô tả (tạo bởi AI)", lines=8)
|
| 609 |
-
metadata_render = gr.Textbox(
|
| 610 |
-
label="Metadata (trích xuất từ file)", lines=4
|
| 611 |
-
)
|
| 612 |
-
|
| 613 |
with gr.Row():
|
| 614 |
-
gr.
|
| 615 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
|
| 617 |
with gr.Row():
|
| 618 |
with gr.Column():
|
| 619 |
-
gr.
|
|
|
|
|
|
|
| 620 |
with gr.Row():
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
with gr.Row():
|
| 624 |
-
|
| 625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
with gr.Column():
|
| 627 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
with gr.Row():
|
| 629 |
-
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
with gr.Row():
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
| 635 |
-
|
| 636 |
-
file_input.upload(
|
| 637 |
fn=accumulate_and_embedding,
|
| 638 |
inputs=[file_input, file_state, embedding_store],
|
| 639 |
outputs=[file_state, file_dropdown, embedding_store],
|
| 640 |
)
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
)
|
| 651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
sim_button.click(
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
model_s_1,
|
| 657 |
model_s_2,
|
| 658 |
model_s_3,
|
| 659 |
model_s_4,
|
| 660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
)
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
)
|
| 668 |
|
| 669 |
if __name__ == "__main__":
|
| 670 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
import random
|
| 5 |
+
import subprocess # used to connect to FreeCAD via terminal sub process
|
| 6 |
+
import sys
|
|
|
|
|
|
|
| 7 |
import tempfile
|
| 8 |
+
from typing import Any, Dict, List, Tuple
|
|
|
|
| 9 |
|
| 10 |
+
import gradio as gr # demo with gradio
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
+
import torchvision.transforms.functional as TF
|
| 14 |
+
import trimesh
|
| 15 |
+
from llama_index.embeddings.clip import ClipEmbedding
|
| 16 |
+
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode
|
| 17 |
from loguru import logger
|
| 18 |
+
from PIL import Image
|
| 19 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 20 |
from torch import Tensor
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
from llm_service import LLMService
|
| 23 |
+
from mv_utils_zs import Realistic_Projection
|
| 24 |
|
| 25 |
+
os.environ.get("GRADIO_TEMP_DIR", "gradio_cache") # You must set it in `.env` file also
|
| 26 |
+
os_name = platform.system()
|
| 27 |
|
| 28 |
+
if os_name == "Linux":
|
| 29 |
+
print("Running on Linux")
|
| 30 |
+
elif os_name == "Darwin":
|
| 31 |
+
print("Running on macOS")
|
| 32 |
+
else:
|
| 33 |
+
print(f"Running on an unsupported OS: {os_name}")
|
| 34 |
|
| 35 |
+
# The Gradio 3D Model component default accept
|
| 36 |
GRADIO_3D_MODEL_DEFAULT_FORMAT = [".obj", ".glb", ".gltf", ".stl", ".splat", ".ply"]
|
| 37 |
USER_REQUIRE_FORMAT = [".3dxml", ".step"]
|
| 38 |
FREECAD_LOW_LEVEL_FORMAT = [".step", ".igs", ".iges"]
|
| 39 |
FREECAD_NATIVE_FORMAT = [".fcstd"]
|
| 40 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
| 41 |
+
|
| 42 |
+
####################################################################################################################
|
| 43 |
+
# Transform high-level to low-level
|
| 44 |
+
####################################################################################################################
|
| 45 |
+
# 3D Component of Gradio only allow some kind of format to render in the UI. We need to transform if need it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
|
|
|
| 48 |
def convert_step_to_obj_with_freecad(step_path, obj_path):
|
| 49 |
+
# Path to the FreeCAD executable
|
| 50 |
+
global os_name
|
| 51 |
+
if os_name == "Linux":
|
| 52 |
+
freecad_executable = "/usr/bin/freecadcmd" # freecadcmd
|
| 53 |
+
elif os_name == "Darwin":
|
| 54 |
+
freecad_executable = "/Applications/FreeCAD.app/Contents/MacOS/FreeCAD"
|
| 55 |
+
# Python script to be executed by FreeCAD
|
| 56 |
_, ext = os.path.splitext(step_path)
|
| 57 |
ext = ext.lower()
|
|
|
|
| 58 |
if ext in FREECAD_LOW_LEVEL_FORMAT:
|
| 59 |
+
python_script = """
|
| 60 |
+
import FreeCAD
|
| 61 |
+
import Part
|
| 62 |
+
import Mesh
|
| 63 |
+
|
| 64 |
+
doc = FreeCAD.newDocument()
|
| 65 |
+
shape = Part.read("{step_path}")
|
| 66 |
+
obj = doc.addObject("Part::Feature", "MyPart")
|
| 67 |
+
obj.Shape = shape
|
| 68 |
+
doc.recompute()
|
| 69 |
+
|
| 70 |
+
Mesh.export([obj], "{obj_path}")
|
| 71 |
+
""".format(step_path=step_path, obj_path=obj_path)
|
| 72 |
elif ext in FREECAD_NATIVE_FORMAT:
|
| 73 |
+
python_script = """
|
| 74 |
+
import FreeCAD
|
| 75 |
+
import Part
|
| 76 |
+
import Mesh
|
| 77 |
+
|
| 78 |
+
doc = FreeCAD.open("{step_path}")
|
| 79 |
+
to_export = [o for o in doc.Objects if hasattr(o, 'Shape')]
|
| 80 |
+
Mesh.export(to_export, "{obj_path}")
|
| 81 |
+
""".format(step_path=step_path, obj_path=obj_path)
|
| 82 |
else:
|
| 83 |
+
logger.error(f"Not support {ext} format")
|
| 84 |
+
raise Exception(f"Not support {ext} format")
|
| 85 |
|
| 86 |
+
# Command to run FreeCAD in headless mode with the provided Python script
|
| 87 |
command = [freecad_executable, "-c", python_script]
|
| 88 |
+
|
| 89 |
+
# Run the command using subprocess
|
| 90 |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 91 |
+
|
| 92 |
+
# Capture the output and errors
|
| 93 |
stdout, stderr = process.communicate()
|
| 94 |
+
return stdout.decode(), stderr.decode()
|
| 95 |
+
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
# input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/Switches/TS6-THT_H-5.0.step" # ok
|
| 98 |
+
# input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_con_Z16_mod_1_5-Body.stl" # ok
|
| 99 |
+
# input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/nema-17-stepper-motors-coaxial-60-48-39-23mm-1.snapshot.3/NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.step" # ok
|
| 100 |
+
# input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_con_Z16_mod_1_5.FCStd" # ok
|
| 101 |
+
# input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_reta_Z_15_mod_1.FCStd" # ok
|
| 102 |
+
# input_path = "/content/TS6-THT_H-5.0.step"
|
| 103 |
+
# print(".".join(input_path.split(".")[:-1]) + ".obj")
|
| 104 |
+
# stdout, stderr = convert_step_to_obj_with_freecad(input_path, ".".join(input_path.split(".")[:-1]) + ".obj")
|
| 105 |
+
# stderr
|
| 106 |
|
| 107 |
+
|
| 108 |
+
# Dummy converter from STEP/3DXML to OBJ (replace with real converter)
|
| 109 |
def convert_to_obj(file: str) -> str:
|
| 110 |
if file is None:
|
| 111 |
return None
|
| 112 |
logger.info(f"Converting {file} to .obj")
|
| 113 |
+
response_path = file
|
| 114 |
prefix_path, ext = os.path.splitext(file)
|
| 115 |
ext = ext.lower()
|
| 116 |
if ext in FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT:
|
|
|
|
| 119 |
convert_step_to_obj_with_freecad(file, response_path)
|
| 120 |
return response_path
|
| 121 |
elif ext in GRADIO_3D_MODEL_DEFAULT_FORMAT:
|
| 122 |
+
return response_path
|
| 123 |
+
else:
|
| 124 |
+
logger.warning(f"Do nothing at convert_to_obj with file {file}")
|
| 125 |
+
raise Exception(f"Do nothing at convert_to_obj with file {file}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
####################################################################################################################
|
| 129 |
+
# Feature Extraction
|
| 130 |
+
####################################################################################################################
|
| 131 |
+
# We have 2 approaches to extract 3D's features:
|
| 132 |
+
# - By algorithm which extract something like volume, surface
|
| 133 |
+
# - By 3D deep learning model, which embed the 3D object into vector representing 3D's features
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def extract_geometric_features(obj_path: str): # depricated
|
| 137 |
+
try:
|
| 138 |
+
mesh = trimesh.load(obj_path)
|
| 139 |
+
volume = mesh.volume # type: ignore
|
| 140 |
+
surface_area = mesh.area # type: ignore
|
| 141 |
+
print("volume", volume)
|
| 142 |
+
print("surface_area", surface_area)
|
| 143 |
+
# Add other features depending on your needs
|
| 144 |
+
features = np.array([volume, surface_area]).reshape(1, -1)
|
| 145 |
+
return features
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error reading file {obj_path}: {e}")
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
####################################################################################################################
|
| 152 |
+
# Similarity Search
|
| 153 |
+
####################################################################################################################
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def search_3D_similarity(filepath: str, embedding_dict: dict, top_k: int = 4):
|
| 157 |
+
if len(embedding_dict) < 5:
|
| 158 |
+
raise gr.Error("Require at least 5 3D files to search similarity")
|
| 159 |
+
if (
|
| 160 |
+
filepath not in embedding_dict
|
| 161 |
+
or "image_embedding" not in embedding_dict[filepath]
|
| 162 |
+
):
|
| 163 |
+
raise ValueError(f"No embedding found for {filepath}")
|
| 164 |
+
|
| 165 |
+
features1 = np.array(embedding_dict[filepath]["image_embedding"]).reshape(1, -1)
|
| 166 |
+
|
| 167 |
+
# List to store (path, similarity)
|
| 168 |
+
valid_items = [
|
| 169 |
+
(fp, data["image_embedding"])
|
| 170 |
+
for fp, data in embedding_dict.items()
|
| 171 |
+
if "image_embedding" in data and fp != filepath
|
| 172 |
+
]
|
| 173 |
+
filepaths = [fp for fp, _ in valid_items]
|
| 174 |
+
feature_matrix = np.array([feat for _, feat in valid_items]) # shape: (N, D)
|
| 175 |
+
similarities = cosine_similarity(features1, feature_matrix)[0] # shape: (N,)
|
| 176 |
+
scores = list(zip(filepaths, similarities))
|
| 177 |
+
|
| 178 |
+
# Sort by similarity in descending order
|
| 179 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 180 |
+
|
| 181 |
+
if len(scores) < 4:
|
| 182 |
+
scores.append(("", 0.0))
|
| 183 |
+
|
| 184 |
+
# Return top_k results
|
| 185 |
+
return [x[0] for x in scores[:top_k]] + [
|
| 186 |
+
os.path.basename(x[0]) for x in scores[:top_k]
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
####################################################################################################################
|
| 191 |
+
# Text-based Query
|
| 192 |
+
####################################################################################################################
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def query_3D_object(query: str, embedding_dict: dict, top_k: int = 4):
|
| 196 |
+
if query == "":
|
| 197 |
+
raise gr.Error("Query cannot be empty!")
|
| 198 |
+
if len(embedding_dict) < 4:
|
| 199 |
+
raise gr.Error("Require at least 4 3D files to query by features")
|
| 200 |
+
|
| 201 |
+
features1 = np.array(text_embedding_model.get_text_embedding(text=query)).reshape(
|
| 202 |
+
1, -1
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# List to store (path, similarity)
|
| 206 |
+
valid_items = [
|
| 207 |
+
(fp, data["text_embedding"])
|
| 208 |
+
for fp, data in embedding_dict.items()
|
| 209 |
+
if "text_embedding" in data
|
| 210 |
+
]
|
| 211 |
+
filepaths = [fp for fp, _ in valid_items]
|
| 212 |
+
feature_matrix = np.array([feat for _, feat in valid_items]) # shape: (N, D)
|
| 213 |
+
similarities = cosine_similarity(features1, feature_matrix)[0] # shape: (N,)
|
| 214 |
+
scores = list(zip(filepaths, similarities))
|
| 215 |
+
|
| 216 |
+
# Sort by similarity in descending order
|
| 217 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 218 |
+
|
| 219 |
+
if len(scores) < 4:
|
| 220 |
+
scores.append(("", 0.0))
|
| 221 |
+
|
| 222 |
+
# Return top_k results
|
| 223 |
+
return [x[0] for x in scores[:top_k]] + [
|
| 224 |
+
os.path.basename(x[0]) for x in scores[:top_k]
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
####################################################################################################################
|
| 229 |
+
# Metadata Extraction
|
| 230 |
+
####################################################################################################################
|
| 231 |
+
|
| 232 |
+
import os
|
| 233 |
+
import xml.etree.ElementTree as ET
|
| 234 |
+
import zipfile
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def extract_header_from_3dxml(file_path):
|
| 238 |
+
header_info = {}
|
| 239 |
+
|
| 240 |
+
# Step 1: Unzip the .3DXML file
|
| 241 |
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
| 242 |
+
zip_ref.extractall("tmp_3dxml_extract")
|
| 243 |
+
|
| 244 |
+
# Step 2: Find and parse the XML containing <Header>
|
| 245 |
+
for root, dirs, files in os.walk("tmp_3dxml_extract"):
|
| 246 |
+
for file in files:
|
| 247 |
+
if file.endswith((".3dxml", ".xml")):
|
| 248 |
+
xml_path = os.path.join(root, file)
|
| 249 |
+
try:
|
| 250 |
+
tree = ET.parse(xml_path)
|
| 251 |
+
root_el = tree.getroot()
|
| 252 |
+
ns = {
|
| 253 |
+
"ns": root_el.tag.split("}")[0].strip("{")
|
| 254 |
+
} # Extract namespace
|
| 255 |
+
|
| 256 |
+
header = root_el.find("ns:Header", ns)
|
| 257 |
+
if header is not None:
|
| 258 |
+
for child in header:
|
| 259 |
+
tag = child.tag.split("}")[-1] # Remove namespace
|
| 260 |
+
value = child.text.strip() if child.text else ""
|
| 261 |
+
header_info[tag] = value
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f"Failed to parse {file}: {e}")
|
| 264 |
+
|
| 265 |
+
return header_info
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
#######################################################################################################################
|
| 269 |
+
|
| 270 |
+
import re
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def extract_step_metadata(file_path):
|
| 274 |
+
metadata = {}
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
| 278 |
+
content = f.read()
|
| 279 |
+
|
| 280 |
+
# Extract FILE_DESCRIPTION
|
| 281 |
+
desc_match = re.search(
|
| 282 |
+
r"FILE_DESCRIPTION\s*\(\s*\((.*?)\),\s*\'(.*?)\'\);", content, re.DOTALL
|
| 283 |
+
)
|
| 284 |
+
if desc_match:
|
| 285 |
+
metadata["Description"] = desc_match.group(1).replace("'", "")
|
| 286 |
+
metadata["Description_Level"] = desc_match.group(2)
|
| 287 |
+
|
| 288 |
+
# Extract FILE_NAME
|
| 289 |
+
name_match = re.search(
|
| 290 |
+
r"FILE_NAME\s*\(\s*'(.*?)',\s*'(.*?)',\s*\((.*?)\),\s*\((.*?)\),\s*'(.*?)',\s*'(.*?)',\s*'(.*?)'\s*\);",
|
| 291 |
+
content,
|
| 292 |
+
re.DOTALL,
|
| 293 |
+
)
|
| 294 |
+
if name_match:
|
| 295 |
+
metadata["FileName"] = name_match.group(1)
|
| 296 |
+
metadata["Created"] = name_match.group(2)
|
| 297 |
+
metadata["Authors"] = name_match.group(3).replace("'", "")
|
| 298 |
+
metadata["Organizations"] = name_match.group(4).replace("'", "")
|
| 299 |
+
metadata["Preprocessor"] = name_match.group(5)
|
| 300 |
+
metadata["OriginatingSystem"] = name_match.group(6)
|
| 301 |
+
metadata["Authorization"] = name_match.group(7)
|
| 302 |
+
|
| 303 |
+
# Extract FILE_SCHEMA
|
| 304 |
+
schema_match = re.search(
|
| 305 |
+
r"FILE_SCHEMA\s*\(\s*\((.*?)\)\s*\);", content, re.DOTALL
|
| 306 |
+
)
|
| 307 |
+
if schema_match:
|
| 308 |
+
metadata["Schema"] = schema_match.group(1).replace("'", "")
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"Failed to read STEP file: {e}")
|
| 312 |
+
|
| 313 |
+
return metadata
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
#######################################################################################################################
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def dict_to_markdown(metadata: dict) -> str:
|
| 320 |
+
return "\n".join(f"{key}: {value}" for key, value in metadata.items())
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
#######################################################################################################################
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# Dummy parser - Replace with real parser
|
| 327 |
+
def parse_3d_file(original_filepath: str):
|
| 328 |
+
if original_filepath is None:
|
| 329 |
+
return "No file"
|
| 330 |
+
if original_filepath.endswith((".3dxml", ".3DXML")):
|
| 331 |
+
meta = extract_header_from_3dxml(original_filepath)
|
| 332 |
+
text = dict_to_markdown(meta)
|
| 333 |
+
return f"Parsed metadata: {text}"
|
| 334 |
+
elif original_filepath.endswith((".step", ".STEP")):
|
| 335 |
+
meta = extract_step_metadata(original_filepath)
|
| 336 |
+
text = dict_to_markdown(meta)
|
| 337 |
+
return f"Parsed metadata: {text}"
|
| 338 |
+
logger.warning(f"No metadata found in the file {original_filepath}")
|
| 339 |
+
return "No metadata found!"
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def render_3D_metadata(
|
| 343 |
+
original_filepath: str, obj_path: str, embedding_dict: dict
|
| 344 |
+
) -> Tuple[str, str]:
|
| 345 |
+
return parse_3d_file(original_filepath=original_filepath), embedding_dict.get(
|
| 346 |
+
obj_path, {}
|
| 347 |
+
).get("description", "No description found!")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
#######################################################################################################################
|
| 351 |
+
# https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/zeroshot.py#L64
|
| 352 |
+
#######################################################################################################################
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
pc_views = Realistic_Projection()
|
| 356 |
|
| 357 |
|
|
|
|
| 358 |
def render_depth_images_from_obj(obj_path: str, imsize: int = 512) -> List[np.ndarray]:
|
| 359 |
mesh = trimesh.load_mesh(obj_path)
|
| 360 |
+
points: Tensor = torch.tensor(mesh.vertices).float()
|
| 361 |
+
if points.ndim == 2:
|
| 362 |
+
points = points.unsqueeze(0) # (1, N, 3)
|
| 363 |
images: Tensor = pc_views.get_img(points)
|
| 364 |
images = torch.nn.functional.interpolate(
|
| 365 |
images, size=(imsize, imsize), mode="bilinear", align_corners=True
|
| 366 |
)
|
| 367 |
+
np_images: List[np.ndarray] = []
|
| 368 |
+
for tensor_image in images:
|
| 369 |
+
np_images.append(np.array(TF.to_pil_image(tensor_image.cpu())))
|
| 370 |
+
return np_images
|
| 371 |
|
| 372 |
|
| 373 |
def aggregate_images(
|
| 374 |
np_images: list[np.ndarray], n_rows: int = 2, n_cols: int = 5
|
| 375 |
) -> np.ndarray:
|
| 376 |
+
img_height, img_width = np_images[0].shape[:2]
|
| 377 |
+
aggregate_img = np.zeros(
|
| 378 |
+
(img_height * n_rows, img_width * n_cols, np_images[0].shape[2]),
|
| 379 |
+
dtype=np_images[0].dtype,
|
| 380 |
)
|
| 381 |
+
|
| 382 |
for i, img in enumerate(np_images):
|
| 383 |
+
row = i // n_cols
|
| 384 |
+
col = i % n_cols
|
| 385 |
+
aggregate_img[
|
| 386 |
+
row * img_height : (row + 1) * img_height,
|
| 387 |
+
col * img_width : (col + 1) * img_width,
|
| 388 |
+
] = img
|
| 389 |
|
| 390 |
+
return aggregate_img
|
| 391 |
|
| 392 |
+
|
| 393 |
+
llm_service = LLMService.from_partner()
|
| 394 |
+
# llm_service.model_name = "o3-mini"
|
| 395 |
+
|
| 396 |
+
DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT = """You are a manufacturing expert analyzing 3D objects for production purposes. Given a set of multi-view depth maps of a single object, extract all possible special features relevant to manufacturing.
|
| 397 |
+
|
| 398 |
+
Your output must follow the structured format provided below and be as complete and specific as possible, even if some features are inferred or uncertain.
|
| 399 |
+
```
|
| 400 |
+
🔎 Extracted Manufacturing Features from Depth Maps
|
| 401 |
+
|
| 402 |
+
1. Geometric Features
|
| 403 |
+
Dimensions: <!-- List key dimensions such as height, width, depth, thickness, or aspect ratios. Use units if possible. Mention estimated ranges if exact values are unclear. -->
|
| 404 |
+
Notable Shapes: <!-- Describe the overall shape and form (e.g., cylindrical body with a tapered end, flat rectangular base, spherical top). Mention symmetry or irregularities. -->
|
| 405 |
+
Holes: <!-- Count and describe hole types (e.g., through-holes, blind holes), location if visible, and their arrangement or pattern (e.g., circular array, linear slot). -->
|
| 406 |
+
Surface Features: <!-- Include textures, fillets, chamfers, ribs, grooves, steps, and engravings. Identify raised or recessed areas that are not part of the base shape. -->
|
| 407 |
+
Other: <!-- Any other geometric characteristics not covered above (e.g., draft angles, deformation, cutouts). -->
|
| 408 |
+
|
| 409 |
+
2. Material-Related Inferences
|
| 410 |
+
Likely Material: <!-- Infer from shape, thickness, or typical use cases (e.g., plastic, aluminum, cast iron). State if uncertain or not visible. -->
|
| 411 |
+
Surface Texture: <!-- Describe the expected finish (e.g., rough, matte, polished) based on depth gradients or edge sharpness. -->
|
| 412 |
+
Durability Hints: <!-- Mention any features that suggest mechanical strength or wear resistance (e.g., thick load-bearing sections, reinforcement patterns). -->
|
| 413 |
+
|
| 414 |
+
3. Manufacturing-Related Features
|
| 415 |
+
Manufacturing Process: <!-- Suggest most likely processes (e.g., injection molding, CNC milling, casting) based on geometry and typical industry practices. -->
|
| 416 |
+
Draft Angles: <!-- Indicate presence and estimate angles if the object appears designed for mold release. -->
|
| 417 |
+
Undercuts: <!-- Identify any undercut areas that may require complex tooling or multi-part molds. -->
|
| 418 |
+
Mold Flow Considerations: <!-- Comment on how the material might flow during molding or casting, and whether the geometry supports or hinders it. -->
|
| 419 |
+
|
| 420 |
+
4. Functional and Assembly Features
|
| 421 |
+
Mounting Points: <!-- Identify places where fasteners or brackets might attach (e.g., holes, bosses, flanges). -->
|
| 422 |
+
Jointing Features: <!-- Describe features used to join with other parts, such as snap fits, tabs, slots, dovetails, etc. -->
|
| 423 |
+
Alignment Aids: <!-- Note features like pins, grooves, or guide rails that help align components during assembly. -->
|
| 424 |
+
Modularity: <!-- Assess whether the object is likely part of a modular system based on interface shapes or repeated features. -->
|
| 425 |
+
|
| 426 |
+
5. Inspection and Quality Features
|
| 427 |
+
Critical Dimensions: <!-- Highlight any dimensions likely to be functionally critical or require tight tolerance. -->
|
| 428 |
+
Surface Finish Zones: <!-- Point out areas that may require fine finishing or polishing for performance or cosmetic reasons. -->
|
| 429 |
+
Datums: <!-- Indicate flat surfaces or edges likely to serve as reference datums during measurement or machining. -->
|
| 430 |
+
Tolerances: <!-- Mention if any tolerances can be inferred, e.g., tight fits, loose clearances, or any standard class assumptions. -->
|
| 431 |
+
|
| 432 |
+
```
|
| 433 |
+
If any feature cannot be determined from the depth maps, state “Not visible” or “Cannot be inferred.”
|
| 434 |
+
Use clear technical vocabulary appropriate for manufacturing and quality control."""
|
| 435 |
|
| 436 |
|
| 437 |
async def generate_description_from_aggregated_depth_map(np_image: np.ndarray) -> str:
|
| 438 |
+
test_prompt = DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT
|
| 439 |
+
base64_image = llm_service.encode_image(image=np_image)
|
| 440 |
+
return await llm_service.chat_with_image(prompt=test_prompt, image=base64_image)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
clip_embedding_model = ClipEmbedding(
|
| 444 |
+
embed_batch_size=1536, # this parameter does not effect to the model
|
| 445 |
+
)
|
| 446 |
+
text_embedding_model = OpenAIEmbedding(
|
| 447 |
+
mode=OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
| 448 |
+
model="text-embedding-3-small",
|
| 449 |
+
api_key=OPENAI_API_KEY,
|
| 450 |
+
dimensions=1536,
|
| 451 |
+
embed_batch_size=512, # default == 100
|
| 452 |
+
)
|
| 453 |
|
| 454 |
|
| 455 |
async def aget_image_embedding_from_np_image(np_image: np.ndarray):
|
| 456 |
+
# Save np_image to a temporary file
|
| 457 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 458 |
+
temp_file_path = temp_file.name
|
| 459 |
+
# Convert np_image to PIL Image and save it
|
| 460 |
+
Image.fromarray(np_image).save(temp_file_path)
|
| 461 |
+
|
| 462 |
+
image_embedding = await clip_embedding_model.aget_image_embedding(temp_file_path)
|
| 463 |
+
|
| 464 |
+
# Delete the temporary file after processing
|
| 465 |
+
os.remove(temp_file_path)
|
| 466 |
+
|
| 467 |
return image_embedding
|
| 468 |
|
| 469 |
|
|
|
|
| 470 |
async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
|
| 471 |
+
# get 10 depth images
|
| 472 |
depth_images = render_depth_images_from_obj(obj_path=obj_path)
|
| 473 |
+
# aggregate to single image
|
| 474 |
aggregated_image = aggregate_images(depth_images)
|
| 475 |
+
# description
|
| 476 |
description = await generate_description_from_aggregated_depth_map(
|
| 477 |
np_image=aggregated_image
|
| 478 |
)
|
| 479 |
+
# embedding aggregated_image: np.ndarray and description: str
|
| 480 |
image_embedding = await aget_image_embedding_from_np_image(
|
| 481 |
np_image=aggregated_image
|
| 482 |
)
|
|
|
|
| 488 |
}
|
| 489 |
|
| 490 |
|
| 491 |
+
BASE_SAMPLE_DIR = "/Users/tridoan/Spartan/Datum/service-ai/poc/3D/gradio_cache/"
|
| 492 |
+
sample_files = [
|
| 493 |
+
# BASE_SAMPLE_DIR + "C5 Knuckle Object.obj",
|
| 494 |
+
# BASE_SAMPLE_DIR + "NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.obj",
|
| 495 |
+
# BASE_SAMPLE_DIR + "TS6-THT_H-5.0.obj",
|
| 496 |
+
# BASE_SAMPLE_DIR + "TS6-THT_H-11.0.obj"
|
| 497 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
|
| 500 |
+
#######################################################################################################################
|
| 501 |
+
## Accumulating and Rendering 3D
|
| 502 |
+
#######################################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
|
| 505 |
async def accumulate_and_embedding(input_files, file_list, embedding_dict):
|
| 506 |
+
# accumulate
|
| 507 |
if not isinstance(input_files, list):
|
| 508 |
input_files = [input_files]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
+
all_files = input_files
|
| 511 |
+
new_files = input_files[len(file_list) :]
|
| 512 |
|
| 513 |
+
# embedding
|
| 514 |
+
for file_path in new_files:
|
| 515 |
+
logger.info("Processing new upload file:", file_path)
|
| 516 |
+
obj_path = convert_to_obj(file_path)
|
| 517 |
+
embeddings = await embedding_3d_object(obj_path)
|
| 518 |
+
if obj_path not in embedding_dict:
|
| 519 |
+
embedding_dict[obj_path] = {}
|
| 520 |
+
embedding_dict[obj_path]["description"] = embeddings["description"]
|
| 521 |
+
embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
|
| 522 |
+
embedding_dict[obj_path]["text_embedding"] = embeddings["text_embedding"]
|
| 523 |
|
| 524 |
+
return all_files, gr.update(choices=all_files), embedding_dict
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
+
def select_file(filename, file_list):
|
| 528 |
+
for file in file_list:
|
| 529 |
+
if file.name == filename:
|
| 530 |
+
with open(file.name, "r", encoding="utf-8", errors="ignore") as f:
|
| 531 |
+
content = f.read()
|
| 532 |
+
return f"Selected: {file.name}\n---\n{content[:300]}..."
|
| 533 |
+
return "File not found."
|
| 534 |
|
| 535 |
|
| 536 |
+
def render_3D_object(filepath) -> Tuple[str, str]:
|
| 537 |
+
_, ext = os.path.splitext(filepath)
|
| 538 |
+
ext = ext.lower()
|
| 539 |
+
if ext in tuple(GRADIO_3D_MODEL_DEFAULT_FORMAT):
|
| 540 |
+
return filepath, filepath
|
| 541 |
+
if ext in tuple(
|
| 542 |
+
USER_REQUIRE_FORMAT + FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT
|
| 543 |
):
|
| 544 |
+
return convert_to_obj(filepath), filepath
|
| 545 |
+
return filepath, filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
|
| 548 |
+
#######################################################################################################################
|
| 549 |
+
## Launching Gradio server
|
| 550 |
+
#######################################################################################################################
|
| 551 |
+
valid_file_types = list(
|
| 552 |
+
set(
|
| 553 |
+
GRADIO_3D_MODEL_DEFAULT_FORMAT
|
| 554 |
+
+ USER_REQUIRE_FORMAT
|
| 555 |
+
+ FREECAD_NATIVE_FORMAT
|
| 556 |
+
+ FREECAD_LOW_LEVEL_FORMAT
|
| 557 |
)
|
| 558 |
+
)
|
| 559 |
+
valid_file_types = valid_file_types + [t.upper() for t in valid_file_types]
|
| 560 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
with gr.Row():
|
| 562 |
+
file_state = gr.State(sample_files)
|
| 563 |
+
###################################### !IMPORTANT #############################################################
|
| 564 |
+
embedding_store = gr.State({}) ####### !IMPORTANT. This is in memory vector database ##########################
|
| 565 |
+
file_input = gr.File(
|
| 566 |
+
file_count="multiple",
|
| 567 |
+
label="Upload files (You can append more)",
|
| 568 |
+
file_types=valid_file_types,
|
| 569 |
+
)
|
| 570 |
|
| 571 |
with gr.Row():
|
| 572 |
with gr.Column():
|
| 573 |
+
query_input = gr.Textbox(placeholder="Which 3D CAD contains 2 holes?")
|
| 574 |
+
query_button = gr.Button("Query Search")
|
| 575 |
+
|
| 576 |
with gr.Row():
|
| 577 |
+
with gr.Row():
|
| 578 |
+
model_q_1 = gr.Model3D(
|
| 579 |
+
label="3D Top 1", interactive=False
|
| 580 |
+
) # debugging
|
| 581 |
+
model_q_1_btn = gr.Button(value="3D Top 1", size="sm")
|
| 582 |
+
with gr.Row():
|
| 583 |
+
model_q_2 = gr.Model3D(label="3D Top 2", interactive=False)
|
| 584 |
+
model_q_2_btn = gr.Button(value="3D Top 2", size="sm")
|
| 585 |
+
|
| 586 |
with gr.Row():
|
| 587 |
+
with gr.Row():
|
| 588 |
+
model_q_3 = gr.Model3D(label="3D Top 3", interactive=False)
|
| 589 |
+
model_q_3_btn = gr.Button(value="3D Top 3", size="sm")
|
| 590 |
+
with gr.Row():
|
| 591 |
+
model_q_4 = gr.Model3D(label="3D Top 4", interactive=False)
|
| 592 |
+
model_q_4_btn = gr.Button(value="3D Top 4", size="sm")
|
| 593 |
+
|
| 594 |
with gr.Column():
|
| 595 |
+
model_render = gr.Model3D(label="3D", height=500, interactive=False)
|
| 596 |
+
model_hidden_filepath = gr.Textbox(visible=False)
|
| 597 |
+
description_render = gr.Textbox(label="Description", lines=6)
|
| 598 |
+
metadata_render = gr.Textbox(label="Metadata", lines=6)
|
| 599 |
+
sim_button = gr.Button("Similarity Search")
|
| 600 |
with gr.Row():
|
| 601 |
+
with gr.Row():
|
| 602 |
+
model_s_1 = gr.Model3D(label="3D Sim 1", interactive=False)
|
| 603 |
+
model_s_1_btn = gr.Button(value="3D Sim 1", size="sm")
|
| 604 |
+
with gr.Row():
|
| 605 |
+
model_s_2 = gr.Model3D(label="3D Sim 2", interactive=False)
|
| 606 |
+
model_s_2_btn = gr.Button(value="3D Sim 2", size="sm")
|
| 607 |
with gr.Row():
|
| 608 |
+
with gr.Row():
|
| 609 |
+
model_s_3 = gr.Model3D(label="3D Sim 3", interactive=False)
|
| 610 |
+
model_s_3_btn = gr.Button(value="3D Sim 3", size="sm")
|
| 611 |
+
with gr.Row():
|
| 612 |
+
model_s_4 = gr.Model3D(label="3D Sim 4", interactive=False)
|
| 613 |
+
model_s_4_btn = gr.Button(value="3D Sim 4", size="sm")
|
| 614 |
+
with gr.Column():
|
| 615 |
+
file_dropdown = gr.Dropdown(
|
| 616 |
+
label="Select a file to process", choices=sample_files, interactive=True
|
| 617 |
+
)
|
| 618 |
|
| 619 |
+
file_input.change(
|
|
|
|
| 620 |
fn=accumulate_and_embedding,
|
| 621 |
inputs=[file_input, file_state, embedding_store],
|
| 622 |
outputs=[file_state, file_dropdown, embedding_store],
|
| 623 |
)
|
| 624 |
+
# query button
|
| 625 |
+
query_button.click(
|
| 626 |
+
query_3D_object,
|
| 627 |
+
[query_input, embedding_store],
|
| 628 |
+
[
|
| 629 |
+
model_q_1,
|
| 630 |
+
model_q_2,
|
| 631 |
+
model_q_3,
|
| 632 |
+
model_q_4,
|
| 633 |
+
model_q_1_btn,
|
| 634 |
+
model_q_2_btn,
|
| 635 |
+
model_q_3_btn,
|
| 636 |
+
model_q_4_btn,
|
| 637 |
+
],
|
| 638 |
)
|
| 639 |
+
# model query
|
| 640 |
+
model_q_1_btn.click(
|
| 641 |
+
render_3D_object, model_q_1, [model_render, model_hidden_filepath]
|
| 642 |
+
)
|
| 643 |
+
model_q_2_btn.click(
|
| 644 |
+
render_3D_object, model_q_2, [model_render, model_hidden_filepath]
|
| 645 |
+
)
|
| 646 |
+
model_q_3_btn.click(
|
| 647 |
+
render_3D_object, model_q_3, [model_render, model_hidden_filepath]
|
| 648 |
+
)
|
| 649 |
+
model_q_4_btn.click(
|
| 650 |
+
render_3D_object, model_q_4, [model_render, model_hidden_filepath]
|
| 651 |
+
)
|
| 652 |
+
# sim button
|
| 653 |
sim_button.click(
|
| 654 |
+
search_3D_similarity,
|
| 655 |
+
[model_render, embedding_store],
|
| 656 |
+
[
|
| 657 |
model_s_1,
|
| 658 |
model_s_2,
|
| 659 |
model_s_3,
|
| 660 |
model_s_4,
|
| 661 |
+
model_s_1_btn,
|
| 662 |
+
model_s_2_btn,
|
| 663 |
+
model_s_3_btn,
|
| 664 |
+
model_s_4_btn,
|
| 665 |
+
],
|
| 666 |
)
|
| 667 |
+
# model similarity
|
| 668 |
+
model_s_1_btn.click(
|
| 669 |
+
render_3D_object, model_s_1, [model_render, model_hidden_filepath]
|
| 670 |
+
)
|
| 671 |
+
model_s_2_btn.click(
|
| 672 |
+
render_3D_object, model_s_2, [model_render, model_hidden_filepath]
|
| 673 |
+
)
|
| 674 |
+
model_s_3_btn.click(
|
| 675 |
+
render_3D_object, model_s_3, [model_render, model_hidden_filepath]
|
| 676 |
+
)
|
| 677 |
+
model_s_4_btn.click(
|
| 678 |
+
render_3D_object, model_s_4, [model_render, model_hidden_filepath]
|
| 679 |
+
)
|
| 680 |
+
# drop down
|
| 681 |
+
file_dropdown.change(
|
| 682 |
+
render_3D_object, file_dropdown, [model_render, model_hidden_filepath]
|
| 683 |
+
)
|
| 684 |
+
# parse metadata
|
| 685 |
+
model_hidden_filepath.change(
|
| 686 |
+
render_3D_metadata,
|
| 687 |
+
[model_hidden_filepath, model_render, embedding_store],
|
| 688 |
+
[metadata_render, description_render],
|
| 689 |
)
|
| 690 |
|
| 691 |
if __name__ == "__main__":
|
| 692 |
+
demo.launch(share=True)
|
encode_image.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%writefile encode_image.py
|
| 2 |
+
import base64
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
Image.MAX_IMAGE_PIXELS = None # Removes the limit, use with caution
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def encode_image(image: Union[str, np.ndarray]) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Encodes an image as a base64 string.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image (Union[str, np.ndarray]): Path to the image file or a NumPy array representing the image.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Base64-encoded image string.
|
| 21 |
+
"""
|
| 22 |
+
if isinstance(image, str): # If the input is a file path
|
| 23 |
+
with open(image, "rb") as image_file:
|
| 24 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 25 |
+
elif isinstance(image, np.ndarray): # If the input is a NumPy array
|
| 26 |
+
_, buffer = cv2.imencode(".jpg", image) # Encode image as JPEG
|
| 27 |
+
return base64.b64encode(buffer).decode("utf-8")
|
| 28 |
+
else:
|
| 29 |
+
raise TypeError("Input must be a file path (str) or a NumPy array.")
|
llm_service.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%writefile llm_service.py
|
| 2 |
+
import asyncio
|
| 3 |
+
import base64
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import List, Tuple, Union, cast
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from openai import AsyncOpenAI
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from loguru import logger
|
| 14 |
+
|
| 15 |
+
from encode_image import encode_image
|
| 16 |
+
from string_utils import StringUtils
|
| 17 |
+
|
| 18 |
+
Image.MAX_IMAGE_PIXELS = None # Removes the limit, use with caution
|
| 19 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class OpenAIService:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
# self.llm_settings = getattr(settings.llm, settings.llm.name)
|
| 25 |
+
self.model_name = "gpt-4o" # settings.llm.openai.model
|
| 26 |
+
self.temperature = 0.3 # settings.llm.openai.temperature
|
| 27 |
+
self.client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
| 28 |
+
# Follow the documentation: https://platform.openai.com/docs/models
|
| 29 |
+
self.deprecated_temperature_models = [
|
| 30 |
+
"o4-mini",
|
| 31 |
+
"o4",
|
| 32 |
+
"o3-mini",
|
| 33 |
+
"o3",
|
| 34 |
+
] # settings.llm.openai.deprecated_temperature_models
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def encode_image(image: Union[str, np.ndarray]) -> str:
|
| 38 |
+
return encode_image(image=image)
|
| 39 |
+
|
| 40 |
+
def get_temperature(self, temperature: float | None) -> dict:
|
| 41 |
+
return (
|
| 42 |
+
{
|
| 43 |
+
"temperature": temperature
|
| 44 |
+
if temperature is not None
|
| 45 |
+
else self.temperature
|
| 46 |
+
}
|
| 47 |
+
if self.model_name not in self.deprecated_temperature_models
|
| 48 |
+
else {}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
async def chat_with_text(
|
| 52 |
+
self,
|
| 53 |
+
prompt: str,
|
| 54 |
+
return_as_json: bool = False,
|
| 55 |
+
retry_left: int = 3, # settings.llm.openai.retry_left,
|
| 56 |
+
temperature: float | None = None,
|
| 57 |
+
) -> str:
|
| 58 |
+
"""
|
| 59 |
+
Sends a text-based chat prompt to the OpenAI model.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
prompt (str): User input text.
|
| 63 |
+
return_as_json (bool): whether to generate output as a json object
|
| 64 |
+
retry_left (int): number of retries left
|
| 65 |
+
temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
str: Response from the model.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
model_kwargs = {
|
| 72 |
+
"model": self.model_name,
|
| 73 |
+
"messages": [
|
| 74 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 75 |
+
{"role": "user", "content": prompt},
|
| 76 |
+
],
|
| 77 |
+
**self.get_temperature(temperature=temperature),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if return_as_json:
|
| 81 |
+
model_kwargs["response_format"] = {"type": "json_object"}
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
response = await self.client.chat.completions.create(**model_kwargs)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
if retry_left > 0:
|
| 87 |
+
logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
|
| 88 |
+
await asyncio.sleep(1) # quota out
|
| 89 |
+
return await self.chat_with_text(
|
| 90 |
+
prompt=prompt,
|
| 91 |
+
return_as_json=return_as_json,
|
| 92 |
+
retry_left=retry_left - 1,
|
| 93 |
+
temperature=temperature,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
logger.error(
|
| 97 |
+
f"OpenAI API calling failed due to {e}. Return empty string!"
|
| 98 |
+
)
|
| 99 |
+
return ""
|
| 100 |
+
|
| 101 |
+
return response.choices[0].message.content
|
| 102 |
+
|
| 103 |
+
async def chat_with_image(
|
| 104 |
+
self,
|
| 105 |
+
prompt: str,
|
| 106 |
+
image: str,
|
| 107 |
+
return_as_json: bool = False,
|
| 108 |
+
retry_left: int = 3, # settings.llm.openai.retry_left,
|
| 109 |
+
temperature: float | None = None,
|
| 110 |
+
) -> str:
|
| 111 |
+
"""
|
| 112 |
+
Sends an image along with a text prompt to the OpenAI model.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
prompt (str): User input text.
|
| 116 |
+
image_path (str): Path to the image file.
|
| 117 |
+
return_as_json (bool): whether to generate output as a json object
|
| 118 |
+
retry_left (int): number of retries left
|
| 119 |
+
temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
str: Response from the model.
|
| 123 |
+
"""
|
| 124 |
+
if os.path.isfile(image):
|
| 125 |
+
base64_image = self.encode_image(image=image)
|
| 126 |
+
elif StringUtils.is_base64(image):
|
| 127 |
+
base64_image = image
|
| 128 |
+
else:
|
| 129 |
+
raise Exception(
|
| 130 |
+
"ServiceAiError.UNSUPPORT_INPUT_IMAGE_TYPE.as_http_exception()"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
model_kwargs = {
|
| 134 |
+
"model": self.model_name,
|
| 135 |
+
"messages": [
|
| 136 |
+
{
|
| 137 |
+
"role": "user",
|
| 138 |
+
"content": [
|
| 139 |
+
{"type": "text", "text": prompt},
|
| 140 |
+
{
|
| 141 |
+
"type": "image_url",
|
| 142 |
+
"image_url": {
|
| 143 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
| 144 |
+
},
|
| 145 |
+
},
|
| 146 |
+
],
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
**self.get_temperature(temperature=temperature),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
if return_as_json:
|
| 153 |
+
model_kwargs["response_format"] = {"type": "json_object"}
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
response = await self.client.chat.completions.create(**model_kwargs)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
if retry_left > 0:
|
| 159 |
+
logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
|
| 160 |
+
await asyncio.sleep(1) # quota out
|
| 161 |
+
return await self.chat_with_image(
|
| 162 |
+
prompt=prompt,
|
| 163 |
+
image=image,
|
| 164 |
+
return_as_json=return_as_json,
|
| 165 |
+
retry_left=retry_left - 1,
|
| 166 |
+
temperature=temperature,
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
logger.error(
|
| 170 |
+
f"OpenAI API calling failed due to {e}. Return empty string!"
|
| 171 |
+
)
|
| 172 |
+
return ""
|
| 173 |
+
return response.choices[0].message.content
|
| 174 |
+
|
| 175 |
+
async def chat_with_multiple_images(
|
| 176 |
+
self,
|
| 177 |
+
prompt: str,
|
| 178 |
+
images: list[str],
|
| 179 |
+
return_as_json: bool = False,
|
| 180 |
+
retry_left: int = 3, # settings.llm.openai.retry_left,
|
| 181 |
+
temperature: float | None = None,
|
| 182 |
+
) -> str:
|
| 183 |
+
"""
|
| 184 |
+
Sends multiple images along with a text prompt to the OpenAI model.
|
| 185 |
+
Args:
|
| 186 |
+
prompt (str): User input text.
|
| 187 |
+
images (list[str]): List of base64 encoded images.
|
| 188 |
+
return_as_json (bool): whether to generate output as a json object
|
| 189 |
+
retry_left (int): number of retries left
|
| 190 |
+
temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
|
| 191 |
+
Returns:
|
| 192 |
+
list[str]: Responses from the model for each image.
|
| 193 |
+
"""
|
| 194 |
+
if len(images) == 0:
|
| 195 |
+
logger.warning("OpenAI chats with multiple images mode without any images")
|
| 196 |
+
|
| 197 |
+
base64_images = []
|
| 198 |
+
for image in images:
|
| 199 |
+
if os.path.isfile(image):
|
| 200 |
+
base64_images.append(self.encode_image(image=image))
|
| 201 |
+
elif StringUtils.is_base64(image):
|
| 202 |
+
base64_images.append(image)
|
| 203 |
+
else:
|
| 204 |
+
raise Exception(
|
| 205 |
+
"ServiceAiError.UNSUPPORT_INPUT_IMAGE_TYPE.as_http_exception()"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
model_kwargs = {
|
| 209 |
+
"model": self.model_name,
|
| 210 |
+
"messages": [
|
| 211 |
+
{
|
| 212 |
+
"role": "user",
|
| 213 |
+
"content": [
|
| 214 |
+
{"type": "text", "text": prompt},
|
| 215 |
+
*[
|
| 216 |
+
{
|
| 217 |
+
"type": "image_url",
|
| 218 |
+
"image_url": {
|
| 219 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
| 220 |
+
},
|
| 221 |
+
}
|
| 222 |
+
for base64_image in base64_images
|
| 223 |
+
],
|
| 224 |
+
],
|
| 225 |
+
}
|
| 226 |
+
],
|
| 227 |
+
**self.get_temperature(temperature=temperature),
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
if return_as_json:
|
| 231 |
+
model_kwargs["response_format"] = {"type": "json_object"}
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
response = await self.client.chat.completions.create(**model_kwargs)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
if retry_left > 0:
|
| 237 |
+
logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
|
| 238 |
+
await asyncio.sleep(1) # quota out
|
| 239 |
+
return await self.chat_with_multiple_images(
|
| 240 |
+
prompt=prompt,
|
| 241 |
+
images=images,
|
| 242 |
+
return_as_json=return_as_json,
|
| 243 |
+
retry_left=retry_left - 1,
|
| 244 |
+
temperature=temperature,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
logger.error(
|
| 248 |
+
f"OpenAI API calling failed due to {e}. Return empty list!"
|
| 249 |
+
)
|
| 250 |
+
return ""
|
| 251 |
+
|
| 252 |
+
return response.choices[0].message.content
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class LLMService:
|
| 256 |
+
@classmethod
|
| 257 |
+
def from_partner(cls):
|
| 258 |
+
return OpenAIService()
|
mv_utils_zs.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%writefile mv_utils_zs.py
|
| 2 |
+
"""
|
| 3 |
+
Author: yangyangyang127
|
| 4 |
+
Github: https://github.com/yangyangyang127
|
| 5 |
+
Repo: https://github.com/yangyangyang127/PointCLIP_V2
|
| 6 |
+
Path: https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/mv_utils_zs.py#L135
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch_scatter import scatter
|
| 13 |
+
|
| 14 |
+
TRANS = -1.5
|
| 15 |
+
|
| 16 |
+
# realistic projection parameters
|
| 17 |
+
params = {
|
| 18 |
+
"maxpoolz": 1,
|
| 19 |
+
"maxpoolxy": 7,
|
| 20 |
+
"maxpoolpadz": 0,
|
| 21 |
+
"maxpoolpadxy": 2,
|
| 22 |
+
"convz": 1,
|
| 23 |
+
"convxy": 3,
|
| 24 |
+
"convsigmaxy": 3,
|
| 25 |
+
"convsigmaz": 1,
|
| 26 |
+
"convpadz": 0,
|
| 27 |
+
"convpadxy": 1,
|
| 28 |
+
"imgbias": 0.0,
|
| 29 |
+
"depth_bias": 0.2,
|
| 30 |
+
"obj_ratio": 0.8,
|
| 31 |
+
"bg_clr": 0.0,
|
| 32 |
+
"resolution": 122,
|
| 33 |
+
"depth": 8, # default = 8
|
| 34 |
+
"grid_height": 64,
|
| 35 |
+
"grid_width": 64,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Grid2Image(nn.Module):
|
| 40 |
+
"""A pytorch implementation to turn 3D grid to 2D image.
|
| 41 |
+
Maxpool: densifying the grid
|
| 42 |
+
Convolution: smoothing via Gaussian
|
| 43 |
+
Maximize: squeezing the depth channel
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__()
|
| 48 |
+
torch.backends.cudnn.benchmark = False
|
| 49 |
+
|
| 50 |
+
self.maxpool = nn.MaxPool3d(
|
| 51 |
+
(params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
|
| 52 |
+
stride=1,
|
| 53 |
+
padding=(
|
| 54 |
+
params["maxpoolpadz"],
|
| 55 |
+
params["maxpoolpadxy"],
|
| 56 |
+
params["maxpoolpadxy"],
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
self.conv = torch.nn.Conv3d(
|
| 60 |
+
1,
|
| 61 |
+
1,
|
| 62 |
+
kernel_size=(params["convz"], params["convxy"], params["convxy"]),
|
| 63 |
+
stride=1,
|
| 64 |
+
padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
|
| 65 |
+
bias=True,
|
| 66 |
+
)
|
| 67 |
+
kn3d = get3DGaussianKernel(
|
| 68 |
+
params["convxy"],
|
| 69 |
+
params["convz"],
|
| 70 |
+
sigma=params["convsigmaxy"],
|
| 71 |
+
zsigma=params["convsigmaz"],
|
| 72 |
+
)
|
| 73 |
+
self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
|
| 74 |
+
self.conv.bias.data.fill_(0)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
x = self.maxpool(x.unsqueeze(1))
|
| 78 |
+
x = self.conv(x)
|
| 79 |
+
img = torch.max(x, dim=2)[0]
|
| 80 |
+
img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
|
| 81 |
+
img = 1 - img
|
| 82 |
+
img = img.repeat(1, 3, 1, 1)
|
| 83 |
+
return img
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def euler2mat(angle):
|
| 87 |
+
"""Convert euler angles to rotation matrix.
|
| 88 |
+
:param angle: [3] or [b, 3]
|
| 89 |
+
:return
|
| 90 |
+
rotmat: [3] or [b, 3, 3]
|
| 91 |
+
source
|
| 92 |
+
https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
|
| 93 |
+
"""
|
| 94 |
+
if len(angle.size()) == 1:
|
| 95 |
+
x, y, z = angle[0], angle[1], angle[2]
|
| 96 |
+
_dim = 0
|
| 97 |
+
_view = [3, 3]
|
| 98 |
+
elif len(angle.size()) == 2:
|
| 99 |
+
b, _ = angle.size()
|
| 100 |
+
x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
|
| 101 |
+
_dim = 1
|
| 102 |
+
_view = [b, 3, 3]
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
assert False
|
| 106 |
+
|
| 107 |
+
cosz = torch.cos(z)
|
| 108 |
+
sinz = torch.sin(z)
|
| 109 |
+
|
| 110 |
+
# zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
|
| 111 |
+
# one = torch.ones([b], requires_grad=False, device=angle.device)[0]
|
| 112 |
+
zero = z.detach() * 0
|
| 113 |
+
one = zero.detach() + 1
|
| 114 |
+
zmat = torch.stack(
|
| 115 |
+
[cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
|
| 116 |
+
).reshape(_view)
|
| 117 |
+
|
| 118 |
+
cosy = torch.cos(y)
|
| 119 |
+
siny = torch.sin(y)
|
| 120 |
+
|
| 121 |
+
ymat = torch.stack(
|
| 122 |
+
[cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
|
| 123 |
+
).reshape(_view)
|
| 124 |
+
|
| 125 |
+
cosx = torch.cos(x)
|
| 126 |
+
sinx = torch.sin(x)
|
| 127 |
+
|
| 128 |
+
xmat = torch.stack(
|
| 129 |
+
[one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
|
| 130 |
+
).reshape(_view)
|
| 131 |
+
|
| 132 |
+
rot_mat = xmat @ ymat @ zmat
|
| 133 |
+
# print(rot_mat)
|
| 134 |
+
return rot_mat
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def points_to_2d_grid(
|
| 138 |
+
points, grid_h=params["grid_height"], grid_w=params["grid_width"]
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Chuyển đổi point cloud thành lưới 2D dựa trên tọa độ X, Y.
|
| 142 |
+
Các điểm được chiếu lên một mặt phẳng và được lượng tử hóa vào các ô lưới.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
points (torch.tensor): Tensor chứa các điểm, kích thước [B, P, 3]
|
| 146 |
+
(B: batch size, P: số lượng điểm, 3: tọa độ x, y, z)
|
| 147 |
+
grid_h (int): Chiều cao của lưới 2D đầu ra.
|
| 148 |
+
grid_w (int): Chiều rộng của lưới 2D đầu ra.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
grid (torch.tensor): Lưới 2D biểu diễn sự chiếm dụng của các điểm,
|
| 152 |
+
kích thước [B, grid_h, grid_w].
|
| 153 |
+
Giá trị 1.0 tại ô (y, x) nếu có ít nhất một điểm rơi vào đó,
|
| 154 |
+
ngược lại là giá trị nền (params["bg_clr"]).
|
| 155 |
+
"""
|
| 156 |
+
batch, pnum, _ = points.shape
|
| 157 |
+
device = points.device
|
| 158 |
+
|
| 159 |
+
# --- Bước 1: Chuẩn hóa tọa độ điểm ---
|
| 160 |
+
# Tìm min/max cho từng point cloud trong batch (chỉ xét X, Y để chuẩn hóa 2D tốt hơn)
|
| 161 |
+
pmax_xy = points[:, :, :2].max(dim=1)[0]
|
| 162 |
+
pmin_xy = points[:, :, :2].min(dim=1)[0]
|
| 163 |
+
|
| 164 |
+
# Tính tâm và phạm vi dựa trên X, Y
|
| 165 |
+
pcent_xy = (pmax_xy + pmin_xy) / 2
|
| 166 |
+
pcent_xy = pcent_xy[:, None, :] # Thêm chiều P để broadcast [B, 1, 2]
|
| 167 |
+
|
| 168 |
+
# Sử dụng phạm vi lớn nhất giữa X và Y để giữ tỷ lệ aspect ratio
|
| 169 |
+
prange_xy = (pmax_xy - pmin_xy).max(dim=-1)[0][:, None, None] # [B, 1, 1]
|
| 170 |
+
|
| 171 |
+
# Thêm một epsilon nhỏ để tránh chia cho 0 nếu tất cả các điểm trùng nhau
|
| 172 |
+
epsilon = 1e-8
|
| 173 |
+
# Chỉ chuẩn hóa X, Y vào khoảng [-1, 1] dựa trên phạm vi X, Y
|
| 174 |
+
# (points[:, :, :2] - pcent_xy) -> [B, P, 2]
|
| 175 |
+
# prange_xy -> [B, 1, 1]
|
| 176 |
+
points_normalized_xy = (points[:, :, :2] - pcent_xy) / (prange_xy + epsilon) * 2.0
|
| 177 |
+
|
| 178 |
+
# Điều chỉnh tỷ lệ theo obj_ratio (nếu cần)
|
| 179 |
+
points_normalized_xy = points_normalized_xy * params["obj_ratio"]
|
| 180 |
+
|
| 181 |
+
# --- Bước 2: Ánh xạ tọa độ chuẩn hóa vào chỉ số lưới 2D ---
|
| 182 |
+
# Ánh xạ X từ khoảng [-obj_ratio, obj_ratio] -> [0, grid_w]
|
| 183 |
+
# Ánh xạ Y từ khoảng [-obj_ratio, obj_ratio] -> [0, grid_h]
|
| 184 |
+
# Công thức chung: (normalized_coord + scale) / (2 * scale) * grid_dim
|
| 185 |
+
_x = (
|
| 186 |
+
(points_normalized_xy[:, :, 0] + params["obj_ratio"])
|
| 187 |
+
/ (2 * params["obj_ratio"])
|
| 188 |
+
* grid_w
|
| 189 |
+
)
|
| 190 |
+
_y = (
|
| 191 |
+
(points_normalized_xy[:, :, 1] + params["obj_ratio"])
|
| 192 |
+
/ (2 * params["obj_ratio"])
|
| 193 |
+
* grid_h
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Làm tròn xuống để xác định chỉ số ô lưới (index)
|
| 197 |
+
_x = torch.floor(_x).long()
|
| 198 |
+
_y = torch.floor(_y).long()
|
| 199 |
+
|
| 200 |
+
# --- Bước 3: Giới hạn chỉ số vào phạm vi hợp lệ của lưới ---
|
| 201 |
+
# Clip _x vào [0, grid_w - 1]
|
| 202 |
+
# Clip _y vào [0, grid_h - 1]
|
| 203 |
+
_x = torch.clip(_x, 0, grid_w - 1)
|
| 204 |
+
_y = torch.clip(_y, 0, grid_h - 1)
|
| 205 |
+
|
| 206 |
+
# --- Bước 4: Tạo lưới 2D và đánh dấu các ô bị chiếm dụng ---
|
| 207 |
+
# Khởi tạo lưới 2D với giá trị nền
|
| 208 |
+
grid = torch.full(
|
| 209 |
+
(batch, grid_h, grid_w), params["bg_clr"], dtype=torch.float32, device=device
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Tạo chỉ số batch tương ứng với mỗi điểm
|
| 213 |
+
batch_indices = torch.arange(batch, device=device).view(-1, 1).repeat(1, pnum)
|
| 214 |
+
|
| 215 |
+
# Flatten các chỉ số để dễ dàng gán giá trị
|
| 216 |
+
batch_idx_flat = batch_indices.view(-1)
|
| 217 |
+
y_idx_flat = _y.view(-1)
|
| 218 |
+
x_idx_flat = _x.view(-1)
|
| 219 |
+
|
| 220 |
+
# Gán giá trị 1.0 vào các ô lưới (y, x) tương ứng với vị trí các điểm
|
| 221 |
+
# Nếu nhiều điểm rơi vào cùng một ô, ô đó vẫn chỉ có giá trị 1.0
|
| 222 |
+
grid[batch_idx_flat, y_idx_flat, x_idx_flat] = 1.0
|
| 223 |
+
|
| 224 |
+
return grid
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
|
| 228 |
+
"""Quantize each point cloud to a 3D grid.
|
| 229 |
+
Args:
|
| 230 |
+
points (torch.tensor): of size [B, _, 3]
|
| 231 |
+
Returns:
|
| 232 |
+
grid (torch.tensor): of size [B * self.num_views, depth, resolution, resolution]
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
batch, pnum, _ = points.shape
|
| 236 |
+
|
| 237 |
+
pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
|
| 238 |
+
pcent = (pmax + pmin) / 2
|
| 239 |
+
pcent = pcent[:, None, :]
|
| 240 |
+
prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
|
| 241 |
+
points = (points - pcent) / prange * 2.0
|
| 242 |
+
points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]
|
| 243 |
+
|
| 244 |
+
depth_bias = params["depth_bias"]
|
| 245 |
+
_x = (points[:, :, 0] + 1) / 2 * resolution
|
| 246 |
+
_y = (points[:, :, 1] + 1) / 2 * resolution
|
| 247 |
+
_z = ((points[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
|
| 248 |
+
|
| 249 |
+
_x.ceil_()
|
| 250 |
+
_y.ceil_()
|
| 251 |
+
z_int = _z.ceil()
|
| 252 |
+
|
| 253 |
+
_x = torch.clip(_x, 1, resolution - 2)
|
| 254 |
+
_y = torch.clip(_y, 1, resolution - 2)
|
| 255 |
+
_z = torch.clip(_z, 1, depth - 2)
|
| 256 |
+
|
| 257 |
+
coordinates = z_int * resolution * resolution + _y * resolution + _x
|
| 258 |
+
grid = (
|
| 259 |
+
torch.ones([batch, depth, resolution, resolution], device=points.device).view(
|
| 260 |
+
batch, -1
|
| 261 |
+
)
|
| 262 |
+
* params["bg_clr"]
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# # *** THAY ĐỔI CHÍNH Ở ĐÂY ***
|
| 266 |
+
# # Tạo tensor nguồn (src) chứa giá trị 1.0 cho mỗi điểm
|
| 267 |
+
# # Kích thước phải phù hợp với coordinates khi flatten: [B * pnum]
|
| 268 |
+
# values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=points.device)
|
| 269 |
+
|
| 270 |
+
# # Scatter giá trị 1.0 vào grid tại các vị trí `coordinates`
|
| 271 |
+
# # Sử dụng reduce="max". Nếu ô có ít nhất một điểm, max(1.0, bg_clr) sẽ là 1.0 (nếu bg_clr <= 1)
|
| 272 |
+
# # Nếu muốn chắc chắn là 1 bất kể bg_clr, có thể dùng reduce khác hoặc xử lý sau scatter.
|
| 273 |
+
# # Lựa chọn an toàn hơn nếu bg_clr có thể > 1 là khởi tạo grid bằng 0 và dùng reduce='max'/'mean'
|
| 274 |
+
# # Hoặc khởi tạo bằng bg_clr và xử lý sau scatter.
|
| 275 |
+
# # Giả định bg_clr = 0.0 là phổ biến nhất cho occupancy grid.
|
| 276 |
+
|
| 277 |
+
# grid = scatter(
|
| 278 |
+
# values_to_scatter,
|
| 279 |
+
# coordinates.view(-1).long(), # Flatten coordinates thành [B*pnum]
|
| 280 |
+
# dim=0, # Scatter trên chiều 0 của grid đã flatten [B*D*R*R]
|
| 281 |
+
# # Cần chỉ số batch tương ứng nếu grid chưa flatten theo batch
|
| 282 |
+
# out=grid.view(-1), # Flatten grid thành [B*D*R*R] để scatter trên dim 0
|
| 283 |
+
# reduce="max",
|
| 284 |
+
# ) # Nếu có điểm -> giá trị ô là 1, nếu không là bg_clr
|
| 285 |
+
# # **********************************
|
| 286 |
+
|
| 287 |
+
grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
|
| 288 |
+
grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
|
| 289 |
+
|
| 290 |
+
return grid
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Giả sử bạn có thư viện scatter, ví dụ: from torch_scatter import scatter
|
| 294 |
+
# Hoặc hàm scatter tương đương
|
| 295 |
+
# import torch # Đảm bảo đã import torch
|
| 296 |
+
# from torch_scatter import scatter # Ví dụ
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def points_to_occupancy_grid(
|
| 300 |
+
points, resolution=params["resolution"], depth=params["depth"]
|
| 301 |
+
):
|
| 302 |
+
"""Quantize each point cloud to a 3D occupancy grid."""
|
| 303 |
+
|
| 304 |
+
batch, pnum, _ = points.shape
|
| 305 |
+
device = points.device # Lấy device để tạo tensor mới
|
| 306 |
+
|
| 307 |
+
# --- Phần chuẩn hóa và ánh xạ tọa độ giữ nguyên ---
|
| 308 |
+
pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
|
| 309 |
+
pcent = (pmax + pmin) / 2
|
| 310 |
+
pcent = pcent[:, None, :]
|
| 311 |
+
prange = (pmax - pmin).max(dim=-1)[0][
|
| 312 |
+
:, None, None
|
| 313 |
+
] + 1e-8 # Thêm epsilon tránh chia 0
|
| 314 |
+
points_norm = (points - pcent) / prange * 2.0
|
| 315 |
+
points_norm[:, :, :2] = points_norm[:, :, :2] * params["obj_ratio"]
|
| 316 |
+
|
| 317 |
+
depth_bias = params["depth_bias"]
|
| 318 |
+
_x = (points_norm[:, :, 0] + 1) / 2 * resolution
|
| 319 |
+
_y = (points_norm[:, :, 1] + 1) / 2 * resolution
|
| 320 |
+
_z = ((points_norm[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
|
| 321 |
+
|
| 322 |
+
_x.ceil_()
|
| 323 |
+
_y.ceil_()
|
| 324 |
+
z_int = _z.ceil()
|
| 325 |
+
|
| 326 |
+
_x = torch.clip(_x, 1, resolution - 2)
|
| 327 |
+
_y = torch.clip(_y, 1, resolution - 2)
|
| 328 |
+
# z_int cũng nên được clip nếu dùng làm chỉ số tọa độ
|
| 329 |
+
z_int = torch.clip(z_int, 1, depth - 2)
|
| 330 |
+
|
| 331 |
+
# --- Tính toán flattened coordinates giữ nguyên ---
|
| 332 |
+
coordinates = z_int * resolution * resolution + _y * resolution + _x
|
| 333 |
+
coordinates = coordinates.long() # Chuyển sang Long
|
| 334 |
+
|
| 335 |
+
# --- Tạo Grid và Scatter ---
|
| 336 |
+
# Khởi tạo grid với giá trị nền (ví dụ: 0)
|
| 337 |
+
# Sử dụng torch.zeros thay vì torch.ones và nhân bg_clr
|
| 338 |
+
bg_clr_value = params.get("bg_clr", 0.0) # Lấy bg_clr, mặc định là 0
|
| 339 |
+
grid = torch.full(
|
| 340 |
+
(batch, depth * resolution * resolution),
|
| 341 |
+
bg_clr_value,
|
| 342 |
+
dtype=torch.float32, # Hoặc dtype phù hợp
|
| 343 |
+
device=device,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# *** THAY ĐỔI CHÍNH Ở ĐÂY ***
|
| 347 |
+
# Tạo tensor nguồn (src) chứa giá trị 1.0 cho mỗi điểm
|
| 348 |
+
# Kích thước phải phù hợp với coordinates khi flatten: [B * pnum]
|
| 349 |
+
values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=device)
|
| 350 |
+
|
| 351 |
+
# Scatter giá trị 1.0 vào grid tại các vị trí `coordinates`
|
| 352 |
+
# Sử dụng reduce="max". Nếu ô có ít nhất một điểm, max(1.0, bg_clr) sẽ là 1.0 (nếu bg_clr <= 1)
|
| 353 |
+
# Nếu muốn chắc chắn là 1 bất kể bg_clr, có thể dùng reduce khác hoặc xử lý sau scatter.
|
| 354 |
+
# Lựa chọn an toàn hơn nếu bg_clr có thể > 1 là khởi tạo grid bằng 0 và dùng reduce='max'/'mean'
|
| 355 |
+
# Hoặc khởi tạo bằng bg_clr và xử lý sau scatter.
|
| 356 |
+
# Giả định bg_clr = 0.0 là phổ biến nhất cho occupancy grid.
|
| 357 |
+
if bg_clr_value != 0.0:
|
| 358 |
+
print(
|
| 359 |
+
"Warning: bg_clr is not 0.0, occupancy grid might not be strictly binary 0/1 with reduce='max'. Consider initializing grid with 0."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
grid = scatter(
|
| 363 |
+
values_to_scatter,
|
| 364 |
+
coordinates.view(-1), # Flatten coordinates thành [B*pnum]
|
| 365 |
+
dim=0, # Scatter trên chiều 0 của grid đã flatten [B*D*R*R]
|
| 366 |
+
# Cần chỉ số batch tương ứng nếu grid chưa flatten theo batch
|
| 367 |
+
out=grid.view(-1), # Flatten grid thành [B*D*R*R] để scatter trên dim 0
|
| 368 |
+
reduce="max",
|
| 369 |
+
) # Nếu có điểm -> giá trị ô là 1, nếu không là bg_clr
|
| 370 |
+
|
| 371 |
+
# --- Reshape và Permute giữ nguyên ---
|
| 372 |
+
# Reshape lại grid về đúng kích thước 3D + batch
|
| 373 |
+
# Lưu ý: scatter vào grid đã flatten cần reshape cẩn thận
|
| 374 |
+
grid = grid.view(batch, depth, resolution, resolution) # Reshape lại
|
| 375 |
+
grid = grid.permute((0, 1, 3, 2))
|
| 376 |
+
|
| 377 |
+
return grid
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Realistic_Projection:
|
| 381 |
+
"""For creating images from PC based on the view information."""
|
| 382 |
+
|
| 383 |
+
def __init__(self):
|
| 384 |
+
_views = np.asarray([
|
| 385 |
+
[[1 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 386 |
+
[[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 387 |
+
[[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 388 |
+
[[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 389 |
+
[[0 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 390 |
+
[[1 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 391 |
+
[[2 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 392 |
+
[[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 393 |
+
[[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 394 |
+
[[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
|
| 395 |
+
])
|
| 396 |
+
|
| 397 |
+
# adding some bias to the view angle to reveal more surface
|
| 398 |
+
_views_bias = np.asarray([
|
| 399 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 400 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 401 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 402 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 403 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 404 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 405 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 406 |
+
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
|
| 407 |
+
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
|
| 408 |
+
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
|
| 409 |
+
])
|
| 410 |
+
|
| 411 |
+
self.num_views = _views.shape[0]
|
| 412 |
+
|
| 413 |
+
angle = torch.tensor(_views[:, 0, :]).float() # .cuda()
|
| 414 |
+
self.rot_mat = euler2mat(angle).transpose(1, 2)
|
| 415 |
+
angle2 = torch.tensor(_views_bias[:, 0, :]).float() # .cuda()
|
| 416 |
+
self.rot_mat2 = euler2mat(angle2).transpose(1, 2)
|
| 417 |
+
|
| 418 |
+
self.translation = torch.tensor(_views[:, 1, :]).float() # .cuda()
|
| 419 |
+
self.translation = self.translation.unsqueeze(1)
|
| 420 |
+
|
| 421 |
+
self.grid2image = Grid2Image() # .cuda()
|
| 422 |
+
|
| 423 |
+
def get_img(self, points):
|
| 424 |
+
b, _, _ = points.shape
|
| 425 |
+
v = self.translation.shape[0]
|
| 426 |
+
|
| 427 |
+
_points = self.point_transform(
|
| 428 |
+
points=torch.repeat_interleave(points, v, dim=0),
|
| 429 |
+
rot_mat=self.rot_mat.repeat(b, 1, 1),
|
| 430 |
+
rot_mat2=self.rot_mat2.repeat(b, 1, 1),
|
| 431 |
+
translation=self.translation.repeat(b, 1, 1),
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
grid = points2grid(
|
| 435 |
+
points=_points, resolution=params["resolution"], depth=params["depth"]
|
| 436 |
+
).squeeze()
|
| 437 |
+
img = self.grid2image(grid)
|
| 438 |
+
return img
|
| 439 |
+
|
| 440 |
+
@staticmethod
|
| 441 |
+
def point_transform(points, rot_mat, rot_mat2, translation):
|
| 442 |
+
"""
|
| 443 |
+
:param points: [batch, num_points, 3]
|
| 444 |
+
:param rot_mat: [batch, 3]
|
| 445 |
+
:param rot_mat2: [batch, 3]
|
| 446 |
+
:param translation: [batch, 1, 3]
|
| 447 |
+
:return:
|
| 448 |
+
"""
|
| 449 |
+
rot_mat = rot_mat.to(points.device)
|
| 450 |
+
rot_mat2 = rot_mat2.to(points.device)
|
| 451 |
+
translation = translation.to(points.device)
|
| 452 |
+
points = torch.matmul(points, rot_mat)
|
| 453 |
+
points = torch.matmul(points, rot_mat2)
|
| 454 |
+
points = points - translation
|
| 455 |
+
return points
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def get2DGaussianKernel(ksize, sigma=0):
|
| 459 |
+
center = ksize // 2
|
| 460 |
+
xs = np.arange(ksize, dtype=np.float32) - center
|
| 461 |
+
kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
|
| 462 |
+
kernel = kernel1d[..., None] @ kernel1d[None, ...]
|
| 463 |
+
kernel = torch.from_numpy(kernel)
|
| 464 |
+
kernel = kernel / kernel.sum()
|
| 465 |
+
return kernel
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# Without numpy
|
| 469 |
+
# def get2DGaussianKernel(ksize, sigma):
|
| 470 |
+
# xs = torch.linspace(-(ksize // 2), ksize // 2, steps=ksize)
|
| 471 |
+
# kernel1d = torch.exp(-(xs ** 2) / (2 * sigma ** 2))
|
| 472 |
+
# kernel2d = torch.outer(kernel1d, kernel1d)
|
| 473 |
+
# kernel2d /= kernel2d.sum()
|
| 474 |
+
# return kernel2d
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
|
| 478 |
+
kernel2d = get2DGaussianKernel(ksize, sigma)
|
| 479 |
+
zs = np.arange(depth, dtype=np.float32) - depth // 2
|
| 480 |
+
zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
|
| 481 |
+
kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
|
| 482 |
+
kernel3d = kernel3d / torch.sum(kernel3d)
|
| 483 |
+
return kernel3d
|
string_utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%writefile string_utils.py
|
| 2 |
+
import base64
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import string
|
| 6 |
+
from urllib.parse import urlparse
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class StringUtils:
|
| 10 |
+
@staticmethod
|
| 11 |
+
def generate_random_string(length: int = 32) -> str:
|
| 12 |
+
characters = string.ascii_letters + string.digits
|
| 13 |
+
random_string = "".join(random.choice(characters) for _ in range(length))
|
| 14 |
+
return random_string
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def clean_string(input_string: str) -> str:
|
| 18 |
+
# Remove non-ASCII characters
|
| 19 |
+
cleaned_string = re.sub(r"[^\x00-\x7F]+", " ", input_string)
|
| 20 |
+
|
| 21 |
+
# Consolidate spaces and ensure correct spacing around punctuation
|
| 22 |
+
cleaned_string = re.sub(r"\s*([.,;!?%:])\s*", r"\1 ", cleaned_string)
|
| 23 |
+
|
| 24 |
+
# Adjust spacing for the dollar sign
|
| 25 |
+
cleaned_string = re.sub(r"\$\s+", "$", cleaned_string)
|
| 26 |
+
|
| 27 |
+
# Ensure correct spacing inside parentheses around numbers
|
| 28 |
+
cleaned_string = re.sub(r"\(\s*(\d+)\s*\)", r"( \1 )", cleaned_string)
|
| 29 |
+
|
| 30 |
+
# Remove extra spaces around punctuation (this might be redundant but ensures
|
| 31 |
+
# no trailing space before punctuation)
|
| 32 |
+
cleaned_string = re.sub(r"\s+([.,;!?%:])", r"\1", cleaned_string)
|
| 33 |
+
|
| 34 |
+
# Remove leading and trailing whitespace, reduce multiple spaces to a single
|
| 35 |
+
# space, and convert to lower case
|
| 36 |
+
cleaned_string = re.sub(r"\s+", " ", cleaned_string).strip().lower()
|
| 37 |
+
|
| 38 |
+
return cleaned_string
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def get_file_name_without_extension(file_name: str) -> str:
|
| 42 |
+
return ".".join(file_name.split(".")[:-1])
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def is_valid_url(url: str):
|
| 46 |
+
try:
|
| 47 |
+
result = urlparse(url)
|
| 48 |
+
return all([result.scheme, result.netloc])
|
| 49 |
+
except ValueError:
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def is_base64(string: str) -> bool:
|
| 54 |
+
"""
|
| 55 |
+
Validates if the input string is a Base64-encoded string.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
string (str): The string to validate.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
bool: True if the string is Base64, False otherwise.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
# Check if the string can be decoded
|
| 65 |
+
base64_bytes = base64.b64decode(string, validate=True)
|
| 66 |
+
# Check if decoded bytes can be re-encoded to the original string
|
| 67 |
+
return base64.b64encode(base64_bytes).decode("utf-8") == string
|
| 68 |
+
except Exception:
|
| 69 |
+
return False
|