• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/client_library.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/backend.h"
22 #include "tensorflow/compiler/xla/service/platform_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 
LocalClientOptions(se::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads,const std::optional<std::set<int>> & allowed_devices)29 LocalClientOptions::LocalClientOptions(
30     se::Platform* platform, int number_of_replicas,
31     int intra_op_parallelism_threads,
32     const std::optional<std::set<int>>& allowed_devices)
33     : platform_(platform),
34       number_of_replicas_(number_of_replicas),
35       intra_op_parallelism_threads_(intra_op_parallelism_threads),
36       allowed_devices_(allowed_devices) {}
37 
set_platform(se::Platform * platform)38 LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) {
39   platform_ = platform;
40   return *this;
41 }
42 
platform() const43 se::Platform* LocalClientOptions::platform() const { return platform_; }
44 
set_number_of_replicas(int number_of_replicas)45 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
46     int number_of_replicas) {
47   number_of_replicas_ = number_of_replicas;
48   return *this;
49 }
50 
number_of_replicas() const51 int LocalClientOptions::number_of_replicas() const {
52   return number_of_replicas_;
53 }
54 
set_intra_op_parallelism_threads(int num_threads)55 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
56     int num_threads) {
57   intra_op_parallelism_threads_ = num_threads;
58   return *this;
59 }
60 
intra_op_parallelism_threads() const61 int LocalClientOptions::intra_op_parallelism_threads() const {
62   return intra_op_parallelism_threads_;
63 }
64 
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)65 LocalClientOptions& LocalClientOptions::set_allowed_devices(
66     const std::optional<std::set<int>>& allowed_devices) {
67   allowed_devices_ = allowed_devices;
68   return *this;
69 }
70 
allowed_devices() const71 const std::optional<std::set<int>>& LocalClientOptions::allowed_devices()
72     const {
73   return allowed_devices_;
74 }
75 
Singleton()76 /* static */ ClientLibrary& ClientLibrary::Singleton() {
77   static ClientLibrary* c = new ClientLibrary;
78   return *c;
79 }
80 
81 ClientLibrary::ClientLibrary() = default;
82 ClientLibrary::~ClientLibrary() = default;
83 
GetOrCreateLocalClient(se::Platform * platform,const std::optional<std::set<int>> & device_set)84 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
85     se::Platform* platform, const std::optional<std::set<int>>& device_set) {
86   LocalClientOptions default_options;
87   default_options.set_platform(platform);
88   default_options.set_allowed_devices(device_set);
89   return GetOrCreateLocalClient(default_options);
90 }
91 
GetOrCreateLocalClient(const LocalClientOptions & options)92 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
93     const LocalClientOptions& options) {
94   se::Platform* platform = options.platform();
95   int replica_count = options.number_of_replicas();
96   ClientLibrary& client_library = Singleton();
97   absl::MutexLock lock(&client_library.service_mutex_);
98 
99   if (platform == nullptr) {
100     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
101   }
102 
103   auto it = client_library.local_instances_.find(platform->id());
104   if (it != client_library.local_instances_.end()) {
105     return it->second->client.get();
106   }
107 
108   ServiceOptions service_options;
109   service_options.set_platform(platform);
110   service_options.set_number_of_replicas(replica_count);
111   service_options.set_intra_op_parallelism_threads(
112       options.intra_op_parallelism_threads());
113   service_options.set_allowed_devices(options.allowed_devices());
114   auto instance = std::make_unique<LocalInstance>();
115   TF_ASSIGN_OR_RETURN(instance->service,
116                       LocalService::NewService(service_options));
117   instance->client = std::make_unique<LocalClient>(instance->service.get());
118   LocalClient* cl = instance->client.get();
119 
120   client_library.local_instances_.insert(
121       std::make_pair(platform->id(), std::move(instance)));
122   return cl;
123 }
124 
LocalClientOrDie()125 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
126   auto client_status = GetOrCreateLocalClient();
127   TF_CHECK_OK(client_status.status());
128   return client_status.ValueOrDie();
129 }
130 
GetXlaService(se::Platform * platform)131 /* static */ LocalService* ClientLibrary::GetXlaService(
132     se::Platform* platform) {
133   ClientLibrary& client_library = Singleton();
134   absl::MutexLock lock(&client_library.service_mutex_);
135   auto it = client_library.local_instances_.find(platform->id());
136   CHECK(it != client_library.local_instances_.end());
137   return it->second->service.get();
138 }
139 
140 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(se::Platform * platform)141 ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
142   ClientLibrary& client_library = Singleton();
143   absl::MutexLock lock(&client_library.service_mutex_);
144 
145   if (platform == nullptr) {
146     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
147   }
148 
149   auto it = client_library.compile_only_instances_.find(platform->id());
150   if (it != client_library.compile_only_instances_.end()) {
151     return it->second->client.get();
152   }
153 
154   auto instance = std::make_unique<CompileOnlyInstance>();
155   TF_ASSIGN_OR_RETURN(instance->service,
156                       CompileOnlyService::NewService(platform));
157   instance->client =
158       std::make_unique<CompileOnlyClient>(instance->service.get());
159   CompileOnlyClient* cl = instance->client.get();
160 
161   client_library.compile_only_instances_.insert(
162       std::make_pair(platform->id(), std::move(instance)));
163   return cl;
164 }
165 
DestroyLocalInstances()166 /* static */ void ClientLibrary::DestroyLocalInstances() {
167   ClientLibrary& client_library = Singleton();
168   absl::MutexLock lock(&client_library.service_mutex_);
169 
170   client_library.local_instances_.clear();
171   client_library.compile_only_instances_.clear();
172 }
173 
174 }  // namespace xla
175