1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Scalar Affine Bijector""" 16from mindspore.ops import operations as P 17from ..distribution._utils.custom_ops import log_generic 18from .bijector import Bijector 19 20 21class ScalarAffine(Bijector): 22 """ 23 Scalar Affine Bijector. 24 This Bijector performs the operation: 25 26 .. math:: 27 Y = a * X + b 28 29 where a is the scale factor and b is the shift factor. 30 31 Args: 32 scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. 33 shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0. 34 name (str): The name of the bijector. Default: 'ScalarAffine'. 35 36 Supported Platforms: 37 ``Ascend`` ``GPU`` 38 39 Note: 40 The dtype of `shift` and `scale` must be float. 41 If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have 42 the same dtype otherwise an error will be raised. 43 44 Raises: 45 TypeError: When the dtype of `shift` or `scale` is not float, 46 and when the dtype of `shift` and `scale` is not same. 47 48 Examples: 49 >>> import mindspore 50 >>> import mindspore.nn as nn 51 >>> from mindspore import Tensor 52 >>> 53 >>> # To initialize a ScalarAffine bijector of scale 1.0 and shift 2. 54 >>> scalaraffine = nn.probability.bijector.ScalarAffine(1.0, 2.0) 55 >>> value = Tensor([1, 2, 3], dtype=mindspore.float32) 56 >>> ans1 = scalaraffine.forward(value) 57 >>> print(ans1.shape) 58 (3,) 59 >>> ans2 = scalaraffine.inverse(value) 60 >>> print(ans2.shape) 61 (3,) 62 >>> ans3 = scalaraffine.forward_log_jacobian(value) 63 >>> print(ans3.shape) 64 () 65 >>> ans4 = scalaraffine.inverse_log_jacobian(value) 66 >>> print(ans4.shape) 67 () 68 """ 69 70 def __init__(self, 71 scale=1.0, 72 shift=0.0, 73 name='ScalarAffine'): 74 """ 75 Constructor of ScalarAffine Bijector. 76 """ 77 param = dict(locals()) 78 param['param_dict'] = {'scale': scale, 'shift': shift} 79 super(ScalarAffine, self).__init__( 80 is_constant_jacobian=True, 81 is_injective=True, 82 name=name, 83 dtype=None, 84 param=param) 85 86 self._scale = self._add_parameter(scale, 'scale') 87 self._shift = self._add_parameter(shift, 'shift') 88 89 self.abs = P.Abs() 90 self.oneslike = P.OnesLike() 91 self.dtypeop = P.DType() 92 self.cast = P.Cast() 93 self.log = log_generic 94 95 @property 96 def scale(self): 97 return self._scale 98 99 @property 100 def shift(self): 101 return self._shift 102 103 def extend_repr(self): 104 """Display instance object as string.""" 105 if self.is_scalar_batch: 106 str_info = 'scale = {}, shift = {}'.format(self.scale, self.shift) 107 else: 108 str_info = 'batch_shape = {}'.format(self.batch_shape) 109 return str_info 110 111 def _forward(self, x): 112 r""" 113 .. math:: 114 f(x) = a * x + b 115 """ 116 x = self._check_value_dtype(x) 117 scale_local = self.cast_param_by_value(x, self.scale) 118 shift_local = self.cast_param_by_value(x, self.shift) 119 forward_v = scale_local * x + shift_local * self.oneslike(x) 120 return forward_v 121 122 def _inverse(self, y): 123 r""" 124 .. math:: 125 f(y) = \frac{y - b}{a} 126 """ 127 y = self._check_value_dtype(y) 128 scale_local = self.cast_param_by_value(y, self.scale) 129 shift_local = self.cast_param_by_value(y, self.shift) 130 inverse_v = (y - shift_local) / scale_local 131 return inverse_v 132 133 def _forward_log_jacobian(self, x): 134 r""" 135 .. math:: 136 f(x) = a * x + b 137 f'(x) = a 138 \log(f'(x)) = \log(a) 139 """ 140 x = self._check_value_dtype(x) 141 scale_local = self.cast_param_by_value(x, self.scale) 142 forward_log_j = self.log(self.abs(scale_local)) 143 return forward_log_j 144 145 def _inverse_log_jacobian(self, y): 146 r""" 147 .. math:: 148 f(y) = \frac{(y - b)}{a} 149 f'(x) = \frac{1.0}{a} 150 \log(f'(x)) = - \log(a) 151 """ 152 y = self._check_value_dtype(y) 153 scale_local = self.cast_param_by_value(y, self.scale) 154 inverse_log_j = -1. * self.log(self.abs(scale_local)) 155 return inverse_log_j 156