1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Random functions.""" 16 17# pylint: disable=g-direct-tensorflow-import 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as onp 24 25from tensorflow.python.framework import random_seed 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import random_ops 28from tensorflow.python.ops.numpy_ops import np_array_ops 29from tensorflow.python.ops.numpy_ops import np_dtypes 30from tensorflow.python.ops.numpy_ops import np_utils 31 32# TODO(agarwal): deprecate this. 33DEFAULT_RANDN_DTYPE = onp.float32 34 35 36@np_utils.np_doc('random.seed') 37def seed(s): 38 """Sets the seed for the random number generator. 39 40 Uses `tf.set_random_seed`. 41 42 Args: 43 s: an integer. 44 """ 45 try: 46 s = int(s) 47 except TypeError: 48 # TODO(wangpeng): support this? 49 raise ValueError('np.seed currently only support integer arguments.') 50 random_seed.set_seed(s) 51 52 53@np_utils.np_doc('random.randn') 54def randn(*args): 55 """Returns samples from a normal distribution. 56 57 Uses `tf.random_normal`. 58 59 Args: 60 *args: The shape of the output array. 61 62 Returns: 63 An ndarray with shape `args` and dtype `float64`. 64 """ 65 return standard_normal(size=args) 66 67 68@np_utils.np_doc('random.standard_normal') 69def standard_normal(size=None): 70 # TODO(wangpeng): Use new stateful RNG 71 if size is None: 72 size = () 73 elif np_utils.isscalar(size): 74 size = (size,) 75 dtype = np_dtypes.default_float_type() 76 return random_ops.random_normal(size, dtype=dtype) 77 78 79@np_utils.np_doc('random.uniform') 80def uniform(low=0.0, high=1.0, size=None): 81 dtype = np_dtypes.default_float_type() 82 low = np_array_ops.asarray(low, dtype=dtype) 83 high = np_array_ops.asarray(high, dtype=dtype) 84 if size is None: 85 size = array_ops.broadcast_dynamic_shape(low.shape, high.shape) 86 return random_ops.random_uniform( 87 shape=size, minval=low, maxval=high, dtype=dtype) 88 89 90@np_utils.np_doc('random.poisson') 91def poisson(lam=1.0, size=None): 92 if size is None: 93 size = () 94 elif np_utils.isscalar(size): 95 size = (size,) 96 return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_) 97 98 99@np_utils.np_doc('random.random') 100def random(size=None): 101 return uniform(0., 1., size) 102 103 104@np_utils.np_doc('random.rand') 105def rand(*size): 106 return uniform(0., 1., size) 107 108 109@np_utils.np_doc('random.randint') 110def randint(low, high=None, size=None, dtype=onp.int): # pylint: disable=missing-function-docstring 111 low = int(low) 112 if high is None: 113 high = low 114 low = 0 115 if size is None: 116 size = () 117 elif isinstance(size, int): 118 size = (size,) 119 dtype = np_utils.result_type(dtype) 120 if dtype not in (onp.int32, onp.int64): 121 raise ValueError('Only np.int32 or np.int64 types are supported') 122 return random_ops.random_uniform( 123 shape=size, minval=low, maxval=high, dtype=dtype) 124