Exported the single window v2 checkpoints using the following code:
import torch
import onnx
import segmentation_models_pytorch as smp
from torch.export.dynamic_shapes import Dim
from onnxconverter_common import float16
# Load model to plain smp.Unet
path = "3_Class_FULL_FTW_Pretrained_singleWindow_v2.ckpt"
ckpt = torch.load(path, map_location="cpu", weights_only=False)
hparams = ckpt["hyper_parameters"]
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
del state_dict["criterion.weight"]
print(hparams["model"], hparams["backbone"], hparams["in_channels"], hparams["num_classes"])
model = smp.Unet(
encoder_name=hparams["backbone"],
encoder_weights=None,
in_channels=hparams["in_channels"],
classes=hparams["num_classes"],
)
model.eval()
model.load_state_dict(state_dict, strict=True)
# Export to exported program and then export to onnx
program = torch.export.export(
model,
args=(torch.randn(1, hparams["in_channels"], 256, 256),),
dynamic_shapes={"x": (Dim.AUTO, hparams["in_channels"], Dim.AUTO, Dim.AUTO)},
)
output_path = "ftw-v2-single-window-unet-efficientnetb3-fp32.onnx"
onnx_program = torch.onnx.export(model=program, f=output_path, external_data=False)
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
# Convert to fp16
model_fp16 = float16.convert_float_to_float16(onnx_model, keep_io_types=True)
onnx.save(model_fp16, output_path.replace("fp32", "fp16"))
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support