Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Converting UNet model outputs from logits to segmentation masks

I’ve been recently working on segmentation task on certain CT scans. I decided to use Python as the language, ResidualUNet model architecture implemented in MONAI, evaluated by DiceLoss.
Everything went smoothly until the interference of a trained model. This ResidualUNet does not have a Softmax or Sigmoid layer, so it outputs raw logits, being floats ranging from around -30 to 12.
How could I properly convert those logits to obtain probabilites of each pixel belonging to each class?
The input, in notation of BCHWD, is 1x1x256x256x256, being of course the image itself, while the output is 1x9x256x256x256, each channel being a mask for different class.
The code for inference looks more less like this:

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

model = Unet(
    spatial_dims = 3,
    in_channels = 1,
    out_channels = 9,
    channels = (8, 16, 32, 64, 128, 256),
    strides = (2,2,2,2,2),
    num_res_units = 4,
    norm = Norm.INSTANCE
).to(device)

model.load_state_dict(torch.load(PATH_TO_SAVED_MODEL_OBJECT)["model_state_dict"])

inputs, labels = next(iter(validation_dataloader))   # obtaining only one image - batch_size=1
inferer = SimpleInferer()   # simple inferer from monai.inferers
inputs = inputs.to(device)  # pass to GPU
labels = labels.to(device)  # pass to GPU
pred = inferer(inputs = inputs, network=model)
pred = pred.detach().cpu().numpy() # conversion to numpy array for viewing purposes

Thanks in advance for your assistance.

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

If you need a multiclass classification such that each pixel can belong to multiple classes, you can use sigmoid i.e., prob = torch.sigmoid(pred) which will give you the probability of each pixel belonging to each class independently, and if you need a single-class classification, you can use softmax i.e., prob = torch.softmax(pred, dim=1). Both of these will convert the logits to valid probability distributions.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading