• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
27 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/device_memory_allocator.h"
40 
41 namespace xla {
42 namespace cpu {
43 
44 // CPU-targeting implementation of the XLA Executable interface.
45 //
46 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host
47 // architecture, so JIT-ed code and host code share the same ABI.
48 class CpuExecutable : public Executable {
49  public:
50   CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
51                 std::unique_ptr<const BufferAssignment> assignment,
52                 std::unique_ptr<HloModule> hlo_module,
53                 const string& entry_function_name,
54                 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
55                 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
56   ~CpuExecutable() override;
57 
58   StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
59       const ServiceExecutableRunOptions* run_options,
60       std::vector<ExecutionInput> arguments,
61       HloExecutionProfile* hlo_execution_profile) override;
62 
63   // Calls the generated function performing the computation with the given
64   // arguments using the supplied buffers.
65   Status ExecuteComputeFunction(
66       const ExecutableRunOptions* run_options,
67       absl::Span<MaybeOwningDeviceMemory const> buffers,
68       HloExecutionProfile* hlo_execution_profile);
69 
70   // This should be called after set_ir_module_string.
ir_module_string()71   const string& ir_module_string() const { return ir_module_string_; }
72 
set_ir_module_string(const string & ir_module_string)73   void set_ir_module_string(const string& ir_module_string) {
74     ir_module_string_ = ir_module_string;
75   }
76 
77   static int64 ShapeSizeBytes(const Shape& shape);
78 
79   // Type of the computation function we expect in the JIT.
80   using ComputeFunctionType =
81       void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
82                const void** /*args*/, void** /*buffer_table*/,
83                int64* /*profile_counters*/);
84 
compute_function()85   const ComputeFunctionType& compute_function() const {
86     return compute_function_;
87   }
88 
buffer_assignment()89   const BufferAssignment& buffer_assignment() const { return *assignment_; }
90 
91   int64 SizeOfGeneratedCodeInBytes() const override;
92 
93  private:
94   // Creates an array suitable for passing as the "buffer_table" argument to the
95   // JIT compiled function pointer.
96   //
97   // Returns (unowning_buffers, owning_buffers) where:
98   //
99   //  - unowning_buffers.data() can be passed as the buffer_table argument as-is
100   //    and includes pointers to the scratch storage required by the
101   //    computation, the live-out buffer into which the result will be written
102   //    and entry computation parameters.
103   //
104   //  - owning_buffers contains owning pointers to the buffers that were
105   //    allocated by this routine.  This routine allocates buffers for temporary
106   //    storage and the live-out buffer into which the computation writes it
107   //    result.
108   //
109   //  - buffers_to_free: buffers whose ownership was donated by the caller that
110   //    are to be freed by the caller.
111   StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable(
112       se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
113       absl::Span<ExecutionInput const> arguments);
114 
115   // Creates an Execution output holding ScopedShapedBuffer for holding the
116   // result of the computation, moving buffers out of allocated_buffers and into
117   // the result as appropriate.  The addresses are set according to buffer
118   // assignment.
119   StatusOr<ExecutionOutput> CreateResultShapedBuffer(
120       const ServiceExecutableRunOptions* run_options,
121       absl::Span<MaybeOwningDeviceMemory> buffers,
122       absl::Span<ExecutionInput> arguments);
123 
124   // Returns the instruction value set of the root instruction of the entry
125   // computation. Uses dataflow analysis from buffer assignment.
126   const InstructionValueSet& GetRootValueSet() const;
127 
128   // The JIT containing compiled modules.
129   const std::unique_ptr<SimpleOrcJIT> jit_;
130 
131   // Buffer assignment for the buffers we need to allocate.
132   const std::unique_ptr<const BufferAssignment> assignment_;
133 
134   std::shared_ptr<const BufferAssignmentProto> buffer_assignment_;
135 
136   // The LLVM IR, in string format, of the unoptimized module generated for this
137   // CpuExecutable. We save a string instead of an llvm::Module* because leaving
138   // llvm::Module* in a singleton can cause the heap checker to emit false
139   // positives.
140   string ir_module_string_;
141 
142   // Unique identifier.
143   string module_name_;
144 
145   ComputeFunctionType compute_function_;
146 
147   // Entry function name for the computation.
148   const string entry_function_name_;
149 
150   TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
151 };
152 
153 }  // namespace cpu
154 }  // namespace xla
155 
156 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
157