1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/compiler.h" 28 #include "tensorflow/compiler/xla/service/computation_placer.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/hlo_computation.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 39 40 namespace xla { 41 42 // A base class for running an HloModule. This executes the given HloModule on a 43 // certain backend directly without using the client interface. HloModule can be 44 // explicitly built, or loaded from a serialization file (e.g., hlo proto 45 // file), or parsed from a hlo textual IR string. 46 class HloRunner : public HloRunnerInterface { 47 public: 48 // intra_op_parallelism_threads: For the CPU backend only. It is the thread 49 // pool size for parallel execution of an individual operator. The default 50 // value of -1 will result in initializing the thread pool with the number of 51 // threads equal to the number of 52 // cores in the system. 53 explicit HloRunner(se::Platform* platform, 54 int intra_op_parallelism_threads = -1); 55 56 ~HloRunner() override; 57 58 // Transfers data between the host and device. 59 StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal); 60 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 61 absl::Span<const Literal* const> literals); 62 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 63 absl::Span<const Literal> literals); 64 StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer); 65 66 // Executes the given module with given literals as input and returns the 67 // result as a Literal. 68 // 69 // If run_hlo_passes is false, the module will be executed without Hlo 70 // optimization. 71 72 using HloRunnerInterface::Execute; 73 74 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 75 absl::Span<const Literal* const> arguments, 76 bool run_hlo_passes, 77 ExecutionProfile* profile) override; 78 79 using HloRunnerInterface::ExecuteWithExecutable; 80 81 StatusOr<Literal> ExecuteWithExecutable( 82 std::unique_ptr<Executable> executable, 83 absl::Span<const Literal* const> arguments, 84 ExecutionProfile* profile) override; 85 86 // As Execute(), but accepts and returns device buffers instead of host 87 // buffers. 88 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 89 std::unique_ptr<HloModule> module, 90 absl::Span<ScopedShapedBuffer const> arguments, 91 bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); 92 93 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 94 Executable* executable, absl::Span<ScopedShapedBuffer const> arguments, 95 ExecutionProfile* profile = nullptr); 96 97 // Creates an executable object given an HLO module. If run_hlo_passes is 98 // true, the HLO passes will be run as part of compilation. 99 StatusOr<std::unique_ptr<Executable>> CreateExecutable( 100 std::unique_ptr<HloModule> module, bool run_hlo_passes) override; 101 102 // Executes a given HLO module into a set of replicas, and returns a map 103 // with the replica number as key, and the corresponding returned literal as 104 // value. 105 StatusOr<std::vector<Literal>> ExecuteReplicated( 106 std::unique_ptr<HloModule> module, 107 const ReplicatedExecuteOptions& options) override; 108 109 // Same as above, but with specified device assignment. 110 StatusOr<std::vector<Literal>> ExecuteReplicated( 111 std::unique_ptr<HloModule> module, 112 const ReplicatedExecuteOptions& options, 113 DeviceAssignment* device_assignment) override; 114 115 // Same as above, but with a reusable Executable. This may update the profile 116 // information in *executable. 117 // 118 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 119 // since we've already compiled the Executable. 120 StatusOr<std::vector<Literal>> ExecuteReplicated( 121 Executable* executable, const ReplicatedExecuteOptions& options, 122 DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); 123 124 // Same as above, but with different reusable Executables. This may update the 125 // profile information in *executables. 126 // 127 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 128 // since we've already compiled the Executable. 129 StatusOr<std::vector<Literal>> ExecuteReplicated( 130 std::function<Executable*(int64)> executable_provider, 131 std::function<int64(int64)> argument_count_provider, 132 std::function<const Literal*(int64, int64)> argument_provider, 133 const ReplicatedExecuteOptions& options); 134 135 // If backend is not created in the constructor, creates and returns the 136 // default backend. If creation fails, crashes the program. 137 // 138 // This creates the backend lazily so it's possible to instantiate an 139 // HloRunner in a program without any backends linked in. 140 Backend& backend(); 141 const Backend& backend() const; 142 143 private: 144 // Creates a ServiceExecutableRunOptions object to configure a run on device, 145 // using the provided stream object. If device_assignment is not nullptr, it 146 // will be used to configure the replication parameters. Replicated executions 147 // should pass the device_assignment parameter. 148 ServiceExecutableRunOptions GetServiceRunOptionsForDevice( 149 int64 device, se::Stream* stream, DeviceAssignment* device_assignment, 150 RunId run_id); 151 152 // Common implementation code for ExecuteReplicated() above. 153 StatusOr<std::vector<Literal>> ExecuteReplicatedImpl( 154 std::function<StatusOr<std::vector<ScopedShapedBuffer>>( 155 const std::vector<ServiceExecutableRunOptions>&, 156 const std::vector<absl::Span<const ShapedBuffer* const>>&)> 157 execution_helper, 158 std::function<int64(int64)> argument_count_provider, 159 std::function<const Literal*(int64, int64)> argument_provider, 160 const ReplicatedExecuteOptions& options, 161 DeviceAssignment* device_assignment); 162 163 std::unique_ptr<Backend> backend_; 164 }; 165 166 } // namespace xla 167 168 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 169