• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/bounds_check.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/kernels/random_op.h"
22 #include "tensorflow/core/lib/random/random_distributions.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace tensorflow {
26 
27 using CPUDevice = Eigen::ThreadPoolDevice;
28 using GPUDevice = Eigen::GpuDevice;
29 
GenerateKey(Tensor seed,random::PhiloxRandom::Key * out_key,random::PhiloxRandom::ResultType * out_counter)30 Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
31                    random::PhiloxRandom::ResultType* out_counter) {
32   // Grab the two seeds
33   uint64 seed0;
34   uint64 seed1;
35   if (seed.dtype() == DT_INT32) {
36     const auto seed_vals = seed.flat<int32>();
37     seed0 = internal::SubtleMustCopy(seed_vals(0));
38     seed1 = internal::SubtleMustCopy(seed_vals(1));
39   } else if (seed.dtype() == DT_INT64) {
40     const auto seed_vals = seed.flat<int64>();
41     seed0 = internal::SubtleMustCopy(seed_vals(0));
42     seed1 = internal::SubtleMustCopy(seed_vals(1));
43   } else {
44     return errors::InvalidArgument("Invalid seed type: ",
45                                    DataTypeString(seed.dtype()));
46   }
47 
48   // Scramble the seeds so that the user doesn't need to worry about which
49   // part of the seed needs to be strong.
50   (*out_key)[0] = 0x3ec8f720;
51   (*out_key)[1] = 0x02461e29;
52   (*out_counter)[0] = static_cast<uint32>(seed0);
53   (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
54   (*out_counter)[2] = static_cast<uint32>(seed1);
55   (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
56   const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
57   (*out_key)[0] = mix[0];
58   (*out_key)[1] = mix[1];
59   (*out_counter)[0] = (*out_counter)[1] = 0;
60   (*out_counter)[2] = mix[2];
61   (*out_counter)[3] = mix[3];
62   return Status::OK();
63 }
64 
65 namespace {
66 
67 class StatelessRandomOpBase : public OpKernel {
68  public:
StatelessRandomOpBase(OpKernelConstruction * context)69   explicit StatelessRandomOpBase(OpKernelConstruction* context)
70       : OpKernel(context) {}
71 
Compute(OpKernelContext * context)72   void Compute(OpKernelContext* context) override {
73     // Sanitize input
74     const Tensor& shape_t = context->input(0);
75     const Tensor& seed_t = context->input(1);
76     TensorShape shape;
77     OP_REQUIRES_OK(context, MakeShape(shape_t, &shape));
78     OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
79                 errors::InvalidArgument("seed must have shape [2], not ",
80                                         seed_t.shape().DebugString()));
81 
82     // Allocate output
83     Tensor* output;
84     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
85     if (shape.num_elements() == 0) return;
86 
87     random::PhiloxRandom::Key key;
88     random::PhiloxRandom::ResultType counter;
89     OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
90 
91     // Fill in the random numbers
92     Fill(context, random::PhiloxRandom(counter, key), output);
93   }
94 
95   // The part of Compute that depends on device, type, and distribution
96   virtual void Fill(OpKernelContext* context, random::PhiloxRandom random,
97                     Tensor* output) = 0;
98 };
99 
100 template <typename Device, class Distribution>
101 class StatelessRandomOp : public StatelessRandomOpBase {
102  public:
103   using StatelessRandomOpBase::StatelessRandomOpBase;
104 
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)105   void Fill(OpKernelContext* context, random::PhiloxRandom random,
106             Tensor* output) override {
107     typedef typename Distribution::ResultElementType T;
108     auto flat = output->flat<T>();
109     // Reuse the compute kernels from the stateful random ops
110     functor::FillPhiloxRandom<Device, Distribution>()(
111         context, context->eigen_device<Device>(), random, flat.data(),
112         flat.size(), Distribution());
113   }
114 };
115 
116 template <typename Device, typename IntType>
117 class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
118  public:
119   using StatelessRandomOpBase::StatelessRandomOpBase;
120 
Fill(OpKernelContext * context,random::PhiloxRandom random,Tensor * output)121   void Fill(OpKernelContext* context, random::PhiloxRandom random,
122             Tensor* output) override {
123     const Tensor& minval = context->input(2);
124     const Tensor& maxval = context->input(3);
125     OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
126                 errors::InvalidArgument("minval must be 0-D, got shape ",
127                                         minval.shape().DebugString()));
128     OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
129                 errors::InvalidArgument("maxval must be 0-D, got shape ",
130                                         maxval.shape().DebugString()));
131 
132     // Verify that minval < maxval.  Note that we'll never reach this point for
133     // empty output.  Zero impossible things are fine.
134     const auto lo = minval.scalar<IntType>()();
135     const auto hi = maxval.scalar<IntType>()();
136     OP_REQUIRES(
137         context, lo < hi,
138         errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
139 
140     // Build distribution
141     typedef random::UniformDistribution<random::PhiloxRandom, IntType>
142         Distribution;
143     Distribution dist(lo, hi);
144 
145     auto flat = output->flat<IntType>();
146     // Reuse the compute kernels from the stateful random ops
147     functor::FillPhiloxRandom<Device, Distribution>()(
148         context, context->eigen_device<Device>(), random, flat.data(),
149         flat.size(), dist);
150   }
151 };
152 
153 #define REGISTER(DEVICE, TYPE)                                              \
154   REGISTER_KERNEL_BUILDER(                                                  \
155       Name("StatelessRandomUniform")                                        \
156           .Device(DEVICE_##DEVICE)                                          \
157           .HostMemory("shape")                                              \
158           .HostMemory("seed")                                               \
159           .TypeConstraint<TYPE>("dtype"),                                   \
160       StatelessRandomOp<DEVICE##Device, random::UniformDistribution<        \
161                                             random::PhiloxRandom, TYPE> >); \
162   REGISTER_KERNEL_BUILDER(                                                  \
163       Name("StatelessRandomNormal")                                         \
164           .Device(DEVICE_##DEVICE)                                          \
165           .HostMemory("shape")                                              \
166           .HostMemory("seed")                                               \
167           .TypeConstraint<TYPE>("dtype"),                                   \
168       StatelessRandomOp<DEVICE##Device, random::NormalDistribution<         \
169                                             random::PhiloxRandom, TYPE> >); \
170   REGISTER_KERNEL_BUILDER(                                                  \
171       Name("StatelessTruncatedNormal")                                      \
172           .Device(DEVICE_##DEVICE)                                          \
173           .HostMemory("shape")                                              \
174           .HostMemory("seed")                                               \
175           .TypeConstraint<TYPE>("dtype"),                                   \
176       StatelessRandomOp<                                                    \
177           DEVICE##Device,                                                   \
178           random::TruncatedNormalDistribution<                              \
179               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
180 
181 #define REGISTER_INT(DEVICE, TYPE)                            \
182   REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt")   \
183                               .Device(DEVICE_##DEVICE)        \
184                               .HostMemory("shape")            \
185                               .HostMemory("seed")             \
186                               .HostMemory("minval")           \
187                               .HostMemory("maxval")           \
188                               .TypeConstraint<TYPE>("dtype"), \
189                           StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
190 
191 #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
192 #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
193 #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
194 #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
195 
196 TF_CALL_half(REGISTER_CPU);
197 TF_CALL_bfloat16(REGISTER_CPU);
198 TF_CALL_float(REGISTER_CPU);
199 TF_CALL_double(REGISTER_CPU);
200 TF_CALL_int32(REGISTER_INT_CPU);
201 TF_CALL_int64(REGISTER_INT_CPU);
202 
203 #if GOOGLE_CUDA
204 
205 TF_CALL_half(REGISTER_GPU);
206 TF_CALL_float(REGISTER_GPU);
207 TF_CALL_double(REGISTER_GPU);
208 TF_CALL_int32(REGISTER_INT_GPU);
209 TF_CALL_int64(REGISTER_INT_GPU);
210 
211 #endif  // GOOGLE_CUDA
212 
213 #undef REGISTER
214 #undef REGISTER_INT
215 #undef REGISTER_CPU
216 #undef REGISTER_GPU
217 #undef REGISTER_INT_CPU
218 #undef REGISTER_INT_GPU
219 
220 }  // namespace
221 
222 }  // namespace tensorflow
223