1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/executable_run_options.h" 23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 24 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 28 29 namespace xla { 30 namespace gpu { 31 32 class GpuExecutable; 33 34 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the 35 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. 36 // 37 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable 38 // to initialize and execute the invocation respectively. Its subclasses are 39 // supposed to override these interfaces to launch a generated kernel or call an 40 // external library function (such as operations in cuBLAS). 41 // 42 // This is thread-compatible. 43 class Thunk { 44 public: 45 enum Kind { 46 kCholesky, 47 kCollectivePermute, 48 kConditional, 49 kConvolution, 50 kCopy, 51 kCudnnBatchNormBackward, 52 kCudnnBatchNormForwardInference, 53 kCudnnBatchNormForwardTraining, 54 kCustomCall, 55 kFft, 56 kGemm, 57 kInfeed, 58 kKernel, 59 kMemset32BitValue, 60 kMemzero, 61 kNcclAllGather, 62 kNcclAllReduce, 63 kNcclAllReduceStart, 64 kNcclAllReduceDone, 65 kNcclReduceScatter, 66 kNcclAllToAll, 67 kOutfeed, 68 kReplicaId, 69 kPartitionId, 70 kSequential, 71 kTriangularSolve, 72 kWhile, 73 }; 74 75 struct ThunkInfo { 76 absl::optional<int64> profile_index; 77 std::string profile_annotation; 78 }; 79 80 // The hlo_instruction argument is meant to be the instruction this thunk was 81 // generated from, but Thunk never uses this argument other than to save it 82 // to Thunk::hlo_instruction, so it can be null. Thunk(Kind kind,ThunkInfo thunk_info)83 explicit Thunk(Kind kind, ThunkInfo thunk_info) 84 : kind_(kind), 85 profile_index_(thunk_info.profile_index), 86 profile_annotation_(thunk_info.profile_annotation) {} ~Thunk()87 virtual ~Thunk() {} 88 Thunk(const Thunk&) = delete; 89 Thunk& operator=(const Thunk&) = delete; 90 ToStringExtra(int indent)91 virtual std::string ToStringExtra(int indent) const { return ""; } kind()92 Kind kind() const { return kind_; } profile_annotation()93 std::string profile_annotation() const { return profile_annotation_; } 94 95 // Prepares the thunk for execution on the given StreamExecutor. 96 // 97 // This may be called multiple times. Its main purpose is to give us a chance 98 // to do initialization outside of ExecuteOnStream() so that the 99 // time spent initializing doesn't count towards our execution profile. Initialize(const GpuExecutable &,se::StreamExecutor *)100 virtual Status Initialize(const GpuExecutable& /*executable*/, 101 se::StreamExecutor* /*executor*/) { 102 return Status::OK(); 103 } 104 105 // Parameters passed to ExecuteOnStream. Encapsulated in a struct so that 106 // when we add something we don't have to change every subclass of Thunk. 107 struct ExecuteParams { 108 const BufferAllocations* buffer_allocations; // never null 109 se::Stream* stream; 110 se::Stream* async_comms_stream; 111 RunId run_id; 112 const DeviceAssignment* device_assn; // never null 113 std::vector<std::function<void()>>* deferred_host_callbacks; // never null 114 const std::vector<GlobalDeviceId>* gpu_global_device_ids; // may be null 115 const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null 116 117 StatusOr<GlobalDeviceId> GetGlobalDeviceId() const; 118 }; 119 120 // Execute the kernel for the thunk on the given stream. This method must be 121 // called after Initialize and can be called multiple times over Thunk's 122 // lifetime. 123 // 124 // Precondition: Initialize(stream->parent()) has been called. 125 virtual Status ExecuteOnStream(const ExecuteParams& params) = 0; 126 127 static absl::string_view KindToString(Thunk::Kind kind); 128 129 protected: profile_index()130 absl::optional<int64> profile_index() const { return profile_index_; } 131 132 // Safely copies the given buffer to the GPU, deleting it on the host only 133 // after the copy has completed. 134 template <typename T> SafeH2DMemcpy(se::DeviceMemory<T> dest,std::unique_ptr<T[]> buf,int64_t count,se::Stream * stream,std::vector<std::function<void ()>> * deferred_host_callbacks)135 void SafeH2DMemcpy( 136 se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64_t count, 137 se::Stream* stream, 138 std::vector<std::function<void()>>* deferred_host_callbacks) { 139 stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T)); 140 auto* buf_raw = buf.release(); 141 deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; }); 142 } 143 144 private: 145 Kind kind_; 146 absl::optional<int64> profile_index_; 147 std::string profile_annotation_; 148 }; 149 150 // A sequence of thunks. 151 class ThunkSequence : public std::vector<std::unique_ptr<Thunk>> { 152 public: 153 std::string ToString(int indent = 0, 154 std::function<std::string(const Thunk*)> 155 get_thunk_annotation = nullptr) const; 156 }; 157 158 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); 159 160 // A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its 161 // shape. 162 struct ShapedSlice { 163 BufferAllocation::Slice slice; 164 Shape shape; 165 }; 166 167 } // namespace gpu 168 } // namespace xla 169 170 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 171