1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/client_library.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/service/backend.h"
22 #include "tensorflow/compiler/xla/service/platform_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28
LocalClientOptions(se::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads,const std::optional<std::set<int>> & allowed_devices)29 LocalClientOptions::LocalClientOptions(
30 se::Platform* platform, int number_of_replicas,
31 int intra_op_parallelism_threads,
32 const std::optional<std::set<int>>& allowed_devices)
33 : platform_(platform),
34 number_of_replicas_(number_of_replicas),
35 intra_op_parallelism_threads_(intra_op_parallelism_threads),
36 allowed_devices_(allowed_devices) {}
37
set_platform(se::Platform * platform)38 LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) {
39 platform_ = platform;
40 return *this;
41 }
42
platform() const43 se::Platform* LocalClientOptions::platform() const { return platform_; }
44
set_number_of_replicas(int number_of_replicas)45 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
46 int number_of_replicas) {
47 number_of_replicas_ = number_of_replicas;
48 return *this;
49 }
50
number_of_replicas() const51 int LocalClientOptions::number_of_replicas() const {
52 return number_of_replicas_;
53 }
54
set_intra_op_parallelism_threads(int num_threads)55 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
56 int num_threads) {
57 intra_op_parallelism_threads_ = num_threads;
58 return *this;
59 }
60
intra_op_parallelism_threads() const61 int LocalClientOptions::intra_op_parallelism_threads() const {
62 return intra_op_parallelism_threads_;
63 }
64
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)65 LocalClientOptions& LocalClientOptions::set_allowed_devices(
66 const std::optional<std::set<int>>& allowed_devices) {
67 allowed_devices_ = allowed_devices;
68 return *this;
69 }
70
allowed_devices() const71 const std::optional<std::set<int>>& LocalClientOptions::allowed_devices()
72 const {
73 return allowed_devices_;
74 }
75
Singleton()76 /* static */ ClientLibrary& ClientLibrary::Singleton() {
77 static ClientLibrary* c = new ClientLibrary;
78 return *c;
79 }
80
81 ClientLibrary::ClientLibrary() = default;
82 ClientLibrary::~ClientLibrary() = default;
83
GetOrCreateLocalClient(se::Platform * platform,const std::optional<std::set<int>> & device_set)84 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
85 se::Platform* platform, const std::optional<std::set<int>>& device_set) {
86 LocalClientOptions default_options;
87 default_options.set_platform(platform);
88 default_options.set_allowed_devices(device_set);
89 return GetOrCreateLocalClient(default_options);
90 }
91
GetOrCreateLocalClient(const LocalClientOptions & options)92 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
93 const LocalClientOptions& options) {
94 se::Platform* platform = options.platform();
95 int replica_count = options.number_of_replicas();
96 ClientLibrary& client_library = Singleton();
97 absl::MutexLock lock(&client_library.service_mutex_);
98
99 if (platform == nullptr) {
100 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
101 }
102
103 auto it = client_library.local_instances_.find(platform->id());
104 if (it != client_library.local_instances_.end()) {
105 return it->second->client.get();
106 }
107
108 ServiceOptions service_options;
109 service_options.set_platform(platform);
110 service_options.set_number_of_replicas(replica_count);
111 service_options.set_intra_op_parallelism_threads(
112 options.intra_op_parallelism_threads());
113 service_options.set_allowed_devices(options.allowed_devices());
114 auto instance = std::make_unique<LocalInstance>();
115 TF_ASSIGN_OR_RETURN(instance->service,
116 LocalService::NewService(service_options));
117 instance->client = std::make_unique<LocalClient>(instance->service.get());
118 LocalClient* cl = instance->client.get();
119
120 client_library.local_instances_.insert(
121 std::make_pair(platform->id(), std::move(instance)));
122 return cl;
123 }
124
LocalClientOrDie()125 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
126 auto client_status = GetOrCreateLocalClient();
127 TF_CHECK_OK(client_status.status());
128 return client_status.ValueOrDie();
129 }
130
GetXlaService(se::Platform * platform)131 /* static */ LocalService* ClientLibrary::GetXlaService(
132 se::Platform* platform) {
133 ClientLibrary& client_library = Singleton();
134 absl::MutexLock lock(&client_library.service_mutex_);
135 auto it = client_library.local_instances_.find(platform->id());
136 CHECK(it != client_library.local_instances_.end());
137 return it->second->service.get();
138 }
139
140 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(se::Platform * platform)141 ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
142 ClientLibrary& client_library = Singleton();
143 absl::MutexLock lock(&client_library.service_mutex_);
144
145 if (platform == nullptr) {
146 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
147 }
148
149 auto it = client_library.compile_only_instances_.find(platform->id());
150 if (it != client_library.compile_only_instances_.end()) {
151 return it->second->client.get();
152 }
153
154 auto instance = std::make_unique<CompileOnlyInstance>();
155 TF_ASSIGN_OR_RETURN(instance->service,
156 CompileOnlyService::NewService(platform));
157 instance->client =
158 std::make_unique<CompileOnlyClient>(instance->service.get());
159 CompileOnlyClient* cl = instance->client.get();
160
161 client_library.compile_only_instances_.insert(
162 std::make_pair(platform->id(), std::move(instance)));
163 return cl;
164 }
165
DestroyLocalInstances()166 /* static */ void ClientLibrary::DestroyLocalInstances() {
167 ClientLibrary& client_library = Singleton();
168 absl::MutexLock lock(&client_library.service_mutex_);
169
170 client_library.local_instances_.clear();
171 client_library.compile_only_instances_.clear();
172 }
173
174 } // namespace xla
175