• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17 
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34 
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36   static auto* const m =
37       new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38   return m;
39 }
40 
41 const uint32 kIsList = 1U << 31;
42 
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44   AttrTypeMap* map = new AttrTypeMap();
45   (*map)["executor_type"] = TF_ATTR_STRING;
46   (*map)["config_proto"] = TF_ATTR_STRING;
47   return map;
48 }
49 
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51   static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52   return map;
53 }
54 
55 }  // namespace
56 
OpDefForOp(const char * op_name,const OpDef ** op_def)57 Status OpDefForOp(const char* op_name, const OpDef** op_def) {
58   const OpRegistrationData* op_reg_data = nullptr;
59   Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60   if (s.ok()) {
61     *op_def = &op_reg_data->op_def;
62   }
63   return s;
64 }
65 
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67                         bool* is_function) {
68   mutex_lock l(g_op_name_to_attr_type_map_lock);
69   *is_function = false;
70   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
71   if (*out != nullptr) return Status::OK();
72   const OpDef* op_def = nullptr;
73   Status s = OpDefForOp(op_name, &op_def);
74   if (errors::IsNotFound(s)) {
75     // If we did not find the op def, we assume `op_name` is a function.
76     // If it is actually a misspelled op, user will get another error when
77     // trying to run it.
78     // TODO(iga): If we ever have a use case for different attribute specs
79     // in different functions, we will need to look at the OpDef in the
80     // function def to retrieve their types.
81     *out = GetDefaultFunctionAttrTypeMap();
82     *is_function = true;
83     return Status::OK();
84   } else if (!s.ok()) {
85     return s;
86   }
87   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
88   // TODO(agarwal): Avoid having to create this "registry" at runtime,
89   // perhaps can be done at op registration time?
90   for (const auto& attr : op_def->attr()) {
91     string type = attr.type();
92     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
93     if (is_list) {
94       type = type.substr(5, type.length() - 6);
95     }
96     uint32 t = is_list ? kIsList : 0;
97     if (type == "string") {
98       t |= TF_ATTR_STRING;
99     } else if (type == "int") {
100       t |= TF_ATTR_INT;
101     } else if (type == "float") {
102       t |= TF_ATTR_FLOAT;
103     } else if (type == "bool") {
104       t |= TF_ATTR_BOOL;
105     } else if (type == "type") {
106       t |= TF_ATTR_TYPE;
107     } else if (type == "shape") {
108       t |= TF_ATTR_SHAPE;
109     } else if (type == "tensor") {
110       t |= TF_ATTR_TENSOR;
111     } else if (type == "func") {
112       t |= TF_ATTR_FUNC;
113     } else {
114       return errors::Unimplemented(
115           "TODO(agarwal): Enable support for ops with attributes of type '",
116           type, "'");
117     }
118     gtl::InsertIfNotPresent(m.get(), attr.name(), t);
119   }
120   *out = m.get();
121   (*OpNameToAttrTypeMap())[op_name] = m.release();
122   return Status::OK();
123 }
124 
125 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE)                         \
126   template <>                                                           \
127   Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const {   \
128     auto it = encoded_attrs_.find(string(attr_name));                   \
129     if (it == encoded_attrs_.end()) {                                   \
130       return errors::NotFound("No attr named'", attr_name,              \
131                               "' found in AttrBuilder for ", op_name_); \
132     }                                                                   \
133     attr_tmp_.ParseFromString(it->second);                              \
134     TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE));         \
135     *value = attr_tmp_.FIELD();                                         \
136     return Status::OK();                                                \
137   }
138 
139 DEFINE_GET_ATTR(float, f, "float");
140 DEFINE_GET_ATTR(int, i, "int");
141 DEFINE_GET_ATTR(bool, b, "bool");
142 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
143 
144 #undef DEFINE_GET_ATTR
145 
NumInputs(int n)146 AttrBuilder& AttrBuilder::NumInputs(int n) {
147   DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
148   num_inputs_ = n;
149   return *this;
150 }
151 
FillAttrValueMap(AttrValueMap * m) const152 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
153   for (auto& entry : encoded_attrs_) {
154     attr_tmp_.ParseFromString(entry.second);
155     m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
156   }
157   // For any attr-value pairs that exist in the op def (from op registry) but
158   // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
159   // specify all the default attr values (e.g. for matmul, the `transpose_a`
160   // attr defaults to false).
161   const OpDef* op_def = nullptr;
162   Status s = OpDefForOp(op_name().c_str(), &op_def);
163   // This is expected, if this op is a custom function, and is therefore not
164   // present in the op registry.
165   if (!s.ok()) return;
166 
167   DCHECK(op_def);
168   for (const auto& attr_def : op_def->attr()) {
169     if (attr_def.has_default_value() && !m->count(attr_def.name())) {
170       SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
171     }
172   }
173 }
174 
175 namespace {
176 
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)177 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
178                          const AttrValue& attr_value) {
179   // TODO(iga): It might make sense to augment OpRegistrationData with a
180   // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
181   for (const OpDef::AttrDef& attr_def : op_def->attr()) {
182     if (attr_def.name() == attr_name && attr_def.has_default_value() &&
183         AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
184       return true;
185     }
186   }
187   return false;
188 }
189 
190 }  // namespace
191 
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const192 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
193   const OpDef* op_def = nullptr;
194   Status s = OpDefForOp(op_name().c_str(), &op_def);
195 
196   for (auto& entry : encoded_attrs_) {
197     attr_tmp_.ParseFromString(entry.second);
198     // Insert the attr-value pair if we did not find the OpDef or if the value
199     // is different from default.
200     if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
201       m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
202     }
203   }
204 }
205 
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)206 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
207                                       const AttrValue& value) {
208   encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
209 }
210 
BuildNodeDef()211 const NodeDef& AttrBuilder::BuildNodeDef() {
212   if (node_def_finalized_) return node_def_;
213   if (!node_def_initialized_) {
214     InitializeNodeDef();
215   }
216   for (int i = 0; i < num_inputs_; ++i) {
217     node_def_.add_input("dummy_input");
218   }
219   FillAttrValueMap(node_def_.mutable_attr());
220   node_def_finalized_ = true;
221   return node_def_;
222 }
223 
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)224 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
225                       TF_AttrType* out, unsigned char* is_list) {
226   auto* t = gtl::FindOrNull(m, attr_name);
227   if (t == nullptr) {
228     return errors::InvalidArgument("Attribute '", attr_name,
229                                    "' does not exist for this operation");
230   }
231   *out = static_cast<TF_AttrType>(*t & ~kIsList);
232   if (*t & kIsList) {
233     *is_list = 1;
234   } else {
235     *is_list = 0;
236   }
237   return Status::OK();
238 }
239 
240 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)241 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
242                                                const tensorflow::Fprint128& b) {
243   return {tensorflow::FingerprintCat64(a.low64, b.low64),
244           tensorflow::FingerprintCat64(a.high64, b.high64)};
245 }
246 
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)247 void CombineUnordered(const tensorflow::Fprint128& a,
248                       tensorflow::Fprint128* b) {
249   b->low64 += a.low64;
250   b->high64 += a.high64;
251 }
252 
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)253 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
254                                             const tensorflow::Fprint128& b) {
255   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
256   return FingerprintCat128(a, b);
257 }
258 
CacheKeyHelper(StringPiece s,uint64 b)259 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
260   return CacheKeyHelper(s, {b, b});
261 }
262 
263 }  // namespace
264 
CacheKey(const StringPiece device)265 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
266   if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
267     cached_cache_key_ = BuildCacheKeyForDevice(device);
268     device_for_cached_cache_key_ = string(device);
269   }
270 
271   return *cached_cache_key_;
272 }
273 
BuildCacheKeyForDevice(const StringPiece device) const274 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
275     const StringPiece device) const {
276   tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
277   f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
278   for (const auto& p : encoded_attrs_) {
279     CombineUnordered(
280         CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
281   }
282   return f;
283 }
284 
InitializeNodeDef()285 void AttrBuilder::InitializeNodeDef() {
286   DCHECK(!node_def_initialized_);
287   node_def_.Clear();
288   node_def_.set_name(op_name_);
289   node_def_.set_op(op_name_);
290   node_def_initialized_ = true;
291 }
292 
293 }  // namespace tensorflow
294