1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 22 #include <vector> 23 #include "tensorflow/core/framework/op_def_builder.h" 24 #include "tensorflow/core/framework/op_def_util.h" 25 #include "tensorflow/core/framework/selective_registration.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 // Users that want to look up an OpDef by type name should take an 39 // OpRegistryInterface. Functions accepting a 40 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 41 class OpRegistryInterface { 42 public: 43 virtual ~OpRegistryInterface(); 44 45 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 46 // registered under that name, otherwise returns the registered OpDef. 47 // Caller must not delete the returned pointer. 48 virtual Status LookUp(const std::string& op_type_name, 49 const OpRegistrationData** op_reg_data) const = 0; 50 51 // Shorthand for calling LookUp to get the OpDef. 52 Status LookUpOpDef(const std::string& op_type_name, 53 const OpDef** op_def) const; 54 }; 55 56 // The standard implementation of OpRegistryInterface, along with a 57 // global singleton used for registering ops via the REGISTER 58 // macros below. Thread-safe. 59 // 60 // Example registration: 61 // OpRegistry::Global()->Register( 62 // [](OpRegistrationData* op_reg_data)->Status { 63 // // Populate *op_reg_data here. 64 // return Status::OK(); 65 // }); 66 class OpRegistry : public OpRegistryInterface { 67 public: 68 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 69 70 OpRegistry(); 71 ~OpRegistry() override; 72 73 void Register(const OpRegistrationDataFactory& op_data_factory); 74 75 Status LookUp(const std::string& op_type_name, 76 const OpRegistrationData** op_reg_data) const override; 77 78 // Returns OpRegistrationData* of registered op type, else returns nullptr. 79 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 80 81 // Fills *ops with all registered OpDefs (except those with names 82 // starting with '_' if include_internal == false) sorted in 83 // ascending alphabetical order. 84 void Export(bool include_internal, OpList* ops) const; 85 86 // Returns ASCII-format OpList for all registered OpDefs (except 87 // those with names starting with '_' if include_internal == false). 88 std::string DebugString(bool include_internal) const; 89 90 // A singleton available at startup. 91 static OpRegistry* Global(); 92 93 // Get all registered ops. 94 void GetRegisteredOps(std::vector<OpDef>* op_defs); 95 96 // Get all `OpRegistrationData`s. 97 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); 98 99 // Registers a function that validates op registry. RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)100 void RegisterValidator( 101 std::function<Status(const OpRegistryInterface&)> validator) { 102 op_registry_validator_ = std::move(validator); 103 } 104 105 // Watcher, a function object. 106 // The watcher, if set by SetWatcher(), is called every time an op is 107 // registered via the Register function. The watcher is passed the Status 108 // obtained from building and adding the OpDef to the registry, and the OpDef 109 // itself if it was successfully built. A watcher returns a Status which is in 110 // turn returned as the final registration status. 111 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 112 113 // An OpRegistry object has only one watcher. This interface is not thread 114 // safe, as different clients are free to set the watcher any time. 115 // Clients are expected to atomically perform the following sequence of 116 // operations : 117 // SetWatcher(a_watcher); 118 // Register some ops; 119 // op_registry->ProcessRegistrations(); 120 // SetWatcher(nullptr); 121 // Returns a non-OK status if a non-null watcher is over-written by another 122 // non-null watcher. 123 Status SetWatcher(const Watcher& watcher); 124 125 // Process the current list of deferred registrations. Note that calls to 126 // Export, LookUp and DebugString would also implicitly process the deferred 127 // registrations. Returns the status of the first failed op registration or 128 // Status::OK() otherwise. 129 Status ProcessRegistrations() const; 130 131 // Defer the registrations until a later call to a function that processes 132 // deferred registrations are made. Normally, registrations that happen after 133 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 134 // immediately. Call this to defer future registrations. 135 void DeferRegistrations(); 136 137 // Clear the registrations that have been deferred. 138 void ClearDeferredRegistrations(); 139 140 private: 141 // Ensures that all the functions in deferred_ get called, their OpDef's 142 // registered, and returns with deferred_ empty. Returns true the first 143 // time it is called. Prints a fatal log if any op registration fails. 144 bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 145 146 // Calls the functions in deferred_ and registers their OpDef's 147 // It returns the Status of the first failed op registration or Status::OK() 148 // otherwise. 149 Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 150 151 // Add 'def' to the registry with additional data 'data'. On failure, or if 152 // there is already an OpDef with that name registered, returns a non-okay 153 // status. 154 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 155 const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 156 157 const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; 158 159 mutable mutex mu_; 160 // Functions in deferred_ may only be called with mu_ held. 161 mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_); 162 // Values are owned. 163 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 164 TF_GUARDED_BY(mu_); 165 mutable bool initialized_ TF_GUARDED_BY(mu_); 166 167 // Registry watcher. 168 mutable Watcher watcher_ TF_GUARDED_BY(mu_); 169 170 std::function<Status(const OpRegistryInterface&)> op_registry_validator_; 171 }; 172 173 // An adapter to allow an OpList to be used as an OpRegistryInterface. 174 // 175 // Note that shape inference functions are not passed in to OpListOpRegistry, so 176 // it will return an unusable shape inference function for every op it supports; 177 // therefore, it should only be used in contexts where this is okay. 178 class OpListOpRegistry : public OpRegistryInterface { 179 public: 180 // Does not take ownership of op_list, *op_list must outlive *this. 181 explicit OpListOpRegistry(const OpList* op_list); 182 ~OpListOpRegistry() override; 183 Status LookUp(const std::string& op_type_name, 184 const OpRegistrationData** op_reg_data) const override; 185 186 // Returns OpRegistrationData* of op type in list, else returns nullptr. 187 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 188 189 private: 190 // Values are owned. 191 std::unordered_map<string, const OpRegistrationData*> index_; 192 }; 193 194 // Support for defining the OpDef (specifying the semantics of the Op and how 195 // it should be created) and registering it in the OpRegistry::Global() 196 // registry. Usage: 197 // 198 // REGISTER_OP("my_op_name") 199 // .Attr("<name>:<type>") 200 // .Attr("<name>:<type>=<default>") 201 // .Input("<name>:<type-expr>") 202 // .Input("<name>:Ref(<type-expr>)") 203 // .Output("<name>:<type-expr>") 204 // .Doc(R"( 205 // <1-line summary> 206 // <rest of the description (potentially many lines)> 207 // <name-of-attr-input-or-output>: <description of name> 208 // <name-of-attr-input-or-output>: <description of name; 209 // if long, indent the description on subsequent lines> 210 // )"); 211 // 212 // Note: .Doc() should be last. 213 // For details, see the OpDefBuilder class in op_def_builder.h. 214 215 namespace register_op { 216 217 class OpDefBuilderWrapper { 218 public: OpDefBuilderWrapper(const char name[])219 explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(std::string spec)220 OpDefBuilderWrapper& Attr(std::string spec) { 221 builder_.Attr(std::move(spec)); 222 return *this; 223 } Input(std::string spec)224 OpDefBuilderWrapper& Input(std::string spec) { 225 builder_.Input(std::move(spec)); 226 return *this; 227 } Output(std::string spec)228 OpDefBuilderWrapper& Output(std::string spec) { 229 builder_.Output(std::move(spec)); 230 return *this; 231 } SetIsCommutative()232 OpDefBuilderWrapper& SetIsCommutative() { 233 builder_.SetIsCommutative(); 234 return *this; 235 } SetIsAggregate()236 OpDefBuilderWrapper& SetIsAggregate() { 237 builder_.SetIsAggregate(); 238 return *this; 239 } SetIsStateful()240 OpDefBuilderWrapper& SetIsStateful() { 241 builder_.SetIsStateful(); 242 return *this; 243 } SetDoNotOptimize()244 OpDefBuilderWrapper& SetDoNotOptimize() { 245 // We don't have a separate flag to disable optimizations such as constant 246 // folding and CSE so we reuse the stateful flag. 247 builder_.SetIsStateful(); 248 return *this; 249 } SetAllowsUninitializedInput()250 OpDefBuilderWrapper& SetAllowsUninitializedInput() { 251 builder_.SetAllowsUninitializedInput(); 252 return *this; 253 } Deprecated(int version,std::string explanation)254 OpDefBuilderWrapper& Deprecated(int version, std::string explanation) { 255 builder_.Deprecated(version, std::move(explanation)); 256 return *this; 257 } Doc(std::string text)258 OpDefBuilderWrapper& Doc(std::string text) { 259 builder_.Doc(std::move(text)); 260 return *this; 261 } SetShapeFn(OpShapeInferenceFn fn)262 OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { 263 builder_.SetShapeFn(std::move(fn)); 264 return *this; 265 } 266 builder()267 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 268 269 InitOnStartupMarker operator()(); 270 271 private: 272 mutable ::tensorflow::OpDefBuilder builder_; 273 }; 274 275 } // namespace register_op 276 277 #define REGISTER_OP_IMPL(ctr, name, is_system_op) \ 278 static ::tensorflow::InitOnStartupMarker const register_op##ctr \ 279 TF_ATTRIBUTE_UNUSED = \ 280 TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \ 281 << ::tensorflow::register_op::OpDefBuilderWrapper(name) 282 283 #define REGISTER_OP(name) \ 284 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 285 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false) 286 287 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 288 // that the op is registered unconditionally even when selective 289 // registration is used. 290 #define REGISTER_SYSTEM_OP(name) \ 291 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 292 TF_ATTRIBUTE_ANNOTATE("tf:op:system") \ 293 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true) 294 295 } // namespace tensorflow 296 297 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ 298