• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
17 
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/c/tf_status_helper.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21 #include "tensorflow/stream_executor/tpu/status_helper.h"
22 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
23 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
24 
25 namespace tensorflow {
26 namespace tpu {
27 
28 const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
29 TpuPlatform* tpu_registered_platform = nullptr;
30 
31 using Status = ::stream_executor::port::Status;
32 template <typename T>
33 using StatusOr = ::stream_executor::port::StatusOr<T>;
34 
TpuPlatform()35 TpuPlatform::TpuPlatform() : name_("TPU") {
36   platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
37 }
38 
GetRegisteredPlatform()39 TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
40   return tpu_registered_platform;
41 }
42 
Initialize(const std::map<std::string,std::string> & platform_options)43 Status TpuPlatform::Initialize(
44     const std::map<std::string, std::string>& platform_options) {
45   StatusHelper status;
46 
47   size_t options_size = platform_options.size();
48   const char** options_key =
49       static_cast<const char**>(malloc(sizeof(const char*) * options_size));
50   const char** options_value =
51       static_cast<const char**>(malloc(sizeof(const char*) * options_size));
52 
53   size_t i = 0;
54   for (const auto& option : platform_options) {
55     options_key[i] = option.first.c_str();
56     options_value[i] = option.second.c_str();
57     i++;
58   }
59 
60   tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
61       platform_, options_size, options_key, options_value, status.c_status);
62 
63   free(options_key);
64   free(options_value);
65 
66   return status.status();
67 }
68 
Initialized() const69 bool TpuPlatform::Initialized() const {
70   return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
71 }
72 
~TpuPlatform()73 TpuPlatform::~TpuPlatform() {
74   tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
75 }
76 
VisibleDeviceCount() const77 int TpuPlatform::VisibleDeviceCount() const {
78   return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
79 }
80 
GetExecutor(const::stream_executor::StreamExecutorConfig & config)81 StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
82     const ::stream_executor::StreamExecutorConfig& config) {
83   return executor_cache_.GetOrCreate(
84       config, [&]() { return GetUncachedExecutor(config); });
85 }
86 
87 StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(const::stream_executor::StreamExecutorConfig & config)88 TpuPlatform::GetUncachedExecutor(
89     const ::stream_executor::StreamExecutorConfig& config) {
90   SE_StreamExecutorConfig* c_config =
91       tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
92 
93   tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
94                                                              config.ordinal);
95 
96   StatusHelper status;
97   SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
98       platform_, c_config, status.c_status);
99   tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
100   if (!status.ok()) {
101     return status.status();
102   }
103   return std::make_unique<stream_executor::StreamExecutor>(
104       this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
105 }
106 
id() const107 ::stream_executor::Platform::Id TpuPlatform::id() const {
108   return TpuPlatform::kId;
109 }
110 
Name() const111 const std::string& TpuPlatform::Name() const { return name_; }
112 
TpuMemoryLimit()113 int64 TpuPlatform::TpuMemoryLimit() {
114   return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
115 }
116 
ShouldRegisterTpuDeviceToDeviceCopy()117 bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
118   return tpu::ExecutorApiFn()
119       ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
120 }
121 
GetTopologyPtr()122 const tensorflow::tpu::TpuTopologyPtr TpuPlatform::GetTopologyPtr() {
123   return tpu::ExecutorApiFn()->TpuPlatform_GetTopologyPtrFn(platform_);
124 }
125 
GetTpuHostLocation() const126 const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
127     const {
128   return tpu::TpuHostLocationExternal(
129       tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
130 }
131 
version() const132 TpuRuntimeVersion TpuPlatform::version() const {
133   return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
134 }
135 
InsertEvent(stream_executor::internal::EventInterface * key,SE_Event * val)136 void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
137                               SE_Event* val) {
138   tensorflow::mutex_lock lock(event_map_mu_);
139   event_map_[key] = val;
140 }
141 
LookupEvent(stream_executor::internal::EventInterface * key)142 SE_Event* TpuPlatform::LookupEvent(
143     stream_executor::internal::EventInterface* key) {
144   tensorflow::tf_shared_lock lock(event_map_mu_);
145   return event_map_.at(key);
146 }
147 
EraseEvent(stream_executor::internal::EventInterface * key)148 void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
149   tensorflow::mutex_lock lock(event_map_mu_);
150   event_map_.erase(key);
151 }
152 
TpusPerHost(int * tpus)153 Status TpuPlatform::TpusPerHost(int* tpus) {
154   TF_Status* status = TF_NewStatus();
155 
156   if (tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn == nullptr) {
157     *tpus = 0;
158     return Status::OK();
159   }
160 
161   tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
162   auto ret_status = StatusFromTF_Status(status);
163   TF_DeleteStatus(status);
164   return ret_status;
165 }
166 
TpuMemoryLimit(int64 * memory_limit)167 Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
168   TF_Status* status = TF_NewStatus();
169 
170   if (tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn == nullptr) {
171     *memory_limit = 0;
172     return Status::OK();
173   }
174 
175   tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
176       reinterpret_cast<int64_t*>(memory_limit), status);
177   auto ret_status = StatusFromTF_Status(status);
178   TF_DeleteStatus(status);
179   return ret_status;
180 }
181 
RegisterTpuPlatform()182 bool RegisterTpuPlatform() {
183   // Silently bail if the underlying TPU C API isn't initialized. This is useful
184   // for code that unconditionally calls RegisterTpuPlatform() but doesn't link
185   // in the underlying TPU library when not running on TPU.
186   if (!tpu::IsInitialized(tpu::ExecutorApiFn())) {
187     return true;
188   }
189   static bool tpu_platform_registered = false;
190   if (!tpu_platform_registered) {
191     tpu_registered_platform = new TpuPlatform();
192     std::unique_ptr<stream_executor::Platform> platform(
193         tpu_registered_platform);
194     SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
195         std::move(platform)));
196     tpu_platform_registered = true;
197   }
198   return true;
199 }
200 
201 }  // namespace tpu
202 }  // namespace tensorflow
203