1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 22 #include <vector> 23 #include "tensorflow/core/framework/op_def_builder.h" 24 #include "tensorflow/core/framework/op_def_util.h" 25 #include "tensorflow/core/framework/selective_registration.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 // Users that want to look up an OpDef by type name should take an 39 // OpRegistryInterface. Functions accepting a 40 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 41 class OpRegistryInterface { 42 public: 43 virtual ~OpRegistryInterface(); 44 45 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 46 // registered under that name, otherwise returns the registered OpDef. 47 // Caller must not delete the returned pointer. 48 virtual Status LookUp(const string& op_type_name, 49 const OpRegistrationData** op_reg_data) const = 0; 50 51 // Shorthand for calling LookUp to get the OpDef. 52 Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; 53 }; 54 55 // The standard implementation of OpRegistryInterface, along with a 56 // global singleton used for registering ops via the REGISTER 57 // macros below. Thread-safe. 58 // 59 // Example registration: 60 // OpRegistry::Global()->Register( 61 // [](OpRegistrationData* op_reg_data)->Status { 62 // // Populate *op_reg_data here. 63 // return Status::OK(); 64 // }); 65 class OpRegistry : public OpRegistryInterface { 66 public: 67 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 68 69 OpRegistry(); 70 ~OpRegistry() override; 71 72 void Register(const OpRegistrationDataFactory& op_data_factory); 73 74 Status LookUp(const string& op_type_name, 75 const OpRegistrationData** op_reg_data) const override; 76 77 // Returns OpRegistrationData* of registered op type, else returns nullptr. 78 const OpRegistrationData* LookUp(const string& op_type_name) const; 79 80 // Fills *ops with all registered OpDefs (except those with names 81 // starting with '_' if include_internal == false) sorted in 82 // ascending alphabetical order. 83 void Export(bool include_internal, OpList* ops) const; 84 85 // Returns ASCII-format OpList for all registered OpDefs (except 86 // those with names starting with '_' if include_internal == false). 87 string DebugString(bool include_internal) const; 88 89 // A singleton available at startup. 90 static OpRegistry* Global(); 91 92 // Get all registered ops. 93 void GetRegisteredOps(std::vector<OpDef>* op_defs); 94 95 // Get all `OpRegistrationData`s. 96 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); 97 98 // Registers a function that validates op registry. RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)99 void RegisterValidator( 100 std::function<Status(const OpRegistryInterface&)> validator) { 101 op_registry_validator_ = std::move(validator); 102 } 103 104 // Watcher, a function object. 105 // The watcher, if set by SetWatcher(), is called every time an op is 106 // registered via the Register function. The watcher is passed the Status 107 // obtained from building and adding the OpDef to the registry, and the OpDef 108 // itself if it was successfully built. A watcher returns a Status which is in 109 // turn returned as the final registration status. 110 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 111 112 // An OpRegistry object has only one watcher. This interface is not thread 113 // safe, as different clients are free to set the watcher any time. 114 // Clients are expected to atomically perform the following sequence of 115 // operations : 116 // SetWatcher(a_watcher); 117 // Register some ops; 118 // op_registry->ProcessRegistrations(); 119 // SetWatcher(nullptr); 120 // Returns a non-OK status if a non-null watcher is over-written by another 121 // non-null watcher. 122 Status SetWatcher(const Watcher& watcher); 123 124 // Process the current list of deferred registrations. Note that calls to 125 // Export, LookUp and DebugString would also implicitly process the deferred 126 // registrations. Returns the status of the first failed op registration or 127 // Status::OK() otherwise. 128 Status ProcessRegistrations() const; 129 130 // Defer the registrations until a later call to a function that processes 131 // deferred registrations are made. Normally, registrations that happen after 132 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 133 // immediately. Call this to defer future registrations. 134 void DeferRegistrations(); 135 136 // Clear the registrations that have been deferred. 137 void ClearDeferredRegistrations(); 138 139 private: 140 // Ensures that all the functions in deferred_ get called, their OpDef's 141 // registered, and returns with deferred_ empty. Returns true the first 142 // time it is called. Prints a fatal log if any op registration fails. 143 bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 144 145 // Calls the functions in deferred_ and registers their OpDef's 146 // It returns the Status of the first failed op registration or Status::OK() 147 // otherwise. 148 Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 149 150 // Add 'def' to the registry with additional data 'data'. On failure, or if 151 // there is already an OpDef with that name registered, returns a non-okay 152 // status. 153 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 154 const EXCLUSIVE_LOCKS_REQUIRED(mu_); 155 156 const OpRegistrationData* LookUpSlow(const string& op_type_name) const; 157 158 mutable mutex mu_; 159 // Functions in deferred_ may only be called with mu_ held. 160 mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); 161 // Values are owned. 162 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 163 GUARDED_BY(mu_); 164 mutable bool initialized_ GUARDED_BY(mu_); 165 166 // Registry watcher. 167 mutable Watcher watcher_ GUARDED_BY(mu_); 168 169 std::function<Status(const OpRegistryInterface&)> op_registry_validator_; 170 }; 171 172 // An adapter to allow an OpList to be used as an OpRegistryInterface. 173 // 174 // Note that shape inference functions are not passed in to OpListOpRegistry, so 175 // it will return an unusable shape inference function for every op it supports; 176 // therefore, it should only be used in contexts where this is okay. 177 class OpListOpRegistry : public OpRegistryInterface { 178 public: 179 // Does not take ownership of op_list, *op_list must outlive *this. 180 explicit OpListOpRegistry(const OpList* op_list); 181 ~OpListOpRegistry() override; 182 Status LookUp(const string& op_type_name, 183 const OpRegistrationData** op_reg_data) const override; 184 185 // Returns OpRegistrationData* of op type in list, else returns nullptr. 186 const OpRegistrationData* LookUp(const string& op_type_name) const; 187 188 private: 189 // Values are owned. 190 std::unordered_map<string, const OpRegistrationData*> index_; 191 }; 192 193 // Support for defining the OpDef (specifying the semantics of the Op and how 194 // it should be created) and registering it in the OpRegistry::Global() 195 // registry. Usage: 196 // 197 // REGISTER_OP("my_op_name") 198 // .Attr("<name>:<type>") 199 // .Attr("<name>:<type>=<default>") 200 // .Input("<name>:<type-expr>") 201 // .Input("<name>:Ref(<type-expr>)") 202 // .Output("<name>:<type-expr>") 203 // .Doc(R"( 204 // <1-line summary> 205 // <rest of the description (potentially many lines)> 206 // <name-of-attr-input-or-output>: <description of name> 207 // <name-of-attr-input-or-output>: <description of name; 208 // if long, indent the description on subsequent lines> 209 // )"); 210 // 211 // Note: .Doc() should be last. 212 // For details, see the OpDefBuilder class in op_def_builder.h. 213 214 namespace register_op { 215 216 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP 217 // calls. This allows the result of REGISTER_OP to be used in chaining, as in 218 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective 219 // registration to turn the entire call-chain into a no-op. 220 template <bool should_register> 221 class OpDefBuilderWrapper; 222 223 // Template specialization that forwards all calls to the contained builder. 224 template <> 225 class OpDefBuilderWrapper<true> { 226 public: OpDefBuilderWrapper(const char name[])227 explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(string spec)228 OpDefBuilderWrapper<true>& Attr(string spec) { 229 builder_.Attr(std::move(spec)); 230 return *this; 231 } Input(string spec)232 OpDefBuilderWrapper<true>& Input(string spec) { 233 builder_.Input(std::move(spec)); 234 return *this; 235 } Output(string spec)236 OpDefBuilderWrapper<true>& Output(string spec) { 237 builder_.Output(std::move(spec)); 238 return *this; 239 } SetIsCommutative()240 OpDefBuilderWrapper<true>& SetIsCommutative() { 241 builder_.SetIsCommutative(); 242 return *this; 243 } SetIsAggregate()244 OpDefBuilderWrapper<true>& SetIsAggregate() { 245 builder_.SetIsAggregate(); 246 return *this; 247 } SetIsStateful()248 OpDefBuilderWrapper<true>& SetIsStateful() { 249 builder_.SetIsStateful(); 250 return *this; 251 } SetAllowsUninitializedInput()252 OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() { 253 builder_.SetAllowsUninitializedInput(); 254 return *this; 255 } Deprecated(int version,string explanation)256 OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) { 257 builder_.Deprecated(version, std::move(explanation)); 258 return *this; 259 } Doc(string text)260 OpDefBuilderWrapper<true>& Doc(string text) { 261 builder_.Doc(std::move(text)); 262 return *this; 263 } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))264 OpDefBuilderWrapper<true>& SetShapeFn( 265 Status (*fn)(shape_inference::InferenceContext*)) { 266 builder_.SetShapeFn(fn); 267 return *this; 268 } builder()269 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 270 271 private: 272 mutable ::tensorflow::OpDefBuilder builder_; 273 }; 274 275 // Template specialization that turns all calls into no-ops. 276 template <> 277 class OpDefBuilderWrapper<false> { 278 public: OpDefBuilderWrapper(const char name[])279 explicit constexpr OpDefBuilderWrapper(const char name[]) {} Attr(StringPiece spec)280 OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; } Input(StringPiece spec)281 OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; } Output(StringPiece spec)282 OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; } SetIsCommutative()283 OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; } SetIsAggregate()284 OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; } SetIsStateful()285 OpDefBuilderWrapper<false>& SetIsStateful() { return *this; } SetAllowsUninitializedInput()286 OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } Deprecated(int,StringPiece)287 OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; } Doc(StringPiece text)288 OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))289 OpDefBuilderWrapper<false>& SetShapeFn( 290 Status (*fn)(shape_inference::InferenceContext*)) { 291 return *this; 292 } 293 }; 294 295 struct OpDefBuilderReceiver { 296 // To call OpRegistry::Global()->Register(...), used by the 297 // REGISTER_OP macro below. 298 // Note: These are implicitly converting constructors. 299 OpDefBuilderReceiver( 300 const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit) OpDefBuilderReceiverOpDefBuilderReceiver301 constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) { 302 } // NOLINT(runtime/explicit) 303 }; 304 } // namespace register_op 305 306 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) 307 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) 308 #define REGISTER_OP_UNIQ(ctr, name) \ 309 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 310 TF_ATTRIBUTE_UNUSED = \ 311 ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ 312 name)>(name) 313 314 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 315 // that the op is registered unconditionally even when selective 316 // registration is used. 317 #define REGISTER_SYSTEM_OP(name) \ 318 REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) 319 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ 320 REGISTER_SYSTEM_OP_UNIQ(ctr, name) 321 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ 322 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 323 TF_ATTRIBUTE_UNUSED = \ 324 ::tensorflow::register_op::OpDefBuilderWrapper<true>(name) 325 326 } // namespace tensorflow 327 328 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ 329