• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Utility functions to help distribution class."""
16import numpy as np
17from mindspore import context
18from mindspore._checkparam import Validator as validator
19from mindspore.common.tensor import Tensor
20from mindspore.common.parameter import Parameter
21from mindspore.common import dtype as mstype
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
25import mindspore.nn as nn
26
27
28def cast_to_tensor(t, hint_type=mstype.float32):
29    """
30    Cast an user input value into a Tensor of dtype.
31    If the input t is of type Parameter, t is directly returned as a Parameter.
32
33    Args:
34        t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
35        dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
36
37    Raises:
38        RuntimeError: if t cannot be cast to Tensor.
39
40    Returns:
41        Tensor.
42    """
43    if t is None:
44        raise ValueError(f'Input cannot be None in cast_to_tensor')
45    if isinstance(t, Parameter):
46        return t
47    if isinstance(t, bool):
48        raise TypeError(f'Input cannot be Type Bool')
49    if isinstance(t, (Tensor, np.ndarray, list, int, float)):
50        return Tensor(t, dtype=hint_type)
51    invalid_type = type(t)
52    raise TypeError(
53        f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
54
55
56def cast_type_for_device(dtype):
57    """
58    use the alternative dtype supported by the device.
59    Args:
60        dtype (mindspore.dtype): input dtype.
61    Returns:
62        mindspore.dtype.
63    """
64    if context.get_context("device_target") == "GPU":
65        if dtype in mstype.uint_type or dtype == mstype.int8:
66            return mstype.int16
67        if dtype == mstype.int64:
68            return mstype.int32
69        if dtype == mstype.float64:
70            return mstype.float32
71    return dtype
72
73
74def check_greater_equal_zero(value, name):
75    """
76    Check if the given Tensor is greater zero.
77
78    Args:
79        value (Tensor, Parameter): value to be checked.
80        name (str) : name of the value.
81
82    Raises:
83        ValueError: if the input value is less than zero.
84
85    """
86    if isinstance(value, Parameter):
87        if not isinstance(value.data, Tensor):
88            return
89        value = value.data
90    comp = np.less(value.asnumpy(), np.zeros(value.shape))
91    if comp.any():
92        raise ValueError(f'{name} should be greater than ot equal to zero.')
93
94
95def check_greater_zero(value, name):
96    """
97    Check if the given Tensor is strictly greater than zero.
98
99    Args:
100        value (Tensor, Parameter): value to be checked.
101        name (str) : name of the value.
102
103    Raises:
104        ValueError: if the input value is less than or equal to zero.
105
106    """
107    if value is None:
108        raise ValueError(f'input value cannot be None in check_greater_zero')
109    if isinstance(value, Parameter):
110        if not isinstance(value.data, Tensor):
111            return
112        value = value.data
113    comp = np.less(np.zeros(value.shape), value.asnumpy())
114    if not comp.all():
115        raise ValueError(f'{name} should be greater than zero.')
116
117
118def check_greater(a, b, name_a, name_b):
119    """
120    Check if Tensor b is strictly greater than Tensor a.
121
122    Args:
123        a (Tensor, Parameter): input tensor a.
124        b (Tensor, Parameter): input tensor b.
125        name_a (str): name of Tensor_a.
126        name_b (str): name of Tensor_b.
127
128    Raises:
129        ValueError: if b is less than or equal to a
130    """
131    if a is None or b is None:
132        raise ValueError(f'input value cannot be None in check_greater')
133    if isinstance(a, Parameter) or isinstance(b, Parameter):
134        return
135    comp = np.less(a.asnumpy(), b.asnumpy())
136    if not comp.all():
137        raise ValueError(f'{name_a} should be less than {name_b}')
138
139
140def check_prob(p):
141    """
142    Check if p is a proper probability, i.e. 0 < p <1.
143
144    Args:
145        p (Tensor, Parameter): value to be checked.
146
147    Raises:
148        ValueError: if p is not a proper probability.
149    """
150    if p is None:
151        raise ValueError(f'input value cannot be None in check_greater_zero')
152    if isinstance(p, Parameter):
153        if not isinstance(p.data, Tensor):
154            return
155        p = p.data
156    comp = np.less(np.zeros(p.shape), p.asnumpy())
157    if not comp.all():
158        raise ValueError('Probabilities should be greater than zero')
159    comp = np.greater(np.ones(p.shape), p.asnumpy())
160    if not comp.all():
161        raise ValueError('Probabilities should be less than one')
162
163
164def check_sum_equal_one(probs):
165    """
166    Used in categorical distribution. check if probabilities of each category sum to 1.
167    """
168    if probs is None:
169        raise ValueError(f'input value cannot be None in check_sum_equal_one')
170    if isinstance(probs, Parameter):
171        if not isinstance(probs.data, Tensor):
172            return
173        probs = probs.data
174    if isinstance(probs, Tensor):
175        probs = probs.asnumpy()
176    prob_sum = np.sum(probs, axis=-1)
177    # add a small tolerance here to increase numerical stability
178    comp = np.allclose(prob_sum, np.ones(prob_sum.shape), rtol=1e-14, atol=1e-14)
179    if not comp:
180        raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')
181
182
183def check_rank(probs):
184    """
185    Used in categorical distribution. check Rank >=1.
186    """
187    if probs is None:
188        raise ValueError(f'input value cannot be None in check_rank')
189    if isinstance(probs, Parameter):
190        if not isinstance(probs.data, Tensor):
191            return
192        probs = probs.data
193    if probs.asnumpy().ndim == 0:
194        raise ValueError('probs for Categorical distribution must have rank >= 1.')
195
196
197def logits_to_probs(logits, is_binary=False):
198    """
199    converts logits into probabilities.
200    Args:
201        logits (Tensor)
202        is_binary (bool)
203    """
204    if is_binary:
205        return nn.Sigmoid()(logits)
206    return nn.Softmax(axis=-1)(logits)
207
208
209def clamp_probs(probs):
210    """
211    clamp probs boundary
212    """
213    eps = P.Eps()(probs)
214    return C.clip_by_value(probs, eps, 1-eps)
215
216
217def probs_to_logits(probs, is_binary=False):
218    """
219    converts probabilities into logits.
220        Args:
221        probs (Tensor)
222        is_binary (bool)
223    """
224    ps_clamped = clamp_probs(probs)
225    if is_binary:
226        return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
227    return P.Log()(ps_clamped)
228
229
230@constexpr
231def raise_none_error(name):
232    raise TypeError(f"the type {name} should be subclass of Tensor."
233                    f" It should not be None since it is not specified during initialization.")
234
235
236@constexpr
237def raise_probs_logits_error():
238    raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
239
240
241@constexpr
242def raise_broadcast_error(shape_a, shape_b):
243    raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
244
245
246@constexpr
247def raise_not_impl_error(name):
248    raise ValueError(
249        f"{name} function should be implemented for non-linear transformation")
250
251
252@constexpr
253def raise_not_implemented_util(func_name, obj, *args, **kwargs):
254    raise NotImplementedError(
255        f"{func_name} is not implemented for {obj} distribution.")
256
257
258@constexpr
259def raise_type_error(name, cur_type, required_type):
260    raise TypeError(
261        f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}")
262
263
264@constexpr
265def raise_not_defined(func_name, obj, *args, **kwargs):
266    raise ValueError(
267        f"{func_name} is undefined for {obj} distribution.")
268
269
270@constexpr
271def check_distribution_name(name, expected_name):
272    if name is None:
273        raise ValueError(
274            f"Input dist should be a constant which is not None.")
275    if name != expected_name:
276        raise ValueError(
277            f"Expected dist input is {expected_name}, but got {name}.")
278
279
280class CheckTuple(PrimitiveWithInfer):
281    """
282    Check if input is a tuple.
283    """
284    @prim_attr_register
285    def __init__(self):
286        super(CheckTuple, self).__init__("CheckTuple")
287        self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
288
289    def __infer__(self, x, name):
290        if not isinstance(x['dtype'], tuple):
291            raise TypeError(
292                f"For {name['value']}, Input type should b a tuple.")
293
294        out = {'shape': None,
295               'dtype': None,
296               'value': x["value"]}
297        return out
298
299    def __call__(self, x, name):
300        # The op is not used in a cell
301        if isinstance(x, tuple):
302            return x
303        if context.get_context("mode") == 0:
304            return x["value"]
305        raise TypeError(f"For {name}, input type should be a tuple.")
306
307
308class CheckTensor(PrimitiveWithInfer):
309    """
310    Check if input is a Tensor.
311    """
312    @prim_attr_register
313    def __init__(self):
314        super(CheckTensor, self).__init__("CheckTensor")
315        self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
316
317    def __infer__(self, x, name):
318        src_type = x['dtype']
319        validator.check_subclass(
320            "input", src_type, [mstype.tensor], name["value"])
321
322        out = {'shape': None,
323               'dtype': None,
324               'value': None}
325        return out
326
327    def __call__(self, x, name):
328        if isinstance(x, Tensor):
329            return x
330        raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
331
332
333def set_param_type(args, hint_type):
334    """
335    Find the common type among arguments.
336
337    Args:
338        args (dict): dictionary of arguments, {'name':value}.
339        hint_type (mindspore.dtype): hint type to return.
340
341    Raises:
342        TypeError: if tensors in args are not the same dtype.
343    """
344    int_type = mstype.int_type + mstype.uint_type
345    if hint_type in int_type or hint_type is None:
346        hint_type = mstype.float32
347    common_dtype = None
348    for name, arg in args.items():
349        if hasattr(arg, 'dtype'):
350            if isinstance(arg, np.ndarray):
351                cur_dtype = mstype.pytype_to_dtype(arg.dtype)
352            else:
353                cur_dtype = arg.dtype
354            if common_dtype is None:
355                common_dtype = cur_dtype
356            elif cur_dtype != common_dtype:
357                raise TypeError(f"{name} should have the same dtype as other arguments.")
358    if common_dtype in int_type or common_dtype == mstype.float64:
359        return mstype.float32
360    return hint_type if common_dtype is None else common_dtype
361