1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/debug_options_flags.h" 27 #include "tensorflow/compiler/xla/executable_run_options.h" 28 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 29 #include "tensorflow/compiler/xla/service/backend.h" 30 #include "tensorflow/compiler/xla/service/channel_tracker.h" 31 #include "tensorflow/compiler/xla/service/compilation_cache.h" 32 #include "tensorflow/compiler/xla/service/executable.h" 33 #include "tensorflow/compiler/xla/service/execution_tracker.h" 34 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 37 #include "tensorflow/compiler/xla/service_interface.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/xla.pb.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/platform/macros.h" 44 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 45 #include "tensorflow/stream_executor/device_memory_allocator.h" 46 47 namespace xla { 48 49 // Options to configure the service when it is created. 50 class ServiceOptions { 51 public: 52 // Set the platform backing the service, or nullptr for the default platform. 53 ServiceOptions& set_platform(se::Platform* platform); 54 se::Platform* platform() const; 55 56 // Set the default number of replicas to use when compiling replicated 57 // programs. 58 ServiceOptions& set_number_of_replicas(int number_of_replicas); 59 int number_of_replicas() const; 60 61 // Sets the thread pool size for parallel execution of an individual operator. 62 ServiceOptions& set_intra_op_parallelism_threads(int num_threads); 63 int intra_op_parallelism_threads() const; 64 65 // Sets the allowed_devices set for selectively constructing stream executors 66 // on the platform. 67 ServiceOptions& set_allowed_devices( 68 const absl::optional<std::set<int>>& allowed_devices); 69 const absl::optional<std::set<int>>& allowed_devices() const; 70 71 private: 72 se::Platform* platform_ = nullptr; 73 int number_of_replicas_ = 1; 74 int intra_op_parallelism_threads_ = -1; 75 absl::optional<std::set<int>> allowed_devices_; 76 }; 77 78 // The XLA service object, which is the same across all platforms. It maintains 79 // the service state of computations and allocations, and delegates 80 // target-specific requests to the target-specific infrastructure 81 // (target-specific compiler, StreamExecutor). 82 class Service : public ServiceInterface { 83 public: 84 // Factory method for creating a new Service. 85 static StatusOr<std::unique_ptr<Service>> NewService( 86 se::Platform* platform = nullptr); 87 static StatusOr<std::unique_ptr<Service>> NewService( 88 const ServiceOptions& options); 89 90 // Unregisters a previously-allocated global handle. 91 // 92 // If the handle given is not currently allocated, a NOT_FOUND status is 93 // returned. 94 Status Unregister(const UnregisterRequest* arg, 95 UnregisterResponse* result) override; 96 97 // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each 98 // element in the tuple. 99 Status DeconstructTuple(const DeconstructTupleRequest* arg, 100 DeconstructTupleResponse* result) override; 101 102 // Compiles a computation into an executable. The request contains the whole 103 // computation graph. Returns the handle to the executable. 104 Status Compile(const CompileRequest* arg, CompileResponse* result) override; 105 106 // Executes an executable with the provided global data passes as immutable 107 // arguments. The request contains the handle to the executable. Returns 108 // global data output and execution timing. 109 Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; 110 111 // Executes one or more computations in parallel with the provided global data 112 // passed as immutable arguments. Returns global data output for each 113 // computation. 114 Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, 115 ExecuteParallelResponse* result) override; 116 117 // Requests one or more device handles from the target. 118 // 119 // When N device handles are requested and the number of replicas is R, at 120 // least N * R devices must be available. The devices are assigned based on 121 // the device ordinals such that the first R available devices are assigned to 122 // the first set of replicas, and the next R devices to the second set of 123 // replicas, etc. Each returned device handle represents the device with the 124 // replica id 0. 125 Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, 126 GetDeviceHandlesResponse* result) override; 127 128 // Waits until the specified execution is complete and returns the result. 129 // Calling this API multiple times with the same execution handle returns the 130 // method with an error since the execution handle is destroyed after the 131 // first call. 132 Status WaitForExecution(const WaitForExecutionRequest* arg, 133 WaitForExecutionResponse* result) override; 134 135 // Requests that global data be transferred to the client in literal form. 136 Status TransferToClient(const TransferToClientRequest* arg, 137 TransferToClientResponse* result) override; 138 139 // Transfers data from a literal provided by the client, into device memory. 140 Status TransferToServer(const TransferToServerRequest* arg, 141 TransferToServerResponse* result) override; 142 143 // Transfers data from a literal provided by the client, into the Infeed 144 // buffer of the device. 145 Status TransferToInfeed(const TransferToInfeedRequest* arg, 146 TransferToInfeedResponse* result) override; 147 148 // Transfers data from the Outfeed othe device to the literal provided by the 149 // client. 150 Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, 151 TransferFromOutfeedResponse* result) override; 152 153 // Resets devices, clearing all existing state on all the devices associated 154 // with this service (including memory allocated on the devices). 155 // 156 // ResetDevice may only be called where no previous Execution state on the 157 // device is used by the next Execution. 158 // 159 // ResetDevice should be called before an Execution that expect the device to 160 // be in the reset state. For example, if the prior Execution modifies device 161 // state (e.g., architectural state) that the next Execution depends on. 162 Status ResetDevice(const ResetDeviceRequest* arg, 163 ResetDeviceResponse* result) override; 164 165 Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, 166 ComputeConstantResponse* result) override; 167 168 // Returns the shape (with layout) of an array associated with a given data 169 // handle. 170 Status GetShape(const GetShapeRequest* arg, 171 GetShapeResponse* result) override; 172 173 // Retrieves the statistics of a computation. 174 Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, 175 ComputationStatsResponse* result) override; 176 177 // Creates a unique channel handle that can be used for Send/Recv 178 // instructions. 179 Status CreateChannelHandle(const CreateChannelHandleRequest* arg, 180 CreateChannelHandleResponse* result) override; 181 182 // Returns the backend used to execute computations. backend()183 const Backend& backend() const { return *execute_backend_; } mutable_backend()184 Backend* mutable_backend() { return execute_backend_.get(); } 185 186 // Create a Hlo module config for the given program shape and arguments. 187 // aot_options is optional; if not given a default is used. 188 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 189 const ProgramShape& program_shape, 190 absl::Span<const Shape* const> argument_shapes, 191 const ExecutionOptions* execution_options, 192 const AotCompilationOptions* aot_options = nullptr); 193 194 private: 195 // A private overload for Service itself, used by other methods within this 196 // class. 197 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 198 const ProgramShape& program_shape, 199 absl::Span<const ShapedBuffer* const> arguments, 200 const ExecutionOptions& execution_options, 201 const AotCompilationOptions* aot_options = nullptr); 202 203 // Prepare the executors for executing parallel. 204 StatusOr<std::vector<se::StreamExecutor*>> GetExecutors( 205 const ExecutionOptions& execution_options, int64 requests_size, 206 int64 request_index) const; 207 208 // Prepare the arguments for executing parallel. 209 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments( 210 const ExecutionOptions& execution_options, 211 absl::Span<const GlobalDataHandle* const> arguments) const; 212 213 protected: 214 friend class LocalExecutable; 215 216 // The constructor is private. Use the NewService factory to create new 217 // service objects. 218 Service(const ServiceOptions& options, 219 std::unique_ptr<Backend> execute_backend); 220 221 // Resolves the given argument handles in the allocation tracker and returns 222 // the corresponding allocations for every replica. The function also verifies 223 // that each allocation matches the execution platform and device ordinal of 224 // the corresponding replica. 225 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> 226 ResolveAndValidateArguments( 227 absl::Span<const GlobalDataHandle* const> arguments, 228 absl::Span<se::StreamExecutor* const> stream_executors) const; 229 230 // Builds an Executable for the given parameters. 231 // 232 // If device_allocator is not null, the compiler may use it to allocate temp 233 // buffers, which the compiler is responsible for freeing. The allocator 234 // given here need not match the allocator used when running the executable. 235 StatusOr<std::unique_ptr<Executable>> BuildExecutable( 236 const HloModuleProto& module_proto, 237 std::unique_ptr<HloModuleConfig> module_config, Backend* backend, 238 se::StreamExecutor* executor, const Compiler::CompileOptions& options, 239 bool run_backend_only = false); 240 241 // Same as BuildExecutable() above, but builds a list of Executables for the 242 // given computations that may interact with each other. 243 StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables( 244 const std::vector<const HloModuleProto*>& module_protos, 245 std::vector<std::unique_ptr<HloModuleConfig>> module_configs, 246 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, 247 const Compiler::CompileOptions& options, bool run_backend_only = false); 248 249 // Runs the given executable with the given arguments and register the result 250 // in the allocation tracker. The handle of the result from the tracker is 251 // returned. If the parameter "profile" is not null, it points to an 252 // ExecutionProfile object which will be filled in with profile data. 253 StatusOr<GlobalDataHandle> ExecuteAndRegisterResult( 254 Executable* executable, 255 absl::Span<const std::vector<const ShapedBuffer*>> arguments, 256 Backend* backend, const DeviceHandle& device_handle, 257 const string& result_tag, ExecutionProfile* profile); 258 259 // Runs the given executables with the given arguments and register the result 260 // from each executable in the allocation tracker. The handles of the result 261 // from the tracker are returned. 262 StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult( 263 absl::Span<Executable* const> executables, 264 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments, 265 Backend* backend, absl::Span<const DeviceHandle> device_handles, 266 absl::Span<const string> result_tags, ExecutionProfile* profile); 267 268 // Convenience function which checks whether the given client_shape 269 // (presumably passed by the client to set the result layout) is valid for the 270 // given computation result shape. 271 Status ValidateResultShape(const Shape& client_shape, 272 const Shape& result_shape) const; 273 274 // Returns the stream executors assigned to the replicas represented by the 275 // given device handle. Each device_handle is a virtual replicated device that 276 // represents a set of physical devices for the replicas. 277 StatusOr<std::vector<se::StreamExecutor*>> Replicas( 278 const Backend& backend, const DeviceHandle& device_handle) const; 279 280 // Returns the device handle that represents the replicated device for a 281 // single computation that is not model-parallelized. 282 DeviceHandle SingleComputationDeviceHandle() const; 283 284 ServiceOptions options_; 285 286 // Cache containing previously built Executables. 287 CompilationCache compilation_cache_; 288 289 // Tracks channels created via the API. 290 ChannelTracker channel_tracker_; 291 292 // Tracks allocations made via the API and computation execution. 293 AllocationTracker allocation_tracker_; 294 295 // Tracks asynchronously launched executions via the API. 296 ExecutionTracker execution_tracker_; 297 298 // Backend to compile and execute computations on. 299 std::unique_ptr<Backend> execute_backend_; 300 301 TF_DISALLOW_COPY_AND_ASSIGN(Service); 302 }; 303 304 } // namespace xla 305 306 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 307