• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/cleanup.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/platform/cuda_libdevice_path.h"
26 #include "tensorflow/core/platform/regexp.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/stream_executor/kernel_spec.h"
31 
32 namespace xla {
33 namespace gpu {
34 
35 using se::dnn::DataLayout;
36 using se::dnn::DataLayoutString;
37 using se::dnn::FilterLayout;
38 using se::dnn::FilterLayoutString;
39 
IsVoltaOrLater(const se::StreamExecutor & stream_executor)40 bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
41   int major, minor;
42   CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
43                                                                        &minor));
44   return major >= 7;
45 }
46 
47 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)48 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
49                                       DataLayout input, FilterLayout filter,
50                                       DataLayout output) {
51   std::vector<int64> input_layout;
52   switch (input) {
53     case DataLayout::kBatchDepthYX:
54       input_layout.push_back(dnums.input_batch_dimension());
55       input_layout.push_back(dnums.input_feature_dimension());
56       input_layout.insert(input_layout.end(),
57                           dnums.input_spatial_dimensions().begin(),
58                           dnums.input_spatial_dimensions().end());
59       break;
60     case DataLayout::kBatchYXDepth:
61       input_layout.push_back(dnums.input_batch_dimension());
62       input_layout.insert(input_layout.end(),
63                           dnums.input_spatial_dimensions().begin(),
64                           dnums.input_spatial_dimensions().end());
65       input_layout.push_back(dnums.input_feature_dimension());
66       break;
67     default:
68       return InternalError("Invalid input layout %s for conv with dnums %s",
69                            DataLayoutString(input),
70                            ConvolutionDimensionNumbersToString(dnums));
71   }
72 
73   std::vector<int64> filter_layout;
74   switch (filter) {
75     case FilterLayout::kOutputInputYX:
76       filter_layout.push_back(dnums.kernel_output_feature_dimension());
77       filter_layout.push_back(dnums.kernel_input_feature_dimension());
78       filter_layout.insert(filter_layout.end(),
79                            dnums.kernel_spatial_dimensions().begin(),
80                            dnums.kernel_spatial_dimensions().end());
81       break;
82     case FilterLayout::kOutputYXInput:
83       filter_layout.push_back(dnums.kernel_output_feature_dimension());
84       filter_layout.insert(filter_layout.end(),
85                            dnums.kernel_spatial_dimensions().begin(),
86                            dnums.kernel_spatial_dimensions().end());
87       filter_layout.push_back(dnums.kernel_input_feature_dimension());
88       break;
89     default:
90       return InternalError("Invalid filter layout %s for conv with dnums %s",
91                            FilterLayoutString(filter),
92                            ConvolutionDimensionNumbersToString(dnums));
93   }
94 
95   std::vector<int64> output_layout;
96   switch (output) {
97     case DataLayout::kBatchDepthYX:
98       output_layout.push_back(dnums.output_batch_dimension());
99       output_layout.push_back(dnums.output_feature_dimension());
100       output_layout.insert(output_layout.end(),
101                            dnums.output_spatial_dimensions().begin(),
102                            dnums.output_spatial_dimensions().end());
103       break;
104     case DataLayout::kBatchYXDepth:
105       output_layout.push_back(dnums.output_batch_dimension());
106       output_layout.insert(output_layout.end(),
107                            dnums.output_spatial_dimensions().begin(),
108                            dnums.output_spatial_dimensions().end());
109       output_layout.push_back(dnums.output_feature_dimension());
110       break;
111     default:
112       return InternalError("Invalid output layout %s for conv with dnums %s",
113                            DataLayoutString(output),
114                            ConvolutionDimensionNumbersToString(dnums));
115   }
116 
117   return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
118                          LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
119                          LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
120 }
121 
122 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Layout & input,const Layout & filter,const Layout & output)123 XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
124                                       const Layout& input, const Layout& filter,
125                                       const Layout& output) {
126   Layout nchw_input, nchw_filter, nchw_output;
127   std::tie(nchw_input, nchw_filter, nchw_output) =
128       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
129                                             FilterLayout::kOutputInputYX,
130                                             DataLayout::kBatchDepthYX)
131           .ConsumeValueOrDie();
132 
133   Layout nhwc_input, nhwc_filter, nhwc_output;
134   std::tie(nhwc_input, nhwc_filter, nhwc_output) =
135       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
136                                             FilterLayout::kOutputYXInput,
137                                             DataLayout::kBatchYXDepth)
138           .ConsumeValueOrDie();
139 
140   DataLayout input_layout;
141   if (LayoutUtil::Equal(input, nchw_input)) {
142     input_layout = DataLayout::kBatchDepthYX;
143   } else if (LayoutUtil::Equal(input, nhwc_input)) {
144     input_layout = DataLayout::kBatchYXDepth;
145   } else {
146     return InternalError("Invalid input layout %s for conv with dnums %s",
147                          LayoutUtil::HumanString(input),
148                          ConvolutionDimensionNumbersToString(dnums));
149   }
150 
151   FilterLayout filter_layout;
152   if (LayoutUtil::Equal(filter, nchw_filter)) {
153     filter_layout = FilterLayout::kOutputInputYX;
154   } else if (LayoutUtil::Equal(filter, nhwc_filter)) {
155     filter_layout = FilterLayout::kOutputYXInput;
156   } else {
157     return InternalError("Invalid filter layout %s for conv with dnums %s",
158                          LayoutUtil::HumanString(filter),
159                          ConvolutionDimensionNumbersToString(dnums));
160   }
161 
162   DataLayout output_layout;
163   if (LayoutUtil::Equal(output, nchw_output)) {
164     output_layout = DataLayout::kBatchDepthYX;
165   } else if (LayoutUtil::Equal(output, nhwc_output)) {
166     output_layout = DataLayout::kBatchYXDepth;
167   } else {
168     return InternalError("Invalid output layout %s for conv with dnums %s",
169                          LayoutUtil::HumanString(output),
170                          ConvolutionDimensionNumbersToString(dnums));
171   }
172 
173   return std::make_tuple(input_layout, filter_layout, output_layout);
174 }
175 
LockGpu(const se::StreamExecutor * stream_exec)176 tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
177   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
178   // se::Platform*s are global singletons guaranteed to live forever.
179   static auto* mutexes =
180       new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
181                    tensorflow::mutex>();
182 
183   tensorflow::mutex_lock global_lock(mu);
184   auto it = mutexes
185                 ->emplace(std::piecewise_construct,
186                           std::make_tuple(stream_exec->platform(),
187                                           stream_exec->device_ordinal()),
188                           std::make_tuple())
189                 .first;
190   return tensorflow::mutex_lock{it->second};
191 }
192 
CreateKernel(absl::string_view kernel_name,uint64 num_args,absl::string_view ptx,absl::Span<const uint8> cubin_data,se::StreamExecutor * stream_exec)193 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
194     absl::string_view kernel_name, uint64 num_args, absl::string_view ptx,
195     absl::Span<const uint8> cubin_data, se::StreamExecutor* stream_exec) {
196   se::MultiKernelLoaderSpec loader_spec(num_args);
197   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
198 
199   if (!cubin_data.empty()) {
200     loader_spec.AddCudaCubinInMemory(
201         reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
202   }
203 
204   auto kernel_base = absl::make_unique<se::KernelBase>(stream_exec);
205   TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
206   return std::move(kernel_base);
207 }
208 
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,int64 threads_per_block,int64 block_count,se::Stream * stream)209 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
210                              absl::Span<const se::DeviceMemoryBase> args,
211                              int64 threads_per_block, int64 block_count,
212                              se::Stream* stream) {
213   static constexpr int kKernelArgsLimit = 1024;
214   auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
215   for (const se::DeviceMemoryBase& buf : args) {
216     kernel_args->add_device_memory_argument(buf);
217   }
218   return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block),
219                                   se::BlockDim(block_count), kernel,
220                                   *kernel_args);
221 }
222 
PtxOptsFromConfig(const HloModuleConfig & hlo_module_config)223 se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
224   return se::GpuAsmOpts(
225       hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(),
226       hlo_module_config.debug_options().xla_gpu_cuda_data_dir());
227 }
228 
229 // Unimplemented for integers yet.
230 template <typename T, typename Generator>
231 typename std::enable_if<std::is_integral<T>::value,
232                         T>::type static UniformDistribution(T lhs, T rhs,
233                                                             Generator* gen) =
234     delete;
235 
236 template <typename T, typename Generator>
237 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)238                         T>::type static UniformDistribution(T lhs, T rhs,
239                                                             Generator* gen) {
240   return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
241 }
242 
243 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64 * rng_state)244 static void InitializeTypedBuffer(se::Stream* stream,
245                                   se::DeviceMemoryBase buffer,
246                                   int64* rng_state) {
247   // Accesses to static variables are not locked, since the caller is already
248   // in a critical section.
249   static std::vector<T>* host_buffer = [] {
250     // Use a large prime number to fragment the accesses.
251     auto* ret = new std::vector<T>(10069);
252     // Default-seeded random numbers.
253     std::mt19937 gen;
254     for (auto& element : *ret) {
255       // Only double gets random values in double.  Other data types get random
256       // values in float then cast them to the target data types.
257       using RandomFloatingPointType =
258           typename std::conditional<std::is_same<T, Eigen::half>::value, float,
259                                     T>::type;
260       using RandomType =
261           typename std::conditional<std::is_integral<T>::value, float,
262                                     RandomFloatingPointType>::type;
263       // Scale down the values for fp16 to have less overflows.
264       auto upper_bound =
265           RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
266       auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
267       // For float or double, it is between [0,1].
268       // For fp16, it ranges between [0, 0.1].
269       // For integer types, element is either 0 or 1 for less overflows
270       // especially for int8.
271       element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
272     }
273     return ret;
274   }();
275 
276   int64& host_index = *rng_state;
277 
278   char* current_addr = static_cast<char*>(buffer.opaque());
279   CHECK_EQ(0, buffer.size() % sizeof(T));
280   int64 elements_left = buffer.size() / sizeof(T);
281   while (elements_left > 0) {
282     CHECK_LE(host_index, host_buffer->size());
283     if (host_buffer->size() == host_index) {
284       host_index = 0;
285     }
286     int64 elements_copied =
287         std::min<int64>(host_buffer->size() - host_index, elements_left);
288     se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
289     stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
290                        elements_copied * sizeof(T));
291     current_addr += elements_copied * sizeof(T);
292     elements_left -= elements_copied;
293     host_index += elements_copied;
294   }
295 }
296 
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64 * rng_state,se::DeviceMemoryBase buffer)297 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
298                       int64* rng_state, se::DeviceMemoryBase buffer) {
299   switch (buffer_type) {
300     case xla::F16:
301       return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
302     case xla::F32:
303     case xla::C64:
304       return InitializeTypedBuffer<float>(stream, buffer, rng_state);
305     case xla::F64:
306     case xla::C128:
307       return InitializeTypedBuffer<double>(stream, buffer, rng_state);
308     case xla::S8:
309       return InitializeTypedBuffer<int8>(stream, buffer, rng_state);
310     default:
311       LOG(FATAL) << "Unexpected type";
312   }
313 }
314 
315 }  // namespace gpu
316 }  // namespace xla
317