1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27
28 namespace xla {
29 namespace gpu {
30
31 namespace dnn = se::dnn;
32
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,const BufferAllocation::Slice & output)33 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
34 ThunkInfo thunk_info, CudnnBatchNormConfig config,
35 const BufferAllocation::Slice& operand,
36 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
37 const BufferAllocation::Slice& mean,
38 const BufferAllocation::Slice& variance,
39 const BufferAllocation::Slice& output)
40 : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
41 config_(std::move(config)),
42 operand_(operand),
43 scale_(scale),
44 offset_(offset),
45 mean_(mean),
46 variance_(variance),
47 output_(output) {}
48
ExecuteOnStream(const ExecuteParams & params)49 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
50 const ExecuteParams& params) {
51 auto& buffer_allocations = *params.buffer_allocations;
52 se::DeviceMemoryBase output_base =
53 buffer_allocations.GetDeviceAddress(output_);
54 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
55 se::DeviceMemory<float> scale(buffer_allocations.GetDeviceAddress(scale_));
56 se::DeviceMemory<float> offset(buffer_allocations.GetDeviceAddress(offset_));
57 se::DeviceMemory<float> mean(buffer_allocations.GetDeviceAddress(mean_));
58 se::DeviceMemory<float> variance(
59 buffer_allocations.GetDeviceAddress(variance_));
60 auto& stream = *params.stream;
61 TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference(
62 config_, operand, output_base, scale, offset, mean, variance, &stream));
63
64 if (!stream.ok()) {
65 return InternalError("BatchNormalizationForward call failed.");
66 }
67 return Status::OK();
68 }
69
CudnnBatchNormForwardTrainingThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev)70 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
71 ThunkInfo thunk_info, CudnnBatchNormConfig config,
72 const BufferAllocation::Slice& operand,
73 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
74 const BufferAllocation::Slice& output_data,
75 const BufferAllocation::Slice& output_mean,
76 const BufferAllocation::Slice& output_inv_stddev)
77 : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
78 config_(std::move(config)),
79 operand_(operand),
80 scale_(scale),
81 offset_(offset),
82 output_data_(output_data),
83 output_mean_(output_mean),
84 output_inv_stddev_(output_inv_stddev) {}
85
ExecuteOnStream(const ExecuteParams & params)86 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
87 const ExecuteParams& params) {
88 auto& buffer_allocations = *params.buffer_allocations;
89 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
90 se::DeviceMemoryBase output_data =
91 buffer_allocations.GetDeviceAddress(output_data_);
92
93 se::DeviceMemory<float> output_mean(
94 buffer_allocations.GetDeviceAddress(output_mean_));
95 se::DeviceMemory<float> output_inv_stddev(
96 buffer_allocations.GetDeviceAddress(output_inv_stddev_));
97
98 se::DeviceMemory<float> null_device_ptr(nullptr);
99 auto& stream = *params.stream;
100 TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
101 config_, operand, output_data, output_mean, output_inv_stddev,
102 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
103 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
104 &stream));
105
106 if (!stream.ok()) {
107 return InternalError("BatchNormalizationTraining call failed.");
108 }
109 return Status::OK();
110 }
111
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset)112 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
113 ThunkInfo thunk_info, CudnnBatchNormConfig config,
114 const BufferAllocation::Slice& operand,
115 const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
116 const BufferAllocation::Slice& inv_stddev,
117 const BufferAllocation::Slice& grad_output,
118 const BufferAllocation::Slice& output_grad_data,
119 const BufferAllocation::Slice& output_grad_scale,
120 const BufferAllocation::Slice& output_grad_offset)
121 : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
122 config_(std::move(config)),
123 operand_(operand),
124 scale_(scale),
125 mean_(mean),
126 inv_stddev_(inv_stddev),
127 grad_output_(grad_output),
128 output_grad_data_(output_grad_data),
129 output_grad_scale_(output_grad_scale),
130 output_grad_offset_(output_grad_offset) {}
131
ExecuteOnStream(const ExecuteParams & params)132 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
133 const ExecuteParams& params) {
134 auto& buffer_allocations = *params.buffer_allocations;
135 se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
136 se::DeviceMemoryBase output_grad_data =
137 buffer_allocations.GetDeviceAddress(output_grad_data_);
138 se::DeviceMemoryBase grad_output =
139 buffer_allocations.GetDeviceAddress(grad_output_);
140 se::DeviceMemory<float> output_grad_scale(
141 buffer_allocations.GetDeviceAddress(output_grad_scale_));
142 se::DeviceMemory<float> output_grad_offset(
143 buffer_allocations.GetDeviceAddress(output_grad_offset_));
144
145 se::Stream* stream = params.stream;
146 TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
147 config_, operand, output_grad_data, grad_output, output_grad_scale,
148 output_grad_offset,
149 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
150 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
151 se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
152 stream));
153
154 if (!stream->ok()) {
155 return InternalError("BatchNormalizationBackward call failed.");
156 }
157 return Status::OK();
158 }
159
160 } // namespace gpu
161 } // namespace xla
162