1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 #include "tensorflow/c/eager/abstract_op_attrs.h" 25 #include "tensorflow/c/tf_attrtype.h" 26 #include "tensorflow/core/common_runtime/device.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/gtl/inlined_vector.h" 33 #include "tensorflow/core/lib/gtl/optional.h" 34 #include "tensorflow/core/platform/fingerprint.h" 35 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 36 37 namespace tensorflow { 38 39 // Maps attribute name to an encoding of the type of the attribute value. 40 // If the type is not a list type, the value is the same as the TF_AttrType type 41 // of the value. Else, the highest order bit is on, and the rest of the bits 42 // represent the TF_AttrType type of the values in the list. 43 typedef std::unordered_map<string, uint32> AttrTypeMap; 44 45 // Look up OpDef for `op_name`. 46 Status OpDefForOp(const string& op_name, const OpDef** op_def); 47 48 // Returns the AttrTypeMap for the TensorFlow operation named op_name. 49 // If op_name is not registered in global op registry, AttrTypeMapForOp assumes 50 // the op to be a function and returns the default attributes for a function. 51 // `is_function` is set to true in this case. 52 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, 53 bool* is_function); 54 55 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. 56 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, 57 TF_AttrType* out, unsigned char* is_list); 58 59 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. 60 // An AttrBuilder is a convenience class to help with that - providing a smaller 61 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity 62 // checks (like number of inputs matching the OpDef - we only care about 63 // attributes here). 64 // 65 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which 66 // ones make sense to replicate. 67 68 // This is a helper class for creating a NodeDef. Additionally, this class 69 // allows computing a cache key based on fingerprinting the attributes of this 70 // NodeDef. 71 // 72 // Example usage: 73 // AttrBuilder a; 74 // a.NumInputs(2); 75 // a.Set("T", TF_FLOAT); 76 // tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0"); 77 // const NodeDef& n = a.BuildNodeDef(); 78 // 79 // Note that all calls to Set and NumInputs should happen before calling 80 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations 81 // to CacheKey may cause different values to be returned by CacheKey. 82 // 83 // For performance reasons, the class internally delays the actual construction 84 // of the NodeDef till BuildNodeDef is called, or Set is called with certain 85 // uncommon types (see template specializations of Set to see which types 86 // trigger a NodeDef creation). 87 // 88 // Setting attributes via `Set` may cause arena-allocated protocol buffer 89 // messages to be destructed, which is not thread safe. This means that it is 90 // currently not safe to set attributes on *different* AttrBuilder objects from 91 // multiple threads. This does not apply to `CopyAttributes`. 92 class AttrBuilder : public AbstractOpAttrs { 93 public: AttrBuilder()94 AttrBuilder() 95 : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {} 96 ~AttrBuilder()97 ~AttrBuilder() override {} AttrBuilder(const char * op)98 explicit AttrBuilder(const char* op) 99 : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) { 100 Reset(op); 101 } 102 Reset(const char * op)103 void Reset(const char* op) { 104 op_name_ = op; 105 num_inputs_ = 0; 106 encoded_attrs_.clear(); 107 node_def_initialized_ = false; 108 node_def_finalized_ = false; 109 cached_cache_key_ = absl::nullopt; 110 device_for_cached_cache_key_.clear(); 111 } 112 op_name()113 const string& op_name() const { return op_name_; } 114 115 // Needed to work around call to ValidateNodeDef in CreateOpKernel. 116 AttrBuilder& NumInputs(int n); 117 118 template <class T> Set(StringPiece attr_name,T && value)119 AttrBuilder& Set(StringPiece attr_name, T&& value) { 120 SetAttrValue(value, &attr_tmp_); 121 AddAttrIfNotPresent(attr_name, attr_tmp_); 122 cached_cache_key_ = absl::nullopt; 123 return *this; 124 } 125 NumAttributes()126 size_t NumAttributes() const { return encoded_attrs_.size(); } 127 Set(StringPiece attr_name,const AttrValue & value)128 AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) { 129 AddAttrIfNotPresent(attr_name, value); 130 cached_cache_key_ = absl::nullopt; 131 return *this; 132 } 133 134 // Retrieves the attribute value. 135 // Note that Get() can involve a linear scan of all attributes with the same 136 // value type in this Node. This is not an issue, because Get is used rarely 137 // and nodes have a small number of attributes. 138 template <class T> Get(StringPiece attr_name,T * value)139 Status Get(StringPiece attr_name, T* value) const { 140 // Common attributes are stored in AttrVecs. This Get() template 141 // is specialized for them below. If we end up here, the type must be 142 // among those that we store in the node_def_. 143 if (!node_def_initialized_) { 144 return errors::NotFound("No attr named'", attr_name, 145 "' found in AttrBuilder for ", op_name_); 146 } 147 return GetNodeAttr(AttrSlice(node_def_), attr_name, value); 148 } 149 150 tensorflow::Fprint128 CacheKey(const StringPiece device); 151 152 // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as 153 // well as any default attr-value pairs from the associated op_def, if there 154 // is one. 155 void FillAttrValueMap(AttrValueMap* m) const; 156 157 // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except 158 // when the value matches the default for this attr. 159 // More precisely, if the global op registry contains an OpDef for this op 160 // and if an attribute value is the same as the default (according to the 161 // OpDef), this attr-value pair is not added to `m`. 162 void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const; 163 const NodeDef& BuildNodeDef(); 164 165 // Transfers the attributes from `other` to this AttrBuilder. Does not 166 // overwrite existing attributes. Since it does not require deserializing and 167 // re-serializing attributes, it is much more efficient than going through an 168 // AttrValueMap. 169 void CopyAttributes(const AttrBuilder& other); 170 171 void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override; 172 173 bool GetInt(absl::string_view attr_name, int64_t* result) const override; 174 bool GetFloat(absl::string_view attr_name, float* result) const override; 175 bool GetBool(absl::string_view attr_name, bool* result) const override; 176 bool GetType(absl::string_view attr_name, 177 tensorflow::DataType* result) const override; 178 179 private: 180 tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const; 181 182 // Initialize the node_def_ object. 183 // REQUIRES: node_def_initialized_ = false 184 void InitializeNodeDef(); 185 186 template <class T> SetInAttrValueMap(AttrValueMap * m,const string & attr_name,T && value)187 void SetInAttrValueMap(AttrValueMap* m, const string& attr_name, 188 T&& value) const { 189 DCHECK(!node_def_finalized_) 190 << "Calling SetInAttrValueMap after BuildNodeDef."; 191 // If attribute is set more than once, its first value prevails 192 m->insert({attr_name, value}); 193 } 194 195 void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value); 196 197 gtl::FlatMap<string, string> encoded_attrs_; 198 mutable AttrValue attr_tmp_; // For encoding 199 200 string op_name_; // Conceptually const, but can't be because of Reset(...) 201 int num_inputs_; 202 NodeDef node_def_; 203 bool node_def_initialized_; 204 bool node_def_finalized_; 205 206 absl::optional<tensorflow::Fprint128> cached_cache_key_; 207 string device_for_cached_cache_key_; 208 }; 209 210 template <> 211 Status AttrBuilder::Get(StringPiece attr_name, int* value) const; 212 template <> 213 Status AttrBuilder::Get(StringPiece attr_name, float* value) const; 214 template <> 215 Status AttrBuilder::Get(StringPiece attr_name, bool* value) const; 216 template <> 217 Status AttrBuilder::Get(StringPiece attr_name, 218 tensorflow::DataType* value) const; 219 } // namespace tensorflow 220 221 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 222