1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Operator argument handle function.""" 16 17from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum 18# Enum Class: 19from mindspore._c_expression import FormatEnum as Format 20from mindspore._c_expression import ReductionEnum as Reduction 21from mindspore.common import Tensor 22from mindspore.common import dtype as mstype 23 24 25def arg_invalid_info(op_name, arg_name, arg_val): 26 """ 27 generate invalid msg. 28 """ 29 return f"For '{op_name}', the value of '{arg_name}' is invalid: '{arg_val}'." 30 31 32def to_pair(op_name, arg_name, arg_val): 33 """ 34 convert arg_val: int/tuple[int*2] -> tuple[int*2]. 35 """ 36 if isinstance(arg_val, (int, float)): 37 return (arg_val, arg_val) 38 if isinstance(arg_val, (list, tuple)): 39 return arg_val 40 raise ValueError(arg_invalid_info(op_name, arg_name, arg_val)) 41 42 43def to_kernel_size(op_name, arg_name, kernel_size): 44 """ 45 convert kernel_size: int/tuple[int*4] -> tuple[int*2]. 46 """ 47 if isinstance(kernel_size, int): 48 return (kernel_size, kernel_size) 49 if isinstance(kernel_size, (tuple, list)): 50 if len(kernel_size) == 4: 51 return (kernel_size[2], kernel_size[3]) 52 return kernel_size 53 raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size)) 54 55 56def to_strides(op_name, arg_name, stride): 57 """ 58 convert strides: int/tuple[int*4] -> tuple[int*2]. 59 """ 60 if isinstance(stride, int): 61 return (stride, stride) 62 if isinstance(stride, (tuple, list)): 63 if len(stride) == 4: 64 return (stride[2], stride[3]) 65 return stride 66 raise ValueError(arg_invalid_info(op_name, arg_name, stride)) 67 68 69def to_rates(op_name, arg_name, rates): 70 """ 71 convert rates: int/tuple[int*4] -> tuple[int*2]. 72 """ 73 if isinstance(rates, int): 74 return (rates, rates) 75 if isinstance(rates, (tuple, list)): 76 if len(rates) == 4: 77 return (rates[2], rates[3]) 78 return rates 79 raise ValueError(arg_invalid_info(op_name, arg_name, rates)) 80 81 82def to_dilations(op_name, arg_name, dilation): 83 """ 84 convert dilations: int/tuple[int*4] -> tuple[int*2]. 85 """ 86 if isinstance(dilation, int): 87 return (dilation, dilation) 88 if isinstance(dilation, (tuple, list)): 89 if len(dilation) == 4: 90 return (dilation[2], dilation[3]) 91 return dilation 92 raise ValueError(arg_invalid_info(op_name, arg_name, dilation)) 93 94 95def to_output_padding(op_name, arg_name, output_padding): 96 """ 97 convert output_padding: int/tuple[int*4] -> tuple[int*2]. 98 """ 99 if isinstance(output_padding, int): 100 return (output_padding, output_padding) 101 if isinstance(output_padding, (tuple, list)): 102 if len(output_padding) == 4: 103 return (output_padding[2], output_padding[3]) 104 return output_padding 105 raise ValueError(arg_invalid_info(op_name, arg_name, output_padding)) 106 107 108def to_2d_paddings(op_name, arg_name, pad): 109 """ 110 convert paddings: int -> tuple[int*2]. 111 """ 112 if isinstance(pad, int): 113 return (pad,) * 2 114 if isinstance(pad, (tuple, list)): 115 return pad 116 raise ValueError(arg_invalid_info(op_name, arg_name, pad)) 117 118 119def to_paddings(op_name, arg_name, pad): 120 """ 121 convert paddings: int -> tuple[int*4]. 122 """ 123 if isinstance(pad, int): 124 return (pad,) * 4 125 if isinstance(pad, (tuple, list)): 126 return pad 127 raise ValueError(arg_invalid_info(op_name, arg_name, pad)) 128 129 130def to_3d_kernel_size(op_name, arg_name, kernel_size): 131 """ 132 convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3]. 133 """ 134 if isinstance(kernel_size, int): 135 return (kernel_size, kernel_size, kernel_size) 136 if isinstance(kernel_size, (tuple, list)): 137 if len(kernel_size) == 5: 138 return (kernel_size[2], kernel_size[3], kernel_size[4]) 139 return kernel_size 140 raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size)) 141 142 143def to_3d_strides(op_name, arg_name, stride): 144 """ 145 convert 3d stride: int/tuple[int*6] -> tuple[int*3]. 146 """ 147 if isinstance(stride, int): 148 return (stride, stride, stride) 149 if isinstance(stride, (tuple, list)): 150 if len(stride) == 5: 151 return (stride[2], stride[3], stride[4]) 152 return stride 153 raise ValueError(arg_invalid_info(op_name, arg_name, stride)) 154 155 156def to_3d_dilations(op_name, arg_name, dilation): 157 """ 158 convert 3d dilation: int/tuple[int*6] -> tuple[int*3]. 159 """ 160 if isinstance(dilation, int): 161 return (dilation, dilation, dilation) 162 if isinstance(dilation, (tuple, list)): 163 if len(dilation) == 5: 164 return (dilation[2], dilation[3], dilation[4]) 165 return dilation 166 raise ValueError(arg_invalid_info(op_name, arg_name, dilation)) 167 168 169def to_3d_paddings(op_name, arg_name, pad): 170 """ 171 convert 3d paddings: int -> tuple[int*6]. 172 """ 173 if isinstance(pad, int): 174 return (pad,) * 6 175 if isinstance(pad, (tuple, list)): 176 return pad 177 raise ValueError(arg_invalid_info(op_name, arg_name, pad)) 178 179 180def generator_handler(op_name, arg_name, inputs): 181 """ 182 convert constant value in tuple to tensor 183 """ 184 new_inputs = [] 185 for input_ in inputs: 186 if isinstance(input_, int): 187 new_inputs.append(Tensor(input_, mstype.int64)) 188 else: 189 new_inputs.append(input_) 190 return tuple(new_inputs) 191 192dtype_to_type_id = DtypeToEnum() 193 194# string to enum 195# A function for converting str type to enum type are written here, 196# but the backend supports str input, and converting str input to enum input is not necessary. 197str_to_enum = StringToEnum() 198