1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
17
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/c/tf_status_helper.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21 #include "tensorflow/stream_executor/tpu/status_helper.h"
22 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
23 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
24
25 namespace tensorflow {
26 namespace tpu {
27
28 const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
29 TpuPlatform* tpu_registered_platform = nullptr;
30
31 using Status = ::stream_executor::port::Status;
32 template <typename T>
33 using StatusOr = ::stream_executor::port::StatusOr<T>;
34
TpuPlatform()35 TpuPlatform::TpuPlatform() : name_("TPU") {
36 platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
37 }
38
GetRegisteredPlatform()39 TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
40 return tpu_registered_platform;
41 }
42
Initialize(const std::map<std::string,std::string> & platform_options)43 Status TpuPlatform::Initialize(
44 const std::map<std::string, std::string>& platform_options) {
45 StatusHelper status;
46
47 size_t options_size = platform_options.size();
48 const char** options_key =
49 static_cast<const char**>(malloc(sizeof(const char*) * options_size));
50 const char** options_value =
51 static_cast<const char**>(malloc(sizeof(const char*) * options_size));
52
53 size_t i = 0;
54 for (const auto& option : platform_options) {
55 options_key[i] = option.first.c_str();
56 options_value[i] = option.second.c_str();
57 i++;
58 }
59
60 tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
61 platform_, options_size, options_key, options_value, status.c_status);
62
63 free(options_key);
64 free(options_value);
65
66 return status.status();
67 }
68
Initialized() const69 bool TpuPlatform::Initialized() const {
70 return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
71 }
72
~TpuPlatform()73 TpuPlatform::~TpuPlatform() {
74 tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
75 }
76
VisibleDeviceCount() const77 int TpuPlatform::VisibleDeviceCount() const {
78 return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
79 }
80
GetExecutor(const::stream_executor::StreamExecutorConfig & config)81 StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
82 const ::stream_executor::StreamExecutorConfig& config) {
83 return executor_cache_.GetOrCreate(
84 config, [&]() { return GetUncachedExecutor(config); });
85 }
86
87 StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(const::stream_executor::StreamExecutorConfig & config)88 TpuPlatform::GetUncachedExecutor(
89 const ::stream_executor::StreamExecutorConfig& config) {
90 SE_StreamExecutorConfig* c_config =
91 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
92
93 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
94 config.ordinal);
95
96 StatusHelper status;
97 SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
98 platform_, c_config, status.c_status);
99 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
100 if (!status.ok()) {
101 return status.status();
102 }
103 return std::make_unique<stream_executor::StreamExecutor>(
104 this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
105 }
106
id() const107 ::stream_executor::Platform::Id TpuPlatform::id() const {
108 return TpuPlatform::kId;
109 }
110
Name() const111 const std::string& TpuPlatform::Name() const { return name_; }
112
TpuMemoryLimit()113 int64 TpuPlatform::TpuMemoryLimit() {
114 return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
115 }
116
ShouldRegisterTpuDeviceToDeviceCopy()117 bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
118 return tpu::ExecutorApiFn()
119 ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
120 }
121
GetTopologyPtr()122 const tensorflow::tpu::TpuTopologyPtr TpuPlatform::GetTopologyPtr() {
123 return tpu::ExecutorApiFn()->TpuPlatform_GetTopologyPtrFn(platform_);
124 }
125
GetTpuHostLocation() const126 const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
127 const {
128 return tpu::TpuHostLocationExternal(
129 tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
130 }
131
version() const132 TpuRuntimeVersion TpuPlatform::version() const {
133 return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
134 }
135
InsertEvent(stream_executor::internal::EventInterface * key,SE_Event * val)136 void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
137 SE_Event* val) {
138 tensorflow::mutex_lock lock(event_map_mu_);
139 event_map_[key] = val;
140 }
141
LookupEvent(stream_executor::internal::EventInterface * key)142 SE_Event* TpuPlatform::LookupEvent(
143 stream_executor::internal::EventInterface* key) {
144 tensorflow::tf_shared_lock lock(event_map_mu_);
145 return event_map_.at(key);
146 }
147
EraseEvent(stream_executor::internal::EventInterface * key)148 void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
149 tensorflow::mutex_lock lock(event_map_mu_);
150 event_map_.erase(key);
151 }
152
TpusPerHost(int * tpus)153 Status TpuPlatform::TpusPerHost(int* tpus) {
154 TF_Status* status = TF_NewStatus();
155
156 if (tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn == nullptr) {
157 *tpus = 0;
158 return Status::OK();
159 }
160
161 tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
162 auto ret_status = StatusFromTF_Status(status);
163 TF_DeleteStatus(status);
164 return ret_status;
165 }
166
TpuMemoryLimit(int64 * memory_limit)167 Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
168 TF_Status* status = TF_NewStatus();
169
170 if (tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn == nullptr) {
171 *memory_limit = 0;
172 return Status::OK();
173 }
174
175 tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
176 reinterpret_cast<int64_t*>(memory_limit), status);
177 auto ret_status = StatusFromTF_Status(status);
178 TF_DeleteStatus(status);
179 return ret_status;
180 }
181
RegisterTpuPlatform()182 bool RegisterTpuPlatform() {
183 // Silently bail if the underlying TPU C API isn't initialized. This is useful
184 // for code that unconditionally calls RegisterTpuPlatform() but doesn't link
185 // in the underlying TPU library when not running on TPU.
186 if (!tpu::IsInitialized(tpu::ExecutorApiFn())) {
187 return true;
188 }
189 static bool tpu_platform_registered = false;
190 if (!tpu_platform_registered) {
191 tpu_registered_platform = new TpuPlatform();
192 std::unique_ptr<stream_executor::Platform> platform(
193 tpu_registered_platform);
194 SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
195 std::move(platform)));
196 tpu_platform_registered = true;
197 }
198 return true;
199 }
200
201 } // namespace tpu
202 } // namespace tensorflow
203