1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""internal graph-compatible utility functions""" 16from __future__ import absolute_import 17 18import math 19from itertools import accumulate 20import operator 21 22import mindspore.context as context 23from mindspore.ops import functional as F 24from mindspore.ops.primitive import constexpr 25from mindspore.ops.primitive import _primexpr 26from mindspore.common import dtype as mstype 27from mindspore.common import Tensor 28from mindspore._c_expression import Tensor as Tensor_ 29from mindspore._c_expression import typing 30from mindspore import _checkparam as validator 31 32from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric 33 34_check_axis_type = constexpr(validator.check_axis_type) 35 36 37@_primexpr 38def _check_shape(shape): 39 """check the shape param to match the numpy style""" 40 # convert tensor to int/list, use followed if statements to do further conversions 41 def _check(shape): 42 for s in shape: 43 if not isinstance(s, int): 44 raise TypeError("each entry in shape should be int.") 45 if s < 0: 46 raise ValueError("each entry in shape should no less than 0.") 47 48 if not isinstance(shape, (int, tuple, list, Tensor, typing.Tuple, typing.List)): 49 raise TypeError(f"only int, tuple, list and tensor are allowed for shape, but got {type(shape)}") 50 if isinstance(shape, Tensor): 51 shape = shape.asnumpy().tolist() 52 # this may return an int, don't change the next if to elif 53 if isinstance(shape, int): 54 shape = (shape,) 55 elif isinstance(shape, (list, typing.List)): 56 shape = tuple(shape) 57 _check(shape) 58 return shape 59 60 61@constexpr 62def _check_dtype(dtype): 63 """check the input dtype and make conversions""" 64 # convert the string dtype to mstype.dtype 65 if isinstance(dtype, str): 66 dtype = dtype.lower() 67 dtype = dtype_map.get(dtype, "") 68 elif isinstance(dtype, type): 69 if dtype is int: 70 dtype = mstype.int32 71 elif dtype is float: 72 dtype = mstype.float32 73 else: 74 dtype = mstype.pytype_to_dtype(dtype) 75 if dtype not in dtype_tuple: 76 raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}") 77 return dtype 78 79 80@_primexpr 81def _is_shape_empty(shp): 82 """Check whether shape contains zero""" 83 if F.is_sequence_shape_unknown(shp): 84 return False 85 if isinstance(shp, int): 86 return shp == 0 87 if isinstance(shp, (tuple, list)): 88 return 0 in shp 89 return F.shape_mul(shp) == 0 90 91 92@_primexpr 93def _check_start_normalize(start, ndim): 94 """check and normalize start argument for rollaxis.""" 95 def _check(): 96 if start < -ndim or start > ndim: 97 raise ValueError(f"For rollaxis, start {start} is out of bounds." 98 f"Ranging from {-ndim} to {ndim} is allowed.") 99 _check() 100 if start < 0: 101 start = start + ndim 102 return start 103 104 105@_primexpr 106def _check_axes_range(axes, ndim): 107 """ 108 Check axes type and normalize the negative axes. 109 110 Args: 111 axes: Axes of the tensor. 112 ndim (int): The number of dimensions of the tensor. 113 114 Return: 115 Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. 116 117 Raises: 118 TypeError: If the axes are not integer, tuple(int) or list(int). 119 ValueError: If duplicate axes exists or some axis is out of bounds. 120 """ 121 _check_axis_type(axes, True, True, True) 122 if isinstance(axes, (list, tuple)): 123 _check_element_int(axes) 124 axes = _canonicalize_axis(axes, ndim) 125 return axes 126 127 128@constexpr 129def _get_device(): 130 """Get the current device (`GPU`, `CPU`, `Ascend`)""" 131 return context.get_context('device_target') 132 133 134@_primexpr 135def _infer_out_shape(*shapes): 136 """ 137 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 138 """ 139 def _check(): 140 if any(item not in (1, max_size) for item in items): 141 raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') 142 143 shape_out = list() 144 max_len = max([len(it) for it in shapes]) 145 for i in range(max_len): 146 items = [ 147 it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes] 148 max_size = 0 if 0 in items else max(items) 149 _check() 150 shape_out.append(max_size) 151 return tuple(shape_out) 152 153 154@_primexpr 155def _can_broadcast(*shapes): 156 """ 157 Returns Ture if shapes can broadcast, False if they cannot. 158 """ 159 max_len = max([len(it) for it in shapes]) 160 for i in range(max_len): 161 items = [ 162 it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes] 163 max_size = 0 if 0 in items else max(items) 164 if any(item not in (1, max_size) for item in items): 165 return False 166 return True 167 168 169@_primexpr 170def _check_axis_in_range(axis, ndim): 171 """Checks axes are with the bounds of ndim""" 172 def _check(): 173 if not isinstance(axis, int): 174 raise TypeError(f'axes should be integers, not {type(axis)}') 175 if not -ndim <= axis < ndim: 176 raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') 177 _check() 178 return axis - axis // ndim * ndim 179 180 181@_primexpr 182def _check_axis_valid(axes, ndim): 183 """ 184 Checks axes are valid given ndim, and returns axes that can be passed 185 to the built-in operator (non-negative, int or tuple) 186 """ 187 def _check(axes): 188 if any(axes.count(el) > 1 for el in axes): 189 raise ValueError('duplicate value in "axis"') 190 191 if axes is None: 192 axes = F.make_range(ndim) 193 return axes 194 if isinstance(axes, (tuple, list)): 195 axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes)) 196 _check(axes) 197 return axes 198 return (_check_axis_in_range(axes, ndim),) 199 200 201@_primexpr 202def _check_shape_aligned(shape1, shape2): 203 """Checks shape1 and shape2 are valid shapes to perform inner product""" 204 if shape1[-1] != shape2[-1]: 205 raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') 206 207 208@_primexpr 209def _tile_size(shape, out_shape, ndim): 210 """Returns tile_size such that shape*tile_size = out_shape""" 211 size = [1]*ndim 212 for idx, (i, j) in enumerate(zip(shape, out_shape)): 213 if i != j: 214 size[idx] = j 215 return tuple(size) 216 217 218@constexpr 219def _raise_type_error(info, param=None): 220 """ 221 Raise TypeError in both graph/pynative mode 222 223 Args: 224 info(str): info string to display 225 param(python obj): any object that can be recognized by graph mode. If is 226 not None, then param's type information will be extracted and displayed. 227 Default is None. 228 """ 229 if param is None: 230 raise TypeError(info) 231 raise TypeError(info + f"{type(param)}") 232 233 234@_primexpr 235def _raise_value_error(info, param=None): 236 """ 237 Raise TypeError in both graph/pynative mode 238 239 Args: 240 info(str): info string to display 241 param(python obj): any object that can be recognized by graph mode. If is 242 not None, then param's value information will be extracted and displayed. 243 Default is None. 244 """ 245 if param is None: 246 raise ValueError(info) 247 raise ValueError(info + f"{param}") 248 249 250@constexpr 251def _raise_runtime_error(info, param=None): 252 """ 253 Raise RuntimeError in both graph/pynative mode 254 255 Args: 256 info(str): info string to display 257 param(python obj): any object that can be recognized by graph mode. If is 258 not None, then param's value information will be extracted and displayed. 259 Default is None. 260 """ 261 if param is None: 262 raise RuntimeError(info) 263 raise RuntimeError(info + f"{param}") 264 265 266@constexpr 267def _raise_unimplemented_error(info, param=None): 268 """ 269 Raise NotImplementedError in both graph/pynative mode 270 271 Args: 272 info(str): info string to display 273 param(python obj): any object that can be recognized by graph mode. If is 274 not None, then param's value information will be extracted and displayed. 275 Default is None. 276 """ 277 if param is None: 278 raise NotImplementedError(info) 279 raise NotImplementedError(info + f"{param}") 280 281 282@_primexpr 283def _empty(dtype, shape): 284 """Returns an uninitialized array with dtype and shape.""" 285 return Tensor_(dtype, shape) 286 287 288@constexpr 289def _promote(dtype1, dtype2): 290 if dtype1 == dtype2: 291 return dtype1 292 if (dtype1, dtype2) in promotion_rule: 293 return promotion_rule.get((dtype1, dtype2)) 294 res = promotion_rule.get((dtype2, dtype1)) 295 if res is None: 296 raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.") 297 return res 298 299 300@constexpr 301def _promote_for_trigonometric(dtype): 302 res_dtype = rule_for_trigonometric.get(dtype) 303 if res_dtype is None: 304 raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.") 305 return res_dtype 306 307 308@_primexpr 309def _max(*args): 310 """Returns the maximum value.""" 311 return max(*args) 312 313 314@_primexpr 315def _min(*args): 316 """"Returns the minimum value.""" 317 return min(*args) 318 319 320@constexpr 321def _abs(arg): 322 """Returns the absolute value.""" 323 return abs(arg) 324 325 326@constexpr 327def _check_same_type(dtype1, dtype2): 328 return dtype1 == dtype2 329 330 331@constexpr 332def _check_is_float(dtype): 333 """Returns whether dtype is float16 or float32.""" 334 return dtype in (mstype.float16, mstype.float32) 335 336 337@constexpr 338def _check_is_int(dtype): 339 return isinstance(dtype, typing.Int) 340 341 342@_primexpr 343def _canonicalize_axis(axis, ndim): 344 """ 345 Check axes are within the number of dimensions of tensor x and normalize the negative axes. 346 Args: 347 axis (Union[int, tuple(int), list(int)]): Axes of the tensor. 348 ndim (int): The number of dimensions of the tensor. 349 Return: 350 Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple. 351 """ 352 def _check(): 353 if not all(axis.count(el) <= 1 for el in axis): 354 raise ValueError(f"duplicate axes in {axis}.") 355 356 if isinstance(axis, int): 357 axis = [axis] 358 for ax in axis: 359 _check_axis_in_range(ax, ndim) 360 361 def canonicalizer(ax): 362 return ax + ndim if ax < 0 else ax 363 364 def _sort_axis(ax): 365 def merge(left, right): 366 result = () 367 i = j = 0 368 while i < len(left) and j < len(right): 369 if left[i] <= right[j]: 370 result += (left[i],) 371 i += 1 372 else: 373 result += (right[j],) 374 j += 1 375 376 return result + left[i:] + right[j:] 377 378 if len(ax) <= 1: 379 return ax 380 381 middle = len(ax) // 2 382 left = _sort_axis(ax[:middle]) 383 right = _sort_axis(ax[middle:]) 384 385 return merge(left, right) 386 387 axis = tuple([canonicalizer(axis) for axis in axis]) 388 _check() 389 return tuple(_sort_axis(axis)) if len(axis) > 1 else axis[0] 390 391 392@constexpr 393def _broadcast_tuples(tup1, tup2): 394 """ 395 Broadcast two 1D tuples to the same length, if inputs are ints, convert to 396 tuples first. 397 """ 398 def _check(): 399 if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)): 400 raise TypeError("input shift and axis must be tuple or list or int.") 401 if len(tup1) == len(tup2) or len(tup1) == 1 or len(tup2) == 1: 402 return 403 raise ValueError("shape mismatch: objects cannot be broadcast to a single shape") 404 405 tup1 = (tup1,) if isinstance(tup1, int) else tup1 406 tup2 = (tup2,) if isinstance(tup2, int) else tup2 407 _check() 408 if len(tup1) == len(tup2): 409 return tup1, tup2 410 if len(tup1) == 1: 411 tup1 *= len(tup2) 412 else: # case: len(tup2) == 1 413 tup2 *= len(tup1) 414 return tup1, tup2 415 416 417@_primexpr 418def _expanded_shape(ndim, axis_size, axis): 419 """ 420 Returns a shape with size = 1 for all dimensions 421 except at axis. 422 """ 423 return tuple([axis_size if i == axis else 1 for i in range(ndim)]) 424 425 426@_primexpr 427def _add_unit_axes(shape, ndim, append=False): 428 """ 429 Prepends shape with 1s so that it has the number of dimensions ndim. 430 If append is set to True, returns shape appended with 1s instead. 431 """ 432 if isinstance(shape, int): 433 shape = (shape,) 434 ndim_diff = ndim - len(shape) 435 if ndim_diff > 0: 436 if append: 437 shape = list(shape) + [1] * ndim_diff 438 else: 439 shape = [1] * ndim_diff + list(shape) 440 return tuple(shape) 441 442 443@constexpr 444def _check_element_int(lst): 445 """ 446 Check whether each element in `lst` is an integer. 447 """ 448 for item in lst: 449 if not isinstance(item, int): 450 raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.") 451 return True 452 453 454@constexpr 455def _type_convert(force, obj): 456 """ 457 Convert type of `obj` to `force`. 458 """ 459 return force(obj) 460 461 462@_primexpr 463def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False): 464 """ 465 Generates a new list/tuple by list comprehension. 466 467 Args: 468 obj (Union[int, list, tuple]): 469 If integer, it will be the length of the returned tuple/list. 470 item: The value to be filled. Default: ``None``. 471 If None, the values in the new list/tuple are the same as obj 472 or range(obj) when obj is integer. 473 return_tuple(bool): If true, returns tuple, else returns list. 474 475 Returns: 476 List or tuple. 477 """ 478 res = [] 479 lst = obj 480 if isinstance(obj, int): 481 lst = range(obj) 482 if make_none: 483 res = [None for _ in lst] 484 elif item is None: 485 res = list(lst) 486 else: 487 res = [item for _ in lst] 488 if return_tuple: 489 return tuple(res) 490 return res 491 492 493@_primexpr 494def _tuple_setitem(tup, idx, value): 495 """ 496 Returns a tuple with specified `idx` set to `value`. 497 """ 498 tup = list(tup) 499 tup[idx] = value 500 return tuple(tup) 501 502 503@constexpr 504def _iota(dtype, num, increasing=True): 505 """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" 506 # Change to P.Linspace when the kernel is implemented on CPU. 507 if num <= 0: 508 raise ValueError("zero shape Tensor is not currently supported.") 509 if increasing: 510 return Tensor(list(range(int(num))), dtype) 511 return Tensor(list(range(int(num)-1, -1, -1)), dtype) 512 513 514@constexpr 515def _ceil(number): 516 """Ceils the number in graph mode.""" 517 return math.ceil(number) 518 519 520@_primexpr 521def _seq_prod(seq1, seq2): 522 """Returns the element-wise product of seq1 and seq2.""" 523 return tuple(map(lambda x, y: x*y, seq1, seq2)) 524 525 526@constexpr 527def _make_tensor(val, dtype): 528 """Returns the tensor with value `val` and dtype `dtype`.""" 529 return Tensor(val, dtype) 530 531 532@_primexpr 533def _tuple_slice(tup, start, end): 534 """get sliced tuple from start and end.""" 535 return tup[start:end] 536 537 538@constexpr 539def _isscalar(x): 540 """Returns True if x is a scalar type""" 541 return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float, 542 typing.Bool, typing.String)) 543 544 545@constexpr 546def _cumprod(x): 547 return tuple(accumulate(x, operator.mul)) 548 549 550@constexpr 551def _in(x, y): 552 return x in y 553 554 555@constexpr 556def _callable_const(x): 557 """Returns true if x is a function in graph mode.""" 558 return isinstance(x, typing.Function) 559 560 561@constexpr 562def _check_is_inf(x, negative=False): 563 if not negative: 564 return x == float('inf') 565 return x == float('-inf') 566