• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17 
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34 
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36   static auto* const m =
37       new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38   return m;
39 }
40 
41 const uint32 kIsList = 1U << 31;
42 
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44   AttrTypeMap* map = new AttrTypeMap();
45   (*map)["executor_type"] = TF_ATTR_STRING;
46   (*map)["config_proto"] = TF_ATTR_STRING;
47   return map;
48 }
49 
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51   static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52   return map;
53 }
54 
55 }  // namespace
56 
OpDefForOp(const string & op_name,const OpDef ** op_def)57 Status OpDefForOp(const string& op_name, const OpDef** op_def) {
58   const OpRegistrationData* op_reg_data = nullptr;
59   Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60   if (s.ok()) {
61     *op_def = &op_reg_data->op_def;
62   }
63   return s;
64 }
65 
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67                         bool* is_function) {
68   {
69     tf_shared_lock l(g_op_name_to_attr_type_map_lock);
70     *is_function = false;
71     *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
72     if (*out != nullptr) return Status::OK();
73   }
74 
75   mutex_lock l(g_op_name_to_attr_type_map_lock);
76 
77   // Check the existence of AttrTypeMap for op_name again because another thread
78   // may insert this map after the tf_shared_lock is released but before the
79   // mutex_lock is acquired.
80   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
81   if (*out != nullptr) return Status::OK();
82 
83   const OpDef* op_def = nullptr;
84   Status s = OpDefForOp(op_name, &op_def);
85   if (errors::IsNotFound(s)) {
86     // If we did not find the op def, we assume `op_name` is a function.
87     // If it is actually a misspelled op, user will get another error when
88     // trying to run it.
89     // TODO(iga): If we ever have a use case for different attribute specs
90     // in different functions, we will need to look at the OpDef in the
91     // function def to retrieve their types.
92     *out = GetDefaultFunctionAttrTypeMap();
93     *is_function = true;
94     return Status::OK();
95   } else if (!s.ok()) {
96     return s;
97   }
98   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
99   // TODO(agarwal): Avoid having to create this "registry" at runtime,
100   // perhaps can be done at op registration time?
101   for (const auto& attr : op_def->attr()) {
102     string type = attr.type();
103     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
104     if (is_list) {
105       type = type.substr(5, type.length() - 6);
106     }
107     uint32 t = is_list ? kIsList : 0;
108     if (type == "string") {
109       t |= TF_ATTR_STRING;
110     } else if (type == "int") {
111       t |= TF_ATTR_INT;
112     } else if (type == "float") {
113       t |= TF_ATTR_FLOAT;
114     } else if (type == "bool") {
115       t |= TF_ATTR_BOOL;
116     } else if (type == "type") {
117       t |= TF_ATTR_TYPE;
118     } else if (type == "shape") {
119       t |= TF_ATTR_SHAPE;
120     } else if (type == "tensor") {
121       t |= TF_ATTR_TENSOR;
122     } else if (type == "func") {
123       t |= TF_ATTR_FUNC;
124     } else {
125       return errors::Unimplemented(
126           "TODO(agarwal): Enable support for ops with attributes of type '",
127           type, "'");
128     }
129     gtl::InsertIfNotPresent(m.get(), attr.name(), t);
130   }
131   *out = m.get();
132   auto r = OpNameToAttrTypeMap()->emplace(op_name, m.release());
133   DCHECK(r.second) << "AttrTypeMap already exists for " << op_name;
134 
135   return Status::OK();
136 }
137 
138 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE)                         \
139   template <>                                                           \
140   Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const {   \
141     auto it = encoded_attrs_.find(string(attr_name));                   \
142     if (it == encoded_attrs_.end()) {                                   \
143       return errors::NotFound("No attr named'", attr_name,              \
144                               "' found in AttrBuilder for ", op_name_); \
145     }                                                                   \
146     attr_tmp_.ParseFromString(it->second);                              \
147     TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE));         \
148     *value = attr_tmp_.FIELD();                                         \
149     return Status::OK();                                                \
150   }
151 
152 DEFINE_GET_ATTR(float, f, "float");
153 DEFINE_GET_ATTR(int, i, "int");
154 DEFINE_GET_ATTR(bool, b, "bool");
155 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
156 
157 #undef DEFINE_GET_ATTR
158 
NumInputs(int n)159 AttrBuilder& AttrBuilder::NumInputs(int n) {
160   DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
161   num_inputs_ = n;
162   return *this;
163 }
164 
FillAttrValueMap(AttrValueMap * m) const165 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
166   for (auto& entry : encoded_attrs_) {
167     attr_tmp_.ParseFromString(entry.second);
168     m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
169   }
170   // For any attr-value pairs that exist in the op def (from op registry) but
171   // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
172   // specify all the default attr values (e.g. for matmul, the `transpose_a`
173   // attr defaults to false).
174   const OpDef* op_def = nullptr;
175   Status s = OpDefForOp(op_name().c_str(), &op_def);
176   // This is expected, if this op is a custom function, and is therefore not
177   // present in the op registry.
178   if (!s.ok()) return;
179 
180   DCHECK(op_def);
181   for (const auto& attr_def : op_def->attr()) {
182     if (attr_def.has_default_value() && !m->count(attr_def.name())) {
183       SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
184     }
185   }
186 }
187 
188 namespace {
189 
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)190 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
191                          const AttrValue& attr_value) {
192   // TODO(iga): It might make sense to augment OpRegistrationData with a
193   // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
194   for (const OpDef::AttrDef& attr_def : op_def->attr()) {
195     if (attr_def.name() == attr_name && attr_def.has_default_value() &&
196         AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
197       return true;
198     }
199   }
200   return false;
201 }
202 
203 }  // namespace
204 
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const205 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
206   const OpDef* op_def = nullptr;
207   Status s = OpDefForOp(op_name().c_str(), &op_def);
208 
209   for (auto& entry : encoded_attrs_) {
210     attr_tmp_.ParseFromString(entry.second);
211     // Insert the attr-value pair if we did not find the OpDef or if the value
212     // is different from default.
213     if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
214       m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
215     }
216   }
217 }
218 
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)219 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
220                                       const AttrValue& value) {
221   encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
222 }
223 
BuildNodeDef()224 const NodeDef& AttrBuilder::BuildNodeDef() {
225   if (node_def_finalized_) return node_def_;
226   if (!node_def_initialized_) {
227     InitializeNodeDef();
228   }
229   for (int i = 0; i < num_inputs_; ++i) {
230     node_def_.add_input("dummy_input");
231   }
232   FillAttrValueMap(node_def_.mutable_attr());
233   node_def_finalized_ = true;
234   return node_def_;
235 }
236 
CopyAttributes(const AttrBuilder & other)237 void AttrBuilder::CopyAttributes(const AttrBuilder& other) {
238   encoded_attrs_.insert(other.encoded_attrs_.begin(),
239                         other.encoded_attrs_.end());
240 }
241 
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)242 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
243                       TF_AttrType* out, unsigned char* is_list) {
244   auto* t = gtl::FindOrNull(m, attr_name);
245   if (t == nullptr) {
246     return errors::InvalidArgument("Attribute '", attr_name,
247                                    "' does not exist for this operation");
248   }
249   *out = static_cast<TF_AttrType>(*t & ~kIsList);
250   if (*t & kIsList) {
251     *is_list = 1;
252   } else {
253     *is_list = 0;
254   }
255   return Status::OK();
256 }
257 
258 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)259 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
260                                                const tensorflow::Fprint128& b) {
261   return {tensorflow::FingerprintCat64(a.low64, b.low64),
262           tensorflow::FingerprintCat64(a.high64, b.high64)};
263 }
264 
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)265 void CombineUnordered(const tensorflow::Fprint128& a,
266                       tensorflow::Fprint128* b) {
267   b->low64 += a.low64;
268   b->high64 += a.high64;
269 }
270 
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)271 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
272                                             const tensorflow::Fprint128& b) {
273   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
274   return FingerprintCat128(a, b);
275 }
276 
CacheKeyHelper(StringPiece s,uint64 b)277 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
278   return CacheKeyHelper(s, {b, b});
279 }
280 
281 }  // namespace
282 
CacheKey(const StringPiece device)283 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
284   if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
285     cached_cache_key_ = BuildCacheKeyForDevice(device);
286     device_for_cached_cache_key_ = string(device);
287   }
288 
289   return *cached_cache_key_;
290 }
291 
BuildCacheKeyForDevice(const StringPiece device) const292 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
293     const StringPiece device) const {
294   tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
295   f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
296   for (const auto& p : encoded_attrs_) {
297     CombineUnordered(
298         CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
299   }
300   return f;
301 }
302 
InitializeNodeDef()303 void AttrBuilder::InitializeNodeDef() {
304   DCHECK(!node_def_initialized_);
305   node_def_.Clear();
306   node_def_.set_name(op_name_);
307   node_def_.set_op(op_name_);
308   node_def_initialized_ = true;
309 }
310 
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const311 void AttrBuilder::GetNameAttrList(
312     tensorflow::NameAttrList* name_and_attrs) const {
313   FillAttrValueMap(name_and_attrs->mutable_attr());
314   name_and_attrs->set_name(op_name());
315 }
316 
GetInt(absl::string_view attr_name,int64_t * result) const317 bool AttrBuilder::GetInt(absl::string_view attr_name, int64_t* result) const {
318   Status s = Get(attr_name, result);
319   return s.ok();
320 }
GetFloat(absl::string_view attr_name,float * result) const321 bool AttrBuilder::GetFloat(absl::string_view attr_name, float* result) const {
322   Status s = Get(attr_name, result);
323   return s.ok();
324 }
GetBool(absl::string_view attr_name,bool * result) const325 bool AttrBuilder::GetBool(absl::string_view attr_name, bool* result) const {
326   Status s = Get(attr_name, result);
327   return s.ok();
328 }
329 
GetType(absl::string_view attr_name,tensorflow::DataType * result) const330 bool AttrBuilder::GetType(absl::string_view attr_name,
331                           tensorflow::DataType* result) const {
332   Status s = Get(attr_name, result);
333   return s.ok();
334 }
335 
336 }  // namespace tensorflow
337