1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 18 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 27 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 28 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 31 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 34 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 39 #include "tensorflow/core/platform/types.h" 40 41 namespace xla { 42 namespace cpu { 43 44 // CPU-targeting implementation of the XLA Executable interface. 45 // 46 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host 47 // architecture, so JIT-ed code and host code share the same ABI. 48 class CpuExecutable : public Executable { 49 public: 50 CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit, 51 std::unique_ptr<const BufferAssignment> assignment, 52 std::unique_ptr<HloModule> hlo_module, 53 const string& entry_function_name, 54 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, 55 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); ~CpuExecutable()56 ~CpuExecutable() override {} 57 58 StatusOr<ScopedShapedBuffer> ExecuteOnStream( 59 const ServiceExecutableRunOptions* run_options, 60 absl::Span<const ShapedBuffer* const> arguments, 61 HloExecutionProfile* hlo_execution_profile) override; 62 63 StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream( 64 const ServiceExecutableRunOptions* run_options, 65 absl::Span<const ShapedBuffer* const> arguments) override; 66 67 // This should be called after set_ir_module_string. ir_module_string()68 const string& ir_module_string() const { return ir_module_string_; } 69 set_ir_module_string(const string & ir_module_string)70 void set_ir_module_string(const string& ir_module_string) { 71 ir_module_string_ = ir_module_string; 72 } 73 74 static int64 ShapeSizeBytes(const Shape& shape); 75 76 // Type of the computation function we expect in the JIT. 77 using ComputeFunctionType = 78 void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, 79 const void** /*args*/, void** /*buffer_table*/, 80 int64* /*profile_counters*/); 81 compute_function()82 const ComputeFunctionType& compute_function() const { 83 return compute_function_; 84 } 85 buffer_assignment()86 const BufferAssignment& buffer_assignment() const { return *assignment_; } 87 88 private: 89 // This is for sharing the code between ExecuteOnStream and 90 // ExecuteAsyncOnStream. 91 // 92 // Notice that it's tricky to use correctly, as the profile object (when it 93 // exists) must out-live the task. 94 StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl( 95 const ServiceExecutableRunOptions* run_options, 96 absl::Span<const ShapedBuffer* const> arguments, 97 HloExecutionProfile* hlo_execution_profile); 98 99 // Creates an array suitable for passing as the "buffer_table" argument to the 100 // JIT compiled function pointer. 101 // 102 // Returns (unowning_buffers, owning_buffers) where: 103 // 104 // - unowning_buffers.data() can be passed as the buffer_table argument as-is 105 // and includes pointers to the scratch storage required by the 106 // computation, the live-out buffer into which the result will be written 107 // and entry computation parameters. 108 // 109 // - owning_buffers contains owning pointers to the buffers that were 110 // allocated by this routine. This routine allocates buffers for temporary 111 // storage and the live-out buffer into which the computation writes it 112 // result. 113 StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, 114 std::vector<OwningDeviceMemory>>> 115 CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, 116 absl::Span<const ShapedBuffer* const> arguments); 117 118 // Calls the generated function performing the computation with the given 119 // arguments using the supplied buffers. 120 Status ExecuteComputeFunction(const ExecutableRunOptions* run_options, 121 absl::Span<const se::DeviceMemoryBase> buffers, 122 HloExecutionProfile* hlo_execution_profile); 123 124 // Creates a ScopedShapedBuffer for holding the result of the computation, 125 // moving buffers out of allocated_buffers and into the result as appropriate. 126 // The addresses are set according to buffer assignment. 127 StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer( 128 const ServiceExecutableRunOptions* run_options, 129 absl::Span<OwningDeviceMemory> buffers); 130 131 // Returns the points-to set of the root instruction of the entry 132 // computation. Uses points-to analysis from buffer assignment. 133 const PointsToSet& GetRootPointsToSet() const; 134 135 // The JIT containing compiled modules. 136 const std::unique_ptr<SimpleOrcJIT> jit_; 137 138 // Buffer assignment for the buffers we need to allocate. 139 const std::unique_ptr<const BufferAssignment> assignment_; 140 141 // The LLVM IR, in string format, of the unoptimized module generated for this 142 // CpuExecutable. We save a string instead of an llvm::Module* because leaving 143 // llvm::Module* in a singleton can cause the heap checker to emit false 144 // positives. 145 string ir_module_string_; 146 147 ComputeFunctionType compute_function_; 148 149 // Entry function name for the computation. 150 const string entry_function_name_; 151 152 TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); 153 }; 154 155 } // namespace cpu 156 } // namespace xla 157 158 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 159