1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Logistic Distribution""" 16import numpy as np 17from mindspore.ops import operations as P 18from mindspore.ops import composite as C 19from mindspore._checkparam import Validator 20from mindspore.common import dtype as mstype 21from .distribution import Distribution 22from ._utils.utils import check_greater_zero 23from ._utils.custom_ops import exp_generic, log_generic 24 25 26class Logistic(Distribution): 27 """ 28 Logistic distribution. 29 30 Args: 31 loc (int, float, list, numpy.ndarray, Tensor): The location of the Logistic distribution. Default: None. 32 scale (int, float, list, numpy.ndarray, Tensor): The scale of the Logistic distribution. Default: None. 33 seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. 34 dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. 35 name (str): The name of the distribution. Default: 'Logistic'. 36 37 Supported Platforms: 38 ``Ascend`` ``GPU`` 39 40 Note: 41 `scale` must be greater than zero. 42 `dist_spec_args` are `loc` and `scale`. 43 `dtype` must be a float type because Logistic distributions are continuous. 44 45 Examples: 46 >>> import mindspore 47 >>> import mindspore.nn as nn 48 >>> import mindspore.nn.probability.distribution as msd 49 >>> from mindspore import Tensor 50 >>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0. 51 >>> l1 = msd.Logistic(3.0, 4.0, dtype=mindspore.float32) 52 >>> # A Logistic distribution can be initialized without arguments. 53 >>> # In this case, `loc` and `scale` must be passed in through arguments. 54 >>> l2 = msd.Logistic(dtype=mindspore.float32) 55 >>> 56 >>> # Here are some tensors used below for testing 57 >>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) 58 >>> loc_a = Tensor([2.0], dtype=mindspore.float32) 59 >>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32) 60 >>> loc_b = Tensor([1.0], dtype=mindspore.float32) 61 >>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32) 62 >>> 63 >>> # Private interfaces of probability functions corresponding to public interfaces, including 64 >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, 65 >>> # have the same arguments as follows. 66 >>> # Args: 67 >>> # value (Tensor): the value to be evaluated. 68 >>> # loc (Tensor): the location of the distribution. Default: self.loc. 69 >>> # scale (Tensor): the scale of the distribution. Default: self.scale. 70 >>> # Examples of `prob`. 71 >>> # Similar calls can be made to other probability functions 72 >>> # by replacing 'prob' by the name of the function 73 >>> ans = l1.prob(value) 74 >>> print(ans.shape) 75 (3,) 76 >>> # Evaluate with respect to distribution b. 77 >>> ans = l1.prob(value, loc_b, scale_b) 78 >>> print(ans.shape) 79 (3,) 80 >>> # `loc` and `scale` must be passed in during function calls 81 >>> ans = l1.prob(value, loc_a, scale_a) 82 >>> print(ans.shape) 83 (3,) 84 >>> # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments. 85 >>> # Args: 86 >>> # loc (Tensor): the location of the distribution. Default: self.loc. 87 >>> # scale (Tensor): the scale of the distribution. Default: self.scale. 88 >>> # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar. 89 >>> ans = l1.mean() 90 >>> print(ans.shape) 91 () 92 >>> ans = l1.mean(loc_b, scale_b) 93 >>> print(ans.shape) 94 (3,) 95 >>> # `loc` and `scale` must be passed in during function calls. 96 >>> ans = l1.mean(loc_a, scale_a) 97 >>> print(ans.shape) 98 (3,) 99 >>> # Examples of `sample`. 100 >>> # Args: 101 >>> # shape (tuple): the shape of the sample. Default: () 102 >>> # loc (Tensor): the location of the distribution. Default: self.loc. 103 >>> # scale (Tensor): the scale of the distribution. Default: self.scale. 104 >>> ans = l1.sample() 105 >>> print(ans.shape) 106 () 107 >>> ans = l1.sample((2,3)) 108 >>> print(ans.shape) 109 (2, 3) 110 >>> ans = l1.sample((2,3), loc_b, scale_b) 111 >>> print(ans.shape) 112 (2, 3, 3) 113 >>> ans = l1.sample((2,3), loc_a, scale_a) 114 >>> print(ans.shape) 115 (2, 3, 3) 116 """ 117 118 def __init__(self, 119 loc=None, 120 scale=None, 121 seed=None, 122 dtype=mstype.float32, 123 name="Logistic"): 124 """ 125 Constructor of Logistic. 126 """ 127 param = dict(locals()) 128 param['param_dict'] = {'loc': loc, 'scale': scale} 129 valid_dtype = mstype.float_type 130 Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) 131 super(Logistic, self).__init__(seed, dtype, name, param) 132 133 self._loc = self._add_parameter(loc, 'loc') 134 self._scale = self._add_parameter(scale, 'scale') 135 if self._scale is not None: 136 check_greater_zero(self._scale, "scale") 137 138 # ops needed for the class 139 self.cast = P.Cast() 140 self.const = P.ScalarToArray() 141 self.consttensor = P.ScalarToTensor() 142 self.dtypeop = P.DType() 143 self.exp = exp_generic 144 self.expm1 = P.Expm1() 145 self.fill = P.Fill() 146 self.less = P.Less() 147 self.log = log_generic 148 self.log1p = P.Log1p() 149 self.logicalor = P.LogicalOr() 150 self.erf = P.Erf() 151 self.greater = P.Greater() 152 self.sigmoid = P.Sigmoid() 153 self.squeeze = P.Squeeze(0) 154 self.select = P.Select() 155 self.shape = P.Shape() 156 self.softplus = self._softplus 157 self.sqrt = P.Sqrt() 158 self.uniform = C.uniform 159 160 self.threshold = np.log(np.finfo(np.float32).eps) + 1. 161 self.tiny = np.finfo(np.float).tiny 162 self.sd_const = np.pi/np.sqrt(3) 163 164 def _softplus(self, x): 165 too_small = self.less(x, self.threshold) 166 too_large = self.greater(x, -self.threshold) 167 too_small_value = self.exp(x) 168 too_large_value = x 169 too_small_or_too_large = self.logicalor(too_small, too_large) 170 ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) 171 x = self.select(too_small_or_too_large, ones, x) 172 y = self.log(self.exp(x) + 1.0) 173 return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) 174 175 def extend_repr(self): 176 """Display instance object as string.""" 177 if self.is_scalar_batch: 178 s = 'location = {}, scale = {}'.format(self._loc, self._scale) 179 else: 180 s = 'batch_shape = {}'.format(self._broadcast_shape) 181 return s 182 183 @property 184 def loc(self): 185 """ 186 Return the location of the distribution after casting to dtype. 187 """ 188 return self._loc 189 190 @property 191 def scale(self): 192 """ 193 Return the scale of the distribution after casting to dtype. 194 """ 195 return self._scale 196 197 def _get_dist_type(self): 198 return "Logistic" 199 200 def _get_dist_args(self, loc=None, scale=None): 201 if loc is None: 202 loc = self.loc 203 else: 204 self.checktensor(loc, 'loc') 205 if scale is None: 206 scale = self.scale 207 else: 208 self.checktensor(scale, 'scale') 209 return loc, scale 210 211 def _mean(self, loc=None, scale=None): 212 """ 213 The mean of the distribution. 214 """ 215 loc, scale = self._check_param_type(loc, scale) 216 return loc 217 218 def _mode(self, loc=None, scale=None): 219 """ 220 The mode of the distribution. 221 """ 222 loc, scale = self._check_param_type(loc, scale) 223 return loc 224 225 def _sd(self, loc=None, scale=None): 226 """ 227 The standard deviation of the distribution. 228 """ 229 _, scale = self._check_param_type(loc, scale) 230 return scale * self.consttensor(self.sd_const, self.dtypeop(scale)) 231 232 def _entropy(self, loc=None, scale=None): 233 r""" 234 Evaluate entropy. 235 236 .. math:: 237 H(X) = \log(scale) + 2. 238 """ 239 loc, scale = self._check_param_type(loc, scale) 240 return self.log(scale) + 2. 241 242 def _log_prob(self, value, loc=None, scale=None): 243 r""" 244 Evaluate log probability. 245 246 Args: 247 value (Tensor): The value to be evaluated. 248 loc (Tensor): The location of the distribution. Default: self.loc. 249 scale (Tensor): The scale of the distribution. Default: self.scale. 250 251 .. math:: 252 z = (x - \mu) / \sigma 253 L(x) = -z * -2. * softplus(-z) - \log(\sigma) 254 """ 255 value = self._check_value(value, 'value') 256 value = self.cast(value, self.dtype) 257 loc, scale = self._check_param_type(loc, scale) 258 z = (value - loc) / scale 259 return -z - 2. * self.softplus(-z) - self.log(scale) 260 261 def _cdf(self, value, loc=None, scale=None): 262 r""" 263 Evaluate the cumulative distribution function on the given value. 264 265 Args: 266 value (Tensor): The value to be evaluated. 267 loc (Tensor): The location of the distribution. Default: self.loc. 268 scale (Tensor): The scale the distribution. Default: self.scale. 269 270 .. math:: 271 cdf(x) = sigmoid((x - loc) / scale) 272 """ 273 value = self._check_value(value, 'value') 274 value = self.cast(value, self.dtype) 275 loc, scale = self._check_param_type(loc, scale) 276 z = (value - loc) / scale 277 return self.sigmoid(z) 278 279 def _log_cdf(self, value, loc=None, scale=None): 280 r""" 281 Evaluate the log cumulative distribution function on the given value. 282 283 Args: 284 value (Tensor): The value to be evaluated. 285 loc (Tensor): The location of the distribution. Default: self.loc. 286 scale (Tensor): The scale the distribution. Default: self.scale. 287 288 .. math:: 289 log_cdf(x) = -softplus(-(x - loc) / scale) 290 """ 291 value = self._check_value(value, 'value') 292 value = self.cast(value, self.dtype) 293 loc, scale = self._check_param_type(loc, scale) 294 z = (value - loc) / scale 295 return (-1) * self.softplus(-z) 296 297 def _survival_function(self, value, loc=None, scale=None): 298 r""" 299 Evaluate the survival function on the given value. 300 301 Args: 302 value (Tensor): The value to be evaluated. 303 loc (Tensor): The location of the distribution. Default: self.loc. 304 scale (Tensor): The scale the distribution. Default: self.scale. 305 306 .. math:: 307 survival(x) = sigmoid(-(x - loc) / scale) 308 """ 309 value = self._check_value(value, 'value') 310 value = self.cast(value, self.dtype) 311 loc, scale = self._check_param_type(loc, scale) 312 z = (value - loc) / scale 313 return self.sigmoid(-z) 314 315 def _log_survival(self, value, loc=None, scale=None): 316 r""" 317 Evaluate the log survival function on the given value. 318 319 Args: 320 value (Tensor): The value to be evaluated. 321 loc (Tensor): The location of the distribution. Default: self.loc. 322 scale (Tensor): The scale the distribution. Default: self.scale. 323 324 .. math:: 325 survival(x) = -softplus((x - loc) / scale) 326 """ 327 value = self._check_value(value, 'value') 328 value = self.cast(value, self.dtype) 329 loc, scale = self._check_param_type(loc, scale) 330 z = (value - loc) / scale 331 return (-1) * self.softplus(z) 332 333 def _sample(self, shape=(), loc=None, scale=None): 334 """ 335 Sampling. 336 337 Args: 338 shape (tuple): The shape of the sample. Default: (). 339 loc (Tensor): The location of the samples. Default: self.loc. 340 scale (Tensor): The scale of the samples. Default: self.scale. 341 342 Returns: 343 Tensor, with the shape being shape + batch_shape. 344 """ 345 shape = self.checktuple(shape, 'shape') 346 loc, scale = self._check_param_type(loc, scale) 347 batch_shape = self.shape(loc + scale) 348 origin_shape = shape + batch_shape 349 if origin_shape == (): 350 sample_shape = (1,) 351 else: 352 sample_shape = origin_shape 353 l_zero = self.const(self.tiny) 354 h_one = self.const(1.0) 355 sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed) 356 sample = self.log(sample_uniform) - self.log1p(sample_uniform) 357 sample = sample * scale + loc 358 value = self.cast(sample, self.dtype) 359 if origin_shape == (): 360 value = self.squeeze(value) 361 return value 362