1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/compiler.h" 28 #include "tensorflow/compiler/xla/service/computation_placer.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/hlo_computation.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 39 40 namespace xla { 41 42 // A base class for running an HloModule. This executes the given HloModule on a 43 // certain backend directly without using the client interface. HloModule can be 44 // explicitly built, or loaded from a serialization file (e.g., hlo proto 45 // file), or parsed from a hlo textual IR string. 46 class HloRunner : public HloRunnerInterface { 47 public: 48 // intra_op_parallelism_threads: For the CPU backend only. It is the thread 49 // pool size for parallel execution of an individual operator. The default 50 // value of -1 will result in initializing the thread pool with the number of 51 // threads equal to the number of 52 // cores in the system. 53 explicit HloRunner(se::Platform* platform, 54 int intra_op_parallelism_threads = -1); 55 56 ~HloRunner() override; 57 58 // Transfers data between the host and device. 59 StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal); 60 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 61 absl::Span<const Literal* const> literals); 62 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 63 absl::Span<const Literal> literals); 64 StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer); 65 66 // Executes the given module with given literals as input and returns the 67 // result as a Literal. 68 // 69 // If run_hlo_passes is false, the module will be executed without Hlo 70 // optimization. 71 72 using HloRunnerInterface::Execute; 73 74 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 75 absl::Span<const Literal* const> arguments, 76 bool run_hlo_passes, 77 ExecutionProfile* profile) override; 78 79 using HloRunnerInterface::ExecuteWithExecutable; 80 81 StatusOr<Literal> ExecuteWithExecutable( 82 Executable* executable, absl::Span<const Literal* const> arguments, 83 ExecutionProfile* profile) override; 84 85 // As Execute(), but accepts and returns device buffers instead of host 86 // buffers. 87 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 88 std::unique_ptr<HloModule> module, 89 absl::Span<ScopedShapedBuffer const> arguments, 90 bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); 91 92 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 93 Executable* executable, absl::Span<ScopedShapedBuffer const> arguments, 94 ExecutionProfile* profile = nullptr); 95 96 // Creates an executable object given an HLO module. If run_hlo_passes is 97 // true, the HLO passes will be run as part of compilation. 98 StatusOr<std::unique_ptr<Executable>> CreateExecutable( 99 std::unique_ptr<HloModule> module, bool run_hlo_passes) override; 100 101 // Executes a given HLO module into a set of replicas, and returns a map 102 // with the replica number as key, and the corresponding returned literal as 103 // value. 104 StatusOr<std::vector<Literal>> ExecuteReplicated( 105 std::unique_ptr<HloModule> module, 106 const ReplicatedExecuteOptions& options) override; 107 108 // Same as above, but with specified device assignment. 109 StatusOr<std::vector<Literal>> ExecuteReplicated( 110 std::unique_ptr<HloModule> module, 111 const ReplicatedExecuteOptions& options, 112 DeviceAssignment* device_assignment) override; 113 114 // Same as above, but with a reusable Executable. This may update the profile 115 // information in *executable. 116 // 117 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 118 // since we've already compiled the Executable. 119 StatusOr<std::vector<Literal>> ExecuteReplicated( 120 Executable* executable, const ReplicatedExecuteOptions& options, 121 DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); 122 123 // Same as above, but with different reusable Executables. This may update the 124 // profile information in *executables. 125 // 126 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 127 // since we've already compiled the Executable. 128 StatusOr<std::vector<Literal>> ExecuteReplicated( 129 std::function<Executable*(int64_t)> executable_provider, 130 std::function<int64(int64_t)> argument_count_provider, 131 std::function<const Literal*(int64_t, int64_t)> argument_provider, 132 const ReplicatedExecuteOptions& options); 133 134 // If backend is not created in the constructor, creates and returns the 135 // default backend. If creation fails, crashes the program. 136 // 137 // This creates the backend lazily so it's possible to instantiate an 138 // HloRunner in a program without any backends linked in. 139 Backend& backend(); 140 const Backend& backend() const; 141 142 private: 143 // Creates a ServiceExecutableRunOptions object to configure a run on device, 144 // using the provided stream object. If device_assignment is not nullptr, it 145 // will be used to configure the replication parameters. Replicated executions 146 // should pass the device_assignment parameter. 147 ServiceExecutableRunOptions GetServiceRunOptionsForDevice( 148 int64_t device, se::Stream* stream, DeviceAssignment* device_assignment, 149 RunId run_id); 150 151 // Common implementation code for ExecuteReplicated() above. 152 StatusOr<std::vector<Literal>> ExecuteReplicatedImpl( 153 std::function<StatusOr<std::vector<ScopedShapedBuffer>>( 154 const std::vector<ServiceExecutableRunOptions>&, 155 const std::vector<absl::Span<const ShapedBuffer* const>>&)> 156 execution_helper, 157 std::function<int64(int64_t)> argument_count_provider, 158 std::function<const Literal*(int64_t, int64_t)> argument_provider, 159 const ReplicatedExecuteOptions& options, 160 DeviceAssignment* device_assignment); 161 162 std::unique_ptr<Backend> backend_; 163 164 DeviceShapeRepresentationFn device_shape_representation_fn_; 165 }; 166 167 } // namespace xla 168 169 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 170