• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/host_info.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 // OpRegistry -----------------------------------------------------------------
34 
~OpRegistryInterface()35 OpRegistryInterface::~OpRegistryInterface() {}
36 
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const37 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
38                                         const OpDef** op_def) const {
39   *op_def = nullptr;
40   const OpRegistrationData* op_reg_data = nullptr;
41   TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
42   *op_def = &op_reg_data->op_def;
43   return Status::OK();
44 }
45 
OpRegistry()46 OpRegistry::OpRegistry() : initialized_(false) {}
47 
~OpRegistry()48 OpRegistry::~OpRegistry() {
49   for (const auto& e : registry_) delete e.second;
50 }
51 
Register(const OpRegistrationDataFactory & op_data_factory)52 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
53   mutex_lock lock(mu_);
54   if (initialized_) {
55     TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
56   } else {
57     deferred_.push_back(op_data_factory);
58   }
59 }
60 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const61 Status OpRegistry::LookUp(const string& op_type_name,
62                           const OpRegistrationData** op_reg_data) const {
63   {
64     tf_shared_lock l(mu_);
65     if (initialized_) {
66       if (const OpRegistrationData* res =
67               gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
68         *op_reg_data = res;
69         return Status::OK();
70       }
71     }
72   }
73   return LookUpSlow(op_type_name, op_reg_data);
74 }
75 
LookUpSlow(const string & op_type_name,const OpRegistrationData ** op_reg_data) const76 Status OpRegistry::LookUpSlow(const string& op_type_name,
77                               const OpRegistrationData** op_reg_data) const {
78   *op_reg_data = nullptr;
79   const OpRegistrationData* res = nullptr;
80 
81   bool first_call = false;
82   bool first_unregistered = false;
83   {  // Scope for lock.
84     mutex_lock lock(mu_);
85     first_call = MustCallDeferred();
86     res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
87 
88     static bool unregistered_before = false;
89     first_unregistered = !unregistered_before && (res == nullptr);
90     if (first_unregistered) {
91       unregistered_before = true;
92     }
93     // Note: Can't hold mu_ while calling Export() below.
94   }
95   if (first_call) {
96     TF_QCHECK_OK(ValidateKernelRegistrations(*this));
97   }
98   if (res == nullptr) {
99     if (first_unregistered) {
100       OpList op_list;
101       Export(true, &op_list);
102       if (VLOG_IS_ON(3)) {
103         LOG(INFO) << "All registered Ops:";
104         for (const auto& op : op_list.op()) {
105           LOG(INFO) << SummarizeOpDef(op);
106         }
107       }
108     }
109     Status status = errors::NotFound(
110         "Op type not registered '", op_type_name, "' in binary running on ",
111         port::Hostname(), ". ",
112         "Make sure the Op and Kernel are registered in the "
113         "binary running in this process. Note that if you "
114         "are loading a saved graph which used ops from "
115         "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
116         "before importing the graph, as contrib ops are lazily registered "
117         "when the module is first accessed.");
118     VLOG(1) << status.ToString();
119     return status;
120   }
121   *op_reg_data = res;
122   return Status::OK();
123 }
124 
GetRegisteredOps(std::vector<OpDef> * op_defs)125 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
126   mutex_lock lock(mu_);
127   MustCallDeferred();
128   for (const auto& p : registry_) {
129     op_defs->push_back(p.second->op_def);
130   }
131 }
132 
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)133 void OpRegistry::GetOpRegistrationData(
134     std::vector<OpRegistrationData>* op_data) {
135   mutex_lock lock(mu_);
136   MustCallDeferred();
137   for (const auto& p : registry_) {
138     op_data->push_back(*p.second);
139   }
140 }
141 
SetWatcher(const Watcher & watcher)142 Status OpRegistry::SetWatcher(const Watcher& watcher) {
143   mutex_lock lock(mu_);
144   if (watcher_ && watcher) {
145     return errors::AlreadyExists(
146         "Cannot over-write a valid watcher with another.");
147   }
148   watcher_ = watcher;
149   return Status::OK();
150 }
151 
Export(bool include_internal,OpList * ops) const152 void OpRegistry::Export(bool include_internal, OpList* ops) const {
153   mutex_lock lock(mu_);
154   MustCallDeferred();
155 
156   std::vector<std::pair<string, const OpRegistrationData*>> sorted(
157       registry_.begin(), registry_.end());
158   std::sort(sorted.begin(), sorted.end());
159 
160   auto out = ops->mutable_op();
161   out->Clear();
162   out->Reserve(sorted.size());
163 
164   for (const auto& item : sorted) {
165     if (include_internal || !str_util::StartsWith(item.first, "_")) {
166       *out->Add() = item.second->op_def;
167     }
168   }
169 }
170 
DeferRegistrations()171 void OpRegistry::DeferRegistrations() {
172   mutex_lock lock(mu_);
173   initialized_ = false;
174 }
175 
ClearDeferredRegistrations()176 void OpRegistry::ClearDeferredRegistrations() {
177   mutex_lock lock(mu_);
178   deferred_.clear();
179 }
180 
ProcessRegistrations() const181 Status OpRegistry::ProcessRegistrations() const {
182   mutex_lock lock(mu_);
183   return CallDeferred();
184 }
185 
DebugString(bool include_internal) const186 string OpRegistry::DebugString(bool include_internal) const {
187   OpList op_list;
188   Export(include_internal, &op_list);
189   string ret;
190   for (const auto& op : op_list.op()) {
191     strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
192   }
193   return ret;
194 }
195 
MustCallDeferred() const196 bool OpRegistry::MustCallDeferred() const {
197   if (initialized_) return false;
198   initialized_ = true;
199   for (size_t i = 0; i < deferred_.size(); ++i) {
200     TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
201   }
202   deferred_.clear();
203   return true;
204 }
205 
CallDeferred() const206 Status OpRegistry::CallDeferred() const {
207   if (initialized_) return Status::OK();
208   initialized_ = true;
209   for (size_t i = 0; i < deferred_.size(); ++i) {
210     Status s = RegisterAlreadyLocked(deferred_[i]);
211     if (!s.ok()) {
212       return s;
213     }
214   }
215   deferred_.clear();
216   return Status::OK();
217 }
218 
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const219 Status OpRegistry::RegisterAlreadyLocked(
220     const OpRegistrationDataFactory& op_data_factory) const {
221   std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
222   Status s = op_data_factory(op_reg_data.get());
223   if (s.ok()) {
224     s = ValidateOpDef(op_reg_data->op_def);
225     if (s.ok() &&
226         !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
227                                  op_reg_data.get())) {
228       s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
229     }
230   }
231   Status watcher_status = s;
232   if (watcher_) {
233     watcher_status = watcher_(s, op_reg_data->op_def);
234   }
235   if (s.ok()) {
236     op_reg_data.release();
237   } else {
238     op_reg_data.reset();
239   }
240   return watcher_status;
241 }
242 
243 // static
Global()244 OpRegistry* OpRegistry::Global() {
245   static OpRegistry* global_op_registry = new OpRegistry;
246   return global_op_registry;
247 }
248 
249 // OpListOpRegistry -----------------------------------------------------------
250 
OpListOpRegistry(const OpList * op_list)251 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
252   for (const OpDef& op_def : op_list->op()) {
253     auto* op_reg_data = new OpRegistrationData();
254     op_reg_data->op_def = op_def;
255     index_[op_def.name()] = op_reg_data;
256   }
257 }
258 
~OpListOpRegistry()259 OpListOpRegistry::~OpListOpRegistry() {
260   for (const auto& e : index_) delete e.second;
261 }
262 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const263 Status OpListOpRegistry::LookUp(const string& op_type_name,
264                                 const OpRegistrationData** op_reg_data) const {
265   auto iter = index_.find(op_type_name);
266   if (iter == index_.end()) {
267     *op_reg_data = nullptr;
268     return errors::NotFound(
269         "Op type not registered '", op_type_name, "' in binary running on ",
270         port::Hostname(), ". ",
271         "Make sure the Op and Kernel are registered in the "
272         "binary running in this process. Note that if you "
273         "are loading a saved graph which used ops from "
274         "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
275         "before importing the graph, as contrib ops are lazily registered "
276         "when the module is first accessed.");
277   }
278   *op_reg_data = iter->second;
279   return Status::OK();
280 }
281 
282 // Other registration ---------------------------------------------------------
283 
284 namespace register_op {
OpDefBuilderReceiver(const OpDefBuilderWrapper<true> & wrapper)285 OpDefBuilderReceiver::OpDefBuilderReceiver(
286     const OpDefBuilderWrapper<true>& wrapper) {
287   OpRegistry::Global()->Register(
288       [wrapper](OpRegistrationData* op_reg_data) -> Status {
289         return wrapper.builder().Finalize(op_reg_data);
290       });
291 }
292 }  // namespace register_op
293 
294 }  // namespace tensorflow
295