• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/client.h"
24 #include "tensorflow/compiler/xla/client/executable_build_options.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/local_service.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
33 #include "tensorflow/compiler/xla/shape_tree.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37 #include "tensorflow/stream_executor/device_memory_allocator.h"
38 
39 namespace xla {
40 
41 class LocalExecutable {
42  public:
43   // Low-level constructor; LocalClient::Compile() is the usual way to create
44   // executables.
45   LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
46                   ExecutableBuildOptions build_options);
47 
48   // Run the compiled computation with the given arguments and options and
49   // return the result.
50   StatusOr<ScopedShapedBuffer> Run(
51       const absl::Span<const ShapedBuffer* const> arguments,
52       ExecutableRunOptions run_options);
53 
54   // Similar to Run(), but need not block the host waiting for the computation
55   // to complete before returning.
56   StatusOr<ScopedShapedBuffer> RunAsync(
57       const absl::Span<const ShapedBuffer* const> arguments,
58       ExecutableRunOptions run_options);
59 
60   // Similar to RunAsync(), but allows for donating argument buffers to the
61   // executable.
62   StatusOr<ExecutionOutput> RunAsync(
63       absl::Span<Shape const* const> argument_host_shapes,
64       std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
65       ExecutableRunOptions run_options);
66 
67   // Return the options used to build the executable.
build_options()68   const ExecutableBuildOptions& build_options() const { return build_options_; }
69 
70   // Return the built executable.
executable()71   Executable* executable() const { return executable_.get(); }
72 
73  private:
74   // Validates that the given arguments and options satisfy various constraints
75   // of the computation.
76   //
77   // The given ExecutableRunOptions override any values from TF_XLA_FLAGS
78   // environment variable.
79   Status ValidateExecutionOptions(
80       const ExecutableRunOptions& run_options, const Backend& backend);
81 
82   // Returns a literal containing the contents of the given ShapedBuffer.
83   StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
84 
85   StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper(
86       const absl::Span<const Shape* const> argument_shapes,
87       ExecutableRunOptions run_options);
88 
89   // The ordinal of the device which this executable was compiled for. The
90   // executable can run on all equivalent devices (as determined by
91   // Backend::devices_equivalent).
build_device_ordinal()92   int build_device_ordinal() const { return build_options_.device_ordinal(); }
93 
94   // Compiled computation.
95   std::unique_ptr<Executable> executable_;
96 
97   // Execution backend.
98   Backend* backend_ = nullptr;
99 
100   // Options used to build the executable.
101   const ExecutableBuildOptions build_options_;
102 };
103 
104 // An XLA Client specialization for use when the client and service run in
105 // the same process.
106 class LocalClient : public Client {
107  public:
LocalClient(LocalService * service)108   explicit LocalClient(LocalService* service)
109       : Client(service), local_service_(service) {}
110 
111   LocalClient(const LocalClient&) = delete;
112   void operator=(const LocalClient&) = delete;
113 
114   // Build and return LocalExecutable objects (one per partition, as specified
115   // by the build options). The executable is compiled using the given
116   // XlaComputation, argument layouts and options.
117   //
118   // The given ExecutableBuildOptions overrides any values from XLA_FLAGS
119   // environment variable.
120   StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
121       const XlaComputation& computation,
122       const absl::Span<const Shape* const> argument_layouts,
123       const ExecutableBuildOptions& options);
124 
125   // Copy the literal data to the device with the given ordinal and return as a
126   // ScopedShapedBuffer. If non-null the given memory allocator is used for
127   // device memory allocation. If null, the default memory allocator for the
128   // device is used.
129   StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
130       const LiteralSlice& literal, int device_ordinal,
131       se::DeviceMemoryAllocator* allocator = nullptr);
132 
133   // Transfer the BorrowingLiteral to the device with the given ordinal.
134   StatusOr<TransferToServerResponse> TransferToLocalServer(
135       const ::xla::BorrowingLiteral& literal, int device_ordinal);
136 
137   // Copy the data from the device contained in the given ShapedBuffer and
138   // return as a Literal.
139   StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
140 
141   // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
142   // as long as the handle is valid.
143   StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
144       const GlobalDataHandle& data, int replica_number);
145 
146   // Transfer the given literal to the infeed queue of the given device.
147   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
148   // not inherit from Client and there is no possibility of confusion with
149   // Client::TransferToInfeed.
150   Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
151 
152   // Transfer and return a value of the given shape from the outfeed of the
153   // given device.
154   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
155   // not inherit from Client and there is no possibility of confusion with
156   // Client::TransferFromOutfeed.
157   StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
158                                              int device_ordinal);
159 
160   // Returns the device ordinal that corresponds to the given replica number.
161   //
162   // This returns an error if there is not a one-to-one correspondence of
163   // replicas to device ordinals, but is useful as a short term mechanism for
164   // the "easy" case where a single replica is a single device.
165   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
166 
167   // Returns the platform that the underlying service targets.
168   se::Platform* platform() const;
169 
170   // Returns the number of devices on the system of the service platform
171   // type. Not all devices may be supported by the service (see
172   // device_ordinal_supported method).
173   int device_count() const;
174 
175   // Returns the default device ordinal that the service will run computations
176   // on if no device ordinal is specified in execute options.
177   int default_device_ordinal() const;
178 
179   // Returns whether the device with the given ordinal can be used by the
180   // service to execute computations. Not all devices of a particular platform
181   // may be usable by the service (eg, a GPU with insufficient CUDA compute
182   // capability).
183   bool device_ordinal_supported(int device_ordinal) const;
184 
185   // Returns the backend used to execute computations.
186   const Backend& backend() const;
187   Backend* mutable_backend();
188 
189  private:
190   LocalService* local_service_;
191 };
192 
193 }  // namespace xla
194 
195 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
196