1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Random functions.""" 16 17# pylint: disable=g-direct-tensorflow-import 18 19import numpy as onp 20 21from tensorflow.python.framework import random_seed 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import random_ops 24from tensorflow.python.ops.numpy_ops import np_array_ops 25from tensorflow.python.ops.numpy_ops import np_dtypes 26from tensorflow.python.ops.numpy_ops import np_utils 27 28# TODO(agarwal): deprecate this. 29DEFAULT_RANDN_DTYPE = onp.float32 30 31 32@np_utils.np_doc('random.seed') 33def seed(s): 34 """Sets the seed for the random number generator. 35 36 Uses `tf.set_random_seed`. 37 38 Args: 39 s: an integer. 40 """ 41 try: 42 s = int(s) 43 except TypeError: 44 # TODO(wangpeng): support this? 45 raise ValueError( 46 f'Argument `s` got an invalid value {s}. Only integers are supported.') 47 random_seed.set_seed(s) 48 49 50@np_utils.np_doc('random.randn') 51def randn(*args): 52 """Returns samples from a normal distribution. 53 54 Uses `tf.random_normal`. 55 56 Args: 57 *args: The shape of the output array. 58 59 Returns: 60 An ndarray with shape `args` and dtype `float64`. 61 """ 62 return standard_normal(size=args) 63 64 65@np_utils.np_doc('random.standard_normal') 66def standard_normal(size=None): 67 # TODO(wangpeng): Use new stateful RNG 68 if size is None: 69 size = () 70 elif np_utils.isscalar(size): 71 size = (size,) 72 dtype = np_dtypes.default_float_type() 73 return random_ops.random_normal(size, dtype=dtype) 74 75 76@np_utils.np_doc('random.uniform') 77def uniform(low=0.0, high=1.0, size=None): 78 dtype = np_dtypes.default_float_type() 79 low = np_array_ops.asarray(low, dtype=dtype) 80 high = np_array_ops.asarray(high, dtype=dtype) 81 if size is None: 82 size = array_ops.broadcast_dynamic_shape(low.shape, high.shape) 83 return random_ops.random_uniform( 84 shape=size, minval=low, maxval=high, dtype=dtype) 85 86 87@np_utils.np_doc('random.poisson') 88def poisson(lam=1.0, size=None): 89 if size is None: 90 size = () 91 elif np_utils.isscalar(size): 92 size = (size,) 93 return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_) 94 95 96@np_utils.np_doc('random.random') 97def random(size=None): 98 return uniform(0., 1., size) 99 100 101@np_utils.np_doc('random.rand') 102def rand(*size): 103 return uniform(0., 1., size) 104 105 106@np_utils.np_doc('random.randint') 107def randint(low, high=None, size=None, dtype=onp.int64): # pylint: disable=missing-function-docstring 108 low = int(low) 109 if high is None: 110 high = low 111 low = 0 112 if size is None: 113 size = () 114 elif isinstance(size, int): 115 size = (size,) 116 dtype_orig = dtype 117 dtype = np_utils.result_type(dtype) 118 accepted_dtypes = (onp.int32, onp.int64) 119 if dtype not in accepted_dtypes: 120 raise ValueError( 121 f'Argument `dtype` got an invalid value {dtype_orig}. Only those ' 122 f'convertible to {accepted_dtypes} are supported.') 123 return random_ops.random_uniform( 124 shape=size, minval=low, maxval=high, dtype=dtype) 125