1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/bef_thunk.h"
17
18 #include "tensorflow/core/platform/errors.h"
19
20 #if BEF_THUNKS
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
24 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
25 #include "mlir/Pass/PassManager.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h"
28 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
29 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
30 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/cpu_info.h"
33 #include "tensorflow/core/tfrt/gpu/gpu_shared_context.h"
34 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
35 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
36 #include "tensorflow/stream_executor/device_memory.h"
37 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
38 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
39 #include "tfrt/gpu/gpu_types.h" // from @tf_runtime
40 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime
41 #include "tfrt/bef_converter/mlir_to_bef_translate.h" // from @tf_runtime
42 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
43 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
44 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
45 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
46 #include "tfrt/host_context/chain.h" // from @tf_runtime
47 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
48 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
49 #include "tfrt/host_context/function.h" // from @tf_runtime
50 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
51 #include "tfrt/host_context/host_context.h" // from @tf_runtime
52 #include "tfrt/support/error_util.h" // from @tf_runtime
53
54 // Common place for all collective thunks to source nccl/rccl headers.
55 // Also, all the RunNcclCollective() functions for various thunks should
56 // use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
57 #if GOOGLE_XCCL
58 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
59 #define XLA_ENABLE_XCCL 1
60 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61 #endif // GOOGLE_XCCL
62
63 #if XLA_ENABLE_XCCL
64 #if GOOGLE_CUDA
65 #include "third_party/nccl/nccl.h"
66 #elif TENSORFLOW_USE_ROCM
67 #include "rocm/include/rccl/rccl.h"
68 #else
69 #error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
70 #endif
71
72 // Also include this file required by all collective thunks.
73 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
74
75 #endif // XLA_ENABLE_XCCL
76
77 namespace xla {
78 namespace gpu {
79
IsBefThunkEnabled()80 bool IsBefThunkEnabled() { return true; }
81
82 namespace {
83
84 struct CoreRuntimeAndWorkQueue {
85 tfrt::CoreRuntime* core_runtime;
86 tensorflow::tfrt_stub::WorkQueueInterface* work_queue;
87 };
88
89 class BefThunk : public Thunk {
90 public:
BefThunk(Thunk::Kind kind,ThunkInfo thunk_info,std::vector<BufferAllocation::Slice> inputs,std::vector<BufferAllocation::Slice> outputs,tfrt::BefBuffer bef_buffer,tfrt::RCReference<tfrt::BEFFile> bef_file,mlir::Operation * op)91 BefThunk(Thunk::Kind kind, ThunkInfo thunk_info,
92 std::vector<BufferAllocation::Slice> inputs,
93 std::vector<BufferAllocation::Slice> outputs,
94 tfrt::BefBuffer bef_buffer,
95 tfrt::RCReference<tfrt::BEFFile> bef_file, mlir::Operation* op)
96 : Thunk(kind, thunk_info),
97 inputs_(std::move(inputs)),
98 outputs_(std::move(outputs)),
99 bef_buffer_(std::move(bef_buffer)),
100 bef_file_(std::move(bef_file)) {
101 // TODO(hanbinyoon): Also handle other collective ops.
102 if (auto all_reduce_op = mlir::dyn_cast<mlir::lmhlo::AllReduceOp>(*op)) {
103 xccl_config_ = GetNcclCollectiveConfigForMlir(
104 all_reduce_op, all_reduce_op.use_global_device_ids());
105 }
106 }
107
108 Status ExecuteOnStream(const ExecuteParams& params) override;
109
110 private:
111 std::vector<BufferAllocation::Slice> inputs_;
112 std::vector<BufferAllocation::Slice> outputs_;
113 tfrt::BefBuffer bef_buffer_;
114 tfrt::RCReference<tfrt::BEFFile> bef_file_;
115 absl::optional<NcclCollectiveConfig> xccl_config_;
116 };
117
118 } // namespace
119
120 static const char kDefaultHostDeviceName[] =
121 "/job:localhost/replica:0/task:0/device:CPU:0";
122
123 static const char kFuncName[] = "main";
124
125 // Clones 'op' into a function within a new module.
CreateModule(mlir::Operation * op)126 static mlir::OwningOpRef<mlir::ModuleOp> CreateModule(mlir::Operation* op) {
127 mlir::OpBuilder builder(op->getContext());
128 mlir::OwningOpRef<mlir::ModuleOp> module =
129 builder.create<mlir::ModuleOp>(op->getLoc());
130
131 builder.setInsertionPointToEnd(module->getBody());
132 auto func_type = builder.getType<mlir::FunctionType>(op->getOperandTypes(),
133 op->getResultTypes());
134 auto func = builder.create<mlir::FuncOp>(op->getLoc(), kFuncName, func_type);
135 func.setPublic();
136
137 builder.setInsertionPointToEnd(func.addEntryBlock());
138 mlir::BlockAndValueMapping mapping;
139 for (const auto& pair :
140 llvm::zip_first(op->getOperands(), func.getArguments())) {
141 mapping.map(std::get<0>(pair), std::get<1>(pair));
142 }
143 builder.clone(*op, mapping);
144
145 builder.create<mlir::lmhlo::TerminatorOp>(op->getLoc());
146
147 return module;
148 }
149
150 // Lowers 'module' to BEF.
ConvertToBef(mlir::ModuleOp module)151 static StatusOr<tfrt::BefBuffer> ConvertToBef(mlir::ModuleOp module) {
152 mlir::PassManager pass_manager(module->getContext(),
153 mlir::PassManager::Nesting::Implicit);
154 pass_manager.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
155 pass_manager.addPass(mlir::createGpuAsyncRegionPass());
156 pass_manager.addPass(tensorflow::createAsyncGpuTfrtConversionPass());
157 if (failed(pass_manager.run(module)))
158 return tensorflow::errors::Internal("Failed to run pass pipeline.");
159
160 std::string bef;
161 llvm::raw_string_ostream bef_ostream(bef);
162 if (failed(tfrt::MLIRToBEFTranslate(module, bef_ostream)))
163 return tensorflow::errors::Internal("Failed to translate MLIR to BEF.");
164
165 return tfrt::BefBuffer(bef.data(), bef.data() + bef.size());
166 }
167
GetThunkKind(mlir::Operation * op)168 static StatusOr<Thunk::Kind> GetThunkKind(mlir::Operation* op) {
169 if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
170 return Thunk::Kind::kGemm;
171 }
172 return tensorflow::errors::Unimplemented(
173 "Operation is not supported by BefThunk.");
174 }
175
GetCoreRuntimeAndWorkQueue()176 static StatusOr<CoreRuntimeAndWorkQueue> GetCoreRuntimeAndWorkQueue() {
177 // TODO(hanbinyoon): Make these configurable.
178 int tfrt_num_threads = tensorflow::port::MaxParallelism();
179 int tfrt_num_blocking_threads = 16;
180
181 static StatusOr<CoreRuntimeAndWorkQueue>* runtime_and_queue_or =
182 [&](int num_threads, int num_blocking_threads) {
183 // Create work queue.
184 auto work_queue = tensorflow::tfrt_stub::WrapDefaultWorkQueue(
185 tfrt::CreateMultiThreadedWorkQueue(num_threads,
186 num_blocking_threads));
187 if (work_queue == nullptr) {
188 auto status =
189 tensorflow::errors::Internal("Failed to create TFRT work queue.");
190 return new StatusOr<CoreRuntimeAndWorkQueue>(status);
191 }
192 auto* work_queue_ptr = work_queue.get();
193
194 // Create core runtime.
195 auto expected_core_runtime = tfrt::CoreRuntime::Create(
196 [](const tfrt::DecodedDiagnostic& diag) {
197 LOG(ERROR) << diag.message;
198 },
199 tfrt::CreateMallocAllocator(), std::move(work_queue),
200 kDefaultHostDeviceName);
201 if (!expected_core_runtime) {
202 auto error = expected_core_runtime.takeError();
203 auto status =
204 tensorflow::errors::Internal(llvm::toString(std::move(error)));
205 return new StatusOr<CoreRuntimeAndWorkQueue>(status);
206 }
207
208 auto runtime_and_queue = CoreRuntimeAndWorkQueue{
209 expected_core_runtime->release(), work_queue_ptr};
210 return new StatusOr<CoreRuntimeAndWorkQueue>(runtime_and_queue);
211 }(tfrt_num_threads, tfrt_num_blocking_threads);
212
213 TF_RETURN_IF_ERROR(runtime_and_queue_or->status());
214 return runtime_and_queue_or->ValueOrDie();
215 }
216
CreateBefThunk(Thunk::ThunkInfo thunk_info,mlir::Operation * op,std::vector<BufferAllocation::Slice> inputs,std::vector<BufferAllocation::Slice> outputs)217 StatusOr<std::unique_ptr<Thunk>> CreateBefThunk(
218 Thunk::ThunkInfo thunk_info, mlir::Operation* op,
219 std::vector<BufferAllocation::Slice> inputs,
220 std::vector<BufferAllocation::Slice> outputs) {
221 TF_ASSIGN_OR_RETURN(auto kind, GetThunkKind(op));
222 auto module = CreateModule(op);
223 TF_ASSIGN_OR_RETURN(tfrt::BefBuffer bef_buffer, ConvertToBef(*module));
224
225 TF_ASSIGN_OR_RETURN(auto runtime_and_queue, GetCoreRuntimeAndWorkQueue());
226 tfrt::HostContext* host = runtime_and_queue.core_runtime->GetHostContext();
227 auto bef_file = tfrt::BEFFile::Open(bef_buffer, host->GetKernelRegistry(),
228 host->diag_handler(), host->allocator());
229 if (!bef_file)
230 return tensorflow::errors::Internal("Failed to load BEF file.");
231
232 return std::unique_ptr<Thunk>(
233 new BefThunk(kind, thunk_info, std::move(inputs), std::move(outputs),
234 std::move(bef_buffer), std::move(bef_file), op));
235 }
236
237 // Wrap the GPU stream specified in 'params' (initialized by the StreamExecutor)
238 // to be passed to BEF functions as AsyncValueRef<GpuStream>.
CreateGpuStream(const Thunk::ExecuteParams & params)239 static auto CreateGpuStream(const Thunk::ExecuteParams& params) {
240 auto se_gpu_executor = static_cast<stream_executor::gpu::GpuExecutor*>(
241 params.stream->parent()->implementation());
242 auto se_gpu_stream = static_cast<stream_executor::gpu::GpuStream*>(
243 params.stream->implementation());
244 return tfrt::gpu::BorrowedGpuStream(
245 tfrt::gpu::wrapper::Context(se_gpu_executor->gpu_context()->context()),
246 tfrt::gpu::wrapper::Stream(se_gpu_stream->gpu_stream()));
247 }
248
249 // Wrap the GPU buffer specified in 'slice' to be passed to BEF functions as
250 // AsyncValueRef<GpuBuffer>.
CreateGpuBuffer(const Thunk::ExecuteParams & params,const BufferAllocation::Slice & slice)251 static tfrt::RCReference<tfrt::AsyncValue> CreateGpuBuffer(
252 const Thunk::ExecuteParams& params, const BufferAllocation::Slice& slice) {
253 se::DeviceMemoryBase data =
254 params.buffer_allocations->GetDeviceAddress(slice);
255 tfrt::gpu::wrapper::Pointer<void> pointer(data.opaque(),
256 tfrt::gpu::wrapper::Platform::CUDA);
257 auto allocator =
258 tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuOneShotAllocator<void>>(
259 pointer);
260 auto buffer =
261 tfrt::gpu::GpuBuffer::Allocate(std::move(allocator), data.size());
262 if (!buffer)
263 return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat(buffer.takeError()));
264 return tfrt::MakeAvailableAsyncValueRef<tfrt::gpu::GpuBuffer>(
265 std::move(*buffer));
266 }
267
CreateXcclContext(const Thunk::ExecuteParams & params,const NcclCollectiveConfig & xccl_config,tfrt::RequestContextBuilder * request_context_builder)268 static Status CreateXcclContext(
269 const Thunk::ExecuteParams& params, const NcclCollectiveConfig& xccl_config,
270 tfrt::RequestContextBuilder* request_context_builder) {
271 TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
272 params.GetGlobalDeviceId());
273 TF_ASSIGN_OR_RETURN(std::vector<GlobalDeviceId> participants,
274 GetParticipatingDevices(
275 global_device_id, *params.device_assn,
276 xccl_config.replica_groups, xccl_config.group_mode));
277 if (IsGlobalNcclConfig() &&
278 (participants.size() != params.device_assn->replica_count())) {
279 return InvalidArgument(
280 "Partial replica groups are not allowed when using NCCL_COMM_ID "
281 "environment configuration.");
282 }
283
284 TF_ASSIGN_OR_RETURN(
285 std::vector<LocalParticipant> local_participants,
286 GetLocalParticipants(participants, params.gpu_global_device_ids));
287 absl::flat_hash_map<tfrt::gpu::GpuSharedContext::LocalDeviceIdentifier, int>
288 local_ids_to_rank;
289 for (const auto& participant : local_participants) {
290 local_ids_to_rank[participant.device_ordinal] = participant.rank;
291 }
292
293 std::vector<int64> gpu_global_device_ids;
294 if (params.gpu_global_device_ids != nullptr) {
295 for (const auto& global_device_id : *params.gpu_global_device_ids) {
296 gpu_global_device_ids.push_back(global_device_id.value());
297 }
298 }
299
300 tfrt::gpu::XcclUniqueIdCallback xccl_unique_id_callback;
301 if (params.nccl_unique_id_callback != nullptr) {
302 xccl_unique_id_callback = [&](const tfrt::gpu::XcclCliqueKey& kernel_key)
303 -> llvm::Expected<std::string> {
304 std::vector<GlobalDeviceId> devices;
305 for (const int64_t device : kernel_key) {
306 devices.push_back(GlobalDeviceId(device));
307 }
308 auto nccl_unique_id_or =
309 (*params.nccl_unique_id_callback)(NcclCliqueKey(devices));
310 if (!nccl_unique_id_or.ok()) {
311 return tfrt::MakeStringError(
312 nccl_unique_id_or.status().error_message());
313 }
314 return nccl_unique_id_or.ValueOrDie();
315 };
316 }
317
318 request_context_builder->context_data().emplace<tfrt::gpu::GpuSharedContext>(
319 params.run_id.ToInt(), std::move(local_ids_to_rank),
320 std::move(gpu_global_device_ids), std::move(xccl_unique_id_callback),
321 /*compiled_code=*/nullptr);
322 return Status::OK();
323 }
324
ExecuteOnStream(const ExecuteParams & params)325 Status BefThunk::ExecuteOnStream(const ExecuteParams& params) {
326 VLOG(2) << "Executing BEF thunk.";
327
328 // Signature: (chain, stream, inputs..., outputs...) -> (chain).
329 const tfrt::Function* function = bef_file_->GetFunction(kFuncName);
330 if (!function) {
331 return tensorflow::errors::Internal("Failed to get '", kFuncName,
332 "' function.");
333 }
334
335 // Create execution context.
336 TF_ASSIGN_OR_RETURN(auto runtime_and_queue, GetCoreRuntimeAndWorkQueue());
337 tfrt::RequestContextBuilder request_context_builder(
338 runtime_and_queue.core_runtime->GetHostContext(),
339 /*resource_context=*/nullptr);
340 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
341 TF_RETURN_IF_ERROR(runtime_and_queue.work_queue->InitializeRequest(
342 &request_context_builder, &intra_op_threadpool));
343 if (xccl_config_.has_value()) {
344 TF_RETURN_IF_ERROR(
345 CreateXcclContext(params, *xccl_config_, &request_context_builder));
346 }
347 auto expected_req_ctx = std::move(request_context_builder).build();
348 if (!expected_req_ctx) {
349 auto error = expected_req_ctx.takeError();
350 return tensorflow::errors::Internal(llvm::toString(std::move(error)));
351 }
352 tfrt::ExecutionContext exec_ctx(std::move(*expected_req_ctx));
353
354 // Create owning handles for arguments and add pointer to them to 'args'.
355 tfrt::SmallVector<tfrt::AsyncValue*, 8> args;
356 args.reserve(function->num_arguments());
357 tfrt::AsyncValueRef<tfrt::Chain> chain = tfrt::GetReadyChain(exec_ctx.host());
358 args.push_back(chain.GetAsyncValue());
359 tfrt::gpu::BorrowedGpuStream stream = CreateGpuStream(params);
360 args.push_back(static_cast<tfrt::AsyncValueRef<tfrt::gpu::GpuStream>>(stream)
361 .GetAsyncValue());
362 llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 8> buffers;
363 for (auto input : inputs_) {
364 buffers.push_back(CreateGpuBuffer(params, input));
365 }
366 for (auto output : outputs_) {
367 buffers.push_back(CreateGpuBuffer(params, output));
368 }
369 for (auto& buffer : buffers) {
370 args.push_back(buffer.get());
371 }
372 if (args.size() != function->num_arguments())
373 return tensorflow::errors::Internal("Unexpected argument count.");
374
375 // Create return chain.
376 tfrt::RCReference<tfrt::AsyncValue> result;
377 if (function->num_results() != 1)
378 return tensorflow::errors::Internal("Unexpected result count.");
379
380 // Execute the function.
381 function->Execute(exec_ctx, args, {result});
382
383 // Wait for async execution to complete.
384 tfrt::Await(exec_ctx, llvm::makeArrayRef(result));
385
386 // Report error if any.
387 if (auto* error = result->GetErrorIfPresent())
388 return tensorflow::errors::Internal(error->message);
389
390 return Status::OK();
391 }
392
393 } // namespace gpu
394 } // namespace xla
395 #else // BEF_THUNKS
396 namespace xla {
397
IsBefThunkEnabled()398 bool gpu::IsBefThunkEnabled() { return false; }
399
CreateBefThunk(Thunk::ThunkInfo,mlir::Operation *,std::vector<BufferAllocation::Slice>,std::vector<BufferAllocation::Slice>)400 StatusOr<std::unique_ptr<gpu::Thunk>> gpu::CreateBefThunk(
401 Thunk::ThunkInfo, mlir::Operation*, std::vector<BufferAllocation::Slice>,
402 std::vector<BufferAllocation::Slice>) {
403 return tensorflow::errors::FailedPrecondition("BefThunks are disabled.");
404 }
405
406 } // namespace xla
407 #endif // BEF_THUNKS
408