• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/platform/fingerprint.h"
33 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
34 
35 namespace tensorflow {
36 
37 // Maps attribute name to an encoding of the type of the attribute value.
38 // If the type is not a list type, the value is the same as the TF_AttrType type
39 // of the value. Else, the highest order bit is on, and the rest of the bits
40 // represent the TF_AttrType type of the values in the list.
41 typedef std::unordered_map<string, uint32> AttrTypeMap;
42 
43 // Look up OpDef for `op_name`.
44 Status OpDefForOp(const char* op_name, const OpDef** op_def);
45 
46 // Returns the AttrTypeMap for the TensorFlow operation named op_name.
47 // If op_name is not registered in global op registry, AttrTypeMapForOp assumes
48 // the op to be a function and returns the default attributes for a function.
49 // `is_function` is set to true in this case.
50 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
51                         bool* is_function);
52 
53 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
54 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
55                       TF_AttrType* out, unsigned char* is_list);
56 
57 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
58 // An AttrBuilder is a convenience class to help with that - providing a smaller
59 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
60 // checks (like number of inputs matching the OpDef - we only care about
61 // attributes here).
62 //
63 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
64 // ones make sense to replicate.
65 
66 // This is a helper class for creating a NodeDef. Additionally, this class
67 // allows computing a cache key based on fingerprinting the attributes of this
68 // NodeDef.
69 //
70 // Example usage:
71 // AttrBuilder a;
72 // a.NumInputs(2);
73 // a.Set("T", TF_FLOAT);
74 // tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0");
75 // const NodeDef& n = a.BuildNodeDef();
76 //
77 // Note that all calls to Set and NumInputs should happen before calling
78 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
79 // to CacheKey may cause different values to be returned by CacheKey.
80 //
81 // For performance reasons, the class internally delays the actual construction
82 // of the NodeDef till BuildNodeDef is called, or Set is called with certain
83 // uncommon types (see template specializations of Set to see which types
84 // trigger a NodeDef creation).
85 class AttrBuilder {
86  public:
AttrBuilder(const char * op)87   explicit AttrBuilder(const char* op)
88       : op_name_(op),
89         num_inputs_(0),
90         node_def_(nullptr),
91         node_def_finalized_(false) {}
92 
93   // Needed to work around call to ValidateNodeDef in CreateOpKernel.
94   AttrBuilder& NumInputs(int n);
95 
96   template <class T>
Set(StringPiece attr_name,T && value)97   AttrBuilder& Set(StringPiece attr_name, T&& value) {
98     MayBeInitializeNodeDef();
99     SetInAttrValueMap(node_def_->mutable_attr(), string(attr_name), value);
100     cached_cache_key_ = absl::nullopt;
101     return *this;
102   }
103 
104   tensorflow::Fprint128 CacheKey(const string& device);
105 
FillAttrValueMap(AttrValueMap * m)106   void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
107   const NodeDef& BuildNodeDef();
108 
109  private:
110   template <class T>
111   using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>;
112 
113   tensorflow::Fprint128 BuildCacheKeyForDevice(const string& device) const;
114 
115   void MayBeInitializeNodeDef();
116   // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
117   // well as any default attr-value pairs from the associated op_def, if there
118   // is one.
119   //
120   // If `include_those_in_node_def` is true, also include any attr-value pairs
121   // from `node_def_`.
122   void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
123 
124   template <class T>
SetInAttrValueMap(AttrValueMap * m,const string & attr_name,T && value)125   void SetInAttrValueMap(AttrValueMap* m, const string& attr_name,
126                          T&& value) const {
127     DCHECK(!node_def_finalized_)
128         << "Calling SetInAttrValueMap after BuildNodeDef.";
129     // If attribute is set more than once, its first value prevails
130     if (AttrSlice(m).Find(attr_name) == nullptr) {
131       AttrValue attr_value;
132       SetAttrValue(value, &attr_value);
133       m->insert(AttrValueMap::value_type(attr_name, attr_value));
134     }
135   }
136 
137   AttrVec<int> int_attrs_;
138   AttrVec<float> float_attrs_;
139   AttrVec<bool> bool_attrs_;
140   AttrVec<tensorflow::DataType> type_attrs_;
141   const string op_name_;
142   int num_inputs_;
143   std::unique_ptr<NodeDef> node_def_;
144   bool node_def_finalized_;
145 
146   absl::optional<tensorflow::Fprint128> cached_cache_key_;
147   string device_for_cached_cache_key_;
148 };  // namespace tensorflow
149 
150 template <>
151 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
152 template <>
153 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
154 template <>
155 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
156 template <>
157 AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
158                               tensorflow::DataType&& value);
159 
160 }  // namespace tensorflow
161 
162 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
163