• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""convolution vmap impl"""
17from __future__ import absolute_import
18
19import mindspore.numpy as mnp
20from mindspore.ops import constexpr
21from mindspore.ops.primitive import _primexpr
22from mindspore.ops import operations as P
23from mindspore.ops import functional as F
24from mindspore.ops.operations import nn_ops as nps
25from mindspore.ops.operations import _grad_ops as G
26from mindspore.ops.primitive import Primitive
27from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, \
28    _raise_value_error, _vmap_update_prim_attr, _vmap_clone_prim
29
30
31@vmap_rules_getters.register(P.Conv2D)
32@vmap_rules_getters.register(P.Conv3D)
33def get_conv_vmap_rule(prim, axis_size):
34    """Vmap rule for `Conv2D` and `Conv3D` operations."""
35    if isinstance(prim, str):
36        prim = Primitive(prim)
37
38    attr_list = [prim.name, prim.group, prim.data_format]
39    new_prim = _vmap_clone_prim(prim)
40
41    def vmap_rule(input_bdim, weight_bdim):
42        is_all_none, result = vmap_general_preprocess(prim, input_bdim, weight_bdim)
43        if is_all_none:
44            return result
45        return _conv_vmap_rule(new_prim, axis_size, input_bdim, weight_bdim, attr_list)
46
47    return vmap_rule
48
49
50@vmap_rules_getters.register(P.Conv2DTranspose)
51@vmap_rules_getters.register(P.Conv2DBackpropInput)
52def get_conv2d_transpose_vmap_rule(prim, axis_size):
53    """Vmap rule for `Conv2DTranspose` and `Conv2DBackpropInput` operations."""
54    if isinstance(prim, str):
55        prim = Primitive(prim)
56
57    attr_list = [prim.name, prim.group, prim.data_format]
58    new_prim = _vmap_clone_prim(prim)
59
60    def vmap_rule(dout_bdim, weight_bdim, input_size_bdim):
61        is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim, input_size_bdim)
62        if is_all_none:
63            return result
64        return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
65                                         weight_bdim, input_size_bdim, attr_list)
66
67    return vmap_rule
68
69
70@vmap_rules_getters.register(P.Conv3DTranspose)
71def get_conv3d_transpose_vmap_rule(prim, axis_size):
72    """Vmap rule for `Conv3DTranspose` operation."""
73    if isinstance(prim, str):
74        prim = Primitive(prim)
75
76    attr_list = [prim.name, prim.group, prim.data_format]
77    new_prim = _vmap_clone_prim(prim)
78
79    def vmap_rule(dout_bdim, weight_bdim):
80        is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim)
81        if is_all_none:
82            return result
83        return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, weight_bdim, None, attr_list)
84
85    return vmap_rule
86
87
88@vmap_rules_getters.register(nps.Conv3DBackpropInput)
89def get_conv3d_backprop_input_vmap_rule(prim, axis_size):
90    """Vmap rule for `Conv3DBackpropInput` operation."""
91    if isinstance(prim, str):
92        prim = Primitive(prim)
93
94    attr_list = [prim.name, prim.group, prim.data_format]
95    new_prim = _vmap_clone_prim(prim)
96
97    def vmap_rule(weight_bdim, dout_bdim, input_size_bdim):
98        is_all_none, result = vmap_general_preprocess(prim, weight_bdim, dout_bdim, input_size_bdim)
99        if is_all_none:
100            return result
101        return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
102                                         weight_bdim, input_size_bdim, attr_list)
103
104    return vmap_rule
105
106
107@vmap_rules_getters.register(G.Conv2DBackpropFilter)
108def get_conv2d_backprop_filter_vmap_rule(prim, axis_size):
109    """Vmap rule for `Conv2DBackpropFilter` operation."""
110    if isinstance(prim, str):
111        prim = Primitive(prim)
112
113    attr_list = [prim.name, prim.group, prim.data_format]
114    new_prim = _vmap_clone_prim(prim)
115
116    def vmap_rule(dout_bdim, input_x_bdim, weight_size_bdim):
117        is_all_none, result = vmap_general_preprocess(prim, dout_bdim, input_x_bdim, weight_size_bdim)
118        if is_all_none:
119            return result
120        return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
121                                               input_x_bdim, weight_size_bdim, attr_list)
122
123    return vmap_rule
124
125
126@vmap_rules_getters.register(G.Conv3DBackpropFilter)
127def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
128    """Vmap rule for `Conv3DBackpropFilter` operation."""
129    if isinstance(prim, str):
130        prim = Primitive(prim)
131
132    attr_list = [prim.name, prim.group, prim.data_format]
133    new_prim = _vmap_clone_prim(prim)
134
135    def vmap_rule(input_x_bdim, dout_bdim, weight_size_bdim):
136        is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, dout_bdim, weight_size_bdim)
137        if is_all_none:
138            return result
139        return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
140                                               input_x_bdim, weight_size_bdim, attr_list)
141
142    return vmap_rule
143
144
145@_primexpr
146def _get_reshape_src_dim(data_dim, cmp_dim):
147    """Get source dim for reshape"""
148    if data_dim > cmp_dim:
149        expand_dim = cmp_dim
150        merge_dim = data_dim + 1
151    else:
152        expand_dim = cmp_dim + 1
153        merge_dim = data_dim
154    return expand_dim, merge_dim
155
156
157@_primexpr
158def _get_merge_shape(src_dim, dst_dim, shape):
159    """Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
160    new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
161    new_shape[dst_dim] *= shape[src_dim]
162    return tuple(new_shape)
163
164
165def _reshape_merge_dims(src_dim, dst_dim, target):
166    """Reshape target by merging the src_dim and dst_dim."""
167    shape = F.shape(target)
168    new_shape = _get_merge_shape(src_dim, dst_dim, shape)
169    new_target = mnp.moveaxis(target, src_dim, dst_dim)
170    output = F.reshape(new_target, new_shape)
171    return output, new_shape
172
173
174@_primexpr
175def _get_expand_shape(src_dim, dst_size, shape, prim_name):
176    """Get new shape for splitting src_dim into dst_size parts."""
177    dst_size2 = shape[src_dim] // dst_size
178    new_shape = list(shape)
179    new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
180    return tuple(new_shape)
181
182
183def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
184    """Reshape target by splitting src_dim into dst_size parts."""
185    shape = F.shape(target)
186    new_shape = _get_expand_shape(src_dim, dst_size, shape, prim_name)
187    return F.reshape(target, new_shape)
188
189
190@_primexpr
191def _get_new_size_by_index(input_size, batch_size, index):
192    """Get the new size of input_size by multiplying input_size[index] by batch_size."""
193    if input_size is None:
194        new_size = ()
195        return new_size
196    new_size = list(input_size)
197    new_size[index] *= batch_size
198    return tuple(new_size)
199
200
201@_primexpr
202def _update_group_attr(prim, groups, batch_size):
203    """Set new value for 'group' attribute of the convolution primitive."""
204    group = groups * batch_size
205    _vmap_update_prim_attr(prim, 'group', group)
206    _vmap_update_prim_attr(prim, 'groups', group)
207
208
209@constexpr
210def _get_channel_index(data_format, prim_name):
211    """Get channel index by data_format, only supports NHWC/NCHW/NCDHW now."""
212    index = 0
213    if data_format == "NHWC":
214        index = 3
215    elif data_format in ("NCHW", "NCDHW"):
216        index = 1
217    else:
218        _raise_value_error("'data_format' in {} should be NHWC/NCHW/NCDHW, "
219                           "but got {}.".format(prim_name, data_format))
220    return index
221
222
223def _conv_vmap_rule(prim, batch_size, input_bdim, weight_bdim, attr_list):
224    """Vmap rule for Convolution operations, such as `Conv2D` and `Conv3D`."""
225    input_x, x_dim = input_bdim
226    weight, w_dim = weight_bdim
227    prim_name = attr_list[0]
228    groups = attr_list[1]
229    data_format = attr_list[2]
230    c_axis = _get_channel_index(data_format, prim_name)
231
232    def _get_output_for_x_w_vmap():
233        new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
234        new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
235
236        _update_group_attr(prim, groups, batch_size)
237        _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
238        out = prim(new_input, new_weight)
239        out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
240        return out, c_axis
241
242    def _get_output_for_x_vmap():
243        new_input, _ = _reshape_merge_dims(x_dim, 0, input_x)
244        out = prim(new_input, weight)
245        out = _reshape_expand_dims(0, batch_size, out, prim_name)
246        return out, 0
247
248    def _get_output_for_w_vmap():
249        if groups > 1:
250            expand_dim, merge_dim = _get_reshape_src_dim(w_dim, 0)
251            new_weight = _reshape_expand_dims(expand_dim, groups, weight, prim_name)
252            new_weight, _ = _reshape_merge_dims(merge_dim, 1, new_weight)
253            new_weight, new_w_shape = _reshape_merge_dims(0, 0, new_weight)
254
255            _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
256            out = prim(input_x, new_weight)
257
258            out = _reshape_expand_dims(c_axis, groups, out, prim_name)
259            out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
260            out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
261            return out, c_axis
262
263        new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
264        _vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
265        out = prim(input_x, new_weight)
266        out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
267        return out, c_axis
268
269    if x_dim is not None and w_dim is not None:
270        if prim_name == "Conv3D":
271            _raise_value_error("vmap in_axes of 'x' and 'weight in `{}` cannot be non-None at the same time,"
272                               "but got {} and {}.".format(prim_name, x_dim, w_dim))
273        output = _get_output_for_x_w_vmap()
274    elif x_dim is not None:
275        output = _get_output_for_x_vmap()
276    else:
277        output = _get_output_for_w_vmap()
278    return output
279
280
281def _conv_transpose_vmap_rule(prim, batch_size, dout_bdim, weight_bdim, input_size_bdim, attr_list):
282    """
283    Vmap rule for transposed convolution operations, such as `Conv2DTranspose`,
284    `Conv2DBackpropInput`, `Conv3DTranspose` and `Conv3DBackpropInput`.
285    """
286    prim_name = attr_list[0]
287    input_size = None
288    if input_size_bdim is not None:
289        input_size, input_size_dim = input_size_bdim
290        if input_size_dim is not None:
291            _raise_value_error("Vmap in_axes of 'input_size' in `{}` must be None, "
292                               "but got {}.".format(prim_name, input_size_dim))
293        if not isinstance(input_size, tuple):
294            _raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
295                               "'input_size' is a tensor.".format(prim_name))
296
297    dout, dout_dim = dout_bdim
298    weight, w_dim = weight_bdim
299
300    groups = attr_list[1]
301    data_format = attr_list[2]
302    c_axis = _get_channel_index(data_format, prim_name)
303
304    def _get_conv_transpose_output(dout, weight, input_size):
305        out = None
306        if prim_name in ('Conv2DTranspose', 'Conv2DBackpropInput'):
307            out = prim(dout, weight, input_size)
308        elif prim_name == "Conv3DTranspose":
309            out = prim(dout, weight)
310        elif prim_name == "Conv3DBackpropInput":
311            out = prim(weight, dout, input_size)
312        else:
313            _raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
314        return out
315
316    def _get_output_for_dout_weight_vmap():
317        _update_group_attr(prim, groups, batch_size)
318        new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
319        new_weight, _ = _reshape_merge_dims(w_dim, 0, weight)
320        new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
321
322        out = _get_conv_transpose_output(new_dout, new_weight, new_input_size)
323        out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
324        return out, c_axis
325
326    def _get_output_for_dout_vmap():
327        new_dout, _ = _reshape_merge_dims(dout_dim, 0, dout)
328        new_input_size = _get_new_size_by_index(input_size, batch_size, 0)
329
330        out = _get_conv_transpose_output(new_dout, weight, new_input_size)
331        out = _reshape_expand_dims(0, batch_size, out, prim_name)
332        return out, 0
333
334    def _get_output_for_weight_vmap():
335        new_weight, _ = _reshape_merge_dims(w_dim, c_axis, weight)
336        new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
337
338        out = _get_conv_transpose_output(dout, new_weight, new_input_size)
339
340        if groups > 1:
341            out = _reshape_expand_dims(c_axis, groups, out, prim_name)
342            out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
343            out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
344        else:
345            out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
346        return out, c_axis
347
348    if dout_dim is not None and w_dim is not None:
349        if prim_name in ("Conv3DTranspose", "Conv3DBackpropInput"):
350            _raise_value_error("vmap in_axes of 'dout' and 'weight' in `{}` cannot be non-None at the same time,"
351                               "but got {} and {}.".format(prim_name, dout_dim, w_dim))
352        output = _get_output_for_dout_weight_vmap()
353    elif dout_dim is not None:
354        output = _get_output_for_dout_vmap()
355    else:
356        output = _get_output_for_weight_vmap()
357    return output
358
359
360def _conv_backprop_filter_vmap_rule(prim, batch_size, dout_bdim, input_bdim, weight_size_bdim, attr_list):
361    """Vmap rule for `Conv2DBackpropFilter` and `Conv3DBackpropFilter` operations"""
362    dout, dout_dim = dout_bdim
363    input_x, x_dim = input_bdim
364    weight_size, w_size_dim = weight_size_bdim
365
366    prim_name = attr_list[0]
367    groups = attr_list[1]
368    data_format = attr_list[2]
369    c_axis = _get_channel_index(data_format, prim_name)
370
371    if w_size_dim is not None:
372        _raise_value_error("Vmap in_axes of 'weight_size' in `{}` must be None, "
373                           "but got {}.".format(prim_name, w_size_dim))
374    if not isinstance(weight_size, tuple):
375        _raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
376                           "'weight_size' is a tensor.".format(prim_name))
377
378    def _get_conv_backprop_filter_output(dout, x, weight_size):
379        out = None
380        if prim_name == "Conv2DBackpropFilter":
381            out = prim(dout, x, weight_size)
382        elif prim_name == "Conv3DBackpropFilter":
383            out = prim(x, dout, weight_size)
384        else:
385            _raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
386        return out
387
388    def _get_output_for_dout_x_vmap():
389        _update_group_attr(prim, groups, batch_size)
390
391        new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
392        new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
393        new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
394
395        out = _get_conv_backprop_filter_output(new_dout, new_input, new_w_size)
396        out = _reshape_expand_dims(0, batch_size, out, prim_name)
397        return out, 0
398
399    def _get_output_for_x_vmap():
400        new_w_size = _get_new_size_by_index(weight_size, batch_size, c_axis)
401        if groups > 1:
402            expand_dim, merge_dim = _get_reshape_src_dim(x_dim, c_axis)
403            new_input = _reshape_expand_dims(expand_dim, groups, input_x, prim_name)
404            new_input, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_input)
405            new_input, _ = _reshape_merge_dims(c_axis, c_axis, new_input)
406        else:
407            new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
408
409        out = _get_conv_backprop_filter_output(dout, new_input, new_w_size)
410        out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
411        return out, c_axis
412
413    def _get_output_for_dout_vmap():
414        new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
415        if groups > 1:
416            expand_dim, merge_dim = _get_reshape_src_dim(dout_dim, c_axis)
417            new_dout = _reshape_expand_dims(expand_dim, groups, dout, prim_name)
418            new_dout, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_dout)
419            new_dout, _ = _reshape_merge_dims(c_axis, c_axis, new_dout)
420
421            out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
422            out = _reshape_expand_dims(0, groups, out, prim_name)
423            out = _reshape_expand_dims(1, batch_size, out, prim_name)
424            out, _ = _reshape_merge_dims(0, 1, out)
425            return out, 0
426
427        new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
428        out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
429        out = _reshape_expand_dims(0, batch_size, out, prim_name)
430        return out, 0
431
432    if dout_dim is not None and x_dim is not None:
433        if prim_name == "Conv3DBackpropFilter":
434            _raise_value_error("vmap in_axes of 'dout' and 'x' in `{}` cannot be non-None at the same time,"
435                               "but got {} and {}.".format(prim_name, dout_dim, x_dim))
436        output = _get_output_for_dout_x_vmap()
437    elif x_dim is not None:
438        output = _get_output_for_x_vmap()
439    else:
440        output = _get_output_for_dout_vmap()
441    return output
442