• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/computation_placer.h"
27 #include "tensorflow/compiler/xla/service/executable.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 
36 namespace xla {
37 
38 // A base class for running an HloModule. This executes the given HloModule on a
39 // certain backend directly without using the client interface. HloModule can be
40 // explicitly built, or loaded from a serialization file (e.g., hlo proto
41 // file), or parsed from a hlo textual IR string.
42 class HloRunnerInterface {
43  public:
44   // The options used to configure an ExecuteReplicated() call.
45   struct ReplicatedExecuteOptions {
46     // The number of devices the HLO module should be replicated onto.
47     int64 num_replicas = 1;
48 
49     // The arguments to be fed to each replica. Since this is used for a
50     // replicated execution, all the arguments are the same for all replicas.
51     std::vector<const Literal*> arguments;
52 
53     // If the HLO module being run has an infeed instruction, this will be the
54     // data which will be fed to it, for as many as infeed_steps steps.
55     const Literal* infeed = nullptr;
56 
57     // The number of times the infeed literal should be fed to the HLO module.
58     // For a clean exit, this should match the iterations-per-loop parameter
59     // used when generating the HLO module proto (that is usually the main
60     // while boundary counter). A value higher then iterations-per-loop would
61     // lead to infeed threads feeding to a gone computation, while a lower
62     // value would trigger a stuck ExecuteReplicated() call (the computation
63     // will be trying to infeed data which will never come).
64     int64 infeed_steps = -1;
65 
66     // The shape of the outfeed operation. If empty, the HLO module does not
67     // generate any outfeed.
68     Shape outfeed_shape;
69 
70     // A pointer to a vector where the outfeed values will be stored. If
71     // nullptr, the values will be read and discarded.
72     std::vector<Literal>* outfeed_values = nullptr;
73 
74     // Whether the HLO passes should be run on the input module. Usually
75     // saved modules are coming from after the HLO pass pipeline, so triggering
76     // another run will likely cause errors.
77     bool run_hlo_passes = false;
78 
79     // If true, executes on multiple threads using se::Stream::ExecuteOnStream.
80     // Otherwise, executes using xla::Executable::ExecuteOnStreams.
81     bool use_threads = false;
82   };
83 
84   HloRunnerInterface() = default;
85 
86   virtual ~HloRunnerInterface() = default;
87 
88   // Converts an HloModule from the given hlo textual IR string (in
89   // HloModule::ToString format).
90   static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
91       const absl::string_view hlo_string, const DebugOptions& debug_options);
92 
93   // Reads the proto file in xla.HloProto format, creates and returns the
94   // HloModule.
95   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
96       const std::string& filename, const DebugOptions& debug_options);
97   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
98       const std::string& filename, const DebugOptions& debug_options);
99 
100   // Reads the proto file in xla.HloModule format, creates and returns the
101   // HloModule.
102   static StatusOr<std::unique_ptr<HloModule>>
103   ReadModuleFromModuleBinaryProtofile(const std::string& filename,
104                                       const DebugOptions& debug_options);
105 
106   // Reads the hlo text dump file in HloModule::ToString format, creates and
107   // returns the HloModule.
108   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
109       const std::string& filename, const DebugOptions& debug_options);
110 
111   // Creates an executable object given an HLO module. If run_hlo_passes is
112   // true, the HLO passes will be run as part of compilation.
113   virtual StatusOr<std::unique_ptr<Executable>> CreateExecutable(
114       std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0;
115 
116   // Executes the given module with given literals as input and returns the
117   // result as a Literal.
118   //
119   // If run_hlo_passes is false, the module will be executed without Hlo
120   // optimization
121   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
122                             absl::Span<const Literal* const> arguments,
123                             bool run_hlo_passes = true) {
124     return Execute(std::move(module), arguments, run_hlo_passes, nullptr);
125   }
126 
127   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
128                             absl::Span<const Literal> arguments,
129                             bool run_hlo_passes = true,
130                             ExecutionProfile* profile = nullptr);
131 
132   virtual StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
133                                     absl::Span<const Literal* const> arguments,
134                                     bool run_hlo_passes,
135                                     ExecutionProfile* profile) = 0;
136 
137   // Same as above, but with Executable as input.
138   StatusOr<Literal> ExecuteWithExecutable(
139       std::unique_ptr<Executable> executable,
140       absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);
141 
ExecuteWithExecutable(std::unique_ptr<Executable> executable,absl::Span<const Literal * const> arguments)142   StatusOr<Literal> ExecuteWithExecutable(
143       std::unique_ptr<Executable> executable,
144       absl::Span<const Literal* const> arguments) {
145     return ExecuteWithExecutable(std::move(executable), arguments, nullptr);
146   }
147 
148   virtual StatusOr<Literal> ExecuteWithExecutable(
149       std::unique_ptr<Executable> executable,
150       absl::Span<const Literal* const> arguments,
151       ExecutionProfile* profile) = 0;
152 
153   // Executes a given HLO module into a set of replicas, and returns a map
154   // with the replica number as key, and the corresponding returned literal as
155   // value.
156   // TODO(b/172931928): change to non-virtual function.
157   virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
158       std::unique_ptr<HloModule> module,
159       const ReplicatedExecuteOptions& options) = 0;
160 
161   // Same as above, but with specified device assignment.
162   virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
163       std::unique_ptr<HloModule> module,
164       const ReplicatedExecuteOptions& options,
165       DeviceAssignment* device_assignment) = 0;
166 
167   virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
168       std::function<Executable*(int64)> executable_provider,
169       std::function<int64(int64)> argument_count_provider,
170       std::function<const Literal*(int64, int64)> argument_provider,
171       const ReplicatedExecuteOptions& options) = 0;
172 };
173 
174 }  // namespace xla
175 
176 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
177