1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 #include "tensorflow/c/c_api.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/lib/gtl/optional.h" 32 #include "tensorflow/core/platform/fingerprint.h" 33 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 34 35 namespace tensorflow { 36 37 // Maps attribute name to an encoding of the type of the attribute value. 38 // If the type is not a list type, the value is the same as the TF_AttrType type 39 // of the value. Else, the highest order bit is on, and the rest of the bits 40 // represent the TF_AttrType type of the values in the list. 41 typedef std::unordered_map<string, uint32> AttrTypeMap; 42 43 // Look up OpDef for `op_name`. 44 Status OpDefForOp(const char* op_name, const OpDef** op_def); 45 46 // Returns the AttrTypeMap for the TensorFlow operation named op_name. 47 // If op_name is not registered in global op registry, AttrTypeMapForOp assumes 48 // the op to be a function and returns the default attributes for a function. 49 // `is_function` is set to true in this case. 50 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, 51 bool* is_function); 52 53 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. 54 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, 55 TF_AttrType* out, unsigned char* is_list); 56 57 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. 58 // An AttrBuilder is a convenience class to help with that - providing a smaller 59 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity 60 // checks (like number of inputs matching the OpDef - we only care about 61 // attributes here). 62 // 63 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which 64 // ones make sense to replicate. 65 66 // This is a helper class for creating a NodeDef. Additionally, this class 67 // allows computing a cache key based on fingerprinting the attributes of this 68 // NodeDef. 69 // 70 // Example usage: 71 // AttrBuilder a; 72 // a.NumInputs(2); 73 // a.Set("T", TF_FLOAT); 74 // tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0"); 75 // const NodeDef& n = a.BuildNodeDef(); 76 // 77 // Note that all calls to Set and NumInputs should happen before calling 78 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations 79 // to CacheKey may cause different values to be returned by CacheKey. 80 // 81 // For performance reasons, the class internally delays the actual construction 82 // of the NodeDef till BuildNodeDef is called, or Set is called with certain 83 // uncommon types (see template specializations of Set to see which types 84 // trigger a NodeDef creation). 85 class AttrBuilder { 86 public: AttrBuilder(const char * op)87 explicit AttrBuilder(const char* op) 88 : op_name_(op), 89 num_inputs_(0), 90 node_def_(nullptr), 91 node_def_finalized_(false) {} 92 93 // Needed to work around call to ValidateNodeDef in CreateOpKernel. 94 AttrBuilder& NumInputs(int n); 95 96 template <class T> Set(StringPiece attr_name,T && value)97 AttrBuilder& Set(StringPiece attr_name, T&& value) { 98 MayBeInitializeNodeDef(); 99 SetInAttrValueMap(node_def_->mutable_attr(), string(attr_name), value); 100 cached_cache_key_ = absl::nullopt; 101 return *this; 102 } 103 104 tensorflow::Fprint128 CacheKey(const string& device); 105 FillAttrValueMap(AttrValueMap * m)106 void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); } 107 const NodeDef& BuildNodeDef(); 108 109 private: 110 template <class T> 111 using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>; 112 113 tensorflow::Fprint128 BuildCacheKeyForDevice(const string& device) const; 114 115 void MayBeInitializeNodeDef(); 116 // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as 117 // well as any default attr-value pairs from the associated op_def, if there 118 // is one. 119 // 120 // If `include_those_in_node_def` is true, also include any attr-value pairs 121 // from `node_def_`. 122 void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const; 123 124 template <class T> SetInAttrValueMap(AttrValueMap * m,const string & attr_name,T && value)125 void SetInAttrValueMap(AttrValueMap* m, const string& attr_name, 126 T&& value) const { 127 DCHECK(!node_def_finalized_) 128 << "Calling SetInAttrValueMap after BuildNodeDef."; 129 // If attribute is set more than once, its first value prevails 130 if (AttrSlice(m).Find(attr_name) == nullptr) { 131 AttrValue attr_value; 132 SetAttrValue(value, &attr_value); 133 m->insert(AttrValueMap::value_type(attr_name, attr_value)); 134 } 135 } 136 137 AttrVec<int> int_attrs_; 138 AttrVec<float> float_attrs_; 139 AttrVec<bool> bool_attrs_; 140 AttrVec<tensorflow::DataType> type_attrs_; 141 const string op_name_; 142 int num_inputs_; 143 std::unique_ptr<NodeDef> node_def_; 144 bool node_def_finalized_; 145 146 absl::optional<tensorflow::Fprint128> cached_cache_key_; 147 string device_for_cached_cache_key_; 148 }; // namespace tensorflow 149 150 template <> 151 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value); 152 template <> 153 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value); 154 template <> 155 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value); 156 template <> 157 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, 158 tensorflow::DataType&& value); 159 160 } // namespace tensorflow 161 162 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ 163