Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
metrics:
|
| 4 |
+
- accuracy
|
| 5 |
+
- precision
|
| 6 |
+
- recall
|
| 7 |
+
- f1
|
| 8 |
+
pipeline_tag: image-classification
|
| 9 |
+
tags:
|
| 10 |
+
- medical
|
| 11 |
+
- cervical-cancer
|
| 12 |
+
- multi-class
|
| 13 |
+
- ood-detection
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Model Card: DenseNet121 for Cervix type Image Classification
|
| 17 |
+
|
| 18 |
+
This model classifies cervical images into **Type_1, Type_2, Type_3**, and an **Out-of-Distribution (OOD)** category. It uses a **DenseNet121 backbone** pretrained on ImageNet and fine-tuned on cervical images, including OOD examples from Caltech101.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
### Model Details
|
| 22 |
+
|
| 23 |
+
- **Base model:** `torchvision.models.densenet121` pretrained on ImageNet
|
| 24 |
+
- **Input:** RGB images (224x224)
|
| 25 |
+
- **Output:** 4 classes: `['Type_1', 'Type_2', 'Type_3', 'OOD']`
|
| 26 |
+
- **License:** MIT
|
| 27 |
+
- **Training dataset sources:**
|
| 28 |
+
- Cervical images: Intel MobileODT competition dataset
|
| 29 |
+
- OOD images: Caltech101 dataset
|
| 30 |
+
- **Preprocessing & Augmentation:**
|
| 31 |
+
- Resize to 224x224
|
| 32 |
+
- Normalization (ImageNet mean & std)
|
| 33 |
+
- Data augmentation: Random rotation, color jitter (brightness/contrast)
|
| 34 |
+
|
| 35 |
+
### Dataset Distribution
|
| 36 |
+
|
| 37 |
+
| Split | Type_1 | Type_2 | Type_3 | OOD | Total |
|
| 38 |
+
| ---------- | ------ | ------ | ------ | ---- | ----- |
|
| 39 |
+
| Train | 557 | 532 | 547 | 424 | 2060 |
|
| 40 |
+
| Validation | 151 | 161 | 154 | 122 | 588 |
|
| 41 |
+
| Test | 73 | 88 | 80 | 54 | 295 |
|
| 42 |
+
|
| 43 |
+
### Training Details
|
| 44 |
+
|
| 45 |
+
- Optimizer: Adam
|
| 46 |
+
- Loss: CrossEntropyLoss
|
| 47 |
+
- Batch size: 8
|
| 48 |
+
- Learning rate: 1e-5
|
| 49 |
+
- Epochs: 30
|
| 50 |
+
- Device: GPU (Tesla T4, 14GB)
|
| 51 |
+
|
| 52 |
+
## Evaluation
|
| 53 |
+
|
| 54 |
+
### Evaluation Metrics
|
| 55 |
+
|
| 56 |
+
| Class | Precision | Recall | F1-score | Sensitivity | Specificity |
|
| 57 |
+
|---------|----------|--------|----------|-------------|-------------|
|
| 58 |
+
| OOD | 1.00 | 1.00 | 1.00 | 1.0000 | 1.0000 |
|
| 59 |
+
| Type_1 | 0.74 | 0.93 | 0.82 | 0.9333 | 0.9074 |
|
| 60 |
+
| Type_2 | 0.85 | 0.51 | 0.64 | 0.5114 | 0.9574 |
|
| 61 |
+
| Type_3 | 0.73 | 0.92 | 0.81 | 0.9189 | 0.8762 |
|
| 62 |
+
|
| 63 |
+
**Overall accuracy:** 0.81
|
| 64 |
+
|
| 65 |
+
**Confusion Matrix**
|
| 66 |
+
```
|
| 67 |
+
Predicted
|
| 68 |
+
OOD T1 T2 T3
|
| 69 |
+
Actual
|
| 70 |
+
OOD 54 0 0 0
|
| 71 |
+
Type_1 0 56 3 1
|
| 72 |
+
Type_2 0 19 45 24
|
| 73 |
+
Type_3 0 1 5 68
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**Classification Report**
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
precision recall f1-score support
|
| 81 |
+
OOD 1.00 1.00 1.00 54
|
| 82 |
+
Type_1 0.74 0.93 0.82 60
|
| 83 |
+
Type_2 0.85 0.51 0.64 88
|
| 84 |
+
Type_3 0.73 0.92 0.81 74
|
| 85 |
+
|
| 86 |
+
accuracy 0.81 276
|
| 87 |
+
macro avg 0.83 0.84 0.82 276
|
| 88 |
+
weighted avg 0.82 0.81 0.80 276
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## How to Get Started
|
| 95 |
+
|
| 96 |
+
``````python
|
| 97 |
+
import torch
|
| 98 |
+
from torchvision import transforms, models
|
| 99 |
+
from PIL import Image
|
| 100 |
+
|
| 101 |
+
# Load model
|
| 102 |
+
model = models.densenet121(pretrained=False)
|
| 103 |
+
model.classifier = torch.nn.Linear(model.classifier.in_features, 4)
|
| 104 |
+
model.load_state_dict(torch.load("Dense_net_121.pth", map_location="cpu"))
|
| 105 |
+
model.eval()
|
| 106 |
+
|
| 107 |
+
# Transform
|
| 108 |
+
transform = transforms.Compose([
|
| 109 |
+
transforms.Resize((224, 224)),
|
| 110 |
+
transforms.ToTensor(),
|
| 111 |
+
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
|
| 112 |
+
])
|
| 113 |
+
|
| 114 |
+
# Load image
|
| 115 |
+
image = Image.open("example.jpg").convert("RGB")
|
| 116 |
+
image = transform(image).unsqueeze(0)
|
| 117 |
+
|
| 118 |
+
# Predict
|
| 119 |
+
outputs = model(image)
|
| 120 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 121 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 122 |
+
confidence = probabilities[0, predicted_class].item()
|
| 123 |
+
|
| 124 |
+
class_names = ["Type_1", "Type_2", "Type_3", "OOD"]
|
| 125 |
+
print(f"Predicted class: {class_names[predicted_class]}, confidence: {confidence:.2f}")
|
| 126 |
+
|
| 127 |
+
````
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
## Technical Specifications
|
| 131 |
+
|
| 132 |
+
### Model Architecture
|
| 133 |
+
|
| 134 |
+
* **Backbone:** DenseNet121 pretrained on ImageNet
|
| 135 |
+
* **Output Layer:** Fully connected layer with 4 outputs (`Type_1`, `Type_2`, `Type_3`, `OOD`)
|
| 136 |
+
* **Activation:** Softmax for multi-class classification
|
| 137 |
+
* **Training Framework:** PyTorch
|
| 138 |
+
* **Loss Function:** CrossEntropyLoss
|
| 139 |
+
* **Data Handling:** Includes OOD images from Caltech101 along with in-distribution cervical images
|
| 140 |
+
* **Preprocessing & Augmentation:** Resize to 224x224, normalization (ImageNet mean/std), random rotation, color jitter
|
| 141 |
+
|
| 142 |
+
### Compute Infrastructure
|
| 143 |
+
|
| 144 |
+
* **Hardware:** Tesla T4 GPU (14GB)
|
| 145 |
+
* **Software:** PyTorch, torchvision, CUDA
|
| 146 |
+
|
| 147 |
+
---
|