• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""constexpr util"""
16
17import operator
18from functools import partial
19from itertools import compress
20
21import numpy as np
22from mindspore.common import dtype as mstype
23from mindspore.common._register_for_tensor import tensor_operator_registry
24from mindspore.common.tensor import Tensor
25from mindspore.ops import operations as P
26from mindspore.ops.operations import _inner_ops
27from mindspore.ops.primitive import constexpr, _primexpr
28from mindspore import log as logger
29from mindspore import context
30from mindspore._c_expression import Tensor as Tensor_
31
32ALL_TENSOR = 0
33NO_TENSOR = 1
34CONTAIN_TENSOR = 2
35ALL_SCALAR = 3
36ALL_BASIC = 7
37MIXED = 8
38
39INT_ = 0
40BOOL_ = 1
41UNSUPPORTED_DTYPE = 2
42
43TENSOR_SETITEM = "tensor setitem"
44TENSOR_GETITEM = "tensor getitem"
45
46SET_ITEM_BY_ONE_TENSOR = 0
47SET_ITEM_BY_TUPLE_OF_TENSOR = 1
48SET_ITEM_BY_NON_TENSOR = 2
49
50type_priority_map = {
51    mstype.bool_: 0,
52    mstype.uint8: 1,
53    mstype.int8: 2,
54    mstype.uint16: 3,
55    mstype.int16: 4,
56    mstype.uint32: 5,
57    mstype.int32: 6,
58    mstype.uint64: 7,
59    mstype.int64: 8,
60    mstype.float16: 9,
61    mstype.float32: 10,
62    mstype.float64: 11
63}
64
65complex_priority_map = {
66    mstype.float32: 0,
67    mstype.float64: 1,
68    mstype.complex64: 2,
69    mstype.complex128: 4
70}
71
72complex_types = [mstype.complex64, mstype.complex128]
73
74
75@constexpr
76def raise_value_error(msg):
77    """Constexpr for raise_value_error."""
78    raise ValueError(msg)
79
80
81@constexpr
82def raise_index_error(msg):
83    """Constexpr for raise_index_error."""
84    raise IndexError(msg)
85
86
87@constexpr
88def raise_type_error(msg):
89    """Constexpr for raise_type_error."""
90    raise TypeError(msg)
91
92
93@constexpr
94def raise_unimplemented_error(msg):
95    raise NotImplementedError(msg)
96
97
98@constexpr
99def log_warning(msg):
100    """Adds warning to logger."""
101    logger.warning(msg)
102
103
104@constexpr
105def check_equal(param1, param2, msg="{},{}"):
106    """Checks whether the two parameters are equal or not."""
107    if param1 != param2:
108        raise ValueError(msg.format(param1, param2))
109    return param1
110
111
112@constexpr
113def make_empty_slice():
114    """Creates a empty slice."""
115    return slice(None, None, None)
116
117
118@_primexpr
119def _deep_list(array_like, dim_size=None):
120    """convert nested tuple/list mixtures to pure nested list"""
121    if dim_size is not None:
122        array_like = check_range(array_like, dim_size)
123    if isinstance(array_like, (list, tuple)):
124        return list(map(lambda x: _deep_list(x, dim_size), array_like))
125    return array_like
126
127
128@constexpr
129def deep_tuple(array_like):
130    """convert nested tuple/list mixtures to pure nested tuple"""
131    if isinstance(array_like, (list, tuple)):
132        return tuple(map(deep_tuple, array_like))
133    return array_like
134
135
136def _deep_tensor_to_nparray(array_like):
137    """
138    convert a nested list of tensor to nested list of np_array.
139
140    Args:
141        array_like(list(tensor)): In any format of nested lists that may contain
142        tensors.
143
144    Returns:
145        array_like(list(np_array)): Formatted array that can be directly processed
146            by numpy.array(), with all tensor elements converted to numpy_array.
147    """
148    # Recursively check whether each element is a tensor or not, if is tensor,
149    # convert it to a numpy array in place
150    if isinstance(array_like, Tensor):
151        return array_like.asnumpy()
152
153    if isinstance(array_like, list):
154        for idx, value in enumerate(array_like):
155            array_like[idx] = _deep_tensor_to_nparray(value)
156
157    return array_like
158
159
160@_primexpr
161def check_range(x, dim_size):
162    if dim_size is None:
163        return x
164    if isinstance(x, int) and not isinstance(x, bool):
165        if x >= dim_size or x < -dim_size:
166            raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}')
167        x = x % dim_size
168    return x
169
170
171@_primexpr
172def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
173    """
174    Converts the input to tensor.
175
176    This function converts tensors from an array-like object.
177
178    Args:
179        a (Union[int, float, bool, list, tuple]): Input data, in any form that can
180            be converted to a `Tensor`.
181        dtype (:class:`mindspore.dtype`): Designated tensor dtype.
182
183    Returns:
184        Tensor, generated tensor with the specified dtype.
185
186    Raises:
187        TypeError: If input arguments have types not specified above.
188        ValueError: If input `a` has different sizes at different dimensions.
189    """
190    if data_shape:
191        return Tensor(np.zeros(data_shape), dtype)
192
193    if not isinstance(a, (list, tuple, int, float, bool)):
194        raise TypeError(f"Input data must be `int`, `float`, `bool`, `list` or `tuple`, but got {a}")
195
196    if dim_size is not None:
197        a = check_range(a, dim_size)
198
199    if isinstance(a, int):
200        return P.ScalarToTensor()(a, dtype)
201
202    if isinstance(a, (list, tuple)):
203        if not a:
204            return Tensor_(a, dtype)
205        # Convert all tuple/nested tuples to lists
206        a = _deep_list(a, dim_size)
207        # Convert all tensor sub-elements to numpy arrays
208        a = _deep_tensor_to_nparray(a)
209        a = np.asarray(a)
210        if a.dtype is np.dtype('object'):
211            raise ValueError('Input array must have the same size across all dimensions.')
212
213    if isinstance(a, np.ndarray):
214        if a.dtype is np.dtype('object'):
215            raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
216
217    return Tensor(a, dtype)
218
219setattr(tensor_operator_registry, 'make_tensor', make_tensor)
220
221
222def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
223    """Judges whether the data dim is valid."""
224    if data_dim < min_data_dim or data_dim > max_data_dim:
225        raise ValueError(f"The input data's dim must in the range of [{min_data_dim}, "
226                         f"{max_data_dim}], but got '{data_dim}'.")
227
228
229def get_source_shape(data_shape, value_shape):
230    """Returns the shape of value that will be used to broadcast against data."""
231    if len(value_shape) > len(data_shape):
232        return data_shape
233    return value_shape
234
235
236@constexpr
237def check_tensor_setitem_index(index, element_type=None):
238    """Checks tuple index type of tensor assignment."""
239    if index is None:
240        raise IndexError("Tensor's index cannot be None.")
241    if isinstance(index, slice):
242        return True
243    if isinstance(index, tuple):
244        if not index:
245            raise IndexError("Tensor's index cannot be empty.")
246        for item in index:
247            if not isinstance(item, (slice, type(...), int)):
248                raise IndexError(
249                    "Index of type '{}' is not supported yet.".format(type(item)))
250        return True
251    if isinstance(index, mstype.TensorType):
252        if element_type is None or element_type != mstype.bool_:
253            raise TypeError(
254                "The index of tensor should be a bool type tensor. "
255                "{} type is not supported yet.".format(element_type))
256        return True
257
258    raise IndexError(
259        "Index of type '{}' is not supported yet.".format(type(index)))
260
261
262@constexpr
263def is_same_type(inst, type_):
264    """
265    Checks whether an object is an instance of a target type.
266
267    Inputs:
268        inst (mindspore.dtype): Inspected type.
269        type_ (mindspore.dtype): Target type.
270
271    Outputs:
272        bool, the check result.
273    """
274    return inst == type_
275
276
277@constexpr
278def check_valid_dim(dim, name):
279    """Checks whether the dim is valid."""
280    if dim not in (1, 2):
281        raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
282
283
284@constexpr
285def judge_index_type(index_type, target_type):
286    """Judges whether the index type is valid."""
287    if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type):
288        return True
289    return False
290
291
292@constexpr
293def judge_indexes_types(dtypes, target_type):
294    """Check a tuple of tensor data type."""
295    for dtype in dtypes:
296        if isinstance(target_type, (list, tuple)):
297            if dtype not in target_type:
298                return False
299        else:
300            if dtype != target_type:
301                return False
302    return True
303
304
305@constexpr
306def check_type_isinstance(dtype, target_type):
307    """Checks whether the dtype is instance of target type."""
308    if isinstance(dtype, (list, tuple)):
309        return all(isinstance(ele, target_type) for ele in dtype)
310    return isinstance(dtype, target_type)
311
312
313@constexpr
314def check_type_invalid(dtype, target_type):
315    """Checks whether the dtype is valid."""
316    return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type)
317
318
319@constexpr
320def check_type_valid(dtype, target_type, op_name):
321    """Checks whether the dtype is valid."""
322    if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
323        if op_name in (TENSOR_GETITEM, TENSOR_SETITEM):
324            raise IndexError(
325                f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.")
326        raise TypeError(
327            f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.")
328
329
330@constexpr
331def check_types_valid(dtypes, target_type, op_name):
332    """Check a tuple of tensor data type."""
333    for dtype in dtypes:
334        check_type_valid(dtype, target_type, op_name)
335
336
337@constexpr
338def get_pos_of_indexes_types(indexes_types, op_name):
339    """Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
340    slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
341        sequence_positions = (), (), (), (), (), (), ()
342    for i, index_type in enumerate(indexes_types):
343        if isinstance(index_type, mstype.Slice):
344            slice_positions += (i,)
345        elif isinstance(index_type, mstype.Ellipsis_):
346            ellipsis_positions += (i,)
347        elif isinstance(index_type, mstype.NoneType):
348            none_positions += (i,)
349        elif isinstance(index_type, mstype.Int):
350            int_positions += (i,)
351        elif isinstance(index_type, mstype.Bool):
352            bool_positions += (i,)
353        elif isinstance(index_type, mstype.TensorType):
354            tensor_positions += (i,)
355        elif isinstance(index_type, (list, tuple)):
356            sequence_positions += (i,)
357        else:
358            raise TypeError(f"For '{op_name}', the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', 'int', "
359                            f"'List', 'Tuple', 'bool', but got {index_type}.")
360    if len(ellipsis_positions) > 1:
361        raise IndexError(
362            f"For '{op_name}, an index can only have a single ellipsis('...')")
363
364    return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
365        tensor_positions, sequence_positions
366
367
368def ellipsis2slice(input_, shape):
369    """Converts ellipsis to slice."""
370    input_slice = input_
371    result = []
372    if isinstance(input_, type(...)):
373        input_slice = (input_,)
374    ell_count = 0
375    for _, element in enumerate(input_slice):
376        if not isinstance(element, type(...)):
377            result.append(element)
378            continue
379        ell_count += 1
380        if ell_count > 1:
381            raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
382                             "but it is currently {}".format(input_slice))
383        for _ in range(len(shape) - len(input_slice) + 1):
384            result.append(slice(None, None, None))
385    return tuple(result)
386
387
388@constexpr
389def slice2indices(input_slice, shape):
390    """
391    Converts slice to indices.
392
393    Inputs:
394        input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice.
395        shape (tuple): The shape of a tensor is an integer element tuple.
396
397    Outputs:
398        Tensor, the shape is (n, 1).
399    """
400    start, stop, step = normalize_slice(input_slice, shape[0])
401    if check_slice_empty(start, stop, step):
402        return False
403    ndim = len(shape)
404    mesh = list()
405    range_op = P.Range()
406    cast_op = P.Cast()
407    grids = [
408        range_op(cast_op(start, mstype.int64), cast_op(stop, mstype.int64),
409                 cast_op(step, mstype.int64))
410    ]
411    grids += [
412        range_op(Tensor(0, mstype.int64), cast_op(dim_size, mstype.int64),
413                 Tensor(1, mstype.int64)) for dim_size in shape[1:]
414    ]
415    for j, grid in enumerate(grids):
416        mesh.append(P.Reshape()(grid, tuple(
417            [grid.size if j == t else 1 for t in range(ndim)])))
418    shapes = map(P.Shape(), mesh)
419    out_shape = infer_out_shape(*shapes)
420    mesh_arrays = list()
421    for arr in mesh:
422        mesh_arrays.append(P.BroadcastTo(out_shape)(arr))
423    return P.Stack(-1)(mesh_arrays)
424
425
426@constexpr
427def check_indices(indices_size, index):
428    """Checks indices whether is empty."""
429    if indices_size < 1:
430        raise IndexError(
431            "The tensor's index is unreasonable. index:{}".format(index))
432    return indices_size
433
434
435@_primexpr
436def check_indices_value_size(indices_size, value_size):
437    """Checks if the sizes are already matched."""
438    if value_size < 1:
439        raise ValueError("The value assigned to tensor cannot be empty.")
440    if value_size > 1:
441        if value_size != indices_size:
442            raise ValueError(
443                "The value given to tensor does not match the index size,"
444                " value size:{}, indics size:{}".format(value_size, indices_size))
445    return value_size
446
447
448@constexpr
449def tuple_index_type_cnt(types, op_name):
450    """count the tensor type of types which contains the tuple elements' type."""
451    if all(isinstance(ele, mstype.TensorType) for ele in types):
452        return ALL_TENSOR
453    if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types):
454        return ALL_BASIC
455    return MIXED
456
457
458@constexpr
459def check_value_elements(types):
460    """Judges the type of all elements of the tuple."""
461    tensor_number = 0
462    last_type = None
463    mix_but_no_tensor = False
464    for ele in types:
465        if isinstance(ele, mstype.TensorType):
466            tensor_number += 1
467        elif isinstance(ele, (list, tuple)):
468            return MIXED
469
470        if last_type is None:
471            last_type = type(ele)
472        elif not isinstance(ele, last_type):
473            mix_but_no_tensor = True
474
475    if tensor_number == 0:
476        if mix_but_no_tensor:
477            return MIXED
478        return NO_TENSOR
479    if tensor_number == len(types):
480        return ALL_TENSOR
481    return CONTAIN_TENSOR
482
483
484@constexpr
485def get_index_tensor_dtype(dtype):
486    """Check a tuple of tensor data type."""
487    if dtype in mstype.int_type:
488        return INT_
489    if dtype == mstype.bool_:
490        return BOOL_
491    raise IndexError(
492        f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
493
494
495@constexpr
496def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
497    """Check tensors data type same."""
498    if value_dtype == data_dtype:
499        return True
500    raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
501                    f"is not consistent with assigned tensor data type {data_dtype}.")
502
503
504@constexpr
505def get_broadcast_shape(x_shape, y_shape, prim_name):
506    """Get broadcast shape from input shapes."""
507    if x_shape is None or y_shape is None:
508        raise ValueError("get_broadcast_shape has dynamic rank input")
509    if None in x_shape or None in y_shape:
510        raise ValueError("get_broadcast_shape has dynamic shape input")
511    if x_shape == y_shape:
512        return x_shape
513    x_len = len(x_shape)
514    y_len = len(y_shape)
515    length = x_len if x_len < y_len else y_len
516    broadcast_shape_back = []
517
518    for i in range(-length, 0):
519        if x_shape[i] == 1:
520            broadcast_shape_back.append(y_shape[i])
521        elif y_shape[i] == 1:
522            broadcast_shape_back.append(x_shape[i])
523        elif x_shape[i] == y_shape[i]:
524            broadcast_shape_back.append(x_shape[i])
525        else:
526            raise ValueError(f"For '{prim_name}', x.shape and y.shape need to "
527                             f"broadcast. The value of x.shape[{i}] or y.shape[{i}]"
528                             f" must be 1 or -1 when they are not the same, "
529                             f"but got x.shape = {x_shape} "
530                             f"and y.shape = {y_shape}.")
531
532    broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
533    broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
534    return broadcast_shape
535
536
537@constexpr
538def generate_broadcast_shape(shapes, op_name):
539    """Generate broadcast shape for a tuple of shape."""
540    if not shapes:
541        return ()
542    broadcast_shape = shapes[0]
543    for shape in shapes:
544        broadcast_shape = get_broadcast_shape(tuple(broadcast_shape), shape, op_name)
545    return tuple(broadcast_shape)
546
547
548@_primexpr
549def compute_multiples(origin_shape, broadcast_shape):
550    """Compute multiples between origin shape with broadcast shape."""
551    len_gap = len(broadcast_shape) - len(origin_shape)
552    return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], tuple(origin_shape)))
553
554
555@constexpr
556def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
557    """Convert a scalar to a tensor."""
558    if op_type == SET_ITEM_BY_ONE_TENSOR:
559        updates_shape = indices_shape + data_shape[1:]
560    else:
561        updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
562    return P.FillV2()(updates_shape, P.Cast()(value, data_dtype))
563
564
565def generate_updates_shape(data_shape, index_shape, op_type):
566    """Generate updates shape for 'tensor setitem'."""
567    if op_type == SET_ITEM_BY_ONE_TENSOR:
568        updates_shape = index_shape + data_shape[1:]
569    else:
570        updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
571    return updates_shape
572
573
574@constexpr
575def transform_slice_to_ele_list(slice_index, dim_len):
576    """Transforms slice to element list."""
577    slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step)
578    start, stop, end = normalize_slice(slice_obj, dim_len)
579    slice_ele_list = list(range(start, stop, end))
580    if not slice_ele_list:
581        raise IndexError(f"An empty slice is not supported, got {slice_obj}")
582    return slice_ele_list
583
584
585@constexpr
586def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
587                                                    slice_shapes, op_name, fancy_position=None):
588    """
589    Generate index info which contain broadcast shape, final shape,
590    indexes shapes info, ellipsis size from a tuple of mixed tensors.
591    """
592    tensor_positions = tuple(sorted(tensor_positions))
593    if fancy_position is None:
594        tensor_index_continue_tag = _judge_order_continuous(tensor_positions)
595        fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0
596    broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name)
597
598    final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:]
599    index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \
600        broadcast_shape + (1,) * len(slice_shapes[fancy_position:])
601
602    return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position
603
604
605def _judge_order_continuous(order_sequence):
606    if not order_sequence:
607        return False
608    for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]):
609        if idx1 + 1 != idx2:
610            return False
611    return True
612
613
614@constexpr
615def scalar_in_sequence(x, y):
616    """Determine whether the scalar in the sequence."""
617    return x in y
618
619
620@constexpr
621def get_np_eps(input_dtype):
622    """Get numpy eps."""
623    nptype = mstype.dtype_to_nptype(input_dtype)
624    eps = np.finfo(nptype).eps
625    return float(eps)
626
627
628@constexpr
629def check_number_index_type(number):
630    """Check if it is int or bool number"""
631    if isinstance(number, bool):
632        return BOOL_
633    if isinstance(number, int):
634        return INT_
635    raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} "
636                     .format(number, type(number)))
637
638
639@constexpr
640def scalar_to_tensor(x):
641    """Convert a scalar to a tensor"""
642    return Tensor(x)
643
644
645@constexpr
646def unpack(x):
647    if isinstance(x, (tuple, list)) and len(x) == 1:
648        return unpack(x[0])
649    return x
650
651
652@_primexpr
653def normalize_start(start, dim_size):
654    """
655    Normalize `start` according to the number of dimensions (`dim_size`).
656    If the number of dimensions is not given, return the original input directly.
657    """
658    if start is None:
659        return 0
660    if dim_size is None:
661        return start
662    if start < 0:
663        return 0 if start < -dim_size else start % dim_size
664    return start if start < dim_size else dim_size
665
666
667@_primexpr
668def normalize_stop(stop, dim_size):
669    """
670    Normalize `stop` according to the number of dimensions (`dim_size`).
671    If the number of dimensions is not given, return the original input directly.
672    """
673    if stop is None and dim_size is None:
674        raise IndexError("Not Support stop is None when dim is dynamic")
675    if stop is None:
676        return dim_size
677    if dim_size is None:
678        return stop
679    if stop < 0:
680        return 0 if stop < -dim_size else stop % dim_size
681    return stop if stop < dim_size else dim_size
682
683
684@constexpr
685def get_step_from_slice(input_slice):
686    """get step in a slice."""
687    step = input_slice.step
688    if step is None:
689        step = 1
690    return step
691
692
693@_primexpr
694def normalize_slice(input_slice, dim_size):
695    """Normalizes start, stop, step in a slice."""
696    step = input_slice.step
697    if step is None:
698        step = 1
699    if step >= 0:
700        start = normalize_start(input_slice.start, dim_size)
701        stop = normalize_stop(input_slice.stop, dim_size)
702    else:
703        start = normalize_stop(input_slice.start, dim_size)
704        stop = normalize_start(input_slice.stop, dim_size)
705    return start, stop, step
706
707
708@constexpr
709def tuple_slice(tup, start, end):
710    """get sliced tuple from start and end."""
711    return tup[start:end]
712
713
714def expanded_shape(shape, expand_size):
715    return (1,)*expand_size + shape
716
717
718@constexpr
719def sequence_mul_int(seq, number):
720    """
721    Make a new list with native python syntax.
722
723    Args:
724        seq (Union[list, tuple]): Input sequence.
725        y (int): Input number.
726
727    Returns:
728        New sequence, has the same type as `seq`.
729    """
730    if not isinstance(number, int):
731        raise TypeError(f"can't multiply sequence by non-int of type {type(number)}")
732    return seq * number
733
734
735@constexpr
736def check_in_sequence(x, y):
737    """Determine whether the input `x` is in the sequence `y`."""
738    return x in y
739
740
741@constexpr
742def is_slice(x):
743    return isinstance(x, slice)
744
745
746@_primexpr
747def filter_expanded_dims(shape, not_expanded_dim):
748    """filter_expanded_dims"""
749    def _check(diff, shape):
750        if diff < 0:
751            raise ValueError(f'unable to broadcast {shape}')
752
753    diff = len(not_expanded_dim) - len(shape)
754    _check(diff, shape)
755    res = list()
756    for i, flag in zip(shape, not_expanded_dim[diff:]):
757        if flag:
758            res.append(i)
759    return tuple(res)
760
761
762@constexpr
763def sequence_to_index(sequence, dim_size):
764    """Transforms sequence to tensor index."""
765    if not sequence:
766        return False
767    if all(isinstance(i, bool) for i in sequence):
768        if dim_size is None:
769            return Tensor(sequence)
770        seq_size = len(sequence)
771        if seq_size != dim_size:
772            raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
773        sequence = tuple(compress(range(dim_size), sequence))
774        if not sequence:
775            return False
776    return make_tensor(sequence, mstype.int64, None, dim_size)
777
778
779@constexpr
780def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
781    """Adds remaining dimensions not indexed to not_expanded_dim"""
782    if idx_advanced != -1:
783        if expand_true:
784            # tensor indices generate only one dimension with size 1
785            tensor_dims = (False,)
786        else:
787            tensor_dims = (True,)*tensor_index_ndim
788        not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
789    not_expanded_dim = not_expanded_dim + (True,)*rem_ndim
790
791    count_leading_false = 0
792    while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]:
793        count_leading_false += 1
794    idx_advanced = max(0, idx_advanced - count_leading_false)
795    return not_expanded_dim, idx_advanced
796
797
798@_primexpr
799def check_slice_empty(start, stop, step):
800    return (start - stop) * step >= 0
801
802
803@_primexpr
804def real_axes(ndim_orig, ndim_out, axes_orig):
805    """Returns the real axes to be reduced after performing broadcast"""
806    _diff = ndim_out - ndim_orig
807    axes = tuple(range(_diff))
808    axes_orig = map(partial(operator.add, _diff), axes_orig)
809    return axes + tuple(axes_orig)
810
811
812@_primexpr
813def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
814    """Computes slice tensor shapes"""
815    shape = [1] * len(slice_shape)
816    shape[slice_cnt] = slice_shape[slice_cnt]
817    shape = shape[:fancy_position] + [1] * broadcast_shape_len + shape[fancy_position:]
818    return shape
819
820
821@_primexpr
822def infer_out_shape(*shapes):
823    """
824    Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
825    """
826    shape_out = list()
827    max_len = max([len(it) for it in shapes])
828
829    for i in range(max_len):
830        items = [it[i-max_len+len(it)] if i-max_len +
831                 len(it) >= 0 else 1 for it in shapes]
832        max_size = 0 if 0 in items else max(items)
833        shape_out.append(max_size)
834    return tuple(shape_out)
835
836
837@constexpr(check=False)
838def use_copy_slice(tuple_index):
839    if tuple_index is not None and len(tuple_index) >= 2:
840        return (isinstance(tuple_index[0], int) and
841                isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and
842                all(x == slice(None, None, None) for x in tuple_index[2:]))
843    return False
844
845
846@constexpr
847def is_ascend():
848    """Device target is Ascend or not"""
849    return context.get_context('device_target') == "Ascend"
850
851
852@constexpr
853def gen_exception_msg(msg_format, *args):
854    return msg_format.format(*args)
855
856
857@constexpr
858def get_output_dtype(dtype_1, dtype_2, use_complex=False):
859    """Returns output dtype after type promotion."""
860    if use_complex:
861        priority_map = complex_priority_map
862        type_str = "Complex binary"
863    else:
864        priority_map = type_priority_map
865        type_str = "Binary"
866    priority_1 = priority_map.get(dtype_1, None)
867    priority_2 = priority_map.get(dtype_2, None)
868    if not priority_1 or not priority_2:
869        raise ValueError(f"{type_str} op type promotion not supported for {dtype_1} and {dtype_2}")
870    if priority_1 > priority_2:
871        return dtype_1
872    return dtype_2
873
874
875@constexpr
876def promote_binary_dtype(dtype_1, dtype_2):
877    """
878    promote binary types
879    """
880    if dtype_1 == dtype_2:
881        return dtype_1
882    if dtype_1 in complex_types or dtype_2 in complex_types:
883        return get_output_dtype(dtype_1, dtype_2, True)
884    return get_output_dtype(dtype_1, dtype_2, False)
885
886
887@_primexpr
888def generate_padding_shape(shape, length):
889    """
890    pad the `shape` to `length` with 1.
891    """
892
893    if _inner_ops.IsConstant()(length) and _inner_ops.IsConstant()(len(shape)):
894        if len(shape) > length:
895            raise ValueError(f"Can not pad {shape} to length {length}.")
896
897    return shape + (1,) * (length - len(shape))
898