import torch
from torchmetrics.classification import BinaryJaccardIndex, JaccardIndex
Metrics
Jaccard Index (aka Intersection over Union, aka IoU)
Binary
= torch.tensor([[1, 1], [1, 0]])
target = torch.tensor([[1, 1], [0, 0]])
preds = BinaryJaccardIndex()
metric metric(preds, target)
tensor(0.6667)
Multiclass
= torch.randint(0, 2, (10, 25, 25))
target = target.clone()
pred 2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
pred[= JaccardIndex(task="multiclass", num_classes=2)
jaccard jaccard(pred, target)
tensor(0.9660)