• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/client.h"
24 #include "tensorflow/compiler/xla/client/executable_build_options.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/local_service.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
33 #include "tensorflow/compiler/xla/shape_tree.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37 #include "tensorflow/stream_executor/device_memory_allocator.h"
38 
39 namespace xla {
40 
41 class LocalExecutable {
42  public:
43   // Low-level constructor; LocalClient::Compile() is the usual way to create
44   // executables.
45   LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
46                   ExecutableBuildOptions build_options);
47 
48   // Run the compiled computation with the given arguments and options and
49   // return the result.
50   StatusOr<ScopedShapedBuffer> Run(
51       const absl::Span<const ShapedBuffer* const> arguments,
52       ExecutableRunOptions run_options);
53 
54   // Similar to Run(), but allows for donating argument buffers to the
55   // executable.
56   StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments,
57                                 ExecutableRunOptions run_options);
58 
59   // Similar to Run(), but need not block the host waiting for the computation
60   // to complete before returning.
61   StatusOr<ScopedShapedBuffer> RunAsync(
62       const absl::Span<const ShapedBuffer* const> arguments,
63       ExecutableRunOptions run_options);
64 
65   // Similar to RunAsync(), but allows for donating argument buffers to the
66   // executable.
67   StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments,
68                                      ExecutableRunOptions run_options);
69 
70   // Return the options used to build the executable.
build_options()71   const ExecutableBuildOptions& build_options() const { return build_options_; }
72 
73   // Return the built executable.
executable()74   Executable* executable() const { return executable_.get(); }
75 
76  private:
77   StatusOr<ExecutionOutput> RunAsync(
78       absl::Span<Shape const* const> argument_host_shapes,
79       std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
80 
81   // Validates that the given arguments and options satisfy various constraints
82   // of the computation.
83   //
84   // The given ExecutableRunOptions override any values from TF_XLA_FLAGS
85   // environment variable.
86   Status ValidateExecutionOptions(const ExecutableRunOptions& run_options,
87                                   const Backend& backend);
88 
89   // Returns a literal containing the contents of the given ShapedBuffer.
90   StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
91 
92   StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper(
93       const absl::Span<const Shape* const> argument_shapes,
94       ExecutableRunOptions run_options);
95 
96   // The ordinal of the device which this executable was compiled for. The
97   // executable can run on all equivalent devices (as determined by
98   // Backend::devices_equivalent).
build_device_ordinal()99   int build_device_ordinal() const { return build_options_.device_ordinal(); }
100 
101   template <typename T>
AsyncCallAndBlockHostUntilDone(absl::Span<Shape const * const> argument_shapes,const ExecutableRunOptions & run_options,std::function<StatusOr<T> (const ExecutableRunOptions &)> async_callback)102   StatusOr<T> AsyncCallAndBlockHostUntilDone(
103       absl::Span<Shape const* const> argument_shapes,
104       const ExecutableRunOptions& run_options,
105       std::function<StatusOr<T>(const ExecutableRunOptions&)> async_callback) {
106     TF_ASSIGN_OR_RETURN(auto options_and_stream,
107                         RunHelper(argument_shapes, run_options));
108     ExecutableRunOptions options = options_and_stream.first.run_options();
109     options.set_device_ordinal(-1);
110     StatusOr<T> result = async_callback(options);
111     Status block_status = options.stream()->BlockHostUntilDone();
112     TF_RETURN_IF_ERROR(result.status());
113     TF_RETURN_IF_ERROR(block_status);
114     return result;
115   }
116 
117   // Compiled computation.
118   std::unique_ptr<Executable> executable_;
119 
120   // Execution backend.
121   Backend* backend_ = nullptr;
122 
123   // Options used to build the executable.
124   const ExecutableBuildOptions build_options_;
125 };
126 
127 // An XLA Client specialization for use when the client and service run in
128 // the same process.
129 class LocalClient : public Client {
130  public:
LocalClient(LocalService * service)131   explicit LocalClient(LocalService* service)
132       : Client(service), local_service_(service) {}
133 
134   LocalClient(const LocalClient&) = delete;
135   void operator=(const LocalClient&) = delete;
136 
137   // Build and return LocalExecutable objects (one per partition, as specified
138   // by the build options). The executable is compiled using the given
139   // XlaComputation, argument layouts and options.
140   //
141   // The given ExecutableBuildOptions overrides any values from XLA_FLAGS
142   // environment variable.
143   StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
144       const XlaComputation& computation,
145       const absl::Span<const Shape* const> argument_layouts,
146       const ExecutableBuildOptions& options);
147 
148   // Copy the literal data to the device with the given ordinal and return as a
149   // ScopedShapedBuffer. If non-null the given memory allocator is used for
150   // device memory allocation. If null, the default memory allocator for the
151   // device is used.
152   StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
153       const LiteralSlice& literal, int device_ordinal,
154       se::DeviceMemoryAllocator* allocator = nullptr);
155 
156   // Transfer the BorrowingLiteral to the device with the given ordinal.
157   StatusOr<TransferToServerResponse> TransferToLocalServer(
158       const ::xla::BorrowingLiteral& literal, int device_ordinal);
159 
160   // Copy the data from the device contained in the given ShapedBuffer and
161   // return as a Literal.
162   StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
163 
164   // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
165   // as long as the handle is valid.
166   StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
167       const GlobalDataHandle& data, int replica_number);
168 
169   // Transfer the given literal to the infeed queue of the given device.
170   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
171   // not inherit from Client and there is no possibility of confusion with
172   // Client::TransferToInfeed.
173   Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
174 
175   // Transfer and return a value from the outfeed of the given device. The
176   // shape of the object to transfer is determined by `literal`'s shape.
177   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
178   // not inherit from Client and there is no possibility of confusion with
179   // Client::TransferFromOutfeed.
180   Status TransferFromOutfeedLocal(int device_ordinal,
181                                   MutableBorrowingLiteral literal);
182 
183   // Returns the device ordinal that corresponds to the given replica number.
184   //
185   // This returns an error if there is not a one-to-one correspondence of
186   // replicas to device ordinals, but is useful as a short term mechanism for
187   // the "easy" case where a single replica is a single device.
188   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
189 
190   // Returns the platform that the underlying service targets.
191   se::Platform* platform() const;
192 
193   // Returns the number of devices on the system of the service platform
194   // type. Not all devices may be supported by the service (see
195   // device_ordinal_supported method).
196   int device_count() const;
197 
198   // Returns the default device ordinal that the service will run computations
199   // on if no device ordinal is specified in execute options.
200   int default_device_ordinal() const;
201 
202   // Returns whether the device with the given ordinal can be used by the
203   // service to execute computations. Not all devices of a particular platform
204   // may be usable by the service (eg, a GPU with insufficient CUDA compute
205   // capability).
206   bool device_ordinal_supported(int device_ordinal) const;
207 
208   // Returns the backend used to execute computations.
209   const Backend& backend() const;
210   Backend* mutable_backend();
211 
212  private:
213   LocalService* local_service_;
214 };
215 
216 }  // namespace xla
217 
218 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
219