1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 22 #include <vector> 23 #include "tensorflow/core/framework/op_def_builder.h" 24 #include "tensorflow/core/framework/op_def_util.h" 25 #include "tensorflow/core/framework/selective_registration.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 // Users that want to look up an OpDef by type name should take an 39 // OpRegistryInterface. Functions accepting a 40 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 41 class OpRegistryInterface { 42 public: 43 virtual ~OpRegistryInterface(); 44 45 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 46 // registered under that name, otherwise returns the registered OpDef. 47 // Caller must not delete the returned pointer. 48 virtual Status LookUp(const string& op_type_name, 49 const OpRegistrationData** op_reg_data) const = 0; 50 51 // Shorthand for calling LookUp to get the OpDef. 52 Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; 53 }; 54 55 // The standard implementation of OpRegistryInterface, along with a 56 // global singleton used for registering ops via the REGISTER 57 // macros below. Thread-safe. 58 // 59 // Example registration: 60 // OpRegistry::Global()->Register( 61 // [](OpRegistrationData* op_reg_data)->Status { 62 // // Populate *op_reg_data here. 63 // return Status::OK(); 64 // }); 65 class OpRegistry : public OpRegistryInterface { 66 public: 67 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 68 69 OpRegistry(); 70 ~OpRegistry() override; 71 72 void Register(const OpRegistrationDataFactory& op_data_factory); 73 74 Status LookUp(const string& op_type_name, 75 const OpRegistrationData** op_reg_data) const override; 76 77 // Fills *ops with all registered OpDefs (except those with names 78 // starting with '_' if include_internal == false) sorted in 79 // ascending alphabetical order. 80 void Export(bool include_internal, OpList* ops) const; 81 82 // Returns ASCII-format OpList for all registered OpDefs (except 83 // those with names starting with '_' if include_internal == false). 84 string DebugString(bool include_internal) const; 85 86 // A singleton available at startup. 87 static OpRegistry* Global(); 88 89 // Get all registered ops. 90 void GetRegisteredOps(std::vector<OpDef>* op_defs); 91 92 // Watcher, a function object. 93 // The watcher, if set by SetWatcher(), is called every time an op is 94 // registered via the Register function. The watcher is passed the Status 95 // obtained from building and adding the OpDef to the registry, and the OpDef 96 // itself if it was successfully built. A watcher returns a Status which is in 97 // turn returned as the final registration status. 98 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 99 100 // An OpRegistry object has only one watcher. This interface is not thread 101 // safe, as different clients are free to set the watcher any time. 102 // Clients are expected to atomically perform the following sequence of 103 // operations : 104 // SetWatcher(a_watcher); 105 // Register some ops; 106 // op_registry->ProcessRegistrations(); 107 // SetWatcher(nullptr); 108 // Returns a non-OK status if a non-null watcher is over-written by another 109 // non-null watcher. 110 Status SetWatcher(const Watcher& watcher); 111 112 // Process the current list of deferred registrations. Note that calls to 113 // Export, LookUp and DebugString would also implicitly process the deferred 114 // registrations. Returns the status of the first failed op registration or 115 // Status::OK() otherwise. 116 Status ProcessRegistrations() const; 117 118 // Defer the registrations until a later call to a function that processes 119 // deferred registrations are made. Normally, registrations that happen after 120 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 121 // immediately. Call this to defer future registrations. 122 void DeferRegistrations(); 123 124 // Clear the registrations that have been deferred. 125 void ClearDeferredRegistrations(); 126 127 private: 128 // Ensures that all the functions in deferred_ get called, their OpDef's 129 // registered, and returns with deferred_ empty. Returns true the first 130 // time it is called. Prints a fatal log if any op registration fails. 131 bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 132 133 // Calls the functions in deferred_ and registers their OpDef's 134 // It returns the Status of the first failed op registration or Status::OK() 135 // otherwise. 136 Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); 137 138 // Add 'def' to the registry with additional data 'data'. On failure, or if 139 // there is already an OpDef with that name registered, returns a non-okay 140 // status. 141 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 142 const EXCLUSIVE_LOCKS_REQUIRED(mu_); 143 144 mutable mutex mu_; 145 // Functions in deferred_ may only be called with mu_ held. 146 mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); 147 // Values are owned. 148 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 149 GUARDED_BY(mu_); 150 mutable bool initialized_ GUARDED_BY(mu_); 151 152 // Registry watcher. 153 mutable Watcher watcher_ GUARDED_BY(mu_); 154 }; 155 156 // An adapter to allow an OpList to be used as an OpRegistryInterface. 157 // 158 // Note that shape inference functions are not passed in to OpListOpRegistry, so 159 // it will return an unusable shape inference function for every op it supports; 160 // therefore, it should only be used in contexts where this is okay. 161 class OpListOpRegistry : public OpRegistryInterface { 162 public: 163 // Does not take ownership of op_list, *op_list must outlive *this. 164 OpListOpRegistry(const OpList* op_list); 165 ~OpListOpRegistry() override; 166 Status LookUp(const string& op_type_name, 167 const OpRegistrationData** op_reg_data) const override; 168 169 private: 170 // Values are owned. 171 std::unordered_map<string, const OpRegistrationData*> index_; 172 }; 173 174 // Support for defining the OpDef (specifying the semantics of the Op and how 175 // it should be created) and registering it in the OpRegistry::Global() 176 // registry. Usage: 177 // 178 // REGISTER_OP("my_op_name") 179 // .Attr("<name>:<type>") 180 // .Attr("<name>:<type>=<default>") 181 // .Input("<name>:<type-expr>") 182 // .Input("<name>:Ref(<type-expr>)") 183 // .Output("<name>:<type-expr>") 184 // .Doc(R"( 185 // <1-line summary> 186 // <rest of the description (potentially many lines)> 187 // <name-of-attr-input-or-output>: <description of name> 188 // <name-of-attr-input-or-output>: <description of name; 189 // if long, indent the description on subsequent lines> 190 // )"); 191 // 192 // Note: .Doc() should be last. 193 // For details, see the OpDefBuilder class in op_def_builder.h. 194 195 namespace register_op { 196 197 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP 198 // calls. This allows the result of REGISTER_OP to be used in chaining, as in 199 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective 200 // registration to turn the entire call-chain into a no-op. 201 template <bool should_register> 202 class OpDefBuilderWrapper; 203 204 // Template specialization that forwards all calls to the contained builder. 205 template <> 206 class OpDefBuilderWrapper<true> { 207 public: OpDefBuilderWrapper(const char name[])208 OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(StringPiece spec)209 OpDefBuilderWrapper<true>& Attr(StringPiece spec) { 210 builder_.Attr(spec); 211 return *this; 212 } Input(StringPiece spec)213 OpDefBuilderWrapper<true>& Input(StringPiece spec) { 214 builder_.Input(spec); 215 return *this; 216 } Output(StringPiece spec)217 OpDefBuilderWrapper<true>& Output(StringPiece spec) { 218 builder_.Output(spec); 219 return *this; 220 } SetIsCommutative()221 OpDefBuilderWrapper<true>& SetIsCommutative() { 222 builder_.SetIsCommutative(); 223 return *this; 224 } SetIsAggregate()225 OpDefBuilderWrapper<true>& SetIsAggregate() { 226 builder_.SetIsAggregate(); 227 return *this; 228 } SetIsStateful()229 OpDefBuilderWrapper<true>& SetIsStateful() { 230 builder_.SetIsStateful(); 231 return *this; 232 } SetAllowsUninitializedInput()233 OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() { 234 builder_.SetAllowsUninitializedInput(); 235 return *this; 236 } Deprecated(int version,StringPiece explanation)237 OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) { 238 builder_.Deprecated(version, explanation); 239 return *this; 240 } Doc(StringPiece text)241 OpDefBuilderWrapper<true>& Doc(StringPiece text) { 242 builder_.Doc(text); 243 return *this; 244 } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))245 OpDefBuilderWrapper<true>& SetShapeFn( 246 Status (*fn)(shape_inference::InferenceContext*)) { 247 builder_.SetShapeFn(fn); 248 return *this; 249 } builder()250 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 251 252 private: 253 mutable ::tensorflow::OpDefBuilder builder_; 254 }; 255 256 // Template specialization that turns all calls into no-ops. 257 template <> 258 class OpDefBuilderWrapper<false> { 259 public: OpDefBuilderWrapper(const char name[])260 constexpr OpDefBuilderWrapper(const char name[]) {} Attr(StringPiece spec)261 OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; } Input(StringPiece spec)262 OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; } Output(StringPiece spec)263 OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; } SetIsCommutative()264 OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; } SetIsAggregate()265 OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; } SetIsStateful()266 OpDefBuilderWrapper<false>& SetIsStateful() { return *this; } SetAllowsUninitializedInput()267 OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } Deprecated(int,StringPiece)268 OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; } Doc(StringPiece text)269 OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))270 OpDefBuilderWrapper<false>& SetShapeFn( 271 Status (*fn)(shape_inference::InferenceContext*)) { 272 return *this; 273 } 274 }; 275 276 struct OpDefBuilderReceiver { 277 // To call OpRegistry::Global()->Register(...), used by the 278 // REGISTER_OP macro below. 279 // Note: These are implicitly converting constructors. 280 OpDefBuilderReceiver( 281 const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit) OpDefBuilderReceiverOpDefBuilderReceiver282 constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) { 283 } // NOLINT(runtime/explicit) 284 }; 285 } // namespace register_op 286 287 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) 288 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) 289 #define REGISTER_OP_UNIQ(ctr, name) \ 290 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 291 TF_ATTRIBUTE_UNUSED = \ 292 ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ 293 name)>(name) 294 295 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 296 // that the op is registered unconditionally even when selective 297 // registration is used. 298 #define REGISTER_SYSTEM_OP(name) \ 299 REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) 300 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ 301 REGISTER_SYSTEM_OP_UNIQ(ctr, name) 302 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ 303 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ 304 TF_ATTRIBUTE_UNUSED = \ 305 ::tensorflow::register_op::OpDefBuilderWrapper<true>(name) 306 307 } // namespace tensorflow 308 309 #endif // TENSORFLOW_FRAMEWORK_OP_H_ 310