1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
18
19 #include "tensorflow/core/framework/resource_var.h"
20 #include "tensorflow/core/kernels/random_ops_util.h"
21 #include "tensorflow/core/kernels/stateful_random_ops.h"
22
23 namespace tensorflow {
24
25 PHILOX_DEVICE_INLINE PhiloxRandom
GetPhiloxRandomFromMem(StateElementType const * ptr)26 GetPhiloxRandomFromMem(StateElementType const* ptr) {
27 auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
28 return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
29 }
30
WritePhiloxRandomToMem(PhiloxRandom const & philox,StateElementType * ptr)31 PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
32 StateElementType* ptr) {
33 auto ptr_ = reinterpret_cast<uint64*>(ptr);
34 WriteCounterToMem(philox.counter(), ptr_);
35 WriteKeyToMem(philox.key(), ptr_ + 2);
36 }
37
SkipPhiloxRandom(PhiloxRandom const & philox,uint64 output_size)38 PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
39 uint64 output_size) {
40 auto new_philox = philox;
41 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
42 // just here.
43 auto delta = output_size * 256;
44 new_philox.Skip(delta); // do the actual increasing
45 return new_philox;
46 }
47
UpdateMemWithPhiloxRandom(PhiloxRandom const & philox,uint64 output_size,StateElementType * ptr)48 PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
49 uint64 output_size,
50 StateElementType* ptr) {
51 auto new_philox = SkipPhiloxRandom(philox, output_size);
52 WritePhiloxRandomToMem(new_philox, ptr);
53 }
54
UpdateCounterMemWithPhiloxRandom(PhiloxRandom::ResultType const & counter,uint64 output_size,StateElementType * ptr)55 PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
56 PhiloxRandom::ResultType const& counter, uint64 output_size,
57 StateElementType* ptr) {
58 auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
59 auto new_philox = SkipPhiloxRandom(philox, output_size);
60 WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
61 }
62
63 namespace functor {
64
65 // A per-device helper function that does the actual work for
66 // `UpdateVariableAndFill`.
67 // Reason to use functor: C++ doesn't allow function-template partial
68 // specialization.
69 template <typename Device, typename Distribution>
70 struct UpdateVariableAndFill_Philox;
71
72 template <typename Device>
73 struct RngSkip_Philox;
74
75 } // end namespace functor
76
77 using CPUDevice = Eigen::ThreadPoolDevice;
78
79 struct UpdateVariableAndFill_Philox_Arg {
80 int64 output_size;
81 int64 alg_tag_skip;
82 ScopedUnlockUnrefVar* not_used;
83 Tensor* state_tensor;
84 };
85
86 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87
88 using GPUDevice = Eigen::GpuDevice;
89
90 namespace functor {
91
92 // Declares the partially GPU-specialized functor structs.
93 // must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
94 template <typename Distribution>
95 struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
96 void operator()(OpKernelContext* ctx, const GPUDevice& device,
97 Distribution dist, UpdateVariableAndFill_Philox_Arg* arg,
98 typename Distribution::ResultElementType* output_data);
99 };
100
101 template <>
102 struct RngSkip_Philox<GPUDevice> {
103 void operator()(const GPUDevice& device, const StateElementType* in_data,
104 uint64 delta, StateElementType* out_data);
105 };
106
107 } // end namespace functor
108
109 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
110
111 } // end namespace tensorflow
112
113 #endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
114