• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace dnn = se::dnn;
33 
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,const BufferAllocation::Slice & output)34 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
35     ThunkInfo thunk_info, CudnnBatchNormConfig config,
36     const BufferAllocation::Slice& operand,
37     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
38     const BufferAllocation::Slice& mean,
39     const BufferAllocation::Slice& variance,
40     const BufferAllocation::Slice& output)
41     : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
42       config_(std::move(config)),
43       operand_(operand),
44       scale_(scale),
45       offset_(offset),
46       mean_(mean),
47       variance_(variance),
48       output_(output) {}
49 
ExecuteOnStream(const ExecuteParams & params)50 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
51     const ExecuteParams& params) {
52   auto& buffer_allocations = *params.buffer_allocations;
53   auto op_profiler =
54       params.profiler->MakeScopedInstructionProfiler(profile_index());
55   se::DeviceMemoryBase output_base =
56       buffer_allocations.GetDeviceAddress(output_);
57   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
58   se::DeviceMemory<float> scale(buffer_allocations.GetDeviceAddress(scale_));
59   se::DeviceMemory<float> offset(buffer_allocations.GetDeviceAddress(offset_));
60   se::DeviceMemory<float> mean(buffer_allocations.GetDeviceAddress(mean_));
61   se::DeviceMemory<float> variance(
62       buffer_allocations.GetDeviceAddress(variance_));
63   auto& stream = *params.stream;
64   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference(
65       config_, operand, output_base, scale, offset, mean, variance, &stream));
66 
67   if (!stream.ok()) {
68     return InternalError("BatchNormalizationForward call failed.");
69   }
70   return Status::OK();
71 }
72 
CudnnBatchNormForwardTrainingThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev)73 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
74     ThunkInfo thunk_info, CudnnBatchNormConfig config,
75     const BufferAllocation::Slice& operand,
76     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
77     const BufferAllocation::Slice& output_data,
78     const BufferAllocation::Slice& output_mean,
79     const BufferAllocation::Slice& output_inv_stddev)
80     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
81       config_(std::move(config)),
82       operand_(operand),
83       scale_(scale),
84       offset_(offset),
85       output_data_(output_data),
86       output_mean_(output_mean),
87       output_inv_stddev_(output_inv_stddev) {}
88 
ExecuteOnStream(const ExecuteParams & params)89 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
90     const ExecuteParams& params) {
91   auto& buffer_allocations = *params.buffer_allocations;
92   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
93   se::DeviceMemoryBase output_data =
94       buffer_allocations.GetDeviceAddress(output_data_);
95 
96   se::DeviceMemory<float> output_mean(
97       buffer_allocations.GetDeviceAddress(output_mean_));
98   se::DeviceMemory<float> output_inv_stddev(
99       buffer_allocations.GetDeviceAddress(output_inv_stddev_));
100 
101   se::DeviceMemory<float> null_device_ptr(nullptr);
102   auto op_profiler =
103       params.profiler->MakeScopedInstructionProfiler(profile_index());
104   auto& stream = *params.stream;
105   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
106       config_, operand, output_data, output_mean, output_inv_stddev,
107       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
108       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
109       &stream));
110 
111   if (!stream.ok()) {
112     return InternalError("BatchNormalizationTraining call failed.");
113   }
114   return Status::OK();
115 }
116 
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset)117 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
118     ThunkInfo thunk_info, CudnnBatchNormConfig config,
119     const BufferAllocation::Slice& operand,
120     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
121     const BufferAllocation::Slice& inv_stddev,
122     const BufferAllocation::Slice& grad_output,
123     const BufferAllocation::Slice& output_grad_data,
124     const BufferAllocation::Slice& output_grad_scale,
125     const BufferAllocation::Slice& output_grad_offset)
126     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
127       config_(std::move(config)),
128       operand_(operand),
129       scale_(scale),
130       mean_(mean),
131       inv_stddev_(inv_stddev),
132       grad_output_(grad_output),
133       output_grad_data_(output_grad_data),
134       output_grad_scale_(output_grad_scale),
135       output_grad_offset_(output_grad_offset) {}
136 
ExecuteOnStream(const ExecuteParams & params)137 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
138     const ExecuteParams& params) {
139   auto& buffer_allocations = *params.buffer_allocations;
140   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
141   se::DeviceMemoryBase output_grad_data =
142       buffer_allocations.GetDeviceAddress(output_grad_data_);
143   se::DeviceMemoryBase grad_output =
144       buffer_allocations.GetDeviceAddress(grad_output_);
145   se::DeviceMemory<float> output_grad_scale(
146       buffer_allocations.GetDeviceAddress(output_grad_scale_));
147   se::DeviceMemory<float> output_grad_offset(
148       buffer_allocations.GetDeviceAddress(output_grad_offset_));
149 
150   auto op_profiler =
151       params.profiler->MakeScopedInstructionProfiler(profile_index());
152   se::Stream* stream = params.stream;
153   TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
154       config_, operand, output_grad_data, grad_output, output_grad_scale,
155       output_grad_offset,
156       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
157       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
158       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
159       stream));
160 
161   if (!stream->ok()) {
162     return InternalError("BatchNormalizationBackward call failed.");
163   }
164   return Status::OK();
165 }
166 
167 }  // namespace gpu
168 }  // namespace xla
169