• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18 
19 #include <functional>
20 #include <unordered_map>
21 
22 #include <vector>
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/selective_registration.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
38 // Users that want to look up an OpDef by type name should take an
39 // OpRegistryInterface.  Functions accepting a
40 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
41 class OpRegistryInterface {
42  public:
43   virtual ~OpRegistryInterface();
44 
45   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
46   // registered under that name, otherwise returns the registered OpDef.
47   // Caller must not delete the returned pointer.
48   virtual Status LookUp(const string& op_type_name,
49                         const OpRegistrationData** op_reg_data) const = 0;
50 
51   // Shorthand for calling LookUp to get the OpDef.
52   Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
53 };
54 
55 // The standard implementation of OpRegistryInterface, along with a
56 // global singleton used for registering ops via the REGISTER
57 // macros below.  Thread-safe.
58 //
59 // Example registration:
60 //   OpRegistry::Global()->Register(
61 //     [](OpRegistrationData* op_reg_data)->Status {
62 //       // Populate *op_reg_data here.
63 //       return Status::OK();
64 //   });
65 class OpRegistry : public OpRegistryInterface {
66  public:
67   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
68 
69   OpRegistry();
70   ~OpRegistry() override;
71 
72   void Register(const OpRegistrationDataFactory& op_data_factory);
73 
74   Status LookUp(const string& op_type_name,
75                 const OpRegistrationData** op_reg_data) const override;
76 
77   // Returns OpRegistrationData* of registered op type, else returns nullptr.
78   const OpRegistrationData* LookUp(const string& op_type_name) const;
79 
80   // Fills *ops with all registered OpDefs (except those with names
81   // starting with '_' if include_internal == false) sorted in
82   // ascending alphabetical order.
83   void Export(bool include_internal, OpList* ops) const;
84 
85   // Returns ASCII-format OpList for all registered OpDefs (except
86   // those with names starting with '_' if include_internal == false).
87   string DebugString(bool include_internal) const;
88 
89   // A singleton available at startup.
90   static OpRegistry* Global();
91 
92   // Get all registered ops.
93   void GetRegisteredOps(std::vector<OpDef>* op_defs);
94 
95   // Get all `OpRegistrationData`s.
96   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
97 
98   // Registers a function that validates op registry.
RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)99   void RegisterValidator(
100       std::function<Status(const OpRegistryInterface&)> validator) {
101     op_registry_validator_ = std::move(validator);
102   }
103 
104   // Watcher, a function object.
105   // The watcher, if set by SetWatcher(), is called every time an op is
106   // registered via the Register function. The watcher is passed the Status
107   // obtained from building and adding the OpDef to the registry, and the OpDef
108   // itself if it was successfully built. A watcher returns a Status which is in
109   // turn returned as the final registration status.
110   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
111 
112   // An OpRegistry object has only one watcher. This interface is not thread
113   // safe, as different clients are free to set the watcher any time.
114   // Clients are expected to atomically perform the following sequence of
115   // operations :
116   // SetWatcher(a_watcher);
117   // Register some ops;
118   // op_registry->ProcessRegistrations();
119   // SetWatcher(nullptr);
120   // Returns a non-OK status if a non-null watcher is over-written by another
121   // non-null watcher.
122   Status SetWatcher(const Watcher& watcher);
123 
124   // Process the current list of deferred registrations. Note that calls to
125   // Export, LookUp and DebugString would also implicitly process the deferred
126   // registrations. Returns the status of the first failed op registration or
127   // Status::OK() otherwise.
128   Status ProcessRegistrations() const;
129 
130   // Defer the registrations until a later call to a function that processes
131   // deferred registrations are made. Normally, registrations that happen after
132   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
133   // immediately. Call this to defer future registrations.
134   void DeferRegistrations();
135 
136   // Clear the registrations that have been deferred.
137   void ClearDeferredRegistrations();
138 
139  private:
140   // Ensures that all the functions in deferred_ get called, their OpDef's
141   // registered, and returns with deferred_ empty.  Returns true the first
142   // time it is called. Prints a fatal log if any op registration fails.
143   bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
144 
145   // Calls the functions in deferred_ and registers their OpDef's
146   // It returns the Status of the first failed op registration or Status::OK()
147   // otherwise.
148   Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
149 
150   // Add 'def' to the registry with additional data 'data'. On failure, or if
151   // there is already an OpDef with that name registered, returns a non-okay
152   // status.
153   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
154       const EXCLUSIVE_LOCKS_REQUIRED(mu_);
155 
156   const OpRegistrationData* LookUpSlow(const string& op_type_name) const;
157 
158   mutable mutex mu_;
159   // Functions in deferred_ may only be called with mu_ held.
160   mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
161   // Values are owned.
162   mutable std::unordered_map<string, const OpRegistrationData*> registry_
163       GUARDED_BY(mu_);
164   mutable bool initialized_ GUARDED_BY(mu_);
165 
166   // Registry watcher.
167   mutable Watcher watcher_ GUARDED_BY(mu_);
168 
169   std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
170 };
171 
172 // An adapter to allow an OpList to be used as an OpRegistryInterface.
173 //
174 // Note that shape inference functions are not passed in to OpListOpRegistry, so
175 // it will return an unusable shape inference function for every op it supports;
176 // therefore, it should only be used in contexts where this is okay.
177 class OpListOpRegistry : public OpRegistryInterface {
178  public:
179   // Does not take ownership of op_list, *op_list must outlive *this.
180   explicit OpListOpRegistry(const OpList* op_list);
181   ~OpListOpRegistry() override;
182   Status LookUp(const string& op_type_name,
183                 const OpRegistrationData** op_reg_data) const override;
184 
185   // Returns OpRegistrationData* of op type in list, else returns nullptr.
186   const OpRegistrationData* LookUp(const string& op_type_name) const;
187 
188  private:
189   // Values are owned.
190   std::unordered_map<string, const OpRegistrationData*> index_;
191 };
192 
193 // Support for defining the OpDef (specifying the semantics of the Op and how
194 // it should be created) and registering it in the OpRegistry::Global()
195 // registry.  Usage:
196 //
197 // REGISTER_OP("my_op_name")
198 //     .Attr("<name>:<type>")
199 //     .Attr("<name>:<type>=<default>")
200 //     .Input("<name>:<type-expr>")
201 //     .Input("<name>:Ref(<type-expr>)")
202 //     .Output("<name>:<type-expr>")
203 //     .Doc(R"(
204 // <1-line summary>
205 // <rest of the description (potentially many lines)>
206 // <name-of-attr-input-or-output>: <description of name>
207 // <name-of-attr-input-or-output>: <description of name;
208 //   if long, indent the description on subsequent lines>
209 // )");
210 //
211 // Note: .Doc() should be last.
212 // For details, see the OpDefBuilder class in op_def_builder.h.
213 
214 namespace register_op {
215 
216 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
217 // calls. This allows the result of REGISTER_OP to be used in chaining, as in
218 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
219 // registration to turn the entire call-chain into a no-op.
220 template <bool should_register>
221 class OpDefBuilderWrapper;
222 
223 // Template specialization that forwards all calls to the contained builder.
224 template <>
225 class OpDefBuilderWrapper<true> {
226  public:
OpDefBuilderWrapper(const char name[])227   explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
Attr(string spec)228   OpDefBuilderWrapper<true>& Attr(string spec) {
229     builder_.Attr(std::move(spec));
230     return *this;
231   }
Input(string spec)232   OpDefBuilderWrapper<true>& Input(string spec) {
233     builder_.Input(std::move(spec));
234     return *this;
235   }
Output(string spec)236   OpDefBuilderWrapper<true>& Output(string spec) {
237     builder_.Output(std::move(spec));
238     return *this;
239   }
SetIsCommutative()240   OpDefBuilderWrapper<true>& SetIsCommutative() {
241     builder_.SetIsCommutative();
242     return *this;
243   }
SetIsAggregate()244   OpDefBuilderWrapper<true>& SetIsAggregate() {
245     builder_.SetIsAggregate();
246     return *this;
247   }
SetIsStateful()248   OpDefBuilderWrapper<true>& SetIsStateful() {
249     builder_.SetIsStateful();
250     return *this;
251   }
SetAllowsUninitializedInput()252   OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
253     builder_.SetAllowsUninitializedInput();
254     return *this;
255   }
Deprecated(int version,string explanation)256   OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
257     builder_.Deprecated(version, std::move(explanation));
258     return *this;
259   }
Doc(string text)260   OpDefBuilderWrapper<true>& Doc(string text) {
261     builder_.Doc(std::move(text));
262     return *this;
263   }
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))264   OpDefBuilderWrapper<true>& SetShapeFn(
265       Status (*fn)(shape_inference::InferenceContext*)) {
266     builder_.SetShapeFn(fn);
267     return *this;
268   }
builder()269   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
270 
271  private:
272   mutable ::tensorflow::OpDefBuilder builder_;
273 };
274 
275 // Template specialization that turns all calls into no-ops.
276 template <>
277 class OpDefBuilderWrapper<false> {
278  public:
OpDefBuilderWrapper(const char name[])279   explicit constexpr OpDefBuilderWrapper(const char name[]) {}
Attr(StringPiece spec)280   OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
Input(StringPiece spec)281   OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
Output(StringPiece spec)282   OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
SetIsCommutative()283   OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
SetIsAggregate()284   OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
SetIsStateful()285   OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
SetAllowsUninitializedInput()286   OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
Deprecated(int,StringPiece)287   OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
Doc(StringPiece text)288   OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))289   OpDefBuilderWrapper<false>& SetShapeFn(
290       Status (*fn)(shape_inference::InferenceContext*)) {
291     return *this;
292   }
293 };
294 
295 struct OpDefBuilderReceiver {
296   // To call OpRegistry::Global()->Register(...), used by the
297   // REGISTER_OP macro below.
298   // Note: These are implicitly converting constructors.
299   OpDefBuilderReceiver(
300       const OpDefBuilderWrapper<true>& wrapper);  // NOLINT(runtime/explicit)
OpDefBuilderReceiverOpDefBuilderReceiver301   constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
302   }  // NOLINT(runtime/explicit)
303 };
304 }  // namespace register_op
305 
306 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
307 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
308 #define REGISTER_OP_UNIQ(ctr, name)                                          \
309   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
310       TF_ATTRIBUTE_UNUSED =                                                  \
311           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
312               name)>(name)
313 
314 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
315 // that the op is registered unconditionally even when selective
316 // registration is used.
317 #define REGISTER_SYSTEM_OP(name) \
318   REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name)
319 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
320   REGISTER_SYSTEM_OP_UNIQ(ctr, name)
321 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name)                                \
322   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
323       TF_ATTRIBUTE_UNUSED =                                               \
324           ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
325 
326 }  // namespace tensorflow
327 
328 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
329