• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Utility functions to help distribution class."""
16import numpy as np
17from mindspore import context
18from mindspore import _checkparam as validator
19from mindspore.common.tensor import Tensor
20from mindspore.common.parameter import Parameter
21from mindspore.common import dtype as mstype
22from mindspore.ops import operations as P
23from mindspore.ops.primitive import constexpr, _primexpr, PrimitiveWithInfer, prim_attr_register
24import mindspore.ops as ops
25import mindspore.nn as nn
26
27
28def cast_to_tensor(t, hint_type=mstype.float32):
29    """
30    Cast an user input value into a Tensor of dtype.
31    If the input t is of type Parameter, t is directly returned as a Parameter.
32
33    Args:
34        t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
35        dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
36
37    Raises:
38        RuntimeError: if t cannot be cast to Tensor.
39
40    Returns:
41        Tensor.
42    """
43    if t is None:
44        raise ValueError(f'Input cannot be None in cast_to_tensor')
45    if isinstance(t, Parameter):
46        return t
47    if isinstance(t, bool):
48        raise TypeError(f'Input cannot be Type Bool')
49    if isinstance(t, (Tensor, np.ndarray, list, int, float)):
50        return Tensor(t, dtype=hint_type)
51    invalid_type = type(t)
52    raise TypeError(
53        f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
54
55
56def cast_type_for_device(dtype):
57    """
58    use the alternative dtype supported by the device.
59    Args:
60        dtype (mindspore.dtype): input dtype.
61    Returns:
62        mindspore.dtype.
63    """
64    if context.get_context("device_target") == "GPU":
65        if dtype in mstype.uint_type or dtype == mstype.int8:
66            return mstype.int16
67        if dtype == mstype.int64:
68            return mstype.int32
69        if dtype == mstype.float64:
70            return mstype.float32
71    return dtype
72
73
74def check_greater_equal_zero(value, name):
75    """
76    Check if the given Tensor is greater zero.
77
78    Args:
79        value (Tensor, Parameter): value to be checked.
80        name (str) : name of the value.
81
82    Raises:
83        ValueError: if the input value is less than zero.
84
85    """
86    if isinstance(value, Parameter):
87        if not isinstance(value.data, Tensor):
88            return
89        value = value.data
90    comp = np.less(value.asnumpy(), np.zeros(value.shape))
91    if comp.any():
92        raise ValueError(f'{name} must be greater than or equal to zero.')
93
94
95def check_greater_zero(value, name):
96    """
97    Check if the given Tensor is strictly greater than zero.
98
99    Args:
100        value (Tensor, Parameter): value to be checked.
101        name (str) : name of the value.
102
103    Raises:
104        ValueError: if the input value is less than or equal to zero.
105
106    """
107    if value is None:
108        raise ValueError(f'input value cannot be None in check_greater_zero')
109    if isinstance(value, Parameter):
110        if not isinstance(value.data, Tensor):
111            return
112        value = value.data
113    comp = np.less(np.zeros(value.shape), value.asnumpy())
114    if not comp.all():
115        raise ValueError(f'{name} must be greater than zero.')
116
117
118def check_greater(a, b, name_a, name_b):
119    """
120    Check if Tensor b is strictly greater than Tensor a.
121
122    Args:
123        a (Tensor, Parameter): input tensor a.
124        b (Tensor, Parameter): input tensor b.
125        name_a (str): name of Tensor_a.
126        name_b (str): name of Tensor_b.
127
128    Raises:
129        ValueError: if b is less than or equal to a
130    """
131    if a is None or b is None:
132        raise ValueError(f'input value cannot be None in check_greater')
133    if isinstance(a, Parameter) or isinstance(b, Parameter):
134        return
135    comp = np.less(a.asnumpy(), b.asnumpy())
136    if not comp.all():
137        raise ValueError(f'{name_a} must be less than {name_b}')
138
139
140def check_prob(p):
141    """
142    Check if p is a proper probability, i.e. 0 < p <1.
143
144    Args:
145        p (Tensor, Parameter): value to be checked.
146
147    Raises:
148        ValueError: if p is not a proper probability.
149    """
150    if p is None:
151        raise ValueError(f'input value cannot be None in check_greater_zero')
152    if isinstance(p, Parameter):
153        if not isinstance(p.data, Tensor):
154            return
155        p = p.data
156    comp = np.less(np.zeros(p.shape), p.asnumpy())
157    if not comp.all():
158        raise ValueError('Probabilities must be greater than zero')
159    comp = np.greater(np.ones(p.shape), p.asnumpy())
160    if not comp.all():
161        raise ValueError('Probabilities must be less than one')
162
163
164def check_sum_equal_one(probs):
165    """
166    Used in categorical distribution. check if probabilities of each category sum to 1.
167    """
168    if probs is None:
169        raise ValueError(f'input value cannot be None in check_sum_equal_one')
170    if isinstance(probs, Parameter):
171        if not isinstance(probs.data, Tensor):
172            return
173        probs = probs.data
174    if isinstance(probs, Tensor):
175        probs = probs.asnumpy()
176    prob_sum = np.sum(probs, axis=-1)
177    # add a small tolerance here to increase numerical stability
178    comp = np.allclose(prob_sum, np.ones(prob_sum.shape),
179                       rtol=np.finfo(prob_sum.dtype).eps * 10, atol=np.finfo(prob_sum.dtype).eps)
180    if not comp:
181        raise ValueError(
182            'Probabilities for each category should sum to one for Categorical distribution.')
183
184
185def check_rank(probs):
186    """
187    Used in categorical distribution. check Rank >=1.
188    """
189    if probs is None:
190        raise ValueError(f'input value cannot be None in check_rank')
191    if isinstance(probs, Parameter):
192        if not isinstance(probs.data, Tensor):
193            return
194        probs = probs.data
195    if probs.asnumpy().ndim == 0:
196        raise ValueError(
197            'probs for Categorical distribution must have rank >= 1.')
198
199
200def logits_to_probs(logits, is_binary=False):
201    """
202    converts logits into probabilities.
203    Args:
204        logits (Tensor)
205        is_binary (bool)
206    """
207    if is_binary:
208        return nn.Sigmoid()(logits)
209    return nn.Softmax(axis=-1)(logits)
210
211
212def clamp_probs(probs):
213    """
214    clamp probs boundary
215    """
216    eps = P.Eps()(probs)
217    return ops.clip_by_value(probs, eps, 1-eps)
218
219
220def probs_to_logits(probs, is_binary=False):
221    """
222    converts probabilities into logits.
223        Args:
224        probs (Tensor)
225        is_binary (bool)
226    """
227    ps_clamped = clamp_probs(probs)
228    if is_binary:
229        return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
230    return P.Log()(ps_clamped)
231
232
233@constexpr(check=False)
234def raise_none_error(name):
235    raise TypeError(f"the type {name} must be subclass of Tensor."
236                    f" It can not be None since it is not specified during initialization.")
237
238
239@_primexpr
240def raise_broadcast_error(shape_a, shape_b):
241    raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
242
243
244@constexpr(check=False)
245def raise_not_impl_error(name):
246    raise ValueError(
247        f"{name} function must be implemented for non-linear transformation")
248
249
250@constexpr(check=False)
251def raise_not_implemented_util(func_name, obj, *args, **kwargs):
252    raise NotImplementedError(
253        f"{func_name} is not implemented for {obj} distribution.")
254
255
256@constexpr(check=False)
257def raise_type_error(name, cur_type, required_type):
258    raise TypeError(
259        f"For {name} , the type must be or be subclass of {required_type}, but got {cur_type}")
260
261
262@constexpr(check=False)
263def raise_not_defined(func_name, obj, *args, **kwargs):
264    raise ValueError(
265        f"{func_name} is undefined for {obj} distribution.")
266
267
268@constexpr(check=False)
269def check_distribution_name(name, expected_name):
270    if name is None:
271        raise ValueError(
272            f"Input dist must be a constant which is not None.")
273    if name != expected_name:
274        raise ValueError(
275            f"Expected dist input is {expected_name}, but got {name}.")
276
277
278class CheckTuple(PrimitiveWithInfer):
279    """
280    Check if input is a tuple.
281    """
282    @prim_attr_register
283    def __init__(self):
284        super(CheckTuple, self).__init__("CheckTuple")
285        self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
286
287    def __infer__(self, x, name):
288        if not isinstance(x['dtype'], tuple):
289            raise TypeError(
290                f"For {name['value']}, Input type must b a tuple.")
291
292        out = {'shape': None,
293               'dtype': None,
294               'value': x["value"]}
295        return out
296
297    def __call__(self, x, name):
298        # The op is not used in a cell
299        if isinstance(x, tuple):
300            return x
301        if context.get_context("mode") == 0:
302            return x["value"]
303        raise TypeError(f"For {name}, input type must be a tuple.")
304
305
306class CheckTensor(PrimitiveWithInfer):
307    """
308    Check if input is a Tensor.
309    """
310    @prim_attr_register
311    def __init__(self):
312        super(CheckTensor, self).__init__("CheckTensor")
313        self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
314
315    def __infer__(self, x, name):
316        src_type = x['dtype']
317        validator.check_subclass(
318            "input", src_type, [mstype.tensor_type], name["value"])
319
320        out = {'shape': None,
321               'dtype': None,
322               'value': None}
323        return out
324
325    def __call__(self, x, name):
326        # we skip this check in graph mode as it is checked in the infer stage
327        # and in the graph mode x is None if x is not const in the graph
328        if x is None or isinstance(x, Tensor):
329            return x
330        raise TypeError(
331            f"For {name}, input type must be a Tensor or Parameter.")
332
333
334def set_param_type(args, hint_type):
335    """
336    Find the common type among arguments.
337
338    Args:
339        args (dict): dictionary of arguments, {'name':value}.
340        hint_type (mindspore.dtype): hint type to return.
341
342    Raises:
343        TypeError: if tensors in args are not the same dtype.
344    """
345    int_type = mstype.int_type + mstype.uint_type
346    if hint_type in int_type or hint_type is None:
347        hint_type = mstype.float32
348    common_dtype = None
349    for name, arg in args.items():
350        if hasattr(arg, 'dtype'):
351            if isinstance(arg, np.ndarray):
352                cur_dtype = mstype.pytype_to_dtype(arg.dtype)
353            else:
354                cur_dtype = arg.dtype
355            if common_dtype is None:
356                common_dtype = cur_dtype
357            elif cur_dtype != common_dtype:
358                raise TypeError(
359                    f"{name} should have the same dtype as other arguments.")
360    if common_dtype in int_type or common_dtype == mstype.float64:
361        return mstype.float32
362    return hint_type if common_dtype is None else common_dtype
363