1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/computation_placer.h" 27 #include "tensorflow/compiler/xla/service/executable.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 36 namespace xla { 37 38 // A base class for running an HloModule. This executes the given HloModule on a 39 // certain backend directly without using the client interface. HloModule can be 40 // explicitly built, or loaded from a serialization file (e.g., hlo proto 41 // file), or parsed from a hlo textual IR string. 42 class HloRunnerInterface { 43 public: 44 // The options used to configure an ExecuteReplicated() call. 45 struct ReplicatedExecuteOptions { 46 // The number of devices the HLO module should be replicated onto. 47 int64 num_replicas = 1; 48 49 // The arguments to be fed to each replica. Since this is used for a 50 // replicated execution, all the arguments are the same for all replicas. 51 std::vector<const Literal*> arguments; 52 53 // If the HLO module being run has an infeed instruction, this will be the 54 // data which will be fed to it, for as many as infeed_steps steps. 55 const Literal* infeed = nullptr; 56 57 // The number of times the infeed literal should be fed to the HLO module. 58 // For a clean exit, this should match the iterations-per-loop parameter 59 // used when generating the HLO module proto (that is usually the main 60 // while boundary counter). A value higher then iterations-per-loop would 61 // lead to infeed threads feeding to a gone computation, while a lower 62 // value would trigger a stuck ExecuteReplicated() call (the computation 63 // will be trying to infeed data which will never come). 64 int64 infeed_steps = -1; 65 66 // The shape of the outfeed operation. If empty, the HLO module does not 67 // generate any outfeed. 68 Shape outfeed_shape; 69 70 // A pointer to a vector where the outfeed values will be stored. If 71 // nullptr, the values will be read and discarded. 72 std::vector<Literal>* outfeed_values = nullptr; 73 74 // Whether the HLO passes should be run on the input module. Usually 75 // saved modules are coming from after the HLO pass pipeline, so triggering 76 // another run will likely cause errors. 77 bool run_hlo_passes = false; 78 79 // If true, executes on multiple threads using se::Stream::ExecuteOnStream. 80 // Otherwise, executes using xla::Executable::ExecuteOnStreams. 81 bool use_threads = false; 82 }; 83 84 HloRunnerInterface() = default; 85 86 virtual ~HloRunnerInterface() = default; 87 88 // Converts an HloModule from the given hlo textual IR string (in 89 // HloModule::ToString format). 90 static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString( 91 const absl::string_view hlo_string, const DebugOptions& debug_options); 92 93 // Reads the proto file in xla.HloProto format, creates and returns the 94 // HloModule. 95 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile( 96 const std::string& filename, const DebugOptions& debug_options); 97 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile( 98 const std::string& filename, const DebugOptions& debug_options); 99 100 // Reads the proto file in xla.HloModule format, creates and returns the 101 // HloModule. 102 static StatusOr<std::unique_ptr<HloModule>> 103 ReadModuleFromModuleBinaryProtofile(const std::string& filename, 104 const DebugOptions& debug_options); 105 106 // Reads the hlo text dump file in HloModule::ToString format, creates and 107 // returns the HloModule. 108 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile( 109 const std::string& filename, const DebugOptions& debug_options); 110 111 // Creates an executable object given an HLO module. If run_hlo_passes is 112 // true, the HLO passes will be run as part of compilation. 113 virtual StatusOr<std::unique_ptr<Executable>> CreateExecutable( 114 std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0; 115 116 // Executes the given module with given literals as input and returns the 117 // result as a Literal. 118 // 119 // If run_hlo_passes is false, the module will be executed without Hlo 120 // optimization 121 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 122 absl::Span<const Literal* const> arguments, 123 bool run_hlo_passes = true) { 124 return Execute(std::move(module), arguments, run_hlo_passes, nullptr); 125 } 126 127 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 128 absl::Span<const Literal> arguments, 129 bool run_hlo_passes = true, 130 ExecutionProfile* profile = nullptr); 131 132 virtual StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 133 absl::Span<const Literal* const> arguments, 134 bool run_hlo_passes, 135 ExecutionProfile* profile) = 0; 136 137 // Same as above, but with Executable as input. 138 StatusOr<Literal> ExecuteWithExecutable( 139 std::unique_ptr<Executable> executable, 140 absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr); 141 ExecuteWithExecutable(std::unique_ptr<Executable> executable,absl::Span<const Literal * const> arguments)142 StatusOr<Literal> ExecuteWithExecutable( 143 std::unique_ptr<Executable> executable, 144 absl::Span<const Literal* const> arguments) { 145 return ExecuteWithExecutable(std::move(executable), arguments, nullptr); 146 } 147 148 virtual StatusOr<Literal> ExecuteWithExecutable( 149 std::unique_ptr<Executable> executable, 150 absl::Span<const Literal* const> arguments, 151 ExecutionProfile* profile) = 0; 152 153 // Executes a given HLO module into a set of replicas, and returns a map 154 // with the replica number as key, and the corresponding returned literal as 155 // value. 156 // TODO(b/172931928): change to non-virtual function. 157 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 158 std::unique_ptr<HloModule> module, 159 const ReplicatedExecuteOptions& options) = 0; 160 161 // Same as above, but with specified device assignment. 162 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 163 std::unique_ptr<HloModule> module, 164 const ReplicatedExecuteOptions& options, 165 DeviceAssignment* device_assignment) = 0; 166 167 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 168 std::function<Executable*(int64)> executable_provider, 169 std::function<int64(int64)> argument_count_provider, 170 std::function<const Literal*(int64, int64)> argument_provider, 171 const ReplicatedExecuteOptions& options) = 0; 172 }; 173 174 } // namespace xla 175 176 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 177