1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Utility functions to help distribution class.""" 16import numpy as np 17from mindspore import context 18from mindspore._checkparam import Validator as validator 19from mindspore.common.tensor import Tensor 20from mindspore.common.parameter import Parameter 21from mindspore.common import dtype as mstype 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register 25import mindspore.nn as nn 26 27 28def cast_to_tensor(t, hint_type=mstype.float32): 29 """ 30 Cast an user input value into a Tensor of dtype. 31 If the input t is of type Parameter, t is directly returned as a Parameter. 32 33 Args: 34 t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. 35 dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. 36 37 Raises: 38 RuntimeError: if t cannot be cast to Tensor. 39 40 Returns: 41 Tensor. 42 """ 43 if t is None: 44 raise ValueError(f'Input cannot be None in cast_to_tensor') 45 if isinstance(t, Parameter): 46 return t 47 if isinstance(t, bool): 48 raise TypeError(f'Input cannot be Type Bool') 49 if isinstance(t, (Tensor, np.ndarray, list, int, float)): 50 return Tensor(t, dtype=hint_type) 51 invalid_type = type(t) 52 raise TypeError( 53 f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}") 54 55 56def cast_type_for_device(dtype): 57 """ 58 use the alternative dtype supported by the device. 59 Args: 60 dtype (mindspore.dtype): input dtype. 61 Returns: 62 mindspore.dtype. 63 """ 64 if context.get_context("device_target") == "GPU": 65 if dtype in mstype.uint_type or dtype == mstype.int8: 66 return mstype.int16 67 if dtype == mstype.int64: 68 return mstype.int32 69 if dtype == mstype.float64: 70 return mstype.float32 71 return dtype 72 73 74def check_greater_equal_zero(value, name): 75 """ 76 Check if the given Tensor is greater zero. 77 78 Args: 79 value (Tensor, Parameter): value to be checked. 80 name (str) : name of the value. 81 82 Raises: 83 ValueError: if the input value is less than zero. 84 85 """ 86 if isinstance(value, Parameter): 87 if not isinstance(value.data, Tensor): 88 return 89 value = value.data 90 comp = np.less(value.asnumpy(), np.zeros(value.shape)) 91 if comp.any(): 92 raise ValueError(f'{name} should be greater than ot equal to zero.') 93 94 95def check_greater_zero(value, name): 96 """ 97 Check if the given Tensor is strictly greater than zero. 98 99 Args: 100 value (Tensor, Parameter): value to be checked. 101 name (str) : name of the value. 102 103 Raises: 104 ValueError: if the input value is less than or equal to zero. 105 106 """ 107 if value is None: 108 raise ValueError(f'input value cannot be None in check_greater_zero') 109 if isinstance(value, Parameter): 110 if not isinstance(value.data, Tensor): 111 return 112 value = value.data 113 comp = np.less(np.zeros(value.shape), value.asnumpy()) 114 if not comp.all(): 115 raise ValueError(f'{name} should be greater than zero.') 116 117 118def check_greater(a, b, name_a, name_b): 119 """ 120 Check if Tensor b is strictly greater than Tensor a. 121 122 Args: 123 a (Tensor, Parameter): input tensor a. 124 b (Tensor, Parameter): input tensor b. 125 name_a (str): name of Tensor_a. 126 name_b (str): name of Tensor_b. 127 128 Raises: 129 ValueError: if b is less than or equal to a 130 """ 131 if a is None or b is None: 132 raise ValueError(f'input value cannot be None in check_greater') 133 if isinstance(a, Parameter) or isinstance(b, Parameter): 134 return 135 comp = np.less(a.asnumpy(), b.asnumpy()) 136 if not comp.all(): 137 raise ValueError(f'{name_a} should be less than {name_b}') 138 139 140def check_prob(p): 141 """ 142 Check if p is a proper probability, i.e. 0 < p <1. 143 144 Args: 145 p (Tensor, Parameter): value to be checked. 146 147 Raises: 148 ValueError: if p is not a proper probability. 149 """ 150 if p is None: 151 raise ValueError(f'input value cannot be None in check_greater_zero') 152 if isinstance(p, Parameter): 153 if not isinstance(p.data, Tensor): 154 return 155 p = p.data 156 comp = np.less(np.zeros(p.shape), p.asnumpy()) 157 if not comp.all(): 158 raise ValueError('Probabilities should be greater than zero') 159 comp = np.greater(np.ones(p.shape), p.asnumpy()) 160 if not comp.all(): 161 raise ValueError('Probabilities should be less than one') 162 163 164def check_sum_equal_one(probs): 165 """ 166 Used in categorical distribution. check if probabilities of each category sum to 1. 167 """ 168 if probs is None: 169 raise ValueError(f'input value cannot be None in check_sum_equal_one') 170 if isinstance(probs, Parameter): 171 if not isinstance(probs.data, Tensor): 172 return 173 probs = probs.data 174 if isinstance(probs, Tensor): 175 probs = probs.asnumpy() 176 prob_sum = np.sum(probs, axis=-1) 177 # add a small tolerance here to increase numerical stability 178 comp = np.allclose(prob_sum, np.ones(prob_sum.shape), rtol=1e-14, atol=1e-14) 179 if not comp: 180 raise ValueError('Probabilities for each category should sum to one for Categorical distribution.') 181 182 183def check_rank(probs): 184 """ 185 Used in categorical distribution. check Rank >=1. 186 """ 187 if probs is None: 188 raise ValueError(f'input value cannot be None in check_rank') 189 if isinstance(probs, Parameter): 190 if not isinstance(probs.data, Tensor): 191 return 192 probs = probs.data 193 if probs.asnumpy().ndim == 0: 194 raise ValueError('probs for Categorical distribution must have rank >= 1.') 195 196 197def logits_to_probs(logits, is_binary=False): 198 """ 199 converts logits into probabilities. 200 Args: 201 logits (Tensor) 202 is_binary (bool) 203 """ 204 if is_binary: 205 return nn.Sigmoid()(logits) 206 return nn.Softmax(axis=-1)(logits) 207 208 209def clamp_probs(probs): 210 """ 211 clamp probs boundary 212 """ 213 eps = P.Eps()(probs) 214 return C.clip_by_value(probs, eps, 1-eps) 215 216 217def probs_to_logits(probs, is_binary=False): 218 """ 219 converts probabilities into logits. 220 Args: 221 probs (Tensor) 222 is_binary (bool) 223 """ 224 ps_clamped = clamp_probs(probs) 225 if is_binary: 226 return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) 227 return P.Log()(ps_clamped) 228 229 230@constexpr 231def raise_none_error(name): 232 raise TypeError(f"the type {name} should be subclass of Tensor." 233 f" It should not be None since it is not specified during initialization.") 234 235 236@constexpr 237def raise_probs_logits_error(): 238 raise TypeError("Either 'probs' or 'logits' must be specified, but not both.") 239 240 241@constexpr 242def raise_broadcast_error(shape_a, shape_b): 243 raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") 244 245 246@constexpr 247def raise_not_impl_error(name): 248 raise ValueError( 249 f"{name} function should be implemented for non-linear transformation") 250 251 252@constexpr 253def raise_not_implemented_util(func_name, obj, *args, **kwargs): 254 raise NotImplementedError( 255 f"{func_name} is not implemented for {obj} distribution.") 256 257 258@constexpr 259def raise_type_error(name, cur_type, required_type): 260 raise TypeError( 261 f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}") 262 263 264@constexpr 265def raise_not_defined(func_name, obj, *args, **kwargs): 266 raise ValueError( 267 f"{func_name} is undefined for {obj} distribution.") 268 269 270@constexpr 271def check_distribution_name(name, expected_name): 272 if name is None: 273 raise ValueError( 274 f"Input dist should be a constant which is not None.") 275 if name != expected_name: 276 raise ValueError( 277 f"Expected dist input is {expected_name}, but got {name}.") 278 279 280class CheckTuple(PrimitiveWithInfer): 281 """ 282 Check if input is a tuple. 283 """ 284 @prim_attr_register 285 def __init__(self): 286 super(CheckTuple, self).__init__("CheckTuple") 287 self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) 288 289 def __infer__(self, x, name): 290 if not isinstance(x['dtype'], tuple): 291 raise TypeError( 292 f"For {name['value']}, Input type should b a tuple.") 293 294 out = {'shape': None, 295 'dtype': None, 296 'value': x["value"]} 297 return out 298 299 def __call__(self, x, name): 300 # The op is not used in a cell 301 if isinstance(x, tuple): 302 return x 303 if context.get_context("mode") == 0: 304 return x["value"] 305 raise TypeError(f"For {name}, input type should be a tuple.") 306 307 308class CheckTensor(PrimitiveWithInfer): 309 """ 310 Check if input is a Tensor. 311 """ 312 @prim_attr_register 313 def __init__(self): 314 super(CheckTensor, self).__init__("CheckTensor") 315 self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) 316 317 def __infer__(self, x, name): 318 src_type = x['dtype'] 319 validator.check_subclass( 320 "input", src_type, [mstype.tensor], name["value"]) 321 322 out = {'shape': None, 323 'dtype': None, 324 'value': None} 325 return out 326 327 def __call__(self, x, name): 328 if isinstance(x, Tensor): 329 return x 330 raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") 331 332 333def set_param_type(args, hint_type): 334 """ 335 Find the common type among arguments. 336 337 Args: 338 args (dict): dictionary of arguments, {'name':value}. 339 hint_type (mindspore.dtype): hint type to return. 340 341 Raises: 342 TypeError: if tensors in args are not the same dtype. 343 """ 344 int_type = mstype.int_type + mstype.uint_type 345 if hint_type in int_type or hint_type is None: 346 hint_type = mstype.float32 347 common_dtype = None 348 for name, arg in args.items(): 349 if hasattr(arg, 'dtype'): 350 if isinstance(arg, np.ndarray): 351 cur_dtype = mstype.pytype_to_dtype(arg.dtype) 352 else: 353 cur_dtype = arg.dtype 354 if common_dtype is None: 355 common_dtype = cur_dtype 356 elif cur_dtype != common_dtype: 357 raise TypeError(f"{name} should have the same dtype as other arguments.") 358 if common_dtype in int_type or common_dtype == mstype.float64: 359 return mstype.float32 360 return hint_type if common_dtype is None else common_dtype 361