1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 22 #include <vector> 23 #include "tensorflow/core/framework/op_def_builder.h" 24 #include "tensorflow/core/framework/op_def_util.h" 25 #include "tensorflow/core/framework/selective_registration.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 // Users that want to look up an OpDef by type name should take an 39 // OpRegistryInterface. Functions accepting a 40 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 41 class OpRegistryInterface { 42 public: 43 virtual ~OpRegistryInterface(); 44 45 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 46 // registered under that name, otherwise returns the registered OpDef. 47 // Caller must not delete the returned pointer. 48 virtual Status LookUp(const string& op_type_name, 49 const OpRegistrationData** op_reg_data) const = 0; 50 51 // Shorthand for calling LookUp to get the OpDef. 52 Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; 53 }; 54 55 // The standard implementation of OpRegistryInterface, along with a 56 // global singleton used for registering ops via the REGISTER 57 // macros below. Thread-safe. 58 // 59 // Example registration: 60 // OpRegistry::Global()->Register( 61 // [](OpRegistrationData* op_reg_data)->Status { 62 // // Populate *op_reg_data here. 63 // return Status::OK(); 64 // }); 65 class OpRegistry : public OpRegistryInterface { 66 public: 67 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 68 69 OpRegistry(); 70 ~OpRegistry() override; 71 72 void Register(const OpRegistrationDataFactory& op_data_factory); 73 74 Status LookUp(const string& op_type_name, 75 const OpRegistrationData** op_reg_data) const override; 76 77 // Fills *ops with all registered OpDefs (except those with names 78 // starting with '_' if include_internal == false) sorted in 79 // ascending alphabetical order. 80 void Export(bool include_internal, OpList* ops) const; 81 82 // Returns ASCII-format OpList for all registered OpDefs (except 83 // those with names starting with '_' if include_internal == false). 84 string DebugString(bool include_internal) const; 85 86 // A singleton available at startup. 87 static OpRegistry* Global(); 88 89 // Get all registered ops. 90 void GetRegisteredOps(std::vector<OpDef>* op_defs); 91 92 // Get all `OpRegistrationData`s. 93 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); 94 95 // Watcher, a function object. 96 // The watcher, if set by SetWatcher(), is called every time an op is 97 // registered via the Register function. The watcher is passed the Status 98 // obtained from building and adding the OpDef to the registry, and the OpDef 99 // itself if it was successfully built. A watcher returns a Status which is in 100 // turn returned as the final registration status. 101 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 102 103 // An OpRegistry object has only one watcher. This interface is not thread 104 // safe, as different clients are free to set the watcher any time. 105 // Clients are expected to atomically perform the following sequence of 106 // operations : 107 // SetWatcher(a_watcher); 108 // Register some ops; 109 // op_registry->ProcessRegistrations(); 110 // SetWatcher(nullptr); 111 // Returns a non-OK status if a non-null watcher is over-written by another 112 // non-null watcher. 113 Status SetWatcher(const Watcher& watcher); 114 115 // Process the current list of deferred registrations. Note that calls to 116 // Export, LookUp and DebugString would also implicitly process the deferred 117 // registrations. Returns the status of the first failed op registration or 118 // Status::OK() otherwise. 119 Status ProcessRegistrations() const; 120 121 // Defer the registrations until a later call to a function that processes 122 // deferred registrations are made. Normally, registrations that happen after 123 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 124 // immediately. Call this to defer future registrations. 125 void DeferRegistrations(); 126 127 // Clear the registrations that have been deferred. 128 void ClearDeferredRegistrations(); 129 130 private: 131 // Ensures that all the functions in deferred_ get called, their OpDef's 132 // registered, and returns with deferred_ empty. Returns true the first 133 // time it is called. Prints a fatal log if any op registration fails. 134 bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 135 136 // Calls the functions in deferred_ and registers their OpDef's 137 // It returns the Status of the first failed op registration or Status::OK() 138 // otherwise. 139 Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 140 141 // Add 'def' to the registry with additional data 'data'. On failure, or if 142 // there is already an OpDef with that name registered, returns a non-okay 143 // status. 144 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 145 const EXCLUSIVE_LOCKS_REQUIRED(mu_); 146 147 Status LookUpSlow(const string& op_type_name, 148 const OpRegistrationData** op_reg_data) const; 149 150 mutable mutex mu_; 151 // Functions in deferred_ may only be called with mu_ held. 152 mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); 153 // Values are owned. 154 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 155 GUARDED_BY(mu_); 156 mutable bool initialized_ GUARDED_BY(mu_); 157 158 // Registry watcher. 159 mutable Watcher watcher_ GUARDED_BY(mu_); 160 }; 161 162 // An adapter to allow an OpList to be used as an OpRegistryInterface. 163 // 164 // Note that shape inference functions are not passed in to OpListOpRegistry, so 165 // it will return an unusable shape inference function for every op it supports; 166 // therefore, it should only be used in contexts where this is okay. 167 class OpListOpRegistry : public OpRegistryInterface { 168 public: 169 // Does not take ownership of op_list, *op_list must outlive *this. 170 OpListOpRegistry(const OpList* op_list); 171 ~OpListOpRegistry() override; 172 Status LookUp(const string& op_type_name, 173 const OpRegistrationData** op_reg_data) const override; 174 175 private: 176 // Values are owned. 177 std::unordered_map<string, const OpRegistrationData*> index_; 178 }; 179 180 // Support for defining the OpDef (specifying the semantics of the Op and how 181 // it should be created) and registering it in the OpRegistry::Global() 182 // registry. Usage: 183 // 184 // REGISTER_OP("my_op_name") 185 // .Attr("<name>:<type>") 186 // .Attr("<name>:<type>=<default>") 187 // .Input("<name>:<type-expr>") 188 // .Input("<name>:Ref(<type-expr>)") 189 // .Output("<name>:<type-expr>") 190 // .Doc(R"( 191 // <1-line summary> 192 // <rest of the description (potentially many lines)> 193 // <name-of-attr-input-or-output>: <description of name> 194 // <name-of-attr-input-or-output>: <description of name; 195 // if long, indent the description on subsequent lines> 196 // )"); 197 // 198 // Note: .Doc() should be last. 199 // For details, see the OpDefBuilder class in op_def_builder.h. 200 201 namespace register_op { 202 203 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP 204 // calls. This allows the result of REGISTER_OP to be used in chaining, as in 205 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective 206 // registration to turn the entire call-chain into a no-op. 207 template <bool should_register> 208 class OpDefBuilderWrapper; 209 210 // Template specialization that forwards all calls to the contained builder. 211 template <> 212 class OpDefBuilderWrapper<true> { 213 public: OpDefBuilderWrapper(const char name[])214 OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(string spec)215 OpDefBuilderWrapper<true>& Attr(string spec) { 216 builder_.Attr(std::move(spec)); 217 return *this; 218 } Input(string spec)219 OpDefBuilderWrapper<true>& Input(string spec) { 220 builder_.Input(std::move(spec)); 221 return *this; 222 } Output(string spec)223 OpDefBuilderWrapper<true>& Output(string spec) { 224 builder_.Output(std::move(spec)); 225 return *this; 226 } SetIsCommutative()227 OpDefBuilderWrapper<true>& SetIsCommutative() { 228 builder_.SetIsCommutative(); 229 return *this; 230 } SetIsAggregate()231 OpDefBuilderWrapper<true>& SetIsAggregate() { 232 builder_.SetIsAggregate(); 233 return *this; 234 } SetIsStateful()235 OpDefBuilderWrapper<true>& SetIsStateful() { 236 builder_.SetIsStateful(); 237 return *this; 238 } SetAllowsUninitializedInput()239 OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() { 240 builder_.SetAllowsUninitializedInput(); 241 return *this; 242 } Deprecated(int version,string explanation)243 OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) { 244 builder_.Deprecated(version, std::move(explanation)); 245 return *this; 246 } Doc(string text)247 OpDefBuilderWrapper<true>& Doc(string text) { 248 builder_.Doc(std::move(text)); 249 return *this; 250 } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))251 OpDefBuilderWrapper<true>& SetShapeFn( 252 Status (*fn)(shape_inference::InferenceContext*)) { 253 builder_.SetShapeFn(fn); 254 return *this; 255 } builder()256 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 257 258 private: 259 mutable ::tensorflow::OpDefBuilder builder_; 260 }; 261 262 // Template specialization that turns all calls into no-ops. 263 template <> 264 class OpDefBuilderWrapper<false> { 265 public: OpDefBuilderWrapper(const char name[])266 constexpr OpDefBuilderWrapper(const char name[]) {} Attr(StringPiece spec)267 OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; } Input(StringPiece spec)268 OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; } Output(StringPiece spec)269 OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; } SetIsCommutative()270 OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; } SetIsAggregate()271 OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; } SetIsStateful()272 OpDefBuilderWrapper<false>& SetIsStateful() { return *this; } SetAllowsUninitializedInput()273 OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } Deprecated(int,StringPiece)274 OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; } Doc(StringPiece text)275 OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))276 OpDefBuilderWrapper<false>& SetShapeFn( 277 Status (*fn)(shape_inference::InferenceContext*)) { 278 return *this; 279 } 280 }; 281 282 struct OpDefBuilderReceiver { 283 // To call OpRegistry::Global()->Register(...), used by the 284 // REGISTER_OP macro below. 285 // Note: These are implicitly converting constructors. 286 OpDefBuilderReceiver( 287 const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit) OpDefBuilderReceiverOpDefBuilderReceiver288 constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) { 289 } // NOLINT(runtime/explicit) 290 }; 291 } // namespace register_op 292 293 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) 294 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) 295 #define REGISTER_OP_UNIQ(ctr, name) \ 296 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 297 TF_ATTRIBUTE_UNUSED = \ 298 ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ 299 name)>(name) 300 301 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 302 // that the op is registered unconditionally even when selective 303 // registration is used. 304 #define REGISTER_SYSTEM_OP(name) \ 305 REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) 306 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ 307 REGISTER_SYSTEM_OP_UNIQ(ctr, name) 308 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ 309 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 310 TF_ATTRIBUTE_UNUSED = \ 311 ::tensorflow::register_op::OpDefBuilderWrapper<true>(name) 312 313 } // namespace tensorflow 314 315 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ 316