1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""internal graph-compatible utility functions""" 16import math 17from itertools import zip_longest, accumulate 18from collections import deque 19import operator 20 21import mindspore.context as context 22from ..ops import functional as F 23from ..ops.primitive import constexpr 24from ..common import dtype as mstype 25from ..common import Tensor 26from .._c_expression import Tensor as Tensor_ 27from .._c_expression import typing 28from .._checkparam import Validator as validator 29 30from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric 31 32 33_check_axis_type = constexpr(validator.check_axis_type) 34 35 36@constexpr 37def _check_shape(shape): 38 """check the shape param to match the numpy style""" 39 if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)): 40 raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}") 41 if isinstance(shape, int): 42 shape = (shape,) 43 if isinstance(shape, (list, typing.List)): 44 shape = tuple(shape) 45 for s in shape: 46 if not isinstance(s, int): 47 raise TypeError("each entry in shape should be int.") 48 if s < 0: 49 raise ValueError("each entry in shape should no less than 0.") 50 return shape 51 52 53@constexpr 54def _check_dtype(dtype): 55 """check the input dtype and make conversions""" 56 # convert the string dtype to mstype.dtype 57 if isinstance(dtype, str): 58 dtype = dtype.lower() 59 dtype = dtype_map[dtype] 60 elif isinstance(dtype, type): 61 if dtype is int: 62 dtype = mstype.int32 63 elif dtype is float: 64 dtype = mstype.float32 65 else: 66 dtype = mstype.pytype_to_dtype(dtype) 67 if dtype not in dtype_tuple: 68 raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}") 69 return dtype 70 71 72@constexpr 73def _is_shape_empty(shp): 74 """Check whether shape contains zero""" 75 if isinstance(shp, int): 76 return shp == 0 77 return F.shape_mul(shp) == 0 78 79 80@constexpr 81def _check_start_normalize(start, ndim): 82 """check and normalize start argument for rollaxis.""" 83 if start < -ndim or start > ndim: 84 raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.") 85 if start < 0: 86 start = start + ndim 87 return start 88 89 90@constexpr 91def _check_axes_range(axes, ndim): 92 """ 93 Check axes type and normalize the negative axes. 94 95 Args: 96 axes: Axes of the tensor. 97 ndim (int): The number of dimensions of the tensor. 98 99 Return: 100 Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. 101 102 Raises: 103 TypeError: If the axes are not integer, tuple(int) or list(int). 104 ValueError: If duplicate axes exists or some axis is out of bounds. 105 """ 106 _check_axis_type(axes, True, True, True) 107 if isinstance(axes, (list, tuple)): 108 _check_element_int(axes) 109 axes = _canonicalize_axis(axes, ndim) 110 return axes 111 112 113@constexpr 114def _get_device(): 115 """Get the current device (`GPU`, `CPU`, `Ascend`)""" 116 return context.get_context('device_target') 117 118 119@constexpr 120def _infer_out_shape(*shapes): 121 """ 122 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 123 """ 124 shape_out = deque() 125 reversed_shapes = map(reversed, shapes) 126 for items in zip_longest(*reversed_shapes, fillvalue=1): 127 max_size = 0 if 0 in items else max(items) 128 if any(item not in (1, max_size) for item in items): 129 raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') 130 shape_out.appendleft(max_size) 131 return tuple(shape_out) 132 133 134@constexpr 135def _can_broadcast(*shapes): 136 """ 137 Returns Ture if shapes can broadcast, False if they cannot. 138 """ 139 try: 140 _infer_out_shape(*shapes) 141 except ValueError: 142 return False 143 finally: 144 pass 145 return True 146 147 148@constexpr 149def _check_axis_in_range(axis, ndim): 150 """Checks axes are with the bounds of ndim""" 151 if not isinstance(axis, int): 152 raise TypeError(f'axes should be integers, not {type(axis)}') 153 if not -ndim <= axis < ndim: 154 raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') 155 return axis % ndim 156 157 158@constexpr 159def _check_axis_valid(axes, ndim): 160 """ 161 Checks axes are valid given ndim, and returns axes that can be passed 162 to the built-in operator (non-negative, int or tuple) 163 """ 164 if axes is None: 165 axes = F.make_range(ndim) 166 return axes 167 if isinstance(axes, (tuple, list)): 168 axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes)) 169 if any(axes.count(el) > 1 for el in axes): 170 raise ValueError('duplicate value in "axis"') 171 return axes 172 return (_check_axis_in_range(axes, ndim),) 173 174 175@constexpr 176def _check_shape_aligned(shape1, shape2): 177 """Checks shape1 and shape2 are valid shapes to perform inner product""" 178 if shape1[-1] != shape2[-1]: 179 raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') 180 181 182@constexpr 183def _tile_size(shape, out_shape, ndim): 184 """Returns tile_size such that shape*tile_size = out_shape""" 185 size = [1]*ndim 186 for idx, (i, j) in enumerate(zip(shape, out_shape)): 187 if i != j: 188 size[idx] = j 189 return tuple(size) 190 191 192@constexpr 193def _raise_type_error(info, param=None): 194 """ 195 Raise TypeError in both graph/pynative mode 196 197 Args: 198 info(str): info string to display 199 param(python obj): any object that can be recognized by graph mode. If is 200 not None, then param's type information will be extracted and displayed. 201 Default is None. 202 """ 203 if param is None: 204 raise TypeError(info) 205 raise TypeError(info + f"{type(param)}") 206 207 208@constexpr 209def _raise_value_error(info, param=None): 210 """ 211 Raise TypeError in both graph/pynative mode 212 213 Args: 214 info(str): info string to display 215 param(python obj): any object that can be recognized by graph mode. If is 216 not None, then param's value information will be extracted and displayed. 217 Default is None. 218 """ 219 if param is None: 220 raise ValueError(info) 221 raise ValueError(info + f"{param}") 222 223 224@constexpr 225def _raise_runtime_error(info, param=None): 226 """ 227 Raise RuntimeError in both graph/pynative mode 228 229 Args: 230 info(str): info string to display 231 param(python obj): any object that can be recognized by graph mode. If is 232 not None, then param's value information will be extracted and displayed. 233 Default is None. 234 """ 235 if param is None: 236 raise RuntimeError(info) 237 raise RuntimeError(info + f"{param}") 238 239 240@constexpr 241def _raise_unimplemented_error(info, param=None): 242 """ 243 Raise NotImplementedError in both graph/pynative mode 244 245 Args: 246 info(str): info string to display 247 param(python obj): any object that can be recognized by graph mode. If is 248 not None, then param's value information will be extracted and displayed. 249 Default is None. 250 """ 251 if param is None: 252 raise NotImplementedError(info) 253 raise NotImplementedError(info + f"{param}") 254 255 256@constexpr 257def _empty(dtype, shape): 258 """Returns an uninitialized array with dtype and shape.""" 259 return Tensor_(dtype, shape) 260 261 262@constexpr 263def _promote(dtype1, dtype2): 264 if dtype1 == dtype2: 265 return dtype1 266 if (dtype1, dtype2) in promotion_rule: 267 return promotion_rule[dtype1, dtype2] 268 return promotion_rule[dtype2, dtype1] 269 270 271@constexpr 272def _promote_for_trigonometric(dtype): 273 return rule_for_trigonometric[dtype] 274 275 276@constexpr 277def _max(*args): 278 """Returns the maximum value.""" 279 return max(*args) 280 281 282@constexpr 283def _min(*args): 284 """"Returns the minimum value.""" 285 return min(*args) 286 287 288@constexpr 289def _abs(arg): 290 """Returns the absolute value.""" 291 return abs(arg) 292 293 294@constexpr 295def _check_same_type(dtype1, dtype2): 296 return dtype1 == dtype2 297 298 299@constexpr 300def _check_is_float(dtype): 301 """Returns whether dtype is float16 or float32.""" 302 return dtype in (mstype.float16, mstype.float32) 303 304 305@constexpr 306def _check_is_int(dtype): 307 return isinstance(dtype, typing.Int) 308 309 310@constexpr 311def _canonicalize_axis(axis, ndim): 312 """ 313 Check axes are within the number of dimensions of tensor x and normalize the negative axes. 314 Args: 315 axis (Union[int, tuple(int), list(int)]): Axes of the tensor. 316 ndim (int): The number of dimensions of the tensor. 317 Return: 318 Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple. 319 """ 320 if isinstance(axis, int): 321 axis = [axis] 322 for ax in axis: 323 _check_axis_in_range(ax, ndim) 324 325 def canonicalizer(ax): 326 return ax + ndim if ax < 0 else ax 327 328 axis = tuple([canonicalizer(axis) for axis in axis]) 329 if all(axis.count(el) <= 1 for el in axis): 330 return tuple(sorted(axis)) if len(axis) > 1 else axis[0] 331 raise ValueError(f"duplicate axes in {axis}.") 332 333 334@constexpr 335def _broadcast_tuples(tup1, tup2): 336 """ 337 Broadcast two 1D tuples to the same length, if inputs are ints, convert to 338 tuples first. 339 """ 340 tup1 = (tup1,) if isinstance(tup1, int) else tup1 341 tup2 = (tup2,) if isinstance(tup2, int) else tup2 342 if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)): 343 raise TypeError("input shift and axis must be tuple or list or int.") 344 if len(tup1) == len(tup2): 345 return tup1, tup2 346 if len(tup1) == 1: 347 tup1 *= len(tup2) 348 elif len(tup2) == 1: 349 tup2 *= len(tup1) 350 else: 351 raise ValueError("shape mismatch: objects cannot be broadcast to a single shape") 352 return tup1, tup2 353 354 355@constexpr 356def _expanded_shape(ndim, axis_size, axis): 357 """ 358 Returns a shape with size = 1 for all dimensions 359 except at axis. 360 """ 361 return tuple([axis_size if i == axis else 1 for i in range(ndim)]) 362 363 364@constexpr 365def _add_unit_axes(shape, ndim, append=False): 366 """ 367 Prepends shape with 1s so that it has the number of dimensions ndim. 368 If append is set to True, returns shape appended with 1s instead. 369 """ 370 if isinstance(shape, int): 371 shape = (shape,) 372 ndim_diff = ndim - len(shape) 373 if ndim_diff > 0: 374 if append: 375 shape = [i for i in shape] + [1]*ndim_diff 376 else: 377 shape = [1]*ndim_diff + [i for i in shape] 378 return tuple(shape) 379 380 381@constexpr 382def _check_element_int(lst): 383 """ 384 Check whether each element in `lst` is an integer. 385 """ 386 for item in lst: 387 if not isinstance(item, int): 388 raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.") 389 return True 390 391 392@constexpr 393def _type_convert(force, obj): 394 """ 395 Convert type of `obj` to `force`. 396 """ 397 return force(obj) 398 399 400@constexpr 401def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False): 402 """ 403 Generates a new list/tuple by list comprehension. 404 405 Args: 406 obj (Union[int, list, tuple]): 407 If integer, it will be the length of the returned tuple/list. 408 item: The value to be filled. Default: None. 409 If None, the values in the new list/tuple are the same as obj 410 or range(obj) when obj is integer. 411 return_tuple(bool): If true, returns tuple, else returns list. 412 413 Returns: 414 List or tuple. 415 """ 416 res = [] 417 lst = obj 418 if isinstance(obj, int): 419 lst = range(obj) 420 if make_none: 421 res = [None for _ in lst] 422 elif item is None: 423 res = [i for i in lst] 424 else: 425 res = [item for i in lst] 426 if return_tuple: 427 return tuple(res) 428 return res 429 430 431@constexpr 432def _tuple_setitem(tup, idx, value): 433 """ 434 Returns a tuple with specified `idx` set to `value`. 435 """ 436 tup = list(tup) 437 tup[idx] = value 438 return tuple(tup) 439 440 441@constexpr 442def _iota(dtype, num, increasing=True): 443 """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" 444 # Change to P.Linspace when the kernel is implemented on CPU. 445 if num <= 0: 446 raise ValueError("zero shape Tensor is not currently supported.") 447 if increasing: 448 return Tensor(list(range(int(num))), dtype) 449 return Tensor(list(range(int(num)-1, -1, -1)), dtype) 450 451 452@constexpr 453def _ceil(number): 454 """Ceils the number in graph mode.""" 455 return math.ceil(number) 456 457 458@constexpr 459def _seq_prod(seq1, seq2): 460 """Returns the element-wise product of seq1 and seq2.""" 461 return tuple(map(lambda x, y: x*y, seq1, seq2)) 462 463 464@constexpr 465def _make_tensor(val, dtype): 466 """Returns the tensor with value `val` and dtype `dtype`.""" 467 return Tensor(val, dtype) 468 469 470@constexpr 471def _tuple_slice(tup, start, end): 472 """get sliced tuple from start and end.""" 473 return tup[start:end] 474 475 476@constexpr 477def _isscalar(x): 478 """Returns True if x is a scalar type""" 479 return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float, 480 typing.Bool, typing.String)) 481 482 483@constexpr 484def _cumprod(x): 485 return tuple(accumulate(x, operator.mul)) 486 487 488@constexpr 489def _in(x, y): 490 return x in y 491 492 493@constexpr 494def _callable_const(x): 495 """Returns true if x is a function in graph mode.""" 496 return isinstance(x, typing.Function) 497 498 499@constexpr 500def _check_is_inf(x, negative=False): 501 if not negative: 502 return x == float('inf') 503 return x == float('-inf') 504