• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
18 
19 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
20 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
21 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
23 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 struct NcclAllReduceConfig {
32   NcclCollectiveConfig config;
33   ReductionKind reduction_kind;
34 };
35 
36 // Thunk that performs a NCCL-based All-Reduce or Reduce-Scatter among CUDA
37 // GPU-based replicas.
38 class NcclAllReduceThunkBase : public NcclCollectiveThunk {
39  public:
40   static absl::optional<ReductionKind> MatchAllReduceComputation(
41       mlir::Region& computation);
42 
43   NcclAllReduceThunkBase(Kind kind, ThunkInfo thunk_info,
44                          NcclAllReduceConfig config,
45                          std::vector<Buffer> buffers);
46 
47  protected:
config()48   const NcclCollectiveConfig& config() const override { return config_.config; }
49 
50  protected:
51   const NcclAllReduceConfig config_;
52   const std::vector<Buffer> buffers_;
53 };
54 
55 class NcclAllReduceThunk : public NcclAllReduceThunkBase {
56  public:
57   NcclAllReduceThunk(ThunkInfo thunk_info, mlir::lmhlo::AllReduceOp op,
58                      std::vector<Buffer> buffers);
59 
GetName()60   static const char* GetName() { return "AllReduce"; }
61 
62   static bool CanImplement(mlir::lmhlo::AllReduceOp op);
63   static bool IsDegenerate(mlir::lmhlo::AllReduceOp op, int64_t replica_count,
64                            int64_t partition_count);
65   static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::AllReduceOp op);
66 
67  protected:
68   Status RunNcclCollective(const ExecuteParams& params,
69                            ncclComm_t comm) override;
70 };
71 
72 class NcclAllReduceStartThunk : public NcclAllReduceThunkBase {
73  public:
74   NcclAllReduceStartThunk(ThunkInfo thunk_info,
75                           mlir::lmhlo_gpu::AllReduceStartOp op,
76                           std::vector<Buffer> buffers);
77 
GetName()78   static const char* GetName() { return "AllReduceStart"; }
79 
80   static bool CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op);
81   static bool IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
82                            int64_t replica_count, int64_t partition_count);
83   static CollectiveOpGroupMode GetGroupMode(
84       mlir::lmhlo_gpu::AllReduceStartOp op);
85 
86   StatusOr<se::Event> TakeDoneEvent(int device_ordinal)
87       ABSL_LOCKS_EXCLUDED(mu_);
88 
89  protected:
90   Status RunNcclCollective(const ExecuteParams& params,
91                            ncclComm_t comm) override;
92 
93  private:
94   absl::Mutex mu_;
95   // Store done events (by device ordinal) for the done thunk to wait on.
96   absl::flat_hash_map<int, se::Event> done_events_ ABSL_GUARDED_BY(mu_);
97 };
98 
99 class NcclAllReduceDoneThunk : public Thunk {
100  public:
101   explicit NcclAllReduceDoneThunk(ThunkInfo thunk_info,
102                                   NcclAllReduceStartThunk& start_thunk);
103 
104   Status ExecuteOnStream(const ExecuteParams& params) override;
105 
106  private:
107   NcclAllReduceStartThunk& start_thunk_;
108 };
109 
110 class NcclReduceScatterThunk : public NcclAllReduceThunkBase {
111  public:
112   NcclReduceScatterThunk(ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
113                          std::vector<Buffer> buffers);
114 
GetName()115   static const char* GetName() { return "ReduceScatter"; }
116 
117   // Returns whether the given instruction can be lowered to a nccl
118   // reduce-scatter call.
119   static bool CanImplement(mlir::lmhlo::ReduceScatterOp op);
120   static bool IsDegenerate(mlir::lmhlo::ReduceScatterOp op,
121                            int64_t replica_count, int64_t partition_count);
122   static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::ReduceScatterOp op);
123 
124  protected:
125   Status RunNcclCollective(const ExecuteParams& params,
126                            ncclComm_t comm) override;
127 };
128 
129 }  // namespace gpu
130 }  // namespace xla
131 
132 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
133