KublaiKhan1 commited on
Commit
acd627a
·
verified ·
1 Parent(s): b4d7285

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision.transforms import v2 as transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ import cv2
8
+ from torchvision.transforms.v2 import functional
9
+
10
+
11
+ # Constants
12
+ RESIZE_DIM = 224
13
+ NORMALIZE_MEAN = [0.485, 0.456, 0.406]
14
+ NORMALIZE_STD = [0.229, 0.224, 0.225]
15
+
16
+ # BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"])
17
+ BREAKHIS_LABELS = {
18
+ 0: "Tubular Adenoma (TA) - Benign",
19
+ 1: "Mucinous Carcinoma (MC) - Malignant",
20
+ 2: "Fibroadenoma (F) - Benign",
21
+ 3: "Ductal Carcinoma (DC) - Malignant"
22
+ }
23
+ GLEASON_LABELS = {
24
+ 0: "Benign",
25
+ 1: "Gleason 3",
26
+ 2: "Gleason 4",
27
+ 3: "Gleason 5"
28
+
29
+ }
30
+ BACH_LABELS = {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3}
31
+ CRC_LABELS = {
32
+ "ADI": 0,
33
+ "BACK": 1,
34
+ "DEB": 2,
35
+ "LYM": 3,
36
+ "MUC": 4,
37
+ "MUS": 5,
38
+ "NORM": 6,
39
+ "STR": 7,
40
+ "TUM": 8,
41
+ }
42
+
43
+ print("Loading DinoV2 base model...")
44
+ dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
45
+
46
+ print("Loading custom pathology checkpoint...")
47
+ #ours = torch.load("/data/linears/teacher_checkpoint.pth")
48
+
49
+ ours = torch.load("/data/linears/teacher_checkpoint_load.pt")
50
+ checkpoint = ours#["teacher"]
51
+ checkpoint_new = {}
52
+
53
+ # Remove dino and ibot keys
54
+ #for key in list(checkpoint.keys()):
55
+ # if "dino" in str(key) or "ibot" in str(key):
56
+ # checkpoint.pop(key, None)
57
+
58
+ # Align keys with dinov2 state dict
59
+ #for key, keyb in zip(checkpoint.keys(), dinov2.state_dict().keys()):
60
+ # checkpoint_new[keyb] = checkpoint[key]
61
+
62
+ #checkpoint = checkpoint_new
63
+
64
+ # Update pos_embed
65
+
66
+ new_shape = checkpoint["pos_embed"]
67
+ dinov2.pos_embed = torch.nn.parameter.Parameter(new_shape)
68
+ dinov2.load_state_dict(checkpoint)
69
+ dinov2.eval()
70
+
71
+ #torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
72
+
73
+ def setup_linear(path):
74
+ print(f"Loading {path} linear classifier...")
75
+ # Load the best checkpoint from the latest run
76
+ linear_checkpoint = torch.load(path)
77
+ linear_weights = linear_checkpoint["state_dict"]["head.weight"]
78
+ linear_bias = linear_checkpoint["state_dict"]["head.bias"]
79
+
80
+ # Create linear layer
81
+ linear = torch.nn.Linear(1536, 4)
82
+ linear.weight.data = linear_weights
83
+ linear.bias.data = linear_bias
84
+ linear.eval()
85
+ return linear
86
+
87
+ def setup_linear_crc(path):
88
+ print(f"Loading {path} linear classifier...")
89
+ # Load the best checkpoint from the latest run
90
+ linear_checkpoint = torch.load(path)
91
+ linear_weights = linear_checkpoint["state_dict"]["head.weight"]
92
+ linear_bias = linear_checkpoint["state_dict"]["head.bias"]
93
+
94
+ # Create linear layer
95
+ linear = torch.nn.Linear(1536, 9)
96
+ linear.weight.data = linear_weights
97
+ linear.bias.data = linear_bias
98
+ linear.eval()
99
+ return linear
100
+
101
+ # Move models to GPU if available
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+ dinov2 = dinov2.to(device)
104
+
105
+ breakhis_path = "/data/linears/logs/dino_vits16/offline/breakhis/20251030-190409559498_5b47c293/run_0/checkpoints/best.ckpt"
106
+ breakhis_linear = setup_linear(breakhis_path).to(device)
107
+
108
+ gleason_path = "/data/linears/logs/dino_vits16/offline/gleason_arvaniti/20251110-164046988851_35daf081/run_4/checkpoints/best.ckpt"
109
+ gleason_linear = setup_linear(gleason_path).to(device)
110
+
111
+ bach_path = "/data/linears/logs/dino_vits16/offline/bach/20251110-164046453320_0b82d41d/run_4/checkpoints/best.ckpt"
112
+ bach_linear = setup_linear(bach_path).to(device)
113
+
114
+ crc_path = "/data/linears/logs/dino_vits16/offline/crc/20251110-164127567401_f6ae5d68/run_4/checkpoints/best.ckpt"
115
+ crc_linear = setup_linear_crc(crc_path).to(device)
116
+
117
+
118
+ print(f"Models loaded on {device}")
119
+
120
+
121
+ model_transforms = transforms.Compose([
122
+ transforms.Resize(RESIZE_DIM),
123
+ transforms.CenterCrop(RESIZE_DIM),
124
+ transforms.ToDtype(torch.float32, scale=True),
125
+ transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
126
+ ])
127
+
128
+
129
+ def cv_path(path):
130
+
131
+ image = cv2.imread(path, flags=cv2.IMREAD_COLOR)
132
+ if image.ndim == 3:
133
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
134
+
135
+ if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
136
+ image = image[:, :, np.newaxis]
137
+ image = np.asarray(image, dtype=np.uint8)
138
+ image = functional.to_image(image)
139
+ return image
140
+
141
+ def predict_breakhis(image):
142
+
143
+ return predict_class(image, breakhis_linear, "breakhis")
144
+
145
+ def predict_gleason(image):
146
+
147
+ return predict_class(image, gleason_linear, "gleason")
148
+
149
+ def predict_bach(image):
150
+
151
+ return predict_class(image, bach_linear, "bach")
152
+
153
+ def predict_crc(image):
154
+
155
+ return predict_class(image, crc_linear, "crc")
156
+
157
+
158
+ def predict_class(image, linear, dataset):
159
+ """
160
+ Predict breast tumor type from a histopathology image
161
+
162
+ Args:
163
+ image: PIL Image or numpy array
164
+
165
+ Returns:
166
+ dict: Probability distribution over tumor types
167
+ """
168
+
169
+ image = cv_path(image)
170
+
171
+ # Preprocess image
172
+ image_tensor = model_transforms(image).unsqueeze(0).to(device)
173
+
174
+ # Get embedding from DinoV2
175
+ with torch.no_grad():
176
+ embedding = dinov2(image_tensor)
177
+ # Get logits from linear classifier
178
+ logits = linear(embedding)
179
+ print(logits)
180
+ # Convert to probabilities
181
+ probs = torch.nn.functional.softmax(logits, dim=1)
182
+ print(probs)
183
+
184
+ # Create output dictionary
185
+ probs_dict = {}
186
+ for idx, prob in enumerate(probs[0].cpu().numpy()):
187
+ if dataset == "breakhis":
188
+ probs_dict[BREAKHIS_LABELS[idx]] = float(prob)
189
+ elif dataset == "gleason":
190
+ probs_dict[GLEASON_LABELS[idx]] = float(prob)
191
+ elif dataset == "bach":
192
+ probs_dict[BACH_LABELS[idx]] = float(prob)
193
+ elif dataset == "crc":
194
+ probs_dict[CRC_LABELS[idx]] = float(prob)
195
+
196
+
197
+ return probs_dict
198
+
199
+ # Create Gradio interface
200
+ breakhis = gr.Interface(
201
+ fn=predict_breakhis,
202
+ inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"),
203
+ outputs=gr.Label(num_top_classes=4, label="Tumor Type Prediction"),
204
+ title="BreakHis Breast Tumor Classification",
205
+ description="""
206
+ Upload a breast histopathology image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image.
207
+
208
+ This model uses a custom-trained DinoV2 foundation model for pathology images
209
+ with a linear classifier for BreakHis tumor classification.
210
+
211
+ **Tumor Types:**
212
+ - **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F)
213
+ - **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC)
214
+
215
+ These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation.
216
+ For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
217
+ """,
218
+ examples=["./data/breakhis/BreaKHis_v1/histology_slides/breast/benign/SOB/tubular_adenoma/SOB_B_TA_14-13200/40X/SOB_B_TA-14-13200-40-001.png",
219
+ "./data/breakhis/BreaKHis_v1/histology_slides/breast/malignant/SOB/mucinous_carcinoma/SOB_M_MC_14-10147/40X/SOB_M_MC-14-10147-40-001.png",
220
+ "./data/breakhis/BreaKHis_v1/histology_slides/breast/benign/SOB/fibroadenoma/SOB_B_F_14-14134/40X/SOB_B_F-14-14134-40-001.png",
221
+ ], # You can add example image paths here
222
+ theme=gr.themes.Soft()
223
+ )
224
+
225
+ gleason = gr.Interface(
226
+ fn=predict_gleason,
227
+ inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"),
228
+ outputs=gr.Label(num_top_classes=4, label="Gleason Tumor Type Prediction"),
229
+ title="Gleason Prostate Tumor Classification",
230
+ description="""
231
+ Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image.
232
+
233
+ This model uses a custom-trained DinoV2 foundation model for pathology images
234
+ with a linear classifier for gleason tumor classification.
235
+
236
+ Images are classified as benign, Gleason pattern 3, 4 or 5.
237
+
238
+ For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
239
+ """,
240
+ examples=["./data/arvaniti_gleason_patches/train_validation_patches_750/ZT111_4_A_1_12/ZT111_4_A_1_12_patch_13_class_2.jpg",
241
+ "./data/arvaniti_gleason_patches/train_validation_patches_750/ZT204_6_A_1_10/ZT204_6_A_1_10_patch_10_class_3.jpg",
242
+ #"",
243
+ ], # You can add example image paths here
244
+ theme=gr.themes.Soft()
245
+ )
246
+
247
+ crc = gr.Interface(
248
+ fn=predict_crc,
249
+ inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"),
250
+ outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"),
251
+ title="Colorectal Tumor Classification",
252
+ description="""
253
+ Upload a colorectal cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image.
254
+
255
+ This model uses a custom-trained DinoV2 foundation model for pathology images
256
+ with a linear classifier for colorectal tumor classification.
257
+
258
+ The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM)
259
+
260
+ For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
261
+ """,
262
+ examples=["./data/crc/CRC-VAL-HE-7K/ADI/ADI-TCGA-AAICEQFN.tif",
263
+ "./data/crc/CRC-VAL-HE-7K/BACK/BACK-TCGA-AARRNSTS.tif",
264
+ "./data/crc/CRC-VAL-HE-7K/DEB/DEB-TCGA-AANNAWLE.tif",
265
+ ], # You can add example image paths here
266
+ theme=gr.themes.Soft()
267
+ )
268
+
269
+ bach = gr.Interface(
270
+ fn=predict_bach,
271
+ inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
272
+ outputs=gr.Label(num_top_classes=4, label="Bach Tumor Type Prediction"),
273
+ title="Tumor Classification",
274
+ description="""
275
+ Upload a prostate cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image.
276
+
277
+ This model uses a custom-trained DinoV2 foundation model for pathology images
278
+ with a linear classifier for tumor classification.
279
+
280
+ Images are classified as benign, normal, invasive, inSitu
281
+
282
+ For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
283
+ """,
284
+ examples=["./data/bach/ICIAR2018_BACH_Challenge/Photos/Benign/b001.tif",
285
+ "./data/bach/ICIAR2018_BACH_Challenge/Photos/Normal/n001.tif",
286
+ "./data/bach/ICIAR2018_BACH_Challenge/Photos/Benign/is001.tif",
287
+ "./data/bach/ICIAR2018_BACH_Challenge/Photos/Benign/iv001.tif"
288
+ ], # You can add example image paths here
289
+ theme=gr.themes.Soft()
290
+ )
291
+
292
+
293
+
294
+
295
+ demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"])
296
+
297
+
298
+ if __name__ == "__main__":
299
+ demo.launch(share=True)