• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
17 
18 #include <limits>
19 
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
22 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logger.h"
32 #include "tensorflow/core/protobuf/autotuning.pb.h"
33 #include "tensorflow/core/util/proto/proto_utils.h"
34 #include "tensorflow/stream_executor/blas.h"
35 #include "tensorflow/stream_executor/device_memory.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
38 
39 namespace xla {
40 namespace gpu {
41 
42 using tensorflow::AutotuneResult;
43 
44 using GemmCacheKey =
45     std::tuple<se::StreamExecutor*, Shape, Shape, Shape, std::string>;
46 
47 static tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED);
48 static auto& autotune_cache TF_GUARDED_BY(autotune_cache_mu) =
49     *new absl::flat_hash_map<GemmCacheKey,
50                              absl::optional<se::blas::AlgorithmType>>();
51 static int64 cache_hits TF_GUARDED_BY(autotune_cache_mu) = 0;
52 static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0;
53 
54 // Experimentally tries to pick the best algorithm for the given gemm.
55 //
56 // This may fail under perfectly normal circumstances.  In particular, it will
57 // fail if the program was built with < CUDA 8 or if we're using a gpu older
58 // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
59 // all.
DoUncachedGemmAutotune(const HloInstruction * gemm,se::Stream * stream,se::DeviceMemoryAllocator * allocator)60 static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
61     const HloInstruction* gemm, se::Stream* stream,
62     se::DeviceMemoryAllocator* allocator) {
63   if (!stream->parent()->SynchronizeAllActivity()) {
64     return InternalError("Failed to synchronize GPU for autotuning.");
65   }
66 
67   const HloModuleConfig& hlo_module_config = gemm->GetModule()->config();
68   const bool init_cublas_data =
69       hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
70   se::RedzoneAllocator input_output_allocator(
71       stream, allocator, PtxOptsFromConfig(hlo_module_config),
72       /*memory_limit=*/std::numeric_limits<int64>::max());
73 
74   BufferComparator comparator(gemm->shape(), hlo_module_config);
75 
76   int64 rng_state = 0;
77   auto get_initialized_buffer =
78       [&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
79     TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
80                         input_output_allocator.AllocateBytes(
81                             ShapeUtil::ByteSizeOf(op->shape())));
82     if (init_cublas_data) {
83       InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
84     }
85     return buffer;
86   };
87 
88   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
89                       get_initialized_buffer(gemm->operand(0)));
90   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
91                       get_initialized_buffer(gemm->operand(1)));
92   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
93                       get_initialized_buffer(gemm));
94   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
95                       get_initialized_buffer(gemm));
96 
97   const DebugOptions& debug_options =
98       gemm->GetModule()->config().debug_options();
99 
100   const bool crash_on_checking_failure =
101       debug_options.xla_gpu_crash_on_verification_failures();
102 
103   GemmBackendConfig backend_config =
104       gemm->backend_config<GemmBackendConfig>().ValueOrDie();
105   const int32 cublas_autotune_level =
106       gemm->GetModule()->config().debug_options().xla_gpu_autotune_level();
107   const bool reinit_cublas_data = cublas_autotune_level > 2;
108   const bool check_cublas = cublas_autotune_level > 3;
109 
110   VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
111 
112   std::vector<se::blas::AlgorithmType> algorithms;
113   CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
114 
115   absl::optional<se::blas::AlgorithmType> first_algorithm;
116   std::vector<AutotuneResult> profile_results;
117 
118   GpuGemmConfig config = GetGpuGemmConfig(gemm);
119 
120   for (se::blas::AlgorithmType algorithm : algorithms) {
121     // Make sure the output buffer always has the same value if we use
122     // the bias parameter.
123     if (reinit_cublas_data && backend_config.beta() != 0) {
124       int64 rng_state = 0;
125       InitializeBuffer(stream, gemm->shape().element_type(), &rng_state,
126                        output_buffer);
127     }
128     se::blas::ProfileResult profile_result;
129 
130     // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
131     // for all algorithms if we're targeting < sm_50.  But because we pass a
132     // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
133     // and the actual success-ness is returned in ProfileResult::is_valid.
134     CHECK(RunGemm(config, lhs_buffer, rhs_buffer, output_buffer, stream,
135                   /*implements_whole_instruction=*/true,
136                   /*profile_index=*/-1,
137                   /*profiler=*/nullptr,
138                   /*profile_result=*/&profile_result, algorithm)
139               .ok());
140 
141     if (!profile_result.is_valid()) {
142       // Unsupported algorithm.
143       continue;
144     }
145 
146     profile_results.emplace_back();
147     AutotuneResult& result = profile_results.back();
148     result.mutable_gemm()->set_algorithm(algorithm);
149 
150     VLOG(2) << "cublas gemm algorithm " << algorithm << " took "
151             << profile_result.elapsed_time_in_ms() << "ms" << std::endl;
152 
153     *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
154         absl::Milliseconds(profile_result.elapsed_time_in_ms()));
155 
156     if (!check_cublas) {
157       continue;
158     }
159 
160     TF_ASSIGN_OR_RETURN(
161         se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
162         input_output_allocator.CheckRedzones());
163     if (!rz_check_status.ok()) {
164       result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
165       *result.mutable_failure()->mutable_msg() =
166           rz_check_status.RedzoneFailureMsg();
167       LOG(ERROR) << "Detected cuBLAS out-of-bounds write in gemm buffer";
168       CHECK(!crash_on_checking_failure);
169       continue;
170     }
171 
172     if (!first_algorithm) {
173       // First run: set the reference result buffer.
174       CHECK(reference_result_buffer.size() == output_buffer.size());
175       stream->ThenMemcpy(&reference_result_buffer, output_buffer,
176                          output_buffer.size());
177       first_algorithm.emplace(algorithm);
178     } else {
179       // Perform the comparison.
180       TF_ASSIGN_OR_RETURN(bool compare_result,
181                           comparator.CompareEqual(stream, output_buffer,
182                                                   reference_result_buffer));
183       if (!compare_result) {
184         LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
185                    << "This is likely a bug/unexpected loss of precision "
186                    << "in cuBLAS.";
187         CHECK(!crash_on_checking_failure);
188 
189         result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT);
190         result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
191             *first_algorithm);
192       }
193     }
194   }
195 
196   tensorflow::AutotuningLog log;
197   for (const AutotuneResult& profile : profile_results) {
198     *log.add_results() = profile;
199   }
200   if (!crash_on_checking_failure) {
201     tensorflow::Logger::GetSingleton()->LogProto(log);
202   }
203 
204   // Choose fastest correct GEMM, but allow for incorrect results (since the
205   // reference result is chosen arbitrary).
206   auto has_failure = [](const AutotuneResult& r) {
207     return r.has_failure() &&
208            r.failure().kind() != AutotuneResult::WRONG_RESULT;
209   };
210 
211   auto result_comparison_key = [&has_failure](const AutotuneResult& r) {
212     return std::make_tuple(
213         has_failure(r),
214         tensorflow::proto_utils::FromDurationProto(r.run_time()));
215   };
216   const auto& best_result = absl::c_min_element(
217       profile_results,
218       [&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
219         return result_comparison_key(lhs) < result_comparison_key(rhs);
220       });
221 
222   if (best_result != profile_results.end() && !has_failure(*best_result)) {
223     return {best_result->gemm().algorithm()};
224   }
225 
226   VLOG(1) << "Unable to autotune cuBLAS gemm on stream " << stream
227           << " none of the " << algorithms.size() << " ran successfully";
228   return {absl::nullopt};
229 }
230 
DoGemmAutotune(const HloInstruction * instr,const GemmBackendConfig & gemm_config,se::DeviceMemoryAllocator * allocator,se::Stream * stream)231 static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
232     const HloInstruction* instr, const GemmBackendConfig& gemm_config,
233     se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
234   const HloInstruction* lhs = instr->operand(0);
235   const HloInstruction* rhs = instr->operand(1);
236 
237   // Don't run autotuning concurrently on the same GPU.
238   tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
239 
240   GemmCacheKey key =
241       std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
242                       instr->shape(), gemm_config.SerializeAsString());
243 
244   tensorflow::mutex_lock cache_lock(autotune_cache_mu);
245   auto it = autotune_cache.find(key);
246   int64 autotuning_requests = cache_hits + cache_misses;
247   if (autotuning_requests && autotuning_requests % 10 == 0) {
248     VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
249             << autotuning_requests;
250   }
251 
252   if (it != autotune_cache.end()) {
253     cache_hits++;
254     VLOG(4) << "Autotuning cache hit, using algorithm: "
255             << (it->second.has_value() ? absl::StrCat(*(it->second))
256                                        : "<generic>");
257     return it->second;
258   }
259   cache_misses++;
260   VLOG(4) << "Autotuning cache miss";
261 
262   int64 batch_size = gemm_config.batch_size();
263   absl::optional<se::blas::AlgorithmType> result;
264   if (batch_size != 1) {
265     // TODO(b/112111608): Implement auto tune for batched gemm.
266     VLOG(2) << "Batch size is non-singular, using generic algorithm";
267     result = absl::nullopt;
268   } else {
269     TF_ASSIGN_OR_RETURN(result,
270                         DoUncachedGemmAutotune(instr, stream, allocator));
271   }
272 
273   CHECK(autotune_cache.emplace(key, result).second);
274   return result;
275 }
276 
RunOnInstruction(HloInstruction * instr,se::StreamExecutor * executor,se::DeviceMemoryAllocator * allocator)277 static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
278                                        se::StreamExecutor* executor,
279                                        se::DeviceMemoryAllocator* allocator) {
280   if (allocator == nullptr) {
281     allocator = executor->GetAllocator();
282   }
283   TF_ASSIGN_OR_RETURN(se::Stream* const stream,
284                       allocator->GetStream(executor->device_ordinal()));
285 
286   GemmBackendConfig gemm_config =
287       instr->backend_config<GemmBackendConfig>().ValueOrDie();
288 
289   TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
290                       DoGemmAutotune(instr, gemm_config, allocator, stream));
291 
292   // We update instruction->backend_config(); if no algorithms are supported,
293   // a different API is used, which does not require specifying an algorithm.
294   GemmBackendConfig updated_config = gemm_config;
295   if (gemm_algorithm) {
296     updated_config.set_selected_algorithm(*gemm_algorithm);
297   }
298   TF_RETURN_IF_ERROR(instr->set_backend_config(updated_config));
299   return updated_config.SerializeAsString() != gemm_config.SerializeAsString();
300 }
301 
RunOnComputation(HloComputation * computation,se::StreamExecutor * se,se::DeviceMemoryAllocator * allocator)302 static StatusOr<bool> RunOnComputation(HloComputation* computation,
303                                        se::StreamExecutor* se,
304                                        se::DeviceMemoryAllocator* allocator) {
305   bool changed = false;
306   for (HloInstruction* instr : computation->instructions()) {
307     if (IsCublasGemm(*instr)) {
308       TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, se, allocator));
309       changed |= result;
310     }
311   }
312   return changed;
313 }
314 
Run(HloModule * module)315 StatusOr<bool> GemmAlgorithmPicker::Run(HloModule* module) {
316   XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker");
317 
318   if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
319     VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
320     return false;
321   }
322 
323   bool changed = false;
324   for (HloComputation* computation : module->MakeNonfusionComputations()) {
325     TF_ASSIGN_OR_RETURN(
326         bool result, RunOnComputation(computation, stream_exec_, allocator_));
327     changed |= result;
328   }
329   return changed;
330 }
331 
332 }  // namespace gpu
333 }  // namespace xla
334