1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29
30 namespace tensorflow {
31 namespace {
32
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36 static auto* const m =
37 new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38 return m;
39 }
40
41 const uint32 kIsList = 1U << 31;
42
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44 AttrTypeMap* map = new AttrTypeMap();
45 (*map)["executor_type"] = TF_ATTR_STRING;
46 (*map)["config_proto"] = TF_ATTR_STRING;
47 return map;
48 }
49
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51 static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52 return map;
53 }
54
55 } // namespace
56
OpDefForOp(const string & op_name,const OpDef ** op_def)57 Status OpDefForOp(const string& op_name, const OpDef** op_def) {
58 const OpRegistrationData* op_reg_data = nullptr;
59 Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60 if (s.ok()) {
61 *op_def = &op_reg_data->op_def;
62 }
63 return s;
64 }
65
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67 bool* is_function) {
68 {
69 tf_shared_lock l(g_op_name_to_attr_type_map_lock);
70 *is_function = false;
71 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
72 if (*out != nullptr) return Status::OK();
73 }
74
75 mutex_lock l(g_op_name_to_attr_type_map_lock);
76
77 // Check the existence of AttrTypeMap for op_name again because another thread
78 // may insert this map after the tf_shared_lock is released but before the
79 // mutex_lock is acquired.
80 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
81 if (*out != nullptr) return Status::OK();
82
83 const OpDef* op_def = nullptr;
84 Status s = OpDefForOp(op_name, &op_def);
85 if (errors::IsNotFound(s)) {
86 // If we did not find the op def, we assume `op_name` is a function.
87 // If it is actually a misspelled op, user will get another error when
88 // trying to run it.
89 // TODO(iga): If we ever have a use case for different attribute specs
90 // in different functions, we will need to look at the OpDef in the
91 // function def to retrieve their types.
92 *out = GetDefaultFunctionAttrTypeMap();
93 *is_function = true;
94 return Status::OK();
95 } else if (!s.ok()) {
96 return s;
97 }
98 std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
99 // TODO(agarwal): Avoid having to create this "registry" at runtime,
100 // perhaps can be done at op registration time?
101 for (const auto& attr : op_def->attr()) {
102 string type = attr.type();
103 const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
104 if (is_list) {
105 type = type.substr(5, type.length() - 6);
106 }
107 uint32 t = is_list ? kIsList : 0;
108 if (type == "string") {
109 t |= TF_ATTR_STRING;
110 } else if (type == "int") {
111 t |= TF_ATTR_INT;
112 } else if (type == "float") {
113 t |= TF_ATTR_FLOAT;
114 } else if (type == "bool") {
115 t |= TF_ATTR_BOOL;
116 } else if (type == "type") {
117 t |= TF_ATTR_TYPE;
118 } else if (type == "shape") {
119 t |= TF_ATTR_SHAPE;
120 } else if (type == "tensor") {
121 t |= TF_ATTR_TENSOR;
122 } else if (type == "func") {
123 t |= TF_ATTR_FUNC;
124 } else {
125 return errors::Unimplemented(
126 "TODO(agarwal): Enable support for ops with attributes of type '",
127 type, "'");
128 }
129 gtl::InsertIfNotPresent(m.get(), attr.name(), t);
130 }
131 *out = m.get();
132 auto r = OpNameToAttrTypeMap()->emplace(op_name, m.release());
133 DCHECK(r.second) << "AttrTypeMap already exists for " << op_name;
134
135 return Status::OK();
136 }
137
138 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE) \
139 template <> \
140 Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const { \
141 auto it = encoded_attrs_.find(string(attr_name)); \
142 if (it == encoded_attrs_.end()) { \
143 return errors::NotFound("No attr named'", attr_name, \
144 "' found in AttrBuilder for ", op_name_); \
145 } \
146 attr_tmp_.ParseFromString(it->second); \
147 TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE)); \
148 *value = attr_tmp_.FIELD(); \
149 return Status::OK(); \
150 }
151
152 DEFINE_GET_ATTR(float, f, "float");
153 DEFINE_GET_ATTR(int, i, "int");
154 DEFINE_GET_ATTR(bool, b, "bool");
155 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
156
157 #undef DEFINE_GET_ATTR
158
NumInputs(int n)159 AttrBuilder& AttrBuilder::NumInputs(int n) {
160 DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
161 num_inputs_ = n;
162 return *this;
163 }
164
FillAttrValueMap(AttrValueMap * m) const165 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
166 for (auto& entry : encoded_attrs_) {
167 attr_tmp_.ParseFromString(entry.second);
168 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
169 }
170 // For any attr-value pairs that exist in the op def (from op registry) but
171 // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
172 // specify all the default attr values (e.g. for matmul, the `transpose_a`
173 // attr defaults to false).
174 const OpDef* op_def = nullptr;
175 Status s = OpDefForOp(op_name().c_str(), &op_def);
176 // This is expected, if this op is a custom function, and is therefore not
177 // present in the op registry.
178 if (!s.ok()) return;
179
180 DCHECK(op_def);
181 for (const auto& attr_def : op_def->attr()) {
182 if (attr_def.has_default_value() && !m->count(attr_def.name())) {
183 SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
184 }
185 }
186 }
187
188 namespace {
189
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)190 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
191 const AttrValue& attr_value) {
192 // TODO(iga): It might make sense to augment OpRegistrationData with a
193 // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
194 for (const OpDef::AttrDef& attr_def : op_def->attr()) {
195 if (attr_def.name() == attr_name && attr_def.has_default_value() &&
196 AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
197 return true;
198 }
199 }
200 return false;
201 }
202
203 } // namespace
204
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const205 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
206 const OpDef* op_def = nullptr;
207 Status s = OpDefForOp(op_name().c_str(), &op_def);
208
209 for (auto& entry : encoded_attrs_) {
210 attr_tmp_.ParseFromString(entry.second);
211 // Insert the attr-value pair if we did not find the OpDef or if the value
212 // is different from default.
213 if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
214 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
215 }
216 }
217 }
218
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)219 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
220 const AttrValue& value) {
221 encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
222 }
223
BuildNodeDef()224 const NodeDef& AttrBuilder::BuildNodeDef() {
225 if (node_def_finalized_) return node_def_;
226 if (!node_def_initialized_) {
227 InitializeNodeDef();
228 }
229 for (int i = 0; i < num_inputs_; ++i) {
230 node_def_.add_input("dummy_input");
231 }
232 FillAttrValueMap(node_def_.mutable_attr());
233 node_def_finalized_ = true;
234 return node_def_;
235 }
236
CopyAttributes(const AttrBuilder & other)237 void AttrBuilder::CopyAttributes(const AttrBuilder& other) {
238 encoded_attrs_.insert(other.encoded_attrs_.begin(),
239 other.encoded_attrs_.end());
240 }
241
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)242 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
243 TF_AttrType* out, unsigned char* is_list) {
244 auto* t = gtl::FindOrNull(m, attr_name);
245 if (t == nullptr) {
246 return errors::InvalidArgument("Attribute '", attr_name,
247 "' does not exist for this operation");
248 }
249 *out = static_cast<TF_AttrType>(*t & ~kIsList);
250 if (*t & kIsList) {
251 *is_list = 1;
252 } else {
253 *is_list = 0;
254 }
255 return Status::OK();
256 }
257
258 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)259 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
260 const tensorflow::Fprint128& b) {
261 return {tensorflow::FingerprintCat64(a.low64, b.low64),
262 tensorflow::FingerprintCat64(a.high64, b.high64)};
263 }
264
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)265 void CombineUnordered(const tensorflow::Fprint128& a,
266 tensorflow::Fprint128* b) {
267 b->low64 += a.low64;
268 b->high64 += a.high64;
269 }
270
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)271 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
272 const tensorflow::Fprint128& b) {
273 tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
274 return FingerprintCat128(a, b);
275 }
276
CacheKeyHelper(StringPiece s,uint64 b)277 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
278 return CacheKeyHelper(s, {b, b});
279 }
280
281 } // namespace
282
CacheKey(const StringPiece device)283 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
284 if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
285 cached_cache_key_ = BuildCacheKeyForDevice(device);
286 device_for_cached_cache_key_ = string(device);
287 }
288
289 return *cached_cache_key_;
290 }
291
BuildCacheKeyForDevice(const StringPiece device) const292 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
293 const StringPiece device) const {
294 tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
295 f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
296 for (const auto& p : encoded_attrs_) {
297 CombineUnordered(
298 CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
299 }
300 return f;
301 }
302
InitializeNodeDef()303 void AttrBuilder::InitializeNodeDef() {
304 DCHECK(!node_def_initialized_);
305 node_def_.Clear();
306 node_def_.set_name(op_name_);
307 node_def_.set_op(op_name_);
308 node_def_initialized_ = true;
309 }
310
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const311 void AttrBuilder::GetNameAttrList(
312 tensorflow::NameAttrList* name_and_attrs) const {
313 FillAttrValueMap(name_and_attrs->mutable_attr());
314 name_and_attrs->set_name(op_name());
315 }
316
GetInt(absl::string_view attr_name,int64_t * result) const317 bool AttrBuilder::GetInt(absl::string_view attr_name, int64_t* result) const {
318 Status s = Get(attr_name, result);
319 return s.ok();
320 }
GetFloat(absl::string_view attr_name,float * result) const321 bool AttrBuilder::GetFloat(absl::string_view attr_name, float* result) const {
322 Status s = Get(attr_name, result);
323 return s.ok();
324 }
GetBool(absl::string_view attr_name,bool * result) const325 bool AttrBuilder::GetBool(absl::string_view attr_name, bool* result) const {
326 Status s = Get(attr_name, result);
327 return s.ok();
328 }
329
GetType(absl::string_view attr_name,tensorflow::DataType * result) const330 bool AttrBuilder::GetType(absl::string_view attr_name,
331 tensorflow::DataType* result) const {
332 Status s = Get(attr_name, result);
333 return s.ok();
334 }
335
336 } // namespace tensorflow
337