• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 class GpuExecutable;
33 
34 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
35 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
36 //
37 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
38 // to initialize and execute the invocation respectively. Its subclasses are
39 // supposed to override these interfaces to launch a generated kernel or call an
40 // external library function (such as operations in cuBLAS).
41 //
42 // This is thread-compatible.
43 class Thunk {
44  public:
45   enum Kind {
46     kCholesky,
47     kCollectivePermute,
48     kConditional,
49     kConvolution,
50     kCopy,
51     kCudnnBatchNormBackward,
52     kCudnnBatchNormForwardInference,
53     kCudnnBatchNormForwardTraining,
54     kCustomCall,
55     kFft,
56     kGemm,
57     kInfeed,
58     kKernel,
59     kMemset32BitValue,
60     kMemzero,
61     kNcclAllReduce,
62     kOutfeed,
63     kReplicaId,
64     kSequential,
65     kTriangularSolve,
66     kTuple,
67     kWhile,
68   };
69 
70   // The hlo_instruction argument is meant to be the instruction this thunk was
71   // generated from, but Thunk never uses this argument other than to save it
72   // to Thunk::hlo_instruction, so it can be null.
Thunk(Kind kind,const HloInstruction * hlo_instruction)73   explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
74       : kind_(kind), hlo_instruction_(hlo_instruction) {}
~Thunk()75   virtual ~Thunk() {}
76   Thunk(const Thunk&) = delete;
77   Thunk& operator=(const Thunk&) = delete;
78 
kind()79   Kind kind() const { return kind_; }
hlo_instruction()80   const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
81 
82   // Prepares the thunk for execution on the given StreamExecutor.
83   //
84   // This may be called multiple times.  Its main purpose is to give us a chance
85   // to do initialization outside of ExecuteOnStream() so that the
86   // time spent initializing doesn't count towards our execution profile.
Initialize(const GpuExecutable &,se::StreamExecutor *)87   virtual Status Initialize(const GpuExecutable& /*executable*/,
88                             se::StreamExecutor* /*executor*/) {
89     return Status::OK();
90   }
91 
92   // Parameters passed to ExecuteOnStream.  Encapsulated in a struct so that
93   // when we add something we don't have to change every subclass of Thunk.
94   struct ExecuteParams {
95     const BufferAllocations* buffer_allocations;  // never null
96     se::Stream* stream;
97     RunId run_id;
98     HloExecutionProfiler* profiler;       // never null
99     const DeviceAssignment* device_assn;  // never null
100   };
101 
102   // Execute the kernel for the thunk on the given stream. This method must be
103   // called after Initialize and can be called multiple times over Thunk's
104   // lifetime.
105   //
106   // Precondition: Initialize(stream->parent()) has been called.
107   virtual Status ExecuteOnStream(const ExecuteParams& params) = 0;
108 
109  protected:
GetModuleConfig()110   const HloModuleConfig& GetModuleConfig() const {
111     return hlo_instruction()->GetModule()->config();
112   }
113 
114   // Safely copies the given buffer to the GPU, deleting it on the host only
115   // after the copy has completed.
116   template <typename T>
SafeH2DMemcpy(se::DeviceMemory<T> dest,std::unique_ptr<T[]> buf,int64 count,se::Stream * stream)117   void SafeH2DMemcpy(se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf,
118                      int64 count, se::Stream* stream) {
119     stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
120     auto* buf_raw = buf.release();
121     stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; });
122   }
123 
124  private:
125   Kind kind_;
126   const HloInstruction* hlo_instruction_;
127 };
128 
129 // A sequence of thunks.
130 using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
131 
132 absl::string_view ThunkKindToString(Thunk::Kind);
133 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
134 
135 }  // namespace gpu
136 }  // namespace xla
137 
138 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
139