• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The "client library" instantiates a local (in-process) XLA service for
17 // use by this process, and connects to it with a singleton XLA local
18 // client. ClientLibrary::GetOrCreateLocalClient will spawn a local service,
19 // and return a client that's connected to it and ready to run XLA
20 // computations.
21 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
22 #define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
23 
24 #include <functional>
25 #include <memory>
26 #include <optional>
27 #include <set>
28 #include <string>
29 #include <vector>
30 
31 #include "absl/container/flat_hash_map.h"
32 #include "tensorflow/compiler/xla/client/compile_only_client.h"
33 #include "tensorflow/compiler/xla/client/local_client.h"
34 #include "tensorflow/compiler/xla/service/compile_only_service.h"
35 #include "tensorflow/compiler/xla/service/local_service.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
39 #include "tensorflow/stream_executor/device_memory_allocator.h"
40 
41 namespace xla {
42 
43 // Options to configure the local client when it is created.
44 class LocalClientOptions {
45  public:
46   LocalClientOptions(
47       se::Platform* platform = nullptr, int number_of_replicas = 1,
48       int intra_op_parallelism_threads = -1,
49       const std::optional<std::set<int>>& allowed_devices = std::nullopt);
50 
51   // Set the platform backing the service, or nullptr for the default platform.
52   LocalClientOptions& set_platform(se::Platform* platform);
53   se::Platform* platform() const;
54 
55   // Set the number of replicas to use when compiling replicated
56   // programs.
57   LocalClientOptions& set_number_of_replicas(int number_of_replicas);
58   int number_of_replicas() const;
59 
60   // Sets the thread pool size for parallel execution of an individual operator.
61   LocalClientOptions& set_intra_op_parallelism_threads(int num_threads);
62   int intra_op_parallelism_threads() const;
63 
64   // Sets the allowed_devices set for selectively constructing stream executors
65   // on the platform.
66   LocalClientOptions& set_allowed_devices(
67       const std::optional<std::set<int>>& allowed_devices);
68   const std::optional<std::set<int>>& allowed_devices() const;
69 
70  private:
71   se::Platform* platform_;
72   int number_of_replicas_;
73   int intra_op_parallelism_threads_;
74   std::optional<std::set<int>> allowed_devices_;
75 };
76 
77 class ClientLibrary {
78  public:
79   // Singleton constructor-or-accessor -- returns a client for the application
80   // to issue XLA commands on. Arguments:
81   //
82   //   platform : The platform the underlying XLA service should target. If
83   //     null then default platform is used.
84   //   device_set: Set of device IDs for which the stream executor will be
85   //   created, for the given platform.
86   static StatusOr<LocalClient*> GetOrCreateLocalClient(
87       se::Platform* platform = nullptr,
88       const std::optional<std::set<int>>& allowed_devices = std::nullopt);
89   static StatusOr<LocalClient*> GetOrCreateLocalClient(
90       const LocalClientOptions& options);
91 
92   // Convenience "or-die" wrapper around the above which returns the existing
93   // client library or creates one with default platform and allocator.
94   static LocalClient* LocalClientOrDie();
95 
96   // Returns the service from the service thread. Only used in unit tests to
97   // access user computations from client.
98   static LocalService* GetXlaService(se::Platform* platform);
99 
100   // Singleton constructor-or-accessor for compile-only clients. Arguments:
101   //
102   //   platform : The platform the underlying XLA service should target. If
103   //     null then default platform is used.
104   static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
105       se::Platform* platform = nullptr);
106 
107   // Clears the local instance and compile only instance caches. The client
108   // pointers returned by the previous GetOrCreateLocalClient() or
109   // GetOrCreateCompileOnlyClient() invocations are not valid anymore.
110   static void DestroyLocalInstances();
111 
112  private:
113   // Returns the singleton instance of ClientLibrary.
114   static ClientLibrary& Singleton();
115 
116   ClientLibrary();
117   ~ClientLibrary();
118 
119   struct LocalInstance {
120     // Service that is wrapped by the singleton client object.
121     std::unique_ptr<LocalService> service;
122     // Singleton client object.
123     std::unique_ptr<LocalClient> client;
124   };
125 
126   struct CompileOnlyInstance {
127     // Service that is wrapped by the singleton client object.
128     std::unique_ptr<CompileOnlyService> service;
129     // Singleton client object.
130     std::unique_ptr<CompileOnlyClient> client;
131   };
132 
133   absl::Mutex service_mutex_;  // Guards the singleton creation state.
134   absl::flat_hash_map<se::Platform::Id, std::unique_ptr<LocalInstance>>
135       local_instances_ ABSL_GUARDED_BY(service_mutex_);
136 
137   absl::flat_hash_map<se::Platform::Id, std::unique_ptr<CompileOnlyInstance>>
138       compile_only_instances_ ABSL_GUARDED_BY(service_mutex_);
139 
140   ClientLibrary(const ClientLibrary&) = delete;
141   ClientLibrary& operator=(const ClientLibrary&) = delete;
142 };
143 
144 }  // namespace xla
145 
146 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
147