1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 18 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 27 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 28 #include "tensorflow/compiler/xla/service/executable.h" 29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 31 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/core/platform/macros.h" 37 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/stream_executor/device_memory_allocator.h" 40 41 namespace xla { 42 namespace cpu { 43 44 // CPU-targeting implementation of the XLA Executable interface. 45 // 46 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host 47 // architecture, so JIT-ed code and host code share the same ABI. 48 class CpuExecutable : public Executable { 49 public: 50 CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit, 51 std::unique_ptr<const BufferAssignment> assignment, 52 std::unique_ptr<HloModule> hlo_module, 53 const string& entry_function_name, 54 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, 55 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); 56 ~CpuExecutable() override; 57 58 StatusOr<ExecutionOutput> ExecuteAsyncOnStream( 59 const ServiceExecutableRunOptions* run_options, 60 std::vector<ExecutionInput> arguments, 61 HloExecutionProfile* hlo_execution_profile) override; 62 63 // Calls the generated function performing the computation with the given 64 // arguments using the supplied buffers. 65 Status ExecuteComputeFunction( 66 const ExecutableRunOptions* run_options, 67 absl::Span<MaybeOwningDeviceMemory const> buffers, 68 HloExecutionProfile* hlo_execution_profile); 69 70 // This should be called after set_ir_module_string. ir_module_string()71 const string& ir_module_string() const { return ir_module_string_; } 72 set_ir_module_string(const string & ir_module_string)73 void set_ir_module_string(const string& ir_module_string) { 74 ir_module_string_ = ir_module_string; 75 } 76 77 static int64 ShapeSizeBytes(const Shape& shape); 78 79 // Type of the computation function we expect in the JIT. 80 using ComputeFunctionType = 81 void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, 82 const void** /*args*/, void** /*buffer_table*/, 83 int64* /*profile_counters*/); 84 compute_function()85 const ComputeFunctionType& compute_function() const { 86 return compute_function_; 87 } 88 buffer_assignment()89 const BufferAssignment& buffer_assignment() const { return *assignment_; } 90 91 int64 SizeOfGeneratedCodeInBytes() const override; 92 93 private: 94 // Creates an array suitable for passing as the "buffer_table" argument to the 95 // JIT compiled function pointer. 96 // 97 // Returns (unowning_buffers, owning_buffers) where: 98 // 99 // - unowning_buffers.data() can be passed as the buffer_table argument as-is 100 // and includes pointers to the scratch storage required by the 101 // computation, the live-out buffer into which the result will be written 102 // and entry computation parameters. 103 // 104 // - owning_buffers contains owning pointers to the buffers that were 105 // allocated by this routine. This routine allocates buffers for temporary 106 // storage and the live-out buffer into which the computation writes it 107 // result. 108 // 109 // - buffers_to_free: buffers whose ownership was donated by the caller that 110 // are to be freed by the caller. 111 StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable( 112 se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, 113 absl::Span<ExecutionInput const> arguments); 114 115 // Creates an Execution output holding ScopedShapedBuffer for holding the 116 // result of the computation, moving buffers out of allocated_buffers and into 117 // the result as appropriate. The addresses are set according to buffer 118 // assignment. 119 StatusOr<ExecutionOutput> CreateResultShapedBuffer( 120 const ServiceExecutableRunOptions* run_options, 121 absl::Span<MaybeOwningDeviceMemory> buffers, 122 absl::Span<ExecutionInput> arguments); 123 124 // Returns the instruction value set of the root instruction of the entry 125 // computation. Uses dataflow analysis from buffer assignment. 126 const InstructionValueSet& GetRootValueSet() const; 127 128 // The JIT containing compiled modules. 129 const std::unique_ptr<SimpleOrcJIT> jit_; 130 131 // Buffer assignment for the buffers we need to allocate. 132 const std::unique_ptr<const BufferAssignment> assignment_; 133 134 std::shared_ptr<const BufferAssignmentProto> buffer_assignment_; 135 136 // The LLVM IR, in string format, of the unoptimized module generated for this 137 // CpuExecutable. We save a string instead of an llvm::Module* because leaving 138 // llvm::Module* in a singleton can cause the heap checker to emit false 139 // positives. 140 string ir_module_string_; 141 142 // Unique identifier. 143 string module_name_; 144 145 ComputeFunctionType compute_function_; 146 147 // Entry function name for the computation. 148 const string entry_function_name_; 149 150 TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); 151 }; 152 153 } // namespace cpu 154 } // namespace xla 155 156 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 157