1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/executable_run_options.h" 23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 24 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 28 29 namespace xla { 30 namespace gpu { 31 32 class GpuExecutable; 33 34 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the 35 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. 36 // 37 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable 38 // to initialize and execute the invocation respectively. Its subclasses are 39 // supposed to override these interfaces to launch a generated kernel or call an 40 // external library function (such as operations in cuBLAS). 41 // 42 // This is thread-compatible. 43 class Thunk { 44 public: 45 enum Kind { 46 kCholesky, 47 kCollectivePermute, 48 kConditional, 49 kConvolution, 50 kCopy, 51 kCudnnBatchNormBackward, 52 kCudnnBatchNormForwardInference, 53 kCudnnBatchNormForwardTraining, 54 kCustomCall, 55 kFft, 56 kGemm, 57 kInfeed, 58 kKernel, 59 kMemset32BitValue, 60 kMemzero, 61 kNcclAllReduce, 62 kOutfeed, 63 kReplicaId, 64 kSequential, 65 kTriangularSolve, 66 kTuple, 67 kWhile, 68 }; 69 70 // The hlo_instruction argument is meant to be the instruction this thunk was 71 // generated from, but Thunk never uses this argument other than to save it 72 // to Thunk::hlo_instruction, so it can be null. Thunk(Kind kind,const HloInstruction * hlo_instruction)73 explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) 74 : kind_(kind), hlo_instruction_(hlo_instruction) {} ~Thunk()75 virtual ~Thunk() {} 76 Thunk(const Thunk&) = delete; 77 Thunk& operator=(const Thunk&) = delete; 78 kind()79 Kind kind() const { return kind_; } hlo_instruction()80 const HloInstruction* hlo_instruction() const { return hlo_instruction_; } 81 82 // Prepares the thunk for execution on the given StreamExecutor. 83 // 84 // This may be called multiple times. Its main purpose is to give us a chance 85 // to do initialization outside of ExecuteOnStream() so that the 86 // time spent initializing doesn't count towards our execution profile. Initialize(const GpuExecutable &,se::StreamExecutor *)87 virtual Status Initialize(const GpuExecutable& /*executable*/, 88 se::StreamExecutor* /*executor*/) { 89 return Status::OK(); 90 } 91 92 // Parameters passed to ExecuteOnStream. Encapsulated in a struct so that 93 // when we add something we don't have to change every subclass of Thunk. 94 struct ExecuteParams { 95 const BufferAllocations* buffer_allocations; // never null 96 se::Stream* stream; 97 RunId run_id; 98 HloExecutionProfiler* profiler; // never null 99 const DeviceAssignment* device_assn; // never null 100 }; 101 102 // Execute the kernel for the thunk on the given stream. This method must be 103 // called after Initialize and can be called multiple times over Thunk's 104 // lifetime. 105 // 106 // Precondition: Initialize(stream->parent()) has been called. 107 virtual Status ExecuteOnStream(const ExecuteParams& params) = 0; 108 109 protected: GetModuleConfig()110 const HloModuleConfig& GetModuleConfig() const { 111 return hlo_instruction()->GetModule()->config(); 112 } 113 114 // Safely copies the given buffer to the GPU, deleting it on the host only 115 // after the copy has completed. 116 template <typename T> SafeH2DMemcpy(se::DeviceMemory<T> dest,std::unique_ptr<T[]> buf,int64 count,se::Stream * stream)117 void SafeH2DMemcpy(se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, 118 int64 count, se::Stream* stream) { 119 stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T)); 120 auto* buf_raw = buf.release(); 121 stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; }); 122 } 123 124 private: 125 Kind kind_; 126 const HloInstruction* hlo_instruction_; 127 }; 128 129 // A sequence of thunks. 130 using ThunkSequence = std::vector<std::unique_ptr<Thunk>>; 131 132 absl::string_view ThunkKindToString(Thunk::Kind); 133 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); 134 135 } // namespace gpu 136 } // namespace xla 137 138 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 139