1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The compiler API is used by the XLA service to generate executables that 17 // run on a given platform. This is a registry and abstract interface, for 18 // pluggability by the various platforms. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 22 23 #include <functional> 24 #include <map> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 31 #include "tensorflow/compiler/xla/service/buffer_value.h" 32 #include "tensorflow/compiler/xla/service/computation_placer.h" 33 #include "tensorflow/compiler/xla/service/executable.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 38 #include "tensorflow/compiler/xla/service/logical_buffer.h" 39 #include "tensorflow/compiler/xla/statusor.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/core/platform/mutex.h" 42 #include "tensorflow/core/platform/protobuf.h" 43 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 44 #include "tensorflow/core/platform/thread_annotations.h" 45 #include "tensorflow/core/platform/threadpool.h" 46 47 namespace xla { 48 49 // The following types are used for ahead of time compilation. 50 51 // Contains the object file data created as a result of ahead-of-time 52 // computation. 53 using ObjectFileData = std::vector<char>; 54 55 // Abstract superclass describing the result of an ahead-of-time compilation. 56 class AotCompilationResult { 57 public: 58 AotCompilationResult(const AotCompilationResult&) = delete; 59 AotCompilationResult& operator=(AotCompilationResult const&) = delete; 60 61 virtual ~AotCompilationResult() = default; 62 63 protected: 64 AotCompilationResult() = default; 65 }; 66 67 // Abstract superclass describing options to an ahead-of-time compilation. 68 class AotCompilationOptions { 69 public: 70 AotCompilationOptions(const AotCompilationOptions&) = delete; 71 AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; 72 73 virtual ~AotCompilationOptions() = default; 74 75 // Returns the ID of the platform to which these options apply. 76 virtual se::Platform::Id PlatformId() const = 0; 77 replica_count()78 virtual int64 replica_count() const { return 0; } num_cores()79 virtual int64 num_cores() const { return 0; } use_spmd_partitioning()80 virtual bool use_spmd_partitioning() const { return false; } deduplicate_hlo()81 virtual bool deduplicate_hlo() const { return false; } 82 83 // Optional allocator that may be used for allocating temp space on the device 84 // during compilation. device_allocator()85 se::DeviceMemoryAllocator* device_allocator() const { 86 return device_allocator_; 87 } set_device_allocator(se::DeviceMemoryAllocator * device_allocator)88 void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) { 89 device_allocator_ = device_allocator; 90 } 91 debug_options()92 const DebugOptions& debug_options() const { return debug_options_; } mutable_debug_options()93 DebugOptions* mutable_debug_options() { return &debug_options_; } 94 has_static_device_assignment()95 bool has_static_device_assignment() const { 96 return static_device_assignment_.has_value(); 97 } static_device_assignment()98 const DeviceAssignment& static_device_assignment() const { 99 CHECK(static_device_assignment_.has_value()); 100 return *static_device_assignment_; 101 } set_static_device_assignment(const DeviceAssignment & device_assignment)102 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 103 static_device_assignment_ = device_assignment; 104 } 105 fusion_config_collection()106 FusionConfigCollection fusion_config_collection() const { 107 return fusion_config_collection_; 108 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)109 void set_fusion_config_collection( 110 FusionConfigCollection fusion_config_collection) { 111 fusion_config_collection_ = fusion_config_collection; 112 } 113 fusion_config()114 const std::vector<std::vector<bool>>& fusion_config() const { 115 return fusion_config_; 116 } set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)117 void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) { 118 fusion_config_ = fusion_config; 119 } 120 121 protected: 122 AotCompilationOptions(); 123 124 private: 125 se::DeviceMemoryAllocator* device_allocator_ = nullptr; 126 DebugOptions debug_options_; 127 absl::optional<DeviceAssignment> static_device_assignment_; 128 std::vector<std::vector<bool>> fusion_config_; 129 FusionConfigCollection fusion_config_collection_ = 130 FusionConfigCollection::kOff; 131 }; 132 133 // Abstract superclass describing metadata produced during ahead-of-time 134 // compilation. 135 class AotCompilationMetadata { 136 public: 137 AotCompilationMetadata(const AotCompilationMetadata&) = delete; 138 AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; ToString()139 virtual std::string ToString() const { return ""; } 140 virtual ~AotCompilationMetadata() = default; 141 142 protected: 143 AotCompilationMetadata() = default; 144 }; 145 146 // Abstract compiler interface that is subclassed for compilation on a 147 // particular platform. 148 // 149 // The compiler ties together high level optimization (HLO) and low level 150 // optimization (LLO) / codegen (CG) to generate efficient executables for the 151 // target platform. 152 // 153 // The platform-based compiler singletons are registered via module initializers 154 // in their corresponding XLA compiler libraries, and are registered via the 155 // RegisterCompilerFactory API below. 156 // 157 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple 158 // XLA clients may be requesting compilation concurrently for a given 159 // platform. 160 class Compiler { 161 public: 162 struct CompileOptions { 163 // If device_allocator is not null, the compiler may use it to allocate temp 164 // space on the device for use during compilation. For example, the 165 // compiler may allocate buffers on the device and then run variants of a 166 // given algorithm over those buffers, to see which variant is fastest. Any 167 // space allocated will be deallocated before the compilation returns. 168 se::DeviceMemoryAllocator* device_allocator = nullptr; 169 170 // An optional thread pool for parallel compilation. 171 tensorflow::thread::ThreadPool* thread_pool = nullptr; 172 }; 173 ~Compiler()174 virtual ~Compiler() {} 175 176 // Returns the ID of the platform that this compiler targets. 177 virtual se::Platform::Id PlatformId() const = 0; 178 179 // Runs Hlo passes to optimize the given Hlo module, returns the optimized 180 // module. 181 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 182 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 183 const CompileOptions& options) = 0; RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)184 StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 185 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 186 se::DeviceMemoryAllocator* device_allocator) { 187 return RunHloPasses(std::move(module), executor, 188 CompileOptions{device_allocator}); 189 } 190 191 // Runs HLO passes to optimize the given HloModule, perform scheduling and 192 // buffer assignment, returns the optimized module and the buffer assignments. 193 // This interface is intentionally narrow. 194 virtual StatusOr< 195 std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)196 RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module, 197 se::StreamExecutor* executor, bool optimize, 198 const CompileOptions& options) { 199 return Unimplemented("This compiler does not support this method"); 200 } 201 202 // Compiles the HLO module for execution on a device given by the executor, 203 // and returns an executable object or an error status. No HLO passes are 204 // applied to module. Generally a module should be passed through RunHloPasses 205 // prior to calling this method because some HLO passes are required for 206 // correctness. Takes ownership of the HLO module. 207 // 208 // The compiler may optionally specialize to the individual device 209 // (not just type of device) indicated by the executor. 210 virtual StatusOr<std::unique_ptr<Executable>> RunBackend( 211 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 212 const CompileOptions& options) = 0; RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)213 StatusOr<std::unique_ptr<Executable>> RunBackend( 214 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 215 se::DeviceMemoryAllocator* device_allocator) { 216 return RunBackend(std::move(module), executor, 217 CompileOptions{device_allocator}); 218 } 219 220 // Compiles a set of HLO modules that can run in parallel, potentially 221 // communicating data between the modules, and returns a corresponding 222 // sequence of executable objects. 223 // 224 // TODO(b/68666782): Remove this method after adding support for multiple 225 // modules to RunHloPasses and RunBackends. 226 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 227 std::unique_ptr<HloModuleGroup> module_group, 228 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 229 const CompileOptions& options) = 0; Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)230 StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 231 std::unique_ptr<HloModuleGroup> module_group, 232 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 233 se::DeviceMemoryAllocator* device_allocator) { 234 return Compile(std::move(module_group), stream_exec, 235 CompileOptions{device_allocator}); 236 } 237 238 // Returns the backend configurations that the backend will consider for the 239 // given HLO. Returns no configurations if the backend does not support 240 // configurations for the given HLO. 241 // 242 // The stream executor is passed in to provide information about the hardware 243 // that the backend configurations would be targeting. 244 virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>> 245 ComputeBackendConfigs(const HloInstruction& hlo, 246 se::StreamExecutor* executor) const; 247 248 // Returns the backend configuration that the backend chooses by default for 249 // the given HLO. Returns no configuration if the backend does not support 250 // configurations for the given HLO. 251 // 252 // The stream executor is passed in to provide information about the hardware 253 // that the backend configurations would be targeting. 254 virtual std::unique_ptr<tensorflow::protobuf::Message> 255 ComputeDefaultBackendConfig(const HloInstruction& hlo, 256 se::StreamExecutor* executor) const; 257 258 // Compiles the HLO module group for ahead-of-time execution. This is 259 // intended for use in static compilation. 260 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 261 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 262 const AotCompilationOptions& options) = 0; 263 264 // Similar to CompileAheadOfTime above but AotCompilationMetadata 265 // has an argument that can be populated during compilation. 266 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 267 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 268 const AotCompilationOptions& options, 269 std::unique_ptr<AotCompilationMetadata>* metadata); 270 271 ///// 272 // The Compiler class also serves as a point to register compiler objects 273 // for the various platforms. 274 275 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>; 276 277 // Registers the compiler singleton for the platform. This is assumed to 278 // be a singleton, so no ownership is transferred. 279 // 280 // Precondition: a platform kind must not be registered more than once. 281 static void RegisterCompilerFactory(se::Platform::Id platform_id, 282 CompilerFactory compiler_factory); 283 284 // Returns the compiler singleton pointer if it is available for the given 285 // platform, or an error status if it is not. 286 static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform); 287 288 // Returns a function that computes the size in bytes of the logical 289 // buffer that contains a shape. 290 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; 291 292 // Returns a function that computes the size in bytes of a given 293 // logical buffer. BufferSizeBytesFunction()294 std::function<int64(const BufferValue&)> BufferSizeBytesFunction() { 295 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); 296 return [shape_size](const BufferValue& buffer) { 297 return shape_size(buffer.shape()); 298 }; 299 } 300 DeviceShapeRepresentation(const Shape & shape)301 virtual Shape DeviceShapeRepresentation(const Shape& shape) const { 302 return shape; 303 } 304 305 private: 306 // Mutex that guards the platform-compiler map. 307 static tensorflow::mutex platform_compiler_mutex_; 308 309 // Map from platform kind to compiler factory. 310 static std::map<se::Platform::Id, CompilerFactory>* 311 GetPlatformCompilerFactories(); 312 313 // Map from platform kind to compiler instance, if we made one already (based 314 // on the factories above). 315 static std::map<se::Platform::Id, std::unique_ptr<Compiler>>* 316 GetPlatformCompilers(); 317 }; 318 319 } // namespace xla 320 321 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 322