• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 #include "tensorflow/c/eager/abstract_op_attrs.h"
25 #include "tensorflow/c/tf_attrtype.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/gtl/optional.h"
34 #include "tensorflow/core/platform/fingerprint.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 
37 namespace tensorflow {
38 
39 // Maps attribute name to an encoding of the type of the attribute value.
40 // If the type is not a list type, the value is the same as the TF_AttrType type
41 // of the value. Else, the highest order bit is on, and the rest of the bits
42 // represent the TF_AttrType type of the values in the list.
43 typedef std::unordered_map<string, uint32> AttrTypeMap;
44 
45 // Look up OpDef for `op_name`.
46 Status OpDefForOp(const string& op_name, const OpDef** op_def);
47 
48 // Returns the AttrTypeMap for the TensorFlow operation named op_name.
49 // If op_name is not registered in global op registry, AttrTypeMapForOp assumes
50 // the op to be a function and returns the default attributes for a function.
51 // `is_function` is set to true in this case.
52 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
53                         bool* is_function);
54 
55 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
56 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
57                       TF_AttrType* out, unsigned char* is_list);
58 
59 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
60 // An AttrBuilder is a convenience class to help with that - providing a smaller
61 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
62 // checks (like number of inputs matching the OpDef - we only care about
63 // attributes here).
64 //
65 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
66 // ones make sense to replicate.
67 
68 // This is a helper class for creating a NodeDef. Additionally, this class
69 // allows computing a cache key based on fingerprinting the attributes of this
70 // NodeDef.
71 //
72 // Example usage:
73 // AttrBuilder a;
74 // a.NumInputs(2);
75 // a.Set("T", TF_FLOAT);
76 // tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0");
77 // const NodeDef& n = a.BuildNodeDef();
78 //
79 // Note that all calls to Set and NumInputs should happen before calling
80 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
81 // to CacheKey may cause different values to be returned by CacheKey.
82 //
83 // For performance reasons, the class internally delays the actual construction
84 // of the NodeDef till BuildNodeDef is called, or Set is called with certain
85 // uncommon types (see template specializations of Set to see which types
86 // trigger a NodeDef creation).
87 //
88 // Setting attributes via `Set` may cause arena-allocated protocol buffer
89 // messages to be destructed, which is not thread safe. This means that it is
90 // currently not safe to set attributes on *different* AttrBuilder objects from
91 // multiple threads. This does not apply to `CopyAttributes`.
92 class AttrBuilder : public AbstractOpAttrs {
93  public:
AttrBuilder()94   AttrBuilder()
95       : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {}
96 
~AttrBuilder()97   ~AttrBuilder() override {}
AttrBuilder(const char * op)98   explicit AttrBuilder(const char* op)
99       : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {
100     Reset(op);
101   }
102 
Reset(const char * op)103   void Reset(const char* op) {
104     op_name_ = op;
105     num_inputs_ = 0;
106     encoded_attrs_.clear();
107     node_def_initialized_ = false;
108     node_def_finalized_ = false;
109     cached_cache_key_ = absl::nullopt;
110     device_for_cached_cache_key_.clear();
111   }
112 
op_name()113   const string& op_name() const { return op_name_; }
114 
115   // Needed to work around call to ValidateNodeDef in CreateOpKernel.
116   AttrBuilder& NumInputs(int n);
117 
118   template <class T>
Set(StringPiece attr_name,T && value)119   AttrBuilder& Set(StringPiece attr_name, T&& value) {
120     SetAttrValue(value, &attr_tmp_);
121     AddAttrIfNotPresent(attr_name, attr_tmp_);
122     cached_cache_key_ = absl::nullopt;
123     return *this;
124   }
125 
NumAttributes()126   size_t NumAttributes() const { return encoded_attrs_.size(); }
127 
Set(StringPiece attr_name,const AttrValue & value)128   AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
129     AddAttrIfNotPresent(attr_name, value);
130     cached_cache_key_ = absl::nullopt;
131     return *this;
132   }
133 
134   // Retrieves the attribute value.
135   // Note that Get() can involve a linear scan of all attributes with the same
136   // value type in this Node. This is not an issue, because Get is used rarely
137   // and nodes have a small number of attributes.
138   template <class T>
Get(StringPiece attr_name,T * value)139   Status Get(StringPiece attr_name, T* value) const {
140     // Common attributes are stored in AttrVecs. This Get() template
141     // is specialized for them below. If we end up here, the type must be
142     // among those that we store in the node_def_.
143     if (!node_def_initialized_) {
144       return errors::NotFound("No attr named'", attr_name,
145                               "' found in AttrBuilder for ", op_name_);
146     }
147     return GetNodeAttr(AttrSlice(node_def_), attr_name, value);
148   }
149 
150   tensorflow::Fprint128 CacheKey(const StringPiece device);
151 
152   // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
153   // well as any default attr-value pairs from the associated op_def, if there
154   // is one.
155   void FillAttrValueMap(AttrValueMap* m) const;
156 
157   // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except
158   // when the value matches the default for this attr.
159   // More precisely, if the global op registry contains an OpDef for this op
160   // and if an attribute value is the same as the default (according to the
161   // OpDef), this attr-value pair is not added to `m`.
162   void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const;
163   const NodeDef& BuildNodeDef();
164 
165   // Transfers the attributes from `other` to this AttrBuilder. Does not
166   // overwrite existing attributes. Since it does not require deserializing and
167   // re-serializing attributes, it is much more efficient than going through an
168   // AttrValueMap.
169   void CopyAttributes(const AttrBuilder& other);
170 
171   void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
172 
173   bool GetInt(absl::string_view attr_name, int64_t* result) const override;
174   bool GetFloat(absl::string_view attr_name, float* result) const override;
175   bool GetBool(absl::string_view attr_name, bool* result) const override;
176   bool GetType(absl::string_view attr_name,
177                tensorflow::DataType* result) const override;
178 
179  private:
180   tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
181 
182   // Initialize the node_def_ object.
183   // REQUIRES: node_def_initialized_ = false
184   void InitializeNodeDef();
185 
186   template <class T>
SetInAttrValueMap(AttrValueMap * m,const string & attr_name,T && value)187   void SetInAttrValueMap(AttrValueMap* m, const string& attr_name,
188                          T&& value) const {
189     DCHECK(!node_def_finalized_)
190         << "Calling SetInAttrValueMap after BuildNodeDef.";
191     // If attribute is set more than once, its first value prevails
192     m->insert({attr_name, value});
193   }
194 
195   void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value);
196 
197   gtl::FlatMap<string, string> encoded_attrs_;
198   mutable AttrValue attr_tmp_;  // For encoding
199 
200   string op_name_;  // Conceptually const, but can't be because of Reset(...)
201   int num_inputs_;
202   NodeDef node_def_;
203   bool node_def_initialized_;
204   bool node_def_finalized_;
205 
206   absl::optional<tensorflow::Fprint128> cached_cache_key_;
207   string device_for_cached_cache_key_;
208 };
209 
210 template <>
211 Status AttrBuilder::Get(StringPiece attr_name, int* value) const;
212 template <>
213 Status AttrBuilder::Get(StringPiece attr_name, float* value) const;
214 template <>
215 Status AttrBuilder::Get(StringPiece attr_name, bool* value) const;
216 template <>
217 Status AttrBuilder::Get(StringPiece attr_name,
218                         tensorflow::DataType* value) const;
219 }  // namespace tensorflow
220 
221 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
222