1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/host_info.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace tensorflow {
32
33 // OpRegistry -----------------------------------------------------------------
34
~OpRegistryInterface()35 OpRegistryInterface::~OpRegistryInterface() {}
36
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const37 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
38 const OpDef** op_def) const {
39 *op_def = nullptr;
40 const OpRegistrationData* op_reg_data = nullptr;
41 TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
42 *op_def = &op_reg_data->op_def;
43 return Status::OK();
44 }
45
OpRegistry()46 OpRegistry::OpRegistry() : initialized_(false) {}
47
~OpRegistry()48 OpRegistry::~OpRegistry() {
49 for (const auto& e : registry_) delete e.second;
50 }
51
Register(const OpRegistrationDataFactory & op_data_factory)52 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
53 mutex_lock lock(mu_);
54 if (initialized_) {
55 TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
56 } else {
57 deferred_.push_back(op_data_factory);
58 }
59 }
60
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const61 Status OpRegistry::LookUp(const string& op_type_name,
62 const OpRegistrationData** op_reg_data) const {
63 {
64 tf_shared_lock l(mu_);
65 if (initialized_) {
66 if (const OpRegistrationData* res =
67 gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
68 *op_reg_data = res;
69 return Status::OK();
70 }
71 }
72 }
73 return LookUpSlow(op_type_name, op_reg_data);
74 }
75
LookUpSlow(const string & op_type_name,const OpRegistrationData ** op_reg_data) const76 Status OpRegistry::LookUpSlow(const string& op_type_name,
77 const OpRegistrationData** op_reg_data) const {
78 *op_reg_data = nullptr;
79 const OpRegistrationData* res = nullptr;
80
81 bool first_call = false;
82 bool first_unregistered = false;
83 { // Scope for lock.
84 mutex_lock lock(mu_);
85 first_call = MustCallDeferred();
86 res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
87
88 static bool unregistered_before = false;
89 first_unregistered = !unregistered_before && (res == nullptr);
90 if (first_unregistered) {
91 unregistered_before = true;
92 }
93 // Note: Can't hold mu_ while calling Export() below.
94 }
95 if (first_call) {
96 TF_QCHECK_OK(ValidateKernelRegistrations(*this));
97 }
98 if (res == nullptr) {
99 if (first_unregistered) {
100 OpList op_list;
101 Export(true, &op_list);
102 if (VLOG_IS_ON(3)) {
103 LOG(INFO) << "All registered Ops:";
104 for (const auto& op : op_list.op()) {
105 LOG(INFO) << SummarizeOpDef(op);
106 }
107 }
108 }
109 Status status = errors::NotFound(
110 "Op type not registered '", op_type_name, "' in binary running on ",
111 port::Hostname(), ". ",
112 "Make sure the Op and Kernel are registered in the "
113 "binary running in this process. Note that if you "
114 "are loading a saved graph which used ops from "
115 "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
116 "before importing the graph, as contrib ops are lazily registered "
117 "when the module is first accessed.");
118 VLOG(1) << status.ToString();
119 return status;
120 }
121 *op_reg_data = res;
122 return Status::OK();
123 }
124
GetRegisteredOps(std::vector<OpDef> * op_defs)125 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
126 mutex_lock lock(mu_);
127 MustCallDeferred();
128 for (const auto& p : registry_) {
129 op_defs->push_back(p.second->op_def);
130 }
131 }
132
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)133 void OpRegistry::GetOpRegistrationData(
134 std::vector<OpRegistrationData>* op_data) {
135 mutex_lock lock(mu_);
136 MustCallDeferred();
137 for (const auto& p : registry_) {
138 op_data->push_back(*p.second);
139 }
140 }
141
SetWatcher(const Watcher & watcher)142 Status OpRegistry::SetWatcher(const Watcher& watcher) {
143 mutex_lock lock(mu_);
144 if (watcher_ && watcher) {
145 return errors::AlreadyExists(
146 "Cannot over-write a valid watcher with another.");
147 }
148 watcher_ = watcher;
149 return Status::OK();
150 }
151
Export(bool include_internal,OpList * ops) const152 void OpRegistry::Export(bool include_internal, OpList* ops) const {
153 mutex_lock lock(mu_);
154 MustCallDeferred();
155
156 std::vector<std::pair<string, const OpRegistrationData*>> sorted(
157 registry_.begin(), registry_.end());
158 std::sort(sorted.begin(), sorted.end());
159
160 auto out = ops->mutable_op();
161 out->Clear();
162 out->Reserve(sorted.size());
163
164 for (const auto& item : sorted) {
165 if (include_internal || !str_util::StartsWith(item.first, "_")) {
166 *out->Add() = item.second->op_def;
167 }
168 }
169 }
170
DeferRegistrations()171 void OpRegistry::DeferRegistrations() {
172 mutex_lock lock(mu_);
173 initialized_ = false;
174 }
175
ClearDeferredRegistrations()176 void OpRegistry::ClearDeferredRegistrations() {
177 mutex_lock lock(mu_);
178 deferred_.clear();
179 }
180
ProcessRegistrations() const181 Status OpRegistry::ProcessRegistrations() const {
182 mutex_lock lock(mu_);
183 return CallDeferred();
184 }
185
DebugString(bool include_internal) const186 string OpRegistry::DebugString(bool include_internal) const {
187 OpList op_list;
188 Export(include_internal, &op_list);
189 string ret;
190 for (const auto& op : op_list.op()) {
191 strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
192 }
193 return ret;
194 }
195
MustCallDeferred() const196 bool OpRegistry::MustCallDeferred() const {
197 if (initialized_) return false;
198 initialized_ = true;
199 for (size_t i = 0; i < deferred_.size(); ++i) {
200 TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
201 }
202 deferred_.clear();
203 return true;
204 }
205
CallDeferred() const206 Status OpRegistry::CallDeferred() const {
207 if (initialized_) return Status::OK();
208 initialized_ = true;
209 for (size_t i = 0; i < deferred_.size(); ++i) {
210 Status s = RegisterAlreadyLocked(deferred_[i]);
211 if (!s.ok()) {
212 return s;
213 }
214 }
215 deferred_.clear();
216 return Status::OK();
217 }
218
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const219 Status OpRegistry::RegisterAlreadyLocked(
220 const OpRegistrationDataFactory& op_data_factory) const {
221 std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
222 Status s = op_data_factory(op_reg_data.get());
223 if (s.ok()) {
224 s = ValidateOpDef(op_reg_data->op_def);
225 if (s.ok() &&
226 !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(),
227 op_reg_data.get())) {
228 s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
229 }
230 }
231 Status watcher_status = s;
232 if (watcher_) {
233 watcher_status = watcher_(s, op_reg_data->op_def);
234 }
235 if (s.ok()) {
236 op_reg_data.release();
237 } else {
238 op_reg_data.reset();
239 }
240 return watcher_status;
241 }
242
243 // static
Global()244 OpRegistry* OpRegistry::Global() {
245 static OpRegistry* global_op_registry = new OpRegistry;
246 return global_op_registry;
247 }
248
249 // OpListOpRegistry -----------------------------------------------------------
250
OpListOpRegistry(const OpList * op_list)251 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
252 for (const OpDef& op_def : op_list->op()) {
253 auto* op_reg_data = new OpRegistrationData();
254 op_reg_data->op_def = op_def;
255 index_[op_def.name()] = op_reg_data;
256 }
257 }
258
~OpListOpRegistry()259 OpListOpRegistry::~OpListOpRegistry() {
260 for (const auto& e : index_) delete e.second;
261 }
262
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const263 Status OpListOpRegistry::LookUp(const string& op_type_name,
264 const OpRegistrationData** op_reg_data) const {
265 auto iter = index_.find(op_type_name);
266 if (iter == index_.end()) {
267 *op_reg_data = nullptr;
268 return errors::NotFound(
269 "Op type not registered '", op_type_name, "' in binary running on ",
270 port::Hostname(), ". ",
271 "Make sure the Op and Kernel are registered in the "
272 "binary running in this process. Note that if you "
273 "are loading a saved graph which used ops from "
274 "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
275 "before importing the graph, as contrib ops are lazily registered "
276 "when the module is first accessed.");
277 }
278 *op_reg_data = iter->second;
279 return Status::OK();
280 }
281
282 // Other registration ---------------------------------------------------------
283
284 namespace register_op {
OpDefBuilderReceiver(const OpDefBuilderWrapper<true> & wrapper)285 OpDefBuilderReceiver::OpDefBuilderReceiver(
286 const OpDefBuilderWrapper<true>& wrapper) {
287 OpRegistry::Global()->Register(
288 [wrapper](OpRegistrationData* op_reg_data) -> Status {
289 return wrapper.builder().Finalize(op_reg_data);
290 });
291 }
292 } // namespace register_op
293
294 } // namespace tensorflow
295