1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "tensorflow/core/framework/full_type.pb.h" 24 #include "tensorflow/core/framework/full_type_util.h" 25 #include "tensorflow/core/framework/op_def_builder.h" 26 #include "tensorflow/core/framework/op_def_util.h" 27 #include "tensorflow/core/framework/registration/registration.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/macros.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/thread_annotations.h" 36 #include "tensorflow/core/platform/types.h" 37 38 namespace tensorflow { 39 40 // Users that want to look up an OpDef by type name should take an 41 // OpRegistryInterface. Functions accepting a 42 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 43 class OpRegistryInterface { 44 public: 45 virtual ~OpRegistryInterface(); 46 47 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 48 // registered under that name, otherwise returns the registered OpDef. 49 // Caller must not delete the returned pointer. 50 virtual Status LookUp(const std::string& op_type_name, 51 const OpRegistrationData** op_reg_data) const = 0; 52 53 // Shorthand for calling LookUp to get the OpDef. 54 Status LookUpOpDef(const std::string& op_type_name, 55 const OpDef** op_def) const; 56 }; 57 58 // The standard implementation of OpRegistryInterface, along with a 59 // global singleton used for registering ops via the REGISTER 60 // macros below. Thread-safe. 61 // 62 // Example registration: 63 // OpRegistry::Global()->Register( 64 // [](OpRegistrationData* op_reg_data)->Status { 65 // // Populate *op_reg_data here. 66 // return Status::OK(); 67 // }); 68 class OpRegistry : public OpRegistryInterface { 69 public: 70 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 71 72 OpRegistry(); 73 ~OpRegistry() override; 74 75 void Register(const OpRegistrationDataFactory& op_data_factory); 76 77 Status LookUp(const std::string& op_type_name, 78 const OpRegistrationData** op_reg_data) const override; 79 80 // Returns OpRegistrationData* of registered op type, else returns nullptr. 81 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 82 83 // Fills *ops with all registered OpDefs (except those with names 84 // starting with '_' if include_internal == false) sorted in 85 // ascending alphabetical order. 86 void Export(bool include_internal, OpList* ops) const; 87 88 // Returns ASCII-format OpList for all registered OpDefs (except 89 // those with names starting with '_' if include_internal == false). 90 std::string DebugString(bool include_internal) const; 91 92 // A singleton available at startup. 93 static OpRegistry* Global(); 94 95 // Get all registered ops. 96 void GetRegisteredOps(std::vector<OpDef>* op_defs); 97 98 // Get all `OpRegistrationData`s. 99 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); 100 101 // Registers a function that validates op registry. RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)102 void RegisterValidator( 103 std::function<Status(const OpRegistryInterface&)> validator) { 104 op_registry_validator_ = std::move(validator); 105 } 106 107 // Watcher, a function object. 108 // The watcher, if set by SetWatcher(), is called every time an op is 109 // registered via the Register function. The watcher is passed the Status 110 // obtained from building and adding the OpDef to the registry, and the OpDef 111 // itself if it was successfully built. A watcher returns a Status which is in 112 // turn returned as the final registration status. 113 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 114 115 // An OpRegistry object has only one watcher. This interface is not thread 116 // safe, as different clients are free to set the watcher any time. 117 // Clients are expected to atomically perform the following sequence of 118 // operations : 119 // SetWatcher(a_watcher); 120 // Register some ops; 121 // op_registry->ProcessRegistrations(); 122 // SetWatcher(nullptr); 123 // Returns a non-OK status if a non-null watcher is over-written by another 124 // non-null watcher. 125 Status SetWatcher(const Watcher& watcher); 126 127 // Process the current list of deferred registrations. Note that calls to 128 // Export, LookUp and DebugString would also implicitly process the deferred 129 // registrations. Returns the status of the first failed op registration or 130 // Status::OK() otherwise. 131 Status ProcessRegistrations() const; 132 133 // Defer the registrations until a later call to a function that processes 134 // deferred registrations are made. Normally, registrations that happen after 135 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 136 // immediately. Call this to defer future registrations. 137 void DeferRegistrations(); 138 139 // Clear the registrations that have been deferred. 140 void ClearDeferredRegistrations(); 141 142 private: 143 // Ensures that all the functions in deferred_ get called, their OpDef's 144 // registered, and returns with deferred_ empty. Returns true the first 145 // time it is called. Prints a fatal log if any op registration fails. 146 bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 147 148 // Calls the functions in deferred_ and registers their OpDef's 149 // It returns the Status of the first failed op registration or Status::OK() 150 // otherwise. 151 Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 152 153 // Add 'def' to the registry with additional data 'data'. On failure, or if 154 // there is already an OpDef with that name registered, returns a non-okay 155 // status. 156 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 157 const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 158 159 const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; 160 161 mutable mutex mu_; 162 // Functions in deferred_ may only be called with mu_ held. 163 mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_); 164 // Values are owned. 165 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 166 TF_GUARDED_BY(mu_); 167 mutable bool initialized_ TF_GUARDED_BY(mu_); 168 169 // Registry watcher. 170 mutable Watcher watcher_ TF_GUARDED_BY(mu_); 171 172 std::function<Status(const OpRegistryInterface&)> op_registry_validator_; 173 }; 174 175 // An adapter to allow an OpList to be used as an OpRegistryInterface. 176 // 177 // Note that shape inference functions are not passed in to OpListOpRegistry, so 178 // it will return an unusable shape inference function for every op it supports; 179 // therefore, it should only be used in contexts where this is okay. 180 class OpListOpRegistry : public OpRegistryInterface { 181 public: 182 // Does not take ownership of op_list, *op_list must outlive *this. 183 explicit OpListOpRegistry(const OpList* op_list); 184 ~OpListOpRegistry() override; 185 Status LookUp(const std::string& op_type_name, 186 const OpRegistrationData** op_reg_data) const override; 187 188 // Returns OpRegistrationData* of op type in list, else returns nullptr. 189 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 190 191 private: 192 // Values are owned. 193 std::unordered_map<string, const OpRegistrationData*> index_; 194 }; 195 196 // Support for defining the OpDef (specifying the semantics of the Op and how 197 // it should be created) and registering it in the OpRegistry::Global() 198 // registry. Usage: 199 // 200 // REGISTER_OP("my_op_name") 201 // .Attr("<name>:<type>") 202 // .Attr("<name>:<type>=<default>") 203 // .Input("<name>:<type-expr>") 204 // .Input("<name>:Ref(<type-expr>)") 205 // .Output("<name>:<type-expr>") 206 // .Doc(R"( 207 // <1-line summary> 208 // <rest of the description (potentially many lines)> 209 // <name-of-attr-input-or-output>: <description of name> 210 // <name-of-attr-input-or-output>: <description of name; 211 // if long, indent the description on subsequent lines> 212 // )"); 213 // 214 // Note: .Doc() should be last. 215 // For details, see the OpDefBuilder class in op_def_builder.h. 216 217 namespace register_op { 218 219 class OpDefBuilderWrapper { 220 public: OpDefBuilderWrapper(const char name[])221 explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(std::string spec)222 OpDefBuilderWrapper& Attr(std::string spec) { 223 builder_.Attr(std::move(spec)); 224 return *this; 225 } Input(std::string spec)226 OpDefBuilderWrapper& Input(std::string spec) { 227 builder_.Input(std::move(spec)); 228 return *this; 229 } Output(std::string spec)230 OpDefBuilderWrapper& Output(std::string spec) { 231 builder_.Output(std::move(spec)); 232 return *this; 233 } SetIsCommutative()234 OpDefBuilderWrapper& SetIsCommutative() { 235 builder_.SetIsCommutative(); 236 return *this; 237 } SetIsAggregate()238 OpDefBuilderWrapper& SetIsAggregate() { 239 builder_.SetIsAggregate(); 240 return *this; 241 } SetIsStateful()242 OpDefBuilderWrapper& SetIsStateful() { 243 builder_.SetIsStateful(); 244 return *this; 245 } SetDoNotOptimize()246 OpDefBuilderWrapper& SetDoNotOptimize() { 247 // We don't have a separate flag to disable optimizations such as constant 248 // folding and CSE so we reuse the stateful flag. 249 builder_.SetIsStateful(); 250 return *this; 251 } SetAllowsUninitializedInput()252 OpDefBuilderWrapper& SetAllowsUninitializedInput() { 253 builder_.SetAllowsUninitializedInput(); 254 return *this; 255 } Deprecated(int version,std::string explanation)256 OpDefBuilderWrapper& Deprecated(int version, std::string explanation) { 257 builder_.Deprecated(version, std::move(explanation)); 258 return *this; 259 } Doc(std::string text)260 OpDefBuilderWrapper& Doc(std::string text) { 261 builder_.Doc(std::move(text)); 262 return *this; 263 } SetShapeFn(OpShapeInferenceFn fn)264 OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { 265 builder_.SetShapeFn(std::move(fn)); 266 return *this; 267 } SetIsDistributedCommunication()268 OpDefBuilderWrapper& SetIsDistributedCommunication() { 269 builder_.SetIsDistributedCommunication(); 270 return *this; 271 } 272 273 // Type constructor to support type inference. Similar to SetShapeFn, it 274 // allows programmatic control over the output type of an op, including 275 // inferring it from the inputs. 276 // TODO(mdan): Merge with shape inference. SetTypeConstructor(OpTypeConstructor fn)277 OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) { 278 builder_.SetTypeConstructor(std::move(fn)); 279 return *this; 280 } 281 builder()282 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 283 284 InitOnStartupMarker operator()(); 285 286 private: 287 mutable ::tensorflow::OpDefBuilder builder_; 288 }; 289 290 } // namespace register_op 291 292 #define REGISTER_OP_IMPL(ctr, name, is_system_op) \ 293 static ::tensorflow::InitOnStartupMarker const register_op##ctr \ 294 TF_ATTRIBUTE_UNUSED = \ 295 TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \ 296 << ::tensorflow::register_op::OpDefBuilderWrapper(name) 297 298 #define REGISTER_OP(name) \ 299 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 300 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false) 301 302 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 303 // that the op is registered unconditionally even when selective 304 // registration is used. 305 #define REGISTER_SYSTEM_OP(name) \ 306 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 307 TF_ATTRIBUTE_ANNOTATE("tf:op:system") \ 308 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true) 309 310 } // namespace tensorflow 311 312 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ 313