• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 class GpuExecutable;
33 
34 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
35 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
36 //
37 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
38 // to initialize and execute the invocation respectively. Its subclasses are
39 // supposed to override these interfaces to launch a generated kernel or call an
40 // external library function (such as operations in cuBLAS).
41 //
42 // This is thread-compatible.
43 class Thunk {
44  public:
45   enum Kind {
46     kCholesky,
47     kCollectivePermute,
48     kConditional,
49     kConvolution,
50     kCopy,
51     kCudnnBatchNormBackward,
52     kCudnnBatchNormForwardInference,
53     kCudnnBatchNormForwardTraining,
54     kCustomCall,
55     kFft,
56     kGemm,
57     kInfeed,
58     kKernel,
59     kMemset32BitValue,
60     kMemzero,
61     kNcclAllGather,
62     kNcclAllReduce,
63     kNcclAllReduceStart,
64     kNcclAllReduceDone,
65     kNcclReduceScatter,
66     kNcclAllToAll,
67     kOutfeed,
68     kReplicaId,
69     kPartitionId,
70     kSequential,
71     kTriangularSolve,
72     kWhile,
73   };
74 
75   struct ThunkInfo {
76     absl::optional<int64> profile_index;
77     std::string profile_annotation;
78   };
79 
80   // The hlo_instruction argument is meant to be the instruction this thunk was
81   // generated from, but Thunk never uses this argument other than to save it
82   // to Thunk::hlo_instruction, so it can be null.
Thunk(Kind kind,ThunkInfo thunk_info)83   explicit Thunk(Kind kind, ThunkInfo thunk_info)
84       : kind_(kind),
85         profile_index_(thunk_info.profile_index),
86         profile_annotation_(thunk_info.profile_annotation) {}
~Thunk()87   virtual ~Thunk() {}
88   Thunk(const Thunk&) = delete;
89   Thunk& operator=(const Thunk&) = delete;
90 
ToStringExtra(int indent)91   virtual std::string ToStringExtra(int indent) const { return ""; }
kind()92   Kind kind() const { return kind_; }
profile_annotation()93   std::string profile_annotation() const { return profile_annotation_; }
94 
95   // Prepares the thunk for execution on the given StreamExecutor.
96   //
97   // This may be called multiple times.  Its main purpose is to give us a chance
98   // to do initialization outside of ExecuteOnStream() so that the
99   // time spent initializing doesn't count towards our execution profile.
Initialize(const GpuExecutable &,se::StreamExecutor *)100   virtual Status Initialize(const GpuExecutable& /*executable*/,
101                             se::StreamExecutor* /*executor*/) {
102     return Status::OK();
103   }
104 
105   // Parameters passed to ExecuteOnStream.  Encapsulated in a struct so that
106   // when we add something we don't have to change every subclass of Thunk.
107   struct ExecuteParams {
108     const BufferAllocations* buffer_allocations;  // never null
109     se::Stream* stream;
110     se::Stream* async_comms_stream;
111     RunId run_id;
112     const DeviceAssignment* device_assn;                          // never null
113     std::vector<std::function<void()>>* deferred_host_callbacks;  // never null
114     const std::vector<GlobalDeviceId>* gpu_global_device_ids;     // may be null
115     const NcclUniqueIdCallback* nccl_unique_id_callback;          // may be null
116 
117     StatusOr<GlobalDeviceId> GetGlobalDeviceId() const;
118   };
119 
120   // Execute the kernel for the thunk on the given stream. This method must be
121   // called after Initialize and can be called multiple times over Thunk's
122   // lifetime.
123   //
124   // Precondition: Initialize(stream->parent()) has been called.
125   virtual Status ExecuteOnStream(const ExecuteParams& params) = 0;
126 
127   static absl::string_view KindToString(Thunk::Kind kind);
128 
129  protected:
profile_index()130   absl::optional<int64> profile_index() const { return profile_index_; }
131 
132   // Safely copies the given buffer to the GPU, deleting it on the host only
133   // after the copy has completed.
134   template <typename T>
SafeH2DMemcpy(se::DeviceMemory<T> dest,std::unique_ptr<T[]> buf,int64_t count,se::Stream * stream,std::vector<std::function<void ()>> * deferred_host_callbacks)135   void SafeH2DMemcpy(
136       se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64_t count,
137       se::Stream* stream,
138       std::vector<std::function<void()>>* deferred_host_callbacks) {
139     stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
140     auto* buf_raw = buf.release();
141     deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; });
142   }
143 
144  private:
145   Kind kind_;
146   absl::optional<int64> profile_index_;
147   std::string profile_annotation_;
148 };
149 
150 // A sequence of thunks.
151 class ThunkSequence : public std::vector<std::unique_ptr<Thunk>> {
152  public:
153   std::string ToString(int indent = 0,
154                        std::function<std::string(const Thunk*)>
155                            get_thunk_annotation = nullptr) const;
156 };
157 
158 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
159 
160 // A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
161 // shape.
162 struct ShapedSlice {
163   BufferAllocation::Slice slice;
164   Shape shape;
165 };
166 
167 }  // namespace gpu
168 }  // namespace xla
169 
170 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
171