1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
18
19 #include "tensorflow/core/framework/resource_var.h"
20 #include "tensorflow/core/kernels/stateful_random_ops.h"
21
22 namespace tensorflow {
23
24 // The following 5 functions are made templates to avoid duplicate symbols when
25 // linking.
26
27 // The following 2 functions use the contract "lower 32 bits for the first
28 // uint32, higher 32 bits for the second". Note that this is endian-neutral,
29 // unlike a direct memory copy `memcpy(output, &input, 8)`.
Int64ToUint32s(int64 input,uint32 * output1,uint32 * output2)30 PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1,
31 uint32* output2) {
32 auto u64 = static_cast<uint64>(input);
33 *output1 = static_cast<uint32>(u64);
34 *output2 = static_cast<uint32>(u64 >> 32);
35 }
36
Uint32sToInt64(uint32 input1,uint32 input2)37 PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) {
38 auto u64_1 = static_cast<uint64>(input1);
39 auto u64_2 = static_cast<uint64>(input2);
40 return static_cast<int64>(u64_1 | (u64_2 << 32));
41 }
42
43 PHILOX_DEVICE_INLINE PhiloxRandom
GetPhiloxRandomFromMem(StateElementType const * ptr)44 GetPhiloxRandomFromMem(StateElementType const* ptr) {
45 PhiloxRandom::ResultType counter;
46 PhiloxRandom::Key key;
47 Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
48 Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
49 Int64ToUint32s(ptr[2], &key[0], &key[1]);
50 return PhiloxRandom(counter, key);
51 }
52
WritePhiloxRandomToMem(PhiloxRandom const & philox,StateElementType * ptr)53 PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
54 StateElementType* ptr) {
55 PhiloxRandom::ResultType const& counter = philox.counter();
56 PhiloxRandom::Key const& key = philox.key();
57 ptr[0] = Uint32sToInt64(counter[0], counter[1]);
58 ptr[1] = Uint32sToInt64(counter[2], counter[3]);
59 ptr[2] = Uint32sToInt64(key[0], key[1]);
60 }
61
UpdateMemWithPhiloxRandom(PhiloxRandom const & philox,int64 output_size,StateElementType * ptr)62 PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
63 int64 output_size,
64 StateElementType* ptr) {
65 auto new_philox = philox;
66 // Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
67 // it just here.
68 auto delta = output_size * 256;
69 new_philox.Skip(delta); // do the actual increasing
70 WritePhiloxRandomToMem(new_philox, ptr);
71 }
72
73 // A per-device helper function that does the actual work for
74 // `UpdateVariableAndFill`.
75 // Reason to use functor: C++ doesn't allow function-template partial
76 // specialization.
77 template <typename Device, typename Distribution>
78 struct UpdateVariableAndFill_Philox;
79
80 using CPUDevice = Eigen::ThreadPoolDevice;
81
82 #if GOOGLE_CUDA
83
84 using GPUDevice = Eigen::GpuDevice;
85
86 // Declares the partially GPU-specialized functor struct.
87 template <typename Distribution>
88 struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
89 void operator()(OpKernelContext* ctx, const GPUDevice& device,
90 int64 output_size, int64 alg_tag_skip,
91 ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
92 typename Distribution::ResultElementType* output_data);
93 };
94
95 #endif // GOOGLE_CUDA
96
97 } // end namespace tensorflow
98
99 #endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
100