1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 23 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 27 28 namespace xla { 29 namespace gpu { 30 31 class GpuExecutable; 32 33 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the 34 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. 35 // 36 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable 37 // to initialize and execute the invocation respectively. Its subclasses are 38 // supposed to override these interfaces to launch a generated kernel or call an 39 // external library function (such as operations in cuBLAS). 40 // 41 // This is thread-compatible. 42 class Thunk { 43 public: 44 enum Kind { 45 kCholesky, 46 kConditional, 47 kConvolution, 48 kCopy, 49 kCudnnBatchNormBackward, 50 kCudnnBatchNormForwardInference, 51 kCudnnBatchNormForwardTraining, 52 kNcclAllReduce, 53 kFft, 54 kGemm, 55 kInfeed, 56 kKernel, 57 kMemset32BitValue, 58 kMemzero, 59 kOutfeed, 60 kSequential, 61 kTriangularSolve, 62 kTuple, 63 kWhile, 64 }; 65 66 // The hlo_instruction argument is meant to be the instruction this thunk was 67 // generated from, but Thunk never uses this argument other than to save it 68 // to Thunk::hlo_instruction, so it can be null. Thunk(Kind kind,const HloInstruction * hlo_instruction)69 explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) 70 : kind_(kind), hlo_instruction_(hlo_instruction) {} ~Thunk()71 virtual ~Thunk() {} 72 Thunk(const Thunk&) = delete; 73 Thunk& operator=(const Thunk&) = delete; 74 kind()75 Kind kind() const { return kind_; } hlo_instruction()76 const HloInstruction* hlo_instruction() const { return hlo_instruction_; } 77 78 // Prepares the thunk for execution on the given StreamExecutor. 79 // 80 // This may be called multiple times. Its main purpose is to give us a chance 81 // to do initialization outside of ExecuteOnStream() so that the 82 // time spent initializing doesn't count towards our execution profile. Initialize(const GpuExecutable &,se::StreamExecutor *)83 virtual Status Initialize(const GpuExecutable& /*executable*/, 84 se::StreamExecutor* /*executor*/) { 85 return Status::OK(); 86 } 87 88 // Returns true if this kernel will autotune for the stream device the next 89 // time it is run. WillAutotuneKernel(se::Stream *)90 virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; } 91 92 // Execute the kernel for the thunk on the given stream. This method must be 93 // called after Initialize and can be called multiple times over Thunk's 94 // lifetime. 'stream' and 'profiler' must be non-null. 95 // 96 // Precondition: Initialize(stream->parent()) has been called. 97 virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, 98 se::Stream* stream, 99 HloExecutionProfiler* profiler) = 0; 100 101 private: 102 Kind kind_; 103 const HloInstruction* hlo_instruction_; 104 }; 105 106 // A sequence of thunks. 107 using ThunkSequence = std::vector<std::unique_ptr<Thunk>>; 108 109 absl::string_view ThunkKindToString(Thunk::Kind); 110 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); 111 112 } // namespace gpu 113 } // namespace xla 114 115 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ 116