• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The compiler API is used by the XLA service to generate executables that
17 // run on a given platform. This is a registry and abstract interface, for
18 // pluggability by the various platforms.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
22 
23 #include <functional>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/service/executable.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
38 #include "tensorflow/compiler/xla/service/logical_buffer.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/platform/threadpool.h"
46 
47 namespace xla {
48 
49 // The following types are used for ahead of time compilation.
50 
51 // Contains the object file data created as a result of ahead-of-time
52 // computation.
53 using ObjectFileData = std::vector<char>;
54 
55 // Abstract superclass describing the result of an ahead-of-time compilation.
56 class AotCompilationResult {
57  public:
58   AotCompilationResult(const AotCompilationResult&) = delete;
59   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
60 
61   virtual ~AotCompilationResult() = default;
62 
63  protected:
64   AotCompilationResult() = default;
65 };
66 
67 // Abstract superclass describing options to an ahead-of-time compilation.
68 class AotCompilationOptions {
69  public:
70   AotCompilationOptions(const AotCompilationOptions&) = delete;
71   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
72 
73   virtual ~AotCompilationOptions() = default;
74 
75   // Returns the ID of the platform to which these options apply.
76   virtual se::Platform::Id PlatformId() const = 0;
77 
replica_count()78   virtual int64 replica_count() const { return 0; }
num_cores()79   virtual int64 num_cores() const { return 0; }
use_spmd_partitioning()80   virtual bool use_spmd_partitioning() const { return false; }
deduplicate_hlo()81   virtual bool deduplicate_hlo() const { return false; }
82 
83   // Optional allocator that may be used for allocating temp space on the device
84   // during compilation.
device_allocator()85   se::DeviceMemoryAllocator* device_allocator() const {
86     return device_allocator_;
87   }
set_device_allocator(se::DeviceMemoryAllocator * device_allocator)88   void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
89     device_allocator_ = device_allocator;
90   }
91 
debug_options()92   const DebugOptions& debug_options() const { return debug_options_; }
mutable_debug_options()93   DebugOptions* mutable_debug_options() { return &debug_options_; }
94 
has_static_device_assignment()95   bool has_static_device_assignment() const {
96     return static_device_assignment_.has_value();
97   }
static_device_assignment()98   const DeviceAssignment& static_device_assignment() const {
99     CHECK(static_device_assignment_.has_value());
100     return *static_device_assignment_;
101   }
set_static_device_assignment(const DeviceAssignment & device_assignment)102   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
103     static_device_assignment_ = device_assignment;
104   }
105 
fusion_config_collection()106   FusionConfigCollection fusion_config_collection() const {
107     return fusion_config_collection_;
108   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)109   void set_fusion_config_collection(
110       FusionConfigCollection fusion_config_collection) {
111     fusion_config_collection_ = fusion_config_collection;
112   }
113 
fusion_config()114   const std::vector<std::vector<bool>>& fusion_config() const {
115     return fusion_config_;
116   }
set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)117   void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
118     fusion_config_ = fusion_config;
119   }
120 
121  protected:
122   AotCompilationOptions();
123 
124  private:
125   se::DeviceMemoryAllocator* device_allocator_ = nullptr;
126   DebugOptions debug_options_;
127   absl::optional<DeviceAssignment> static_device_assignment_;
128   std::vector<std::vector<bool>> fusion_config_;
129   FusionConfigCollection fusion_config_collection_ =
130       FusionConfigCollection::kOff;
131 };
132 
133 // Abstract superclass describing metadata produced during ahead-of-time
134 // compilation.
135 class AotCompilationMetadata {
136  public:
137   AotCompilationMetadata(const AotCompilationMetadata&) = delete;
138   AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
ToString()139   virtual std::string ToString() const { return ""; }
140   virtual ~AotCompilationMetadata() = default;
141 
142  protected:
143   AotCompilationMetadata() = default;
144 };
145 
146 // Abstract compiler interface that is subclassed for compilation on a
147 // particular platform.
148 //
149 // The compiler ties together high level optimization (HLO) and low level
150 // optimization (LLO) / codegen (CG) to generate efficient executables for the
151 // target platform.
152 //
153 // The platform-based compiler singletons are registered via module initializers
154 // in their corresponding XLA compiler libraries, and are registered via the
155 // RegisterCompilerFactory API below.
156 //
157 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
158 // XLA clients may be requesting compilation concurrently for a given
159 // platform.
160 class Compiler {
161  public:
162   struct CompileOptions {
163     // If device_allocator is not null, the compiler may use it to allocate temp
164     // space on the device for use during compilation.  For example, the
165     // compiler may allocate buffers on the device and then run variants of a
166     // given algorithm over those buffers, to see which variant is fastest.  Any
167     // space allocated will be deallocated before the compilation returns.
168     se::DeviceMemoryAllocator* device_allocator = nullptr;
169 
170     // An optional thread pool for parallel compilation.
171     tensorflow::thread::ThreadPool* thread_pool = nullptr;
172   };
173 
~Compiler()174   virtual ~Compiler() {}
175 
176   // Returns the ID of the platform that this compiler targets.
177   virtual se::Platform::Id PlatformId() const = 0;
178 
179   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
180   // module.
181   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
182       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
183       const CompileOptions& options) = 0;
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)184   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
185       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
186       se::DeviceMemoryAllocator* device_allocator) {
187     return RunHloPasses(std::move(module), executor,
188                         CompileOptions{device_allocator});
189   }
190 
191   // Runs HLO passes to optimize the given HloModule, perform scheduling and
192   // buffer assignment, returns the optimized module and the buffer assignments.
193   // This interface is intentionally narrow.
194   virtual StatusOr<
195       std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)196   RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
197                                    se::StreamExecutor* executor, bool optimize,
198                                    const CompileOptions& options) {
199     return Unimplemented("This compiler does not support this method");
200   }
201 
202   // Compiles the HLO module for execution on a device given by the executor,
203   // and returns an executable object or an error status. No HLO passes are
204   // applied to module. Generally a module should be passed through RunHloPasses
205   // prior to calling this method because some HLO passes are required for
206   // correctness. Takes ownership of the HLO module.
207   //
208   // The compiler may optionally specialize to the individual device
209   // (not just type of device) indicated by the executor.
210   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
211       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
212       const CompileOptions& options) = 0;
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)213   StatusOr<std::unique_ptr<Executable>> RunBackend(
214       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
215       se::DeviceMemoryAllocator* device_allocator) {
216     return RunBackend(std::move(module), executor,
217                       CompileOptions{device_allocator});
218   }
219 
220   // Compiles a set of HLO modules that can run in parallel, potentially
221   // communicating data between the modules, and returns a corresponding
222   // sequence of executable objects.
223   //
224   // TODO(b/68666782): Remove this method after adding support for multiple
225   // modules to RunHloPasses and RunBackends.
226   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
227       std::unique_ptr<HloModuleGroup> module_group,
228       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
229       const CompileOptions& options) = 0;
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)230   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
231       std::unique_ptr<HloModuleGroup> module_group,
232       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
233       se::DeviceMemoryAllocator* device_allocator) {
234     return Compile(std::move(module_group), stream_exec,
235                    CompileOptions{device_allocator});
236   }
237 
238   // Returns the backend configurations that the backend will consider for the
239   // given HLO. Returns no configurations if the backend does not support
240   // configurations for the given HLO.
241   //
242   // The stream executor is passed in to provide information about the hardware
243   // that the backend configurations would be targeting.
244   virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
245   ComputeBackendConfigs(const HloInstruction& hlo,
246                         se::StreamExecutor* executor) const;
247 
248   // Returns the backend configuration that the backend chooses by default for
249   // the given HLO. Returns no configuration if the backend does not support
250   // configurations for the given HLO.
251   //
252   // The stream executor is passed in to provide information about the hardware
253   // that the backend configurations would be targeting.
254   virtual std::unique_ptr<tensorflow::protobuf::Message>
255   ComputeDefaultBackendConfig(const HloInstruction& hlo,
256                               se::StreamExecutor* executor) const;
257 
258   // Compiles the HLO module group for ahead-of-time execution.  This is
259   // intended for use in static compilation.
260   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
261   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
262                      const AotCompilationOptions& options) = 0;
263 
264   // Similar to CompileAheadOfTime above but AotCompilationMetadata
265   // has an argument that can be populated during compilation.
266   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
267   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
268                      const AotCompilationOptions& options,
269                      std::unique_ptr<AotCompilationMetadata>* metadata);
270 
271   /////
272   // The Compiler class also serves as a point to register compiler objects
273   // for the various platforms.
274 
275   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
276 
277   // Registers the compiler singleton for the platform. This is assumed to
278   // be a singleton, so no ownership is transferred.
279   //
280   // Precondition: a platform kind must not be registered more than once.
281   static void RegisterCompilerFactory(se::Platform::Id platform_id,
282                                       CompilerFactory compiler_factory);
283 
284   // Returns the compiler singleton pointer if it is available for the given
285   // platform, or an error status if it is not.
286   static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
287 
288   // Returns a function that computes the size in bytes of the logical
289   // buffer that contains a shape.
290   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
291 
292   // Returns a function that computes the size in bytes of a given
293   // logical buffer.
BufferSizeBytesFunction()294   std::function<int64(const BufferValue&)> BufferSizeBytesFunction() {
295     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
296     return [shape_size](const BufferValue& buffer) {
297       return shape_size(buffer.shape());
298     };
299   }
300 
DeviceShapeRepresentation(const Shape & shape)301   virtual Shape DeviceShapeRepresentation(const Shape& shape) const {
302     return shape;
303   }
304 
305  private:
306   // Mutex that guards the platform-compiler map.
307   static tensorflow::mutex platform_compiler_mutex_;
308 
309   // Map from platform kind to compiler factory.
310   static std::map<se::Platform::Id, CompilerFactory>*
311   GetPlatformCompilerFactories();
312 
313   // Map from platform kind to compiler instance, if we made one already (based
314   // on the factories above).
315   static std::map<se::Platform::Id, std::unique_ptr<Compiler>>*
316   GetPlatformCompilers();
317 };
318 
319 }  // namespace xla
320 
321 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
322