1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2021 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17 18"""Basic composite operations.""" 19from functools import partial 20from types import FunctionType 21 22from mindspore import context 23from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ 24 TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ 25from ...common import dtype as mstype 26from ...common.api import ms_function, _pynative_executor, _wrap_func 27from ..primitive import Primitive 28from ..operations import _grad_ops 29from .. import operations as P 30from .. import signature as sig 31 32__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] 33 34 35def add_flags(fn=None, **flags): 36 """ 37 A decorator that adds a flag to the function. 38 39 Note: 40 Only supports bool value. 41 42 Args: 43 fn (Function): Function or cell to add flag. Default: None. 44 flags (dict): Flags use kwargs. Default: None. 45 46 Returns: 47 Function, the function with added flags. 48 49 Examples: 50 >>> net = Net(); 51 >>> net = add_flags(net, predit=True) 52 >>> print(hasattr(net, '_mindspore_flags')) 53 True 54 """ 55 def deco(fn): 56 # need set the attr and access on c++ 57 if not hasattr(fn, "_mindspore_flags"): 58 fn._mindspore_flags = {} 59 60 fn._mindspore_flags.update({**flags}) 61 return fn 62 ret = deco 63 if fn is not None: 64 ret = deco(fn) 65 return ret 66 67 68def core(fn=None, **flags): 69 """ 70 A decorator that adds a flag to the function. 71 72 By default, the function is marked as True, enabling to use this decorator to 73 set flag to a graph. 74 75 Args: 76 fn (Function): Function to add flag. Default: None. 77 flags (dict): The following flags can be set core, which indicates that this is a core function or 78 other flag. Default: None. 79 80 Supported Platforms: 81 ``Ascend`` ``GPU`` ``CPU`` 82 83 Examples: 84 >>> net = Net() 85 >>> net = core(net, predit=True) 86 >>> print(hasattr(net, '_mindspore_flags')) 87 True 88 """ 89 # need set the attr and access on c++ 90 91 def deco(fn): 92 fn._mindspore_flags = { 93 'core': True, 94 **flags, 95 } 96 return fn 97 98 if fn is not None: 99 ret = deco(fn) 100 else: 101 ret = deco 102 return ret 103 104 105class GradOperation(GradOperation_): 106 """ 107 A higher-order function which is used to generate the gradient function for the input function. 108 109 The gradient function generated by `GradOperation` higher-order function can be customized by 110 construction arguments. 111 112 Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`, 113 see `Net` in Examples. 114 115 116 To generate a gradient function that returns gradients with respect to the first input 117 (see `GradNetWrtX` in Examples). 118 119 1. Construct a `GradOperation` higher-order function with default arguments: 120 `grad_op = GradOperation()`. 121 122 2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`. 123 124 3. Call the gradient function with input function's inputs to get the gradients with respect to the first input: 125 `grad_op(net)(x, y)`. 126 127 128 To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples). 129 130 1. Construct a `GradOperation` higher-order function with `get_all=True` which 131 indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`: 132 `grad_op = GradOperation(get_all=True)`. 133 134 2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`. 135 136 3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs: 137 `gradient_function(x, y)`. 138 139 To generate a gradient function that returns gradients with respect to given parameters 140 (see `GradNetWithWrtParams` in Examples). 141 142 1. Construct a `GradOperation` higher-order function with `get_by_list=True`: 143 `grad_op = GradOperation(get_by_list=True)`. 144 145 2. Construct a `ParameterTuple` that will be passed to the input function when constructing 146 `GradOperation` higher-order function, it will be used as a parameter filter that determine 147 which gradient to return: `params = ParameterTuple(net.trainable_params())`. 148 149 3. Call it with input function and `params` as arguments to get the gradient function: 150 `gradient_function = grad_op(net, params)`. 151 152 4. Call the gradient function with input function's inputs to get the gradients with 153 respect to given parameters: `gradient_function(x, y)`. 154 155 To generate a gradient function that returns gradients with respect to all inputs and given parameters 156 in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples). 157 158 1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`: 159 `grad_op = GradOperation(get_all=True, get_by_list=True)`. 160 161 2. Construct a `ParameterTuple` that will be passed along input function when constructing 162 `GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`. 163 164 3. Call it with input function and `params` as arguments to get the gradient function: 165 `gradient_function = grad_op(net, params)`. 166 167 4. Call the gradient function with input function's inputs 168 to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`. 169 170 171 We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and 172 passing an extra sensitivity input to the gradient function, the sensitivity input should has the 173 same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). 174 175 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`: 176 `grad_op = GradOperation(get_all=True, sens_param=True)`. 177 178 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output: 179 `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`. 180 181 3. Call it with input function as argument to get the gradient function: 182 `gradient_function = grad_op(net)`. 183 184 4. Call the gradient function with input function's inputs and `sens_param` to 185 get the gradients with respect to all inputs: 186 `gradient_function(x, y, grad_wrt_output)`. 187 188 Args: 189 get_all (bool): If True, get all the gradients with respect to inputs. Default: False. 190 get_by_list (bool): If True, get all the gradients with respect to Parameter variables. 191 If get_all and get_by_list are both False, get the gradient with respect to first input. 192 If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables 193 at the same time in the form of ((gradients with respect to inputs), 194 (gradients with respect to parameters)). Default: False. 195 sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. 196 If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. 197 Default: False. 198 If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred 199 through the location parameter or key-value pair parameter. If the value is transferred through 200 the key-value pair parameter, the key must be sens. 201 202 Returns: 203 The higher-order function which takes a function as argument and returns gradient function for it. 204 205 Raises: 206 TypeError: If `get_all`, `get_by_list` or `sens_param` is not a bool. 207 208 Supported Platforms: 209 ``Ascend`` ``GPU`` ``CPU`` 210 211 Examples: 212 >>> from mindspore import ParameterTuple 213 >>> class Net(nn.Cell): 214 ... def __init__(self): 215 ... super(Net, self).__init__() 216 ... self.matmul = P.MatMul() 217 ... self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 218 ... def construct(self, x, y): 219 ... x = x * self.z 220 ... out = self.matmul(x, y) 221 ... return out 222 ... 223 >>> class GradNetWrtX(nn.Cell): 224 ... def __init__(self, net): 225 ... super(GradNetWrtX, self).__init__() 226 ... self.net = net 227 ... self.grad_op = GradOperation() 228 ... def construct(self, x, y): 229 ... gradient_function = self.grad_op(self.net) 230 ... return gradient_function(x, y) 231 ... 232 >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 233 >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 234 >>> output = GradNetWrtX(Net())(x, y) 235 >>> print(output) 236 [[1.4100001 1.5999999 6.6 ] 237 [1.4100001 1.5999999 6.6 ]] 238 >>> 239 >>> class GradNetWrtXY(nn.Cell): 240 ... def __init__(self, net): 241 ... super(GradNetWrtXY, self).__init__() 242 ... self.net = net 243 ... self.grad_op = GradOperation(get_all=True) 244 ... def construct(self, x, y): 245 ... gradient_function = self.grad_op(self.net) 246 ... return gradient_function(x, y) 247 >>> 248 >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32) 249 >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32) 250 >>> output = GradNetWrtXY(Net())(x, y) 251 >>> print(output) 252 (Tensor(shape=[2, 3], dtype=Float32, value= 253 [[ 4.50999975e+00, 2.70000005e+00, 3.60000014e+00], 254 [ 4.50999975e+00, 2.70000005e+00, 3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= 255 [[ 2.59999990e+00, 2.59999990e+00, 2.59999990e+00], 256 [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00], 257 [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]])) 258 >>> 259 >>> class GradNetWrtXYWithSensParam(nn.Cell): 260 ... def __init__(self, net): 261 ... super(GradNetWrtXYWithSensParam, self).__init__() 262 ... self.net = net 263 ... self.grad_op = GradOperation(get_all=True, sens_param=True) 264 ... self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32) 265 ... def construct(self, x, y): 266 ... gradient_function = self.grad_op(self.net) 267 ... return gradient_function(x, y, self.grad_wrt_output) 268 >>> 269 >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32) 270 >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32) 271 >>> output = GradNetWrtXYWithSensParam(Net())(x, y) 272 >>> print(output) 273 (Tensor(shape=[2, 3], dtype=Float32, value= 274 [[ 2.21099997e+00, 5.09999990e-01, 1.49000001e+00], 275 [ 5.58800030e+00, 2.68000007e+00, 4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= 276 [[ 1.51999998e+00, 2.81999993e+00, 2.14000010e+00], 277 [ 1.09999990e+00, 2.04999995e+00, 1.54999995e+00], 278 [ 9.00000036e-01, 1.54999995e+00, 1.25000000e+00]])) 279 >>> 280 >>> class GradNetWithWrtParams(nn.Cell): 281 ... def __init__(self, net): 282 ... super(GradNetWithWrtParams, self).__init__() 283 ... self.net = net 284 ... self.params = ParameterTuple(net.trainable_params()) 285 ... self.grad_op = GradOperation(get_by_list=True) 286 ... def construct(self, x, y): 287 ... gradient_function = self.grad_op(self.net, self.params) 288 ... return gradient_function(x, y) 289 >>> 290 >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32) 291 >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32) 292 >>> output = GradNetWithWrtParams(Net())(x, y) 293 >>> print(output) 294 (Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),) 295 >>> 296 >>> class GradNetWrtInputsAndParams(nn.Cell): 297 ... def __init__(self, net): 298 ... super(GradNetWrtInputsAndParams, self).__init__() 299 ... self.net = net 300 ... self.params = ParameterTuple(net.trainable_params()) 301 ... self.grad_op = GradOperation(get_all=True, get_by_list=True) 302 ... def construct(self, x, y): 303 ... gradient_function = self.grad_op(self.net, self.params) 304 ... return gradient_function(x, y) 305 >>> 306 >>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32) 307 >>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32) 308 >>> output = GradNetWrtInputsAndParams(Net())(x, y) 309 >>> print(output) 310 ((Tensor(shape=[2, 3], dtype=Float32, value= 311 [[ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00], 312 [ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= 313 [[ 6.00000024e-01, 6.00000024e-01, 6.00000024e-01], 314 [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00], 315 [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value= 316 [ 1.29020004e+01]),)) 317 """ 318 319 def __init__(self, get_all=False, get_by_list=False, sens_param=False): 320 """Initialize GradOperation.""" 321 if not isinstance(get_all, bool): 322 raise TypeError(f"For 'GradOperation', the 'get_all' should be bool, but got {type(get_all).__name__}") 323 if not isinstance(get_by_list, bool): 324 raise TypeError(f"For 'GradOperation', the 'get_by_list' should be bool, " 325 f"but got {type(get_by_list).__name__}") 326 if not isinstance(sens_param, bool): 327 raise TypeError(f"For 'GradOperation', the 'sens_param' should be bool, " 328 f"but got {type(sens_param).__name__}") 329 self.get_all = get_all 330 self.get_by_list = get_by_list 331 self.sens_param = sens_param 332 GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param) 333 self.grad_fn = None 334 self.fn = None 335 self.pynative_ = False 336 337 def _pynative_forward_run(self, grad, args, kwargs, fn): 338 """ Pynative forward run to build grad graph. """ 339 new_kwargs = kwargs 340 if self.sens_param: 341 if not 'sens' in kwargs.keys(): 342 args = args[:-1] 343 else: 344 new_kwargs = kwargs.copy() 345 new_kwargs.pop('sens') 346 if isinstance(fn, FunctionType): 347 if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): 348 _pynative_executor.set_grad_flag(True) 349 _pynative_executor.new_graph(fn, *args, **new_kwargs) 350 output = fn(*args, **new_kwargs) 351 _pynative_executor.end_graph(fn, output, *args, **new_kwargs) 352 else: 353 # Check if fn have run already 354 if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs): 355 fn.set_grad() 356 fn(*args, **new_kwargs) 357 fn.set_grad(False) 358 359 def __call__(self, fn, weights=None): 360 if self.grad_fn is not None and self.fn == fn: 361 return self.grad_fn 362 grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param) 363 # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE 364 # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE 365 # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation. 366 # In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE. 367 if context.get_context("mode") == context.GRAPH_MODE: 368 if self.get_by_list: 369 @ms_function 370 def after_grad(*args): 371 return grad_(fn, weights)(*args) 372 else: 373 @ms_function 374 def after_grad(*args): 375 return grad_(fn)(*args) 376 elif self.pynative_: 377 @_wrap_func 378 def after_grad(*args, **kwargs): 379 if _pynative_executor.check_graph(fn, *args, **kwargs): 380 print("Another grad step is running") 381 self._pynative_forward_run(grad_, args, kwargs, fn) 382 _pynative_executor.grad(grad_, fn, weights, *args, **kwargs) 383 out = _pynative_executor(fn, *args, **kwargs) 384 _pynative_executor.clear_grad(fn, *args, **kwargs) 385 return out 386 else: 387 grad_.pynative_ = True 388 # after_grad of this branch can't use @ms_function, just directly call grad_ 389 if self.get_by_list: 390 def after_grad(*args, **kwargs): 391 return grad_(fn, weights)(*args, **kwargs) 392 else: 393 def after_grad(*args, **kwargs): 394 return grad_(fn)(*args, **kwargs) 395 396 self.grad_fn = after_grad 397 self.fn = fn 398 return self.grad_fn 399 400 401class MultitypeFuncGraph(MultitypeFuncGraph_): 402 """ 403 Generates overloaded functions. 404 405 MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. 406 Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator 407 for the function to be registered. And the object can be called with different types of inputs, 408 and work with `HyperMap` and `Map`. 409 410 Args: 411 name (str): Operator name. 412 read_value (bool): If the registered function not need to set value on Parameter, 413 and all inputs will pass by value, set `read_value` to True. Default: False. 414 415 Raises: 416 ValueError: If failed to find find a matching function for the given arguments. 417 418 Supported Platforms: 419 ``Ascend`` ``GPU`` ``CPU`` 420 421 Examples: 422 >>> # `add` is a metagraph object which will add two objects according to 423 >>> # input type using ".register" decorator. 424 >>> from mindspore import Tensor 425 >>> from mindspore import ops 426 >>> from mindspore import dtype as mstype 427 >>> 428 >>> tensor_add = ops.Add() 429 >>> add = MultitypeFuncGraph('add') 430 >>> @add.register("Number", "Number") 431 ... def add_scala(x, y): 432 ... return x + y 433 >>> @add.register("Tensor", "Tensor") 434 ... def add_tensor(x, y): 435 ... return tensor_add(x, y) 436 >>> output = add(1, 2) 437 >>> print(output) 438 3 439 >>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32)) 440 >>> print(output) 441 [0.2 1.2 2.4] 442 """ 443 444 def __init__(self, name, read_value=False): 445 """Initialize MultitypeFuncGraph.""" 446 MultitypeFuncGraph_.__init__(self, name) 447 self.entries = list() 448 if read_value: 449 self.set_signatures(( 450 sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),)) 451 452 def __call__(self, *args): 453 if len(self.entries) == 1: 454 output = self.entries[0][1](*args) 455 return output 456 types = tuple(map(mstype.get_py_obj_dtype, args)) 457 for sigs, fn in self.entries: 458 if len(sigs) != len(types): 459 continue 460 if any(not mstype.issubclass_(type_, sig) for sig, type_ in zip(sigs, types)): 461 continue 462 output = fn(*args) 463 return output 464 raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, " 465 f"and (dtype, args): {types}.") 466 467 def register(self, *type_names): 468 """ 469 Register a function for the given type string. 470 471 Args: 472 type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list. 473 474 Return: 475 decorator, a decorator to register the function to run, when called under the 476 types described in `type_names`. 477 """ 478 def deco(fn): 479 def convert_type(type_input): 480 if isinstance(type_input, str): 481 return mstype.typing.str_to_type(type_input) 482 if not isinstance(type_input, mstype.Type): 483 raise TypeError(f"For 'MultitypeFuncGraph', register only support str or {mstype.Type}, but got " 484 f"'type_input': {type_input}.") 485 return type_input 486 487 types = tuple(map(convert_type, type_names)) 488 self.register_fn(type_names, fn) 489 self.entries.append((types, fn)) 490 return fn 491 return deco 492 493 494class HyperMap(HyperMap_): 495 """ 496 Hypermap will apply the set operation to input sequences. 497 498 Apply the operations to every elements of the sequence or nested sequence. Different 499 from `Map`, the `HyperMap` supports to apply on nested structure. 500 501 Args: 502 ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, 503 the operations should be put in the first input of the instance. Default is None. 504 reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance, 505 general users please ignore. `reverse` is the flag to decide if apply the operation reversely. 506 Only supported in graph mode. Default is False. 507 508 Inputs: 509 - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length. 510 And each row of the sequences will be the inputs of the operation. 511 512 If `ops` is `None`, the first input is the operation, and the others are inputs. 513 514 Outputs: 515 Sequence or nested sequence, the sequence of output after applying the function. 516 e.g. `operation(args[0][i], args[1][i])`. 517 518 Raises: 519 TypeError: If `ops` is neither MultitypeFuncGraph nor None. 520 TypeError: If `args` is not a Tuple. 521 522 Supported Platforms: 523 ``Ascend`` ``GPU`` ``CPU`` 524 525 Examples: 526 >>> from mindspore import dtype as mstype 527 >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)), 528 ... (Tensor(3, mstype.float32), Tensor(4, mstype.float32))) 529 >>> # square all the tensor in the nested list 530 >>> 531 >>> square = MultitypeFuncGraph('square') 532 >>> @square.register("Tensor") 533 ... def square_tensor(x): 534 ... return ops.square(x) 535 >>> 536 >>> common_map = HyperMap() 537 >>> output = common_map(square, nest_tensor_list) 538 >>> print(output) 539 ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), 540 (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16))) 541 >>> square_map = HyperMap(square, False) 542 >>> output = square_map(nest_tensor_list) 543 >>> print(output) 544 ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), 545 (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16))) 546 """ 547 548 def __init__(self, ops=None, reverse=False): 549 """Initialize HyperMap.""" 550 self.ops = ops 551 if ops: 552 HyperMap_.__init__(self, reverse, ops) 553 else: 554 HyperMap_.__init__(self, reverse) 555 556 def __call__(self, *args): 557 func = self.ops 558 args_list = args 559 hypermap = self 560 if self.ops is None: 561 func = args[0] 562 args_list = args[1:] 563 hypermap = partial(self, func) 564 # is leaf 565 if not isinstance(args_list[0], (tuple, list)): 566 return func(*args_list) 567 return tuple(map(hypermap, *args_list)) 568 569 570class Map(Map_): 571 """ 572 Map will apply the set operation on input sequences. 573 574 Apply the operations to every elements of the sequence. 575 576 Args: 577 ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, 578 the operations should be put in the first input of the instance. Default: None 579 reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance, 580 general users please ignore. `reverse` is the flag to decide if apply the operation reversely. 581 Only supported in graph mode. Default is False. 582 583 Inputs: 584 - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, 585 and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence 586 `(args[0][i], args[1][i])` will be the input of the operation. 587 588 If `ops` is `None`, the first input is the operation, and the other is inputs. 589 590 Outputs: 591 Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`. 592 593 Examples: 594 >>> from mindspore import dtype as mstype 595 >>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32)) 596 >>> # square all the tensor in the list 597 >>> 598 >>> square = MultitypeFuncGraph('square') 599 >>> @square.register("Tensor") 600 ... def square_tensor(x): 601 ... return ops.square(x) 602 >>> 603 >>> common_map = Map() 604 >>> output = common_map(square, tensor_list) 605 >>> print(output) 606 (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4), 607 Tensor(shape=[], dtype=Float32, value= 9)) 608 >>> square_map = Map(square, False) 609 >>> output = square_map(tensor_list) 610 >>> print(output) 611 (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4), 612 Tensor(shape=[], dtype=Float32, value= 9)) 613 """ 614 615 def __init__(self, ops=None, reverse=False): 616 """Initialize Map.""" 617 self.ops = ops 618 if ops: 619 Map_.__init__(self, reverse, ops) 620 else: 621 Map_.__init__(self, reverse) 622 623 def __call__(self, *args): 624 func = self.ops 625 args_list = args 626 if self.ops is None: 627 func = args[0] 628 args_list = args[1:] 629 return tuple(map(func, *args_list)) 630 631 632class _ListAppend(ListAppend_): 633 """ 634 A metafuncgraph class that append one element to list. 635 636 Args: 637 name (str): The name of the metafuncgraph object. 638 """ 639 640 def __init__(self, name): 641 """Initialize _ListAppend.""" 642 ListAppend_.__init__(self, name) 643 644 def __call__(self, *args): 645 pass 646 647 648_append = _ListAppend("append") 649 650 651class _Tail(Tail_): 652 """ 653 A metafuncgraph class that generates tail elements of the tuple. 654 655 Args: 656 name (str): The name of the metafuncgraph object. 657 """ 658 659 def __init__(self, name): 660 """Initialize _Tail.""" 661 Tail_.__init__(self, name) 662 663 def __call__(self, *args): 664 pass 665 666 667tail = _Tail('tail') 668 669 670class _ZipOperation(ZipOperation_): 671 """Generates a tuple of zip iterations for inputs.""" 672 673 def __init__(self, name): 674 """Initialize _ZipOperation.""" 675 ZipOperation_.__init__(self, name) 676 677 def __call__(self, *args): 678 pass 679 680 681zip_operation = _ZipOperation('zip_operation') 682"""`zip_operation` will generate a tuple of zip iterations of inputs.""" 683 684 685env_get = MultitypeFuncGraph("env_get") 686 687 688env_getitem = Primitive('env_getitem') 689ref_to_embed = _grad_ops.RefToEmbed() 690zeros_like = P.ZerosLike() 691 692 693@env_get.register("EnvType", "Tensor") 694def _tensor_env_get(env, parameter): 695 """Used to get env.""" 696 return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter)) 697