TeeA commited on
Commit
d6cfb5e
·
1 Parent(s): 9a19e9e
Files changed (5) hide show
  1. app.py +563 -541
  2. encode_image.py +29 -0
  3. llm_service.py +258 -0
  4. mv_utils_zs.py +483 -0
  5. string_utils.py +69 -0
app.py CHANGED
@@ -1,352 +1,116 @@
1
- # app.py
2
- import os
3
- import subprocess
4
  import asyncio
5
- import base64
6
- import io
7
  import random
8
- import string
9
- import re
10
- import zipfile
11
- import xml.etree.ElementTree as ET
12
  import tempfile
13
- from urllib.parse import urlparse
14
- from typing import Tuple, Dict, Any, Union, List
15
 
16
- import gradio as gr
17
- import trimesh
18
  import numpy as np
19
  import torch
20
- from openai import AsyncOpenAI
21
- from PIL import Image
 
 
22
  from loguru import logger
 
23
  from sklearn.metrics.pairwise import cosine_similarity
24
  from torch import Tensor
25
- import torchvision.transforms.functional as TF
26
- from torch_scatter import scatter
27
 
28
- from llama_index.embeddings.clip import ClipEmbedding
29
- from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode
30
 
31
- # Tải API Key từ biến môi trường (sẽ được set trong Hugging Face Secrets)
32
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
33
 
34
- # ==============================================================================
35
- # PHẦN 1: CÁC HÀM TIỆN ÍCH VÀ KHAI BÁO (TỪ NOTEBOOK)
36
- # ==============================================================================
 
 
 
37
 
38
- # Các hằng số định dạng file
39
  GRADIO_3D_MODEL_DEFAULT_FORMAT = [".obj", ".glb", ".gltf", ".stl", ".splat", ".ply"]
40
  USER_REQUIRE_FORMAT = [".3dxml", ".step"]
41
  FREECAD_LOW_LEVEL_FORMAT = [".step", ".igs", ".iges"]
42
  FREECAD_NATIVE_FORMAT = [".fcstd"]
43
- VALID_FILE_TYPES = list(
44
- set(
45
- GRADIO_3D_MODEL_DEFAULT_FORMAT
46
- + USER_REQUIRE_FORMAT
47
- + FREECAD_NATIVE_FORMAT
48
- + FREECAD_LOW_LEVEL_FORMAT
49
- )
50
- )
51
- VALID_FILE_TYPES = VALID_FILE_TYPES + [t.upper() for t in VALID_FILE_TYPES]
52
-
53
- # Realistic Projection Parameters (from mv_utils_zs.py)
54
- TRANS = -1.5
55
- params = {
56
- "maxpoolz": 1,
57
- "maxpoolxy": 7,
58
- "maxpoolpadz": 0,
59
- "maxpoolpadxy": 2,
60
- "convz": 1,
61
- "convxy": 3,
62
- "convsigmaxy": 3,
63
- "convsigmaz": 1,
64
- "convpadz": 0,
65
- "convpadxy": 1,
66
- "imgbias": 0.0,
67
- "depth_bias": 0.2,
68
- "obj_ratio": 0.8,
69
- "bg_clr": 0.0,
70
- "resolution": 122,
71
- "depth": 8,
72
- "grid_height": 64,
73
- "grid_width": 64,
74
- }
75
-
76
-
77
- def get2DGaussianKernel(ksize, sigma=0):
78
- center = ksize // 2
79
- xs = np.arange(ksize, dtype=np.float32) - center
80
- kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
81
- kernel = kernel1d[..., None] @ kernel1d[None, ...]
82
- kernel = torch.from_numpy(kernel)
83
- kernel = kernel / kernel.sum()
84
- return kernel
85
-
86
-
87
- def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
88
- kernel2d = get2DGaussianKernel(ksize, sigma)
89
- zs = np.arange(depth, dtype=np.float32) - depth // 2
90
- zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
91
- kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
92
- kernel3d = kernel3d / torch.sum(kernel3d)
93
- return kernel3d
94
-
95
-
96
- def euler2mat(angle):
97
- if len(angle.size()) == 1:
98
- x, y, z = angle[0], angle[1], angle[2]
99
- _dim, _view = 0, [3, 3]
100
- else:
101
- b, _ = angle.size()
102
- x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
103
- _dim, _view = 1, [b, 3, 3]
104
- zero, one = z.detach() * 0, z.detach() * 0 + 1
105
- cosz, sinz = torch.cos(z), torch.sin(z)
106
- zmat = torch.stack(
107
- [cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
108
- ).reshape(_view)
109
- cosy, siny = torch.cos(y), torch.sin(y)
110
- ymat = torch.stack(
111
- [cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
112
- ).reshape(_view)
113
- cosx, sinx = torch.cos(x), torch.sin(x)
114
- xmat = torch.stack(
115
- [one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
116
- ).reshape(_view)
117
- return xmat @ ymat @ zmat
118
-
119
-
120
- def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
121
- batch, pnum, _ = points.shape
122
- pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
123
- pcent = (pmax + pmin) / 2
124
- pcent = pcent[:, None, :]
125
- prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
126
- points = (points - pcent) / prange * 2.0
127
- points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]
128
- _x = (points[:, :, 0] + 1) / 2 * resolution
129
- _y = (points[:, :, 1] + 1) / 2 * resolution
130
- _z = (
131
- ((points[:, :, 2] + 1) / 2 + params["depth_bias"])
132
- / (1 + params["depth_bias"])
133
- * (depth - 2)
134
- )
135
- _x.ceil_(), _y.ceil_()
136
- z_int = _z.ceil()
137
- _x, _y, _z = (
138
- torch.clip(_x, 1, resolution - 2),
139
- torch.clip(_y, 1, resolution - 2),
140
- torch.clip(_z, 1, depth - 2),
141
- )
142
- coordinates = z_int * resolution * resolution + _y * resolution + _x
143
- grid = (
144
- torch.ones([batch, depth, resolution, resolution], device=points.device).view(
145
- batch, -1
146
- )
147
- * params["bg_clr"]
148
- )
149
- grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
150
- grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
151
- return grid
152
-
153
-
154
- class Grid2Image(torch.nn.Module):
155
- def __init__(self):
156
- super().__init__()
157
- self.maxpool = torch.nn.MaxPool3d(
158
- (params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
159
- stride=1,
160
- padding=(
161
- params["maxpoolpadz"],
162
- params["maxpoolpadxy"],
163
- params["maxpoolpadxy"],
164
- ),
165
- )
166
- self.conv = torch.nn.Conv3d(
167
- 1,
168
- 1,
169
- kernel_size=(params["convz"], params["convxy"], params["convxy"]),
170
- stride=1,
171
- padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
172
- bias=True,
173
- )
174
- kn3d = get3DGaussianKernel(
175
- params["convxy"],
176
- params["convz"],
177
- sigma=params["convsigmaxy"],
178
- zsigma=params["convsigmaz"],
179
- )
180
- self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
181
- self.conv.bias.data.fill_(0)
182
-
183
- def forward(self, x):
184
- x = self.maxpool(x.unsqueeze(1))
185
- x = self.conv(x)
186
- img = torch.max(x, dim=2)[0]
187
- img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
188
- img = 1 - img
189
- img = img.repeat(1, 3, 1, 1)
190
- return img
191
-
192
-
193
- class Realistic_Projection:
194
- def __init__(self):
195
- _views = np.asarray([
196
- [[np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
197
- [[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
198
- [[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
199
- [[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
200
- [[0, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
201
- [[np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
202
- [[np.pi, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
203
- [[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
204
- [[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
205
- [[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
206
- ])
207
- _views_bias = np.asarray([
208
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
209
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
210
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
211
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
212
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
213
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
214
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
215
- [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
216
- [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
217
- [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
218
- ])
219
- angle, angle2 = (
220
- torch.tensor(_views[:, 0, :]).float(),
221
- torch.tensor(_views_bias[:, 0, :]).float(),
222
- )
223
- self.rot_mat, self.rot_mat2 = (
224
- euler2mat(angle).transpose(1, 2),
225
- euler2mat(angle2).transpose(1, 2),
226
- )
227
- self.translation = torch.tensor(_views[:, 1, :]).float().unsqueeze(1)
228
- self.grid2image = Grid2Image()
229
-
230
- def get_img(self, points):
231
- b, _, _ = points.shape
232
- v = self.translation.shape[0]
233
- _points = self.point_transform(
234
- torch.repeat_interleave(points, v, dim=0),
235
- self.rot_mat.repeat(b, 1, 1),
236
- self.rot_mat2.repeat(b, 1, 1),
237
- self.translation.repeat(b, 1, 1),
238
- )
239
- grid = points2grid(
240
- _points, resolution=params["resolution"], depth=params["depth"]
241
- ).squeeze()
242
- return self.grid2image(grid)
243
-
244
- @staticmethod
245
- def point_transform(points, rot_mat, rot_mat2, translation):
246
- rot_mat, rot_mat2, translation = (
247
- rot_mat.to(points.device),
248
- rot_mat2.to(points.device),
249
- translation.to(points.device),
250
- )
251
- points = torch.matmul(points, rot_mat)
252
- points = torch.matmul(points, rot_mat2)
253
- return points - translation
254
-
255
-
256
- # OpenAI Service
257
- class OpenAIService:
258
- def __init__(self):
259
- self.model_name = "gpt-4o"
260
- self.temperature = 0.3
261
- self.client = AsyncOpenAI(api_key=OPENAI_API_KEY)
262
-
263
- @staticmethod
264
- def encode_image(image: Union[str, np.ndarray]) -> str:
265
- if isinstance(image, str):
266
- with open(image, "rb") as image_file:
267
- return base64.b64encode(image_file.read()).decode("utf-8")
268
- elif isinstance(image, np.ndarray):
269
- _, buffer = cv2.imencode(".jpg", image)
270
- return base64.b64encode(buffer).decode("utf-8")
271
- raise TypeError("Input must be a file path or a NumPy array.")
272
-
273
- async def chat_with_image(
274
- self, prompt: str, image: str, retry_left: int = 3
275
- ) -> str:
276
- base64_image = self.encode_image(image=image)
277
- model_kwargs = {
278
- "model": self.model_name,
279
- "temperature": self.temperature,
280
- "messages": [
281
- {
282
- "role": "user",
283
- "content": [
284
- {"type": "text", "text": prompt},
285
- {
286
- "type": "image_url",
287
- "image_url": {
288
- "url": f"data:image/jpeg;base64,{base64_image}"
289
- },
290
- },
291
- ],
292
- }
293
- ],
294
- }
295
- try:
296
- response = await self.client.chat.completions.create(**model_kwargs)
297
- return response.choices[0].message.content
298
- except Exception as e:
299
- if retry_left > 0:
300
- logger.warning(f"OpenAI API failed: {e}. Retrying.")
301
- await asyncio.sleep(1)
302
- return await self.chat_with_image(prompt, image, retry_left - 1)
303
- logger.error(f"OpenAI API failed: {e}. Returning empty string.")
304
- return ""
305
-
306
-
307
- # ==============================================================================
308
- # PHẦN 2: LOGIC XỬ LÝ 3D (TỪ NOTEBOOK)
309
- # ==============================================================================
310
-
311
- # Khởi tạo các model và service
312
- llm_service = OpenAIService()
313
- pc_views = Realistic_Projection()
314
- clip_embedding_model = ClipEmbedding(embed_batch_size=1536)
315
- text_embedding_model = OpenAIEmbedding(
316
- mode=OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
317
- model="text-embedding-3-small",
318
- api_key=OPENAI_API_KEY,
319
- dimensions=1536,
320
- )
321
 
322
 
323
- # Chuyển đổi file STEP/FCStd sang OBJ
324
  def convert_step_to_obj_with_freecad(step_path, obj_path):
325
- freecad_executable = "/usr/bin/freecadcmd"
 
 
 
 
 
 
326
  _, ext = os.path.splitext(step_path)
327
  ext = ext.lower()
328
- script_template = ""
329
  if ext in FREECAD_LOW_LEVEL_FORMAT:
330
- script_template = "import FreeCAD, Part, Mesh; doc = FreeCAD.newDocument(); shape = Part.read('{step}'); obj = doc.addObject('Part::Feature', 'MyPart'); obj.Shape = shape; doc.recompute(); Mesh.export([obj], '{obj}')"
 
 
 
 
 
 
 
 
 
 
 
 
331
  elif ext in FREECAD_NATIVE_FORMAT:
332
- script_template = "import FreeCAD, Mesh; doc = FreeCAD.open('{step}'); to_export = [o for o in doc.Objects if hasattr(o, 'Shape')]; Mesh.export(to_export, '{obj}')"
 
 
 
 
 
 
 
 
333
  else:
334
- raise Exception(f"Unsupported format for conversion: {ext}")
 
335
 
336
- python_script = script_template.format(step=step_path, obj=obj_path)
337
  command = [freecad_executable, "-c", python_script]
 
 
338
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
 
339
  stdout, stderr = process.communicate()
340
- if process.returncode != 0:
341
- logger.error(
342
- f"FreeCAD conversion failed for {step_path}. Stderr: {stderr.decode()}"
343
- )
344
 
 
 
 
 
 
 
 
 
 
345
 
 
 
346
  def convert_to_obj(file: str) -> str:
347
  if file is None:
348
  return None
349
  logger.info(f"Converting {file} to .obj")
 
350
  prefix_path, ext = os.path.splitext(file)
351
  ext = ext.lower()
352
  if ext in FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT:
@@ -355,61 +119,364 @@ def convert_to_obj(file: str) -> str:
355
  convert_step_to_obj_with_freecad(file, response_path)
356
  return response_path
357
  elif ext in GRADIO_3D_MODEL_DEFAULT_FORMAT:
358
- return file
359
- raise Exception(f"Cannot convert file type {ext}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
 
362
- # Render ảnh chiều sâu
363
  def render_depth_images_from_obj(obj_path: str, imsize: int = 512) -> List[np.ndarray]:
364
  mesh = trimesh.load_mesh(obj_path)
365
- points: Tensor = torch.tensor(mesh.vertices).float().unsqueeze(0)
 
 
366
  images: Tensor = pc_views.get_img(points)
367
  images = torch.nn.functional.interpolate(
368
  images, size=(imsize, imsize), mode="bilinear", align_corners=True
369
  )
370
- return [np.array(TF.to_pil_image(img.cpu())) for img in images]
 
 
 
371
 
372
 
373
  def aggregate_images(
374
  np_images: list[np.ndarray], n_rows: int = 2, n_cols: int = 5
375
  ) -> np.ndarray:
376
- img_h, img_w, channels = np_images[0].shape
377
- agg_img = np.zeros(
378
- (img_h * n_rows, img_w * n_cols, channels), dtype=np_images[0].dtype
 
379
  )
 
380
  for i, img in enumerate(np_images):
381
- row, col = i // n_cols, i % n_cols
382
- agg_img[row * img_h : (row + 1) * img_h, col * img_w : (col + 1) * img_w] = img
383
- return agg_img
 
 
 
384
 
 
385
 
386
- # Tạo mô tả từ ảnh
387
- DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT = "You are a manufacturing expert. Given these multi-view depth maps, extract all possible special features relevant to manufacturing. Provide a detailed, structured analysis covering geometry, materials, manufacturing processes, and assembly features. If a feature is not visible, state 'Not visible'."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
 
390
  async def generate_description_from_aggregated_depth_map(np_image: np.ndarray) -> str:
391
- return await llm_service.chat_with_image(
392
- prompt=DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT, image=np_image
393
- )
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
 
396
  async def aget_image_embedding_from_np_image(np_image: np.ndarray):
 
397
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
398
- Image.fromarray(np_image).save(temp_file.name)
399
- image_embedding = await clip_embedding_model.aget_image_embedding(
400
- temp_file.name
401
- )
402
- os.remove(temp_file.name)
 
 
 
 
403
  return image_embedding
404
 
405
 
406
- # Embedding 3D Object
407
  async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
 
408
  depth_images = render_depth_images_from_obj(obj_path=obj_path)
 
409
  aggregated_image = aggregate_images(depth_images)
 
410
  description = await generate_description_from_aggregated_depth_map(
411
  np_image=aggregated_image
412
  )
 
413
  image_embedding = await aget_image_embedding_from_np_image(
414
  np_image=aggregated_image
415
  )
@@ -421,250 +488,205 @@ async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
421
  }
422
 
423
 
424
- # Trích xuất metadata
425
- def extract_step_metadata(file_path):
426
- metadata = {}
427
- try:
428
- with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
429
- content = f.read()
430
- desc_match = re.search(
431
- r"FILE_DESCRIPTION\s*\(\s*\((.*?)\),\s*'(.*?)'\);", content, re.DOTALL
432
- )
433
- if desc_match:
434
- metadata["Description"] = desc_match.group(1).replace("'", "")
435
- name_match = re.search(
436
- r"FILE_NAME\s*\(\s*'(.*?)',.*?,'(.*?)'", content, re.DOTALL
437
- )
438
- if name_match:
439
- metadata["FileName"], metadata["OriginatingSystem"] = (
440
- name_match.group(1),
441
- name_match.group(2),
442
- )
443
- except Exception as e:
444
- logger.error(f"Failed to read STEP file: {e}")
445
- return metadata
446
-
447
-
448
- def dict_to_markdown(metadata: dict) -> str:
449
- return "\\n".join(f"{key}: {value}" for key, value in metadata.items())
450
 
451
 
452
- def parse_3d_file(original_filepath: str):
453
- if original_filepath is None:
454
- return "No file selected."
455
- if original_filepath.lower().endswith((".step")):
456
- meta = extract_step_metadata(original_filepath)
457
- return dict_to_markdown(meta) if meta else "No metadata found in STEP file."
458
- logger.warning(f"No metadata parser for file {original_filepath}")
459
- return "No metadata found."
460
-
461
-
462
- # ==============================================================================
463
- # PHẦN 3: LOGIC CỦA GRADIO APP
464
- # ==============================================================================
465
 
466
 
467
  async def accumulate_and_embedding(input_files, file_list, embedding_dict):
 
468
  if not isinstance(input_files, list):
469
  input_files = [input_files]
470
- new_files = [
471
- f.name for f in input_files if f.name not in [fi.name for fi in file_list]
472
- ]
473
-
474
- for file_path in new_files:
475
- logger.info(f"Processing new upload file: {file_path}")
476
- try:
477
- obj_path = convert_to_obj(file_path)
478
- embeddings = await embedding_3d_object(obj_path)
479
- if obj_path not in embedding_dict:
480
- embedding_dict[obj_path] = {}
481
- embedding_dict[obj_path].update(embeddings)
482
- except Exception as e:
483
- logger.error(f"Failed to process {file_path}: {e}")
484
- gr.Warning(f"Could not process file: {os.path.basename(file_path)}")
485
-
486
- all_file_paths = [f.name for f in input_files]
487
- return (
488
- input_files,
489
- gr.update(
490
- choices=all_file_paths, value=all_file_paths[-1] if all_file_paths else None
491
- ),
492
- embedding_dict,
493
- )
494
-
495
-
496
- def render_3D_object(filepath) -> Tuple[str, str]:
497
- if not filepath:
498
- return None, None
499
- try:
500
- obj_path = convert_to_obj(filepath)
501
- return obj_path, filepath
502
- except Exception as e:
503
- logger.error(f"Failed to render {filepath}: {e}")
504
- gr.Warning(f"Could not render file: {os.path.basename(filepath)}")
505
- return None, None
506
 
 
 
507
 
508
- def render_3D_metadata(
509
- original_filepath: str, obj_path: str, embedding_dict: dict
510
- ) -> Tuple[str, str]:
511
- if not original_filepath or not obj_path:
512
- return "No file selected.", "No description found."
513
- metadata = parse_3d_file(original_filepath=original_filepath)
514
- description = embedding_dict.get(obj_path, {}).get(
515
- "description", "Description not generated yet."
516
- )
517
- return metadata, description
518
 
 
519
 
520
- def find_top_k_similar(query_embedding, embedding_dict, key, top_k=4):
521
- valid_items = [
522
- (path, data[key]) for path, data in embedding_dict.items() if key in data
523
- ]
524
- if not valid_items:
525
- gr.Warning("No embeddings available for search.")
526
- return [None] * top_k + ["-"] * top_k
527
-
528
- filepaths = [item[0] for item in valid_items]
529
- feature_matrix = np.array([item[1] for item in valid_items])
530
- similarities = cosine_similarity(query_embedding.reshape(1, -1), feature_matrix)[0]
531
- scores = sorted(
532
- list(zip(filepaths, similarities)), key=lambda x: x[1], reverse=True
533
- )
534
 
535
- results = [s[0] for s in scores[:top_k]]
536
- result_names = [os.path.basename(s[0]) for s in scores[:top_k]]
537
- # Pad with Nones if less than top_k results
538
- while len(results) < top_k:
539
- results.append(None)
540
- result_names.append("-")
541
- return results + result_names
542
 
543
 
544
- def search_3D_similarity(filepath: str, embedding_dict: dict, top_k: int = 4):
545
- if (
546
- not filepath
547
- or filepath not in embedding_dict
548
- or "image_embedding" not in embedding_dict[filepath]
 
 
549
  ):
550
- gr.Warning("Please select a file with a generated embedding first.")
551
- return [None] * top_k + ["-"] * top_k
552
-
553
- query_embedding = np.array(embedding_dict[filepath]["image_embedding"])
554
- # Exclude the query file itself from the search
555
- search_dict = {k: v for k, v in embedding_dict.items() if k != filepath}
556
- return find_top_k_similar(query_embedding, search_dict, "image_embedding", top_k)
557
-
558
-
559
- def query_3D_object(query: str, embedding_dict: dict, top_k: int = 4):
560
- if not query.strip():
561
- gr.Warning("Query cannot be empty!")
562
- return [None] * top_k + ["-"] * top_k
563
- if len(embedding_dict) < 1:
564
- gr.Warning("Please upload and process at least one 3D file.")
565
- return [None] * top_k + ["-"] * top_k
566
-
567
- query_embedding = np.array(text_embedding_model.get_text_embedding(text=query))
568
- return find_top_k_similar(query_embedding, embedding_dict, "text_embedding", top_k)
569
 
570
 
571
- # ==============================================================================
572
- # PHẦN 4: GIAO DIỆN GRADIO
573
- # ==============================================================================
574
-
575
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
576
- gr.Markdown("# 🚀 Demo Tìm kiếm và Truy vấn CAD 3D")
577
- gr.Markdown(
578
- "Tải lên các file 3D (STEP, FCStd, OBJ, etc.), hệ thống sẽ tự động 'hiểu' và cho phép bạn tìm kiếm theo hình dạng hoặc mô tả văn bản."
 
579
  )
580
-
581
- # State variables
582
- file_state = gr.State([])
583
- embedding_store = gr.State({})
584
-
585
- with gr.Row():
586
- with gr.Column(scale=1):
587
- file_input = gr.File(
588
- file_count="multiple",
589
- label="1. Tải lên File 3D",
590
- file_types=VALID_FILE_TYPES,
591
- )
592
- file_dropdown = gr.Dropdown(
593
- label="2. Chọn File để xem và tìm kiếm", interactive=True
594
- )
595
- sim_button = gr.Button("🔍 Tìm kiếm Tương tự", variant="primary")
596
- query_input = gr.Textbox(
597
- label="Hoặc, truy vấn bằng văn bản",
598
- placeholder="ví dụ: một bộ phận có hai lỗ xuyên...",
599
- )
600
- query_button = gr.Button("💬 Tìm kiếm theo Văn bản", variant="primary")
601
-
602
- with gr.Column(scale=2):
603
- gr.Markdown("### **Trình xem và Thông tin Chi tiết**")
604
- model_render = gr.Model3D(label="Mô hình 3D", height=400, interactive=False)
605
- model_hidden_filepath = gr.Textbox(visible=False)
606
- original_hidden_filepath = gr.Textbox(visible=False)
607
- with gr.Accordion("📝 Mô tả & Metadata", open=False):
608
- description_render = gr.Textbox(label="Mô tả (tạo bởi AI)", lines=8)
609
- metadata_render = gr.Textbox(
610
- label="Metadata (trích xuất từ file)", lines=4
611
- )
612
-
613
  with gr.Row():
614
- gr.Markdown("---")
615
- gr.Markdown("### **Kết quả Tìm kiếm**")
 
 
 
 
 
 
616
 
617
  with gr.Row():
618
  with gr.Column():
619
- gr.Markdown("#### Tương tự về Hình dạng")
 
 
620
  with gr.Row():
621
- model_s_1 = gr.Model3D(label="Top 1", interactive=False)
622
- model_s_2 = gr.Model3D(label="Top 2", interactive=False)
 
 
 
 
 
 
 
623
  with gr.Row():
624
- model_s_3 = gr.Model3D(label="Top 3", interactive=False)
625
- model_s_4 = gr.Model3D(label="Top 4", interactive=False)
 
 
 
 
 
626
  with gr.Column():
627
- gr.Markdown("#### Tương tự về Văn bản")
 
 
 
 
628
  with gr.Row():
629
- model_q_1 = gr.Model3D(label="Top 1", interactive=False)
630
- model_q_2 = gr.Model3D(label="Top 2", interactive=False)
 
 
 
 
631
  with gr.Row():
632
- model_q_3 = gr.Model3D(label="Top 3", interactive=False)
633
- model_q_4 = gr.Model3D(label="Top 4", interactive=False)
 
 
 
 
 
 
 
 
634
 
635
- # Event Handlers
636
- file_input.upload(
637
  fn=accumulate_and_embedding,
638
  inputs=[file_input, file_state, embedding_store],
639
  outputs=[file_state, file_dropdown, embedding_store],
640
  )
641
-
642
- file_dropdown.change(
643
- fn=render_3D_object,
644
- inputs=file_dropdown,
645
- outputs=[model_render, original_hidden_filepath],
646
- ).then(
647
- fn=render_3D_metadata,
648
- inputs=[original_hidden_filepath, model_render, embedding_store],
649
- outputs=[metadata_render, description_render],
 
 
 
 
 
650
  )
651
-
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  sim_button.click(
653
- fn=search_3D_similarity,
654
- inputs=[model_render, embedding_store],
655
- outputs=[
656
  model_s_1,
657
  model_s_2,
658
  model_s_3,
659
  model_s_4,
660
- ], # Chỉ cần cập nhật model, không cần button
 
 
 
 
661
  )
662
-
663
- query_button.click(
664
- fn=query_3D_object,
665
- inputs=[query_input, embedding_store],
666
- outputs=[model_q_1, model_q_2, model_q_3, model_q_4],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  )
668
 
669
  if __name__ == "__main__":
670
- demo.launch()
 
 
 
 
1
  import asyncio
2
+ import os
3
+ import platform
4
  import random
5
+ import subprocess # used to connect to FreeCAD via terminal sub process
6
+ import sys
 
 
7
  import tempfile
8
+ from typing import Any, Dict, List, Tuple
 
9
 
10
+ import gradio as gr # demo with gradio
 
11
  import numpy as np
12
  import torch
13
+ import torchvision.transforms.functional as TF
14
+ import trimesh
15
+ from llama_index.embeddings.clip import ClipEmbedding
16
+ from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode
17
  from loguru import logger
18
+ from PIL import Image
19
  from sklearn.metrics.pairwise import cosine_similarity
20
  from torch import Tensor
 
 
21
 
22
+ from llm_service import LLMService
23
+ from mv_utils_zs import Realistic_Projection
24
 
25
+ os.environ.get("GRADIO_TEMP_DIR", "gradio_cache") # You must set it in `.env` file also
26
+ os_name = platform.system()
27
 
28
+ if os_name == "Linux":
29
+ print("Running on Linux")
30
+ elif os_name == "Darwin":
31
+ print("Running on macOS")
32
+ else:
33
+ print(f"Running on an unsupported OS: {os_name}")
34
 
35
+ # The Gradio 3D Model component default accept
36
  GRADIO_3D_MODEL_DEFAULT_FORMAT = [".obj", ".glb", ".gltf", ".stl", ".splat", ".ply"]
37
  USER_REQUIRE_FORMAT = [".3dxml", ".step"]
38
  FREECAD_LOW_LEVEL_FORMAT = [".step", ".igs", ".iges"]
39
  FREECAD_NATIVE_FORMAT = [".fcstd"]
40
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
41
+
42
+ ####################################################################################################################
43
+ # Transform high-level to low-level
44
+ ####################################################################################################################
45
+ # 3D Component of Gradio only allow some kind of format to render in the UI. We need to transform if need it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
 
48
  def convert_step_to_obj_with_freecad(step_path, obj_path):
49
+ # Path to the FreeCAD executable
50
+ global os_name
51
+ if os_name == "Linux":
52
+ freecad_executable = "/usr/bin/freecadcmd" # freecadcmd
53
+ elif os_name == "Darwin":
54
+ freecad_executable = "/Applications/FreeCAD.app/Contents/MacOS/FreeCAD"
55
+ # Python script to be executed by FreeCAD
56
  _, ext = os.path.splitext(step_path)
57
  ext = ext.lower()
 
58
  if ext in FREECAD_LOW_LEVEL_FORMAT:
59
+ python_script = """
60
+ import FreeCAD
61
+ import Part
62
+ import Mesh
63
+
64
+ doc = FreeCAD.newDocument()
65
+ shape = Part.read("{step_path}")
66
+ obj = doc.addObject("Part::Feature", "MyPart")
67
+ obj.Shape = shape
68
+ doc.recompute()
69
+
70
+ Mesh.export([obj], "{obj_path}")
71
+ """.format(step_path=step_path, obj_path=obj_path)
72
  elif ext in FREECAD_NATIVE_FORMAT:
73
+ python_script = """
74
+ import FreeCAD
75
+ import Part
76
+ import Mesh
77
+
78
+ doc = FreeCAD.open("{step_path}")
79
+ to_export = [o for o in doc.Objects if hasattr(o, 'Shape')]
80
+ Mesh.export(to_export, "{obj_path}")
81
+ """.format(step_path=step_path, obj_path=obj_path)
82
  else:
83
+ logger.error(f"Not support {ext} format")
84
+ raise Exception(f"Not support {ext} format")
85
 
86
+ # Command to run FreeCAD in headless mode with the provided Python script
87
  command = [freecad_executable, "-c", python_script]
88
+
89
+ # Run the command using subprocess
90
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
91
+
92
+ # Capture the output and errors
93
  stdout, stderr = process.communicate()
94
+ return stdout.decode(), stderr.decode()
95
+
 
 
96
 
97
+ # input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/Switches/TS6-THT_H-5.0.step" # ok
98
+ # input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_con_Z16_mod_1_5-Body.stl" # ok
99
+ # input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/nema-17-stepper-motors-coaxial-60-48-39-23mm-1.snapshot.3/NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.step" # ok
100
+ # input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_con_Z16_mod_1_5.FCStd" # ok
101
+ # input_path = "/Users/tridoan/Spartan/Datum/service-ai/poc/resources/notebooks/3d_files/engrenagens-5.snapshot.6/Engre_reta_Z_15_mod_1.FCStd" # ok
102
+ # input_path = "/content/TS6-THT_H-5.0.step"
103
+ # print(".".join(input_path.split(".")[:-1]) + ".obj")
104
+ # stdout, stderr = convert_step_to_obj_with_freecad(input_path, ".".join(input_path.split(".")[:-1]) + ".obj")
105
+ # stderr
106
 
107
+
108
+ # Dummy converter from STEP/3DXML to OBJ (replace with real converter)
109
  def convert_to_obj(file: str) -> str:
110
  if file is None:
111
  return None
112
  logger.info(f"Converting {file} to .obj")
113
+ response_path = file
114
  prefix_path, ext = os.path.splitext(file)
115
  ext = ext.lower()
116
  if ext in FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT:
 
119
  convert_step_to_obj_with_freecad(file, response_path)
120
  return response_path
121
  elif ext in GRADIO_3D_MODEL_DEFAULT_FORMAT:
122
+ return response_path
123
+ else:
124
+ logger.warning(f"Do nothing at convert_to_obj with file {file}")
125
+ raise Exception(f"Do nothing at convert_to_obj with file {file}")
126
+
127
+
128
+ ####################################################################################################################
129
+ # Feature Extraction
130
+ ####################################################################################################################
131
+ # We have 2 approaches to extract 3D's features:
132
+ # - By algorithm which extract something like volume, surface
133
+ # - By 3D deep learning model, which embed the 3D object into vector representing 3D's features
134
+
135
+
136
+ def extract_geometric_features(obj_path: str): # depricated
137
+ try:
138
+ mesh = trimesh.load(obj_path)
139
+ volume = mesh.volume # type: ignore
140
+ surface_area = mesh.area # type: ignore
141
+ print("volume", volume)
142
+ print("surface_area", surface_area)
143
+ # Add other features depending on your needs
144
+ features = np.array([volume, surface_area]).reshape(1, -1)
145
+ return features
146
+ except Exception as e:
147
+ print(f"Error reading file {obj_path}: {e}")
148
+ return None
149
+
150
+
151
+ ####################################################################################################################
152
+ # Similarity Search
153
+ ####################################################################################################################
154
+
155
+
156
+ def search_3D_similarity(filepath: str, embedding_dict: dict, top_k: int = 4):
157
+ if len(embedding_dict) < 5:
158
+ raise gr.Error("Require at least 5 3D files to search similarity")
159
+ if (
160
+ filepath not in embedding_dict
161
+ or "image_embedding" not in embedding_dict[filepath]
162
+ ):
163
+ raise ValueError(f"No embedding found for {filepath}")
164
+
165
+ features1 = np.array(embedding_dict[filepath]["image_embedding"]).reshape(1, -1)
166
+
167
+ # List to store (path, similarity)
168
+ valid_items = [
169
+ (fp, data["image_embedding"])
170
+ for fp, data in embedding_dict.items()
171
+ if "image_embedding" in data and fp != filepath
172
+ ]
173
+ filepaths = [fp for fp, _ in valid_items]
174
+ feature_matrix = np.array([feat for _, feat in valid_items]) # shape: (N, D)
175
+ similarities = cosine_similarity(features1, feature_matrix)[0] # shape: (N,)
176
+ scores = list(zip(filepaths, similarities))
177
+
178
+ # Sort by similarity in descending order
179
+ scores.sort(key=lambda x: x[1], reverse=True)
180
+
181
+ if len(scores) < 4:
182
+ scores.append(("", 0.0))
183
+
184
+ # Return top_k results
185
+ return [x[0] for x in scores[:top_k]] + [
186
+ os.path.basename(x[0]) for x in scores[:top_k]
187
+ ]
188
+
189
+
190
+ ####################################################################################################################
191
+ # Text-based Query
192
+ ####################################################################################################################
193
+
194
+
195
+ def query_3D_object(query: str, embedding_dict: dict, top_k: int = 4):
196
+ if query == "":
197
+ raise gr.Error("Query cannot be empty!")
198
+ if len(embedding_dict) < 4:
199
+ raise gr.Error("Require at least 4 3D files to query by features")
200
+
201
+ features1 = np.array(text_embedding_model.get_text_embedding(text=query)).reshape(
202
+ 1, -1
203
+ )
204
+
205
+ # List to store (path, similarity)
206
+ valid_items = [
207
+ (fp, data["text_embedding"])
208
+ for fp, data in embedding_dict.items()
209
+ if "text_embedding" in data
210
+ ]
211
+ filepaths = [fp for fp, _ in valid_items]
212
+ feature_matrix = np.array([feat for _, feat in valid_items]) # shape: (N, D)
213
+ similarities = cosine_similarity(features1, feature_matrix)[0] # shape: (N,)
214
+ scores = list(zip(filepaths, similarities))
215
+
216
+ # Sort by similarity in descending order
217
+ scores.sort(key=lambda x: x[1], reverse=True)
218
+
219
+ if len(scores) < 4:
220
+ scores.append(("", 0.0))
221
+
222
+ # Return top_k results
223
+ return [x[0] for x in scores[:top_k]] + [
224
+ os.path.basename(x[0]) for x in scores[:top_k]
225
+ ]
226
+
227
+
228
+ ####################################################################################################################
229
+ # Metadata Extraction
230
+ ####################################################################################################################
231
+
232
+ import os
233
+ import xml.etree.ElementTree as ET
234
+ import zipfile
235
+
236
+
237
+ def extract_header_from_3dxml(file_path):
238
+ header_info = {}
239
+
240
+ # Step 1: Unzip the .3DXML file
241
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
242
+ zip_ref.extractall("tmp_3dxml_extract")
243
+
244
+ # Step 2: Find and parse the XML containing <Header>
245
+ for root, dirs, files in os.walk("tmp_3dxml_extract"):
246
+ for file in files:
247
+ if file.endswith((".3dxml", ".xml")):
248
+ xml_path = os.path.join(root, file)
249
+ try:
250
+ tree = ET.parse(xml_path)
251
+ root_el = tree.getroot()
252
+ ns = {
253
+ "ns": root_el.tag.split("}")[0].strip("{")
254
+ } # Extract namespace
255
+
256
+ header = root_el.find("ns:Header", ns)
257
+ if header is not None:
258
+ for child in header:
259
+ tag = child.tag.split("}")[-1] # Remove namespace
260
+ value = child.text.strip() if child.text else ""
261
+ header_info[tag] = value
262
+ except Exception as e:
263
+ print(f"Failed to parse {file}: {e}")
264
+
265
+ return header_info
266
+
267
+
268
+ #######################################################################################################################
269
+
270
+ import re
271
+
272
+
273
+ def extract_step_metadata(file_path):
274
+ metadata = {}
275
+
276
+ try:
277
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
278
+ content = f.read()
279
+
280
+ # Extract FILE_DESCRIPTION
281
+ desc_match = re.search(
282
+ r"FILE_DESCRIPTION\s*\(\s*\((.*?)\),\s*\'(.*?)\'\);", content, re.DOTALL
283
+ )
284
+ if desc_match:
285
+ metadata["Description"] = desc_match.group(1).replace("'", "")
286
+ metadata["Description_Level"] = desc_match.group(2)
287
+
288
+ # Extract FILE_NAME
289
+ name_match = re.search(
290
+ r"FILE_NAME\s*\(\s*'(.*?)',\s*'(.*?)',\s*\((.*?)\),\s*\((.*?)\),\s*'(.*?)',\s*'(.*?)',\s*'(.*?)'\s*\);",
291
+ content,
292
+ re.DOTALL,
293
+ )
294
+ if name_match:
295
+ metadata["FileName"] = name_match.group(1)
296
+ metadata["Created"] = name_match.group(2)
297
+ metadata["Authors"] = name_match.group(3).replace("'", "")
298
+ metadata["Organizations"] = name_match.group(4).replace("'", "")
299
+ metadata["Preprocessor"] = name_match.group(5)
300
+ metadata["OriginatingSystem"] = name_match.group(6)
301
+ metadata["Authorization"] = name_match.group(7)
302
+
303
+ # Extract FILE_SCHEMA
304
+ schema_match = re.search(
305
+ r"FILE_SCHEMA\s*\(\s*\((.*?)\)\s*\);", content, re.DOTALL
306
+ )
307
+ if schema_match:
308
+ metadata["Schema"] = schema_match.group(1).replace("'", "")
309
+
310
+ except Exception as e:
311
+ logger.error(f"Failed to read STEP file: {e}")
312
+
313
+ return metadata
314
+
315
+
316
+ #######################################################################################################################
317
+
318
+
319
+ def dict_to_markdown(metadata: dict) -> str:
320
+ return "\n".join(f"{key}: {value}" for key, value in metadata.items())
321
+
322
+
323
+ #######################################################################################################################
324
+
325
+
326
+ # Dummy parser - Replace with real parser
327
+ def parse_3d_file(original_filepath: str):
328
+ if original_filepath is None:
329
+ return "No file"
330
+ if original_filepath.endswith((".3dxml", ".3DXML")):
331
+ meta = extract_header_from_3dxml(original_filepath)
332
+ text = dict_to_markdown(meta)
333
+ return f"Parsed metadata: {text}"
334
+ elif original_filepath.endswith((".step", ".STEP")):
335
+ meta = extract_step_metadata(original_filepath)
336
+ text = dict_to_markdown(meta)
337
+ return f"Parsed metadata: {text}"
338
+ logger.warning(f"No metadata found in the file {original_filepath}")
339
+ return "No metadata found!"
340
+
341
+
342
+ def render_3D_metadata(
343
+ original_filepath: str, obj_path: str, embedding_dict: dict
344
+ ) -> Tuple[str, str]:
345
+ return parse_3d_file(original_filepath=original_filepath), embedding_dict.get(
346
+ obj_path, {}
347
+ ).get("description", "No description found!")
348
+
349
+
350
+ #######################################################################################################################
351
+ # https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/zeroshot.py#L64
352
+ #######################################################################################################################
353
+
354
+
355
+ pc_views = Realistic_Projection()
356
 
357
 
 
358
  def render_depth_images_from_obj(obj_path: str, imsize: int = 512) -> List[np.ndarray]:
359
  mesh = trimesh.load_mesh(obj_path)
360
+ points: Tensor = torch.tensor(mesh.vertices).float()
361
+ if points.ndim == 2:
362
+ points = points.unsqueeze(0) # (1, N, 3)
363
  images: Tensor = pc_views.get_img(points)
364
  images = torch.nn.functional.interpolate(
365
  images, size=(imsize, imsize), mode="bilinear", align_corners=True
366
  )
367
+ np_images: List[np.ndarray] = []
368
+ for tensor_image in images:
369
+ np_images.append(np.array(TF.to_pil_image(tensor_image.cpu())))
370
+ return np_images
371
 
372
 
373
  def aggregate_images(
374
  np_images: list[np.ndarray], n_rows: int = 2, n_cols: int = 5
375
  ) -> np.ndarray:
376
+ img_height, img_width = np_images[0].shape[:2]
377
+ aggregate_img = np.zeros(
378
+ (img_height * n_rows, img_width * n_cols, np_images[0].shape[2]),
379
+ dtype=np_images[0].dtype,
380
  )
381
+
382
  for i, img in enumerate(np_images):
383
+ row = i // n_cols
384
+ col = i % n_cols
385
+ aggregate_img[
386
+ row * img_height : (row + 1) * img_height,
387
+ col * img_width : (col + 1) * img_width,
388
+ ] = img
389
 
390
+ return aggregate_img
391
 
392
+
393
+ llm_service = LLMService.from_partner()
394
+ # llm_service.model_name = "o3-mini"
395
+
396
+ DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT = """You are a manufacturing expert analyzing 3D objects for production purposes. Given a set of multi-view depth maps of a single object, extract all possible special features relevant to manufacturing.
397
+
398
+ Your output must follow the structured format provided below and be as complete and specific as possible, even if some features are inferred or uncertain.
399
+ ```
400
+ 🔎 Extracted Manufacturing Features from Depth Maps
401
+
402
+ 1. Geometric Features
403
+ Dimensions: <!-- List key dimensions such as height, width, depth, thickness, or aspect ratios. Use units if possible. Mention estimated ranges if exact values are unclear. -->
404
+ Notable Shapes: <!-- Describe the overall shape and form (e.g., cylindrical body with a tapered end, flat rectangular base, spherical top). Mention symmetry or irregularities. -->
405
+ Holes: <!-- Count and describe hole types (e.g., through-holes, blind holes), location if visible, and their arrangement or pattern (e.g., circular array, linear slot). -->
406
+ Surface Features: <!-- Include textures, fillets, chamfers, ribs, grooves, steps, and engravings. Identify raised or recessed areas that are not part of the base shape. -->
407
+ Other: <!-- Any other geometric characteristics not covered above (e.g., draft angles, deformation, cutouts). -->
408
+
409
+ 2. Material-Related Inferences
410
+ Likely Material: <!-- Infer from shape, thickness, or typical use cases (e.g., plastic, aluminum, cast iron). State if uncertain or not visible. -->
411
+ Surface Texture: <!-- Describe the expected finish (e.g., rough, matte, polished) based on depth gradients or edge sharpness. -->
412
+ Durability Hints: <!-- Mention any features that suggest mechanical strength or wear resistance (e.g., thick load-bearing sections, reinforcement patterns). -->
413
+
414
+ 3. Manufacturing-Related Features
415
+ Manufacturing Process: <!-- Suggest most likely processes (e.g., injection molding, CNC milling, casting) based on geometry and typical industry practices. -->
416
+ Draft Angles: <!-- Indicate presence and estimate angles if the object appears designed for mold release. -->
417
+ Undercuts: <!-- Identify any undercut areas that may require complex tooling or multi-part molds. -->
418
+ Mold Flow Considerations: <!-- Comment on how the material might flow during molding or casting, and whether the geometry supports or hinders it. -->
419
+
420
+ 4. Functional and Assembly Features
421
+ Mounting Points: <!-- Identify places where fasteners or brackets might attach (e.g., holes, bosses, flanges). -->
422
+ Jointing Features: <!-- Describe features used to join with other parts, such as snap fits, tabs, slots, dovetails, etc. -->
423
+ Alignment Aids: <!-- Note features like pins, grooves, or guide rails that help align components during assembly. -->
424
+ Modularity: <!-- Assess whether the object is likely part of a modular system based on interface shapes or repeated features. -->
425
+
426
+ 5. Inspection and Quality Features
427
+ Critical Dimensions: <!-- Highlight any dimensions likely to be functionally critical or require tight tolerance. -->
428
+ Surface Finish Zones: <!-- Point out areas that may require fine finishing or polishing for performance or cosmetic reasons. -->
429
+ Datums: <!-- Indicate flat surfaces or edges likely to serve as reference datums during measurement or machining. -->
430
+ Tolerances: <!-- Mention if any tolerances can be inferred, e.g., tight fits, loose clearances, or any standard class assumptions. -->
431
+
432
+ ```
433
+ If any feature cannot be determined from the depth maps, state “Not visible” or “Cannot be inferred.”
434
+ Use clear technical vocabulary appropriate for manufacturing and quality control."""
435
 
436
 
437
  async def generate_description_from_aggregated_depth_map(np_image: np.ndarray) -> str:
438
+ test_prompt = DESCRIPTION_AGGREGATED_DEPTH_MAP_PROMPT
439
+ base64_image = llm_service.encode_image(image=np_image)
440
+ return await llm_service.chat_with_image(prompt=test_prompt, image=base64_image)
441
+
442
+
443
+ clip_embedding_model = ClipEmbedding(
444
+ embed_batch_size=1536, # this parameter does not effect to the model
445
+ )
446
+ text_embedding_model = OpenAIEmbedding(
447
+ mode=OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
448
+ model="text-embedding-3-small",
449
+ api_key=OPENAI_API_KEY,
450
+ dimensions=1536,
451
+ embed_batch_size=512, # default == 100
452
+ )
453
 
454
 
455
  async def aget_image_embedding_from_np_image(np_image: np.ndarray):
456
+ # Save np_image to a temporary file
457
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
458
+ temp_file_path = temp_file.name
459
+ # Convert np_image to PIL Image and save it
460
+ Image.fromarray(np_image).save(temp_file_path)
461
+
462
+ image_embedding = await clip_embedding_model.aget_image_embedding(temp_file_path)
463
+
464
+ # Delete the temporary file after processing
465
+ os.remove(temp_file_path)
466
+
467
  return image_embedding
468
 
469
 
 
470
  async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
471
+ # get 10 depth images
472
  depth_images = render_depth_images_from_obj(obj_path=obj_path)
473
+ # aggregate to single image
474
  aggregated_image = aggregate_images(depth_images)
475
+ # description
476
  description = await generate_description_from_aggregated_depth_map(
477
  np_image=aggregated_image
478
  )
479
+ # embedding aggregated_image: np.ndarray and description: str
480
  image_embedding = await aget_image_embedding_from_np_image(
481
  np_image=aggregated_image
482
  )
 
488
  }
489
 
490
 
491
+ BASE_SAMPLE_DIR = "/Users/tridoan/Spartan/Datum/service-ai/poc/3D/gradio_cache/"
492
+ sample_files = [
493
+ # BASE_SAMPLE_DIR + "C5 Knuckle Object.obj",
494
+ # BASE_SAMPLE_DIR + "NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.obj",
495
+ # BASE_SAMPLE_DIR + "TS6-THT_H-5.0.obj",
496
+ # BASE_SAMPLE_DIR + "TS6-THT_H-11.0.obj"
497
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
 
500
+ #######################################################################################################################
501
+ ## Accumulating and Rendering 3D
502
+ #######################################################################################################################
 
 
 
 
 
 
 
 
 
 
503
 
504
 
505
  async def accumulate_and_embedding(input_files, file_list, embedding_dict):
506
+ # accumulate
507
  if not isinstance(input_files, list):
508
  input_files = [input_files]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
+ all_files = input_files
511
+ new_files = input_files[len(file_list) :]
512
 
513
+ # embedding
514
+ for file_path in new_files:
515
+ logger.info("Processing new upload file:", file_path)
516
+ obj_path = convert_to_obj(file_path)
517
+ embeddings = await embedding_3d_object(obj_path)
518
+ if obj_path not in embedding_dict:
519
+ embedding_dict[obj_path] = {}
520
+ embedding_dict[obj_path]["description"] = embeddings["description"]
521
+ embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
522
+ embedding_dict[obj_path]["text_embedding"] = embeddings["text_embedding"]
523
 
524
+ return all_files, gr.update(choices=all_files), embedding_dict
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
+ def select_file(filename, file_list):
528
+ for file in file_list:
529
+ if file.name == filename:
530
+ with open(file.name, "r", encoding="utf-8", errors="ignore") as f:
531
+ content = f.read()
532
+ return f"Selected: {file.name}\n---\n{content[:300]}..."
533
+ return "File not found."
534
 
535
 
536
+ def render_3D_object(filepath) -> Tuple[str, str]:
537
+ _, ext = os.path.splitext(filepath)
538
+ ext = ext.lower()
539
+ if ext in tuple(GRADIO_3D_MODEL_DEFAULT_FORMAT):
540
+ return filepath, filepath
541
+ if ext in tuple(
542
+ USER_REQUIRE_FORMAT + FREECAD_LOW_LEVEL_FORMAT + FREECAD_NATIVE_FORMAT
543
  ):
544
+ return convert_to_obj(filepath), filepath
545
+ return filepath, filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
 
548
+ #######################################################################################################################
549
+ ## Launching Gradio server
550
+ #######################################################################################################################
551
+ valid_file_types = list(
552
+ set(
553
+ GRADIO_3D_MODEL_DEFAULT_FORMAT
554
+ + USER_REQUIRE_FORMAT
555
+ + FREECAD_NATIVE_FORMAT
556
+ + FREECAD_LOW_LEVEL_FORMAT
557
  )
558
+ )
559
+ valid_file_types = valid_file_types + [t.upper() for t in valid_file_types]
560
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  with gr.Row():
562
+ file_state = gr.State(sample_files)
563
+ ###################################### !IMPORTANT #############################################################
564
+ embedding_store = gr.State({}) ####### !IMPORTANT. This is in memory vector database ##########################
565
+ file_input = gr.File(
566
+ file_count="multiple",
567
+ label="Upload files (You can append more)",
568
+ file_types=valid_file_types,
569
+ )
570
 
571
  with gr.Row():
572
  with gr.Column():
573
+ query_input = gr.Textbox(placeholder="Which 3D CAD contains 2 holes?")
574
+ query_button = gr.Button("Query Search")
575
+
576
  with gr.Row():
577
+ with gr.Row():
578
+ model_q_1 = gr.Model3D(
579
+ label="3D Top 1", interactive=False
580
+ ) # debugging
581
+ model_q_1_btn = gr.Button(value="3D Top 1", size="sm")
582
+ with gr.Row():
583
+ model_q_2 = gr.Model3D(label="3D Top 2", interactive=False)
584
+ model_q_2_btn = gr.Button(value="3D Top 2", size="sm")
585
+
586
  with gr.Row():
587
+ with gr.Row():
588
+ model_q_3 = gr.Model3D(label="3D Top 3", interactive=False)
589
+ model_q_3_btn = gr.Button(value="3D Top 3", size="sm")
590
+ with gr.Row():
591
+ model_q_4 = gr.Model3D(label="3D Top 4", interactive=False)
592
+ model_q_4_btn = gr.Button(value="3D Top 4", size="sm")
593
+
594
  with gr.Column():
595
+ model_render = gr.Model3D(label="3D", height=500, interactive=False)
596
+ model_hidden_filepath = gr.Textbox(visible=False)
597
+ description_render = gr.Textbox(label="Description", lines=6)
598
+ metadata_render = gr.Textbox(label="Metadata", lines=6)
599
+ sim_button = gr.Button("Similarity Search")
600
  with gr.Row():
601
+ with gr.Row():
602
+ model_s_1 = gr.Model3D(label="3D Sim 1", interactive=False)
603
+ model_s_1_btn = gr.Button(value="3D Sim 1", size="sm")
604
+ with gr.Row():
605
+ model_s_2 = gr.Model3D(label="3D Sim 2", interactive=False)
606
+ model_s_2_btn = gr.Button(value="3D Sim 2", size="sm")
607
  with gr.Row():
608
+ with gr.Row():
609
+ model_s_3 = gr.Model3D(label="3D Sim 3", interactive=False)
610
+ model_s_3_btn = gr.Button(value="3D Sim 3", size="sm")
611
+ with gr.Row():
612
+ model_s_4 = gr.Model3D(label="3D Sim 4", interactive=False)
613
+ model_s_4_btn = gr.Button(value="3D Sim 4", size="sm")
614
+ with gr.Column():
615
+ file_dropdown = gr.Dropdown(
616
+ label="Select a file to process", choices=sample_files, interactive=True
617
+ )
618
 
619
+ file_input.change(
 
620
  fn=accumulate_and_embedding,
621
  inputs=[file_input, file_state, embedding_store],
622
  outputs=[file_state, file_dropdown, embedding_store],
623
  )
624
+ # query button
625
+ query_button.click(
626
+ query_3D_object,
627
+ [query_input, embedding_store],
628
+ [
629
+ model_q_1,
630
+ model_q_2,
631
+ model_q_3,
632
+ model_q_4,
633
+ model_q_1_btn,
634
+ model_q_2_btn,
635
+ model_q_3_btn,
636
+ model_q_4_btn,
637
+ ],
638
  )
639
+ # model query
640
+ model_q_1_btn.click(
641
+ render_3D_object, model_q_1, [model_render, model_hidden_filepath]
642
+ )
643
+ model_q_2_btn.click(
644
+ render_3D_object, model_q_2, [model_render, model_hidden_filepath]
645
+ )
646
+ model_q_3_btn.click(
647
+ render_3D_object, model_q_3, [model_render, model_hidden_filepath]
648
+ )
649
+ model_q_4_btn.click(
650
+ render_3D_object, model_q_4, [model_render, model_hidden_filepath]
651
+ )
652
+ # sim button
653
  sim_button.click(
654
+ search_3D_similarity,
655
+ [model_render, embedding_store],
656
+ [
657
  model_s_1,
658
  model_s_2,
659
  model_s_3,
660
  model_s_4,
661
+ model_s_1_btn,
662
+ model_s_2_btn,
663
+ model_s_3_btn,
664
+ model_s_4_btn,
665
+ ],
666
  )
667
+ # model similarity
668
+ model_s_1_btn.click(
669
+ render_3D_object, model_s_1, [model_render, model_hidden_filepath]
670
+ )
671
+ model_s_2_btn.click(
672
+ render_3D_object, model_s_2, [model_render, model_hidden_filepath]
673
+ )
674
+ model_s_3_btn.click(
675
+ render_3D_object, model_s_3, [model_render, model_hidden_filepath]
676
+ )
677
+ model_s_4_btn.click(
678
+ render_3D_object, model_s_4, [model_render, model_hidden_filepath]
679
+ )
680
+ # drop down
681
+ file_dropdown.change(
682
+ render_3D_object, file_dropdown, [model_render, model_hidden_filepath]
683
+ )
684
+ # parse metadata
685
+ model_hidden_filepath.change(
686
+ render_3D_metadata,
687
+ [model_hidden_filepath, model_render, embedding_store],
688
+ [metadata_render, description_render],
689
  )
690
 
691
  if __name__ == "__main__":
692
+ demo.launch(share=True)
encode_image.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile encode_image.py
2
+ import base64
3
+ from typing import Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ Image.MAX_IMAGE_PIXELS = None # Removes the limit, use with caution
10
+
11
+
12
+ def encode_image(image: Union[str, np.ndarray]) -> str:
13
+ """
14
+ Encodes an image as a base64 string.
15
+
16
+ Args:
17
+ image (Union[str, np.ndarray]): Path to the image file or a NumPy array representing the image.
18
+
19
+ Returns:
20
+ str: Base64-encoded image string.
21
+ """
22
+ if isinstance(image, str): # If the input is a file path
23
+ with open(image, "rb") as image_file:
24
+ return base64.b64encode(image_file.read()).decode("utf-8")
25
+ elif isinstance(image, np.ndarray): # If the input is a NumPy array
26
+ _, buffer = cv2.imencode(".jpg", image) # Encode image as JPEG
27
+ return base64.b64encode(buffer).decode("utf-8")
28
+ else:
29
+ raise TypeError("Input must be a file path (str) or a NumPy array.")
llm_service.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile llm_service.py
2
+ import asyncio
3
+ import base64
4
+ import io
5
+ import os
6
+ from enum import Enum
7
+ from typing import List, Tuple, Union, cast
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from openai import AsyncOpenAI
12
+ from PIL import Image
13
+ from loguru import logger
14
+
15
+ from encode_image import encode_image
16
+ from string_utils import StringUtils
17
+
18
+ Image.MAX_IMAGE_PIXELS = None # Removes the limit, use with caution
19
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
20
+
21
+
22
+ class OpenAIService:
23
+ def __init__(self):
24
+ # self.llm_settings = getattr(settings.llm, settings.llm.name)
25
+ self.model_name = "gpt-4o" # settings.llm.openai.model
26
+ self.temperature = 0.3 # settings.llm.openai.temperature
27
+ self.client = AsyncOpenAI(api_key=OPENAI_API_KEY)
28
+ # Follow the documentation: https://platform.openai.com/docs/models
29
+ self.deprecated_temperature_models = [
30
+ "o4-mini",
31
+ "o4",
32
+ "o3-mini",
33
+ "o3",
34
+ ] # settings.llm.openai.deprecated_temperature_models
35
+
36
+ @staticmethod
37
+ def encode_image(image: Union[str, np.ndarray]) -> str:
38
+ return encode_image(image=image)
39
+
40
+ def get_temperature(self, temperature: float | None) -> dict:
41
+ return (
42
+ {
43
+ "temperature": temperature
44
+ if temperature is not None
45
+ else self.temperature
46
+ }
47
+ if self.model_name not in self.deprecated_temperature_models
48
+ else {}
49
+ )
50
+
51
+ async def chat_with_text(
52
+ self,
53
+ prompt: str,
54
+ return_as_json: bool = False,
55
+ retry_left: int = 3, # settings.llm.openai.retry_left,
56
+ temperature: float | None = None,
57
+ ) -> str:
58
+ """
59
+ Sends a text-based chat prompt to the OpenAI model.
60
+
61
+ Args:
62
+ prompt (str): User input text.
63
+ return_as_json (bool): whether to generate output as a json object
64
+ retry_left (int): number of retries left
65
+ temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
66
+
67
+ Returns:
68
+ str: Response from the model.
69
+ """
70
+
71
+ model_kwargs = {
72
+ "model": self.model_name,
73
+ "messages": [
74
+ {"role": "system", "content": "You are a helpful assistant."},
75
+ {"role": "user", "content": prompt},
76
+ ],
77
+ **self.get_temperature(temperature=temperature),
78
+ }
79
+
80
+ if return_as_json:
81
+ model_kwargs["response_format"] = {"type": "json_object"}
82
+
83
+ try:
84
+ response = await self.client.chat.completions.create(**model_kwargs)
85
+ except Exception as e:
86
+ if retry_left > 0:
87
+ logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
88
+ await asyncio.sleep(1) # quota out
89
+ return await self.chat_with_text(
90
+ prompt=prompt,
91
+ return_as_json=return_as_json,
92
+ retry_left=retry_left - 1,
93
+ temperature=temperature,
94
+ )
95
+ else:
96
+ logger.error(
97
+ f"OpenAI API calling failed due to {e}. Return empty string!"
98
+ )
99
+ return ""
100
+
101
+ return response.choices[0].message.content
102
+
103
+ async def chat_with_image(
104
+ self,
105
+ prompt: str,
106
+ image: str,
107
+ return_as_json: bool = False,
108
+ retry_left: int = 3, # settings.llm.openai.retry_left,
109
+ temperature: float | None = None,
110
+ ) -> str:
111
+ """
112
+ Sends an image along with a text prompt to the OpenAI model.
113
+
114
+ Args:
115
+ prompt (str): User input text.
116
+ image_path (str): Path to the image file.
117
+ return_as_json (bool): whether to generate output as a json object
118
+ retry_left (int): number of retries left
119
+ temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
120
+
121
+ Returns:
122
+ str: Response from the model.
123
+ """
124
+ if os.path.isfile(image):
125
+ base64_image = self.encode_image(image=image)
126
+ elif StringUtils.is_base64(image):
127
+ base64_image = image
128
+ else:
129
+ raise Exception(
130
+ "ServiceAiError.UNSUPPORT_INPUT_IMAGE_TYPE.as_http_exception()"
131
+ )
132
+
133
+ model_kwargs = {
134
+ "model": self.model_name,
135
+ "messages": [
136
+ {
137
+ "role": "user",
138
+ "content": [
139
+ {"type": "text", "text": prompt},
140
+ {
141
+ "type": "image_url",
142
+ "image_url": {
143
+ "url": f"data:image/jpeg;base64,{base64_image}"
144
+ },
145
+ },
146
+ ],
147
+ }
148
+ ],
149
+ **self.get_temperature(temperature=temperature),
150
+ }
151
+
152
+ if return_as_json:
153
+ model_kwargs["response_format"] = {"type": "json_object"}
154
+
155
+ try:
156
+ response = await self.client.chat.completions.create(**model_kwargs)
157
+ except Exception as e:
158
+ if retry_left > 0:
159
+ logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
160
+ await asyncio.sleep(1) # quota out
161
+ return await self.chat_with_image(
162
+ prompt=prompt,
163
+ image=image,
164
+ return_as_json=return_as_json,
165
+ retry_left=retry_left - 1,
166
+ temperature=temperature,
167
+ )
168
+ else:
169
+ logger.error(
170
+ f"OpenAI API calling failed due to {e}. Return empty string!"
171
+ )
172
+ return ""
173
+ return response.choices[0].message.content
174
+
175
+ async def chat_with_multiple_images(
176
+ self,
177
+ prompt: str,
178
+ images: list[str],
179
+ return_as_json: bool = False,
180
+ retry_left: int = 3, # settings.llm.openai.retry_left,
181
+ temperature: float | None = None,
182
+ ) -> str:
183
+ """
184
+ Sends multiple images along with a text prompt to the OpenAI model.
185
+ Args:
186
+ prompt (str): User input text.
187
+ images (list[str]): List of base64 encoded images.
188
+ return_as_json (bool): whether to generate output as a json object
189
+ retry_left (int): number of retries left
190
+ temperature (float | None): Controls randomness in the response. Lower values make responses more focused and deterministic.
191
+ Returns:
192
+ list[str]: Responses from the model for each image.
193
+ """
194
+ if len(images) == 0:
195
+ logger.warning("OpenAI chats with multiple images mode without any images")
196
+
197
+ base64_images = []
198
+ for image in images:
199
+ if os.path.isfile(image):
200
+ base64_images.append(self.encode_image(image=image))
201
+ elif StringUtils.is_base64(image):
202
+ base64_images.append(image)
203
+ else:
204
+ raise Exception(
205
+ "ServiceAiError.UNSUPPORT_INPUT_IMAGE_TYPE.as_http_exception()"
206
+ )
207
+
208
+ model_kwargs = {
209
+ "model": self.model_name,
210
+ "messages": [
211
+ {
212
+ "role": "user",
213
+ "content": [
214
+ {"type": "text", "text": prompt},
215
+ *[
216
+ {
217
+ "type": "image_url",
218
+ "image_url": {
219
+ "url": f"data:image/jpeg;base64,{base64_image}"
220
+ },
221
+ }
222
+ for base64_image in base64_images
223
+ ],
224
+ ],
225
+ }
226
+ ],
227
+ **self.get_temperature(temperature=temperature),
228
+ }
229
+
230
+ if return_as_json:
231
+ model_kwargs["response_format"] = {"type": "json_object"}
232
+
233
+ try:
234
+ response = await self.client.chat.completions.create(**model_kwargs)
235
+ except Exception as e:
236
+ if retry_left > 0:
237
+ logger.warning(f"OpenAI API calling failed due to {e}. Retry!")
238
+ await asyncio.sleep(1) # quota out
239
+ return await self.chat_with_multiple_images(
240
+ prompt=prompt,
241
+ images=images,
242
+ return_as_json=return_as_json,
243
+ retry_left=retry_left - 1,
244
+ temperature=temperature,
245
+ )
246
+ else:
247
+ logger.error(
248
+ f"OpenAI API calling failed due to {e}. Return empty list!"
249
+ )
250
+ return ""
251
+
252
+ return response.choices[0].message.content
253
+
254
+
255
+ class LLMService:
256
+ @classmethod
257
+ def from_partner(cls):
258
+ return OpenAIService()
mv_utils_zs.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile mv_utils_zs.py
2
+ """
3
+ Author: yangyangyang127
4
+ Github: https://github.com/yangyangyang127
5
+ Repo: https://github.com/yangyangyang127/PointCLIP_V2
6
+ Path: https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/mv_utils_zs.py#L135
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch_scatter import scatter
13
+
14
+ TRANS = -1.5
15
+
16
+ # realistic projection parameters
17
+ params = {
18
+ "maxpoolz": 1,
19
+ "maxpoolxy": 7,
20
+ "maxpoolpadz": 0,
21
+ "maxpoolpadxy": 2,
22
+ "convz": 1,
23
+ "convxy": 3,
24
+ "convsigmaxy": 3,
25
+ "convsigmaz": 1,
26
+ "convpadz": 0,
27
+ "convpadxy": 1,
28
+ "imgbias": 0.0,
29
+ "depth_bias": 0.2,
30
+ "obj_ratio": 0.8,
31
+ "bg_clr": 0.0,
32
+ "resolution": 122,
33
+ "depth": 8, # default = 8
34
+ "grid_height": 64,
35
+ "grid_width": 64,
36
+ }
37
+
38
+
39
+ class Grid2Image(nn.Module):
40
+ """A pytorch implementation to turn 3D grid to 2D image.
41
+ Maxpool: densifying the grid
42
+ Convolution: smoothing via Gaussian
43
+ Maximize: squeezing the depth channel
44
+ """
45
+
46
+ def __init__(self):
47
+ super().__init__()
48
+ torch.backends.cudnn.benchmark = False
49
+
50
+ self.maxpool = nn.MaxPool3d(
51
+ (params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
52
+ stride=1,
53
+ padding=(
54
+ params["maxpoolpadz"],
55
+ params["maxpoolpadxy"],
56
+ params["maxpoolpadxy"],
57
+ ),
58
+ )
59
+ self.conv = torch.nn.Conv3d(
60
+ 1,
61
+ 1,
62
+ kernel_size=(params["convz"], params["convxy"], params["convxy"]),
63
+ stride=1,
64
+ padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
65
+ bias=True,
66
+ )
67
+ kn3d = get3DGaussianKernel(
68
+ params["convxy"],
69
+ params["convz"],
70
+ sigma=params["convsigmaxy"],
71
+ zsigma=params["convsigmaz"],
72
+ )
73
+ self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
74
+ self.conv.bias.data.fill_(0)
75
+
76
+ def forward(self, x):
77
+ x = self.maxpool(x.unsqueeze(1))
78
+ x = self.conv(x)
79
+ img = torch.max(x, dim=2)[0]
80
+ img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
81
+ img = 1 - img
82
+ img = img.repeat(1, 3, 1, 1)
83
+ return img
84
+
85
+
86
+ def euler2mat(angle):
87
+ """Convert euler angles to rotation matrix.
88
+ :param angle: [3] or [b, 3]
89
+ :return
90
+ rotmat: [3] or [b, 3, 3]
91
+ source
92
+ https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
93
+ """
94
+ if len(angle.size()) == 1:
95
+ x, y, z = angle[0], angle[1], angle[2]
96
+ _dim = 0
97
+ _view = [3, 3]
98
+ elif len(angle.size()) == 2:
99
+ b, _ = angle.size()
100
+ x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
101
+ _dim = 1
102
+ _view = [b, 3, 3]
103
+
104
+ else:
105
+ assert False
106
+
107
+ cosz = torch.cos(z)
108
+ sinz = torch.sin(z)
109
+
110
+ # zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
111
+ # one = torch.ones([b], requires_grad=False, device=angle.device)[0]
112
+ zero = z.detach() * 0
113
+ one = zero.detach() + 1
114
+ zmat = torch.stack(
115
+ [cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
116
+ ).reshape(_view)
117
+
118
+ cosy = torch.cos(y)
119
+ siny = torch.sin(y)
120
+
121
+ ymat = torch.stack(
122
+ [cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
123
+ ).reshape(_view)
124
+
125
+ cosx = torch.cos(x)
126
+ sinx = torch.sin(x)
127
+
128
+ xmat = torch.stack(
129
+ [one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
130
+ ).reshape(_view)
131
+
132
+ rot_mat = xmat @ ymat @ zmat
133
+ # print(rot_mat)
134
+ return rot_mat
135
+
136
+
137
+ def points_to_2d_grid(
138
+ points, grid_h=params["grid_height"], grid_w=params["grid_width"]
139
+ ):
140
+ """
141
+ Chuyển đổi point cloud thành lưới 2D dựa trên tọa độ X, Y.
142
+ Các điểm được chiếu lên một mặt phẳng và được lượng tử hóa vào các ô lưới.
143
+
144
+ Args:
145
+ points (torch.tensor): Tensor chứa các điểm, kích thước [B, P, 3]
146
+ (B: batch size, P: số lượng điểm, 3: tọa độ x, y, z)
147
+ grid_h (int): Chiều cao của lưới 2D đầu ra.
148
+ grid_w (int): Chiều rộng của lưới 2D đầu ra.
149
+
150
+ Returns:
151
+ grid (torch.tensor): Lưới 2D biểu diễn sự chiếm dụng của các điểm,
152
+ kích thước [B, grid_h, grid_w].
153
+ Giá trị 1.0 tại ô (y, x) nếu có ít nhất một điểm rơi vào đó,
154
+ ngược lại là giá trị nền (params["bg_clr"]).
155
+ """
156
+ batch, pnum, _ = points.shape
157
+ device = points.device
158
+
159
+ # --- Bước 1: Chuẩn hóa tọa độ điểm ---
160
+ # Tìm min/max cho từng point cloud trong batch (chỉ xét X, Y để chuẩn hóa 2D tốt hơn)
161
+ pmax_xy = points[:, :, :2].max(dim=1)[0]
162
+ pmin_xy = points[:, :, :2].min(dim=1)[0]
163
+
164
+ # Tính tâm và phạm vi dựa trên X, Y
165
+ pcent_xy = (pmax_xy + pmin_xy) / 2
166
+ pcent_xy = pcent_xy[:, None, :] # Thêm chiều P để broadcast [B, 1, 2]
167
+
168
+ # Sử dụng phạm vi lớn nhất giữa X và Y để giữ tỷ lệ aspect ratio
169
+ prange_xy = (pmax_xy - pmin_xy).max(dim=-1)[0][:, None, None] # [B, 1, 1]
170
+
171
+ # Thêm một epsilon nhỏ để tránh chia cho 0 nếu tất cả các điểm trùng nhau
172
+ epsilon = 1e-8
173
+ # Chỉ chuẩn hóa X, Y vào khoảng [-1, 1] dựa trên phạm vi X, Y
174
+ # (points[:, :, :2] - pcent_xy) -> [B, P, 2]
175
+ # prange_xy -> [B, 1, 1]
176
+ points_normalized_xy = (points[:, :, :2] - pcent_xy) / (prange_xy + epsilon) * 2.0
177
+
178
+ # Điều chỉnh tỷ lệ theo obj_ratio (nếu cần)
179
+ points_normalized_xy = points_normalized_xy * params["obj_ratio"]
180
+
181
+ # --- Bước 2: Ánh xạ tọa độ chuẩn hóa vào chỉ số lưới 2D ---
182
+ # Ánh xạ X từ khoảng [-obj_ratio, obj_ratio] -> [0, grid_w]
183
+ # Ánh xạ Y từ khoảng [-obj_ratio, obj_ratio] -> [0, grid_h]
184
+ # Công thức chung: (normalized_coord + scale) / (2 * scale) * grid_dim
185
+ _x = (
186
+ (points_normalized_xy[:, :, 0] + params["obj_ratio"])
187
+ / (2 * params["obj_ratio"])
188
+ * grid_w
189
+ )
190
+ _y = (
191
+ (points_normalized_xy[:, :, 1] + params["obj_ratio"])
192
+ / (2 * params["obj_ratio"])
193
+ * grid_h
194
+ )
195
+
196
+ # Làm tròn xuống để xác định chỉ số ô lưới (index)
197
+ _x = torch.floor(_x).long()
198
+ _y = torch.floor(_y).long()
199
+
200
+ # --- Bước 3: Giới hạn chỉ số vào phạm vi hợp lệ của lưới ---
201
+ # Clip _x vào [0, grid_w - 1]
202
+ # Clip _y vào [0, grid_h - 1]
203
+ _x = torch.clip(_x, 0, grid_w - 1)
204
+ _y = torch.clip(_y, 0, grid_h - 1)
205
+
206
+ # --- Bước 4: Tạo lưới 2D và đánh dấu các ô bị chiếm dụng ---
207
+ # Khởi tạo lưới 2D với giá trị nền
208
+ grid = torch.full(
209
+ (batch, grid_h, grid_w), params["bg_clr"], dtype=torch.float32, device=device
210
+ )
211
+
212
+ # Tạo chỉ số batch tương ứng với mỗi điểm
213
+ batch_indices = torch.arange(batch, device=device).view(-1, 1).repeat(1, pnum)
214
+
215
+ # Flatten các chỉ số để dễ dàng gán giá trị
216
+ batch_idx_flat = batch_indices.view(-1)
217
+ y_idx_flat = _y.view(-1)
218
+ x_idx_flat = _x.view(-1)
219
+
220
+ # Gán giá trị 1.0 vào các ô lưới (y, x) tương ứng với vị trí các điểm
221
+ # Nếu nhiều điểm rơi vào cùng một ô, ô đó vẫn chỉ có giá trị 1.0
222
+ grid[batch_idx_flat, y_idx_flat, x_idx_flat] = 1.0
223
+
224
+ return grid
225
+
226
+
227
+ def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
228
+ """Quantize each point cloud to a 3D grid.
229
+ Args:
230
+ points (torch.tensor): of size [B, _, 3]
231
+ Returns:
232
+ grid (torch.tensor): of size [B * self.num_views, depth, resolution, resolution]
233
+ """
234
+
235
+ batch, pnum, _ = points.shape
236
+
237
+ pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
238
+ pcent = (pmax + pmin) / 2
239
+ pcent = pcent[:, None, :]
240
+ prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
241
+ points = (points - pcent) / prange * 2.0
242
+ points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]
243
+
244
+ depth_bias = params["depth_bias"]
245
+ _x = (points[:, :, 0] + 1) / 2 * resolution
246
+ _y = (points[:, :, 1] + 1) / 2 * resolution
247
+ _z = ((points[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
248
+
249
+ _x.ceil_()
250
+ _y.ceil_()
251
+ z_int = _z.ceil()
252
+
253
+ _x = torch.clip(_x, 1, resolution - 2)
254
+ _y = torch.clip(_y, 1, resolution - 2)
255
+ _z = torch.clip(_z, 1, depth - 2)
256
+
257
+ coordinates = z_int * resolution * resolution + _y * resolution + _x
258
+ grid = (
259
+ torch.ones([batch, depth, resolution, resolution], device=points.device).view(
260
+ batch, -1
261
+ )
262
+ * params["bg_clr"]
263
+ )
264
+
265
+ # # *** THAY ĐỔI CHÍNH Ở ĐÂY ***
266
+ # # Tạo tensor nguồn (src) chứa giá trị 1.0 cho mỗi điểm
267
+ # # Kích thước phải phù hợp với coordinates khi flatten: [B * pnum]
268
+ # values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=points.device)
269
+
270
+ # # Scatter giá trị 1.0 vào grid tại các vị trí `coordinates`
271
+ # # Sử dụng reduce="max". Nếu ô có ít nhất một điểm, max(1.0, bg_clr) sẽ là 1.0 (nếu bg_clr <= 1)
272
+ # # Nếu muốn chắc chắn là 1 bất kể bg_clr, có thể dùng reduce khác hoặc xử lý sau scatter.
273
+ # # Lựa chọn an toàn hơn nếu bg_clr có thể > 1 là khởi tạo grid bằng 0 và dùng reduce='max'/'mean'
274
+ # # Hoặc khởi tạo bằng bg_clr và xử lý sau scatter.
275
+ # # Giả định bg_clr = 0.0 là phổ biến nhất cho occupancy grid.
276
+
277
+ # grid = scatter(
278
+ # values_to_scatter,
279
+ # coordinates.view(-1).long(), # Flatten coordinates thành [B*pnum]
280
+ # dim=0, # Scatter trên chiều 0 của grid đã flatten [B*D*R*R]
281
+ # # Cần chỉ số batch tương ứng nếu grid chưa flatten theo batch
282
+ # out=grid.view(-1), # Flatten grid thành [B*D*R*R] để scatter trên dim 0
283
+ # reduce="max",
284
+ # ) # Nếu có điểm -> giá trị ô là 1, nếu không là bg_clr
285
+ # # **********************************
286
+
287
+ grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
288
+ grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
289
+
290
+ return grid
291
+
292
+
293
+ # Giả sử bạn có thư viện scatter, ví dụ: from torch_scatter import scatter
294
+ # Hoặc hàm scatter tương đương
295
+ # import torch # Đảm bảo đã import torch
296
+ # from torch_scatter import scatter # Ví dụ
297
+
298
+
299
+ def points_to_occupancy_grid(
300
+ points, resolution=params["resolution"], depth=params["depth"]
301
+ ):
302
+ """Quantize each point cloud to a 3D occupancy grid."""
303
+
304
+ batch, pnum, _ = points.shape
305
+ device = points.device # Lấy device để tạo tensor mới
306
+
307
+ # --- Phần chuẩn hóa và ánh xạ tọa độ giữ nguyên ---
308
+ pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
309
+ pcent = (pmax + pmin) / 2
310
+ pcent = pcent[:, None, :]
311
+ prange = (pmax - pmin).max(dim=-1)[0][
312
+ :, None, None
313
+ ] + 1e-8 # Thêm epsilon tránh chia 0
314
+ points_norm = (points - pcent) / prange * 2.0
315
+ points_norm[:, :, :2] = points_norm[:, :, :2] * params["obj_ratio"]
316
+
317
+ depth_bias = params["depth_bias"]
318
+ _x = (points_norm[:, :, 0] + 1) / 2 * resolution
319
+ _y = (points_norm[:, :, 1] + 1) / 2 * resolution
320
+ _z = ((points_norm[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
321
+
322
+ _x.ceil_()
323
+ _y.ceil_()
324
+ z_int = _z.ceil()
325
+
326
+ _x = torch.clip(_x, 1, resolution - 2)
327
+ _y = torch.clip(_y, 1, resolution - 2)
328
+ # z_int cũng nên được clip nếu dùng làm chỉ số tọa độ
329
+ z_int = torch.clip(z_int, 1, depth - 2)
330
+
331
+ # --- Tính toán flattened coordinates giữ nguyên ---
332
+ coordinates = z_int * resolution * resolution + _y * resolution + _x
333
+ coordinates = coordinates.long() # Chuyển sang Long
334
+
335
+ # --- Tạo Grid và Scatter ---
336
+ # Khởi tạo grid với giá trị nền (ví dụ: 0)
337
+ # Sử dụng torch.zeros thay vì torch.ones và nhân bg_clr
338
+ bg_clr_value = params.get("bg_clr", 0.0) # Lấy bg_clr, mặc định là 0
339
+ grid = torch.full(
340
+ (batch, depth * resolution * resolution),
341
+ bg_clr_value,
342
+ dtype=torch.float32, # Hoặc dtype phù hợp
343
+ device=device,
344
+ )
345
+
346
+ # *** THAY ĐỔI CHÍNH Ở ĐÂY ***
347
+ # Tạo tensor nguồn (src) chứa giá trị 1.0 cho mỗi điểm
348
+ # Kích thước phải phù hợp với coordinates khi flatten: [B * pnum]
349
+ values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=device)
350
+
351
+ # Scatter giá trị 1.0 vào grid tại các vị trí `coordinates`
352
+ # Sử dụng reduce="max". Nếu ô có ít nhất một điểm, max(1.0, bg_clr) sẽ là 1.0 (nếu bg_clr <= 1)
353
+ # Nếu muốn chắc chắn là 1 bất kể bg_clr, có thể dùng reduce khác hoặc xử lý sau scatter.
354
+ # Lựa chọn an toàn hơn nếu bg_clr có thể > 1 là khởi tạo grid bằng 0 và dùng reduce='max'/'mean'
355
+ # Hoặc khởi tạo bằng bg_clr và xử lý sau scatter.
356
+ # Giả định bg_clr = 0.0 là phổ biến nhất cho occupancy grid.
357
+ if bg_clr_value != 0.0:
358
+ print(
359
+ "Warning: bg_clr is not 0.0, occupancy grid might not be strictly binary 0/1 with reduce='max'. Consider initializing grid with 0."
360
+ )
361
+
362
+ grid = scatter(
363
+ values_to_scatter,
364
+ coordinates.view(-1), # Flatten coordinates thành [B*pnum]
365
+ dim=0, # Scatter trên chiều 0 của grid đã flatten [B*D*R*R]
366
+ # Cần chỉ số batch tương ứng nếu grid chưa flatten theo batch
367
+ out=grid.view(-1), # Flatten grid thành [B*D*R*R] để scatter trên dim 0
368
+ reduce="max",
369
+ ) # Nếu có điểm -> giá trị ô là 1, nếu không là bg_clr
370
+
371
+ # --- Reshape và Permute giữ nguyên ---
372
+ # Reshape lại grid về đúng kích thước 3D + batch
373
+ # Lưu ý: scatter vào grid đã flatten cần reshape cẩn thận
374
+ grid = grid.view(batch, depth, resolution, resolution) # Reshape lại
375
+ grid = grid.permute((0, 1, 3, 2))
376
+
377
+ return grid
378
+
379
+
380
+ class Realistic_Projection:
381
+ """For creating images from PC based on the view information."""
382
+
383
+ def __init__(self):
384
+ _views = np.asarray([
385
+ [[1 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
386
+ [[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
387
+ [[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
388
+ [[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
389
+ [[0 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
390
+ [[1 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
391
+ [[2 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
392
+ [[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
393
+ [[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
394
+ [[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
395
+ ])
396
+
397
+ # adding some bias to the view angle to reveal more surface
398
+ _views_bias = np.asarray([
399
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
400
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
401
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
402
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
403
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
404
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
405
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
406
+ [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
407
+ [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
408
+ [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
409
+ ])
410
+
411
+ self.num_views = _views.shape[0]
412
+
413
+ angle = torch.tensor(_views[:, 0, :]).float() # .cuda()
414
+ self.rot_mat = euler2mat(angle).transpose(1, 2)
415
+ angle2 = torch.tensor(_views_bias[:, 0, :]).float() # .cuda()
416
+ self.rot_mat2 = euler2mat(angle2).transpose(1, 2)
417
+
418
+ self.translation = torch.tensor(_views[:, 1, :]).float() # .cuda()
419
+ self.translation = self.translation.unsqueeze(1)
420
+
421
+ self.grid2image = Grid2Image() # .cuda()
422
+
423
+ def get_img(self, points):
424
+ b, _, _ = points.shape
425
+ v = self.translation.shape[0]
426
+
427
+ _points = self.point_transform(
428
+ points=torch.repeat_interleave(points, v, dim=0),
429
+ rot_mat=self.rot_mat.repeat(b, 1, 1),
430
+ rot_mat2=self.rot_mat2.repeat(b, 1, 1),
431
+ translation=self.translation.repeat(b, 1, 1),
432
+ )
433
+
434
+ grid = points2grid(
435
+ points=_points, resolution=params["resolution"], depth=params["depth"]
436
+ ).squeeze()
437
+ img = self.grid2image(grid)
438
+ return img
439
+
440
+ @staticmethod
441
+ def point_transform(points, rot_mat, rot_mat2, translation):
442
+ """
443
+ :param points: [batch, num_points, 3]
444
+ :param rot_mat: [batch, 3]
445
+ :param rot_mat2: [batch, 3]
446
+ :param translation: [batch, 1, 3]
447
+ :return:
448
+ """
449
+ rot_mat = rot_mat.to(points.device)
450
+ rot_mat2 = rot_mat2.to(points.device)
451
+ translation = translation.to(points.device)
452
+ points = torch.matmul(points, rot_mat)
453
+ points = torch.matmul(points, rot_mat2)
454
+ points = points - translation
455
+ return points
456
+
457
+
458
+ def get2DGaussianKernel(ksize, sigma=0):
459
+ center = ksize // 2
460
+ xs = np.arange(ksize, dtype=np.float32) - center
461
+ kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
462
+ kernel = kernel1d[..., None] @ kernel1d[None, ...]
463
+ kernel = torch.from_numpy(kernel)
464
+ kernel = kernel / kernel.sum()
465
+ return kernel
466
+
467
+
468
+ # Without numpy
469
+ # def get2DGaussianKernel(ksize, sigma):
470
+ # xs = torch.linspace(-(ksize // 2), ksize // 2, steps=ksize)
471
+ # kernel1d = torch.exp(-(xs ** 2) / (2 * sigma ** 2))
472
+ # kernel2d = torch.outer(kernel1d, kernel1d)
473
+ # kernel2d /= kernel2d.sum()
474
+ # return kernel2d
475
+
476
+
477
+ def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
478
+ kernel2d = get2DGaussianKernel(ksize, sigma)
479
+ zs = np.arange(depth, dtype=np.float32) - depth // 2
480
+ zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
481
+ kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
482
+ kernel3d = kernel3d / torch.sum(kernel3d)
483
+ return kernel3d
string_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile string_utils.py
2
+ import base64
3
+ import random
4
+ import re
5
+ import string
6
+ from urllib.parse import urlparse
7
+
8
+
9
+ class StringUtils:
10
+ @staticmethod
11
+ def generate_random_string(length: int = 32) -> str:
12
+ characters = string.ascii_letters + string.digits
13
+ random_string = "".join(random.choice(characters) for _ in range(length))
14
+ return random_string
15
+
16
+ @staticmethod
17
+ def clean_string(input_string: str) -> str:
18
+ # Remove non-ASCII characters
19
+ cleaned_string = re.sub(r"[^\x00-\x7F]+", " ", input_string)
20
+
21
+ # Consolidate spaces and ensure correct spacing around punctuation
22
+ cleaned_string = re.sub(r"\s*([.,;!?%:])\s*", r"\1 ", cleaned_string)
23
+
24
+ # Adjust spacing for the dollar sign
25
+ cleaned_string = re.sub(r"\$\s+", "$", cleaned_string)
26
+
27
+ # Ensure correct spacing inside parentheses around numbers
28
+ cleaned_string = re.sub(r"\(\s*(\d+)\s*\)", r"( \1 )", cleaned_string)
29
+
30
+ # Remove extra spaces around punctuation (this might be redundant but ensures
31
+ # no trailing space before punctuation)
32
+ cleaned_string = re.sub(r"\s+([.,;!?%:])", r"\1", cleaned_string)
33
+
34
+ # Remove leading and trailing whitespace, reduce multiple spaces to a single
35
+ # space, and convert to lower case
36
+ cleaned_string = re.sub(r"\s+", " ", cleaned_string).strip().lower()
37
+
38
+ return cleaned_string
39
+
40
+ @staticmethod
41
+ def get_file_name_without_extension(file_name: str) -> str:
42
+ return ".".join(file_name.split(".")[:-1])
43
+
44
+ @staticmethod
45
+ def is_valid_url(url: str):
46
+ try:
47
+ result = urlparse(url)
48
+ return all([result.scheme, result.netloc])
49
+ except ValueError:
50
+ return False
51
+
52
+ @staticmethod
53
+ def is_base64(string: str) -> bool:
54
+ """
55
+ Validates if the input string is a Base64-encoded string.
56
+
57
+ Args:
58
+ string (str): The string to validate.
59
+
60
+ Returns:
61
+ bool: True if the string is Base64, False otherwise.
62
+ """
63
+ try:
64
+ # Check if the string can be decoded
65
+ base64_bytes = base64.b64decode(string, validate=True)
66
+ # Check if decoded bytes can be re-encoded to the original string
67
+ return base64.b64encode(base64_bytes).decode("utf-8") == string
68
+ except Exception:
69
+ return False