1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""math""" 16import numpy as np 17from mindspore.ops import operations as P 18from mindspore.common.tensor import Tensor 19from mindspore.common._decorator import deprecated 20from mindspore.ops.primitive import constexpr 21from mindspore.ops import functional as F 22from ..cell import Cell 23from ...common import dtype as mstype 24from ..._checkparam import Validator as validator 25 26__all__ = ['ReduceLogSumExp', 27 'Range', 28 'LGamma', 29 'DiGamma', 30 'IGamma', 31 'LBeta', 32 'MatMul', 33 'Moments', 34 'MatInverse', 35 'MatDet', 36 ] 37 38_BASE_LANCZOS_COEFF = 0.99999999999980993227684700473478 39_LANCZOS_COEFFICIENTS = [676.520368121885098567009190444019, 40 -1259.13921672240287047156078755283, 41 771.3234287776530788486528258894, 42 -176.61502916214059906584551354, 43 12.507343278686904814458936853, 44 -0.13857109526572011689554707, 45 9.984369578019570859563e-6, 46 1.50563273514931155834e-7] 47 48 49@constexpr 50def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): 51 validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) 52 53 54class ReduceLogSumExp(Cell): 55 r""" 56 Reduces a dimension of a tensor by calculating exponential for all elements in the dimension, 57 then calculate logarithm of the sum. 58 59 .. math:: 60 61 ReduceLogSumExp(x) = \log(\sum(e^x)) 62 63 Args: 64 axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. 65 Only constant value is allowed. 66 keep_dims (bool): If True, keep these reduced dimensions and the length is 1. 67 If False, don't keep these dimensions. 68 Default : False. 69 70 Inputs: 71 - **x** (Tensor) - The input tensor. With float16 or float32 data type. 72 73 Outputs: 74 Tensor, has the same dtype as the `x`. 75 76 - If axis is (), and keep_dims is False, 77 the output is a 0-D tensor representing the sum of all elements in the input tensor. 78 - If axis is int, set as 2, and keep_dims is False, 79 the shape of output is :math:`(x_1, x_3, ..., x_R)`. 80 - If axis is tuple(int), set as (2, 3), and keep_dims is False, 81 the shape of output is :math:`(x_1, x_4, ..., x_R)`. 82 83 Raises: 84 TypeError: If `axis` is not one of int, list, tuple. 85 TypeError: If `keep_dims` is not bool. 86 TypeError: If dtype of `x` is neither float16 nor float32. 87 88 Supported Platforms: 89 ``Ascend`` ``GPU`` ``CPU`` 90 91 Examples: 92 >>> x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32)) 93 >>> op = nn.ReduceLogSumExp(1, keep_dims=True) 94 >>> output = op(x) 95 >>> print(output.shape) 96 (3, 1, 5, 6) 97 """ 98 99 def __init__(self, axis, keep_dims=False): 100 """Initialize ReduceLogSumExp.""" 101 super(ReduceLogSumExp, self).__init__() 102 validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name) 103 validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) 104 self.axis = axis 105 self.exp = P.Exp() 106 self.sum = P.ReduceSum(keep_dims) 107 self.log = P.Log() 108 109 def construct(self, x): 110 exp = self.exp(x) 111 sumexp = self.sum(exp, self.axis) 112 logsumexp = self.log(sumexp) 113 return logsumexp 114 115 116class Range(Cell): 117 r""" 118 Creates a sequence of numbers in range [start, limit) with step size delta. 119 120 The size of output is :math:`\left \lfloor \frac{limit-start}{delta} \right \rfloor + 1` and `delta` is the gap 121 between two values in the tensor. 122 123 .. math:: 124 125 out_{i+1} = out_{i} +delta 126 127 Args: 128 start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry 129 defaults to `0`. Otherwise, it acts as first entry in the range. 130 limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start` 131 while set the first entry of the range to `0`. It can not be equal to `start`. Default: None. 132 delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1. 133 134 Outputs: 135 Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float. 136 137 Supported Platforms: 138 ``Ascend`` ``GPU`` ``CPU`` 139 140 Examples: 141 >>> net = nn.Range(1, 8, 2) 142 >>> output = net() 143 >>> print(output) 144 [1 3 5 7] 145 """ 146 147 def __init__(self, start, limit=None, delta=1): 148 """Initialize Range.""" 149 super(Range, self).__init__() 150 if delta == 0: 151 raise ValueError(f"For '{self.cls_name}', the 'delta' can not be zero.") 152 data = np.arange(start, limit, delta) 153 if data.dtype == np.float: 154 self.ms_dtype = mstype.float32 155 else: 156 self.ms_dtype = mstype.int32 157 self.result_tensor = Tensor(data, dtype=self.ms_dtype) 158 159 def construct(self): 160 return self.result_tensor 161 162 163class LGamma(Cell): 164 r""" 165 Calculates LGamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function". 166 The algorithm is: 167 168 .. math:: 169 \begin{array}{ll} \\ 170 lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z) \\ 171 t(z) = z + kLanczosGamma + 1/2 \\ 172 A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} 173 \end{array} 174 175 However, if the input is less than 0.5 use Euler's reflection formula: 176 177 .. math:: 178 179 lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x))) 180 181 And please note that 182 183 .. math:: 184 185 lgamma(+/-inf) = +inf 186 187 Thus, the behaviour of LGamma follows: 188 189 - when x > 0.5, return log(Gamma(x)) 190 - when x < 0.5 and is not an integer, return the real part of Log(Gamma(x)) where Log is the complex logarithm 191 - when x is an integer less or equal to 0, return +inf 192 - when x = +/- inf, return +inf 193 194 Inputs: 195 - **x** (Tensor) - The input tensor. Only float16, float32 are supported. 196 197 Outputs: 198 Tensor, has the same shape and dtype as the `x`. 199 200 Raises: 201 TypeError: If dtype of `x` is neither float16 nor float32. 202 203 Supported Platforms: 204 ``Ascend`` ``GPU`` 205 206 Examples: 207 >>> x = Tensor(np.array([2, 3, 4]).astype(np.float32)) 208 >>> op = nn.LGamma() 209 >>> output = op(x) 210 >>> print(output) 211 [3.5762787e-07 6.9314754e-01 1.7917603e+00] 212 """ 213 214 def __init__(self): 215 """Initialize LGamma.""" 216 super(LGamma, self).__init__() 217 # const numbers 218 self.k_lanczos_gamma = 7 219 self.k_base_lanczos_coeff = _BASE_LANCZOS_COEFF 220 self.k_lanczos_coefficients = _LANCZOS_COEFFICIENTS 221 self.one_half = 0.5 222 self.one = 1 223 self.two = 2 224 self.inf = np.inf 225 self.pi = np.pi 226 self.log_2 = np.log(self.two) 227 self.log_pi = np.log(np.pi) 228 self.log_sqrt_two_pi = (self.log_2 + self.log_pi) / self.two 229 self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5 230 self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half) 231 232 # operations 233 self.log = P.Log() 234 self.log1p = P.Log1p() 235 self.abs = P.Abs() 236 self.shape = P.Shape() 237 self.dtype = P.DType() 238 self.fill = P.Fill() 239 self.floor = P.Floor() 240 self.equal = P.Equal() 241 self.greater = P.Greater() 242 self.less = P.Less() 243 self.lessequal = P.LessEqual() 244 self.select = P.Select() 245 self.sin = P.Sin() 246 self.isfinite = P.IsFinite() 247 248 def construct(self, x): 249 input_dtype = self.dtype(x) 250 _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) 251 infinity = self.fill(input_dtype, self.shape(x), self.inf) 252 253 need_to_reflect = self.less(x, 0.5) 254 neg_input = -x 255 z = self.select(need_to_reflect, neg_input, x - 1) 256 257 @constexpr 258 def _calculate_reflected_x(z, k_base_lanczos_coeff, k_lanczos_coefficients): 259 reflex_x = k_base_lanczos_coeff 260 for i in range(8): 261 product_ = k_lanczos_coefficients[i] / (z + i + 1) 262 reflex_x = product_ + reflex_x 263 return reflex_x 264 reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients) 265 266 t = z + self.lanczos_gamma_plus_one_half 267 log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half 268 269 log_y = self.log(reflex_x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi 270 271 abs_input = self.abs(x) 272 abs_frac_input = abs_input - self.floor(abs_input) 273 x = self.select(self.lessequal(x, 0.0), self.select(self.equal(abs_frac_input, 0.0), infinity, x), x) 274 reduced_frac_input = self.select(self.greater(abs_frac_input, 0.5), 275 1 - abs_frac_input, abs_frac_input) 276 reflection_denom = self.log(self.sin(self.pi * reduced_frac_input)) 277 278 reflection = self.select(self.isfinite(reflection_denom), 279 -reflection_denom - log_y + self.log_pi, # pylint: disable=invalid-unary-operand-type 280 -reflection_denom) # pylint: disable=invalid-unary-operand-type 281 282 result = self.select(need_to_reflect, reflection, log_y) 283 284 return self.select(self.isfinite(x), result, infinity) 285 286 287class DiGamma(Cell): 288 r""" 289 Calculates Digamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function". 290 The algorithm is: 291 292 .. math:: 293 \begin{array}{ll} \\ 294 digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) \\ 295 t(z) = z + kLanczosGamma + 1/2 \\ 296 A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} \\ 297 A'(z) = \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{{z + k}^2} 298 \end{array} 299 300 However, if the input is less than 0.5 use Euler's reflection formula: 301 302 .. math:: 303 304 digamma(x) = digamma(1 - x) - pi * cot(pi * x) 305 306 Inputs: 307 - **x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. 308 309 Outputs: 310 Tensor, has the same shape and dtype as the `x`. 311 312 Raises: 313 TypeError: If dtype of `x` is neither float16 nor float32. 314 315 Supported Platforms: 316 ``Ascend`` ``GPU`` 317 318 Examples: 319 >>> x = Tensor(np.array([2, 3, 4]).astype(np.float32)) 320 >>> op = nn.DiGamma() 321 >>> output = op(x) 322 >>> print(output) 323 [0.42278463 0.92278427 1.2561178] 324 """ 325 326 def __init__(self): 327 """Initialize DiGamma.""" 328 super(DiGamma, self).__init__() 329 # const numbers 330 self.k_lanczos_gamma = 7 331 self.k_base_lanczos_coeff = _BASE_LANCZOS_COEFF 332 self.k_lanczos_coefficients = _LANCZOS_COEFFICIENTS 333 self.nan = np.nan 334 self.pi = np.pi 335 self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5 336 self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half) 337 338 # operations 339 self.log1p = P.Log1p() 340 self.abs = P.Abs() 341 self.shape = P.Shape() 342 self.dtype = P.DType() 343 self.fill = P.Fill() 344 self.floor = P.Floor() 345 self.equal = P.Equal() 346 self.less = P.Less() 347 self.select = P.Select() 348 self.sin = P.Sin() 349 self.cos = P.Cos() 350 self.logicaland = P.LogicalAnd() 351 352 def construct(self, x): 353 input_dtype = self.dtype(x) 354 _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) 355 need_to_reflect = self.less(x, 0.5) 356 neg_input = -x 357 z = self.select(need_to_reflect, neg_input, x - 1) 358 359 @constexpr 360 def _calculate_num_denom(z, k_base_lanczos_coeff, k_lanczos_coefficients): 361 num = 0 362 denom = k_base_lanczos_coeff 363 for i in range(8): 364 num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1)) 365 denom = denom + k_lanczos_coefficients[i] / (z + i + 1) 366 return num, denom 367 num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients) 368 369 t = z + self.lanczos_gamma_plus_one_half 370 log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half 371 372 y = log_t + num / denom - self.k_lanczos_gamma / t 373 374 reduced_input = x + self.abs(self.floor(x + 0.5)) 375 reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input) 376 real_result = self.select(need_to_reflect, reflection, y) 377 nan = self.fill(self.dtype(x), self.shape(x), np.nan) 378 379 return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))), 380 nan, real_result) 381 382 383eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) 384 385 386def _while_helper_func(cond, body, vals): 387 while cond(vals).any(): 388 vals = body(vals) 389 return vals 390 391 392def _igamma_series(ax, x, a, enabled): 393 """Helper function for computing Igamma using a power series.""" 394 395 logicaland = P.LogicalAnd() 396 greater = P.Greater() 397 fill = P.Fill() 398 shape = P.Shape() 399 dtype = P.DType() 400 select = P.Select() 401 402 # If more data types are supported, this epsilon need to be selected. 403 epsilon = eps_fp32 404 405 def cond(vals): 406 enabled = vals[0] 407 return enabled 408 409 def body(vals): 410 enabled = vals[0] 411 r = vals[1] 412 c = vals[2] 413 ans = vals[3] 414 x = vals[4] 415 dc_da = vals[5] 416 dans_da = vals[6] 417 418 r = r + 1 419 dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r) 420 dans_da = dans_da + dc_da 421 c = c * (x / r) 422 ans = ans + c 423 conditional = logicaland(enabled, greater(c / ans, epsilon)) 424 425 return (conditional, select(enabled, r, vals[1]), 426 select(enabled, c, vals[2]), select(enabled, ans, vals[3]), 427 select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]), 428 select(enabled, dans_da, vals[6])) 429 430 ones = fill(dtype(a), shape(a), 1) 431 zeros = fill(dtype(a), shape(a), 0) 432 vals = (enabled, a, ones, ones, x, zeros, zeros) 433 434 vals = _while_helper_func(cond, body, vals) 435 ans = vals[3] 436 return (ans * ax) / a 437 438 439def _igammac_continued_fraction(ax, x, a, enabled): 440 """Helper function for computing Igammac using a continued fraction.""" 441 442 abs_x = P.Abs() 443 logicaland = P.LogicalAnd() 444 greater = P.Greater() 445 less = P.Less() 446 notequal = P.NotEqual() 447 fill = P.Fill() 448 shape = P.Shape() 449 dtype = P.DType() 450 select = P.Select() 451 452 # If more data types are supported, this epsilon need to be selected. 453 epsilon = eps_fp32 454 455 def cond(vals): 456 enabled = vals[0] 457 c = vals[5] 458 return logicaland(less(c, 2000), enabled) 459 460 def body(vals): 461 enabled = vals[0] 462 ans = vals[1] 463 t = vals[2] 464 y = vals[3] 465 z = vals[4] 466 c = vals[5] 467 pkm1 = vals[6] 468 qkm1 = vals[7] 469 pkm2 = vals[8] 470 qkm2 = vals[9] 471 472 dpkm2_da = vals[10] 473 dqkm2_da = vals[11] 474 dpkm1_da = vals[12] 475 dqkm1_da = vals[13] 476 dans_da = vals[14] 477 478 c = c + 1 479 y = y + 1 480 z = z + 2 481 482 yc = y * c 483 pk = pkm1 * z - pkm2 * yc 484 qk = qkm1 * z - qkm2 * yc 485 qk_is_nonzero = notequal(qk, 0) 486 r = pk / qk 487 488 t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1)) 489 ans = select(qk_is_nonzero, r, ans) 490 491 dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c 492 dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c 493 dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da) 494 grad_conditional = select(qk_is_nonzero, 495 abs_x(dans_da_new - dans_da), 496 fill(dtype(dans_da), shape(dans_da), 1)) 497 498 pkm2 = pkm1 499 pkm1 = pk 500 qkm2 = qkm1 501 qkm1 = qk 502 503 dpkm2_da = dpkm1_da 504 dqkm2_da = dqkm1_da 505 dpkm1_da = dpk_da 506 dqkm1_da = dqk_da 507 508 rescale = greater(abs_x(pk), 1 / epsilon) 509 pkm2 = select(rescale, pkm2 * epsilon, pkm2) 510 pkm1 = select(rescale, pkm1 * epsilon, pkm1) 511 qkm2 = select(rescale, qkm2 * epsilon, qkm2) 512 qkm1 = select(rescale, qkm1 * epsilon, qkm1) 513 514 dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da) 515 dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da) 516 dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da) 517 dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da) 518 519 conditional = logicaland(enabled, greater(grad_conditional, epsilon)) 520 521 return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]), 522 select(enabled, y, vals[3]), select(enabled, z, vals[4]), 523 c, select(enabled, pkm1, vals[6]), 524 select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]), 525 select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]), 526 select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]), 527 select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14])) 528 529 y = 1 - a 530 z = x + y + 1 531 c = fill(dtype(x), shape(x), 0) 532 pkm2 = fill(dtype(x), shape(x), 1) 533 qkm2 = x 534 pkm1 = x + 1 535 qkm1 = z * x 536 ans = pkm1 / qkm1 537 t = fill(dtype(x), shape(x), 1) 538 dpkm2_da = fill(dtype(x), shape(x), 0) 539 dqkm2_da = fill(dtype(x), shape(x), 0) 540 dpkm1_da = fill(dtype(x), shape(x), 0) 541 dqkm1_da = -x 542 dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1 543 vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) 544 vals = _while_helper_func(cond, body, vals) 545 ans = vals[1] 546 return ans * ax 547 548 549class IGamma(Cell): 550 r""" 551 Calculates lower regularized incomplete Gamma function. 552 The lower regularized incomplete Gamma function is defined as: 553 554 .. math:: 555 P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) 556 557 where 558 559 .. math:: 560 gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt 561 562 is the lower incomplete Gamma function. 563 564 Above :math:`Q(a, x)` is the upper regularized complete Gamma function. 565 566 Inputs: 567 - **a** (Tensor) - The input tensor. With float32 data type. `a` should have 568 the same dtype with `x`. 569 - **x** (Tensor) - The input tensor. With float32 data type. `x` should have 570 the same dtype with `a`. 571 572 Outputs: 573 Tensor, has the same dtype as `a` and `x`. 574 575 Raises: 576 TypeError: If dtype of input x and a is not float16 nor float32, 577 or if x has different dtype with a. 578 579 Supported Platforms: 580 ``Ascend`` ``GPU`` 581 582 Examples: 583 >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) 584 >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) 585 >>> igamma = nn.IGamma() 586 >>> output = igamma(a, x) 587 >>> print (output) 588 [0.593994 0.35276785 0.21486944 0.13337152] 589 """ 590 591 def __init__(self): 592 """Initialize IGamma.""" 593 super(IGamma, self).__init__() 594 # const numbers 595 # If more data types are supported, this float max value need to be selected. 596 self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) 597 598 # operations 599 self.logicaland = P.LogicalAnd() 600 self.logicalor = P.LogicalOr() 601 self.logicalnot = P.LogicalNot() 602 self.equal = P.Equal() 603 self.greater = P.Greater() 604 self.less = P.Less() 605 self.neg = P.Neg() 606 self.log = P.Log() 607 self.exp = P.Exp() 608 self.select = P.Select() 609 self.zeroslike = P.ZerosLike() 610 self.fill = P.Fill() 611 self.shape = P.Shape() 612 self.dtype = P.DType() 613 self.lgamma = LGamma() 614 self.const = P.ScalarToArray() 615 self.cast = P.Cast() 616 617 def construct(self, a, x): 618 a_dtype = self.dtype(a) 619 x_dtype = self.dtype(x) 620 _check_input_dtype("a", a_dtype, [mstype.float32], self.cls_name) 621 _check_input_dtype("x", x_dtype, a_dtype, self.cls_name) 622 domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) 623 use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) 624 ax = a * self.log(x) - x - self.lgamma(a) 625 para_shape = self.shape(ax) 626 if para_shape != (): 627 broadcastto = P.BroadcastTo(para_shape) 628 x = broadcastto(x) 629 a = broadcastto(a) 630 x_is_zero = self.equal(x, 0) 631 log_maxfloat = self.log_maxfloat32 632 underflow = self.less(ax, self.neg(log_maxfloat)) 633 ax = self.exp(ax) 634 enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow)) 635 output = self.select(use_igammac, 636 1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)), 637 _igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac)))) 638 output = self.select(x_is_zero, self.zeroslike(output), output) 639 output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output) 640 return output 641 642 643class LBeta(Cell): 644 r""" 645 This method avoids the numeric cancellation by explicitly 646 decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling 647 the large terms from the Striling analytically. 648 649 This is semantically equal to 650 651 .. math:: 652 P(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y). 653 654 The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation 655 is catastrophic cancellation between the lgammas. 656 657 Inputs: 658 - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have 659 the same dtype with `y`. 660 - **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have 661 the same dtype with `x`. 662 663 Outputs: 664 Tensor, has the same dtype as `x` and `y`. 665 666 Raises: 667 TypeError: If dtype of `x` or `y` is neither float16 nor float32, 668 or if `x` has different dtype with `y`. 669 670 Supported Platforms: 671 ``Ascend`` ``GPU`` 672 673 Examples: 674 >>> x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) 675 >>> y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32)) 676 >>> lbeta = nn.LBeta() 677 >>> output = lbeta(y, x) 678 >>> print(output) 679 [-1.7917596 -4.094345 -12.000229 -14.754799] 680 """ 681 682 def __init__(self): 683 """Initialize LBeta.""" 684 super(LBeta, self).__init__() 685 # const numbers 686 self.log_2pi = np.log(2 * np.pi) 687 self.minimax_coeff = [-0.165322962780713e-02, 688 0.837308034031215e-03, 689 -0.595202931351870e-03, 690 0.793650666825390e-03, 691 -0.277777777760991e-02, 692 0.833333333333333e-01] 693 694 # operations 695 self.log = P.Log() 696 self.log1p = P.Log1p() 697 self.less = P.Less() 698 self.select = P.Select() 699 self.shape = P.Shape() 700 self.dtype = P.DType() 701 self.lgamma = LGamma() 702 self.const = P.ScalarToTensor() 703 704 def construct(self, x, y): 705 x_dtype = self.dtype(x) 706 y_dtype = self.dtype(y) 707 _check_input_dtype("x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) 708 _check_input_dtype("y", y_dtype, x_dtype, self.cls_name) 709 x_plus_y = x + y 710 para_shape = self.shape(x_plus_y) 711 if para_shape != (): 712 broadcastto = P.BroadcastTo(para_shape) 713 x = broadcastto(x) 714 y = broadcastto(y) 715 comp_less = self.less(x, y) 716 x_min = self.select(comp_less, x, y) 717 y_max = self.select(comp_less, y, x) 718 719 @constexpr 720 def _log_gamma_correction(x, minimax_coeff): 721 inverse_x = 1. / x 722 inverse_x_squared = inverse_x * inverse_x 723 accum = minimax_coeff[0] 724 for i in range(1, 6): 725 accum = accum * inverse_x_squared + minimax_coeff[i] 726 return accum * inverse_x 727 728 log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff) 729 log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff) 730 log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff) 731 732 # Two large arguments case: y >= x >= 8. 733 log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \ 734 + log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \ 735 + (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max) 736 737 cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min 738 correction = log_gamma_correction_y - log_gamma_correction_x_y 739 log_gamma_difference_big_y = correction + cancelled_stirling 740 741 # One large argument case: x < 8, y >= 8. 742 log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y 743 744 # Small arguments case: x <= y < 8. 745 log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max) 746 comp_xless8 = self.less(x_min, 8) 747 comp_yless8 = self.less(y_max, 8) 748 temp = self.select(comp_yless8, log_beta_small, log_beta_one_large) 749 return self.select(comp_xless8, temp, log_beta_two_large) 750 751 752@constexpr 753def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None): 754 """get broadcast_matmul shape""" 755 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 756 if (len(x_shape) < 2) or (len(y_shape) < 2): 757 raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' should be equal to or greater than 2, " 758 f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': " 759 f"{len(y_shape)}.") 760 x_shape_batch = x_shape[:-2] 761 y_shape_batch = y_shape[:-2] 762 if x_shape_batch == y_shape_batch: 763 return x_shape, y_shape 764 x_len = len(x_shape) 765 y_len = len(y_shape) 766 length = x_len if x_len < y_len else y_len 767 broadcast_shape_back = [] 768 for i in range(-length, -2): 769 if x_shape[i] == 1: 770 broadcast_shape_back.append(y_shape[i]) 771 elif y_shape[i] == 1: 772 broadcast_shape_back.append(x_shape[i]) 773 elif x_shape[i] == y_shape[i]: 774 broadcast_shape_back.append(x_shape[i]) 775 else: 776 raise ValueError(f"{msg_prefix} 'x_shape[{i}]' should be equal to 1, or the 'y_shape[{i}]' should be equal " 777 f"to 1, or the 'x_shape[{i}]' should be equal to 'y_shape[{i}]', but got " 778 f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.") 779 780 broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] 781 x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:] 782 y_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + y_shape[-2:] 783 return x_broadcast_shape, y_broadcast_shape 784 785 786@constexpr 787def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None): 788 """check col and row equal""" 789 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 790 if len(x1_shape) == 1: 791 transpose_x1 = False 792 x1_shape = (1,) + x1_shape 793 if len(x2_shape) == 1: 794 transpose_x2 = False 795 x2_shape = x2_shape + (1,) 796 x1_last = x1_shape[-2:] 797 x2_last = x2_shape[-2:] 798 x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0] 799 x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1] 800 if x1_col != x2_row: 801 raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' should be equal to " 802 f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.") 803 804 805def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2): 806 """select matmul op""" 807 x1_dim, x2_dim = len(x1_shape), len(x2_shape) 808 if x1_dim == 1 and x2_dim == 1: 809 matmul_op = P.Mul() 810 elif x1_dim <= 2 and x2_dim <= 2: 811 transpose_x1 = False if x1_dim == 1 else transpose_x1 812 transpose_x2 = False if x2_dim == 1 else transpose_x2 813 matmul_op = P.MatMul(transpose_x1, transpose_x2) 814 elif x1_dim == 1 and x2_dim > 2: 815 matmul_op = P.BatchMatMul(False, transpose_x2) 816 elif x1_dim > 2 and x2_dim == 1: 817 matmul_op = P.BatchMatMul(transpose_x1, False) 818 else: 819 matmul_op = P.BatchMatMul(transpose_x1, transpose_x2) 820 return matmul_op 821 822 823class MatMul(Cell): 824 r""" 825 Multiplies matrix `x1` by matrix `x2`. 826 827 nn.MatMul will be deprecated in future versions. Please use ops.matmul instead. 828 829 - If both `x1` and `x2` are 1-dimensional, the dot product is returned. 830 - If the dimensions of `x1` and `x2` are all not greater than 2, the matrix-matrix product will 831 be returned. Note if one of 'x1' and 'x2' is 1-dimensional, the argument will first be 832 expanded to 2 dimension. After the matrix multiply, the expanded dimension will be removed. 833 - If at least one of `x1` and `x2` is N-dimensional (N>2), the none-matrix dimensions(batch) of inputs 834 will be broadcasted and must be broadcastable. Note if one of 'x1' and 'x2' is 1-dimensional, 835 the argument will first be expanded to 2 dimension and then the none-matrix dimensions will be broadcasted. 836 after the matrix multiply, the expanded dimension will be removed. For example, 837 if `x1` is a :math:`(j \times 1 \times n \times m)` tensor and 838 `x2` is b :math:`(k \times m \times p)` tensor, the output will be a :math:`(j \times k \times n \times p)` 839 tensor. 840 841 Args: 842 transpose_x1 (bool): If true, `a` is transposed before multiplication. Default: False. 843 transpose_x2 (bool): If true, `b` is transposed before multiplication. Default: False. 844 845 Inputs: 846 - **x1** (Tensor) - The first tensor to be multiplied. 847 - **x2** (Tensor) - The second tensor to be multiplied. 848 849 Outputs: 850 Tensor, the shape of the output tensor depends on the dimension of input tensors. 851 852 Raises: 853 TypeError: If `transpose_x1` or `transpose_x2` is not a bool. 854 ValueError: If the column of matrix dimensions of `x1` is not equal to 855 the row of matrix dimensions of `x2`. 856 857 Supported Platforms: 858 ``Ascend`` ``GPU`` ``CPU`` 859 860 Examples: 861 >>> net = nn.MatMul() 862 >>> x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32) 863 >>> x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32) 864 >>> output = net(x1, x2) 865 >>> print(output.shape) 866 (3, 2, 4) 867 """ 868 869 @deprecated('1.2', 'ops.matmul', False) 870 def __init__(self, transpose_x1=False, transpose_x2=False): 871 """Initialize MatMul.""" 872 super(MatMul, self).__init__() 873 874 validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name) 875 validator.check_value_type('transpose_x2', transpose_x2, [bool], self.cls_name) 876 self.transpose_x1 = transpose_x1 877 self.transpose_x2 = transpose_x2 878 self.shape_op = P.Shape() 879 self.expand_op = P.ExpandDims() 880 self.squeeze_left_op = P.Squeeze(-2) 881 self.squeeze_right_op = P.Squeeze(-1) 882 self.reduce_sum_op = P.ReduceSum(keep_dims=False) 883 884 def construct(self, x1, x2): 885 x1_shape = self.shape_op(x1) 886 x2_shape = self.shape_op(x2) 887 check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name) 888 matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2) 889 890 x1_dim, x2_dim = len(x1_shape), len(x2_shape) 891 if x1_dim == x2_dim and x2_dim == 1: 892 return self.reduce_sum_op(matmul_op(x1, x2), -1) 893 if x1_dim == 1: 894 x1 = self.expand_op(x1, 0) 895 x1_shape = self.shape_op(x1) 896 if x2_dim == 1: 897 x2 = self.expand_op(x2, 1) 898 x2_shape = self.shape_op(x2) 899 900 x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape) 901 x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape) 902 x2_broadcast_to = P.BroadcastTo(x2_broadcast_shape) 903 if x1_broadcast_shape != x1_shape: 904 x1 = x1_broadcast_to(x1) 905 if x2_broadcast_shape != x2_shape: 906 x2 = x2_broadcast_to(x2) 907 908 matmul_broadcast = matmul_op(x1, x2) 909 910 if x1_dim == 1: 911 matmul_broadcast = self.squeeze_left_op(matmul_broadcast) 912 if x2_dim == 1: 913 matmul_broadcast = self.squeeze_right_op(matmul_broadcast) 914 915 return matmul_broadcast 916 917 918class Moments(Cell): 919 """ 920 Calculates the mean and variance of `x`. 921 922 The mean and variance are calculated by aggregating the contents of `input_x` across axes. 923 If `input_x` is 1-D and axes = [0] this is just the mean and variance of a vector. 924 925 Args: 926 axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: None. 927 keep_dims (bool): If true, The dimension of mean and variance are identical with input's. 928 If false, don't keep these dimensions. Default: None. 929 930 Inputs: 931 - **x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported. 932 :math:`(N,*)` where :math:`*` means,any number of additional dimensions. 933 934 Outputs: 935 - **mean** (Tensor) - The mean of `x`, with the same date type as input `x`. 936 - **variance** (Tensor) - The variance of `x`, with the same date type as input `x`. 937 938 Raises: 939 TypeError: If `axis` is not one of int, tuple, None. 940 TypeError: If `keep_dims` is neither bool nor None. 941 TypeError: If dtype of `x` is neither float16 nor float32. 942 943 Supported Platforms: 944 ``Ascend`` ``GPU`` ``CPU`` 945 946 Examples: 947 >>> x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32) 948 >>> net = nn.Moments(axis=0, keep_dims=True) 949 >>> output = net(x) 950 >>> print(output) 951 (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= 952 [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], 953 [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]), 954 Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= 955 [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], 956 [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])) 957 >>> net = nn.Moments(axis=1, keep_dims=True) 958 >>> output = net(x) 959 >>> print(output) 960 (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= 961 [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], 962 [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]), 963 Tensor(shape=[1, 1, 2, 4], dtype=Float32, value= 964 [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], 965 [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])) 966 >>> net = nn.Moments(axis=2, keep_dims=True) 967 >>> output = net(x) 968 >>> print(output) 969 (Tensor(shape=[1, 1, 1, 4], dtype=Float32, value= 970 [[[[ 2.00000000e+00, 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]]]), 971 Tensor(shape=[1, 1, 1, 4], dtype=Float32, value= 972 [[[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]]])) 973 >>> net = nn.Moments(axis=3, keep_dims=True) 974 >>> output = net(x) 975 >>> print(output) 976 (Tensor(shape=[1, 1, 2, 1], dtype=Float32, value= 977 [[[[ 2.50000000e+00], 978 [ 4.50000000e+00]]]]), Tensor(shape=[1, 1, 2, 1], dtype=Float32, value= 979 [[[[ 1.25000000e+00], 980 [ 1.25000000e+00]]]])) 981 """ 982 983 def __init__(self, axis=None, keep_dims=None): 984 """Initialize Moments.""" 985 super(Moments, self).__init__() 986 if axis is None: 987 axis = () 988 if isinstance(axis, tuple): 989 for idx, item in enumerate(axis): 990 validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name) 991 self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name) 992 if keep_dims is None: 993 keep_dims = False 994 self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) 995 self.cast = P.Cast() 996 self.reduce_mean = P.ReduceMean(keep_dims=True) 997 self.square_diff = P.SquaredDifference() 998 self.squeeze = P.Squeeze(self.axis) 999 1000 def construct(self, x): 1001 tensor_dtype = F.dtype(x) 1002 _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name) 1003 if tensor_dtype == mstype.float16: 1004 x = self.cast(x, mstype.float32) 1005 mean = self.reduce_mean(x, self.axis) 1006 variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis) 1007 if not self.keep_dims: 1008 mean = self.squeeze(mean) 1009 variance = self.squeeze(variance) 1010 if tensor_dtype == mstype.float16: 1011 mean = self.cast(mean, mstype.float16) 1012 variance = self.cast(variance, mstype.float16) 1013 return mean, variance 1014 return mean, variance 1015 1016 1017class MatInverse(Cell): 1018 """ 1019 Calculates the inverse of Positive-Definite Hermitian matrix using Cholesky decomposition. 1020 1021 Inputs: 1022 - **x** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix. 1023 With float16 or float32 data type. 1024 1025 Outputs: 1026 Tensor, has the same dtype as the `x`. 1027 1028 Raises: 1029 TypeError: If dtype of `x` is neither float16 nor float32. 1030 1031 Supported Platforms: 1032 ``GPU`` 1033 1034 Examples: 1035 >>> x = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)) 1036 >>> op = nn.MatInverse() 1037 >>> output = op(x) 1038 >>> print(output) 1039 [[49.36112 -13.555558 2.1111116] 1040 [-13.555558 3.7777784 -0.5555557] 1041 [2.1111116 -0.5555557 0.11111113]] 1042 """ 1043 def __init__(self): 1044 """Initialize MatInverse.""" 1045 super(MatInverse, self).__init__() 1046 self.dtype = P.DType() 1047 self.choleskytrsm = P.CholeskyTrsm() 1048 self.matmul = MatMul(transpose_x1=True) 1049 1050 def construct(self, a): 1051 input_dtype = self.dtype(a) 1052 _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name) 1053 l_inverse = self.choleskytrsm(a) 1054 a_inverse = self.matmul(l_inverse, l_inverse) 1055 return a_inverse 1056 1057 1058class MatDet(Cell): 1059 """ 1060 Calculates the determinant of Positive-Definite Hermitian matrix using Cholesky decomposition. 1061 1062 Inputs: 1063 - **x** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix. 1064 With float16 or float32 data type. 1065 1066 Outputs: 1067 Tensor, has the same dtype as the `x`. 1068 1069 Raises: 1070 TypeError: If dtype of `x` is neither float16 nor float32. 1071 1072 Supported Platforms: 1073 ``GPU`` 1074 1075 Examples: 1076 >>> x = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)) 1077 >>> op = nn.MatDet() 1078 >>> output = op(x) 1079 >>> print(output) 1080 35.999996 1081 """ 1082 def __init__(self): 1083 """Initialize MatDet.""" 1084 super(MatDet, self).__init__() 1085 self.dtype = P.DType() 1086 self.cholesky = P.Cholesky() 1087 self.det_triangle = P.DetTriangle() 1088 self.square = P.Square() 1089 1090 def construct(self, a): 1091 input_dtype = self.dtype(a) 1092 _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name) 1093 l = self.cholesky(a) 1094 l_det = self.det_triangle(l) 1095 a_det = self.square(l_det) 1096 return a_det 1097