1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/core/framework/op_def_builder.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/host_info.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace tensorflow {
33
DefaultValidator(const OpRegistryInterface & op_registry)34 Status DefaultValidator(const OpRegistryInterface& op_registry) {
35 LOG(WARNING) << "No kernel validator registered with OpRegistry.";
36 return Status::OK();
37 }
38
39 // OpRegistry -----------------------------------------------------------------
40
~OpRegistryInterface()41 OpRegistryInterface::~OpRegistryInterface() {}
42
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const43 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
44 const OpDef** op_def) const {
45 *op_def = nullptr;
46 const OpRegistrationData* op_reg_data = nullptr;
47 TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
48 *op_def = &op_reg_data->op_def;
49 return Status::OK();
50 }
51
OpRegistry()52 OpRegistry::OpRegistry()
53 : initialized_(false), op_registry_validator_(DefaultValidator) {}
54
~OpRegistry()55 OpRegistry::~OpRegistry() {
56 for (const auto& e : registry_) delete e.second;
57 }
58
Register(const OpRegistrationDataFactory & op_data_factory)59 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
60 mutex_lock lock(mu_);
61 if (initialized_) {
62 TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
63 } else {
64 deferred_.push_back(op_data_factory);
65 }
66 }
67
68 namespace {
69 // Helper function that returns Status message for failed LookUp.
OpNotFound(const string & op_type_name)70 Status OpNotFound(const string& op_type_name) {
71 Status status = errors::NotFound(
72 "Op type not registered '", op_type_name, "' in binary running on ",
73 port::Hostname(), ". ",
74 "Make sure the Op and Kernel are registered in the binary running in "
75 "this process. Note that if you are loading a saved graph which used ops "
76 "from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
77 "before importing the graph, as contrib ops are lazily registered when "
78 "the module is first accessed.");
79 VLOG(1) << status.ToString();
80 return status;
81 }
82 } // namespace
83
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const84 Status OpRegistry::LookUp(const string& op_type_name,
85 const OpRegistrationData** op_reg_data) const {
86 if ((*op_reg_data = LookUp(op_type_name))) return Status::OK();
87 return OpNotFound(op_type_name);
88 }
89
LookUp(const string & op_type_name) const90 const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const {
91 {
92 tf_shared_lock l(mu_);
93 if (initialized_) {
94 if (const OpRegistrationData* res =
95 gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
96 return res;
97 }
98 }
99 }
100 return LookUpSlow(op_type_name);
101 }
102
LookUpSlow(const string & op_type_name) const103 const OpRegistrationData* OpRegistry::LookUpSlow(
104 const string& op_type_name) const {
105 const OpRegistrationData* res = nullptr;
106
107 bool first_call = false;
108 bool first_unregistered = false;
109 { // Scope for lock.
110 mutex_lock lock(mu_);
111 first_call = MustCallDeferred();
112 res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
113
114 static bool unregistered_before = false;
115 first_unregistered = !unregistered_before && (res == nullptr);
116 if (first_unregistered) {
117 unregistered_before = true;
118 }
119 // Note: Can't hold mu_ while calling Export() below.
120 }
121 if (first_call) {
122 TF_QCHECK_OK(op_registry_validator_(*this));
123 }
124 if (res == nullptr) {
125 if (first_unregistered) {
126 OpList op_list;
127 Export(true, &op_list);
128 if (VLOG_IS_ON(3)) {
129 LOG(INFO) << "All registered Ops:";
130 for (const auto& op : op_list.op()) {
131 LOG(INFO) << SummarizeOpDef(op);
132 }
133 }
134 }
135 }
136 return res;
137 }
138
GetRegisteredOps(std::vector<OpDef> * op_defs)139 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
140 mutex_lock lock(mu_);
141 MustCallDeferred();
142 for (const auto& p : registry_) {
143 op_defs->push_back(p.second->op_def);
144 }
145 }
146
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)147 void OpRegistry::GetOpRegistrationData(
148 std::vector<OpRegistrationData>* op_data) {
149 mutex_lock lock(mu_);
150 MustCallDeferred();
151 for (const auto& p : registry_) {
152 op_data->push_back(*p.second);
153 }
154 }
155
SetWatcher(const Watcher & watcher)156 Status OpRegistry::SetWatcher(const Watcher& watcher) {
157 mutex_lock lock(mu_);
158 if (watcher_ && watcher) {
159 return errors::AlreadyExists(
160 "Cannot over-write a valid watcher with another.");
161 }
162 watcher_ = watcher;
163 return Status::OK();
164 }
165
Export(bool include_internal,OpList * ops) const166 void OpRegistry::Export(bool include_internal, OpList* ops) const {
167 mutex_lock lock(mu_);
168 MustCallDeferred();
169
170 std::vector<std::pair<string, const OpRegistrationData*>> sorted(
171 registry_.begin(), registry_.end());
172 std::sort(sorted.begin(), sorted.end());
173
174 auto out = ops->mutable_op();
175 out->Clear();
176 out->Reserve(sorted.size());
177
178 for (const auto& item : sorted) {
179 if (include_internal || !absl::StartsWith(item.first, "_")) {
180 *out->Add() = item.second->op_def;
181 }
182 }
183 }
184
DeferRegistrations()185 void OpRegistry::DeferRegistrations() {
186 mutex_lock lock(mu_);
187 initialized_ = false;
188 }
189
ClearDeferredRegistrations()190 void OpRegistry::ClearDeferredRegistrations() {
191 mutex_lock lock(mu_);
192 deferred_.clear();
193 }
194
ProcessRegistrations() const195 Status OpRegistry::ProcessRegistrations() const {
196 mutex_lock lock(mu_);
197 return CallDeferred();
198 }
199
DebugString(bool include_internal) const200 string OpRegistry::DebugString(bool include_internal) const {
201 OpList op_list;
202 Export(include_internal, &op_list);
203 string ret;
204 for (const auto& op : op_list.op()) {
205 strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
206 }
207 return ret;
208 }
209
MustCallDeferred() const210 bool OpRegistry::MustCallDeferred() const {
211 if (initialized_) return false;
212 initialized_ = true;
213 for (size_t i = 0; i < deferred_.size(); ++i) {
214 TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
215 }
216 deferred_.clear();
217 return true;
218 }
219
CallDeferred() const220 Status OpRegistry::CallDeferred() const {
221 if (initialized_) return Status::OK();
222 initialized_ = true;
223 for (size_t i = 0; i < deferred_.size(); ++i) {
224 Status s = RegisterAlreadyLocked(deferred_[i]);
225 if (!s.ok()) {
226 return s;
227 }
228 }
229 deferred_.clear();
230 return Status::OK();
231 }
232
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const233 Status OpRegistry::RegisterAlreadyLocked(
234 const OpRegistrationDataFactory& op_data_factory) const {
235 std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
236 Status s = op_data_factory(op_reg_data.get());
237 if (s.ok()) {
238 s = ValidateOpDef(op_reg_data->op_def);
239 if (s.ok() &&
240 !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(),
241 op_reg_data.get())) {
242 s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
243 }
244 }
245 Status watcher_status = s;
246 if (watcher_) {
247 watcher_status = watcher_(s, op_reg_data->op_def);
248 }
249 if (s.ok()) {
250 op_reg_data.release();
251 } else {
252 op_reg_data.reset();
253 }
254 return watcher_status;
255 }
256
257 // static
Global()258 OpRegistry* OpRegistry::Global() {
259 static OpRegistry* global_op_registry = new OpRegistry;
260 return global_op_registry;
261 }
262
263 // OpListOpRegistry -----------------------------------------------------------
264
OpListOpRegistry(const OpList * op_list)265 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
266 for (const OpDef& op_def : op_list->op()) {
267 auto* op_reg_data = new OpRegistrationData();
268 op_reg_data->op_def = op_def;
269 index_[op_def.name()] = op_reg_data;
270 }
271 }
272
~OpListOpRegistry()273 OpListOpRegistry::~OpListOpRegistry() {
274 for (const auto& e : index_) delete e.second;
275 }
276
LookUp(const string & op_type_name) const277 const OpRegistrationData* OpListOpRegistry::LookUp(
278 const string& op_type_name) const {
279 auto iter = index_.find(op_type_name);
280 if (iter == index_.end()) {
281 return nullptr;
282 }
283 return iter->second;
284 }
285
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const286 Status OpListOpRegistry::LookUp(const string& op_type_name,
287 const OpRegistrationData** op_reg_data) const {
288 if ((*op_reg_data = LookUp(op_type_name))) return Status::OK();
289 return OpNotFound(op_type_name);
290 }
291
292 namespace register_op {
293
operator ()()294 InitOnStartupMarker OpDefBuilderWrapper::operator()() {
295 OpRegistry::Global()->Register(
296 [builder =
297 std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
298 return builder.Finalize(op_reg_data);
299 });
300 return {};
301 }
302
303 } // namespace register_op
304
305 } // namespace tensorflow
306