• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/computation_placer.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 
39 namespace xla {
40 
41 // A base class for running an HloModule. This executes the given HloModule on a
42 // certain backend directly without using the client interface. HloModule can be
43 // explicitly built, or loaded from a serialization file (e.g., hlo proto
44 // file), or parsed from a hlo textual IR string.
45 class HloRunner {
46  public:
47   // The options used to configure a ExecuteReplicated() call.
48   struct ReplicatedExecuteOptions {
49     // The number of devices the HLO module should be replicated onto.
50     int64 num_replicas = 1;
51 
52     // The arguments to be fed to each replica. Since this is used for a
53     // replicated execution, all the arguments are the same for all replicas.
54     std::vector<const Literal*> arguments;
55 
56     // If the HLO module being run has an infeed instruction, this will be the
57     // data which will be fed to it, for as many as infeed_steps steps.
58     const Literal* infeed = nullptr;
59 
60     // The number of times the infeed literal should be fed to the HLO module.
61     // For a clean exit, this should match the iterations-per-loop parameter
62     // used when generating the HLO module proto (that is usually the main
63     // while boundary counter). A value higher then iterations-per-loop would
64     // lead to infeed threads feeding to a gone computation, while a lower
65     // value would trigger a stuck ExecuteReplicated() call (the computation
66     // will be trying to infeed data which will never come).
67     int64 infeed_steps = -1;
68 
69     // The shape of the outfeed operation. If empty, the HLO module does not
70     // generate any outfeed.
71     Shape outfeed_shape;
72 
73     // A pointer to a vector where the outfeed values will be stored. If
74     // nullptr, the values will be read and discarded.
75     std::vector<Literal>* outfeed_values = nullptr;
76 
77     // Whether the HLO passes should be run on the input module. Usually
78     // saved modules are coming from after the HLO pass pipeline, so triggering
79     // another run will likely cause errors.
80     bool run_hlo_passes = false;
81   };
82 
83   explicit HloRunner(se::Platform* platform);
84 
85   ~HloRunner();
86 
87   // Converts an HloModule from the given hlo textual IR string (in
88   // HloModule::ToString format).
89   static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
90       const absl::string_view hlo_string, const DebugOptions& debug_options);
91 
92   // Reads the proto file in xla.HloProto format, creates and returns the
93   // HloModule.
94   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
95       const std::string& filename, const DebugOptions& debug_options);
96   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
97       const std::string& filename, const DebugOptions& debug_options);
98 
99   // Reads the hlo text dump file in HloModule::ToString format, creates and
100   // returns the HloModule.
101   static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
102       const std::string& filename, const DebugOptions& debug_options);
103 
104   // Transfers data between the host and device.
105   StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
106   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
107       const absl::Span<const Literal* const> literals);
108   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
109       const absl::Span<const Literal> literals);
110   StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
111 
112   // Executes the given module with given literals as input and returns the
113   // result as a Literal.
114   //
115   // If run_hlo_passes is false, the module will be executed without Hlo
116   // optimization.
117   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
118                             const absl::Span<const Literal* const> arguments,
119                             bool run_hlo_passes = true,
120                             ExecutionProfile* profile = nullptr);
121 
122   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
123                             const absl::Span<const Literal> arguments,
124                             bool run_hlo_passes = true,
125                             ExecutionProfile* profile = nullptr);
126 
127   StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
128                             const absl::Span<const Literal* const> arguments,
129                             ExecutionProfile* profile = nullptr);
130 
131   StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
132                             const absl::Span<const Literal> arguments,
133                             ExecutionProfile* profile = nullptr);
134 
135   // As Execute(), but accepts and returns device buffers instead of host
136   // buffers.
137   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
138       std::unique_ptr<HloModule> module,
139       const absl::Span<const ShapedBuffer* const> arguments,
140       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
141 
142   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
143       std::unique_ptr<HloModule> module,
144       const absl::Span<const ScopedShapedBuffer> arguments,
145       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
146 
147   // In the following two calls, "executable" is not a unique_ptr to allow
148   // reuse of the Executable.  This call may update the profile information in
149   // *executable.
150   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
151       Executable* executable,
152       const absl::Span<const ShapedBuffer* const> arguments,
153       ExecutionProfile* profile = nullptr);
154 
155   StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
156       Executable* executable,
157       const absl::Span<const ScopedShapedBuffer> arguments,
158       ExecutionProfile* profile = nullptr);
159 
160   // Creates an executable object given an HLO module. If run_hlo_passes is
161   // true, the HLO passes will be run as part of compilation.
162   StatusOr<std::unique_ptr<Executable>> CreateExecutable(
163       std::unique_ptr<HloModule> module, bool run_hlo_passes);
164 
165   // Executes a given HLO module into a set of replicas, and returns a map
166   // with the replica number as key, and the corresponding returned literal as
167   // value.
168   //
169   // use_threads indicates whether this replicated computation will be executed
170   // with a thread-per-replica, vs using an implicitly async call such as
171   // Executable::ExecuteOnStreams.
172   StatusOr<std::vector<Literal>> ExecuteReplicated(
173       std::unique_ptr<HloModule> module,
174       const ReplicatedExecuteOptions& options, bool use_threads = false);
175 
176   // Same as above, but with specified device assignment.
177   StatusOr<std::vector<Literal>> ExecuteReplicated(
178       std::unique_ptr<HloModule> module,
179       const ReplicatedExecuteOptions& options,
180       DeviceAssignment* device_assignment, bool use_threads = false);
181 
182   // If backend is not created in the constructor, creates and returns the
183   // default backend. If creation fails, crashes the program.
184   //
185   // This creates the backend lazily so it's possible to instantiate an
186   // HloRunner in a program without any backends linked in.
187   Backend& backend();
188   const Backend& backend() const;
189 
190  private:
191   // Creates a ServiceExecutableRunOptions object to configure a run on device,
192   // using the provided stream object. If device_assignment is not nullptr, it
193   // will be used to configure the replication parameters. Replicated executions
194   // should pass the device_assignment parameter.
195   ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
196       int64 device, se::Stream* stream, DeviceAssignment* device_assignment);
197 
198   std::unique_ptr<Backend> backend_;
199 };
200 
201 }  // namespace xla
202 
203 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
204