1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/client/client.h" 24 #include "tensorflow/compiler/xla/client/executable_build_options.h" 25 #include "tensorflow/compiler/xla/client/xla_computation.h" 26 #include "tensorflow/compiler/xla/executable_run_options.h" 27 #include "tensorflow/compiler/xla/service/compiler.h" 28 #include "tensorflow/compiler/xla/service/executable.h" 29 #include "tensorflow/compiler/xla/service/hlo.pb.h" 30 #include "tensorflow/compiler/xla/service/local_service.h" 31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" 32 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 33 #include "tensorflow/compiler/xla/shape_tree.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/xla_data.pb.h" 36 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 37 #include "tensorflow/stream_executor/device_memory_allocator.h" 38 39 namespace xla { 40 41 class LocalExecutable { 42 public: 43 // Low-level constructor; LocalClient::Compile() is the usual way to create 44 // executables. 45 LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend, 46 ExecutableBuildOptions build_options); 47 48 // Run the compiled computation with the given arguments and options and 49 // return the result. 50 StatusOr<ScopedShapedBuffer> Run( 51 const absl::Span<const ShapedBuffer* const> arguments, 52 ExecutableRunOptions run_options); 53 54 // Similar to Run(), but allows for donating argument buffers to the 55 // executable. 56 StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments, 57 ExecutableRunOptions run_options); 58 59 // Similar to Run(), but need not block the host waiting for the computation 60 // to complete before returning. 61 StatusOr<ScopedShapedBuffer> RunAsync( 62 const absl::Span<const ShapedBuffer* const> arguments, 63 ExecutableRunOptions run_options); 64 65 // Similar to RunAsync(), but allows for donating argument buffers to the 66 // executable. 67 StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments, 68 ExecutableRunOptions run_options); 69 70 // Return the options used to build the executable. build_options()71 const ExecutableBuildOptions& build_options() const { return build_options_; } 72 73 // Return the built executable. executable()74 Executable* executable() const { return executable_.get(); } 75 76 private: 77 StatusOr<ExecutionOutput> RunAsync( 78 absl::Span<Shape const* const> argument_host_shapes, 79 std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options); 80 81 // Validates that the given arguments and options satisfy various constraints 82 // of the computation. 83 // 84 // The given ExecutableRunOptions override any values from TF_XLA_FLAGS 85 // environment variable. 86 Status ValidateExecutionOptions(const ExecutableRunOptions& run_options, 87 const Backend& backend); 88 89 // Returns a literal containing the contents of the given ShapedBuffer. 90 StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); 91 92 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper( 93 const absl::Span<const Shape* const> argument_shapes, 94 ExecutableRunOptions run_options); 95 96 // The ordinal of the device which this executable was compiled for. The 97 // executable can run on all equivalent devices (as determined by 98 // Backend::devices_equivalent). build_device_ordinal()99 int build_device_ordinal() const { return build_options_.device_ordinal(); } 100 101 template <typename T> AsyncCallAndBlockHostUntilDone(absl::Span<Shape const * const> argument_shapes,const ExecutableRunOptions & run_options,std::function<StatusOr<T> (const ExecutableRunOptions &)> async_callback)102 StatusOr<T> AsyncCallAndBlockHostUntilDone( 103 absl::Span<Shape const* const> argument_shapes, 104 const ExecutableRunOptions& run_options, 105 std::function<StatusOr<T>(const ExecutableRunOptions&)> async_callback) { 106 TF_ASSIGN_OR_RETURN(auto options_and_stream, 107 RunHelper(argument_shapes, run_options)); 108 ExecutableRunOptions options = options_and_stream.first.run_options(); 109 options.set_device_ordinal(-1); 110 StatusOr<T> result = async_callback(options); 111 Status block_status = options.stream()->BlockHostUntilDone(); 112 TF_RETURN_IF_ERROR(result.status()); 113 TF_RETURN_IF_ERROR(block_status); 114 return result; 115 } 116 117 // Compiled computation. 118 std::unique_ptr<Executable> executable_; 119 120 // Execution backend. 121 Backend* backend_ = nullptr; 122 123 // Options used to build the executable. 124 const ExecutableBuildOptions build_options_; 125 }; 126 127 // An XLA Client specialization for use when the client and service run in 128 // the same process. 129 class LocalClient : public Client { 130 public: LocalClient(LocalService * service)131 explicit LocalClient(LocalService* service) 132 : Client(service), local_service_(service) {} 133 134 LocalClient(const LocalClient&) = delete; 135 void operator=(const LocalClient&) = delete; 136 137 // Build and return LocalExecutable objects (one per partition, as specified 138 // by the build options). The executable is compiled using the given 139 // XlaComputation, argument layouts and options. 140 // 141 // The given ExecutableBuildOptions overrides any values from XLA_FLAGS 142 // environment variable. 143 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile( 144 const XlaComputation& computation, 145 const absl::Span<const Shape* const> argument_layouts, 146 const ExecutableBuildOptions& options); 147 148 // Copy the literal data to the device with the given ordinal and return as a 149 // ScopedShapedBuffer. If non-null the given memory allocator is used for 150 // device memory allocation. If null, the default memory allocator for the 151 // device is used. 152 StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer( 153 const LiteralSlice& literal, int device_ordinal, 154 se::DeviceMemoryAllocator* allocator = nullptr); 155 156 // Transfer the BorrowingLiteral to the device with the given ordinal. 157 StatusOr<TransferToServerResponse> TransferToLocalServer( 158 const ::xla::BorrowingLiteral& literal, int device_ordinal); 159 160 // Copy the data from the device contained in the given ShapedBuffer and 161 // return as a Literal. 162 StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); 163 164 // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid 165 // as long as the handle is valid. 166 StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer( 167 const GlobalDataHandle& data, int replica_number); 168 169 // Transfer the given literal to the infeed queue of the given device. 170 // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does 171 // not inherit from Client and there is no possibility of confusion with 172 // Client::TransferToInfeed. 173 Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal); 174 175 // Transfer and return a value from the outfeed of the given device. The 176 // shape of the object to transfer is determined by `literal`'s shape. 177 // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does 178 // not inherit from Client and there is no possibility of confusion with 179 // Client::TransferFromOutfeed. 180 Status TransferFromOutfeedLocal(int device_ordinal, 181 MutableBorrowingLiteral literal); 182 183 // Returns the device ordinal that corresponds to the given replica number. 184 // 185 // This returns an error if there is not a one-to-one correspondence of 186 // replicas to device ordinals, but is useful as a short term mechanism for 187 // the "easy" case where a single replica is a single device. 188 StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number); 189 190 // Returns the platform that the underlying service targets. 191 se::Platform* platform() const; 192 193 // Returns the number of devices on the system of the service platform 194 // type. Not all devices may be supported by the service (see 195 // device_ordinal_supported method). 196 int device_count() const; 197 198 // Returns the default device ordinal that the service will run computations 199 // on if no device ordinal is specified in execute options. 200 int default_device_ordinal() const; 201 202 // Returns whether the device with the given ordinal can be used by the 203 // service to execute computations. Not all devices of a particular platform 204 // may be usable by the service (eg, a GPU with insufficient CUDA compute 205 // capability). 206 bool device_ordinal_supported(int device_ordinal) const; 207 208 // Returns the backend used to execute computations. 209 const Backend& backend() const; 210 Backend* mutable_backend(); 211 212 private: 213 LocalService* local_service_; 214 }; 215 216 } // namespace xla 217 218 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 219