• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""constexpr util"""
16
17from itertools import compress, zip_longest
18from functools import partial
19from collections import deque
20import operator
21
22import numpy as np
23
24from ...primitive import constexpr
25from .... import log as logger
26from ....common import dtype as mstype
27from ....common.tensor import Tensor
28from ....common._register_for_tensor import tensor_operator_registry
29from ....ops import _utils as op_utils
30from ...._checkparam import Validator as validator
31
32ALL_TENSOR = 0
33NO_TENSOR = 1
34CONTAIN_TENSOR = 2
35ALL_SCALAR = 3
36ALL_BASIC = 7
37MIXED = 8
38
39INT_ = 0
40BOOL_ = 1
41UNSUPPORTED_DTYPE = 2
42
43TENSOR_SETITEM = "tensor setitem"
44TENSOR_GETITEM = "tensor getitem"
45
46SET_ITEM_BY_ONE_TENSOR = 0
47SET_ITEM_BY_TUPLE_OF_TENSOR = 1
48SET_ITEM_BY_NON_TENSOR = 2
49
50
51@constexpr
52def raise_value_error(msg):
53    """Constexpr for raise_value_error."""
54    raise ValueError(msg)
55
56
57@constexpr
58def raise_index_error(msg):
59    """Constexpr for raise_index_error."""
60    raise IndexError(msg)
61
62
63@constexpr
64def raise_type_error(msg):
65    """Constexpr for raise_type_error."""
66    raise TypeError(msg)
67
68
69@constexpr
70def raise_unimplemented_error(msg):
71    raise NotImplementedError(msg)
72
73
74@constexpr
75def check_equal(param1, param2, msg="{},{}"):
76    """Checks whether the two parameters are equal or not."""
77    if param1 != param2:
78        raise ValueError(msg.format(param1, param2))
79    return param1
80
81
82@constexpr
83def make_empty_slice():
84    """Creates a empty slice."""
85    return slice(None, None, None)
86
87
88@constexpr
89def _deep_list(array_like, dim_size=-1):
90    """convert nested tuple/list mixtures to pure nested list"""
91    if dim_size != -1:
92        array_like = check_range(array_like, dim_size)
93    if isinstance(array_like, (list, tuple)):
94        return list(map(lambda x: _deep_list(x, dim_size), array_like))
95    return array_like
96
97
98@constexpr
99def deep_tuple(array_like):
100    """convert nested tuple/list mixtures to pure nested tuple"""
101    if isinstance(array_like, (list, tuple)):
102        return tuple(map(deep_tuple, array_like))
103    return array_like
104
105
106def _deep_tensor_to_nparray(array_like):
107    """
108    convert a nested list of tensor to nested list of np_array.
109
110    Args:
111        array_like(list(tensor)): In any format of nested lists that may contain
112        tensors.
113
114    Returns:
115        array_like(list(np_array)): Formatted array that can be directly processed
116            by numpy.array(), with all tensor elements converted to numpy_array.
117    """
118    # Recursively check whether each element is a tensor or not, if is tensor,
119    # convert it to a numpy array in place
120    if isinstance(array_like, Tensor):
121        return array_like.asnumpy()
122
123    if isinstance(array_like, list):
124        for idx, value in enumerate(array_like):
125            array_like[idx] = _deep_tensor_to_nparray(value)
126
127    return array_like
128
129
130@constexpr
131def check_range(x, dim_size):
132    if isinstance(x, int) and not isinstance(x, bool):
133        if x >= dim_size or x < -dim_size:
134            raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}')
135        x = x % dim_size
136    return x
137
138
139@constexpr
140def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=-1):
141    """
142    Converts the input to tensor.
143
144    This function converts tensors from an array-like object.
145
146    Args:
147        a (Union[int, float, bool, list, tuple]): Input data, in any form that can
148            be converted to a `Tensor`.
149        dtype (:class:`mindspore.dtype`): Designated tensor dtype.
150
151    Returns:
152        Tensor, generated tensor with the specified dtype.
153
154    Raises:
155        TypeError: If input arguments have types not specified above.
156        ValueError: If input `a` has different sizes at different dimensions.
157    """
158    if data_shape:
159        return Tensor(np.zeros(data_shape), dtype)
160
161    if not isinstance(a, (list, tuple, int, float, bool)):
162        raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
163
164    if dim_size != -1:
165        a = check_range(a, dim_size)
166
167    if isinstance(a, (list, tuple)):
168        # Convert all tuple/nested tuples to lists
169        a = _deep_list(a, dim_size)
170        # Convert all tensor sub-elements to numpy arrays
171        a = _deep_tensor_to_nparray(a)
172        a = np.asarray(a)
173        if a.dtype is np.dtype('object'):
174            raise ValueError('Input array must have the same size across all dimensions.')
175
176    if isinstance(a, np.ndarray):
177        if a.dtype is np.dtype('object'):
178            raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
179
180    return Tensor(a, dtype)
181
182tensor_operator_registry.register('make_tensor', make_tensor)
183
184
185@constexpr
186def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
187    """Judges whether the data dim is valid."""
188    if data_dim < min_data_dim or data_dim > max_data_dim:
189        raise ValueError(f"The input data's dim should in the range of[{min_data_dim}, "
190                         f"{max_data_dim}], bug actually is '{data_dim}'")
191
192
193@constexpr
194def get_source_shape(data_shape, value_shape):
195    """Returns the shape of value that will be used to broadcast against data."""
196    cannot_broadcast = False
197    source_shape = value_shape
198    for i, j in zip(reversed(data_shape), reversed(value_shape)):
199        if j not in (1, i):
200            cannot_broadcast = True
201    for i in range(len(value_shape) - len(data_shape)):
202        source_shape = data_shape
203        if value_shape[i] != 1:
204            cannot_broadcast = True
205    if cannot_broadcast:
206        raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}')
207    return source_shape
208
209
210@constexpr
211def check_tensor_setitem_index(index, element_type=None):
212    """Checks tuple index type of tensor assignment."""
213    if index is None:
214        raise IndexError("Tensor's index cannot be None.")
215    if isinstance(index, slice):
216        return True
217    if isinstance(index, tuple):
218        if not index:
219            raise IndexError("Tensor's index cannot be empty.")
220        for item in index:
221            if not isinstance(item, (slice, type(...), int)):
222                raise IndexError(
223                    "Index of type '{}' is not supported yet.".format(type(item)))
224        return True
225    if isinstance(index, mstype.tensor_type):
226        if element_type is None or element_type != mstype.bool_:
227            raise TypeError(
228                "The index of tensor should be a bool type tensor. "
229                "{} type is not supported yet.".format(element_type))
230        return True
231
232    raise IndexError(
233        "Index of type '{}' is not supported yet.".format(type(index)))
234
235
236@constexpr
237def is_same_type(inst, type_):
238    """
239    Checks whether an object is an instance of a target type.
240
241    Inputs:
242        inst (mindspore.dtype): Inspected type.
243        type_ (mindspore.dtype): Target type.
244
245    Outputs:
246        bool, the check result.
247    """
248    return inst == type_
249
250
251@constexpr
252def check_valid_dim(dim, name):
253    """Checks whether the dim is valid."""
254    if dim not in (1, 2):
255        raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
256
257
258@constexpr
259def judge_index_type(index_type, target_type):
260    """Judges whether the index type is valid."""
261    if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type):
262        return True
263    return False
264
265
266@constexpr
267def judge_indexes_types(dtypes, target_type):
268    """Check a tuple of tensor data type."""
269    for dtype in dtypes:
270        if isinstance(target_type, (list, tuple)):
271            if dtype not in target_type:
272                return False
273        else:
274            if dtype != target_type:
275                return False
276    return True
277
278
279@constexpr
280def check_type_invalid(dtype, target_type):
281    """Checks whether the dtype is valid."""
282    return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type)
283
284
285@constexpr
286def check_type_valid(dtype, target_type, op_name):
287    """Checks whether the dtype is valid."""
288    if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
289        if op_name in (TENSOR_GETITEM, TENSOR_SETITEM):
290            raise IndexError(
291                f"The '{op_name}' doesn't support {dtype}' and expect to receive {target_type}.")
292        raise TypeError(
293            f"The '{op_name}' doesn't support {dtype}' and expect to receive {target_type}.")
294
295
296@constexpr
297def check_types_valid(dtypes, target_type, op_name):
298    """Check a tuple of tensor data type."""
299    for dtype in dtypes:
300        check_type_valid(dtype, target_type, op_name)
301
302
303@constexpr
304def get_pos_of_indexes_types(indexes_types, op_name):
305    """Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
306    slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
307        sequence_positions = (), (), (), (), (), (), ()
308    for i, index_type in enumerate(indexes_types):
309        if isinstance(index_type, mstype.Slice):
310            slice_positions += (i,)
311        elif isinstance(index_type, mstype.Ellipsis_):
312            ellipsis_positions += (i,)
313        elif isinstance(index_type, mstype.none_type):
314            none_positions += (i,)
315        elif isinstance(index_type, mstype.Int):
316            int_positions += (i,)
317        elif isinstance(index_type, mstype.Bool):
318            bool_positions += (i,)
319        elif isinstance(index_type, mstype.tensor_type):
320            tensor_positions += (i,)
321        elif isinstance(index_type, (list, tuple)):
322            sequence_positions += (i,)
323        else:
324            raise IndexError(f"For '{op_name}', the index elements only support 'Slice', 'Ellipsis', 'None', "
325                             f"'Tensor', 'int',  'List', 'Tuple', 'bool' but got {index_type}.")
326    if len(ellipsis_positions) > 1:
327        raise IndexError(
328            f"For '{op_name}, an index can only have a single ellipsis('...')")
329
330    return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
331        tensor_positions, sequence_positions
332
333
334def ellipsis2slice(input_, shape):
335    """Converts ellipsis to slice."""
336    input_slice = input_
337    result = []
338    if isinstance(input_, type(...)):
339        input_slice = (input_,)
340    ell_count = 0
341    for _, element in enumerate(input_slice):
342        if not isinstance(element, type(...)):
343            result.append(element)
344            continue
345        ell_count += 1
346        if ell_count > 1:
347            raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
348                             "but it is currently {}".format(input_slice))
349        for _ in range(len(shape) - len(input_slice) + 1):
350            result.append(slice(None, None, None))
351    return tuple(result)
352
353
354@constexpr
355def slice2indices(input_slice, shape):
356    """
357    Converts slice to indices.
358
359    Inputs:
360        input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice.
361        shape (tuple): The shape of a tensor is an integer element tuple.
362
363    Outputs:
364        Tensor, the shape is (n, 1).
365    """
366    start, stop, step = normalize_slice(input_slice, shape[0])
367    if check_slice_empty(start, stop, step):
368        return False
369    grids = ([np.array(list(range(start, stop, step)), dtype=np.int64)] +
370             [np.array(list(range(dim_size)), dtype=np.int64) for dim_size in shape[1:]])
371    mesh = np.ix_(*grids)
372    return Tensor(np.stack(np.broadcast_arrays(*mesh), axis=-1))
373
374
375@constexpr
376def check_indices(indices_size, index):
377    """Checks indices whether is empty."""
378    if indices_size < 1:
379        raise IndexError(
380            "The tensor's index is unreasonable. index:{}".format(index))
381    return indices_size
382
383
384@constexpr
385def check_indices_value_size(indices_size, value_size):
386    """Checks if the sizes are already matched."""
387    if value_size < 1:
388        raise ValueError("The value assigned to tensor cannot be empty.")
389    if value_size > 1:
390        if value_size != indices_size:
391            raise ValueError(
392                "The value given to tensor does not match the index size,"
393                " value size:{}, indics size:{}".format(value_size, indices_size))
394    return value_size
395
396
397@constexpr
398def tuple_index_type_cnt(types, op_name):
399    """count the tensor type of types which contains the tuple elements' type."""
400    if all(isinstance(ele, mstype.tensor_type) for ele in types):
401        return ALL_TENSOR
402    if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types):
403        return ALL_BASIC
404    return MIXED
405
406
407@constexpr
408def check_value_elements(types):
409    """Judges the type of all elements of the tuple."""
410    tensor_number = 0
411    for ele in types:
412        if isinstance(ele, mstype.tensor_type):
413            tensor_number += 1
414    if tensor_number == 0:
415        return NO_TENSOR
416    if tensor_number == len(types):
417        return ALL_TENSOR
418    return CONTAIN_TENSOR
419
420
421@constexpr
422def get_index_tensor_dtype(dtype):
423    """Check a tuple of tensor data type."""
424    if dtype in mstype.int_type:
425        return INT_
426    if dtype == mstype.bool_:
427        return BOOL_
428    raise IndexError(
429        f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
430
431
432@constexpr
433def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
434    """Check tensors data type same."""
435    if value_dtype == data_dtype:
436        return True
437    raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
438                    f"is not consistent with assigned tensor data type {data_dtype}.")
439
440
441@constexpr
442def generate_broadcast_shape(shapes, op_name):
443    """Generate broadcast shape for a tuple of shape."""
444    if not shapes:
445        return ()
446    broadcast_shape = shapes[0]
447    for i, shape in enumerate(shapes):
448        logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
449        try:
450            broadcast_shape = op_utils.get_broadcast_shape(
451                broadcast_shape, shape, op_name)
452        except ValueError as ex:
453            raise IndexError(ex)
454    return tuple(broadcast_shape)
455
456
457@constexpr
458def check_two_shapes_need_broadcast(shape_x, shape_y):
459    """Check shape_y needs to be broadcast to shape_x."""
460    if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))):
461        raise ValueError(f"{shape_y} could not broadcast with {shape_x}.")
462    return shape_y != shape_x
463
464
465@constexpr
466def compute_multiples(origin_shape, broadcast_shape):
467    """Compute multiples between origin shape with broadcast shape."""
468    len_gap = len(broadcast_shape) - len(origin_shape)
469    return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
470
471
472@constexpr
473def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
474    """Convert a scalar to a tensor."""
475    if op_type == SET_ITEM_BY_ONE_TENSOR:
476        updates_shape = indices_shape + data_shape[1:]
477    else:
478        updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
479    return Tensor(np.full(updates_shape, value), dtype=data_dtype)
480
481
482@constexpr
483def generate_updates_shape(data_shape, index_shape, op_type):
484    """Generate updates shape for 'tensor setitem'."""
485    if op_type == SET_ITEM_BY_ONE_TENSOR:
486        updates_shape = index_shape + data_shape[1:]
487    else:
488        updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
489    return updates_shape
490
491
492@constexpr
493def transform_slice_to_ele_list(slice_index, dim_len):
494    """Transforms slice to element list."""
495    slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step)
496    start, stop, end = normalize_slice(slice_obj, dim_len)
497    slice_ele_list = list(range(start, stop, end))
498    if not slice_ele_list:
499        raise IndexError(f"An empty slice is not supported, got {slice_obj}")
500    return slice_ele_list
501
502
503@constexpr
504def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
505                                                    slice_shapes, op_name, fancy_position=None):
506    """
507    Generate index info which contain broadcast shape, final shape,
508    indexes shapes info, ellipsis size from a tuple of mixed tensors.
509    """
510    tensor_positions = tuple(sorted(tensor_positions))
511    if fancy_position is None:
512        tensor_index_continue_tag = _judge_order_continuous(tensor_positions)
513        fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0
514    broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name)
515
516    final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:]
517    index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \
518        broadcast_shape + (1,) * len(slice_shapes[fancy_position:])
519
520    return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position
521
522
523def _judge_order_continuous(order_sequence):
524    if not order_sequence:
525        return False
526    for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]):
527        if idx1 + 1 != idx2:
528            return False
529    return True
530
531
532@constexpr
533def scalar_in_sequence(x, y):
534    """Determine whether the scalar in the sequence."""
535    if x is None:
536        raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
537                         "but the scalar is not.")
538    if y is None:
539        raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
540                         "but the sequence is not.")
541    return x in y
542
543
544@constexpr
545def get_np_eps(input_dtype):
546    """Get numpy eps."""
547    nptype = mstype.dtype_to_nptype(input_dtype)
548    eps = np.finfo(nptype).eps
549    return float(eps)
550
551
552@constexpr
553def check_number_index_type(number):
554    """Check if it is int or bool number"""
555    if isinstance(number, bool):
556        return BOOL_
557    if isinstance(number, int):
558        return INT_
559    raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} "
560                     .format(number, type(number)))
561
562
563@constexpr
564def get_stride_info_from_slice(data_shape, slice_index):
565    """Get stride info from a python slice"""
566    begin, end, step = get_slice_stride(slice_index, data_shape[0])
567    begin_strides = [begin]
568    end_strides = [end]
569    step_strides = [step]
570    for end in data_shape[1:]:
571        begin_strides.append(0)
572        end_strides.append(end)
573        step_strides.append(1)
574    return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
575
576
577@constexpr
578def get_stride_info_from_integer(data_shape, number):
579    """Get stride info from a integer"""
580    begin_strides = [number]
581    end_strides = [number + 1]
582    step_strides = [1]
583    for end in data_shape[1:]:
584        begin_strides.append(0)
585        end_strides.append(end)
586        step_strides.append(1)
587    return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
588
589
590def get_slice_stride(index_slice, dim_size):
591    """Get slice stride info"""
592    step = 1 if index_slice.step is None else index_slice.step
593    start_default = 0
594    stop_default = dim_size
595    if step < 0:
596        start_default = -1
597        stop_default = -(dim_size + 1)
598    start = start_default if index_slice.start is None else index_slice.start
599    stop = stop_default if index_slice.stop is None else index_slice.stop
600    return start, stop, step
601
602
603@constexpr
604def get_stride_info_from_tuple(data_shape, tuple_index):
605    """Get stride info from a tuple"""
606    begin_strides, end_strides, step_strides = [], [], []
607    tuple_index_len = len(tuple_index)
608    data_dim = len(data_shape)
609    shrink_axis, index_count, ellipsis_count = 0, 0, 0
610    for index, dim_size in zip(tuple_index, data_shape):
611        if isinstance(index, slice):
612            start, stop, step = get_slice_stride(index, dim_size)
613            begin_strides.append(start)
614            end_strides.append(stop)
615            step_strides.append(step)
616            index_count = index_count + 1
617        elif isinstance(index, int):
618            begin_strides.append(index)
619            end_strides.append(index + 1)
620            step_strides.append(1)
621            shrink_axis = shrink_axis + (1 << index_count)
622            index_count = index_count + 1
623        elif index is ...:
624            ellipsis_count = ellipsis_count + 1
625            if ellipsis_count > 1:
626                raise IndexError("An index can have only one ellipsis (...)")
627            ellipsis_range_size = data_dim - tuple_index_len + 1
628            begin_strides.extend([0] * ellipsis_range_size)
629            end_strides.extend(
630                [shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
631            step_strides.extend([1] * ellipsis_range_size)
632            index_count = index_count + ellipsis_range_size
633        else:
634            raise IndexError("Not supported index data type, got ",
635                             index, " type is ", type(index))
636    for index in range(index_count, data_dim):
637        begin_strides.append(0)
638        end_strides.append(data_shape[index])
639        step_strides.append(1)
640    return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
641
642
643@constexpr
644def scalar_to_tensor(x):
645    """Convert a scalar to a tensor"""
646    return Tensor(x)
647
648
649@constexpr
650def unpack(x):
651    if isinstance(x, (tuple, list)) and len(x) == 1:
652        return unpack(x[0])
653    return x
654
655
656@constexpr
657def normalize_start(start, dim_size):
658    """
659    Normalize `start` according to the number of dimensions (`dim_size`).
660    If the number of dimensions is not given, return the original input directly.
661    """
662    if start is None:
663        return 0
664    if start < 0:
665        return 0 if start < -dim_size else start % dim_size
666    return start if start < dim_size else dim_size
667
668
669@constexpr
670def normalize_stop(stop, dim_size):
671    """
672    Normalize `stop` according to the number of dimensions (`dim_size`).
673    If the number of dimensions is not given, return the original input directly.
674    """
675    if stop is None:
676        return dim_size
677    if stop < 0:
678        return 0 if stop < -dim_size else stop % dim_size
679    return stop if stop < dim_size else dim_size
680
681
682@constexpr
683def normalize_slice(input_slice, dim_size):
684    """Normalizes start, stop, step in a slice."""
685    start = normalize_start(input_slice.start, dim_size)
686    stop = normalize_stop(input_slice.stop, dim_size)
687    step = input_slice.step
688    if step is None:
689        step = 1
690    if step >= 0:
691        start = normalize_start(input_slice.start, dim_size)
692        stop = normalize_stop(input_slice.stop, dim_size)
693    else:
694        start = normalize_stop(input_slice.start, dim_size)
695        stop = normalize_start(input_slice.stop, dim_size)
696    return start, stop, step
697
698
699@constexpr
700def tuple_slice(tup, start, end):
701    """get sliced tuple from start and end."""
702    return tup[start:end]
703
704
705@constexpr
706def expanded_shape(shape, expand_size):
707    return (1,)*expand_size + shape
708
709
710@constexpr
711def sequence_mul_int(seq, number):
712    """
713    Make a new list with native python syntax.
714
715    Args:
716        seq (Union[list, tuple]): Input sequence.
717        y (int): Input number.
718
719    Returns:
720        New sequence, has the same type as `seq`.
721    """
722    if not isinstance(number, int):
723        raise TypeError(f"can't multiply sequence by non-int of type {type(number)}")
724    return seq * number
725
726
727@constexpr
728def check_in_sequence(x, y):
729    """Determine whether the input `x` is in the sequence `y`."""
730    return x in y
731
732
733@constexpr
734def is_slice(x):
735    return isinstance(x, slice)
736
737
738@constexpr
739def filter_expanded_dims(shape, not_expanded_dim):
740    diff = len(not_expanded_dim) - len(shape)
741    if diff < 0:
742        raise ValueError(f'unable to broadcast {shape}')
743    return tuple(compress(shape, not_expanded_dim[diff:]))
744
745
746@constexpr
747def sequence_to_index(sequence, dim_size):
748    """Transforms sequence to tensor index."""
749    if not sequence:
750        return False
751    if all(isinstance(i, bool) for i in sequence):
752        seq_size = len(sequence)
753        if seq_size != dim_size:
754            raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
755        sequence = tuple(compress(range(dim_size), sequence))
756        if not sequence:
757            return False
758    return make_tensor(sequence, mstype.int64, None, dim_size)
759
760
761@constexpr
762def int_to_index(i, shape):
763    """Converts integer to tensor indices."""
764    dim_size = shape[0]
765    if i < -dim_size or i >= dim_size:
766        raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
767    i = i % dim_size
768    if len(shape) == 1:
769        return Tensor([[i]])
770    grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]]
771    mesh = np.ix_(*grids)
772    index = np.stack(np.broadcast_arrays(*mesh), -1)
773    return Tensor(np.insert(index, 0, i, -1))
774
775
776@constexpr
777def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
778    """Adds remaining dimensions not indexed to not_expanded_dim"""
779    if idx_advanced != -1:
780        if expand_true:
781            # tensor indices generate only one dimension with size 1
782            tensor_dims = (False,)
783        else:
784            tensor_dims = (True,)*tensor_index_ndim
785        not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
786    not_expanded_dim = not_expanded_dim + (True,)*rem_ndim
787
788    count_leading_false = 0
789    while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]:
790        count_leading_false += 1
791    idx_advanced = max(0, idx_advanced - count_leading_false)
792    return not_expanded_dim, idx_advanced
793
794
795@constexpr
796def check_slice_empty(start, stop, step):
797    return (start - stop)*step >= 0
798
799
800@constexpr
801def real_axes(ndim_orig, ndim_out, axes_orig):
802    """Returns the real axes to be reduced after performing broadcast"""
803    _diff = ndim_out - ndim_orig
804    axes = tuple(range(_diff))
805    axes_orig = map(partial(operator.add, _diff), axes_orig)
806    return axes + tuple(axes_orig)
807
808
809check_axis_valid_const = constexpr(validator.check_axis_valid)
810
811
812@constexpr
813def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
814    """Computes slice tensor shapes"""
815    shape = [1] * len(slice_shape)
816    shape[slice_cnt] = slice_shape[slice_cnt]
817    shape = shape[:fancy_position] + [1] * broadcast_shape_len + shape[fancy_position:]
818    return shape
819
820
821@constexpr
822def infer_out_shape(*shapes):
823    """
824    Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
825    """
826    shape_out = deque()
827    reversed_shapes = map(reversed, shapes)
828    for items in zip_longest(*reversed_shapes, fillvalue=1):
829        max_size = 0 if 0 in items else max(items)
830        if any(item not in (1, max_size) for item in items):
831            raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
832        shape_out.appendleft(max_size)
833    return tuple(shape_out)
834
835
836@constexpr
837def use_copy_slice(tuple_index):
838    if tuple_index is not None and len(tuple_index) >= 2:
839        return (isinstance(tuple_index[0], int) and
840                isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and
841                all(x == slice(None, None, None) for x in tuple_index[2:]))
842    return False
843
844
845@constexpr
846def gen_exception_msg(msg_format, *args):
847    return msg_format.format(args)
848