1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ 18 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "absl/types/variant.h" 25 #include "tensorflow/compiler/xla/debug_options_flags.h" 26 #include "tensorflow/compiler/xla/service/computation_layout.h" 27 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 28 #include "tensorflow/compiler/xla/service/hlo.pb.h" 29 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 30 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" 33 #include "tensorflow/compiler/xla/service/owning_device_memory.h" 34 #include "tensorflow/compiler/xla/service/service_executable_run_options.h" 35 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 36 #include "tensorflow/compiler/xla/shape_tree.h" 37 #include "tensorflow/compiler/xla/statusor.h" 38 #include "tensorflow/compiler/xla/util.h" 39 #include "tensorflow/compiler/xla/xla_data.pb.h" 40 #include "tensorflow/core/platform/mutex.h" 41 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 42 #include "tensorflow/core/platform/thread_annotations.h" 43 44 namespace xla { 45 46 // ExecutionOutput encapsulates the output buffers of a execution and the 47 // leftover buffers to be released by the caller. 48 struct ExecutionOutput { ExecutionOutputExecutionOutput49 ExecutionOutput(ScopedShapedBuffer result, 50 std::vector<OwningDeviceMemory> to_be_released) 51 : result(std::move(result)), to_be_released(std::move(to_be_released)) {} 52 ScopedShapedBuffer result; 53 54 // Leftover buffers for the caller to release. Elements in this list are 55 // donated input memory buffers that are not reused by XLA as outputs. 56 std::vector<OwningDeviceMemory> to_be_released; 57 }; 58 59 // A given platform's compiler will produce an Executable -- this is a uniform 60 // interface that is used for launching compiled programs across platforms. 61 class Executable { 62 public: Executable(std::unique_ptr<HloModule> hlo_module,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)63 explicit Executable( 64 std::unique_ptr<HloModule> hlo_module, 65 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, 66 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) 67 : hlo_module_(std::move(hlo_module)), 68 hlo_profile_printer_data_(std::move(hlo_profile_printer_data)), 69 hlo_profile_index_map_(std::move(hlo_profile_index_map)) { 70 CHECK_EQ(hlo_profile_printer_data_.get() == nullptr, 71 hlo_profile_index_map_.get() == nullptr); 72 } ~Executable()73 virtual ~Executable() {} 74 75 // Enqueues the compilation result on the provided stream, passing the given 76 // arguments. This call is blocking and returns after the execution is done. 77 // 78 // If the hlo_execution_profile is provided as non-nullptr, profiling will be 79 // enabled. 80 // 81 // Returns a shaped buffer containing the result of the computation. 82 virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream( 83 const ServiceExecutableRunOptions* run_options, 84 absl::Span<const ShapedBuffer* const> arguments, 85 HloExecutionProfile* hlo_execution_profile) = 0; 86 87 // Same as ExecuteOnStream(), but this call is non-blocking and returns as 88 // soon as all of the operations are enqueued for launch on the stream. 89 virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream( 90 const ServiceExecutableRunOptions* run_options, 91 absl::Span<const ShapedBuffer* const> arguments) = 0; 92 93 // Starts the given program executing on the given stream/executor. 94 // 95 // `arguments` are ShapeTree containing the input parameters. For each element 96 // in the shape tree, if the element holds the ownership of the memory, it is 97 // considered donated and XLA will potentially reuse it as output buffers. For 98 // all donated inputs, XLA is also responsible for freeing them. 99 // 100 // If an input is donated to XLA but is not reused as output, it is returned 101 // as an leftover buffer for the caller to release. ExecuteOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,HloExecutionProfile * hlo_execution_profile)102 virtual StatusOr<ExecutionOutput> ExecuteOnStream( 103 const ServiceExecutableRunOptions* run_options, 104 std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments, 105 HloExecutionProfile* hlo_execution_profile) { 106 return Unimplemented( 107 "MaybeOwningDeviceMemory version of overload is not implemented "); 108 } 109 ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments)110 virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream( 111 const ServiceExecutableRunOptions* run_options, 112 std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments) { 113 return Unimplemented( 114 "MaybeOwningDeviceMemory version of overload is not implemented "); 115 } 116 117 // Same as ExecuteOnStream(), but runs this executable on multiple 118 // streams. arguments[i] contains the arguments to the execution on 119 // run_options[i]->stream() and the returned value is at index i of the 120 // returned vector. 121 virtual StatusOr<std::vector<ScopedShapedBuffer>> ExecuteOnStreams( 122 absl::Span<const ServiceExecutableRunOptions> run_options, 123 absl::Span<const absl::Span<const ShapedBuffer* const>> arguments); 124 125 // Populates `hlo_execution_profile` from `executor`. This is implicit in any 126 // Execute* API call that takes a hlo_execution_profile argument, but must be 127 // called explicitly for other (async, for example) variants after the stream 128 // has completed. PopulateExecutionProfile(HloExecutionProfile * hlo_execution_profile,se::Stream * stream)129 virtual Status PopulateExecutionProfile( 130 HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { 131 return Status::OK(); 132 } 133 134 // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a 135 // timer for the execution, sets up HLO profiling if enabled, and fills in the 136 // given ExecutionProfile if non-null. 137 StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper( 138 const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, 139 absl::Span<const ShapedBuffer* const> arguments); 140 141 // Returns the ExecutionProfile from executing on the device. This includes 142 // the number of cycles taken for the computation or the compilation time. execution_profile()143 ExecutionProfile execution_profile() const { 144 tensorflow::mutex_lock lock(mutex_); 145 return execution_profile_; 146 } 147 hlo_profile_printer_data()148 const HloProfilePrinterData& hlo_profile_printer_data() const { 149 CHECK(hlo_profiling_enabled()); 150 return *hlo_profile_printer_data_; 151 } 152 hlo_profile_index_map()153 const HloProfileIndexMap& hlo_profile_index_map() const { 154 CHECK(hlo_profiling_enabled()); 155 return *hlo_profile_index_map_; 156 } 157 158 // Returns whether this executable was compiled with HLO profilings support 159 // enabled. If not, the caller should not expect an hlo_execution_profile 160 // passed to ExecuteOnStream above to be populated during execution. hlo_profiling_enabled()161 bool hlo_profiling_enabled() const { 162 return hlo_profile_printer_data_ != nullptr; 163 } 164 module()165 HloModule& module() const { return *hlo_module_; } 166 has_module()167 const bool has_module() const { return hlo_module_ != nullptr; } 168 module_config()169 const HloModuleConfig& module_config() const { return hlo_module_->config(); } 170 171 // The shape (including layout) that results from this execution. This is the 172 // shape of the DeviceMemoryBase result value in ExecuteOnStream above. result_shape()173 const Shape& result_shape() const { 174 return hlo_module_->config().entry_computation_layout().result_shape(); 175 } 176 177 // Returns the size of the executable in bytes. Returns -1 by default if the 178 // method is not overridden to support this kind of query. 179 virtual int64 SizeInBytes(); 180 181 // Dumping helpers. set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot)182 void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) { 183 hlo_snapshot_ = std::move(hlo_snapshot); 184 } dumping_snapshot()185 bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } hlo_snapshot()186 HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } 187 188 protected: 189 mutable tensorflow::mutex mutex_; 190 191 // Execution profile data on the device. 192 ExecutionProfile execution_profile_ GUARDED_BY(mutex_); 193 194 // HloModule this was compiled from. BufferAssignment keeps pointers to 195 // HloInstructions owned by the HloModule so we need to keep the HloModule 196 // around. 197 const std::unique_ptr<HloModule> hlo_module_; 198 199 // HloSnapshot this was compiled from. Null if not dumping executions. 200 std::unique_ptr<HloSnapshot> hlo_snapshot_; 201 202 // Execution count, used to generate a unique filename for each dumped 203 // execution. 204 int64 execution_count_ = 0; 205 206 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_; 207 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_; 208 }; 209 210 } // namespace xla 211 212 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ 213