1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/client_library.h"
17
18 #include "tensorflow/compiler/xla/service/backend.h"
19 #include "tensorflow/compiler/xla/service/platform_util.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace xla {
25
LocalClientOptions(perftools::gputools::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads)26 LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform,
27 int number_of_replicas,
28 int intra_op_parallelism_threads)
29 : platform_(platform),
30 number_of_replicas_(number_of_replicas),
31 intra_op_parallelism_threads_(intra_op_parallelism_threads) {}
32
set_platform(perftools::gputools::Platform * platform)33 LocalClientOptions& LocalClientOptions::set_platform(
34 perftools::gputools::Platform* platform) {
35 platform_ = platform;
36 return *this;
37 }
38
platform() const39 perftools::gputools::Platform* LocalClientOptions::platform() const {
40 return platform_;
41 }
42
set_number_of_replicas(int number_of_replicas)43 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
44 int number_of_replicas) {
45 number_of_replicas_ = number_of_replicas;
46 return *this;
47 }
48
number_of_replicas() const49 int LocalClientOptions::number_of_replicas() const {
50 return number_of_replicas_;
51 }
52
set_intra_op_parallelism_threads(int num_threads)53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
54 int num_threads) {
55 intra_op_parallelism_threads_ = num_threads;
56 return *this;
57 }
58
intra_op_parallelism_threads() const59 int LocalClientOptions::intra_op_parallelism_threads() const {
60 return intra_op_parallelism_threads_;
61 }
62
Singleton()63 /* static */ ClientLibrary& ClientLibrary::Singleton() {
64 static ClientLibrary* c = new ClientLibrary;
65 return *c;
66 }
67
68 ClientLibrary::ClientLibrary() = default;
69 ClientLibrary::~ClientLibrary() = default;
70
GetOrCreateLocalClient(perftools::gputools::Platform * platform)71 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
72 perftools::gputools::Platform* platform) {
73 LocalClientOptions default_options;
74 default_options.set_platform(platform);
75 return GetOrCreateLocalClient(default_options);
76 }
77
GetOrCreateLocalClient(const LocalClientOptions & options)78 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
79 const LocalClientOptions& options) {
80 perftools::gputools::Platform* platform = options.platform();
81 int replica_count = options.number_of_replicas();
82 ClientLibrary& client_library = Singleton();
83 tensorflow::mutex_lock lock(client_library.service_mutex_);
84
85 if (platform == nullptr) {
86 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
87 }
88
89 auto it = client_library.local_instances_.find(platform->id());
90 if (it != client_library.local_instances_.end()) {
91 return it->second->client.get();
92 }
93
94 ServiceOptions service_options;
95 service_options.set_platform(platform);
96 service_options.set_number_of_replicas(replica_count);
97 service_options.set_intra_op_parallelism_threads(
98 options.intra_op_parallelism_threads());
99
100 auto instance = MakeUnique<LocalInstance>();
101 TF_ASSIGN_OR_RETURN(instance->service,
102 LocalService::NewService(service_options));
103 instance->client = MakeUnique<LocalClient>(instance->service.get());
104 LocalClient* cl = instance->client.get();
105
106 client_library.local_instances_.insert(
107 std::make_pair(platform->id(), std::move(instance)));
108 return cl;
109 }
110
LocalClientOrDie()111 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
112 auto client_status = GetOrCreateLocalClient();
113 TF_CHECK_OK(client_status.status());
114 return client_status.ValueOrDie();
115 }
116
GetXlaService(perftools::gputools::Platform * platform)117 /* static */ LocalService* ClientLibrary::GetXlaService(
118 perftools::gputools::Platform* platform) {
119 ClientLibrary& client_library = Singleton();
120 tensorflow::mutex_lock lock(client_library.service_mutex_);
121 auto it = client_library.local_instances_.find(platform->id());
122 CHECK(it != client_library.local_instances_.end());
123 return it->second->service.get();
124 }
125
126 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(perftools::gputools::Platform * platform)127 ClientLibrary::GetOrCreateCompileOnlyClient(
128 perftools::gputools::Platform* platform) {
129 ClientLibrary& client_library = Singleton();
130 tensorflow::mutex_lock lock(client_library.service_mutex_);
131
132 if (platform == nullptr) {
133 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
134 }
135
136 auto it = client_library.compile_only_instances_.find(platform->id());
137 if (it != client_library.compile_only_instances_.end()) {
138 return it->second->client.get();
139 }
140
141 auto instance = MakeUnique<CompileOnlyInstance>();
142 TF_ASSIGN_OR_RETURN(instance->service,
143 CompileOnlyService::NewService(platform));
144 instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
145 CompileOnlyClient* cl = instance->client.get();
146
147 client_library.compile_only_instances_.insert(
148 std::make_pair(platform->id(), std::move(instance)));
149 return cl;
150 }
151
DestroyLocalInstances()152 /* static */ void ClientLibrary::DestroyLocalInstances() {
153 ClientLibrary& client_library = Singleton();
154 tensorflow::mutex_lock lock(client_library.service_mutex_);
155
156 client_library.local_instances_.clear();
157 client_library.compile_only_instances_.clear();
158 }
159
160 } // namespace xla
161