• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The compiler API is used by the XLA service to generate executables that
17 // run on a given platform. This is a registry and abstract interface, for
18 // pluggability by the various platforms.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
22 
23 #include <functional>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/service/buffer_value.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/executable.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
37 #include "tensorflow/compiler/xla/service/logical_buffer.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
43 #include "tensorflow/core/platform/thread_annotations.h"
44 
45 namespace xla {
46 
47 // The following types are used for ahead of time compilation.
48 
49 // Contains the object file data created as a result of ahead-of-time
50 // compuation.
51 using ObjectFileData = std::vector<char>;
52 
53 // Abstract superclass describing the result of an ahead-of-time compilation.
54 class AotCompilationResult {
55  public:
56   AotCompilationResult(const AotCompilationResult&) = delete;
57   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
58 
59   virtual ~AotCompilationResult() = default;
60 
61  protected:
62   AotCompilationResult() = default;
63 };
64 
65 // Abstract superclass describing options to an ahead-of-time compilation.
66 class AotCompilationOptions {
67  public:
68   AotCompilationOptions(const AotCompilationOptions&) = delete;
69   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
70 
71   virtual ~AotCompilationOptions() = default;
72 
73   // Returns the ID of the platform to which these options apply.
74   virtual se::Platform::Id PlatformId() const = 0;
75 
76   // Optional allocator that may be used for allocating temp space on the device
77   // during compilation.
device_allocator()78   DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
set_device_allocator(DeviceMemoryAllocator * device_allocator)79   void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
80     device_allocator_ = device_allocator;
81   }
82 
debug_options()83   const DebugOptions& debug_options() const { return debug_options_; }
mutable_debug_options()84   DebugOptions* mutable_debug_options() { return &debug_options_; }
85 
has_static_device_assignment()86   bool has_static_device_assignment() const {
87     return static_device_assignment_.has_value();
88   }
static_device_assignment()89   const DeviceAssignment& static_device_assignment() const {
90     CHECK(static_device_assignment_.has_value());
91     return *static_device_assignment_;
92   }
set_static_device_assignment(const DeviceAssignment & device_assignment)93   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
94     static_device_assignment_ = device_assignment;
95   }
96 
97  protected:
98   AotCompilationOptions();
99 
100  private:
101   DeviceMemoryAllocator* device_allocator_ = nullptr;
102   DebugOptions debug_options_;
103   absl::optional<DeviceAssignment> static_device_assignment_;
104 };
105 
106 // Abstract superclass describing metadata produced during ahead-of-time
107 // compilation.
108 class AotCompilationMetadata {
109  public:
110   AotCompilationMetadata(const AotCompilationMetadata&) = delete;
111   AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
112 
113   virtual ~AotCompilationMetadata() = default;
114 
115  protected:
116   AotCompilationMetadata() = default;
117 };
118 
119 // Abstract compiler interface that is subclassed for compilation on a
120 // particular platform.
121 //
122 // The compiler ties together high level optimization (HLO) and low level
123 // optimization (LLO) / codegen (CG) to generate efficient executables for the
124 // target platform.
125 //
126 // The platform-based compiler singletons are registered via module initializers
127 // in their corresponding XLA compiler libraries, and are registered via the
128 // RegisterCompilerFactory API below.
129 //
130 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
131 // XLA clients may be requesting compilation concurrently for a given
132 // platform.
133 class Compiler {
134  public:
~Compiler()135   virtual ~Compiler() {}
136 
137   // Returns the ID of the platform that this compiler targets.
138   virtual se::Platform::Id PlatformId() const = 0;
139 
140   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
141   // module.
142   //
143   // If device_allocator is not null, the compiler may use it to allocate temp
144   // space on the device for use during compilation.  For example, the compiler
145   // may allocate buffers on the device and then run variants of a given
146   // algorithm over those buffers, to see which variant is fastest.  Any space
147   // allocated should be deallocated before this function returns.
148   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
149       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
150       DeviceMemoryAllocator* device_allocator) = 0;
151 
152   // Optimizes a HLO module group, a set of module which runs concurrently on
153   // multiple devices potentially communicating data between the modules.
154   virtual Status RunHloPassesOnModuleGroup(
155       HloModuleGroup* module_group,
156       absl::Span<se::StreamExecutor* const> executors,
157       DeviceMemoryAllocator* device_allocator) = 0;
158 
159   // Compiles the HLO module for execution on a device given by the executor,
160   // and returns an executable object or an error status. No HLO passes are
161   // applied to module. Generally a module should be passed through RunHloPasses
162   // prior to calling this method because some HLO passes are required for
163   // correctness. Takes ownership of the HLO module.
164   //
165   // The compiler may optionally specialize to the individual device
166   // (not just type of device) indicated by the executor.
167   //
168   // device_allocator is optional; see RunHloPasses.
169   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
170       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
171       DeviceMemoryAllocator* device_allocator) = 0;
172 
173   // Compiles a set of HLO modules that can run in parallel, potentially
174   // communicating data between the modules.
175   virtual StatusOr<std::vector<std::unique_ptr<Executable>>>
176   RunBackendOnModuleGroup(
177       std::unique_ptr<HloModuleGroup> module_group,
178       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
179       DeviceMemoryAllocator* device_allocator) = 0;
180 
181   // Compiles a set of HLO modules that can run in parallel, potentially
182   // communicating data between the modules, and returns a corresponding
183   // sequence of executable objects.
184   //
185   // device_allocator is optional; see RunHloPasses.
186   //
187   // TODO(b/68666782): Remove this method after adding support for multiple
188   // modules to RunHloPasses and RunBackends.
189   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
190       std::unique_ptr<HloModuleGroup> module_group,
191       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
192       DeviceMemoryAllocator* device_allocator) = 0;
193 
194   // Returns the backend configurations that the backend will consider for the
195   // given HLO. Returns no configurations if the backend does not support
196   // configurations for the given HLO.
197   //
198   // The stream executor is passed in to provide information about the hardware
199   // that the backend configurations would be targeting.
200   virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
201   ComputeBackendConfigs(const HloInstruction& hlo,
202                         se::StreamExecutor* executor) const;
203 
204   // Returns the backend configuration that the backend chooses by default for
205   // the given HLO. Returns no configuration if the backend does not support
206   // configurations for the given HLO.
207   //
208   // The stream executor is passed in to provide information about the hardware
209   // that the backend configurations would be targeting.
210   virtual std::unique_ptr<tensorflow::protobuf::Message>
211   ComputeDefaultBackendConfig(const HloInstruction& hlo,
212                               se::StreamExecutor* executor) const;
213 
214   // Compiles the HLO module group for ahead-of-time execution.  This is
215   // intended for use in static compilation.
216   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
217   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
218                      const AotCompilationOptions& options) = 0;
219 
220   // Similar to CompileAheadOfTime above but AotCompilationMetadata
221   // has an argument that can be populated during compilation.
222   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
223   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
224                      const AotCompilationOptions& options,
225                      std::unique_ptr<AotCompilationMetadata>* metadata);
226 
227   /////
228   // The Compiler class also serves as a point to register compiler objects
229   // for the various platforms.
230 
231   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
232 
233   // Registers the compiler singleton for the platform. This is assumed to
234   // be a singleton, so no ownership is transferred.
235   //
236   // Precondition: a platform kind must not be registered more than once.
237   static void RegisterCompilerFactory(se::Platform::Id platform_id,
238                                       CompilerFactory compiler_factory);
239 
240   // Returns the compiler singleton pointer if it is available for the given
241   // platform, or an error status if it is not.
242   static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
243 
244   // Returns a function that computes the size in bytes of the logical
245   // buffer that contains a shape.
246   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
247 
248   // Returns a function that computes the size in bytes of a given
249   // logical buffer.
BufferSizeBytesFunction()250   std::function<int64(const BufferValue&)> BufferSizeBytesFunction() {
251     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
252     return [shape_size](const BufferValue& buffer) {
253       return shape_size(buffer.shape());
254     };
255   }
256 
257  private:
258   // Mutex that guards the platform-compiler map.
259   static tensorflow::mutex platform_compiler_mutex_;
260 
261   // Map from platform kind to compiler factory.
262   static std::map<se::Platform::Id, CompilerFactory>*
263   GetPlatformCompilerFactories();
264 
265   // Map from platform kind to compiler instance, if we made one already (based
266   // on the factories above).
267   static std::map<se::Platform::Id, std::unique_ptr<Compiler>>*
268   GetPlatformCompilers();
269 };
270 
271 }  // namespace xla
272 
273 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
274