• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/computation_placer.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 
39 namespace xla {
40 
41 // A base class for running an HloModule. This executes the given HloModule on a
42 // certain backend directly without using the client interface. HloModule can be
43 // explicitly built, or loaded from a serialization file (e.g., hlo proto
44 // file), or parsed from a hlo textual IR string.
45 class HloRunner {
46  public:
47   // The options used to configure a ExecuteReplicated() call.
48   struct ReplicatedExecuteOptions {
49     // The number of devices the HLO module should be replicated onto.
50     int64 num_replicas = 1;
51 
52     // The arguments to be fed to each replica. Since this is used for a
53     // replicated execution, all the arguments are the same for all replicas.
54     std::vector<const Literal*> arguments;
55 
56     // If the HLO module being run has an infeed instruction, this will be the
57     // data which will be fed to it, for as many as infeed_steps steps.
58     const Literal* infeed = nullptr;
59 
60     // The number of times the infeed literal should be fed to the HLO module.
61     // For a clean exit, this should match the iterations-per-loop parameter
62     // used when generating the HLO module proto (that is usually the main
63     // while boundary counter). A value higher then iterations-per-loop would
64     // lead to infeed threads feeding to a gone computation, while a lower
65     // value would trigger a stuck ExecuteReplicated() call (the computation
66     // will be trying to infeed data which will never come).
67     int64 infeed_steps = -1;
68 
69     // The shape of the outfeed operation. If empty, the HLO module does not
70     // generate any outfeed.
71     Shape outfeed_shape;
72 
73     // A pointer to a vector where the outfeed values will be stored. If
74     // nullptr, the values will be read and discarded.
75     std::vector<Literal>* outfeed_values = nullptr;
76 
77     // Whether the HLO passes should be run on the input module. Usually
78     // saved modules are coming from after the HLO pass pipeline, so triggering
79     // another run will likely cause errors.
80     bool run_hlo_passes = false;
81 
82     // If true, executes on multiple threads using se::Stream::ExecuteOnStream.
83     // Otherwise, executes using xla::Executable::ExecuteOnStreams.
84     bool use_threads = false;
85   };
86 
87   // intra_op_parallelism_threads: For the CPU backend only. It is the thread
88   // pool size for parallel execution of an individual operator. The default
89   // value of -1 will result in initializing the thread pool with the number of
90   // threads equal to the number of
91   // cores in the system.
92   explicit HloRunner(se::Platform* platform,
93                      int intra_op_parallelism_threads = -1);
94 
95   ~HloRunner();
96 
97   // Converts an HloModule from the given hlo textual IR string (in
98   // HloModule::ToString format).
99   static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
100       const absl::string_view hlo_string, const DebugOptions& debug_options);
101 
102   // Reads the proto file in xla.HloProto format, creates and returns the
103   // HloModule.
104   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
105       const std::string& filename, const DebugOptions& debug_options);
106   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
107       const std::string& filename, const DebugOptions& debug_options);
108 
109   // Reads the hlo text dump file in HloModule::ToString format, creates and
110   // returns the HloModule.
111   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
112       const std::string& filename, const DebugOptions& debug_options);
113 
114   // Transfers data between the host and device.
115   StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
116   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
117       absl::Span<const Literal* const> literals);
118   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
119       absl::Span<const Literal> literals);
120   StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
121 
122   // Executes the given module with given literals as input and returns the
123   // result as a Literal.
124   //
125   // If run_hlo_passes is false, the module will be executed without Hlo
126   // optimization.
127   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
128                             absl::Span<const Literal* const> arguments,
129                             bool run_hlo_passes = true,
130                             ExecutionProfile* profile = nullptr);
131 
132   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
133                             absl::Span<const Literal> arguments,
134                             bool run_hlo_passes = true,
135                             ExecutionProfile* profile = nullptr);
136 
137   StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
138                             absl::Span<const Literal* const> arguments,
139                             ExecutionProfile* profile = nullptr);
140 
141   StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
142                             absl::Span<const Literal> arguments,
143                             ExecutionProfile* profile = nullptr);
144 
145   // As Execute(), but accepts and returns device buffers instead of host
146   // buffers.
147   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
148       std::unique_ptr<HloModule> module,
149       absl::Span<const ShapedBuffer* const> arguments,
150       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
151 
152   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
153       std::unique_ptr<HloModule> module,
154       absl::Span<const ScopedShapedBuffer> arguments,
155       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
156 
157   // In the following two calls, "executable" is not a unique_ptr to allow
158   // reuse of the Executable.  This call may update the profile information in
159   // *executable.
160   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
161       Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
162       ExecutionProfile* profile = nullptr);
163 
164   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
165       Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
166       ExecutionProfile* profile = nullptr);
167 
168   // Creates an executable object given an HLO module. If run_hlo_passes is
169   // true, the HLO passes will be run as part of compilation.
170   StatusOr<std::unique_ptr<Executable>> CreateExecutable(
171       std::unique_ptr<HloModule> module, bool run_hlo_passes);
172 
173   // Executes a given HLO module into a set of replicas, and returns a map
174   // with the replica number as key, and the corresponding returned literal as
175   // value.
176   StatusOr<std::vector<Literal>> ExecuteReplicated(
177       std::unique_ptr<HloModule> module,
178       const ReplicatedExecuteOptions& options);
179 
180   // Same as above, but with specified device assignment.
181   StatusOr<std::vector<Literal>> ExecuteReplicated(
182       std::unique_ptr<HloModule> module,
183       const ReplicatedExecuteOptions& options,
184       DeviceAssignment* device_assignment);
185 
186   // Same as above, but with a reusable Executable.  This may update the profile
187   // information in *executable.
188   //
189   // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
190   // since we've already compiled the Executable.
191   StatusOr<std::vector<Literal>> ExecuteReplicated(
192       Executable* executable, const ReplicatedExecuteOptions& options,
193       DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
194 
195   // If backend is not created in the constructor, creates and returns the
196   // default backend. If creation fails, crashes the program.
197   //
198   // This creates the backend lazily so it's possible to instantiate an
199   // HloRunner in a program without any backends linked in.
200   Backend& backend();
201   const Backend& backend() const;
202 
203  private:
204   // Creates a ServiceExecutableRunOptions object to configure a run on device,
205   // using the provided stream object. If device_assignment is not nullptr, it
206   // will be used to configure the replication parameters. Replicated executions
207   // should pass the device_assignment parameter.
208   ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
209       int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
210       RunId run_id);
211 
212   std::unique_ptr<Backend> backend_;
213 };
214 
215 }  // namespace xla
216 
217 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
218