• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
27 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
28 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
34 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 namespace cpu {
43 
44 // CPU-targeting implementation of the XLA Executable interface.
45 //
46 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host
47 // architecture, so JIT-ed code and host code share the same ABI.
48 class CpuExecutable : public Executable {
49  public:
50   CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
51                 std::unique_ptr<const BufferAssignment> assignment,
52                 std::unique_ptr<HloModule> hlo_module,
53                 const string& entry_function_name,
54                 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
55                 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable()56   ~CpuExecutable() override {}
57 
58   StatusOr<ScopedShapedBuffer> ExecuteOnStream(
59       const ServiceExecutableRunOptions* run_options,
60       absl::Span<const ShapedBuffer* const> arguments,
61       HloExecutionProfile* hlo_execution_profile) override;
62 
63   StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
64       const ServiceExecutableRunOptions* run_options,
65       absl::Span<const ShapedBuffer* const> arguments) override;
66 
67   // This should be called after set_ir_module_string.
ir_module_string()68   const string& ir_module_string() const { return ir_module_string_; }
69 
set_ir_module_string(const string & ir_module_string)70   void set_ir_module_string(const string& ir_module_string) {
71     ir_module_string_ = ir_module_string;
72   }
73 
74   static int64 ShapeSizeBytes(const Shape& shape);
75 
76   // Type of the computation function we expect in the JIT.
77   using ComputeFunctionType =
78       void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
79                const void** /*args*/, void** /*buffer_table*/,
80                int64* /*profile_counters*/);
81 
compute_function()82   const ComputeFunctionType& compute_function() const {
83     return compute_function_;
84   }
85 
buffer_assignment()86   const BufferAssignment& buffer_assignment() const { return *assignment_; }
87 
88  private:
89   // This is for sharing the code between ExecuteOnStream and
90   // ExecuteAsyncOnStream.
91   //
92   // Notice that it's tricky to use correctly, as the profile object (when it
93   // exists) must out-live the task.
94   StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
95       const ServiceExecutableRunOptions* run_options,
96       absl::Span<const ShapedBuffer* const> arguments,
97       HloExecutionProfile* hlo_execution_profile);
98 
99   // Creates an array suitable for passing as the "buffer_table" argument to the
100   // JIT compiled function pointer.
101   //
102   // Returns (unowning_buffers, owning_buffers) where:
103   //
104   //  - unowning_buffers.data() can be passed as the buffer_table argument as-is
105   //    and includes pointers to the scratch storage required by the
106   //    computation, the live-out buffer into which the result will be written
107   //    and entry computation parameters.
108   //
109   //  - owning_buffers contains owning pointers to the buffers that were
110   //    allocated by this routine.  This routine allocates buffers for temporary
111   //    storage and the live-out buffer into which the computation writes it
112   //    result.
113   StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
114                      std::vector<OwningDeviceMemory>>>
115   CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
116                     absl::Span<const ShapedBuffer* const> arguments);
117 
118   // Calls the generated function performing the computation with the given
119   // arguments using the supplied buffers.
120   Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
121                                 absl::Span<const se::DeviceMemoryBase> buffers,
122                                 HloExecutionProfile* hlo_execution_profile);
123 
124   // Creates a ScopedShapedBuffer for holding the result of the computation,
125   // moving buffers out of allocated_buffers and into the result as appropriate.
126   // The addresses are set according to buffer assignment.
127   StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
128       const ServiceExecutableRunOptions* run_options,
129       absl::Span<OwningDeviceMemory> buffers);
130 
131   // Returns the points-to set of the root instruction of the entry
132   // computation. Uses points-to analysis from buffer assignment.
133   const PointsToSet& GetRootPointsToSet() const;
134 
135   // The JIT containing compiled modules.
136   const std::unique_ptr<SimpleOrcJIT> jit_;
137 
138   // Buffer assignment for the buffers we need to allocate.
139   const std::unique_ptr<const BufferAssignment> assignment_;
140 
141   // The LLVM IR, in string format, of the unoptimized module generated for this
142   // CpuExecutable. We save a string instead of an llvm::Module* because leaving
143   // llvm::Module* in a singleton can cause the heap checker to emit false
144   // positives.
145   string ir_module_string_;
146 
147   ComputeFunctionType compute_function_;
148 
149   // Entry function name for the computation.
150   const string entry_function_name_;
151 
152   TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
153 };
154 
155 }  // namespace cpu
156 }  // namespace xla
157 
158 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
159