1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""bprop primitives""" 17from mindspore.ops import _constants 18from ..operations import _grad_ops as G 19from .. import functional as F 20from .. import operations as P 21from ..composite import multitype_ops as C 22from .grad_base import bprops 23 24get_dtype = P.DType() 25# Unused parameters are placeholders. 26 27 28@bprops.register("MaximumGrad") 29@bprops.register("MinimumGrad") 30def bprop_max_and_minimum_grad_grad(x, y, z, out, dout): 31 """Backpropagator for primitive `MaximumGrad` and `MinimumGrad`.""" 32 out0 = F.cast(out[0] != 0, get_dtype(dout[0])) 33 out1 = F.cast(out[1] != 0, get_dtype(dout[1])) 34 dz = out0 * dout[0] + out1 * dout[1] 35 return F.zeros_like(x), F.zeros_like(y), dz 36 37 38@bprops.register("ReluGrad") 39def bprop_relu_grad_grad(x, y, out, dout): 40 """Backpropagator for primitive `ReluGrad`.""" 41 input_grad = G.ReluGrad() 42 dy = input_grad(dout, y) 43 return dy, F.zeros_like(y) 44 45 46@bprops.register(_constants.kScalarAdd) 47def bprop_scalar_add(x, y, out, dout): 48 """Backpropagator for primitive `scalar_add`.""" 49 return dout, dout 50 51 52@bprops.register(_constants.kScalarMul) 53def bprop_scalar_mul(x, y, out, dout): 54 """Backpropagator for primitive `scalar_mul`.""" 55 return dout*y, dout*x 56 57 58@bprops.register(_constants.kScalarSub) 59def bprop_scalar_sub(x, y, out, dout): 60 """Backpropagator for primitive `scalar_sub`.""" 61 return dout, -dout 62 63 64@bprops.register(_constants.kScalarDiv) 65def bprop_scalar_div(x, y, out, dout): 66 """Backpropagator for primitive `scalar_div`.""" 67 return dout/y, (-dout) * (out/y) 68 69 70@bprops.register(_constants.kScalarPow) 71def bprop_scalar_pow(x, y, out, dout): 72 """Backpropagator for primitive `scalar_pow`.""" 73 return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out) 74 75 76@bprops.register("scalar_exp") 77def bprop_scalar_exp(x, out, dout): 78 """Backpropagator for primitive `scalar_exp`.""" 79 return (dout * out,) 80 81 82@bprops.register(_constants.kScalarUadd) 83def bprop_scalar_uadd(x, out, dout): 84 """Backpropagator for primitive `scalar_uadd`.""" 85 return (dout,) 86 87 88@bprops.register(_constants.kScalarUsub) 89def bprop_scalar_usub(x, out, dout): 90 """Backpropagator for primitive `scalar_usub`.""" 91 return (-dout,) 92 93 94@bprops.register("scalar_gt") 95def bprop_scalar_gt(x, y, out, dout): 96 """Backpropagator for primitive `scalar_gt`.""" 97 return C.zeros_like(x), C.zeros_like(y) 98 99 100@bprops.register("scalar_lt") 101def bprop_scalar_lt(x, y, out, dout): 102 """Backpropagator for primitive `scalar_lt`.""" 103 return C.zeros_like(x), C.zeros_like(y) 104 105 106@bprops.register("scalar_ge") 107def bprop_scalar_ge(x, y, out, dout): 108 """Backpropagator for primitive `scalar_ge`.""" 109 return C.zeros_like(x), C.zeros_like(y) 110 111 112@bprops.register("scalar_le") 113def bprop_scalar_le(x, y, out, dout): 114 """Backpropagator for primitive `scalar_le`.""" 115 return C.zeros_like(x), C.zeros_like(y) 116 117 118@bprops.register("scalar_eq") 119def bprop_scalar_eq(x, y, out, dout): 120 """Backpropagator for primitive `scalar_eq`.""" 121 return C.zeros_like(x), C.zeros_like(y) 122 123 124@bprops.register("scalar_ne") 125def bprop_scalar_ne(x, y, out, dout): 126 """Backpropagator for primitive `scalar_eq`.""" 127 return C.zeros_like(x), C.zeros_like(y) 128 129 130@bprops.register("scalar_cast") 131def bprop_scalar_cast(x, t, out, dout): 132 """Backpropagator for primitive `scalar_cast`.""" 133 return F.scalar_cast(dout, F.typeof(x)), t 134 135 136@bprops.register(_constants.kTupleGetItem) 137def bprop_tuple_getitem(data, idx, out, dout): 138 """Backpropagator for primitive `tuple_getitem`.""" 139 return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) 140 141 142@bprops.register("list_getitem") 143def bprop_list_getitem(data, idx, out, dout): 144 """Backpropagator for primitive `list_getitem`.""" 145 return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) 146 147 148@bprops.register("identity") 149def bprop_identity(x, out, dout): 150 """Backpropagator for primitive `identity`.""" 151 return (dout,) 152 153 154@bprops.register("make_ref") 155def bprop_make_ref(key, x, y, out, dout): 156 """Backpropagator for primitive `make_ref`.""" 157 return (C.zeros_like(key), dout, C.zeros_like(y)) 158 159 160@bprops.register("get_ref_value") 161def bprop_get_ref_value(x, out, dout): 162 """Backpropagator for primitive `get_ref_value`.""" 163 return (dout,) 164 165 166@bprops.register("get_ref_key") 167def bprop_get_ref_key(x, out, dout): 168 """Backpropagator for primitive `get_ref_key`.""" 169 return (C.zeros_like(x),) 170 171 172@bprops.register("scalar_to_array") 173def bprop_scalar_to_array(x, out, dout): 174 """Backpropagator for primitive `scalar_to_array`.""" 175 return (F.array_to_scalar(dout),) 176 177 178@bprops.register("array_to_scalar") 179def bprop_array_to_scalar(x, out, dout): 180 """Backpropagator for primitive `array_to_scalar`.""" 181 return (F.scalar_to_array(dout),) 182 183 184@bprops.register("reshape") 185def bprop_reshape(xs, shp, out, dout): 186 """Backpropagator for primitive `reshape`.""" 187 return F.reshape(dout, F.shape(xs)), C.zeros_like(shp) 188 189 190@bprops.register("distribute") 191def bprop_distribute(arr, shp, out, dout): 192 """Backpropagator for primitive `distribute`.""" 193 return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp) 194 195 196@bprops.register("shape") 197def bprop_shape(arr, out, dout): 198 """Backpropagator for primitive `shape`.""" 199 return (C.zeros_like(arr),) 200 201 202@bprops.register("broadcast_shape") 203def bprop_broadcast_shape(shp1, shp2, out, dout): 204 """Backpropagator for primitive `broadcast_shape`.""" 205 return C.zeros_like(shp1), C.zeros_like(shp2) 206 207 208@bprops.register("array_reduce") 209def bprop_array_reduce(fn, x, shp, out, dout): 210 """Backpropagator for primitive `array_reduce`.""" 211 return F.distribute(dout, F.shape(x)), C.zeros_like(shp) 212 213 214@bprops.register("Depend") 215def bprop_depend(x, y, out, dout): 216 """Backpropagator for primitive `depend`.""" 217 return dout, C.zeros_like(y) 218 219 220@bprops.register("embed") 221def bprop_embed(x, out, dout): 222 """Backpropagator for primitive `embed`.""" 223 return (C.zeros_like(x),) 224 225 226@bprops.register("bool_not") 227def bprop_bool_not(x, out, dout): 228 """Backpropagator for primitive `bool_not`.""" 229 return (C.zeros_like(x),) 230 231 232@bprops.register("bool_or") 233def bprop_bool_or(x, y, out, dout): 234 """Backpropagator for primitive `bool_or`.""" 235 return C.zeros_like(x), C.zeros_like(y) 236 237 238@bprops.register("stop_gradient") 239def bprop_stop_gradient(x, out, dout): 240 """Backpropagator for primitive `stop_gradient`.""" 241 return (C.zeros_like(x),) 242 243 244@bprops.register("bool_and") 245def bprop_bool_and(x, y, out, dout): 246 """Backpropagator for primitive `bool_and`.""" 247 return C.zeros_like(x), C.zeros_like(y) 248 249 250@bprops.register("Switch") 251def bprop_switch(cond, tb, fb, out, dout): 252 """Backpropagator for primitive `switch`.""" 253 return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ 254 F.switch(cond, C.zeros_like(fb), dout) 255 256 257def _fprop_switch_layer(index, layers): 258 """Backpropagator for primitive `switch_layer`.""" 259 def _bprop_switch_layer(dout): 260 return dout, C.zeros_like(index), () 261 return F.switch_layer(index, layers), _bprop_switch_layer 262 263 264@bprops.register("UpdateState") 265def bprop_update_state(u_monad, x, out, dout): 266 """Backpropagator for primitive `UpdateState`.""" 267 return C.zeros_like(u_monad), C.zeros_like(x) 268 269 270@bprops.register("Load") 271def bprop_load(param, u_monad, out, dout): 272 """Backpropagator for primitive `load`.""" 273 return dout, C.zeros_like(u_monad) 274