1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace dnn = se::dnn;
33
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,const BufferAllocation::Slice & output)34 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
35 ThunkInfo thunk_info, CudnnBatchNormConfig config,
36 const BufferAllocation::Slice& operand,
37 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
38 const BufferAllocation::Slice& mean,
39 const BufferAllocation::Slice& variance,
40 const BufferAllocation::Slice& output)
41 : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
42 config_(std::move(config)),
43 operand_(operand),
44 scale_(scale),
45 offset_(offset),
46 mean_(mean),
47 variance_(variance),
48 output_(output) {}
49
ExecuteOnStream(const ExecuteParams & params)50 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
51 const ExecuteParams& params) {
52 auto& buffer_allocations = *params.buffer_allocations;
53 auto op_profiler =
54 params.profiler->MakeScopedInstructionProfiler(profile_index());
55 se::DeviceMemoryBase output_base =
56 buffer_allocations.GetDeviceAddress(output_);
57 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
58 se::DeviceMemory<float> scale(buffer_allocations.GetDeviceAddress(scale_));
59 se::DeviceMemory<float> offset(buffer_allocations.GetDeviceAddress(offset_));
60 se::DeviceMemory<float> mean(buffer_allocations.GetDeviceAddress(mean_));
61 se::DeviceMemory<float> variance(
62 buffer_allocations.GetDeviceAddress(variance_));
63 auto& stream = *params.stream;
64 TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference(
65 config_, operand, output_base, scale, offset, mean, variance, &stream));
66
67 if (!stream.ok()) {
68 return InternalError("BatchNormalizationForward call failed.");
69 }
70 return Status::OK();
71 }
72
CudnnBatchNormForwardTrainingThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev)73 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
74 ThunkInfo thunk_info, CudnnBatchNormConfig config,
75 const BufferAllocation::Slice& operand,
76 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
77 const BufferAllocation::Slice& output_data,
78 const BufferAllocation::Slice& output_mean,
79 const BufferAllocation::Slice& output_inv_stddev)
80 : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
81 config_(std::move(config)),
82 operand_(operand),
83 scale_(scale),
84 offset_(offset),
85 output_data_(output_data),
86 output_mean_(output_mean),
87 output_inv_stddev_(output_inv_stddev) {}
88
ExecuteOnStream(const ExecuteParams & params)89 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
90 const ExecuteParams& params) {
91 auto& buffer_allocations = *params.buffer_allocations;
92 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
93 se::DeviceMemoryBase output_data =
94 buffer_allocations.GetDeviceAddress(output_data_);
95
96 se::DeviceMemory<float> output_mean(
97 buffer_allocations.GetDeviceAddress(output_mean_));
98 se::DeviceMemory<float> output_inv_stddev(
99 buffer_allocations.GetDeviceAddress(output_inv_stddev_));
100
101 se::DeviceMemory<float> null_device_ptr(nullptr);
102 auto op_profiler =
103 params.profiler->MakeScopedInstructionProfiler(profile_index());
104 auto& stream = *params.stream;
105 TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
106 config_, operand, output_data, output_mean, output_inv_stddev,
107 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
108 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
109 &stream));
110
111 if (!stream.ok()) {
112 return InternalError("BatchNormalizationTraining call failed.");
113 }
114 return Status::OK();
115 }
116
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset)117 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
118 ThunkInfo thunk_info, CudnnBatchNormConfig config,
119 const BufferAllocation::Slice& operand,
120 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
121 const BufferAllocation::Slice& inv_stddev,
122 const BufferAllocation::Slice& grad_output,
123 const BufferAllocation::Slice& output_grad_data,
124 const BufferAllocation::Slice& output_grad_scale,
125 const BufferAllocation::Slice& output_grad_offset)
126 : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
127 config_(std::move(config)),
128 operand_(operand),
129 scale_(scale),
130 mean_(mean),
131 inv_stddev_(inv_stddev),
132 grad_output_(grad_output),
133 output_grad_data_(output_grad_data),
134 output_grad_scale_(output_grad_scale),
135 output_grad_offset_(output_grad_offset) {}
136
ExecuteOnStream(const ExecuteParams & params)137 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
138 const ExecuteParams& params) {
139 auto& buffer_allocations = *params.buffer_allocations;
140 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
141 se::DeviceMemoryBase output_grad_data =
142 buffer_allocations.GetDeviceAddress(output_grad_data_);
143 se::DeviceMemoryBase grad_output =
144 buffer_allocations.GetDeviceAddress(grad_output_);
145 se::DeviceMemory<float> output_grad_scale(
146 buffer_allocations.GetDeviceAddress(output_grad_scale_));
147 se::DeviceMemory<float> output_grad_offset(
148 buffer_allocations.GetDeviceAddress(output_grad_offset_));
149
150 auto op_profiler =
151 params.profiler->MakeScopedInstructionProfiler(profile_index());
152 se::Stream* stream = params.stream;
153 TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
154 config_, operand, output_grad_data, grad_output, output_grad_scale,
155 output_grad_offset,
156 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
157 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
158 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
159 stream));
160
161 if (!stream->ok()) {
162 return InternalError("BatchNormalizationBackward call failed.");
163 }
164 return Status::OK();
165 }
166
167 } // namespace gpu
168 } // namespace xla
169