| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import, division, print_function |
| |
|
| | try: |
| | import DCNv3 |
| | dcn_version = float(pkg_resources.get_distribution('DCNv3').version) |
| | has_cuda_kernel = True |
| | except: |
| | has_cuda_kernel = False |
| | import pkg_resources |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.autograd import Function |
| | from torch.autograd.function import once_differentiable |
| | from torch.cuda.amp import custom_bwd, custom_fwd |
| |
|
| |
|
| | class DCNv3Function(Function): |
| | @staticmethod |
| | @custom_fwd |
| | def forward( |
| | ctx, input, offset, mask, |
| | kernel_h, kernel_w, stride_h, stride_w, |
| | pad_h, pad_w, dilation_h, dilation_w, |
| | group, group_channels, offset_scale, im2col_step, remove_center): |
| | ctx.kernel_h = kernel_h |
| | ctx.kernel_w = kernel_w |
| | ctx.stride_h = stride_h |
| | ctx.stride_w = stride_w |
| | ctx.pad_h = pad_h |
| | ctx.pad_w = pad_w |
| | ctx.dilation_h = dilation_h |
| | ctx.dilation_w = dilation_w |
| | ctx.group = group |
| | ctx.group_channels = group_channels |
| | ctx.offset_scale = offset_scale |
| | ctx.im2col_step = im2col_step |
| | ctx.remove_center = remove_center |
| |
|
| | args = [ |
| | input, offset, mask, kernel_h, |
| | kernel_w, stride_h, stride_w, pad_h, |
| | pad_w, dilation_h, dilation_w, group, |
| | group_channels, offset_scale, ctx.im2col_step |
| | ] |
| | if remove_center or dcn_version > 1.0: |
| | args.append(remove_center) |
| |
|
| | output = DCNv3.dcnv3_forward(*args) |
| | ctx.save_for_backward(input, offset, mask) |
| |
|
| | return output |
| |
|
| | @staticmethod |
| | @once_differentiable |
| | @custom_bwd |
| | def backward(ctx, grad_output): |
| | input, offset, mask = ctx.saved_tensors |
| |
|
| | args = [ |
| | input, offset, mask, ctx.kernel_h, |
| | ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, |
| | ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, |
| | ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step |
| | ] |
| | if ctx.remove_center or dcn_version > 1.0: |
| | args.append(ctx.remove_center) |
| |
|
| | grad_input, grad_offset, grad_mask = \ |
| | DCNv3.dcnv3_backward(*args) |
| |
|
| | return grad_input, grad_offset, grad_mask, \ |
| | None, None, None, None, None, None, None, None, None, None, None, None, None |
| |
|
| | @staticmethod |
| | def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, |
| | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, |
| | group_channels, offset_scale, im2col_step, remove_center): |
| | """Symbolic function for mmdeploy::DCNv3. |
| | |
| | Returns: |
| | DCNv3 op for onnx. |
| | """ |
| | return g.op( |
| | 'mmdeploy::TRTDCNv3', |
| | input, |
| | offset, |
| | mask, |
| | kernel_h_i=int(kernel_h), |
| | kernel_w_i=int(kernel_w), |
| | stride_h_i=int(stride_h), |
| | stride_w_i=int(stride_w), |
| | pad_h_i=int(pad_h), |
| | pad_w_i=int(pad_w), |
| | dilation_h_i=int(dilation_h), |
| | dilation_w_i=int(dilation_w), |
| | group_i=int(group), |
| | group_channels_i=int(group_channels), |
| | offset_scale_f=float(offset_scale), |
| | im2col_step_i=int(im2col_step), |
| | remove_center_i=int(remove_center), |
| | ) |
| |
|
| |
|
| | def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): |
| | _, H_, W_, _ = spatial_shapes |
| | H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 |
| | W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 |
| |
|
| | ref_y, ref_x = torch.meshgrid( |
| | torch.linspace( |
| | |
| | |
| | (dilation_h * (kernel_h - 1)) // 2 + 0.5, |
| | (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, |
| | H_out, |
| | dtype=torch.float32, |
| | device=device), |
| | torch.linspace( |
| | |
| | |
| | (dilation_w * (kernel_w - 1)) // 2 + 0.5, |
| | (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, |
| | W_out, |
| | dtype=torch.float32, |
| | device=device)) |
| | ref_y = ref_y.reshape(-1)[None] / H_ |
| | ref_x = ref_x.reshape(-1)[None] / W_ |
| |
|
| | ref = torch.stack((ref_x, ref_y), -1).reshape( |
| | 1, H_out, W_out, 1, 2) |
| |
|
| | return ref |
| |
|
| |
|
| | def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): |
| | _, H_, W_, _ = spatial_shapes |
| | points_list = [] |
| | x, y = torch.meshgrid( |
| | torch.linspace( |
| | -((dilation_w * (kernel_w - 1)) // 2), |
| | -((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w, |
| | kernel_w, |
| | dtype=torch.float32, |
| | device=device), |
| | torch.linspace( |
| | -((dilation_h * (kernel_h - 1)) // 2), |
| | -((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h, |
| | kernel_h, |
| | dtype=torch.float32, |
| | device=device)) |
| |
|
| | points_list.extend([x / W_, y / H_]) |
| | grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ |
| | repeat(1, group, 1).permute(1, 0, 2) |
| | grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) |
| |
|
| | return grid |
| |
|
| |
|
| | def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h): |
| | idx = list(range(sampling_locations.shape[-2])) |
| | C = (kernel_w * kernel_h - 1)//2 |
| | idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0] |
| | sampling_locations = sampling_locations[:,:,:,idx, :] |
| | return sampling_locations |
| |
|
| |
|
| | def dcnv3_core_pytorch( |
| | input, offset, mask, kernel_h, |
| | kernel_w, stride_h, stride_w, pad_h, |
| | pad_w, dilation_h, dilation_w, group, |
| | group_channels, offset_scale, remove_center): |
| | |
| | |
| |
|
| | if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h): |
| | raise ValueError('remove_center is only compatible with square odd kernel size.') |
| |
|
| | input = F.pad( |
| | input, |
| | [0, 0, pad_h, pad_h, pad_w, pad_w]) |
| | N_, H_in, W_in, _ = input.shape |
| | _, H_out, W_out, _ = offset.shape |
| |
|
| | ref = _get_reference_points( |
| | input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) |
| | grid = _generate_dilation_grids( |
| | input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) |
| | spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ |
| | repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device) |
| |
|
| | sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1) |
| | if remove_center: |
| | sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h) |
| | sampling_locations = sampling_locations.flatten(3, 4) |
| | sampling_locations = sampling_locations + offset * offset_scale / spatial_norm |
| |
|
| | P_ = kernel_h * kernel_w - remove_center |
| | sampling_grids = 2 * sampling_locations - 1 |
| | |
| | input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ |
| | reshape(N_*group, group_channels, H_in, W_in) |
| | |
| | sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ |
| | flatten(0, 1) |
| | |
| | sampling_input_ = F.grid_sample( |
| | input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) |
| |
|
| | |
| | mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ |
| | reshape(N_*group, 1, H_out*W_out, P_) |
| | output = (sampling_input_ * mask).sum(-1).view(N_, |
| | group*group_channels, H_out*W_out) |
| |
|
| | return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() |
| |
|