1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""constexpr util""" 16 17import operator 18from functools import partial 19from itertools import compress 20 21import numpy as np 22from mindspore.common import dtype as mstype 23from mindspore.common._register_for_tensor import tensor_operator_registry 24from mindspore.common.tensor import Tensor 25from mindspore.ops import operations as P 26from mindspore.ops.operations import _inner_ops 27from mindspore.ops.primitive import constexpr, _primexpr 28from mindspore import log as logger 29from mindspore import context 30from mindspore._c_expression import Tensor as Tensor_ 31 32ALL_TENSOR = 0 33NO_TENSOR = 1 34CONTAIN_TENSOR = 2 35ALL_SCALAR = 3 36ALL_BASIC = 7 37MIXED = 8 38 39INT_ = 0 40BOOL_ = 1 41UNSUPPORTED_DTYPE = 2 42 43TENSOR_SETITEM = "tensor setitem" 44TENSOR_GETITEM = "tensor getitem" 45 46SET_ITEM_BY_ONE_TENSOR = 0 47SET_ITEM_BY_TUPLE_OF_TENSOR = 1 48SET_ITEM_BY_NON_TENSOR = 2 49 50type_priority_map = { 51 mstype.bool_: 0, 52 mstype.uint8: 1, 53 mstype.int8: 2, 54 mstype.uint16: 3, 55 mstype.int16: 4, 56 mstype.uint32: 5, 57 mstype.int32: 6, 58 mstype.uint64: 7, 59 mstype.int64: 8, 60 mstype.float16: 9, 61 mstype.float32: 10, 62 mstype.float64: 11 63} 64 65complex_priority_map = { 66 mstype.float32: 0, 67 mstype.float64: 1, 68 mstype.complex64: 2, 69 mstype.complex128: 4 70} 71 72complex_types = [mstype.complex64, mstype.complex128] 73 74 75@constexpr 76def raise_value_error(msg): 77 """Constexpr for raise_value_error.""" 78 raise ValueError(msg) 79 80 81@constexpr 82def raise_index_error(msg): 83 """Constexpr for raise_index_error.""" 84 raise IndexError(msg) 85 86 87@constexpr 88def raise_type_error(msg): 89 """Constexpr for raise_type_error.""" 90 raise TypeError(msg) 91 92 93@constexpr 94def raise_unimplemented_error(msg): 95 raise NotImplementedError(msg) 96 97 98@constexpr 99def log_warning(msg): 100 """Adds warning to logger.""" 101 logger.warning(msg) 102 103 104@constexpr 105def check_equal(param1, param2, msg="{},{}"): 106 """Checks whether the two parameters are equal or not.""" 107 if param1 != param2: 108 raise ValueError(msg.format(param1, param2)) 109 return param1 110 111 112@constexpr 113def make_empty_slice(): 114 """Creates a empty slice.""" 115 return slice(None, None, None) 116 117 118@_primexpr 119def _deep_list(array_like, dim_size=None): 120 """convert nested tuple/list mixtures to pure nested list""" 121 if dim_size is not None: 122 array_like = check_range(array_like, dim_size) 123 if isinstance(array_like, (list, tuple)): 124 return list(map(lambda x: _deep_list(x, dim_size), array_like)) 125 return array_like 126 127 128@constexpr 129def deep_tuple(array_like): 130 """convert nested tuple/list mixtures to pure nested tuple""" 131 if isinstance(array_like, (list, tuple)): 132 return tuple(map(deep_tuple, array_like)) 133 return array_like 134 135 136def _deep_tensor_to_nparray(array_like): 137 """ 138 convert a nested list of tensor to nested list of np_array. 139 140 Args: 141 array_like(list(tensor)): In any format of nested lists that may contain 142 tensors. 143 144 Returns: 145 array_like(list(np_array)): Formatted array that can be directly processed 146 by numpy.array(), with all tensor elements converted to numpy_array. 147 """ 148 # Recursively check whether each element is a tensor or not, if is tensor, 149 # convert it to a numpy array in place 150 if isinstance(array_like, Tensor): 151 return array_like.asnumpy() 152 153 if isinstance(array_like, list): 154 for idx, value in enumerate(array_like): 155 array_like[idx] = _deep_tensor_to_nparray(value) 156 157 return array_like 158 159 160@_primexpr 161def check_range(x, dim_size): 162 if dim_size is None: 163 return x 164 if isinstance(x, int) and not isinstance(x, bool): 165 if x >= dim_size or x < -dim_size: 166 raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}') 167 x = x % dim_size 168 return x 169 170 171@_primexpr 172def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None): 173 """ 174 Converts the input to tensor. 175 176 This function converts tensors from an array-like object. 177 178 Args: 179 a (Union[int, float, bool, list, tuple]): Input data, in any form that can 180 be converted to a `Tensor`. 181 dtype (:class:`mindspore.dtype`): Designated tensor dtype. 182 183 Returns: 184 Tensor, generated tensor with the specified dtype. 185 186 Raises: 187 TypeError: If input arguments have types not specified above. 188 ValueError: If input `a` has different sizes at different dimensions. 189 """ 190 if data_shape: 191 return Tensor(np.zeros(data_shape), dtype) 192 193 if not isinstance(a, (list, tuple, int, float, bool)): 194 raise TypeError(f"Input data must be `int`, `float`, `bool`, `list` or `tuple`, but got {a}") 195 196 if dim_size is not None: 197 a = check_range(a, dim_size) 198 199 if isinstance(a, int): 200 return P.ScalarToTensor()(a, dtype) 201 202 if isinstance(a, (list, tuple)): 203 if not a: 204 return Tensor_(a, dtype) 205 # Convert all tuple/nested tuples to lists 206 a = _deep_list(a, dim_size) 207 # Convert all tensor sub-elements to numpy arrays 208 a = _deep_tensor_to_nparray(a) 209 a = np.asarray(a) 210 if a.dtype is np.dtype('object'): 211 raise ValueError('Input array must have the same size across all dimensions.') 212 213 if isinstance(a, np.ndarray): 214 if a.dtype is np.dtype('object'): 215 raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") 216 217 return Tensor(a, dtype) 218 219setattr(tensor_operator_registry, 'make_tensor', make_tensor) 220 221 222def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8): 223 """Judges whether the data dim is valid.""" 224 if data_dim < min_data_dim or data_dim > max_data_dim: 225 raise ValueError(f"The input data's dim must in the range of [{min_data_dim}, " 226 f"{max_data_dim}], but got '{data_dim}'.") 227 228 229def get_source_shape(data_shape, value_shape): 230 """Returns the shape of value that will be used to broadcast against data.""" 231 if len(value_shape) > len(data_shape): 232 return data_shape 233 return value_shape 234 235 236@constexpr 237def check_tensor_setitem_index(index, element_type=None): 238 """Checks tuple index type of tensor assignment.""" 239 if index is None: 240 raise IndexError("Tensor's index cannot be None.") 241 if isinstance(index, slice): 242 return True 243 if isinstance(index, tuple): 244 if not index: 245 raise IndexError("Tensor's index cannot be empty.") 246 for item in index: 247 if not isinstance(item, (slice, type(...), int)): 248 raise IndexError( 249 "Index of type '{}' is not supported yet.".format(type(item))) 250 return True 251 if isinstance(index, mstype.TensorType): 252 if element_type is None or element_type != mstype.bool_: 253 raise TypeError( 254 "The index of tensor should be a bool type tensor. " 255 "{} type is not supported yet.".format(element_type)) 256 return True 257 258 raise IndexError( 259 "Index of type '{}' is not supported yet.".format(type(index))) 260 261 262@constexpr 263def is_same_type(inst, type_): 264 """ 265 Checks whether an object is an instance of a target type. 266 267 Inputs: 268 inst (mindspore.dtype): Inspected type. 269 type_ (mindspore.dtype): Target type. 270 271 Outputs: 272 bool, the check result. 273 """ 274 return inst == type_ 275 276 277@constexpr 278def check_valid_dim(dim, name): 279 """Checks whether the dim is valid.""" 280 if dim not in (1, 2): 281 raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.") 282 283 284@constexpr 285def judge_index_type(index_type, target_type): 286 """Judges whether the index type is valid.""" 287 if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type): 288 return True 289 return False 290 291 292@constexpr 293def judge_indexes_types(dtypes, target_type): 294 """Check a tuple of tensor data type.""" 295 for dtype in dtypes: 296 if isinstance(target_type, (list, tuple)): 297 if dtype not in target_type: 298 return False 299 else: 300 if dtype != target_type: 301 return False 302 return True 303 304 305@constexpr 306def check_type_isinstance(dtype, target_type): 307 """Checks whether the dtype is instance of target type.""" 308 if isinstance(dtype, (list, tuple)): 309 return all(isinstance(ele, target_type) for ele in dtype) 310 return isinstance(dtype, target_type) 311 312 313@constexpr 314def check_type_invalid(dtype, target_type): 315 """Checks whether the dtype is valid.""" 316 return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type) 317 318 319@constexpr 320def check_type_valid(dtype, target_type, op_name): 321 """Checks whether the dtype is valid.""" 322 if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): 323 if op_name in (TENSOR_GETITEM, TENSOR_SETITEM): 324 raise IndexError( 325 f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.") 326 raise TypeError( 327 f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.") 328 329 330@constexpr 331def check_types_valid(dtypes, target_type, op_name): 332 """Check a tuple of tensor data type.""" 333 for dtype in dtypes: 334 check_type_valid(dtype, target_type, op_name) 335 336 337@constexpr 338def get_pos_of_indexes_types(indexes_types, op_name): 339 """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" 340 slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \ 341 sequence_positions = (), (), (), (), (), (), () 342 for i, index_type in enumerate(indexes_types): 343 if isinstance(index_type, mstype.Slice): 344 slice_positions += (i,) 345 elif isinstance(index_type, mstype.Ellipsis_): 346 ellipsis_positions += (i,) 347 elif isinstance(index_type, mstype.NoneType): 348 none_positions += (i,) 349 elif isinstance(index_type, mstype.Int): 350 int_positions += (i,) 351 elif isinstance(index_type, mstype.Bool): 352 bool_positions += (i,) 353 elif isinstance(index_type, mstype.TensorType): 354 tensor_positions += (i,) 355 elif isinstance(index_type, (list, tuple)): 356 sequence_positions += (i,) 357 else: 358 raise TypeError(f"For '{op_name}', the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', 'int', " 359 f"'List', 'Tuple', 'bool', but got {index_type}.") 360 if len(ellipsis_positions) > 1: 361 raise IndexError( 362 f"For '{op_name}, an index can only have a single ellipsis('...')") 363 364 return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ 365 tensor_positions, sequence_positions 366 367 368def ellipsis2slice(input_, shape): 369 """Converts ellipsis to slice.""" 370 input_slice = input_ 371 result = [] 372 if isinstance(input_, type(...)): 373 input_slice = (input_,) 374 ell_count = 0 375 for _, element in enumerate(input_slice): 376 if not isinstance(element, type(...)): 377 result.append(element) 378 continue 379 ell_count += 1 380 if ell_count > 1: 381 raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " 382 "but it is currently {}".format(input_slice)) 383 for _ in range(len(shape) - len(input_slice) + 1): 384 result.append(slice(None, None, None)) 385 return tuple(result) 386 387 388@constexpr 389def slice2indices(input_slice, shape): 390 """ 391 Converts slice to indices. 392 393 Inputs: 394 input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice. 395 shape (tuple): The shape of a tensor is an integer element tuple. 396 397 Outputs: 398 Tensor, the shape is (n, 1). 399 """ 400 start, stop, step = normalize_slice(input_slice, shape[0]) 401 if check_slice_empty(start, stop, step): 402 return False 403 ndim = len(shape) 404 mesh = list() 405 range_op = P.Range() 406 cast_op = P.Cast() 407 grids = [ 408 range_op(cast_op(start, mstype.int64), cast_op(stop, mstype.int64), 409 cast_op(step, mstype.int64)) 410 ] 411 grids += [ 412 range_op(Tensor(0, mstype.int64), cast_op(dim_size, mstype.int64), 413 Tensor(1, mstype.int64)) for dim_size in shape[1:] 414 ] 415 for j, grid in enumerate(grids): 416 mesh.append(P.Reshape()(grid, tuple( 417 [grid.size if j == t else 1 for t in range(ndim)]))) 418 shapes = map(P.Shape(), mesh) 419 out_shape = infer_out_shape(*shapes) 420 mesh_arrays = list() 421 for arr in mesh: 422 mesh_arrays.append(P.BroadcastTo(out_shape)(arr)) 423 return P.Stack(-1)(mesh_arrays) 424 425 426@constexpr 427def check_indices(indices_size, index): 428 """Checks indices whether is empty.""" 429 if indices_size < 1: 430 raise IndexError( 431 "The tensor's index is unreasonable. index:{}".format(index)) 432 return indices_size 433 434 435@_primexpr 436def check_indices_value_size(indices_size, value_size): 437 """Checks if the sizes are already matched.""" 438 if value_size < 1: 439 raise ValueError("The value assigned to tensor cannot be empty.") 440 if value_size > 1: 441 if value_size != indices_size: 442 raise ValueError( 443 "The value given to tensor does not match the index size," 444 " value size:{}, indics size:{}".format(value_size, indices_size)) 445 return value_size 446 447 448@constexpr 449def tuple_index_type_cnt(types, op_name): 450 """count the tensor type of types which contains the tuple elements' type.""" 451 if all(isinstance(ele, mstype.TensorType) for ele in types): 452 return ALL_TENSOR 453 if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types): 454 return ALL_BASIC 455 return MIXED 456 457 458@constexpr 459def check_value_elements(types): 460 """Judges the type of all elements of the tuple.""" 461 tensor_number = 0 462 last_type = None 463 mix_but_no_tensor = False 464 for ele in types: 465 if isinstance(ele, mstype.TensorType): 466 tensor_number += 1 467 elif isinstance(ele, (list, tuple)): 468 return MIXED 469 470 if last_type is None: 471 last_type = type(ele) 472 elif not isinstance(ele, last_type): 473 mix_but_no_tensor = True 474 475 if tensor_number == 0: 476 if mix_but_no_tensor: 477 return MIXED 478 return NO_TENSOR 479 if tensor_number == len(types): 480 return ALL_TENSOR 481 return CONTAIN_TENSOR 482 483 484@constexpr 485def get_index_tensor_dtype(dtype): 486 """Check a tuple of tensor data type.""" 487 if dtype in mstype.int_type: 488 return INT_ 489 if dtype == mstype.bool_: 490 return BOOL_ 491 raise IndexError( 492 f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") 493 494 495@constexpr 496def check_tensors_dtype_same(data_dtype, value_dtype, op_name): 497 """Check tensors data type same.""" 498 if value_dtype == data_dtype: 499 return True 500 raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " 501 f"is not consistent with assigned tensor data type {data_dtype}.") 502 503 504@constexpr 505def get_broadcast_shape(x_shape, y_shape, prim_name): 506 """Get broadcast shape from input shapes.""" 507 if x_shape is None or y_shape is None: 508 raise ValueError("get_broadcast_shape has dynamic rank input") 509 if None in x_shape or None in y_shape: 510 raise ValueError("get_broadcast_shape has dynamic shape input") 511 if x_shape == y_shape: 512 return x_shape 513 x_len = len(x_shape) 514 y_len = len(y_shape) 515 length = x_len if x_len < y_len else y_len 516 broadcast_shape_back = [] 517 518 for i in range(-length, 0): 519 if x_shape[i] == 1: 520 broadcast_shape_back.append(y_shape[i]) 521 elif y_shape[i] == 1: 522 broadcast_shape_back.append(x_shape[i]) 523 elif x_shape[i] == y_shape[i]: 524 broadcast_shape_back.append(x_shape[i]) 525 else: 526 raise ValueError(f"For '{prim_name}', x.shape and y.shape need to " 527 f"broadcast. The value of x.shape[{i}] or y.shape[{i}]" 528 f" must be 1 or -1 when they are not the same, " 529 f"but got x.shape = {x_shape} " 530 f"and y.shape = {y_shape}.") 531 532 broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] 533 broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back 534 return broadcast_shape 535 536 537@constexpr 538def generate_broadcast_shape(shapes, op_name): 539 """Generate broadcast shape for a tuple of shape.""" 540 if not shapes: 541 return () 542 broadcast_shape = shapes[0] 543 for shape in shapes: 544 broadcast_shape = get_broadcast_shape(tuple(broadcast_shape), shape, op_name) 545 return tuple(broadcast_shape) 546 547 548@_primexpr 549def compute_multiples(origin_shape, broadcast_shape): 550 """Compute multiples between origin shape with broadcast shape.""" 551 len_gap = len(broadcast_shape) - len(origin_shape) 552 return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], tuple(origin_shape))) 553 554 555@constexpr 556def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): 557 """Convert a scalar to a tensor.""" 558 if op_type == SET_ITEM_BY_ONE_TENSOR: 559 updates_shape = indices_shape + data_shape[1:] 560 else: 561 updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] 562 return P.FillV2()(updates_shape, P.Cast()(value, data_dtype)) 563 564 565def generate_updates_shape(data_shape, index_shape, op_type): 566 """Generate updates shape for 'tensor setitem'.""" 567 if op_type == SET_ITEM_BY_ONE_TENSOR: 568 updates_shape = index_shape + data_shape[1:] 569 else: 570 updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:] 571 return updates_shape 572 573 574@constexpr 575def transform_slice_to_ele_list(slice_index, dim_len): 576 """Transforms slice to element list.""" 577 slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step) 578 start, stop, end = normalize_slice(slice_obj, dim_len) 579 slice_ele_list = list(range(start, stop, end)) 580 if not slice_ele_list: 581 raise IndexError(f"An empty slice is not supported, got {slice_obj}") 582 return slice_ele_list 583 584 585@constexpr 586def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes, 587 slice_shapes, op_name, fancy_position=None): 588 """ 589 Generate index info which contain broadcast shape, final shape, 590 indexes shapes info, ellipsis size from a tuple of mixed tensors. 591 """ 592 tensor_positions = tuple(sorted(tensor_positions)) 593 if fancy_position is None: 594 tensor_index_continue_tag = _judge_order_continuous(tensor_positions) 595 fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0 596 broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name) 597 598 final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:] 599 index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \ 600 broadcast_shape + (1,) * len(slice_shapes[fancy_position:]) 601 602 return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position 603 604 605def _judge_order_continuous(order_sequence): 606 if not order_sequence: 607 return False 608 for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]): 609 if idx1 + 1 != idx2: 610 return False 611 return True 612 613 614@constexpr 615def scalar_in_sequence(x, y): 616 """Determine whether the scalar in the sequence.""" 617 return x in y 618 619 620@constexpr 621def get_np_eps(input_dtype): 622 """Get numpy eps.""" 623 nptype = mstype.dtype_to_nptype(input_dtype) 624 eps = np.finfo(nptype).eps 625 return float(eps) 626 627 628@constexpr 629def check_number_index_type(number): 630 """Check if it is int or bool number""" 631 if isinstance(number, bool): 632 return BOOL_ 633 if isinstance(number, int): 634 return INT_ 635 raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} " 636 .format(number, type(number))) 637 638 639@constexpr 640def scalar_to_tensor(x): 641 """Convert a scalar to a tensor""" 642 return Tensor(x) 643 644 645@constexpr 646def unpack(x): 647 if isinstance(x, (tuple, list)) and len(x) == 1: 648 return unpack(x[0]) 649 return x 650 651 652@_primexpr 653def normalize_start(start, dim_size): 654 """ 655 Normalize `start` according to the number of dimensions (`dim_size`). 656 If the number of dimensions is not given, return the original input directly. 657 """ 658 if start is None: 659 return 0 660 if dim_size is None: 661 return start 662 if start < 0: 663 return 0 if start < -dim_size else start % dim_size 664 return start if start < dim_size else dim_size 665 666 667@_primexpr 668def normalize_stop(stop, dim_size): 669 """ 670 Normalize `stop` according to the number of dimensions (`dim_size`). 671 If the number of dimensions is not given, return the original input directly. 672 """ 673 if stop is None and dim_size is None: 674 raise IndexError("Not Support stop is None when dim is dynamic") 675 if stop is None: 676 return dim_size 677 if dim_size is None: 678 return stop 679 if stop < 0: 680 return 0 if stop < -dim_size else stop % dim_size 681 return stop if stop < dim_size else dim_size 682 683 684@constexpr 685def get_step_from_slice(input_slice): 686 """get step in a slice.""" 687 step = input_slice.step 688 if step is None: 689 step = 1 690 return step 691 692 693@_primexpr 694def normalize_slice(input_slice, dim_size): 695 """Normalizes start, stop, step in a slice.""" 696 step = input_slice.step 697 if step is None: 698 step = 1 699 if step >= 0: 700 start = normalize_start(input_slice.start, dim_size) 701 stop = normalize_stop(input_slice.stop, dim_size) 702 else: 703 start = normalize_stop(input_slice.start, dim_size) 704 stop = normalize_start(input_slice.stop, dim_size) 705 return start, stop, step 706 707 708@constexpr 709def tuple_slice(tup, start, end): 710 """get sliced tuple from start and end.""" 711 return tup[start:end] 712 713 714def expanded_shape(shape, expand_size): 715 return (1,)*expand_size + shape 716 717 718@constexpr 719def sequence_mul_int(seq, number): 720 """ 721 Make a new list with native python syntax. 722 723 Args: 724 seq (Union[list, tuple]): Input sequence. 725 y (int): Input number. 726 727 Returns: 728 New sequence, has the same type as `seq`. 729 """ 730 if not isinstance(number, int): 731 raise TypeError(f"can't multiply sequence by non-int of type {type(number)}") 732 return seq * number 733 734 735@constexpr 736def check_in_sequence(x, y): 737 """Determine whether the input `x` is in the sequence `y`.""" 738 return x in y 739 740 741@constexpr 742def is_slice(x): 743 return isinstance(x, slice) 744 745 746@_primexpr 747def filter_expanded_dims(shape, not_expanded_dim): 748 """filter_expanded_dims""" 749 def _check(diff, shape): 750 if diff < 0: 751 raise ValueError(f'unable to broadcast {shape}') 752 753 diff = len(not_expanded_dim) - len(shape) 754 _check(diff, shape) 755 res = list() 756 for i, flag in zip(shape, not_expanded_dim[diff:]): 757 if flag: 758 res.append(i) 759 return tuple(res) 760 761 762@constexpr 763def sequence_to_index(sequence, dim_size): 764 """Transforms sequence to tensor index.""" 765 if not sequence: 766 return False 767 if all(isinstance(i, bool) for i in sequence): 768 if dim_size is None: 769 return Tensor(sequence) 770 seq_size = len(sequence) 771 if seq_size != dim_size: 772 raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}') 773 sequence = tuple(compress(range(dim_size), sequence)) 774 if not sequence: 775 return False 776 return make_tensor(sequence, mstype.int64, None, dim_size) 777 778 779@constexpr 780def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim): 781 """Adds remaining dimensions not indexed to not_expanded_dim""" 782 if idx_advanced != -1: 783 if expand_true: 784 # tensor indices generate only one dimension with size 1 785 tensor_dims = (False,) 786 else: 787 tensor_dims = (True,)*tensor_index_ndim 788 not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:] 789 not_expanded_dim = not_expanded_dim + (True,)*rem_ndim 790 791 count_leading_false = 0 792 while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]: 793 count_leading_false += 1 794 idx_advanced = max(0, idx_advanced - count_leading_false) 795 return not_expanded_dim, idx_advanced 796 797 798@_primexpr 799def check_slice_empty(start, stop, step): 800 return (start - stop) * step >= 0 801 802 803@_primexpr 804def real_axes(ndim_orig, ndim_out, axes_orig): 805 """Returns the real axes to be reduced after performing broadcast""" 806 _diff = ndim_out - ndim_orig 807 axes = tuple(range(_diff)) 808 axes_orig = map(partial(operator.add, _diff), axes_orig) 809 return axes + tuple(axes_orig) 810 811 812@_primexpr 813def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position): 814 """Computes slice tensor shapes""" 815 shape = [1] * len(slice_shape) 816 shape[slice_cnt] = slice_shape[slice_cnt] 817 shape = shape[:fancy_position] + [1] * broadcast_shape_len + shape[fancy_position:] 818 return shape 819 820 821@_primexpr 822def infer_out_shape(*shapes): 823 """ 824 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 825 """ 826 shape_out = list() 827 max_len = max([len(it) for it in shapes]) 828 829 for i in range(max_len): 830 items = [it[i-max_len+len(it)] if i-max_len + 831 len(it) >= 0 else 1 for it in shapes] 832 max_size = 0 if 0 in items else max(items) 833 shape_out.append(max_size) 834 return tuple(shape_out) 835 836 837@constexpr(check=False) 838def use_copy_slice(tuple_index): 839 if tuple_index is not None and len(tuple_index) >= 2: 840 return (isinstance(tuple_index[0], int) and 841 isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and 842 all(x == slice(None, None, None) for x in tuple_index[2:])) 843 return False 844 845 846@constexpr 847def is_ascend(): 848 """Device target is Ascend or not""" 849 return context.get_context('device_target') == "Ascend" 850 851 852@constexpr 853def gen_exception_msg(msg_format, *args): 854 return msg_format.format(*args) 855 856 857@constexpr 858def get_output_dtype(dtype_1, dtype_2, use_complex=False): 859 """Returns output dtype after type promotion.""" 860 if use_complex: 861 priority_map = complex_priority_map 862 type_str = "Complex binary" 863 else: 864 priority_map = type_priority_map 865 type_str = "Binary" 866 priority_1 = priority_map.get(dtype_1, None) 867 priority_2 = priority_map.get(dtype_2, None) 868 if not priority_1 or not priority_2: 869 raise ValueError(f"{type_str} op type promotion not supported for {dtype_1} and {dtype_2}") 870 if priority_1 > priority_2: 871 return dtype_1 872 return dtype_2 873 874 875@constexpr 876def promote_binary_dtype(dtype_1, dtype_2): 877 """ 878 promote binary types 879 """ 880 if dtype_1 == dtype_2: 881 return dtype_1 882 if dtype_1 in complex_types or dtype_2 in complex_types: 883 return get_output_dtype(dtype_1, dtype_2, True) 884 return get_output_dtype(dtype_1, dtype_2, False) 885 886 887@_primexpr 888def generate_padding_shape(shape, length): 889 """ 890 pad the `shape` to `length` with 1. 891 """ 892 893 if _inner_ops.IsConstant()(length) and _inner_ops.IsConstant()(len(shape)): 894 if len(shape) > length: 895 raise ValueError(f"Can not pad {shape} to length {length}.") 896 897 return shape + (1,) * (length - len(shape)) 898