Skip to content

How to use DiceLoss for multiple-class 2D segmentation? #8523

Answered by NabJa
alexaex asked this question in Q&A
Discussion options

You must be logged in to vote

You are right. The dice loss should be 1 - dice score. Consider this code that does fix some minor problems in your code:

batch_size, n_classes, h, w = 1, 3, 128, 128

dice_metric = DiceMetric(include_background=True, num_classes=n_classes, reduction="mean_batch")
dice_loss = DiceLoss(to_onehot_y=True, reduction="none")

# img = torch.rand(batch_size, n_classes, h, w)
mask = torch.randint(0, n_classes, size=(batch_size, 1, h, w))

logits = torch.rand(batch_size, n_classes, h, w) # self.model(img)
y_pred = torch.argmax(logits, dim=1, keepdim=True) # convert to label map

dice_score = dice_metric(y_pred, mask)
loss = dl(one_hot(y_pred, n_classes), mask).squeeze()

torch.allclose(dice_score,…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@alexaex
Comment options

Answer selected by alexaex
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants