• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""bprop primitives"""
17from mindspore.ops import _constants
18from ..operations import _grad_ops as G
19from .. import functional as F
20from .. import operations as P
21from ..composite import multitype_ops as C
22from .grad_base import bprops
23
24get_dtype = P.DType()
25# Unused parameters are placeholders.
26
27
28@bprops.register("MaximumGrad")
29@bprops.register("MinimumGrad")
30def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
31    """Backpropagator for primitive `MaximumGrad` and `MinimumGrad`."""
32    out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
33    out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
34    dz = out0 * dout[0] + out1 * dout[1]
35    return F.zeros_like(x), F.zeros_like(y), dz
36
37
38@bprops.register("ReluGrad")
39def bprop_relu_grad_grad(x, y, out, dout):
40    """Backpropagator for primitive `ReluGrad`."""
41    input_grad = G.ReluGrad()
42    dy = input_grad(dout, y)
43    return dy, F.zeros_like(y)
44
45
46@bprops.register(_constants.kScalarAdd)
47def bprop_scalar_add(x, y, out, dout):
48    """Backpropagator for primitive `scalar_add`."""
49    return dout, dout
50
51
52@bprops.register(_constants.kScalarMul)
53def bprop_scalar_mul(x, y, out, dout):
54    """Backpropagator for primitive `scalar_mul`."""
55    return dout*y, dout*x
56
57
58@bprops.register(_constants.kScalarSub)
59def bprop_scalar_sub(x, y, out, dout):
60    """Backpropagator for primitive `scalar_sub`."""
61    return dout, -dout
62
63
64@bprops.register(_constants.kScalarDiv)
65def bprop_scalar_div(x, y, out, dout):
66    """Backpropagator for primitive `scalar_div`."""
67    return dout/y, (-dout) * (out/y)
68
69
70@bprops.register(_constants.kScalarPow)
71def bprop_scalar_pow(x, y, out, dout):
72    """Backpropagator for primitive `scalar_pow`."""
73    return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
74
75
76@bprops.register("scalar_exp")
77def bprop_scalar_exp(x, out, dout):
78    """Backpropagator for primitive `scalar_exp`."""
79    return (dout * out,)
80
81
82@bprops.register(_constants.kScalarUadd)
83def bprop_scalar_uadd(x, out, dout):
84    """Backpropagator for primitive `scalar_uadd`."""
85    return (dout,)
86
87
88@bprops.register(_constants.kScalarUsub)
89def bprop_scalar_usub(x, out, dout):
90    """Backpropagator for primitive `scalar_usub`."""
91    return (-dout,)
92
93
94@bprops.register("scalar_gt")
95def bprop_scalar_gt(x, y, out, dout):
96    """Backpropagator for primitive `scalar_gt`."""
97    return C.zeros_like(x), C.zeros_like(y)
98
99
100@bprops.register("scalar_lt")
101def bprop_scalar_lt(x, y, out, dout):
102    """Backpropagator for primitive `scalar_lt`."""
103    return C.zeros_like(x), C.zeros_like(y)
104
105
106@bprops.register("scalar_ge")
107def bprop_scalar_ge(x, y, out, dout):
108    """Backpropagator for primitive `scalar_ge`."""
109    return C.zeros_like(x), C.zeros_like(y)
110
111
112@bprops.register("scalar_le")
113def bprop_scalar_le(x, y, out, dout):
114    """Backpropagator for primitive `scalar_le`."""
115    return C.zeros_like(x), C.zeros_like(y)
116
117
118@bprops.register("scalar_eq")
119def bprop_scalar_eq(x, y, out, dout):
120    """Backpropagator for primitive `scalar_eq`."""
121    return C.zeros_like(x), C.zeros_like(y)
122
123
124@bprops.register("scalar_ne")
125def bprop_scalar_ne(x, y, out, dout):
126    """Backpropagator for primitive `scalar_eq`."""
127    return C.zeros_like(x), C.zeros_like(y)
128
129
130@bprops.register("scalar_cast")
131def bprop_scalar_cast(x, t, out, dout):
132    """Backpropagator for primitive `scalar_cast`."""
133    return F.scalar_cast(dout, F.typeof(x)), t
134
135
136@bprops.register(_constants.kTupleGetItem)
137def bprop_tuple_getitem(data, idx, out, dout):
138    """Backpropagator for primitive `tuple_getitem`."""
139    return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
140
141
142@bprops.register("list_getitem")
143def bprop_list_getitem(data, idx, out, dout):
144    """Backpropagator for primitive `list_getitem`."""
145    return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
146
147
148@bprops.register("identity")
149def bprop_identity(x, out, dout):
150    """Backpropagator for primitive `identity`."""
151    return (dout,)
152
153
154@bprops.register("make_ref")
155def bprop_make_ref(key, x, y, out, dout):
156    """Backpropagator for primitive `make_ref`."""
157    return (C.zeros_like(key), dout, C.zeros_like(y))
158
159
160@bprops.register("get_ref_value")
161def bprop_get_ref_value(x, out, dout):
162    """Backpropagator for primitive `get_ref_value`."""
163    return (dout,)
164
165
166@bprops.register("get_ref_key")
167def bprop_get_ref_key(x, out, dout):
168    """Backpropagator for primitive `get_ref_key`."""
169    return (C.zeros_like(x),)
170
171
172@bprops.register("scalar_to_array")
173def bprop_scalar_to_array(x, out, dout):
174    """Backpropagator for primitive `scalar_to_array`."""
175    return (F.array_to_scalar(dout),)
176
177
178@bprops.register("array_to_scalar")
179def bprop_array_to_scalar(x, out, dout):
180    """Backpropagator for primitive `array_to_scalar`."""
181    return (F.scalar_to_array(dout),)
182
183
184@bprops.register("reshape")
185def bprop_reshape(xs, shp, out, dout):
186    """Backpropagator for primitive `reshape`."""
187    return F.reshape(dout, F.shape(xs)), C.zeros_like(shp)
188
189
190@bprops.register("distribute")
191def bprop_distribute(arr, shp, out, dout):
192    """Backpropagator for primitive `distribute`."""
193    return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp)
194
195
196@bprops.register("shape")
197def bprop_shape(arr, out, dout):
198    """Backpropagator for primitive `shape`."""
199    return (C.zeros_like(arr),)
200
201
202@bprops.register("broadcast_shape")
203def bprop_broadcast_shape(shp1, shp2, out, dout):
204    """Backpropagator for primitive `broadcast_shape`."""
205    return C.zeros_like(shp1), C.zeros_like(shp2)
206
207
208@bprops.register("array_reduce")
209def bprop_array_reduce(fn, x, shp, out, dout):
210    """Backpropagator for primitive `array_reduce`."""
211    return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
212
213
214@bprops.register("Depend")
215def bprop_depend(x, y, out, dout):
216    """Backpropagator for primitive `depend`."""
217    return dout, C.zeros_like(y)
218
219
220@bprops.register("embed")
221def bprop_embed(x, out, dout):
222    """Backpropagator for primitive `embed`."""
223    return (C.zeros_like(x),)
224
225
226@bprops.register("bool_not")
227def bprop_bool_not(x, out, dout):
228    """Backpropagator for primitive `bool_not`."""
229    return (C.zeros_like(x),)
230
231
232@bprops.register("bool_or")
233def bprop_bool_or(x, y, out, dout):
234    """Backpropagator for primitive `bool_or`."""
235    return C.zeros_like(x), C.zeros_like(y)
236
237
238@bprops.register("stop_gradient")
239def bprop_stop_gradient(x, out, dout):
240    """Backpropagator for primitive `stop_gradient`."""
241    return (C.zeros_like(x),)
242
243
244@bprops.register("bool_and")
245def bprop_bool_and(x, y, out, dout):
246    """Backpropagator for primitive `bool_and`."""
247    return C.zeros_like(x), C.zeros_like(y)
248
249
250@bprops.register("Switch")
251def bprop_switch(cond, tb, fb, out, dout):
252    """Backpropagator for primitive `switch`."""
253    return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
254        F.switch(cond, C.zeros_like(fb), dout)
255
256
257def _fprop_switch_layer(index, layers):
258    """Backpropagator for primitive `switch_layer`."""
259    def _bprop_switch_layer(dout):
260        return dout, C.zeros_like(index), ()
261    return F.switch_layer(index, layers), _bprop_switch_layer
262
263
264@bprops.register("UpdateState")
265def bprop_update_state(u_monad, x, out, dout):
266    """Backpropagator for primitive `UpdateState`."""
267    return C.zeros_like(u_monad), C.zeros_like(x)
268
269
270@bprops.register("Load")
271def bprop_load(param, u_monad, out, dout):
272    """Backpropagator for primitive `load`."""
273    return dout, C.zeros_like(u_monad)
274