• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Operator argument handle function."""
16
17from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
18# Enum Class:
19from mindspore._c_expression import FormatEnum as Format
20from mindspore._c_expression import ReductionEnum as Reduction
21from mindspore.common import Tensor
22from mindspore.common import dtype as mstype
23
24
25def arg_invalid_info(op_name, arg_name, arg_val):
26    """
27    generate invalid msg.
28    """
29    return f"For '{op_name}', the value of '{arg_name}' is invalid: '{arg_val}'."
30
31
32def to_pair(op_name, arg_name, arg_val):
33    """
34    convert arg_val: int/tuple[int*2] -> tuple[int*2].
35    """
36    if isinstance(arg_val, (int, float)):
37        return (arg_val, arg_val)
38    if isinstance(arg_val, (list, tuple)):
39        return arg_val
40    raise ValueError(arg_invalid_info(op_name, arg_name, arg_val))
41
42
43def to_kernel_size(op_name, arg_name, kernel_size):
44    """
45    convert kernel_size: int/tuple[int*4] -> tuple[int*2].
46    """
47    if isinstance(kernel_size, int):
48        return (kernel_size, kernel_size)
49    if isinstance(kernel_size, (tuple, list)):
50        if len(kernel_size) == 4:
51            return (kernel_size[2], kernel_size[3])
52        return kernel_size
53    raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
54
55
56def to_strides(op_name, arg_name, stride):
57    """
58    convert strides: int/tuple[int*4] -> tuple[int*2].
59    """
60    if isinstance(stride, int):
61        return (stride, stride)
62    if isinstance(stride, (tuple, list)):
63        if len(stride) == 4:
64            return (stride[2], stride[3])
65        return stride
66    raise ValueError(arg_invalid_info(op_name, arg_name, stride))
67
68
69def to_rates(op_name, arg_name, rates):
70    """
71    convert rates: int/tuple[int*4] -> tuple[int*2].
72    """
73    if isinstance(rates, int):
74        return (rates, rates)
75    if isinstance(rates, (tuple, list)):
76        if len(rates) == 4:
77            return (rates[2], rates[3])
78        return rates
79    raise ValueError(arg_invalid_info(op_name, arg_name, rates))
80
81
82def to_dilations(op_name, arg_name, dilation):
83    """
84    convert dilations: int/tuple[int*4] -> tuple[int*2].
85    """
86    if isinstance(dilation, int):
87        return (dilation, dilation)
88    if isinstance(dilation, (tuple, list)):
89        if len(dilation) == 4:
90            return (dilation[2], dilation[3])
91        return dilation
92    raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
93
94
95def to_output_padding(op_name, arg_name, output_padding):
96    """
97    convert output_padding: int/tuple[int*4] -> tuple[int*2].
98    """
99    if isinstance(output_padding, int):
100        return (output_padding, output_padding)
101    if isinstance(output_padding, (tuple, list)):
102        if len(output_padding) == 4:
103            return (output_padding[2], output_padding[3])
104        return output_padding
105    raise ValueError(arg_invalid_info(op_name, arg_name, output_padding))
106
107
108def to_2d_paddings(op_name, arg_name, pad):
109    """
110    convert paddings: int -> tuple[int*2].
111    """
112    if isinstance(pad, int):
113        return (pad,) * 2
114    if isinstance(pad, (tuple, list)):
115        return pad
116    raise ValueError(arg_invalid_info(op_name, arg_name, pad))
117
118
119def to_paddings(op_name, arg_name, pad):
120    """
121    convert paddings: int -> tuple[int*4].
122    """
123    if isinstance(pad, int):
124        return (pad,) * 4
125    if isinstance(pad, (tuple, list)):
126        return pad
127    raise ValueError(arg_invalid_info(op_name, arg_name, pad))
128
129
130def to_3d_kernel_size(op_name, arg_name, kernel_size):
131    """
132    convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3].
133    """
134    if isinstance(kernel_size, int):
135        return (kernel_size, kernel_size, kernel_size)
136    if isinstance(kernel_size, (tuple, list)):
137        if len(kernel_size) == 5:
138            return (kernel_size[2], kernel_size[3], kernel_size[4])
139        return kernel_size
140    raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
141
142
143def to_3d_strides(op_name, arg_name, stride):
144    """
145    convert 3d stride: int/tuple[int*6] -> tuple[int*3].
146    """
147    if isinstance(stride, int):
148        return (stride, stride, stride)
149    if isinstance(stride, (tuple, list)):
150        if len(stride) == 5:
151            return (stride[2], stride[3], stride[4])
152        return stride
153    raise ValueError(arg_invalid_info(op_name, arg_name, stride))
154
155
156def to_3d_dilations(op_name, arg_name, dilation):
157    """
158    convert 3d dilation: int/tuple[int*6] -> tuple[int*3].
159    """
160    if isinstance(dilation, int):
161        return (dilation, dilation, dilation)
162    if isinstance(dilation, (tuple, list)):
163        if len(dilation) == 5:
164            return (dilation[2], dilation[3], dilation[4])
165        return dilation
166    raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
167
168
169def to_3d_paddings(op_name, arg_name, pad):
170    """
171    convert 3d paddings: int -> tuple[int*6].
172    """
173    if isinstance(pad, int):
174        return (pad,) * 6
175    if isinstance(pad, (tuple, list)):
176        return pad
177    raise ValueError(arg_invalid_info(op_name, arg_name, pad))
178
179
180def generator_handler(op_name, arg_name, inputs):
181    """
182    convert constant value in tuple to tensor
183    """
184    new_inputs = []
185    for input_ in inputs:
186        if isinstance(input_, int):
187            new_inputs.append(Tensor(input_, mstype.int64))
188        else:
189            new_inputs.append(input_)
190    return tuple(new_inputs)
191
192dtype_to_type_id = DtypeToEnum()
193
194# string to enum
195# A function for converting str type to enum type are written here,
196# but the backend supports str input, and converting str input to enum input is not necessary.
197str_to_enum = StringToEnum()
198