1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The compiler API is used by the XLA service to generate executables that 17 // run on a given platform. This is a registry and abstract interface, for 18 // pluggability by the various platforms. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 22 23 #include <functional> 24 #include <map> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/service/buffer_value.h" 31 #include "tensorflow/compiler/xla/service/computation_placer.h" 32 #include "tensorflow/compiler/xla/service/executable.h" 33 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 34 #include "tensorflow/compiler/xla/service/hlo_module.h" 35 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 37 #include "tensorflow/compiler/xla/service/logical_buffer.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/core/platform/mutex.h" 41 #include "tensorflow/core/platform/protobuf.h" 42 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 43 #include "tensorflow/core/platform/thread_annotations.h" 44 45 namespace xla { 46 47 // The following types are used for ahead of time compilation. 48 49 // Contains the object file data created as a result of ahead-of-time 50 // compuation. 51 using ObjectFileData = std::vector<char>; 52 53 // Abstract superclass describing the result of an ahead-of-time compilation. 54 class AotCompilationResult { 55 public: 56 AotCompilationResult(const AotCompilationResult&) = delete; 57 AotCompilationResult& operator=(AotCompilationResult const&) = delete; 58 59 virtual ~AotCompilationResult() = default; 60 61 protected: 62 AotCompilationResult() = default; 63 }; 64 65 // Abstract superclass describing options to an ahead-of-time compilation. 66 class AotCompilationOptions { 67 public: 68 AotCompilationOptions(const AotCompilationOptions&) = delete; 69 AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; 70 71 virtual ~AotCompilationOptions() = default; 72 73 // Returns the ID of the platform to which these options apply. 74 virtual se::Platform::Id PlatformId() const = 0; 75 76 // Optional allocator that may be used for allocating temp space on the device 77 // during compilation. device_allocator()78 DeviceMemoryAllocator* device_allocator() const { return device_allocator_; } set_device_allocator(DeviceMemoryAllocator * device_allocator)79 void set_device_allocator(DeviceMemoryAllocator* device_allocator) { 80 device_allocator_ = device_allocator; 81 } 82 debug_options()83 const DebugOptions& debug_options() const { return debug_options_; } mutable_debug_options()84 DebugOptions* mutable_debug_options() { return &debug_options_; } 85 has_static_device_assignment()86 bool has_static_device_assignment() const { 87 return static_device_assignment_.has_value(); 88 } static_device_assignment()89 const DeviceAssignment& static_device_assignment() const { 90 CHECK(static_device_assignment_.has_value()); 91 return *static_device_assignment_; 92 } set_static_device_assignment(const DeviceAssignment & device_assignment)93 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 94 static_device_assignment_ = device_assignment; 95 } 96 97 protected: 98 AotCompilationOptions(); 99 100 private: 101 DeviceMemoryAllocator* device_allocator_ = nullptr; 102 DebugOptions debug_options_; 103 absl::optional<DeviceAssignment> static_device_assignment_; 104 }; 105 106 // Abstract superclass describing metadata produced during ahead-of-time 107 // compilation. 108 class AotCompilationMetadata { 109 public: 110 AotCompilationMetadata(const AotCompilationMetadata&) = delete; 111 AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; 112 113 virtual ~AotCompilationMetadata() = default; 114 115 protected: 116 AotCompilationMetadata() = default; 117 }; 118 119 // Abstract compiler interface that is subclassed for compilation on a 120 // particular platform. 121 // 122 // The compiler ties together high level optimization (HLO) and low level 123 // optimization (LLO) / codegen (CG) to generate efficient executables for the 124 // target platform. 125 // 126 // The platform-based compiler singletons are registered via module initializers 127 // in their corresponding XLA compiler libraries, and are registered via the 128 // RegisterCompilerFactory API below. 129 // 130 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple 131 // XLA clients may be requesting compilation concurrently for a given 132 // platform. 133 class Compiler { 134 public: ~Compiler()135 virtual ~Compiler() {} 136 137 // Returns the ID of the platform that this compiler targets. 138 virtual se::Platform::Id PlatformId() const = 0; 139 140 // Runs Hlo passes to optimize the given Hlo module, returns the optimized 141 // module. 142 // 143 // If device_allocator is not null, the compiler may use it to allocate temp 144 // space on the device for use during compilation. For example, the compiler 145 // may allocate buffers on the device and then run variants of a given 146 // algorithm over those buffers, to see which variant is fastest. Any space 147 // allocated should be deallocated before this function returns. 148 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 149 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 150 DeviceMemoryAllocator* device_allocator) = 0; 151 152 // Optimizes a HLO module group, a set of module which runs concurrently on 153 // multiple devices potentially communicating data between the modules. 154 virtual Status RunHloPassesOnModuleGroup( 155 HloModuleGroup* module_group, 156 absl::Span<se::StreamExecutor* const> executors, 157 DeviceMemoryAllocator* device_allocator) = 0; 158 159 // Compiles the HLO module for execution on a device given by the executor, 160 // and returns an executable object or an error status. No HLO passes are 161 // applied to module. Generally a module should be passed through RunHloPasses 162 // prior to calling this method because some HLO passes are required for 163 // correctness. Takes ownership of the HLO module. 164 // 165 // The compiler may optionally specialize to the individual device 166 // (not just type of device) indicated by the executor. 167 // 168 // device_allocator is optional; see RunHloPasses. 169 virtual StatusOr<std::unique_ptr<Executable>> RunBackend( 170 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 171 DeviceMemoryAllocator* device_allocator) = 0; 172 173 // Compiles a set of HLO modules that can run in parallel, potentially 174 // communicating data between the modules. 175 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> 176 RunBackendOnModuleGroup( 177 std::unique_ptr<HloModuleGroup> module_group, 178 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 179 DeviceMemoryAllocator* device_allocator) = 0; 180 181 // Compiles a set of HLO modules that can run in parallel, potentially 182 // communicating data between the modules, and returns a corresponding 183 // sequence of executable objects. 184 // 185 // device_allocator is optional; see RunHloPasses. 186 // 187 // TODO(b/68666782): Remove this method after adding support for multiple 188 // modules to RunHloPasses and RunBackends. 189 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 190 std::unique_ptr<HloModuleGroup> module_group, 191 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 192 DeviceMemoryAllocator* device_allocator) = 0; 193 194 // Returns the backend configurations that the backend will consider for the 195 // given HLO. Returns no configurations if the backend does not support 196 // configurations for the given HLO. 197 // 198 // The stream executor is passed in to provide information about the hardware 199 // that the backend configurations would be targeting. 200 virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>> 201 ComputeBackendConfigs(const HloInstruction& hlo, 202 se::StreamExecutor* executor) const; 203 204 // Returns the backend configuration that the backend chooses by default for 205 // the given HLO. Returns no configuration if the backend does not support 206 // configurations for the given HLO. 207 // 208 // The stream executor is passed in to provide information about the hardware 209 // that the backend configurations would be targeting. 210 virtual std::unique_ptr<tensorflow::protobuf::Message> 211 ComputeDefaultBackendConfig(const HloInstruction& hlo, 212 se::StreamExecutor* executor) const; 213 214 // Compiles the HLO module group for ahead-of-time execution. This is 215 // intended for use in static compilation. 216 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 217 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 218 const AotCompilationOptions& options) = 0; 219 220 // Similar to CompileAheadOfTime above but AotCompilationMetadata 221 // has an argument that can be populated during compilation. 222 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 223 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 224 const AotCompilationOptions& options, 225 std::unique_ptr<AotCompilationMetadata>* metadata); 226 227 ///// 228 // The Compiler class also serves as a point to register compiler objects 229 // for the various platforms. 230 231 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>; 232 233 // Registers the compiler singleton for the platform. This is assumed to 234 // be a singleton, so no ownership is transferred. 235 // 236 // Precondition: a platform kind must not be registered more than once. 237 static void RegisterCompilerFactory(se::Platform::Id platform_id, 238 CompilerFactory compiler_factory); 239 240 // Returns the compiler singleton pointer if it is available for the given 241 // platform, or an error status if it is not. 242 static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform); 243 244 // Returns a function that computes the size in bytes of the logical 245 // buffer that contains a shape. 246 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; 247 248 // Returns a function that computes the size in bytes of a given 249 // logical buffer. BufferSizeBytesFunction()250 std::function<int64(const BufferValue&)> BufferSizeBytesFunction() { 251 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); 252 return [shape_size](const BufferValue& buffer) { 253 return shape_size(buffer.shape()); 254 }; 255 } 256 257 private: 258 // Mutex that guards the platform-compiler map. 259 static tensorflow::mutex platform_compiler_mutex_; 260 261 // Map from platform kind to compiler factory. 262 static std::map<se::Platform::Id, CompilerFactory>* 263 GetPlatformCompilerFactories(); 264 265 // Map from platform kind to compiler instance, if we made one already (based 266 // on the factories above). 267 static std::map<se::Platform::Id, std::unique_ptr<Compiler>>* 268 GetPlatformCompilers(); 269 }; 270 271 } // namespace xla 272 273 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 274