1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""constexpr util""" 16 17from itertools import compress, zip_longest 18from functools import partial 19from collections import deque 20import operator 21 22import numpy as np 23 24from ...primitive import constexpr 25from .... import log as logger 26from ....common import dtype as mstype 27from ....common.tensor import Tensor 28from ....common._register_for_tensor import tensor_operator_registry 29from ....ops import _utils as op_utils 30from ...._checkparam import Validator as validator 31 32ALL_TENSOR = 0 33NO_TENSOR = 1 34CONTAIN_TENSOR = 2 35ALL_SCALAR = 3 36ALL_BASIC = 7 37MIXED = 8 38 39INT_ = 0 40BOOL_ = 1 41UNSUPPORTED_DTYPE = 2 42 43TENSOR_SETITEM = "tensor setitem" 44TENSOR_GETITEM = "tensor getitem" 45 46SET_ITEM_BY_ONE_TENSOR = 0 47SET_ITEM_BY_TUPLE_OF_TENSOR = 1 48SET_ITEM_BY_NON_TENSOR = 2 49 50 51@constexpr 52def raise_value_error(msg): 53 """Constexpr for raise_value_error.""" 54 raise ValueError(msg) 55 56 57@constexpr 58def raise_index_error(msg): 59 """Constexpr for raise_index_error.""" 60 raise IndexError(msg) 61 62 63@constexpr 64def raise_type_error(msg): 65 """Constexpr for raise_type_error.""" 66 raise TypeError(msg) 67 68 69@constexpr 70def raise_unimplemented_error(msg): 71 raise NotImplementedError(msg) 72 73 74@constexpr 75def check_equal(param1, param2, msg="{},{}"): 76 """Checks whether the two parameters are equal or not.""" 77 if param1 != param2: 78 raise ValueError(msg.format(param1, param2)) 79 return param1 80 81 82@constexpr 83def make_empty_slice(): 84 """Creates a empty slice.""" 85 return slice(None, None, None) 86 87 88@constexpr 89def _deep_list(array_like, dim_size=-1): 90 """convert nested tuple/list mixtures to pure nested list""" 91 if dim_size != -1: 92 array_like = check_range(array_like, dim_size) 93 if isinstance(array_like, (list, tuple)): 94 return list(map(lambda x: _deep_list(x, dim_size), array_like)) 95 return array_like 96 97 98@constexpr 99def deep_tuple(array_like): 100 """convert nested tuple/list mixtures to pure nested tuple""" 101 if isinstance(array_like, (list, tuple)): 102 return tuple(map(deep_tuple, array_like)) 103 return array_like 104 105 106def _deep_tensor_to_nparray(array_like): 107 """ 108 convert a nested list of tensor to nested list of np_array. 109 110 Args: 111 array_like(list(tensor)): In any format of nested lists that may contain 112 tensors. 113 114 Returns: 115 array_like(list(np_array)): Formatted array that can be directly processed 116 by numpy.array(), with all tensor elements converted to numpy_array. 117 """ 118 # Recursively check whether each element is a tensor or not, if is tensor, 119 # convert it to a numpy array in place 120 if isinstance(array_like, Tensor): 121 return array_like.asnumpy() 122 123 if isinstance(array_like, list): 124 for idx, value in enumerate(array_like): 125 array_like[idx] = _deep_tensor_to_nparray(value) 126 127 return array_like 128 129 130@constexpr 131def check_range(x, dim_size): 132 if isinstance(x, int) and not isinstance(x, bool): 133 if x >= dim_size or x < -dim_size: 134 raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}') 135 x = x % dim_size 136 return x 137 138 139@constexpr 140def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=-1): 141 """ 142 Converts the input to tensor. 143 144 This function converts tensors from an array-like object. 145 146 Args: 147 a (Union[int, float, bool, list, tuple]): Input data, in any form that can 148 be converted to a `Tensor`. 149 dtype (:class:`mindspore.dtype`): Designated tensor dtype. 150 151 Returns: 152 Tensor, generated tensor with the specified dtype. 153 154 Raises: 155 TypeError: If input arguments have types not specified above. 156 ValueError: If input `a` has different sizes at different dimensions. 157 """ 158 if data_shape: 159 return Tensor(np.zeros(data_shape), dtype) 160 161 if not isinstance(a, (list, tuple, int, float, bool)): 162 raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`") 163 164 if dim_size != -1: 165 a = check_range(a, dim_size) 166 167 if isinstance(a, (list, tuple)): 168 # Convert all tuple/nested tuples to lists 169 a = _deep_list(a, dim_size) 170 # Convert all tensor sub-elements to numpy arrays 171 a = _deep_tensor_to_nparray(a) 172 a = np.asarray(a) 173 if a.dtype is np.dtype('object'): 174 raise ValueError('Input array must have the same size across all dimensions.') 175 176 if isinstance(a, np.ndarray): 177 if a.dtype is np.dtype('object'): 178 raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") 179 180 return Tensor(a, dtype) 181 182tensor_operator_registry.register('make_tensor', make_tensor) 183 184 185@constexpr 186def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8): 187 """Judges whether the data dim is valid.""" 188 if data_dim < min_data_dim or data_dim > max_data_dim: 189 raise ValueError(f"The input data's dim should in the range of[{min_data_dim}, " 190 f"{max_data_dim}], bug actually is '{data_dim}'") 191 192 193@constexpr 194def get_source_shape(data_shape, value_shape): 195 """Returns the shape of value that will be used to broadcast against data.""" 196 cannot_broadcast = False 197 source_shape = value_shape 198 for i, j in zip(reversed(data_shape), reversed(value_shape)): 199 if j not in (1, i): 200 cannot_broadcast = True 201 for i in range(len(value_shape) - len(data_shape)): 202 source_shape = data_shape 203 if value_shape[i] != 1: 204 cannot_broadcast = True 205 if cannot_broadcast: 206 raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}') 207 return source_shape 208 209 210@constexpr 211def check_tensor_setitem_index(index, element_type=None): 212 """Checks tuple index type of tensor assignment.""" 213 if index is None: 214 raise IndexError("Tensor's index cannot be None.") 215 if isinstance(index, slice): 216 return True 217 if isinstance(index, tuple): 218 if not index: 219 raise IndexError("Tensor's index cannot be empty.") 220 for item in index: 221 if not isinstance(item, (slice, type(...), int)): 222 raise IndexError( 223 "Index of type '{}' is not supported yet.".format(type(item))) 224 return True 225 if isinstance(index, mstype.tensor_type): 226 if element_type is None or element_type != mstype.bool_: 227 raise TypeError( 228 "The index of tensor should be a bool type tensor. " 229 "{} type is not supported yet.".format(element_type)) 230 return True 231 232 raise IndexError( 233 "Index of type '{}' is not supported yet.".format(type(index))) 234 235 236@constexpr 237def is_same_type(inst, type_): 238 """ 239 Checks whether an object is an instance of a target type. 240 241 Inputs: 242 inst (mindspore.dtype): Inspected type. 243 type_ (mindspore.dtype): Target type. 244 245 Outputs: 246 bool, the check result. 247 """ 248 return inst == type_ 249 250 251@constexpr 252def check_valid_dim(dim, name): 253 """Checks whether the dim is valid.""" 254 if dim not in (1, 2): 255 raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.") 256 257 258@constexpr 259def judge_index_type(index_type, target_type): 260 """Judges whether the index type is valid.""" 261 if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type): 262 return True 263 return False 264 265 266@constexpr 267def judge_indexes_types(dtypes, target_type): 268 """Check a tuple of tensor data type.""" 269 for dtype in dtypes: 270 if isinstance(target_type, (list, tuple)): 271 if dtype not in target_type: 272 return False 273 else: 274 if dtype != target_type: 275 return False 276 return True 277 278 279@constexpr 280def check_type_invalid(dtype, target_type): 281 """Checks whether the dtype is valid.""" 282 return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type) 283 284 285@constexpr 286def check_type_valid(dtype, target_type, op_name): 287 """Checks whether the dtype is valid.""" 288 if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): 289 if op_name in (TENSOR_GETITEM, TENSOR_SETITEM): 290 raise IndexError( 291 f"The '{op_name}' doesn't support {dtype}' and expect to receive {target_type}.") 292 raise TypeError( 293 f"The '{op_name}' doesn't support {dtype}' and expect to receive {target_type}.") 294 295 296@constexpr 297def check_types_valid(dtypes, target_type, op_name): 298 """Check a tuple of tensor data type.""" 299 for dtype in dtypes: 300 check_type_valid(dtype, target_type, op_name) 301 302 303@constexpr 304def get_pos_of_indexes_types(indexes_types, op_name): 305 """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" 306 slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \ 307 sequence_positions = (), (), (), (), (), (), () 308 for i, index_type in enumerate(indexes_types): 309 if isinstance(index_type, mstype.Slice): 310 slice_positions += (i,) 311 elif isinstance(index_type, mstype.Ellipsis_): 312 ellipsis_positions += (i,) 313 elif isinstance(index_type, mstype.none_type): 314 none_positions += (i,) 315 elif isinstance(index_type, mstype.Int): 316 int_positions += (i,) 317 elif isinstance(index_type, mstype.Bool): 318 bool_positions += (i,) 319 elif isinstance(index_type, mstype.tensor_type): 320 tensor_positions += (i,) 321 elif isinstance(index_type, (list, tuple)): 322 sequence_positions += (i,) 323 else: 324 raise IndexError(f"For '{op_name}', the index elements only support 'Slice', 'Ellipsis', 'None', " 325 f"'Tensor', 'int', 'List', 'Tuple', 'bool' but got {index_type}.") 326 if len(ellipsis_positions) > 1: 327 raise IndexError( 328 f"For '{op_name}, an index can only have a single ellipsis('...')") 329 330 return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ 331 tensor_positions, sequence_positions 332 333 334def ellipsis2slice(input_, shape): 335 """Converts ellipsis to slice.""" 336 input_slice = input_ 337 result = [] 338 if isinstance(input_, type(...)): 339 input_slice = (input_,) 340 ell_count = 0 341 for _, element in enumerate(input_slice): 342 if not isinstance(element, type(...)): 343 result.append(element) 344 continue 345 ell_count += 1 346 if ell_count > 1: 347 raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " 348 "but it is currently {}".format(input_slice)) 349 for _ in range(len(shape) - len(input_slice) + 1): 350 result.append(slice(None, None, None)) 351 return tuple(result) 352 353 354@constexpr 355def slice2indices(input_slice, shape): 356 """ 357 Converts slice to indices. 358 359 Inputs: 360 input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice. 361 shape (tuple): The shape of a tensor is an integer element tuple. 362 363 Outputs: 364 Tensor, the shape is (n, 1). 365 """ 366 start, stop, step = normalize_slice(input_slice, shape[0]) 367 if check_slice_empty(start, stop, step): 368 return False 369 grids = ([np.array(list(range(start, stop, step)), dtype=np.int64)] + 370 [np.array(list(range(dim_size)), dtype=np.int64) for dim_size in shape[1:]]) 371 mesh = np.ix_(*grids) 372 return Tensor(np.stack(np.broadcast_arrays(*mesh), axis=-1)) 373 374 375@constexpr 376def check_indices(indices_size, index): 377 """Checks indices whether is empty.""" 378 if indices_size < 1: 379 raise IndexError( 380 "The tensor's index is unreasonable. index:{}".format(index)) 381 return indices_size 382 383 384@constexpr 385def check_indices_value_size(indices_size, value_size): 386 """Checks if the sizes are already matched.""" 387 if value_size < 1: 388 raise ValueError("The value assigned to tensor cannot be empty.") 389 if value_size > 1: 390 if value_size != indices_size: 391 raise ValueError( 392 "The value given to tensor does not match the index size," 393 " value size:{}, indics size:{}".format(value_size, indices_size)) 394 return value_size 395 396 397@constexpr 398def tuple_index_type_cnt(types, op_name): 399 """count the tensor type of types which contains the tuple elements' type.""" 400 if all(isinstance(ele, mstype.tensor_type) for ele in types): 401 return ALL_TENSOR 402 if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types): 403 return ALL_BASIC 404 return MIXED 405 406 407@constexpr 408def check_value_elements(types): 409 """Judges the type of all elements of the tuple.""" 410 tensor_number = 0 411 for ele in types: 412 if isinstance(ele, mstype.tensor_type): 413 tensor_number += 1 414 if tensor_number == 0: 415 return NO_TENSOR 416 if tensor_number == len(types): 417 return ALL_TENSOR 418 return CONTAIN_TENSOR 419 420 421@constexpr 422def get_index_tensor_dtype(dtype): 423 """Check a tuple of tensor data type.""" 424 if dtype in mstype.int_type: 425 return INT_ 426 if dtype == mstype.bool_: 427 return BOOL_ 428 raise IndexError( 429 f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") 430 431 432@constexpr 433def check_tensors_dtype_same(data_dtype, value_dtype, op_name): 434 """Check tensors data type same.""" 435 if value_dtype == data_dtype: 436 return True 437 raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " 438 f"is not consistent with assigned tensor data type {data_dtype}.") 439 440 441@constexpr 442def generate_broadcast_shape(shapes, op_name): 443 """Generate broadcast shape for a tuple of shape.""" 444 if not shapes: 445 return () 446 broadcast_shape = shapes[0] 447 for i, shape in enumerate(shapes): 448 logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") 449 try: 450 broadcast_shape = op_utils.get_broadcast_shape( 451 broadcast_shape, shape, op_name) 452 except ValueError as ex: 453 raise IndexError(ex) 454 return tuple(broadcast_shape) 455 456 457@constexpr 458def check_two_shapes_need_broadcast(shape_x, shape_y): 459 """Check shape_y needs to be broadcast to shape_x.""" 460 if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))): 461 raise ValueError(f"{shape_y} could not broadcast with {shape_x}.") 462 return shape_y != shape_x 463 464 465@constexpr 466def compute_multiples(origin_shape, broadcast_shape): 467 """Compute multiples between origin shape with broadcast shape.""" 468 len_gap = len(broadcast_shape) - len(origin_shape) 469 return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) 470 471 472@constexpr 473def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): 474 """Convert a scalar to a tensor.""" 475 if op_type == SET_ITEM_BY_ONE_TENSOR: 476 updates_shape = indices_shape + data_shape[1:] 477 else: 478 updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] 479 return Tensor(np.full(updates_shape, value), dtype=data_dtype) 480 481 482@constexpr 483def generate_updates_shape(data_shape, index_shape, op_type): 484 """Generate updates shape for 'tensor setitem'.""" 485 if op_type == SET_ITEM_BY_ONE_TENSOR: 486 updates_shape = index_shape + data_shape[1:] 487 else: 488 updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:] 489 return updates_shape 490 491 492@constexpr 493def transform_slice_to_ele_list(slice_index, dim_len): 494 """Transforms slice to element list.""" 495 slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step) 496 start, stop, end = normalize_slice(slice_obj, dim_len) 497 slice_ele_list = list(range(start, stop, end)) 498 if not slice_ele_list: 499 raise IndexError(f"An empty slice is not supported, got {slice_obj}") 500 return slice_ele_list 501 502 503@constexpr 504def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes, 505 slice_shapes, op_name, fancy_position=None): 506 """ 507 Generate index info which contain broadcast shape, final shape, 508 indexes shapes info, ellipsis size from a tuple of mixed tensors. 509 """ 510 tensor_positions = tuple(sorted(tensor_positions)) 511 if fancy_position is None: 512 tensor_index_continue_tag = _judge_order_continuous(tensor_positions) 513 fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0 514 broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name) 515 516 final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:] 517 index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \ 518 broadcast_shape + (1,) * len(slice_shapes[fancy_position:]) 519 520 return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position 521 522 523def _judge_order_continuous(order_sequence): 524 if not order_sequence: 525 return False 526 for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]): 527 if idx1 + 1 != idx2: 528 return False 529 return True 530 531 532@constexpr 533def scalar_in_sequence(x, y): 534 """Determine whether the scalar in the sequence.""" 535 if x is None: 536 raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " 537 "but the scalar is not.") 538 if y is None: 539 raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " 540 "but the sequence is not.") 541 return x in y 542 543 544@constexpr 545def get_np_eps(input_dtype): 546 """Get numpy eps.""" 547 nptype = mstype.dtype_to_nptype(input_dtype) 548 eps = np.finfo(nptype).eps 549 return float(eps) 550 551 552@constexpr 553def check_number_index_type(number): 554 """Check if it is int or bool number""" 555 if isinstance(number, bool): 556 return BOOL_ 557 if isinstance(number, int): 558 return INT_ 559 raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} " 560 .format(number, type(number))) 561 562 563@constexpr 564def get_stride_info_from_slice(data_shape, slice_index): 565 """Get stride info from a python slice""" 566 begin, end, step = get_slice_stride(slice_index, data_shape[0]) 567 begin_strides = [begin] 568 end_strides = [end] 569 step_strides = [step] 570 for end in data_shape[1:]: 571 begin_strides.append(0) 572 end_strides.append(end) 573 step_strides.append(1) 574 return tuple(begin_strides), tuple(end_strides), tuple(step_strides) 575 576 577@constexpr 578def get_stride_info_from_integer(data_shape, number): 579 """Get stride info from a integer""" 580 begin_strides = [number] 581 end_strides = [number + 1] 582 step_strides = [1] 583 for end in data_shape[1:]: 584 begin_strides.append(0) 585 end_strides.append(end) 586 step_strides.append(1) 587 return tuple(begin_strides), tuple(end_strides), tuple(step_strides) 588 589 590def get_slice_stride(index_slice, dim_size): 591 """Get slice stride info""" 592 step = 1 if index_slice.step is None else index_slice.step 593 start_default = 0 594 stop_default = dim_size 595 if step < 0: 596 start_default = -1 597 stop_default = -(dim_size + 1) 598 start = start_default if index_slice.start is None else index_slice.start 599 stop = stop_default if index_slice.stop is None else index_slice.stop 600 return start, stop, step 601 602 603@constexpr 604def get_stride_info_from_tuple(data_shape, tuple_index): 605 """Get stride info from a tuple""" 606 begin_strides, end_strides, step_strides = [], [], [] 607 tuple_index_len = len(tuple_index) 608 data_dim = len(data_shape) 609 shrink_axis, index_count, ellipsis_count = 0, 0, 0 610 for index, dim_size in zip(tuple_index, data_shape): 611 if isinstance(index, slice): 612 start, stop, step = get_slice_stride(index, dim_size) 613 begin_strides.append(start) 614 end_strides.append(stop) 615 step_strides.append(step) 616 index_count = index_count + 1 617 elif isinstance(index, int): 618 begin_strides.append(index) 619 end_strides.append(index + 1) 620 step_strides.append(1) 621 shrink_axis = shrink_axis + (1 << index_count) 622 index_count = index_count + 1 623 elif index is ...: 624 ellipsis_count = ellipsis_count + 1 625 if ellipsis_count > 1: 626 raise IndexError("An index can have only one ellipsis (...)") 627 ellipsis_range_size = data_dim - tuple_index_len + 1 628 begin_strides.extend([0] * ellipsis_range_size) 629 end_strides.extend( 630 [shape for shape in data_shape[index_count: index_count + ellipsis_range_size]]) 631 step_strides.extend([1] * ellipsis_range_size) 632 index_count = index_count + ellipsis_range_size 633 else: 634 raise IndexError("Not supported index data type, got ", 635 index, " type is ", type(index)) 636 for index in range(index_count, data_dim): 637 begin_strides.append(0) 638 end_strides.append(data_shape[index]) 639 step_strides.append(1) 640 return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis 641 642 643@constexpr 644def scalar_to_tensor(x): 645 """Convert a scalar to a tensor""" 646 return Tensor(x) 647 648 649@constexpr 650def unpack(x): 651 if isinstance(x, (tuple, list)) and len(x) == 1: 652 return unpack(x[0]) 653 return x 654 655 656@constexpr 657def normalize_start(start, dim_size): 658 """ 659 Normalize `start` according to the number of dimensions (`dim_size`). 660 If the number of dimensions is not given, return the original input directly. 661 """ 662 if start is None: 663 return 0 664 if start < 0: 665 return 0 if start < -dim_size else start % dim_size 666 return start if start < dim_size else dim_size 667 668 669@constexpr 670def normalize_stop(stop, dim_size): 671 """ 672 Normalize `stop` according to the number of dimensions (`dim_size`). 673 If the number of dimensions is not given, return the original input directly. 674 """ 675 if stop is None: 676 return dim_size 677 if stop < 0: 678 return 0 if stop < -dim_size else stop % dim_size 679 return stop if stop < dim_size else dim_size 680 681 682@constexpr 683def normalize_slice(input_slice, dim_size): 684 """Normalizes start, stop, step in a slice.""" 685 start = normalize_start(input_slice.start, dim_size) 686 stop = normalize_stop(input_slice.stop, dim_size) 687 step = input_slice.step 688 if step is None: 689 step = 1 690 if step >= 0: 691 start = normalize_start(input_slice.start, dim_size) 692 stop = normalize_stop(input_slice.stop, dim_size) 693 else: 694 start = normalize_stop(input_slice.start, dim_size) 695 stop = normalize_start(input_slice.stop, dim_size) 696 return start, stop, step 697 698 699@constexpr 700def tuple_slice(tup, start, end): 701 """get sliced tuple from start and end.""" 702 return tup[start:end] 703 704 705@constexpr 706def expanded_shape(shape, expand_size): 707 return (1,)*expand_size + shape 708 709 710@constexpr 711def sequence_mul_int(seq, number): 712 """ 713 Make a new list with native python syntax. 714 715 Args: 716 seq (Union[list, tuple]): Input sequence. 717 y (int): Input number. 718 719 Returns: 720 New sequence, has the same type as `seq`. 721 """ 722 if not isinstance(number, int): 723 raise TypeError(f"can't multiply sequence by non-int of type {type(number)}") 724 return seq * number 725 726 727@constexpr 728def check_in_sequence(x, y): 729 """Determine whether the input `x` is in the sequence `y`.""" 730 return x in y 731 732 733@constexpr 734def is_slice(x): 735 return isinstance(x, slice) 736 737 738@constexpr 739def filter_expanded_dims(shape, not_expanded_dim): 740 diff = len(not_expanded_dim) - len(shape) 741 if diff < 0: 742 raise ValueError(f'unable to broadcast {shape}') 743 return tuple(compress(shape, not_expanded_dim[diff:])) 744 745 746@constexpr 747def sequence_to_index(sequence, dim_size): 748 """Transforms sequence to tensor index.""" 749 if not sequence: 750 return False 751 if all(isinstance(i, bool) for i in sequence): 752 seq_size = len(sequence) 753 if seq_size != dim_size: 754 raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}') 755 sequence = tuple(compress(range(dim_size), sequence)) 756 if not sequence: 757 return False 758 return make_tensor(sequence, mstype.int64, None, dim_size) 759 760 761@constexpr 762def int_to_index(i, shape): 763 """Converts integer to tensor indices.""" 764 dim_size = shape[0] 765 if i < -dim_size or i >= dim_size: 766 raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}') 767 i = i % dim_size 768 if len(shape) == 1: 769 return Tensor([[i]]) 770 grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]] 771 mesh = np.ix_(*grids) 772 index = np.stack(np.broadcast_arrays(*mesh), -1) 773 return Tensor(np.insert(index, 0, i, -1)) 774 775 776@constexpr 777def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim): 778 """Adds remaining dimensions not indexed to not_expanded_dim""" 779 if idx_advanced != -1: 780 if expand_true: 781 # tensor indices generate only one dimension with size 1 782 tensor_dims = (False,) 783 else: 784 tensor_dims = (True,)*tensor_index_ndim 785 not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:] 786 not_expanded_dim = not_expanded_dim + (True,)*rem_ndim 787 788 count_leading_false = 0 789 while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]: 790 count_leading_false += 1 791 idx_advanced = max(0, idx_advanced - count_leading_false) 792 return not_expanded_dim, idx_advanced 793 794 795@constexpr 796def check_slice_empty(start, stop, step): 797 return (start - stop)*step >= 0 798 799 800@constexpr 801def real_axes(ndim_orig, ndim_out, axes_orig): 802 """Returns the real axes to be reduced after performing broadcast""" 803 _diff = ndim_out - ndim_orig 804 axes = tuple(range(_diff)) 805 axes_orig = map(partial(operator.add, _diff), axes_orig) 806 return axes + tuple(axes_orig) 807 808 809check_axis_valid_const = constexpr(validator.check_axis_valid) 810 811 812@constexpr 813def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position): 814 """Computes slice tensor shapes""" 815 shape = [1] * len(slice_shape) 816 shape[slice_cnt] = slice_shape[slice_cnt] 817 shape = shape[:fancy_position] + [1] * broadcast_shape_len + shape[fancy_position:] 818 return shape 819 820 821@constexpr 822def infer_out_shape(*shapes): 823 """ 824 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 825 """ 826 shape_out = deque() 827 reversed_shapes = map(reversed, shapes) 828 for items in zip_longest(*reversed_shapes, fillvalue=1): 829 max_size = 0 if 0 in items else max(items) 830 if any(item not in (1, max_size) for item in items): 831 raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') 832 shape_out.appendleft(max_size) 833 return tuple(shape_out) 834 835 836@constexpr 837def use_copy_slice(tuple_index): 838 if tuple_index is not None and len(tuple_index) >= 2: 839 return (isinstance(tuple_index[0], int) and 840 isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and 841 all(x == slice(None, None, None) for x in tuple_index[2:])) 842 return False 843 844 845@constexpr 846def gen_exception_msg(msg_format, *args): 847 return msg_format.format(args) 848