• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
25 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 class GpuExecutable;
34 
35 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
36 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
37 //
38 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
39 // to initialize and execute the invocation respectively. Its subclasses are
40 // supposed to override these interfaces to launch a generated kernel or call an
41 // external library function (such as operations in cuBLAS).
42 //
43 // This is thread-compatible.
44 class Thunk {
45  public:
46   enum Kind {
47     kCholesky,
48     kCollectivePermute,
49     kConditional,
50     kConvolution,
51     kCopy,
52     kCudnnBatchNormBackward,
53     kCudnnBatchNormForwardInference,
54     kCudnnBatchNormForwardTraining,
55     kCustomCall,
56     kFft,
57     kGemm,
58     kInfeed,
59     kKernel,
60     kMemset32BitValue,
61     kMemzero,
62     kNcclAllGather,
63     kNcclAllReduce,
64     kNcclAllToAll,
65     kOutfeed,
66     kReplicaId,
67     kPartitionId,
68     kSequential,
69     kTriangularSolve,
70     kTuple,
71     kWhile,
72   };
73 
74   struct ThunkInfo {
75     absl::optional<int64> profile_index;
76     std::string profile_annotation;
77   };
78 
79   // The hlo_instruction argument is meant to be the instruction this thunk was
80   // generated from, but Thunk never uses this argument other than to save it
81   // to Thunk::hlo_instruction, so it can be null.
Thunk(Kind kind,ThunkInfo thunk_info)82   explicit Thunk(Kind kind, ThunkInfo thunk_info)
83       : kind_(kind),
84         profile_index_(thunk_info.profile_index),
85         profile_annotation_(thunk_info.profile_annotation) {}
~Thunk()86   virtual ~Thunk() {}
87   Thunk(const Thunk&) = delete;
88   Thunk& operator=(const Thunk&) = delete;
89 
kind()90   Kind kind() const { return kind_; }
profile_annotation()91   string profile_annotation() const { return profile_annotation_; }
92 
93   // Prepares the thunk for execution on the given StreamExecutor.
94   //
95   // This may be called multiple times.  Its main purpose is to give us a chance
96   // to do initialization outside of ExecuteOnStream() so that the
97   // time spent initializing doesn't count towards our execution profile.
Initialize(const GpuExecutable &,se::StreamExecutor *)98   virtual Status Initialize(const GpuExecutable& /*executable*/,
99                             se::StreamExecutor* /*executor*/) {
100     return Status::OK();
101   }
102 
103   // Parameters passed to ExecuteOnStream.  Encapsulated in a struct so that
104   // when we add something we don't have to change every subclass of Thunk.
105   struct ExecuteParams {
106     const BufferAllocations* buffer_allocations;  // never null
107     se::Stream* stream;
108     RunId run_id;
109     HloExecutionProfiler* profiler;                               // never null
110     const DeviceAssignment* device_assn;                          // never null
111     std::vector<std::function<void()>>* deferred_host_callbacks;  // never null
112     const std::vector<GlobalDeviceId>* gpu_global_device_ids;     // may be null
113     const NcclUniqueIdCallback* nccl_unique_id_callback;          // may be null
114 
115     StatusOr<GlobalDeviceId> GetGlobalDeviceId() const;
116   };
117 
118   // Execute the kernel for the thunk on the given stream. This method must be
119   // called after Initialize and can be called multiple times over Thunk's
120   // lifetime.
121   //
122   // Precondition: Initialize(stream->parent()) has been called.
123   virtual Status ExecuteOnStream(const ExecuteParams& params) = 0;
124 
125  protected:
profile_index()126   absl::optional<int64> profile_index() const { return profile_index_; }
127 
128   // Safely copies the given buffer to the GPU, deleting it on the host only
129   // after the copy has completed.
130   template <typename T>
SafeH2DMemcpy(se::DeviceMemory<T> dest,std::unique_ptr<T[]> buf,int64 count,se::Stream * stream,std::vector<std::function<void ()>> * deferred_host_callbacks)131   void SafeH2DMemcpy(
132       se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64 count,
133       se::Stream* stream,
134       std::vector<std::function<void()>>* deferred_host_callbacks) {
135     stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
136     auto* buf_raw = buf.release();
137     deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; });
138   }
139 
140  private:
141   Kind kind_;
142   absl::optional<int64> profile_index_;
143   std::string profile_annotation_;
144 };
145 
146 // A sequence of thunks.
147 using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
148 
149 absl::string_view ThunkKindToString(Thunk::Kind);
150 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
151 
152 // A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
153 // shape.
154 struct ShapedSlice {
155   BufferAllocation::Slice slice;
156   Shape shape;
157 };
158 
159 }  // namespace gpu
160 }  // namespace xla
161 
162 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
163