1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Utility functions to help distribution class.""" 16import numpy as np 17from mindspore import context 18from mindspore import _checkparam as validator 19from mindspore.common.tensor import Tensor 20from mindspore.common.parameter import Parameter 21from mindspore.common import dtype as mstype 22from mindspore.ops import operations as P 23from mindspore.ops.primitive import constexpr, _primexpr, PrimitiveWithInfer, prim_attr_register 24import mindspore.ops as ops 25import mindspore.nn as nn 26 27 28def cast_to_tensor(t, hint_type=mstype.float32): 29 """ 30 Cast an user input value into a Tensor of dtype. 31 If the input t is of type Parameter, t is directly returned as a Parameter. 32 33 Args: 34 t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. 35 dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. 36 37 Raises: 38 RuntimeError: if t cannot be cast to Tensor. 39 40 Returns: 41 Tensor. 42 """ 43 if t is None: 44 raise ValueError(f'Input cannot be None in cast_to_tensor') 45 if isinstance(t, Parameter): 46 return t 47 if isinstance(t, bool): 48 raise TypeError(f'Input cannot be Type Bool') 49 if isinstance(t, (Tensor, np.ndarray, list, int, float)): 50 return Tensor(t, dtype=hint_type) 51 invalid_type = type(t) 52 raise TypeError( 53 f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}") 54 55 56def cast_type_for_device(dtype): 57 """ 58 use the alternative dtype supported by the device. 59 Args: 60 dtype (mindspore.dtype): input dtype. 61 Returns: 62 mindspore.dtype. 63 """ 64 if context.get_context("device_target") == "GPU": 65 if dtype in mstype.uint_type or dtype == mstype.int8: 66 return mstype.int16 67 if dtype == mstype.int64: 68 return mstype.int32 69 if dtype == mstype.float64: 70 return mstype.float32 71 return dtype 72 73 74def check_greater_equal_zero(value, name): 75 """ 76 Check if the given Tensor is greater zero. 77 78 Args: 79 value (Tensor, Parameter): value to be checked. 80 name (str) : name of the value. 81 82 Raises: 83 ValueError: if the input value is less than zero. 84 85 """ 86 if isinstance(value, Parameter): 87 if not isinstance(value.data, Tensor): 88 return 89 value = value.data 90 comp = np.less(value.asnumpy(), np.zeros(value.shape)) 91 if comp.any(): 92 raise ValueError(f'{name} must be greater than or equal to zero.') 93 94 95def check_greater_zero(value, name): 96 """ 97 Check if the given Tensor is strictly greater than zero. 98 99 Args: 100 value (Tensor, Parameter): value to be checked. 101 name (str) : name of the value. 102 103 Raises: 104 ValueError: if the input value is less than or equal to zero. 105 106 """ 107 if value is None: 108 raise ValueError(f'input value cannot be None in check_greater_zero') 109 if isinstance(value, Parameter): 110 if not isinstance(value.data, Tensor): 111 return 112 value = value.data 113 comp = np.less(np.zeros(value.shape), value.asnumpy()) 114 if not comp.all(): 115 raise ValueError(f'{name} must be greater than zero.') 116 117 118def check_greater(a, b, name_a, name_b): 119 """ 120 Check if Tensor b is strictly greater than Tensor a. 121 122 Args: 123 a (Tensor, Parameter): input tensor a. 124 b (Tensor, Parameter): input tensor b. 125 name_a (str): name of Tensor_a. 126 name_b (str): name of Tensor_b. 127 128 Raises: 129 ValueError: if b is less than or equal to a 130 """ 131 if a is None or b is None: 132 raise ValueError(f'input value cannot be None in check_greater') 133 if isinstance(a, Parameter) or isinstance(b, Parameter): 134 return 135 comp = np.less(a.asnumpy(), b.asnumpy()) 136 if not comp.all(): 137 raise ValueError(f'{name_a} must be less than {name_b}') 138 139 140def check_prob(p): 141 """ 142 Check if p is a proper probability, i.e. 0 < p <1. 143 144 Args: 145 p (Tensor, Parameter): value to be checked. 146 147 Raises: 148 ValueError: if p is not a proper probability. 149 """ 150 if p is None: 151 raise ValueError(f'input value cannot be None in check_greater_zero') 152 if isinstance(p, Parameter): 153 if not isinstance(p.data, Tensor): 154 return 155 p = p.data 156 comp = np.less(np.zeros(p.shape), p.asnumpy()) 157 if not comp.all(): 158 raise ValueError('Probabilities must be greater than zero') 159 comp = np.greater(np.ones(p.shape), p.asnumpy()) 160 if not comp.all(): 161 raise ValueError('Probabilities must be less than one') 162 163 164def check_sum_equal_one(probs): 165 """ 166 Used in categorical distribution. check if probabilities of each category sum to 1. 167 """ 168 if probs is None: 169 raise ValueError(f'input value cannot be None in check_sum_equal_one') 170 if isinstance(probs, Parameter): 171 if not isinstance(probs.data, Tensor): 172 return 173 probs = probs.data 174 if isinstance(probs, Tensor): 175 probs = probs.asnumpy() 176 prob_sum = np.sum(probs, axis=-1) 177 # add a small tolerance here to increase numerical stability 178 comp = np.allclose(prob_sum, np.ones(prob_sum.shape), 179 rtol=np.finfo(prob_sum.dtype).eps * 10, atol=np.finfo(prob_sum.dtype).eps) 180 if not comp: 181 raise ValueError( 182 'Probabilities for each category should sum to one for Categorical distribution.') 183 184 185def check_rank(probs): 186 """ 187 Used in categorical distribution. check Rank >=1. 188 """ 189 if probs is None: 190 raise ValueError(f'input value cannot be None in check_rank') 191 if isinstance(probs, Parameter): 192 if not isinstance(probs.data, Tensor): 193 return 194 probs = probs.data 195 if probs.asnumpy().ndim == 0: 196 raise ValueError( 197 'probs for Categorical distribution must have rank >= 1.') 198 199 200def logits_to_probs(logits, is_binary=False): 201 """ 202 converts logits into probabilities. 203 Args: 204 logits (Tensor) 205 is_binary (bool) 206 """ 207 if is_binary: 208 return nn.Sigmoid()(logits) 209 return nn.Softmax(axis=-1)(logits) 210 211 212def clamp_probs(probs): 213 """ 214 clamp probs boundary 215 """ 216 eps = P.Eps()(probs) 217 return ops.clip_by_value(probs, eps, 1-eps) 218 219 220def probs_to_logits(probs, is_binary=False): 221 """ 222 converts probabilities into logits. 223 Args: 224 probs (Tensor) 225 is_binary (bool) 226 """ 227 ps_clamped = clamp_probs(probs) 228 if is_binary: 229 return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) 230 return P.Log()(ps_clamped) 231 232 233@constexpr(check=False) 234def raise_none_error(name): 235 raise TypeError(f"the type {name} must be subclass of Tensor." 236 f" It can not be None since it is not specified during initialization.") 237 238 239@_primexpr 240def raise_broadcast_error(shape_a, shape_b): 241 raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") 242 243 244@constexpr(check=False) 245def raise_not_impl_error(name): 246 raise ValueError( 247 f"{name} function must be implemented for non-linear transformation") 248 249 250@constexpr(check=False) 251def raise_not_implemented_util(func_name, obj, *args, **kwargs): 252 raise NotImplementedError( 253 f"{func_name} is not implemented for {obj} distribution.") 254 255 256@constexpr(check=False) 257def raise_type_error(name, cur_type, required_type): 258 raise TypeError( 259 f"For {name} , the type must be or be subclass of {required_type}, but got {cur_type}") 260 261 262@constexpr(check=False) 263def raise_not_defined(func_name, obj, *args, **kwargs): 264 raise ValueError( 265 f"{func_name} is undefined for {obj} distribution.") 266 267 268@constexpr(check=False) 269def check_distribution_name(name, expected_name): 270 if name is None: 271 raise ValueError( 272 f"Input dist must be a constant which is not None.") 273 if name != expected_name: 274 raise ValueError( 275 f"Expected dist input is {expected_name}, but got {name}.") 276 277 278class CheckTuple(PrimitiveWithInfer): 279 """ 280 Check if input is a tuple. 281 """ 282 @prim_attr_register 283 def __init__(self): 284 super(CheckTuple, self).__init__("CheckTuple") 285 self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) 286 287 def __infer__(self, x, name): 288 if not isinstance(x['dtype'], tuple): 289 raise TypeError( 290 f"For {name['value']}, Input type must b a tuple.") 291 292 out = {'shape': None, 293 'dtype': None, 294 'value': x["value"]} 295 return out 296 297 def __call__(self, x, name): 298 # The op is not used in a cell 299 if isinstance(x, tuple): 300 return x 301 if context.get_context("mode") == 0: 302 return x["value"] 303 raise TypeError(f"For {name}, input type must be a tuple.") 304 305 306class CheckTensor(PrimitiveWithInfer): 307 """ 308 Check if input is a Tensor. 309 """ 310 @prim_attr_register 311 def __init__(self): 312 super(CheckTensor, self).__init__("CheckTensor") 313 self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) 314 315 def __infer__(self, x, name): 316 src_type = x['dtype'] 317 validator.check_subclass( 318 "input", src_type, [mstype.tensor_type], name["value"]) 319 320 out = {'shape': None, 321 'dtype': None, 322 'value': None} 323 return out 324 325 def __call__(self, x, name): 326 # we skip this check in graph mode as it is checked in the infer stage 327 # and in the graph mode x is None if x is not const in the graph 328 if x is None or isinstance(x, Tensor): 329 return x 330 raise TypeError( 331 f"For {name}, input type must be a Tensor or Parameter.") 332 333 334def set_param_type(args, hint_type): 335 """ 336 Find the common type among arguments. 337 338 Args: 339 args (dict): dictionary of arguments, {'name':value}. 340 hint_type (mindspore.dtype): hint type to return. 341 342 Raises: 343 TypeError: if tensors in args are not the same dtype. 344 """ 345 int_type = mstype.int_type + mstype.uint_type 346 if hint_type in int_type or hint_type is None: 347 hint_type = mstype.float32 348 common_dtype = None 349 for name, arg in args.items(): 350 if hasattr(arg, 'dtype'): 351 if isinstance(arg, np.ndarray): 352 cur_dtype = mstype.pytype_to_dtype(arg.dtype) 353 else: 354 cur_dtype = arg.dtype 355 if common_dtype is None: 356 common_dtype = cur_dtype 357 elif cur_dtype != common_dtype: 358 raise TypeError( 359 f"{name} should have the same dtype as other arguments.") 360 if common_dtype in int_type or common_dtype == mstype.float64: 361 return mstype.float32 362 return hint_type if common_dtype is None else common_dtype 363