• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18 
19 #include <functional>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/full_type.pb.h"
24 #include "tensorflow/core/framework/full_type_util.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_def_util.h"
27 #include "tensorflow/core/framework/registration/registration.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 
40 // Users that want to look up an OpDef by type name should take an
41 // OpRegistryInterface.  Functions accepting a
42 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
43 class OpRegistryInterface {
44  public:
45   virtual ~OpRegistryInterface();
46 
47   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
48   // registered under that name, otherwise returns the registered OpDef.
49   // Caller must not delete the returned pointer.
50   virtual Status LookUp(const std::string& op_type_name,
51                         const OpRegistrationData** op_reg_data) const = 0;
52 
53   // Shorthand for calling LookUp to get the OpDef.
54   Status LookUpOpDef(const std::string& op_type_name,
55                      const OpDef** op_def) const;
56 };
57 
58 // The standard implementation of OpRegistryInterface, along with a
59 // global singleton used for registering ops via the REGISTER
60 // macros below.  Thread-safe.
61 //
62 // Example registration:
63 //   OpRegistry::Global()->Register(
64 //     [](OpRegistrationData* op_reg_data)->Status {
65 //       // Populate *op_reg_data here.
66 //       return Status::OK();
67 //   });
68 class OpRegistry : public OpRegistryInterface {
69  public:
70   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
71 
72   OpRegistry();
73   ~OpRegistry() override;
74 
75   void Register(const OpRegistrationDataFactory& op_data_factory);
76 
77   Status LookUp(const std::string& op_type_name,
78                 const OpRegistrationData** op_reg_data) const override;
79 
80   // Returns OpRegistrationData* of registered op type, else returns nullptr.
81   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
82 
83   // Fills *ops with all registered OpDefs (except those with names
84   // starting with '_' if include_internal == false) sorted in
85   // ascending alphabetical order.
86   void Export(bool include_internal, OpList* ops) const;
87 
88   // Returns ASCII-format OpList for all registered OpDefs (except
89   // those with names starting with '_' if include_internal == false).
90   std::string DebugString(bool include_internal) const;
91 
92   // A singleton available at startup.
93   static OpRegistry* Global();
94 
95   // Get all registered ops.
96   void GetRegisteredOps(std::vector<OpDef>* op_defs);
97 
98   // Get all `OpRegistrationData`s.
99   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
100 
101   // Registers a function that validates op registry.
RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)102   void RegisterValidator(
103       std::function<Status(const OpRegistryInterface&)> validator) {
104     op_registry_validator_ = std::move(validator);
105   }
106 
107   // Watcher, a function object.
108   // The watcher, if set by SetWatcher(), is called every time an op is
109   // registered via the Register function. The watcher is passed the Status
110   // obtained from building and adding the OpDef to the registry, and the OpDef
111   // itself if it was successfully built. A watcher returns a Status which is in
112   // turn returned as the final registration status.
113   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
114 
115   // An OpRegistry object has only one watcher. This interface is not thread
116   // safe, as different clients are free to set the watcher any time.
117   // Clients are expected to atomically perform the following sequence of
118   // operations :
119   // SetWatcher(a_watcher);
120   // Register some ops;
121   // op_registry->ProcessRegistrations();
122   // SetWatcher(nullptr);
123   // Returns a non-OK status if a non-null watcher is over-written by another
124   // non-null watcher.
125   Status SetWatcher(const Watcher& watcher);
126 
127   // Process the current list of deferred registrations. Note that calls to
128   // Export, LookUp and DebugString would also implicitly process the deferred
129   // registrations. Returns the status of the first failed op registration or
130   // Status::OK() otherwise.
131   Status ProcessRegistrations() const;
132 
133   // Defer the registrations until a later call to a function that processes
134   // deferred registrations are made. Normally, registrations that happen after
135   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
136   // immediately. Call this to defer future registrations.
137   void DeferRegistrations();
138 
139   // Clear the registrations that have been deferred.
140   void ClearDeferredRegistrations();
141 
142  private:
143   // Ensures that all the functions in deferred_ get called, their OpDef's
144   // registered, and returns with deferred_ empty.  Returns true the first
145   // time it is called. Prints a fatal log if any op registration fails.
146   bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
147 
148   // Calls the functions in deferred_ and registers their OpDef's
149   // It returns the Status of the first failed op registration or Status::OK()
150   // otherwise.
151   Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
152 
153   // Add 'def' to the registry with additional data 'data'. On failure, or if
154   // there is already an OpDef with that name registered, returns a non-okay
155   // status.
156   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
157       const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
158 
159   const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
160 
161   mutable mutex mu_;
162   // Functions in deferred_ may only be called with mu_ held.
163   mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
164   // Values are owned.
165   mutable std::unordered_map<string, const OpRegistrationData*> registry_
166       TF_GUARDED_BY(mu_);
167   mutable bool initialized_ TF_GUARDED_BY(mu_);
168 
169   // Registry watcher.
170   mutable Watcher watcher_ TF_GUARDED_BY(mu_);
171 
172   std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
173 };
174 
175 // An adapter to allow an OpList to be used as an OpRegistryInterface.
176 //
177 // Note that shape inference functions are not passed in to OpListOpRegistry, so
178 // it will return an unusable shape inference function for every op it supports;
179 // therefore, it should only be used in contexts where this is okay.
180 class OpListOpRegistry : public OpRegistryInterface {
181  public:
182   // Does not take ownership of op_list, *op_list must outlive *this.
183   explicit OpListOpRegistry(const OpList* op_list);
184   ~OpListOpRegistry() override;
185   Status LookUp(const std::string& op_type_name,
186                 const OpRegistrationData** op_reg_data) const override;
187 
188   // Returns OpRegistrationData* of op type in list, else returns nullptr.
189   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
190 
191  private:
192   // Values are owned.
193   std::unordered_map<string, const OpRegistrationData*> index_;
194 };
195 
196 // Support for defining the OpDef (specifying the semantics of the Op and how
197 // it should be created) and registering it in the OpRegistry::Global()
198 // registry.  Usage:
199 //
200 // REGISTER_OP("my_op_name")
201 //     .Attr("<name>:<type>")
202 //     .Attr("<name>:<type>=<default>")
203 //     .Input("<name>:<type-expr>")
204 //     .Input("<name>:Ref(<type-expr>)")
205 //     .Output("<name>:<type-expr>")
206 //     .Doc(R"(
207 // <1-line summary>
208 // <rest of the description (potentially many lines)>
209 // <name-of-attr-input-or-output>: <description of name>
210 // <name-of-attr-input-or-output>: <description of name;
211 //   if long, indent the description on subsequent lines>
212 // )");
213 //
214 // Note: .Doc() should be last.
215 // For details, see the OpDefBuilder class in op_def_builder.h.
216 
217 namespace register_op {
218 
219 class OpDefBuilderWrapper {
220  public:
OpDefBuilderWrapper(const char name[])221   explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
Attr(std::string spec)222   OpDefBuilderWrapper& Attr(std::string spec) {
223     builder_.Attr(std::move(spec));
224     return *this;
225   }
Input(std::string spec)226   OpDefBuilderWrapper& Input(std::string spec) {
227     builder_.Input(std::move(spec));
228     return *this;
229   }
Output(std::string spec)230   OpDefBuilderWrapper& Output(std::string spec) {
231     builder_.Output(std::move(spec));
232     return *this;
233   }
SetIsCommutative()234   OpDefBuilderWrapper& SetIsCommutative() {
235     builder_.SetIsCommutative();
236     return *this;
237   }
SetIsAggregate()238   OpDefBuilderWrapper& SetIsAggregate() {
239     builder_.SetIsAggregate();
240     return *this;
241   }
SetIsStateful()242   OpDefBuilderWrapper& SetIsStateful() {
243     builder_.SetIsStateful();
244     return *this;
245   }
SetDoNotOptimize()246   OpDefBuilderWrapper& SetDoNotOptimize() {
247     // We don't have a separate flag to disable optimizations such as constant
248     // folding and CSE so we reuse the stateful flag.
249     builder_.SetIsStateful();
250     return *this;
251   }
SetAllowsUninitializedInput()252   OpDefBuilderWrapper& SetAllowsUninitializedInput() {
253     builder_.SetAllowsUninitializedInput();
254     return *this;
255   }
Deprecated(int version,std::string explanation)256   OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
257     builder_.Deprecated(version, std::move(explanation));
258     return *this;
259   }
Doc(std::string text)260   OpDefBuilderWrapper& Doc(std::string text) {
261     builder_.Doc(std::move(text));
262     return *this;
263   }
SetShapeFn(OpShapeInferenceFn fn)264   OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
265     builder_.SetShapeFn(std::move(fn));
266     return *this;
267   }
SetIsDistributedCommunication()268   OpDefBuilderWrapper& SetIsDistributedCommunication() {
269     builder_.SetIsDistributedCommunication();
270     return *this;
271   }
272 
273   // Type constructor to support type inference. Similar to SetShapeFn, it
274   // allows programmatic control over the output type of an op, including
275   // inferring it from the inputs.
276   // TODO(mdan): Merge with shape inference.
SetTypeConstructor(OpTypeConstructor fn)277   OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
278     builder_.SetTypeConstructor(std::move(fn));
279     return *this;
280   }
281 
builder()282   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
283 
284   InitOnStartupMarker operator()();
285 
286  private:
287   mutable ::tensorflow::OpDefBuilder builder_;
288 };
289 
290 }  // namespace register_op
291 
292 #define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \
293   static ::tensorflow::InitOnStartupMarker const register_op##ctr         \
294       TF_ATTRIBUTE_UNUSED =                                               \
295           TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
296           << ::tensorflow::register_op::OpDefBuilderWrapper(name)
297 
298 #define REGISTER_OP(name)        \
299   TF_ATTRIBUTE_ANNOTATE("tf:op") \
300   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
301 
302 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
303 // that the op is registered unconditionally even when selective
304 // registration is used.
305 #define REGISTER_SYSTEM_OP(name)        \
306   TF_ATTRIBUTE_ANNOTATE("tf:op")        \
307   TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
308   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
309 
310 }  // namespace tensorflow
311 
312 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
313