# Taken from https://github.com/pytorch/vision # So that we don't need torchvision to be installed from collections import OrderedDict import torch from torch import nn from torch.jit.annotations import Dict from torch.nn import functional as F try: from scipy.optimize import linear_sum_assignment scipy_available = True except Exception: scipy_available = False def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__( self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, ): super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion = 4 def __init__( self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, ): super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, ): super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " f"or a 3-element tuple, got {replace_stride_with_dilation}" ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _forward_impl(self, x): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def forward(self, x): return self._forward_impl(x) def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) # if pretrained: # state_dict = load_state_dict_from_url(model_urls[arch], # progress=progress) # model.load_state_dict(state_dict) return model def resnet18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) class IntermediateLayerGetter(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model It has a strong assumption that the modules have been registered into the model in the same order as they are used. This means that one should **not** reuse the same nn.Module twice in the forward if you want this to work. Additionally, it is only able to query submodules that are directly assigned to the model. So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. Args: model (nn.Module): model on which we will extract the features return_layers (Dict[name, new_name]): a dict containing the names of the modules for which the activations will be returned as the key of the dict, and the value of the dict is the name of the returned activation (which the user can specify). Examples:: >>> m = torchvision.models.resnet18(pretrained=True) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = new_m(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] """ _version = 2 __annotations__ = { "return_layers": Dict[str, str], } def __init__(self, model, return_layers): if not set(return_layers).issubset( [name for name, _ in model.named_children()] ): raise ValueError("return_layers are not present in model") orig_return_layers = return_layers return_layers = {str(k): str(v) for k, v in return_layers.items()} layers = OrderedDict() for name, module in model.named_children(): layers[name] = module if name in return_layers: del return_layers[name] if not return_layers: break super().__init__(layers) self.return_layers = orig_return_layers def forward(self, x): out = OrderedDict() for name, module in self.items(): x = module(x) if name in self.return_layers: out_name = self.return_layers[name] out[out_name] = x return out class _SimpleSegmentationModel(nn.Module): __constants__ = ["aux_classifier"] def __init__(self, backbone, classifier, aux_classifier=None): super().__init__() self.backbone = backbone self.classifier = classifier self.aux_classifier = aux_classifier def forward(self, x): input_shape = x.shape[-2:] # contract: features is a dict of tensors features = self.backbone(x) result = OrderedDict() x = features["out"] x = self.classifier(x) x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["out"] = x if self.aux_classifier is not None: x = features["aux"] x = self.aux_classifier(x) x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["aux"] = x return result class FCN(_SimpleSegmentationModel): """ Implements a Fully-Convolutional Network for semantic segmentation. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier is used. classifier (nn.Module): module that takes the "out" element returned from the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ class FCNHead(nn.Sequential): def __init__(self, in_channels, channels): inter_channels = in_channels // 4 layers = [ nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU(), nn.Dropout(0.1), nn.Conv2d(inter_channels, channels, 1), ] super().__init__(*layers) def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): # backbone = resnet.__dict__[backbone_name]( # pretrained=pretrained_backbone, # replace_stride_with_dilation=[False, True, True]) # Hardcoded resnet 50 assert backbone_name == "resnet50" backbone = resnet50( pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True] ) return_layers = {"layer4": "out"} if aux: return_layers["layer3"] = "aux" backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { # 'deeplabv3': (DeepLabHead, DeepLabV3), # Not used "fcn": (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model def _load_model( arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs ): if pretrained: aux_loss = True model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) # if pretrained: # arch = arch_type + '_' + backbone + '_coco' # model_url = model_urls[arch] # if model_url is None: # raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) # else: # state_dict = load_state_dict_from_url(model_url, progress=progress) # model.load_state_dict(state_dict) return model def fcn_resnet50( pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs ): """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ return _load_model( "fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs ) # Taken from @fmassa example slides and https://github.com/facebookresearch/detr class DETR(nn.Module): """ Demo DETR implementation. Demo implementation of DETR in minimal number of lines, with the following differences wrt DETR in the paper: * learned positional encoding (instead of sine) * positional encoding is passed at input (instead of attention) * fc bbox predictor (instead of MLP) The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100. Only batch size 1 supported. """ def __init__( self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6, ): super().__init__() # create ResNet-50 backbone self.backbone = resnet50() del self.backbone.fc # create conversion layer self.conv = nn.Conv2d(2048, hidden_dim, 1) # create a default PyTorch transformer self.transformer = nn.Transformer( hidden_dim, nheads, num_encoder_layers, num_decoder_layers ) # prediction heads, one extra class for predicting non-empty slots # note that in baseline DETR linear_bbox layer is 3-layer MLP self.linear_class = nn.Linear(hidden_dim, num_classes + 1) self.linear_bbox = nn.Linear(hidden_dim, 4) # output positional encodings (object queries) self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # spatial positional encodings # note that in baseline DETR we use sine positional encodings self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) def forward(self, inputs): # propagate inputs through ResNet-50 up to avg-pool layer x = self.backbone.conv1(inputs) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) # convert from 2048 to 256 feature planes for the transformer h = self.conv(x) # construct positional encodings H, W = h.shape[-2:] pos = ( torch.cat( [ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), ], dim=-1, ) .flatten(0, 1) .unsqueeze(1) ) # propagate through the transformer # TODO (alband) Why this is not automatically broadcasted? (had to add the repeat) f = pos + 0.1 * h.flatten(2).permute(2, 0, 1) s = self.query_pos.unsqueeze(1) s = s.expand(s.size(0), inputs.size(0), s.size(2)) h = self.transformer(f, s).transpose(0, 1) # finally project transformer outputs to class labels and bounding boxes return { "pred_logits": self.linear_class(h), "pred_boxes": self.linear_bbox(h).sigmoid(), } def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_area(boxes): """ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. Args: boxes (Tensor[N, 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format Returns: area (Tensor[N]): area for each box """ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou, union def is_dist_avail_and_initialized(): return False def get_world_size(): if not is_dist_avail_and_initialized(): return 1 @torch.no_grad() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" if target.numel() == 0: return [torch.zeros([], device=output.device)] maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res class SetCriterion(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) def loss_labels(self, outputs, targets, indices, num_boxes, log=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat( [t["labels"][J] for t, (_, J) in zip(targets, indices)] ) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device, ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy( src_logits.transpose(1, 2), target_classes, self.empty_weight ) losses = {"loss_ce": loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes): """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ pred_logits = outputs["pred_logits"] device = pred_logits.device tgt_lengths = torch.as_tensor( [len(v["labels"]) for v in targets], device=device ) # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) losses = {"cardinality_error": card_err} return losses def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. """ assert "pred_boxes" in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] target_boxes = torch.cat( [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 ) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") losses = {} losses["loss_bbox"] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag( generalized_box_iou( box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes) ) ) losses["loss_giou"] = loss_giou.sum() / num_boxes return losses def loss_masks(self, outputs, targets, indices, num_boxes): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ assert "pred_masks" in outputs src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list( # noqa: F821 [t["masks"] for t in targets] ).decompose() target_masks = target_masks.to(src_masks) src_masks = src_masks[src_idx] # upsample predictions to the target size src_masks = interpolate( # noqa: F821 src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False, ) src_masks = src_masks[:, 0].flatten(1) target_masks = target_masks[tgt_idx].flatten(1) losses = { "loss_mask": sigmoid_focal_loss( # noqa: F821 src_masks, target_masks, num_boxes ), # noqa: F821 "loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821 } return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat( [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] ) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat( [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] ) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { "labels": self.loss_labels, "cardinality": self.loss_cardinality, "boxes": self.loss_boxes, "masks": self.loss_masks, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) def forward(self, outputs, targets): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor( [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device ) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == "masks": # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == "labels": # Logging is enabled only for the last layer kwargs = {"log": False} l_dict = self.get_loss( loss, aux_outputs, targets, indices, num_boxes, **kwargs ) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) return losses class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1 ): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou assert ( cost_class != 0 or cost_bbox != 0 or cost_giou != 0 ), "all costs cant be 0" @torch.no_grad() def forward(self, outputs, targets): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ bs, num_queries = outputs["pred_logits"].shape[:2] # We flatten to compute the cost matrices in a batch out_prob = ( outputs["pred_logits"].flatten(0, 1).softmax(-1) ) # [batch_size * num_queries, num_classes] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes tgt_ids = torch.cat([v["labels"] for v in targets]) tgt_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -out_prob[:, tgt_ids] # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # Compute the giou cost betwen boxes cost_giou = -generalized_box_iou( box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) ) # Final cost matrix C = ( self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou ) C = C.view(bs, num_queries, -1).cpu() sizes = [len(v["boxes"]) for v in targets] if not scipy_available: raise RuntimeError( "The 'detr' model requires scipy to run. Please make sure you have it installed" " if you enable the 'detr' model." ) indices = [ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) ] return [ ( torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64), ) for i, j in indices ]