• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17 
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/gtl/stl_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34 
OpNameToAttrTypeMap()35 std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36   static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
37   return m;
38 }
39 
40 const uint32 kIsList = 1U << 31;
41 
DefaultFunctionAttrTypeMap()42 AttrTypeMap* DefaultFunctionAttrTypeMap() {
43   AttrTypeMap* map = new AttrTypeMap();
44   (*map)["executor_type"] = TF_ATTR_STRING;
45   (*map)["config_proto"] = TF_ATTR_STRING;
46   return map;
47 }
48 
GetDefaultFunctionAttrTypeMap()49 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
50   static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
51   return map;
52 }
53 
54 }  // namespace
55 
OpDefForOp(const char * op_name,const OpDef ** op_def)56 Status OpDefForOp(const char* op_name, const OpDef** op_def) {
57   const OpRegistrationData* op_reg_data = nullptr;
58   Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
59   if (s.ok()) {
60     *op_def = &op_reg_data->op_def;
61   }
62   return s;
63 }
64 
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)65 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
66                         bool* is_function) {
67   mutex_lock l(g_op_name_to_attr_type_map_lock);
68   *is_function = false;
69   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
70   if (*out != nullptr) return Status::OK();
71   const OpDef* op_def = nullptr;
72   Status s = OpDefForOp(op_name, &op_def);
73   if (errors::IsNotFound(s)) {
74     // If we did not find the op def, we assume `op_name` is a function.
75     // If it is actually a misspelled op, user will get another error when
76     // trying to run it.
77     // TODO(iga): If we ever have a use case for different attribute specs
78     // in different functions, we will need to look at the OpDef in the
79     // function def to retrieve their types.
80     *out = GetDefaultFunctionAttrTypeMap();
81     *is_function = true;
82     return Status::OK();
83   } else if (!s.ok()) {
84     return s;
85   }
86   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
87   // TODO(agarwal): Avoid having to create this "registry" at runtime,
88   // perhaps can be done at op registration time?
89   for (const auto& attr : op_def->attr()) {
90     string type = attr.type();
91     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
92     if (is_list) {
93       type = type.substr(5, type.length() - 6);
94     }
95     uint32 t = is_list ? kIsList : 0;
96     if (type == "string") {
97       t |= TF_ATTR_STRING;
98     } else if (type == "int") {
99       t |= TF_ATTR_INT;
100     } else if (type == "float") {
101       t |= TF_ATTR_FLOAT;
102     } else if (type == "bool") {
103       t |= TF_ATTR_BOOL;
104     } else if (type == "type") {
105       t |= TF_ATTR_TYPE;
106     } else if (type == "shape") {
107       t |= TF_ATTR_SHAPE;
108     } else if (type == "tensor") {
109       t |= TF_ATTR_TENSOR;
110     } else if (type == "func") {
111       t |= TF_ATTR_FUNC;
112     } else {
113       return errors::Unimplemented(
114           "TODO(agarwal): Enable support for ops with attributes of type '",
115           type, "'");
116     }
117     gtl::InsertIfNotPresent(m.get(), attr.name(), t);
118   }
119   *out = m.get();
120   (*OpNameToAttrTypeMap())[op_name] = m.release();
121   return Status::OK();
122 }
123 
124 #define DEFINE_SET_ATTR(value_type, value_field)                             \
125   template <>                                                                \
126   AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
127     value_field.push_back(std::make_pair(string(attr_name), value));         \
128     cached_cache_key_ = absl::nullopt;                                       \
129     return *this;                                                            \
130   }
131 
132 DEFINE_SET_ATTR(float, float_attrs_);
133 DEFINE_SET_ATTR(int, int_attrs_);
134 DEFINE_SET_ATTR(bool, bool_attrs_);
135 DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
136 
137 #undef DEFINE_SET_ATTR
138 
NumInputs(int n)139 AttrBuilder& AttrBuilder::NumInputs(int n) {
140   DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
141   num_inputs_ = n;
142   return *this;
143 }
144 
FillAttrValueMap(AttrValueMap * m,bool include_those_in_node_def) const145 void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
146                                    bool include_those_in_node_def) const {
147   for (const auto& p : int_attrs_) {
148     SetInAttrValueMap(m, p.first, p.second);
149   }
150   for (const auto& p : float_attrs_) {
151     SetInAttrValueMap(m, p.first, p.second);
152   }
153   for (const auto& p : bool_attrs_) {
154     SetInAttrValueMap(m, p.first, p.second);
155   }
156   for (const auto& p : type_attrs_) {
157     SetInAttrValueMap(m, p.first, p.second);
158   }
159   if (include_those_in_node_def && node_def_ != nullptr) {
160     for (AttrValueMap::const_iterator it = node_def_->attr().begin();
161          it != node_def_->attr().end(); ++it) {
162       m->insert(*it);
163     }
164   }
165   // For any attr-value pairs that exist in the op def (from op registry) but
166   // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
167   // specify all the default attr values (e.g. for matmul, the `transpose_a`
168   // attr defaults to false).
169   const OpDef* op_def = nullptr;
170   Status s = OpDefForOp(op_name_.c_str(), &op_def);
171   // This is expected, if this op is a custom function, and is therefore not
172   // present in the op registry.
173   if (!s.ok()) return;
174 
175   DCHECK(op_def);
176   for (const auto& attr_def : op_def->attr()) {
177     if (attr_def.has_default_value() && !m->count(attr_def.name())) {
178       SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
179     }
180   }
181 }
182 
BuildNodeDef()183 const NodeDef& AttrBuilder::BuildNodeDef() {
184   if (node_def_finalized_) return *node_def_;
185   MayBeInitializeNodeDef();
186   for (int i = 0; i < num_inputs_; ++i) {
187     node_def_->add_input("dummy_input");
188   }
189   FillAttrValueMap(node_def_->mutable_attr(), false);
190   node_def_finalized_ = true;
191   return *node_def_;
192 }
193 
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)194 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
195                       TF_AttrType* out, unsigned char* is_list) {
196   auto* t = gtl::FindOrNull(m, attr_name);
197   if (t == nullptr) {
198     return errors::InvalidArgument("Attribute '", attr_name,
199                                    "' does not exist for this operation");
200   }
201   *out = static_cast<TF_AttrType>(*t & ~kIsList);
202   if (*t & kIsList) {
203     *is_list = 1;
204   } else {
205     *is_list = 0;
206   }
207   return Status::OK();
208 }
209 
210 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)211 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
212                                                const tensorflow::Fprint128& b) {
213   return {tensorflow::FingerprintCat64(a.low64, b.low64),
214           tensorflow::FingerprintCat64(a.high64, b.high64)};
215 }
216 
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)217 void CombineUnordered(const tensorflow::Fprint128& a,
218                       tensorflow::Fprint128* b) {
219   b->low64 += a.low64;
220   b->high64 += a.high64;
221 }
222 
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)223 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
224                                             const tensorflow::Fprint128& b) {
225   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
226   return FingerprintCat128(a, b);
227 }
228 
CacheKeyHelper(StringPiece s,uint64 b)229 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
230   return CacheKeyHelper(s, {b, b});
231 }
232 
233 }  // namespace
234 
CacheKey(const string & device)235 tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) {
236   if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
237     cached_cache_key_ = BuildCacheKeyForDevice(device);
238     device_for_cached_cache_key_ = device;
239   }
240 
241   return *cached_cache_key_;
242 }
243 
BuildCacheKeyForDevice(const string & device) const244 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
245     const string& device) const {
246   tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
247   f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
248   if (node_def_ != nullptr) {
249     // Some attributes are directly written to node_def_ instead of being
250     // stored explicitly.
251     string value;
252     for (const auto& attr : node_def_->attr()) {
253       attr.second.SerializeToString(&value);
254       CombineUnordered(
255           CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
256     }
257     // Note that node_def_ may be created but not finalized. This can happen
258     // when the creation was triggered by a call to Set, but BuildNodeDef has
259     // not been called.
260     if (node_def_finalized_) return f;
261   }
262   for (const auto& p : int_attrs_) {
263     CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
264                      &f);
265   }
266   static std::hash<float> float_hasher;
267   for (const auto& p : float_attrs_) {
268     CombineUnordered(
269         CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
270         &f);
271   }
272   for (const auto& p : bool_attrs_) {
273     CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
274   }
275   for (const auto& p : type_attrs_) {
276     CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
277                      &f);
278   }
279   return f;
280 }
281 
MayBeInitializeNodeDef()282 void AttrBuilder::MayBeInitializeNodeDef() {
283   if (node_def_ == nullptr) {
284     node_def_.reset(new NodeDef());
285     node_def_->set_name(op_name_);
286     node_def_->set_op(op_name_);
287   }
288 }
289 
290 }  // namespace tensorflow
291