• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
22 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 namespace dnn = se::dnn;
32 
CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & variance,const BufferAllocation::Slice & output)33 CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
34     ThunkInfo thunk_info, CudnnBatchNormConfig config,
35     const BufferAllocation::Slice& operand,
36     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
37     const BufferAllocation::Slice& mean,
38     const BufferAllocation::Slice& variance,
39     const BufferAllocation::Slice& output)
40     : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info),
41       config_(std::move(config)),
42       operand_(operand),
43       scale_(scale),
44       offset_(offset),
45       mean_(mean),
46       variance_(variance),
47       output_(output) {}
48 
ExecuteOnStream(const ExecuteParams & params)49 Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
50     const ExecuteParams& params) {
51   auto& buffer_allocations = *params.buffer_allocations;
52   se::DeviceMemoryBase output_base =
53       buffer_allocations.GetDeviceAddress(output_);
54   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
55   se::DeviceMemory<float> scale(buffer_allocations.GetDeviceAddress(scale_));
56   se::DeviceMemory<float> offset(buffer_allocations.GetDeviceAddress(offset_));
57   se::DeviceMemory<float> mean(buffer_allocations.GetDeviceAddress(mean_));
58   se::DeviceMemory<float> variance(
59       buffer_allocations.GetDeviceAddress(variance_));
60   auto& stream = *params.stream;
61   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference(
62       config_, operand, output_base, scale, offset, mean, variance, &stream));
63 
64   if (!stream.ok()) {
65     return InternalError("BatchNormalizationForward call failed.");
66   }
67   return Status::OK();
68 }
69 
CudnnBatchNormForwardTrainingThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & offset,const BufferAllocation::Slice & output_data,const BufferAllocation::Slice & output_mean,const BufferAllocation::Slice & output_inv_stddev)70 CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
71     ThunkInfo thunk_info, CudnnBatchNormConfig config,
72     const BufferAllocation::Slice& operand,
73     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
74     const BufferAllocation::Slice& output_data,
75     const BufferAllocation::Slice& output_mean,
76     const BufferAllocation::Slice& output_inv_stddev)
77     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
78       config_(std::move(config)),
79       operand_(operand),
80       scale_(scale),
81       offset_(offset),
82       output_data_(output_data),
83       output_mean_(output_mean),
84       output_inv_stddev_(output_inv_stddev) {}
85 
ExecuteOnStream(const ExecuteParams & params)86 Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
87     const ExecuteParams& params) {
88   auto& buffer_allocations = *params.buffer_allocations;
89   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
90   se::DeviceMemoryBase output_data =
91       buffer_allocations.GetDeviceAddress(output_data_);
92 
93   se::DeviceMemory<float> output_mean(
94       buffer_allocations.GetDeviceAddress(output_mean_));
95   se::DeviceMemory<float> output_inv_stddev(
96       buffer_allocations.GetDeviceAddress(output_inv_stddev_));
97 
98   se::DeviceMemory<float> null_device_ptr(nullptr);
99   auto& stream = *params.stream;
100   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining(
101       config_, operand, output_data, output_mean, output_inv_stddev,
102       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
103       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
104       &stream));
105 
106   if (!stream.ok()) {
107     return InternalError("BatchNormalizationTraining call failed.");
108   }
109   return Status::OK();
110 }
111 
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,CudnnBatchNormConfig config,const BufferAllocation::Slice & operand,const BufferAllocation::Slice & scale,const BufferAllocation::Slice & mean,const BufferAllocation::Slice & inv_stddev,const BufferAllocation::Slice & grad_output,const BufferAllocation::Slice & output_grad_data,const BufferAllocation::Slice & output_grad_scale,const BufferAllocation::Slice & output_grad_offset)112 CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
113     ThunkInfo thunk_info, CudnnBatchNormConfig config,
114     const BufferAllocation::Slice& operand,
115     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
116     const BufferAllocation::Slice& inv_stddev,
117     const BufferAllocation::Slice& grad_output,
118     const BufferAllocation::Slice& output_grad_data,
119     const BufferAllocation::Slice& output_grad_scale,
120     const BufferAllocation::Slice& output_grad_offset)
121     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
122       config_(std::move(config)),
123       operand_(operand),
124       scale_(scale),
125       mean_(mean),
126       inv_stddev_(inv_stddev),
127       grad_output_(grad_output),
128       output_grad_data_(output_grad_data),
129       output_grad_scale_(output_grad_scale),
130       output_grad_offset_(output_grad_offset) {}
131 
ExecuteOnStream(const ExecuteParams & params)132 Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
133     const ExecuteParams& params) {
134   auto& buffer_allocations = *params.buffer_allocations;
135   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_);
136   se::DeviceMemoryBase output_grad_data =
137       buffer_allocations.GetDeviceAddress(output_grad_data_);
138   se::DeviceMemoryBase grad_output =
139       buffer_allocations.GetDeviceAddress(grad_output_);
140   se::DeviceMemory<float> output_grad_scale(
141       buffer_allocations.GetDeviceAddress(output_grad_scale_));
142   se::DeviceMemory<float> output_grad_offset(
143       buffer_allocations.GetDeviceAddress(output_grad_offset_));
144 
145   se::Stream* stream = params.stream;
146   TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward(
147       config_, operand, output_grad_data, grad_output, output_grad_scale,
148       output_grad_offset,
149       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
150       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(mean_)),
151       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
152       stream));
153 
154   if (!stream->ok()) {
155     return InternalError("BatchNormalizationBackward call failed.");
156   }
157   return Status::OK();
158 }
159 
160 }  // namespace gpu
161 }  // namespace xla
162