1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/client_library.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/service/backend.h"
20 #include "tensorflow/compiler/xla/service/platform_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace xla {
26
LocalClientOptions(se::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads,const absl::optional<std::set<int>> & allowed_devices)27 LocalClientOptions::LocalClientOptions(
28 se::Platform* platform, int number_of_replicas,
29 int intra_op_parallelism_threads,
30 const absl::optional<std::set<int>>& allowed_devices)
31 : platform_(platform),
32 number_of_replicas_(number_of_replicas),
33 intra_op_parallelism_threads_(intra_op_parallelism_threads),
34 allowed_devices_(allowed_devices) {}
35
set_platform(se::Platform * platform)36 LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) {
37 platform_ = platform;
38 return *this;
39 }
40
platform() const41 se::Platform* LocalClientOptions::platform() const { return platform_; }
42
set_number_of_replicas(int number_of_replicas)43 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
44 int number_of_replicas) {
45 number_of_replicas_ = number_of_replicas;
46 return *this;
47 }
48
number_of_replicas() const49 int LocalClientOptions::number_of_replicas() const {
50 return number_of_replicas_;
51 }
52
set_intra_op_parallelism_threads(int num_threads)53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
54 int num_threads) {
55 intra_op_parallelism_threads_ = num_threads;
56 return *this;
57 }
58
intra_op_parallelism_threads() const59 int LocalClientOptions::intra_op_parallelism_threads() const {
60 return intra_op_parallelism_threads_;
61 }
62
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)63 LocalClientOptions& LocalClientOptions::set_allowed_devices(
64 const absl::optional<std::set<int>>& allowed_devices) {
65 allowed_devices_ = allowed_devices;
66 return *this;
67 }
68
allowed_devices() const69 const absl::optional<std::set<int>>& LocalClientOptions::allowed_devices()
70 const {
71 return allowed_devices_;
72 }
73
Singleton()74 /* static */ ClientLibrary& ClientLibrary::Singleton() {
75 static ClientLibrary* c = new ClientLibrary;
76 return *c;
77 }
78
79 ClientLibrary::ClientLibrary() = default;
80 ClientLibrary::~ClientLibrary() = default;
81
GetOrCreateLocalClient(se::Platform * platform,const absl::optional<std::set<int>> & device_set)82 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
83 se::Platform* platform, const absl::optional<std::set<int>>& device_set) {
84 LocalClientOptions default_options;
85 default_options.set_platform(platform);
86 default_options.set_allowed_devices(device_set);
87 return GetOrCreateLocalClient(default_options);
88 }
89
GetOrCreateLocalClient(const LocalClientOptions & options)90 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
91 const LocalClientOptions& options) {
92 se::Platform* platform = options.platform();
93 int replica_count = options.number_of_replicas();
94 ClientLibrary& client_library = Singleton();
95 tensorflow::mutex_lock lock(client_library.service_mutex_);
96
97 if (platform == nullptr) {
98 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
99 }
100
101 auto it = client_library.local_instances_.find(platform->id());
102 if (it != client_library.local_instances_.end()) {
103 return it->second->client.get();
104 }
105
106 ServiceOptions service_options;
107 service_options.set_platform(platform);
108 service_options.set_number_of_replicas(replica_count);
109 service_options.set_intra_op_parallelism_threads(
110 options.intra_op_parallelism_threads());
111 service_options.set_allowed_devices(options.allowed_devices());
112 auto instance = absl::make_unique<LocalInstance>();
113 TF_ASSIGN_OR_RETURN(instance->service,
114 LocalService::NewService(service_options));
115 instance->client = absl::make_unique<LocalClient>(instance->service.get());
116 LocalClient* cl = instance->client.get();
117
118 client_library.local_instances_.insert(
119 std::make_pair(platform->id(), std::move(instance)));
120 return cl;
121 }
122
LocalClientOrDie()123 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
124 auto client_status = GetOrCreateLocalClient();
125 TF_CHECK_OK(client_status.status());
126 return client_status.ValueOrDie();
127 }
128
GetXlaService(se::Platform * platform)129 /* static */ LocalService* ClientLibrary::GetXlaService(
130 se::Platform* platform) {
131 ClientLibrary& client_library = Singleton();
132 tensorflow::mutex_lock lock(client_library.service_mutex_);
133 auto it = client_library.local_instances_.find(platform->id());
134 CHECK(it != client_library.local_instances_.end());
135 return it->second->service.get();
136 }
137
138 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(se::Platform * platform)139 ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
140 ClientLibrary& client_library = Singleton();
141 tensorflow::mutex_lock lock(client_library.service_mutex_);
142
143 if (platform == nullptr) {
144 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
145 }
146
147 auto it = client_library.compile_only_instances_.find(platform->id());
148 if (it != client_library.compile_only_instances_.end()) {
149 return it->second->client.get();
150 }
151
152 auto instance = absl::make_unique<CompileOnlyInstance>();
153 TF_ASSIGN_OR_RETURN(instance->service,
154 CompileOnlyService::NewService(platform));
155 instance->client =
156 absl::make_unique<CompileOnlyClient>(instance->service.get());
157 CompileOnlyClient* cl = instance->client.get();
158
159 client_library.compile_only_instances_.insert(
160 std::make_pair(platform->id(), std::move(instance)));
161 return cl;
162 }
163
DestroyLocalInstances()164 /* static */ void ClientLibrary::DestroyLocalInstances() {
165 ClientLibrary& client_library = Singleton();
166 tensorflow::mutex_lock lock(client_library.service_mutex_);
167
168 client_library.local_instances_.clear();
169 client_library.compile_only_instances_.clear();
170 }
171
172 } // namespace xla
173