1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Implementation for internal polymorphism `add` operations.""" 17 18from . import _compile_utils as utils 19from ...composite import base 20from ... import functional as F 21 22 23add = base.MultitypeFuncGraph('add', True) 24"""`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator.""" 25 26 27_add_backward = base.MultitypeFuncGraph('add_backward') 28""" 29`_add_backward` is an metafuncgraph object which will add_backward two objects according to input type 30using ".register" decorator. 31""" 32 33 34class _TupleAdd(base.TupleAdd_): 35 """ 36 Adding two tuples. 37 38 Args: 39 x (tuple): x 40 y (tuple): y 41 42 Returns: 43 Tuple, consists of elements of x and elements of y. 44 """ 45 46 def __init__(self, name): 47 """Initialize _TupleAdd.""" 48 base.TupleAdd_.__init__(self, name) 49 50 def __call__(self, *args): 51 pass 52 53 54_tuple_add = _TupleAdd('tuple_add') 55"""`_tuple_add` is an metafuncgraph object which will concatenate two tuples to form a tuple.""" 56 57 58@add.register("Number", "Number") 59@_add_backward.register("Number", "Number") 60def _scalar_add_scalar(x, y): 61 """ 62 Returns the sum of two numbers. 63 64 Args: 65 x (Number): x 66 y (Number): y 67 68 Returns: 69 Number, equal to x + y, has the same type as x. 70 """ 71 return F.scalar_add(x, y) 72 73 74@add.register("String", "String") 75def _string_concat_string(x, y): 76 """ 77 Concatenate the string y to the string x. 78 79 Args: 80 x (str): The first input string. 81 y (str): the second input string. 82 83 Returns: 84 str, concatenate the y to the x. 85 """ 86 return F.string_concat(x, y) 87 88 89@add.register("Number", "Tensor") 90def _scalar_add_tensor(x, y): 91 """ 92 Number is added to tensor. 93 94 Args: 95 x (Number): x 96 y (Tensor): The dtype is same as x. 97 98 Returns: 99 Tensor, has the same dtype as x. 100 """ 101 return F.add(x, y) 102 103 104@add.register("Tensor", "Number") 105def _tensor_add_scalar(x, y): 106 """ 107 Tensor is added to number. 108 109 Args: 110 x (Tensor): x 111 y (Number): The dtype is same as x. 112 113 Returns: 114 Tensor, has the same dtype as x. 115 """ 116 return F.add(x, y) 117 118 119@add.register("Tuple", "Tensor") 120def _tuple_add_tensor(x, y): 121 """ 122 Tuple is added to tensor. 123 124 Args: 125 x (Tuple): x 126 y (Tensor): The dtype is same as x. 127 128 Returns: 129 Tensor, has the same dtype as x. 130 """ 131 x = utils.sequence_to_tensor(x, y.dtype) 132 return F.tensor_add(x, y) 133 134 135@add.register("Tensor", "Tuple") 136def _tensor_add_tuple(x, y): 137 """ 138 Tensor is added to number. 139 140 Args: 141 x (Tensor): x 142 y (Tuple): The dtype is same as x. 143 144 Returns: 145 Tensor, has the same dtype as x. 146 """ 147 y = utils.sequence_to_tensor(y, x.dtype) 148 return F.tensor_add(x, y) 149 150 151@add.register("List", "Tensor") 152def _list_add_tensor(x, y): 153 """ 154 Tuple is added to tensor. 155 156 Args: 157 x (List): x 158 y (Tensor): The dtype is same as x. 159 160 Returns: 161 Tensor, has the same dtype as x. 162 """ 163 x = utils.sequence_to_tensor(x, y.dtype) 164 return F.tensor_add(x, y) 165 166 167@add.register("Tensor", "List") 168def _tensor_add_list(x, y): 169 """ 170 Tensor is added to number. 171 172 Args: 173 x (Tensor): x 174 y (List): The dtype is same as x. 175 176 Returns: 177 Tensor, has the same dtype as x. 178 """ 179 y = utils.sequence_to_tensor(y, x.dtype) 180 return F.tensor_add(x, y) 181 182 183@add.register("List", "List") 184def _list_add_list(x, y): 185 """ 186 list is added to list. 187 188 Args: 189 x (list): x 190 y (list): y. 191 192 Returns: 193 list, has the same dtype as x. 194 """ 195 for i in y: 196 x.append(i) 197 return x 198 199 200@add.register("Tensor", "Tensor") 201def _tensor_add_tensor(x, y): 202 """ 203 Returns x + y element-wise. 204 205 Args: 206 x (Tensor): x 207 y (Tensor): The dtype is same as x. 208 209 Returns: 210 Tensor, has the same dtype as x. 211 """ 212 return F.add(x, y) 213 214 215@add.register("RowTensor", "Tensor") 216def add_rowtensor_tensor(x, y): 217 """ 218 Adds RowTensor and Tensor. 219 220 Args: 221 x (RowTensor): x 222 y (Tensor): y 223 224 Returns: 225 RowTensor, the dtype is same as x. 226 """ 227 return F.row_tensor_add(x, y) 228 229 230@add.register("None", "None") 231def _none_add_none(x, y): 232 """ 233 Adds None and None. 234 235 Args: 236 x (None): x 237 y (None): y 238 239 Returns: 240 None. 241 """ 242 return None 243 244 245@_add_backward.register("EnvType", "EnvType") 246def _add_env(x, y): 247 """ 248 Adds two EnvType variables. 249 250 Args: 251 x (EnvType): x 252 y (EnvType): y 253 254 Returns: 255 EnvType, equal to x + y. 256 """ 257 return F.env_add(x, y) 258 259 260@add.register("Tuple", "Tuple") 261def _add_tuple(x, y): 262 """ 263 Adds two tuples. 264 265 Args: 266 x (tuple): x 267 y (tuple): y 268 269 Returns: 270 Tuple, consists of elements of x and elements of y. 271 """ 272 return _tuple_add(x, y) 273 274 275@_add_backward.register("Tensor", "Tensor") 276def _add_addn(x, y): 277 """ 278 Adds two tensors by element. 279 280 Args: 281 x (Tensor): x 282 y (Tensor): The dtype is same as x. 283 284 Returns: 285 Tensor, the dtype is same as x. 286 """ 287 return F.addn((x, y)) 288 289 290@_add_backward.register("UMonad", "UMonad") 291def _add_umonad_umonad(x, y): 292 """ 293 Adds two monad. 294 295 Args: 296 x (UMonad): x 297 y (UMonad): y 298 299 Returns: 300 Monad, the dtype is same as x. 301 """ 302 return x 303 304 305@_add_backward.register("IOMonad", "IOMonad") 306def _add_iomonad_iomonad(x, y): 307 """ 308 Adds two monad. 309 310 Args: 311 x (IOMonad): x 312 y (IOMonad): y 313 314 Returns: 315 Monad, the dtype is same as x. 316 """ 317 return x 318 319 320@_add_backward.register("RowTensor", "Tensor") 321def _add_rowtensor_tensor(x, y): 322 """ 323 Adds RowTensor and Tensor. 324 325 Args: 326 x (RowTensor): x 327 y (Tensor): y 328 329 Returns: 330 RowTensor, the dtype is same as x. 331 """ 332 return x + y 333 334 335@_add_backward.register("None", "None") 336def _add_nonetensor_tensor(x, y): 337 """ 338 Adds None and None. 339 340 Args: 341 x (None): x 342 y (None): y 343 344 Returns: 345 None. 346 """ 347 return x + y 348 349hyper_add = base.HyperMap(_add_backward) 350