1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/bounds_check.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/kernels/random_op.h"
22 #include "tensorflow/core/lib/random/random_distributions.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace tensorflow {
26
27 using CPUDevice = Eigen::ThreadPoolDevice;
28 using GPUDevice = Eigen::GpuDevice;
29
GenerateKey(Tensor seed,random::PhiloxRandom::Key * out_key,random::PhiloxRandom::ResultType * out_counter)30 Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
31 random::PhiloxRandom::ResultType* out_counter) {
32 // Grab the two seeds
33 uint64 seed0;
34 uint64 seed1;
35 if (seed.dtype() == DT_INT32) {
36 const auto seed_vals = seed.flat<int32>();
37 seed0 = internal::SubtleMustCopy(seed_vals(0));
38 seed1 = internal::SubtleMustCopy(seed_vals(1));
39 } else if (seed.dtype() == DT_INT64) {
40 const auto seed_vals = seed.flat<int64>();
41 seed0 = internal::SubtleMustCopy(seed_vals(0));
42 seed1 = internal::SubtleMustCopy(seed_vals(1));
43 } else {
44 return errors::InvalidArgument("Invalid seed type: ",
45 DataTypeString(seed.dtype()));
46 }
47
48 // Scramble the seeds so that the user doesn't need to worry about which
49 // part of the seed needs to be strong.
50 (*out_key)[0] = 0x3ec8f720;
51 (*out_key)[1] = 0x02461e29;
52 (*out_counter)[0] = static_cast<uint32>(seed0);
53 (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
54 (*out_counter)[2] = static_cast<uint32>(seed1);
55 (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
56 const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
57 (*out_key)[0] = mix[0];
58 (*out_key)[1] = mix[1];
59 (*out_counter)[0] = (*out_counter)[1] = 0;
60 (*out_counter)[2] = mix[2];
61 (*out_counter)[3] = mix[3];
62 return Status::OK();
63 }
64
65 namespace {
66
67 class StatelessRandomOpBase : public OpKernel {
68 public:
StatelessRandomOpBase(OpKernelConstruction * context)69 explicit StatelessRandomOpBase(OpKernelConstruction* context)
70 : OpKernel(context) {}
71
Compute(OpKernelContext * context)72 void Compute(OpKernelContext* context) override {
73 // Sanitize input
74 const Tensor& shape_t = context->input(0);
75 const Tensor& seed_t = context->input(1);
76 TensorShape shape;
77 OP_REQUIRES_OK(context, MakeShape(shape_t, &shape));
78 OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
79 errors::InvalidArgument("seed must have shape [2], not ",
80 seed_t.shape().DebugString()));
81
82 // Allocate output
83 Tensor* output;
84 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
85 if (shape.num_elements() == 0) return;
86
87 random::PhiloxRandom::Key key;
88 random::PhiloxRandom::ResultType counter;
89 OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
90
91 // Fill in the random numbers
92 Fill(context, random::PhiloxRandom(counter, key), output);
93 }
94
95 // The part of Compute that depends on device, type, and distribution
96 virtual void Fill(OpKernelContext* context, random::PhiloxRandom random,
97 Tensor* output) = 0;
98 };
99
100 template <typename Device, class Distribution>
101 class StatelessRandomOp : public StatelessRandomOpBase {
102 public:
103 using StatelessRandomOpBase::StatelessRandomOpBase;
104
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)105 void Fill(OpKernelContext* context, random::PhiloxRandom random,
106 Tensor* output) override {
107 typedef typename Distribution::ResultElementType T;
108 auto flat = output->flat<T>();
109 // Reuse the compute kernels from the stateful random ops
110 functor::FillPhiloxRandom<Device, Distribution>()(
111 context, context->eigen_device<Device>(), random, flat.data(),
112 flat.size(), Distribution());
113 }
114 };
115
116 template <typename Device, typename IntType>
117 class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
118 public:
119 using StatelessRandomOpBase::StatelessRandomOpBase;
120
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)121 void Fill(OpKernelContext* context, random::PhiloxRandom random,
122 Tensor* output) override {
123 const Tensor& minval = context->input(2);
124 const Tensor& maxval = context->input(3);
125 OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
126 errors::InvalidArgument("minval must be 0-D, got shape ",
127 minval.shape().DebugString()));
128 OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
129 errors::InvalidArgument("maxval must be 0-D, got shape ",
130 maxval.shape().DebugString()));
131
132 // Verify that minval < maxval. Note that we'll never reach this point for
133 // empty output. Zero impossible things are fine.
134 const auto lo = minval.scalar<IntType>()();
135 const auto hi = maxval.scalar<IntType>()();
136 OP_REQUIRES(
137 context, lo < hi,
138 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
139
140 // Build distribution
141 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
142 Distribution;
143 Distribution dist(lo, hi);
144
145 auto flat = output->flat<IntType>();
146 // Reuse the compute kernels from the stateful random ops
147 functor::FillPhiloxRandom<Device, Distribution>()(
148 context, context->eigen_device<Device>(), random, flat.data(),
149 flat.size(), dist);
150 }
151 };
152
153 #define REGISTER(DEVICE, TYPE) \
154 REGISTER_KERNEL_BUILDER( \
155 Name("StatelessRandomUniform") \
156 .Device(DEVICE_##DEVICE) \
157 .HostMemory("shape") \
158 .HostMemory("seed") \
159 .TypeConstraint<TYPE>("dtype"), \
160 StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
161 random::PhiloxRandom, TYPE> >); \
162 REGISTER_KERNEL_BUILDER( \
163 Name("StatelessRandomNormal") \
164 .Device(DEVICE_##DEVICE) \
165 .HostMemory("shape") \
166 .HostMemory("seed") \
167 .TypeConstraint<TYPE>("dtype"), \
168 StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
169 random::PhiloxRandom, TYPE> >); \
170 REGISTER_KERNEL_BUILDER( \
171 Name("StatelessTruncatedNormal") \
172 .Device(DEVICE_##DEVICE) \
173 .HostMemory("shape") \
174 .HostMemory("seed") \
175 .TypeConstraint<TYPE>("dtype"), \
176 StatelessRandomOp< \
177 DEVICE##Device, \
178 random::TruncatedNormalDistribution< \
179 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
180
181 #define REGISTER_INT(DEVICE, TYPE) \
182 REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
183 .Device(DEVICE_##DEVICE) \
184 .HostMemory("shape") \
185 .HostMemory("seed") \
186 .HostMemory("minval") \
187 .HostMemory("maxval") \
188 .TypeConstraint<TYPE>("dtype"), \
189 StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
190
191 #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
192 #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
193 #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
194 #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
195
196 TF_CALL_half(REGISTER_CPU);
197 TF_CALL_bfloat16(REGISTER_CPU);
198 TF_CALL_float(REGISTER_CPU);
199 TF_CALL_double(REGISTER_CPU);
200 TF_CALL_int32(REGISTER_INT_CPU);
201 TF_CALL_int64(REGISTER_INT_CPU);
202
203 #if GOOGLE_CUDA
204
205 TF_CALL_half(REGISTER_GPU);
206 TF_CALL_float(REGISTER_GPU);
207 TF_CALL_double(REGISTER_GPU);
208 TF_CALL_int32(REGISTER_INT_GPU);
209 TF_CALL_int64(REGISTER_INT_GPU);
210
211 #endif // GOOGLE_CUDA
212
213 #undef REGISTER
214 #undef REGISTER_INT
215 #undef REGISTER_CPU
216 #undef REGISTER_GPU
217 #undef REGISTER_INT_CPU
218 #undef REGISTER_INT_GPU
219
220 } // namespace
221
222 } // namespace tensorflow
223