1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""convolution vmap impl""" 17from __future__ import absolute_import 18 19import mindspore.numpy as mnp 20from mindspore.ops import constexpr 21from mindspore.ops.primitive import _primexpr 22from mindspore.ops import operations as P 23from mindspore.ops import functional as F 24from mindspore.ops.operations import nn_ops as nps 25from mindspore.ops.operations import _grad_ops as G 26from mindspore.ops.primitive import Primitive 27from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, \ 28 _raise_value_error, _vmap_update_prim_attr, _vmap_clone_prim 29 30 31@vmap_rules_getters.register(P.Conv2D) 32@vmap_rules_getters.register(P.Conv3D) 33def get_conv_vmap_rule(prim, axis_size): 34 """Vmap rule for `Conv2D` and `Conv3D` operations.""" 35 if isinstance(prim, str): 36 prim = Primitive(prim) 37 38 attr_list = [prim.name, prim.group, prim.data_format] 39 new_prim = _vmap_clone_prim(prim) 40 41 def vmap_rule(input_bdim, weight_bdim): 42 is_all_none, result = vmap_general_preprocess(prim, input_bdim, weight_bdim) 43 if is_all_none: 44 return result 45 return _conv_vmap_rule(new_prim, axis_size, input_bdim, weight_bdim, attr_list) 46 47 return vmap_rule 48 49 50@vmap_rules_getters.register(P.Conv2DTranspose) 51@vmap_rules_getters.register(P.Conv2DBackpropInput) 52def get_conv2d_transpose_vmap_rule(prim, axis_size): 53 """Vmap rule for `Conv2DTranspose` and `Conv2DBackpropInput` operations.""" 54 if isinstance(prim, str): 55 prim = Primitive(prim) 56 57 attr_list = [prim.name, prim.group, prim.data_format] 58 new_prim = _vmap_clone_prim(prim) 59 60 def vmap_rule(dout_bdim, weight_bdim, input_size_bdim): 61 is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim, input_size_bdim) 62 if is_all_none: 63 return result 64 return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \ 65 weight_bdim, input_size_bdim, attr_list) 66 67 return vmap_rule 68 69 70@vmap_rules_getters.register(P.Conv3DTranspose) 71def get_conv3d_transpose_vmap_rule(prim, axis_size): 72 """Vmap rule for `Conv3DTranspose` operation.""" 73 if isinstance(prim, str): 74 prim = Primitive(prim) 75 76 attr_list = [prim.name, prim.group, prim.data_format] 77 new_prim = _vmap_clone_prim(prim) 78 79 def vmap_rule(dout_bdim, weight_bdim): 80 is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim) 81 if is_all_none: 82 return result 83 return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, weight_bdim, None, attr_list) 84 85 return vmap_rule 86 87 88@vmap_rules_getters.register(nps.Conv3DBackpropInput) 89def get_conv3d_backprop_input_vmap_rule(prim, axis_size): 90 """Vmap rule for `Conv3DBackpropInput` operation.""" 91 if isinstance(prim, str): 92 prim = Primitive(prim) 93 94 attr_list = [prim.name, prim.group, prim.data_format] 95 new_prim = _vmap_clone_prim(prim) 96 97 def vmap_rule(weight_bdim, dout_bdim, input_size_bdim): 98 is_all_none, result = vmap_general_preprocess(prim, weight_bdim, dout_bdim, input_size_bdim) 99 if is_all_none: 100 return result 101 return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \ 102 weight_bdim, input_size_bdim, attr_list) 103 104 return vmap_rule 105 106 107@vmap_rules_getters.register(G.Conv2DBackpropFilter) 108def get_conv2d_backprop_filter_vmap_rule(prim, axis_size): 109 """Vmap rule for `Conv2DBackpropFilter` operation.""" 110 if isinstance(prim, str): 111 prim = Primitive(prim) 112 113 attr_list = [prim.name, prim.group, prim.data_format] 114 new_prim = _vmap_clone_prim(prim) 115 116 def vmap_rule(dout_bdim, input_x_bdim, weight_size_bdim): 117 is_all_none, result = vmap_general_preprocess(prim, dout_bdim, input_x_bdim, weight_size_bdim) 118 if is_all_none: 119 return result 120 return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \ 121 input_x_bdim, weight_size_bdim, attr_list) 122 123 return vmap_rule 124 125 126@vmap_rules_getters.register(G.Conv3DBackpropFilter) 127def get_conv3d_backprop_filter_vmap_rule(prim, axis_size): 128 """Vmap rule for `Conv3DBackpropFilter` operation.""" 129 if isinstance(prim, str): 130 prim = Primitive(prim) 131 132 attr_list = [prim.name, prim.group, prim.data_format] 133 new_prim = _vmap_clone_prim(prim) 134 135 def vmap_rule(input_x_bdim, dout_bdim, weight_size_bdim): 136 is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, dout_bdim, weight_size_bdim) 137 if is_all_none: 138 return result 139 return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \ 140 input_x_bdim, weight_size_bdim, attr_list) 141 142 return vmap_rule 143 144 145@_primexpr 146def _get_reshape_src_dim(data_dim, cmp_dim): 147 """Get source dim for reshape""" 148 if data_dim > cmp_dim: 149 expand_dim = cmp_dim 150 merge_dim = data_dim + 1 151 else: 152 expand_dim = cmp_dim + 1 153 merge_dim = data_dim 154 return expand_dim, merge_dim 155 156 157@_primexpr 158def _get_merge_shape(src_dim, dst_dim, shape): 159 """Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim.""" 160 new_shape = [shape[i] for i in range(len(shape)) if i != src_dim] 161 new_shape[dst_dim] *= shape[src_dim] 162 return tuple(new_shape) 163 164 165def _reshape_merge_dims(src_dim, dst_dim, target): 166 """Reshape target by merging the src_dim and dst_dim.""" 167 shape = F.shape(target) 168 new_shape = _get_merge_shape(src_dim, dst_dim, shape) 169 new_target = mnp.moveaxis(target, src_dim, dst_dim) 170 output = F.reshape(new_target, new_shape) 171 return output, new_shape 172 173 174@_primexpr 175def _get_expand_shape(src_dim, dst_size, shape, prim_name): 176 """Get new shape for splitting src_dim into dst_size parts.""" 177 dst_size2 = shape[src_dim] // dst_size 178 new_shape = list(shape) 179 new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2] 180 return tuple(new_shape) 181 182 183def _reshape_expand_dims(src_dim, dst_size, target, prim_name): 184 """Reshape target by splitting src_dim into dst_size parts.""" 185 shape = F.shape(target) 186 new_shape = _get_expand_shape(src_dim, dst_size, shape, prim_name) 187 return F.reshape(target, new_shape) 188 189 190@_primexpr 191def _get_new_size_by_index(input_size, batch_size, index): 192 """Get the new size of input_size by multiplying input_size[index] by batch_size.""" 193 if input_size is None: 194 new_size = () 195 return new_size 196 new_size = list(input_size) 197 new_size[index] *= batch_size 198 return tuple(new_size) 199 200 201@_primexpr 202def _update_group_attr(prim, groups, batch_size): 203 """Set new value for 'group' attribute of the convolution primitive.""" 204 group = groups * batch_size 205 _vmap_update_prim_attr(prim, 'group', group) 206 _vmap_update_prim_attr(prim, 'groups', group) 207 208 209@constexpr 210def _get_channel_index(data_format, prim_name): 211 """Get channel index by data_format, only supports NHWC/NCHW/NCDHW now.""" 212 index = 0 213 if data_format == "NHWC": 214 index = 3 215 elif data_format in ("NCHW", "NCDHW"): 216 index = 1 217 else: 218 _raise_value_error("'data_format' in {} should be NHWC/NCHW/NCDHW, " 219 "but got {}.".format(prim_name, data_format)) 220 return index 221 222 223def _conv_vmap_rule(prim, batch_size, input_bdim, weight_bdim, attr_list): 224 """Vmap rule for Convolution operations, such as `Conv2D` and `Conv3D`.""" 225 input_x, x_dim = input_bdim 226 weight, w_dim = weight_bdim 227 prim_name = attr_list[0] 228 groups = attr_list[1] 229 data_format = attr_list[2] 230 c_axis = _get_channel_index(data_format, prim_name) 231 232 def _get_output_for_x_w_vmap(): 233 new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x) 234 new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight) 235 236 _update_group_attr(prim, groups, batch_size) 237 _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0]) 238 out = prim(new_input, new_weight) 239 out = _reshape_expand_dims(c_axis, batch_size, out, prim_name) 240 return out, c_axis 241 242 def _get_output_for_x_vmap(): 243 new_input, _ = _reshape_merge_dims(x_dim, 0, input_x) 244 out = prim(new_input, weight) 245 out = _reshape_expand_dims(0, batch_size, out, prim_name) 246 return out, 0 247 248 def _get_output_for_w_vmap(): 249 if groups > 1: 250 expand_dim, merge_dim = _get_reshape_src_dim(w_dim, 0) 251 new_weight = _reshape_expand_dims(expand_dim, groups, weight, prim_name) 252 new_weight, _ = _reshape_merge_dims(merge_dim, 1, new_weight) 253 new_weight, new_w_shape = _reshape_merge_dims(0, 0, new_weight) 254 255 _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0]) 256 out = prim(input_x, new_weight) 257 258 out = _reshape_expand_dims(c_axis, groups, out, prim_name) 259 out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name) 260 out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out) 261 return out, c_axis 262 263 new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight) 264 _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0]) 265 out = prim(input_x, new_weight) 266 out = _reshape_expand_dims(c_axis, batch_size, out, prim_name) 267 return out, c_axis 268 269 if x_dim is not None and w_dim is not None: 270 if prim_name == "Conv3D": 271 _raise_value_error("vmap in_axes of 'x' and 'weight in `{}` cannot be non-None at the same time," 272 "but got {} and {}.".format(prim_name, x_dim, w_dim)) 273 output = _get_output_for_x_w_vmap() 274 elif x_dim is not None: 275 output = _get_output_for_x_vmap() 276 else: 277 output = _get_output_for_w_vmap() 278 return output 279 280 281def _conv_transpose_vmap_rule(prim, batch_size, dout_bdim, weight_bdim, input_size_bdim, attr_list): 282 """ 283 Vmap rule for transposed convolution operations, such as `Conv2DTranspose`, 284 `Conv2DBackpropInput`, `Conv3DTranspose` and `Conv3DBackpropInput`. 285 """ 286 prim_name = attr_list[0] 287 input_size = None 288 if input_size_bdim is not None: 289 input_size, input_size_dim = input_size_bdim 290 if input_size_dim is not None: 291 _raise_value_error("Vmap in_axes of 'input_size' in `{}` must be None, " 292 "but got {}.".format(prim_name, input_size_dim)) 293 if not isinstance(input_size, tuple): 294 _raise_value_error("Unsupported vmap for dynamic shape of `{}` when " 295 "'input_size' is a tensor.".format(prim_name)) 296 297 dout, dout_dim = dout_bdim 298 weight, w_dim = weight_bdim 299 300 groups = attr_list[1] 301 data_format = attr_list[2] 302 c_axis = _get_channel_index(data_format, prim_name) 303 304 def _get_conv_transpose_output(dout, weight, input_size): 305 out = None 306 if prim_name in ('Conv2DTranspose', 'Conv2DBackpropInput'): 307 out = prim(dout, weight, input_size) 308 elif prim_name == "Conv3DTranspose": 309 out = prim(dout, weight) 310 elif prim_name == "Conv3DBackpropInput": 311 out = prim(weight, dout, input_size) 312 else: 313 _raise_value_error("Unsupported the operation: `{}`.".format(prim_name)) 314 return out 315 316 def _get_output_for_dout_weight_vmap(): 317 _update_group_attr(prim, groups, batch_size) 318 new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout) 319 new_weight, _ = _reshape_merge_dims(w_dim, 0, weight) 320 new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis) 321 322 out = _get_conv_transpose_output(new_dout, new_weight, new_input_size) 323 out = _reshape_expand_dims(c_axis, batch_size, out, prim_name) 324 return out, c_axis 325 326 def _get_output_for_dout_vmap(): 327 new_dout, _ = _reshape_merge_dims(dout_dim, 0, dout) 328 new_input_size = _get_new_size_by_index(input_size, batch_size, 0) 329 330 out = _get_conv_transpose_output(new_dout, weight, new_input_size) 331 out = _reshape_expand_dims(0, batch_size, out, prim_name) 332 return out, 0 333 334 def _get_output_for_weight_vmap(): 335 new_weight, _ = _reshape_merge_dims(w_dim, c_axis, weight) 336 new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis) 337 338 out = _get_conv_transpose_output(dout, new_weight, new_input_size) 339 340 if groups > 1: 341 out = _reshape_expand_dims(c_axis, groups, out, prim_name) 342 out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name) 343 out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out) 344 else: 345 out = _reshape_expand_dims(c_axis, batch_size, out, prim_name) 346 return out, c_axis 347 348 if dout_dim is not None and w_dim is not None: 349 if prim_name in ("Conv3DTranspose", "Conv3DBackpropInput"): 350 _raise_value_error("vmap in_axes of 'dout' and 'weight' in `{}` cannot be non-None at the same time," 351 "but got {} and {}.".format(prim_name, dout_dim, w_dim)) 352 output = _get_output_for_dout_weight_vmap() 353 elif dout_dim is not None: 354 output = _get_output_for_dout_vmap() 355 else: 356 output = _get_output_for_weight_vmap() 357 return output 358 359 360def _conv_backprop_filter_vmap_rule(prim, batch_size, dout_bdim, input_bdim, weight_size_bdim, attr_list): 361 """Vmap rule for `Conv2DBackpropFilter` and `Conv3DBackpropFilter` operations""" 362 dout, dout_dim = dout_bdim 363 input_x, x_dim = input_bdim 364 weight_size, w_size_dim = weight_size_bdim 365 366 prim_name = attr_list[0] 367 groups = attr_list[1] 368 data_format = attr_list[2] 369 c_axis = _get_channel_index(data_format, prim_name) 370 371 if w_size_dim is not None: 372 _raise_value_error("Vmap in_axes of 'weight_size' in `{}` must be None, " 373 "but got {}.".format(prim_name, w_size_dim)) 374 if not isinstance(weight_size, tuple): 375 _raise_value_error("Unsupported vmap for dynamic shape of `{}` when " 376 "'weight_size' is a tensor.".format(prim_name)) 377 378 def _get_conv_backprop_filter_output(dout, x, weight_size): 379 out = None 380 if prim_name == "Conv2DBackpropFilter": 381 out = prim(dout, x, weight_size) 382 elif prim_name == "Conv3DBackpropFilter": 383 out = prim(x, dout, weight_size) 384 else: 385 _raise_value_error("Unsupported the operation: `{}`.".format(prim_name)) 386 return out 387 388 def _get_output_for_dout_x_vmap(): 389 _update_group_attr(prim, groups, batch_size) 390 391 new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout) 392 new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x) 393 new_w_size = _get_new_size_by_index(weight_size, batch_size, 0) 394 395 out = _get_conv_backprop_filter_output(new_dout, new_input, new_w_size) 396 out = _reshape_expand_dims(0, batch_size, out, prim_name) 397 return out, 0 398 399 def _get_output_for_x_vmap(): 400 new_w_size = _get_new_size_by_index(weight_size, batch_size, c_axis) 401 if groups > 1: 402 expand_dim, merge_dim = _get_reshape_src_dim(x_dim, c_axis) 403 new_input = _reshape_expand_dims(expand_dim, groups, input_x, prim_name) 404 new_input, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_input) 405 new_input, _ = _reshape_merge_dims(c_axis, c_axis, new_input) 406 else: 407 new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x) 408 409 out = _get_conv_backprop_filter_output(dout, new_input, new_w_size) 410 out = _reshape_expand_dims(c_axis, batch_size, out, prim_name) 411 return out, c_axis 412 413 def _get_output_for_dout_vmap(): 414 new_w_size = _get_new_size_by_index(weight_size, batch_size, 0) 415 if groups > 1: 416 expand_dim, merge_dim = _get_reshape_src_dim(dout_dim, c_axis) 417 new_dout = _reshape_expand_dims(expand_dim, groups, dout, prim_name) 418 new_dout, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_dout) 419 new_dout, _ = _reshape_merge_dims(c_axis, c_axis, new_dout) 420 421 out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size) 422 out = _reshape_expand_dims(0, groups, out, prim_name) 423 out = _reshape_expand_dims(1, batch_size, out, prim_name) 424 out, _ = _reshape_merge_dims(0, 1, out) 425 return out, 0 426 427 new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout) 428 out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size) 429 out = _reshape_expand_dims(0, batch_size, out, prim_name) 430 return out, 0 431 432 if dout_dim is not None and x_dim is not None: 433 if prim_name == "Conv3DBackpropFilter": 434 _raise_value_error("vmap in_axes of 'dout' and 'x' in `{}` cannot be non-None at the same time," 435 "but got {} and {}.".format(prim_name, dout_dim, x_dim)) 436 output = _get_output_for_dout_x_vmap() 437 elif x_dim is not None: 438 output = _get_output_for_x_vmap() 439 else: 440 output = _get_output_for_dout_vmap() 441 return output 442