1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29
30 namespace tensorflow {
31 namespace {
32
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36 static auto* const m =
37 new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38 return m;
39 }
40
41 const uint32 kIsList = 1U << 31;
42
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44 AttrTypeMap* map = new AttrTypeMap();
45 (*map)["executor_type"] = TF_ATTR_STRING;
46 (*map)["config_proto"] = TF_ATTR_STRING;
47 return map;
48 }
49
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51 static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52 return map;
53 }
54
55 } // namespace
56
OpDefForOp(const char * op_name,const OpDef ** op_def)57 Status OpDefForOp(const char* op_name, const OpDef** op_def) {
58 const OpRegistrationData* op_reg_data = nullptr;
59 Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60 if (s.ok()) {
61 *op_def = &op_reg_data->op_def;
62 }
63 return s;
64 }
65
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67 bool* is_function) {
68 mutex_lock l(g_op_name_to_attr_type_map_lock);
69 *is_function = false;
70 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
71 if (*out != nullptr) return Status::OK();
72 const OpDef* op_def = nullptr;
73 Status s = OpDefForOp(op_name, &op_def);
74 if (errors::IsNotFound(s)) {
75 // If we did not find the op def, we assume `op_name` is a function.
76 // If it is actually a misspelled op, user will get another error when
77 // trying to run it.
78 // TODO(iga): If we ever have a use case for different attribute specs
79 // in different functions, we will need to look at the OpDef in the
80 // function def to retrieve their types.
81 *out = GetDefaultFunctionAttrTypeMap();
82 *is_function = true;
83 return Status::OK();
84 } else if (!s.ok()) {
85 return s;
86 }
87 std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
88 // TODO(agarwal): Avoid having to create this "registry" at runtime,
89 // perhaps can be done at op registration time?
90 for (const auto& attr : op_def->attr()) {
91 string type = attr.type();
92 const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
93 if (is_list) {
94 type = type.substr(5, type.length() - 6);
95 }
96 uint32 t = is_list ? kIsList : 0;
97 if (type == "string") {
98 t |= TF_ATTR_STRING;
99 } else if (type == "int") {
100 t |= TF_ATTR_INT;
101 } else if (type == "float") {
102 t |= TF_ATTR_FLOAT;
103 } else if (type == "bool") {
104 t |= TF_ATTR_BOOL;
105 } else if (type == "type") {
106 t |= TF_ATTR_TYPE;
107 } else if (type == "shape") {
108 t |= TF_ATTR_SHAPE;
109 } else if (type == "tensor") {
110 t |= TF_ATTR_TENSOR;
111 } else if (type == "func") {
112 t |= TF_ATTR_FUNC;
113 } else {
114 return errors::Unimplemented(
115 "TODO(agarwal): Enable support for ops with attributes of type '",
116 type, "'");
117 }
118 gtl::InsertIfNotPresent(m.get(), attr.name(), t);
119 }
120 *out = m.get();
121 (*OpNameToAttrTypeMap())[op_name] = m.release();
122 return Status::OK();
123 }
124
125 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE) \
126 template <> \
127 Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const { \
128 auto it = encoded_attrs_.find(string(attr_name)); \
129 if (it == encoded_attrs_.end()) { \
130 return errors::NotFound("No attr named'", attr_name, \
131 "' found in AttrBuilder for ", op_name_); \
132 } \
133 attr_tmp_.ParseFromString(it->second); \
134 TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE)); \
135 *value = attr_tmp_.FIELD(); \
136 return Status::OK(); \
137 }
138
139 DEFINE_GET_ATTR(float, f, "float");
140 DEFINE_GET_ATTR(int, i, "int");
141 DEFINE_GET_ATTR(bool, b, "bool");
142 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
143
144 #undef DEFINE_GET_ATTR
145
NumInputs(int n)146 AttrBuilder& AttrBuilder::NumInputs(int n) {
147 DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
148 num_inputs_ = n;
149 return *this;
150 }
151
FillAttrValueMap(AttrValueMap * m) const152 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
153 for (auto& entry : encoded_attrs_) {
154 attr_tmp_.ParseFromString(entry.second);
155 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
156 }
157 // For any attr-value pairs that exist in the op def (from op registry) but
158 // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
159 // specify all the default attr values (e.g. for matmul, the `transpose_a`
160 // attr defaults to false).
161 const OpDef* op_def = nullptr;
162 Status s = OpDefForOp(op_name().c_str(), &op_def);
163 // This is expected, if this op is a custom function, and is therefore not
164 // present in the op registry.
165 if (!s.ok()) return;
166
167 DCHECK(op_def);
168 for (const auto& attr_def : op_def->attr()) {
169 if (attr_def.has_default_value() && !m->count(attr_def.name())) {
170 SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
171 }
172 }
173 }
174
175 namespace {
176
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)177 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
178 const AttrValue& attr_value) {
179 // TODO(iga): It might make sense to augment OpRegistrationData with a
180 // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
181 for (const OpDef::AttrDef& attr_def : op_def->attr()) {
182 if (attr_def.name() == attr_name && attr_def.has_default_value() &&
183 AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
184 return true;
185 }
186 }
187 return false;
188 }
189
190 } // namespace
191
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const192 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
193 const OpDef* op_def = nullptr;
194 Status s = OpDefForOp(op_name().c_str(), &op_def);
195
196 for (auto& entry : encoded_attrs_) {
197 attr_tmp_.ParseFromString(entry.second);
198 // Insert the attr-value pair if we did not find the OpDef or if the value
199 // is different from default.
200 if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
201 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
202 }
203 }
204 }
205
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)206 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
207 const AttrValue& value) {
208 encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
209 }
210
BuildNodeDef()211 const NodeDef& AttrBuilder::BuildNodeDef() {
212 if (node_def_finalized_) return node_def_;
213 if (!node_def_initialized_) {
214 InitializeNodeDef();
215 }
216 for (int i = 0; i < num_inputs_; ++i) {
217 node_def_.add_input("dummy_input");
218 }
219 FillAttrValueMap(node_def_.mutable_attr());
220 node_def_finalized_ = true;
221 return node_def_;
222 }
223
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)224 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
225 TF_AttrType* out, unsigned char* is_list) {
226 auto* t = gtl::FindOrNull(m, attr_name);
227 if (t == nullptr) {
228 return errors::InvalidArgument("Attribute '", attr_name,
229 "' does not exist for this operation");
230 }
231 *out = static_cast<TF_AttrType>(*t & ~kIsList);
232 if (*t & kIsList) {
233 *is_list = 1;
234 } else {
235 *is_list = 0;
236 }
237 return Status::OK();
238 }
239
240 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)241 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
242 const tensorflow::Fprint128& b) {
243 return {tensorflow::FingerprintCat64(a.low64, b.low64),
244 tensorflow::FingerprintCat64(a.high64, b.high64)};
245 }
246
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)247 void CombineUnordered(const tensorflow::Fprint128& a,
248 tensorflow::Fprint128* b) {
249 b->low64 += a.low64;
250 b->high64 += a.high64;
251 }
252
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)253 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
254 const tensorflow::Fprint128& b) {
255 tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
256 return FingerprintCat128(a, b);
257 }
258
CacheKeyHelper(StringPiece s,uint64 b)259 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
260 return CacheKeyHelper(s, {b, b});
261 }
262
263 } // namespace
264
CacheKey(const StringPiece device)265 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
266 if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
267 cached_cache_key_ = BuildCacheKeyForDevice(device);
268 device_for_cached_cache_key_ = string(device);
269 }
270
271 return *cached_cache_key_;
272 }
273
BuildCacheKeyForDevice(const StringPiece device) const274 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
275 const StringPiece device) const {
276 tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
277 f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
278 for (const auto& p : encoded_attrs_) {
279 CombineUnordered(
280 CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
281 }
282 return f;
283 }
284
InitializeNodeDef()285 void AttrBuilder::InitializeNodeDef() {
286 DCHECK(!node_def_initialized_);
287 node_def_.Clear();
288 node_def_.set_name(op_name_);
289 node_def_.set_op(op_name_);
290 node_def_initialized_ = true;
291 }
292
293 } // namespace tensorflow
294