• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/bef_thunk.h"
17 
18 #include "tensorflow/core/platform/errors.h"
19 
20 #if BEF_THUNKS
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "mlir/Dialect/GPU/Passes.h"  // from @llvm-project
24 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h"
28 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
29 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
30 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/cpu_info.h"
33 #include "tensorflow/core/tfrt/gpu/gpu_shared_context.h"
34 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
35 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
36 #include "tensorflow/stream_executor/device_memory.h"
37 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
38 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
39 #include "tfrt/gpu/gpu_types.h"  // from @tf_runtime
40 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
41 #include "tfrt/bef_converter/mlir_to_bef_translate.h"  // from @tf_runtime
42 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
43 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
44 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
45 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
46 #include "tfrt/host_context/chain.h"  // from @tf_runtime
47 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
48 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
49 #include "tfrt/host_context/function.h"  // from @tf_runtime
50 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
51 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
52 #include "tfrt/support/error_util.h"  // from @tf_runtime
53 
54 // Common place for all collective thunks to source nccl/rccl headers.
55 // Also, all the RunNcclCollective() functions for various thunks should
56 // use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
57 #if GOOGLE_XCCL
58 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
59 #define XLA_ENABLE_XCCL 1
60 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61 #endif  // GOOGLE_XCCL
62 
63 #if XLA_ENABLE_XCCL
64 #if GOOGLE_CUDA
65 #include "third_party/nccl/nccl.h"
66 #elif TENSORFLOW_USE_ROCM
67 #include "rocm/include/rccl/rccl.h"
68 #else
69 #error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
70 #endif
71 
72 // Also include this file required by all collective thunks.
73 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
74 
75 #endif  // XLA_ENABLE_XCCL
76 
77 namespace xla {
78 namespace gpu {
79 
IsBefThunkEnabled()80 bool IsBefThunkEnabled() { return true; }
81 
82 namespace {
83 
84 struct CoreRuntimeAndWorkQueue {
85   tfrt::CoreRuntime* core_runtime;
86   tensorflow::tfrt_stub::WorkQueueInterface* work_queue;
87 };
88 
89 class BefThunk : public Thunk {
90  public:
BefThunk(Thunk::Kind kind,ThunkInfo thunk_info,std::vector<BufferAllocation::Slice> inputs,std::vector<BufferAllocation::Slice> outputs,tfrt::BefBuffer bef_buffer,tfrt::RCReference<tfrt::BEFFile> bef_file,mlir::Operation * op)91   BefThunk(Thunk::Kind kind, ThunkInfo thunk_info,
92            std::vector<BufferAllocation::Slice> inputs,
93            std::vector<BufferAllocation::Slice> outputs,
94            tfrt::BefBuffer bef_buffer,
95            tfrt::RCReference<tfrt::BEFFile> bef_file, mlir::Operation* op)
96       : Thunk(kind, thunk_info),
97         inputs_(std::move(inputs)),
98         outputs_(std::move(outputs)),
99         bef_buffer_(std::move(bef_buffer)),
100         bef_file_(std::move(bef_file)) {
101     // TODO(hanbinyoon): Also handle other collective ops.
102     if (auto all_reduce_op = mlir::dyn_cast<mlir::lmhlo::AllReduceOp>(*op)) {
103       xccl_config_ = GetNcclCollectiveConfigForMlir(
104           all_reduce_op, all_reduce_op.use_global_device_ids());
105     }
106   }
107 
108   Status ExecuteOnStream(const ExecuteParams& params) override;
109 
110  private:
111   std::vector<BufferAllocation::Slice> inputs_;
112   std::vector<BufferAllocation::Slice> outputs_;
113   tfrt::BefBuffer bef_buffer_;
114   tfrt::RCReference<tfrt::BEFFile> bef_file_;
115   absl::optional<NcclCollectiveConfig> xccl_config_;
116 };
117 
118 }  // namespace
119 
120 static const char kDefaultHostDeviceName[] =
121     "/job:localhost/replica:0/task:0/device:CPU:0";
122 
123 static const char kFuncName[] = "main";
124 
125 // Clones 'op' into a function within a new module.
CreateModule(mlir::Operation * op)126 static mlir::OwningOpRef<mlir::ModuleOp> CreateModule(mlir::Operation* op) {
127   mlir::OpBuilder builder(op->getContext());
128   mlir::OwningOpRef<mlir::ModuleOp> module =
129       builder.create<mlir::ModuleOp>(op->getLoc());
130 
131   builder.setInsertionPointToEnd(module->getBody());
132   auto func_type = builder.getType<mlir::FunctionType>(op->getOperandTypes(),
133                                                        op->getResultTypes());
134   auto func = builder.create<mlir::FuncOp>(op->getLoc(), kFuncName, func_type);
135   func.setPublic();
136 
137   builder.setInsertionPointToEnd(func.addEntryBlock());
138   mlir::BlockAndValueMapping mapping;
139   for (const auto& pair :
140        llvm::zip_first(op->getOperands(), func.getArguments())) {
141     mapping.map(std::get<0>(pair), std::get<1>(pair));
142   }
143   builder.clone(*op, mapping);
144 
145   builder.create<mlir::lmhlo::TerminatorOp>(op->getLoc());
146 
147   return module;
148 }
149 
150 // Lowers 'module' to BEF.
ConvertToBef(mlir::ModuleOp module)151 static StatusOr<tfrt::BefBuffer> ConvertToBef(mlir::ModuleOp module) {
152   mlir::PassManager pass_manager(module->getContext(),
153                                  mlir::PassManager::Nesting::Implicit);
154   pass_manager.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
155   pass_manager.addPass(mlir::createGpuAsyncRegionPass());
156   pass_manager.addPass(tensorflow::createAsyncGpuTfrtConversionPass());
157   if (failed(pass_manager.run(module)))
158     return tensorflow::errors::Internal("Failed to run pass pipeline.");
159 
160   std::string bef;
161   llvm::raw_string_ostream bef_ostream(bef);
162   if (failed(tfrt::MLIRToBEFTranslate(module, bef_ostream)))
163     return tensorflow::errors::Internal("Failed to translate MLIR to BEF.");
164 
165   return tfrt::BefBuffer(bef.data(), bef.data() + bef.size());
166 }
167 
GetThunkKind(mlir::Operation * op)168 static StatusOr<Thunk::Kind> GetThunkKind(mlir::Operation* op) {
169   if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
170     return Thunk::Kind::kGemm;
171   }
172   return tensorflow::errors::Unimplemented(
173       "Operation is not supported by BefThunk.");
174 }
175 
GetCoreRuntimeAndWorkQueue()176 static StatusOr<CoreRuntimeAndWorkQueue> GetCoreRuntimeAndWorkQueue() {
177   // TODO(hanbinyoon): Make these configurable.
178   int tfrt_num_threads = tensorflow::port::MaxParallelism();
179   int tfrt_num_blocking_threads = 16;
180 
181   static StatusOr<CoreRuntimeAndWorkQueue>* runtime_and_queue_or =
182       [&](int num_threads, int num_blocking_threads) {
183         // Create work queue.
184         auto work_queue = tensorflow::tfrt_stub::WrapDefaultWorkQueue(
185             tfrt::CreateMultiThreadedWorkQueue(num_threads,
186                                                num_blocking_threads));
187         if (work_queue == nullptr) {
188           auto status =
189               tensorflow::errors::Internal("Failed to create TFRT work queue.");
190           return new StatusOr<CoreRuntimeAndWorkQueue>(status);
191         }
192         auto* work_queue_ptr = work_queue.get();
193 
194         // Create core runtime.
195         auto expected_core_runtime = tfrt::CoreRuntime::Create(
196             [](const tfrt::DecodedDiagnostic& diag) {
197               LOG(ERROR) << diag.message;
198             },
199             tfrt::CreateMallocAllocator(), std::move(work_queue),
200             kDefaultHostDeviceName);
201         if (!expected_core_runtime) {
202           auto error = expected_core_runtime.takeError();
203           auto status =
204               tensorflow::errors::Internal(llvm::toString(std::move(error)));
205           return new StatusOr<CoreRuntimeAndWorkQueue>(status);
206         }
207 
208         auto runtime_and_queue = CoreRuntimeAndWorkQueue{
209             expected_core_runtime->release(), work_queue_ptr};
210         return new StatusOr<CoreRuntimeAndWorkQueue>(runtime_and_queue);
211       }(tfrt_num_threads, tfrt_num_blocking_threads);
212 
213   TF_RETURN_IF_ERROR(runtime_and_queue_or->status());
214   return runtime_and_queue_or->ValueOrDie();
215 }
216 
CreateBefThunk(Thunk::ThunkInfo thunk_info,mlir::Operation * op,std::vector<BufferAllocation::Slice> inputs,std::vector<BufferAllocation::Slice> outputs)217 StatusOr<std::unique_ptr<Thunk>> CreateBefThunk(
218     Thunk::ThunkInfo thunk_info, mlir::Operation* op,
219     std::vector<BufferAllocation::Slice> inputs,
220     std::vector<BufferAllocation::Slice> outputs) {
221   TF_ASSIGN_OR_RETURN(auto kind, GetThunkKind(op));
222   auto module = CreateModule(op);
223   TF_ASSIGN_OR_RETURN(tfrt::BefBuffer bef_buffer, ConvertToBef(*module));
224 
225   TF_ASSIGN_OR_RETURN(auto runtime_and_queue, GetCoreRuntimeAndWorkQueue());
226   tfrt::HostContext* host = runtime_and_queue.core_runtime->GetHostContext();
227   auto bef_file = tfrt::BEFFile::Open(bef_buffer, host->GetKernelRegistry(),
228                                       host->diag_handler(), host->allocator());
229   if (!bef_file)
230     return tensorflow::errors::Internal("Failed to load BEF file.");
231 
232   return std::unique_ptr<Thunk>(
233       new BefThunk(kind, thunk_info, std::move(inputs), std::move(outputs),
234                    std::move(bef_buffer), std::move(bef_file), op));
235 }
236 
237 // Wrap the GPU stream specified in 'params' (initialized by the StreamExecutor)
238 // to be passed to BEF functions as AsyncValueRef<GpuStream>.
CreateGpuStream(const Thunk::ExecuteParams & params)239 static auto CreateGpuStream(const Thunk::ExecuteParams& params) {
240   auto se_gpu_executor = static_cast<stream_executor::gpu::GpuExecutor*>(
241       params.stream->parent()->implementation());
242   auto se_gpu_stream = static_cast<stream_executor::gpu::GpuStream*>(
243       params.stream->implementation());
244   return tfrt::gpu::BorrowedGpuStream(
245       tfrt::gpu::wrapper::Context(se_gpu_executor->gpu_context()->context()),
246       tfrt::gpu::wrapper::Stream(se_gpu_stream->gpu_stream()));
247 }
248 
249 // Wrap the GPU buffer specified in 'slice' to be passed to BEF functions as
250 // AsyncValueRef<GpuBuffer>.
CreateGpuBuffer(const Thunk::ExecuteParams & params,const BufferAllocation::Slice & slice)251 static tfrt::RCReference<tfrt::AsyncValue> CreateGpuBuffer(
252     const Thunk::ExecuteParams& params, const BufferAllocation::Slice& slice) {
253   se::DeviceMemoryBase data =
254       params.buffer_allocations->GetDeviceAddress(slice);
255   tfrt::gpu::wrapper::Pointer<void> pointer(data.opaque(),
256                                             tfrt::gpu::wrapper::Platform::CUDA);
257   auto allocator =
258       tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuOneShotAllocator<void>>(
259           pointer);
260   auto buffer =
261       tfrt::gpu::GpuBuffer::Allocate(std::move(allocator), data.size());
262   if (!buffer)
263     return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat(buffer.takeError()));
264   return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuBuffer>(
265       std::move(*buffer));
266 }
267 
CreateXcclContext(const Thunk::ExecuteParams & params,const NcclCollectiveConfig & xccl_config,tfrt::RequestContextBuilder * request_context_builder)268 static Status CreateXcclContext(
269     const Thunk::ExecuteParams& params, const NcclCollectiveConfig& xccl_config,
270     tfrt::RequestContextBuilder* request_context_builder) {
271   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
272                       params.GetGlobalDeviceId());
273   TF_ASSIGN_OR_RETURN(std::vector<GlobalDeviceId> participants,
274                       GetParticipatingDevices(
275                           global_device_id, *params.device_assn,
276                           xccl_config.replica_groups, xccl_config.group_mode));
277   if (IsGlobalNcclConfig() &&
278       (participants.size() != params.device_assn->replica_count())) {
279     return InvalidArgument(
280         "Partial replica groups are not allowed when using NCCL_COMM_ID "
281         "environment configuration.");
282   }
283 
284   TF_ASSIGN_OR_RETURN(
285       std::vector<LocalParticipant> local_participants,
286       GetLocalParticipants(participants, params.gpu_global_device_ids));
287   absl::flat_hash_map<tfrt::gpu::GpuSharedContext::LocalDeviceIdentifier, int>
288       local_ids_to_rank;
289   for (const auto& participant : local_participants) {
290     local_ids_to_rank[participant.device_ordinal] = participant.rank;
291   }
292 
293   std::vector<int64> gpu_global_device_ids;
294   if (params.gpu_global_device_ids != nullptr) {
295     for (const auto& global_device_id : *params.gpu_global_device_ids) {
296       gpu_global_device_ids.push_back(global_device_id.value());
297     }
298   }
299 
300   tfrt::gpu::XcclUniqueIdCallback xccl_unique_id_callback;
301   if (params.nccl_unique_id_callback != nullptr) {
302     xccl_unique_id_callback = [&](const tfrt::gpu::XcclCliqueKey& kernel_key)
303         -> llvm::Expected<std::string> {
304       std::vector<GlobalDeviceId> devices;
305       for (const int64_t device : kernel_key) {
306         devices.push_back(GlobalDeviceId(device));
307       }
308       auto nccl_unique_id_or =
309           (*params.nccl_unique_id_callback)(NcclCliqueKey(devices));
310       if (!nccl_unique_id_or.ok()) {
311         return tfrt::MakeStringError(
312             nccl_unique_id_or.status().error_message());
313       }
314       return nccl_unique_id_or.ValueOrDie();
315     };
316   }
317 
318   request_context_builder->context_data().emplace<tfrt::gpu::GpuSharedContext>(
319       params.run_id.ToInt(), std::move(local_ids_to_rank),
320       std::move(gpu_global_device_ids), std::move(xccl_unique_id_callback),
321       /*compiled_code=*/nullptr);
322   return Status::OK();
323 }
324 
ExecuteOnStream(const ExecuteParams & params)325 Status BefThunk::ExecuteOnStream(const ExecuteParams& params) {
326   VLOG(2) << "Executing BEF thunk.";
327 
328   // Signature: (chain, stream, inputs..., outputs...) -> (chain).
329   const tfrt::Function* function = bef_file_->GetFunction(kFuncName);
330   if (!function) {
331     return tensorflow::errors::Internal("Failed to get '", kFuncName,
332                                         "' function.");
333   }
334 
335   // Create execution context.
336   TF_ASSIGN_OR_RETURN(auto runtime_and_queue, GetCoreRuntimeAndWorkQueue());
337   tfrt::RequestContextBuilder request_context_builder(
338       runtime_and_queue.core_runtime->GetHostContext(),
339       /*resource_context=*/nullptr);
340   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
341   TF_RETURN_IF_ERROR(runtime_and_queue.work_queue->InitializeRequest(
342       &request_context_builder, &intra_op_threadpool));
343   if (xccl_config_.has_value()) {
344     TF_RETURN_IF_ERROR(
345         CreateXcclContext(params, *xccl_config_, &request_context_builder));
346   }
347   auto expected_req_ctx = std::move(request_context_builder).build();
348   if (!expected_req_ctx) {
349     auto error = expected_req_ctx.takeError();
350     return tensorflow::errors::Internal(llvm::toString(std::move(error)));
351   }
352   tfrt::ExecutionContext exec_ctx(std::move(*expected_req_ctx));
353 
354   // Create owning handles for arguments and add pointer to them to 'args'.
355   tfrt::SmallVector<tfrt::AsyncValue*, 8> args;
356   args.reserve(function->num_arguments());
357   tfrt::AsyncValueRef<tfrt::Chain> chain = tfrt::GetReadyChain(exec_ctx.host());
358   args.push_back(chain.GetAsyncValue());
359   tfrt::gpu::BorrowedGpuStream stream = CreateGpuStream(params);
360   args.push_back(static_cast<tfrt::AsyncValueRef<tfrt::gpu::GpuStream>>(stream)
361                      .GetAsyncValue());
362   llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 8> buffers;
363   for (auto input : inputs_) {
364     buffers.push_back(CreateGpuBuffer(params, input));
365   }
366   for (auto output : outputs_) {
367     buffers.push_back(CreateGpuBuffer(params, output));
368   }
369   for (auto& buffer : buffers) {
370     args.push_back(buffer.get());
371   }
372   if (args.size() != function->num_arguments())
373     return tensorflow::errors::Internal("Unexpected argument count.");
374 
375   // Create return chain.
376   tfrt::RCReference<tfrt::AsyncValue> result;
377   if (function->num_results() != 1)
378     return tensorflow::errors::Internal("Unexpected result count.");
379 
380   // Execute the function.
381   function->Execute(exec_ctx, args, {result});
382 
383   // Wait for async execution to complete.
384   tfrt::Await(exec_ctx, llvm::makeArrayRef(result));
385 
386   // Report error if any.
387   if (auto* error = result->GetErrorIfPresent())
388     return tensorflow::errors::Internal(error->message);
389 
390   return Status::OK();
391 }
392 
393 }  // namespace gpu
394 }  // namespace xla
395 #else   // BEF_THUNKS
396 namespace xla {
397 
IsBefThunkEnabled()398 bool gpu::IsBefThunkEnabled() { return false; }
399 
CreateBefThunk(Thunk::ThunkInfo,mlir::Operation *,std::vector<BufferAllocation::Slice>,std::vector<BufferAllocation::Slice>)400 StatusOr<std::unique_ptr<gpu::Thunk>> gpu::CreateBefThunk(
401     Thunk::ThunkInfo, mlir::Operation*, std::vector<BufferAllocation::Slice>,
402     std::vector<BufferAllocation::Slice>) {
403   return tensorflow::errors::FailedPrecondition("BefThunks are disabled.");
404 }
405 
406 }  // namespace xla
407 #endif  // BEF_THUNKS
408