• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/computation_placer.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
39 
40 namespace xla {
41 
42 // A base class for running an HloModule. This executes the given HloModule on a
43 // certain backend directly without using the client interface. HloModule can be
44 // explicitly built, or loaded from a serialization file (e.g., hlo proto
45 // file), or parsed from a hlo textual IR string.
46 class HloRunner : public HloRunnerInterface {
47  public:
48   // intra_op_parallelism_threads: For the CPU backend only. It is the thread
49   // pool size for parallel execution of an individual operator. The default
50   // value of -1 will result in initializing the thread pool with the number of
51   // threads equal to the number of
52   // cores in the system.
53   explicit HloRunner(se::Platform* platform,
54                      int intra_op_parallelism_threads = -1);
55 
56   ~HloRunner() override;
57 
58   // Transfers data between the host and device.
59   StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
60   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
61       absl::Span<const Literal* const> literals);
62   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
63       absl::Span<const Literal> literals);
64   StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
65 
66   // Executes the given module with given literals as input and returns the
67   // result as a Literal.
68   //
69   // If run_hlo_passes is false, the module will be executed without Hlo
70   // optimization.
71 
72   using HloRunnerInterface::Execute;
73 
74   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
75                             absl::Span<const Literal* const> arguments,
76                             bool run_hlo_passes,
77                             ExecutionProfile* profile) override;
78 
79   using HloRunnerInterface::ExecuteWithExecutable;
80 
81   StatusOr<Literal> ExecuteWithExecutable(
82       std::unique_ptr<Executable> executable,
83       absl::Span<const Literal* const> arguments,
84       ExecutionProfile* profile) override;
85 
86   // As Execute(), but accepts and returns device buffers instead of host
87   // buffers.
88   StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
89       std::unique_ptr<HloModule> module,
90       absl::Span<ScopedShapedBuffer const> arguments,
91       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
92 
93   StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
94       Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
95       ExecutionProfile* profile = nullptr);
96 
97   // Creates an executable object given an HLO module. If run_hlo_passes is
98   // true, the HLO passes will be run as part of compilation.
99   StatusOr<std::unique_ptr<Executable>> CreateExecutable(
100       std::unique_ptr<HloModule> module, bool run_hlo_passes) override;
101 
102   // Executes a given HLO module into a set of replicas, and returns a map
103   // with the replica number as key, and the corresponding returned literal as
104   // value.
105   StatusOr<std::vector<Literal>> ExecuteReplicated(
106       std::unique_ptr<HloModule> module,
107       const ReplicatedExecuteOptions& options) override;
108 
109   // Same as above, but with specified device assignment.
110   StatusOr<std::vector<Literal>> ExecuteReplicated(
111       std::unique_ptr<HloModule> module,
112       const ReplicatedExecuteOptions& options,
113       DeviceAssignment* device_assignment) override;
114 
115   // Same as above, but with a reusable Executable.  This may update the profile
116   // information in *executable.
117   //
118   // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
119   // since we've already compiled the Executable.
120   StatusOr<std::vector<Literal>> ExecuteReplicated(
121       Executable* executable, const ReplicatedExecuteOptions& options,
122       DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
123 
124   // Same as above, but with different reusable Executables. This may update the
125   // profile information in *executables.
126   //
127   // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
128   // since we've already compiled the Executable.
129   StatusOr<std::vector<Literal>> ExecuteReplicated(
130       std::function<Executable*(int64)> executable_provider,
131       std::function<int64(int64)> argument_count_provider,
132       std::function<const Literal*(int64, int64)> argument_provider,
133       const ReplicatedExecuteOptions& options);
134 
135   // If backend is not created in the constructor, creates and returns the
136   // default backend. If creation fails, crashes the program.
137   //
138   // This creates the backend lazily so it's possible to instantiate an
139   // HloRunner in a program without any backends linked in.
140   Backend& backend();
141   const Backend& backend() const;
142 
143  private:
144   // Creates a ServiceExecutableRunOptions object to configure a run on device,
145   // using the provided stream object. If device_assignment is not nullptr, it
146   // will be used to configure the replication parameters. Replicated executions
147   // should pass the device_assignment parameter.
148   ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
149       int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
150       RunId run_id);
151 
152   // Common implementation code for ExecuteReplicated() above.
153   StatusOr<std::vector<Literal>> ExecuteReplicatedImpl(
154       std::function<StatusOr<std::vector<ScopedShapedBuffer>>(
155           const std::vector<ServiceExecutableRunOptions>&,
156           const std::vector<absl::Span<const ShapedBuffer* const>>&)>
157           execution_helper,
158       std::function<int64(int64)> argument_count_provider,
159       std::function<const Literal*(int64, int64)> argument_provider,
160       const ReplicatedExecuteOptions& options,
161       DeviceAssignment* device_assignment);
162 
163   std::unique_ptr<Backend> backend_;
164 };
165 
166 }  // namespace xla
167 
168 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
169