| | """by lyuwenyu |
| | """ |
| |
|
| | import json |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision |
| |
|
| |
|
| | __all__ = ['RTDETRPostProcessor'] |
| |
|
| |
|
| | class RTDETRPostProcessor(nn.Module): |
| | __share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category'] |
| | |
| | def __init__(self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None: |
| | super().__init__() |
| | self.use_focal_loss = use_focal_loss |
| | self.num_top_queries = num_top_queries |
| | self.num_classes = num_classes |
| | self.remap_mscoco_category = remap_mscoco_category |
| | self.deploy_mode = False |
| |
|
| |
|
| | def extra_repr(self) -> str: |
| | return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' |
| | |
| | |
| | def forward(self, outputs, orig_target_sizes): |
| |
|
| | logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] |
| | |
| |
|
| | bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') |
| | bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) |
| |
|
| | if self.use_focal_loss: |
| | scores = F.sigmoid(logits) |
| | scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1) |
| | labels = index % self.num_classes |
| | index = index // self.num_classes |
| | boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) |
| | |
| | else: |
| | scores = F.softmax(logits)[:, :, :-1] |
| | scores, labels = scores.max(dim=-1) |
| | boxes = bbox_pred |
| | if scores.shape[1] > self.num_top_queries: |
| | scores, index = torch.topk(scores, self.num_top_queries, dim=-1) |
| | labels = torch.gather(labels, dim=1, index=index) |
| | boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) |
| |
|
| | |
| | if self.deploy_mode: |
| | return labels, boxes, scores |
| |
|
| | |
| | if self.remap_mscoco_category: |
| | from ...data.coco import mscoco_label2category |
| | labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ |
| | .to(boxes.device).reshape(labels.shape) |
| |
|
| | results = [] |
| | for lab, box, sco in zip(labels, boxes, scores): |
| | result = dict(labels=lab, boxes=box, scores=sco) |
| | results.append(result) |
| | |
| | return results |
| | |
| |
|
| | def deploy(self, ): |
| | self.eval() |
| | self.deploy_mode = True |
| | return self |
| |
|
| | @property |
| | def iou_types(self, ): |
| | return ('bbox', ) |