1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/gtl/stl_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29
30 namespace tensorflow {
31 namespace {
32
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34
OpNameToAttrTypeMap()35 std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36 static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
37 return m;
38 }
39
40 const uint32 kIsList = 1U << 31;
41
DefaultFunctionAttrTypeMap()42 AttrTypeMap* DefaultFunctionAttrTypeMap() {
43 AttrTypeMap* map = new AttrTypeMap();
44 (*map)["executor_type"] = TF_ATTR_STRING;
45 (*map)["config_proto"] = TF_ATTR_STRING;
46 return map;
47 }
48
GetDefaultFunctionAttrTypeMap()49 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
50 static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
51 return map;
52 }
53
54 } // namespace
55
OpDefForOp(const char * op_name,const OpDef ** op_def)56 Status OpDefForOp(const char* op_name, const OpDef** op_def) {
57 const OpRegistrationData* op_reg_data = nullptr;
58 Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
59 if (s.ok()) {
60 *op_def = &op_reg_data->op_def;
61 }
62 return s;
63 }
64
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)65 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
66 bool* is_function) {
67 mutex_lock l(g_op_name_to_attr_type_map_lock);
68 *is_function = false;
69 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
70 if (*out != nullptr) return Status::OK();
71 const OpDef* op_def = nullptr;
72 Status s = OpDefForOp(op_name, &op_def);
73 if (errors::IsNotFound(s)) {
74 // If we did not find the op def, we assume `op_name` is a function.
75 // If it is actually a misspelled op, user will get another error when
76 // trying to run it.
77 // TODO(iga): If we ever have a use case for different attribute specs
78 // in different functions, we will need to look at the OpDef in the
79 // function def to retrieve their types.
80 *out = GetDefaultFunctionAttrTypeMap();
81 *is_function = true;
82 return Status::OK();
83 } else if (!s.ok()) {
84 return s;
85 }
86 std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
87 // TODO(agarwal): Avoid having to create this "registry" at runtime,
88 // perhaps can be done at op registration time?
89 for (const auto& attr : op_def->attr()) {
90 string type = attr.type();
91 const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
92 if (is_list) {
93 type = type.substr(5, type.length() - 6);
94 }
95 uint32 t = is_list ? kIsList : 0;
96 if (type == "string") {
97 t |= TF_ATTR_STRING;
98 } else if (type == "int") {
99 t |= TF_ATTR_INT;
100 } else if (type == "float") {
101 t |= TF_ATTR_FLOAT;
102 } else if (type == "bool") {
103 t |= TF_ATTR_BOOL;
104 } else if (type == "type") {
105 t |= TF_ATTR_TYPE;
106 } else if (type == "shape") {
107 t |= TF_ATTR_SHAPE;
108 } else if (type == "tensor") {
109 t |= TF_ATTR_TENSOR;
110 } else if (type == "func") {
111 t |= TF_ATTR_FUNC;
112 } else {
113 return errors::Unimplemented(
114 "TODO(agarwal): Enable support for ops with attributes of type '",
115 type, "'");
116 }
117 gtl::InsertIfNotPresent(m.get(), attr.name(), t);
118 }
119 *out = m.get();
120 (*OpNameToAttrTypeMap())[op_name] = m.release();
121 return Status::OK();
122 }
123
124 #define DEFINE_SET_ATTR(value_type, value_field) \
125 template <> \
126 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
127 value_field.push_back(std::make_pair(string(attr_name), value)); \
128 cached_cache_key_ = absl::nullopt; \
129 return *this; \
130 }
131
132 DEFINE_SET_ATTR(float, float_attrs_);
133 DEFINE_SET_ATTR(int, int_attrs_);
134 DEFINE_SET_ATTR(bool, bool_attrs_);
135 DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
136
137 #undef DEFINE_SET_ATTR
138
NumInputs(int n)139 AttrBuilder& AttrBuilder::NumInputs(int n) {
140 DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
141 num_inputs_ = n;
142 return *this;
143 }
144
FillAttrValueMap(AttrValueMap * m,bool include_those_in_node_def) const145 void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
146 bool include_those_in_node_def) const {
147 for (const auto& p : int_attrs_) {
148 SetInAttrValueMap(m, p.first, p.second);
149 }
150 for (const auto& p : float_attrs_) {
151 SetInAttrValueMap(m, p.first, p.second);
152 }
153 for (const auto& p : bool_attrs_) {
154 SetInAttrValueMap(m, p.first, p.second);
155 }
156 for (const auto& p : type_attrs_) {
157 SetInAttrValueMap(m, p.first, p.second);
158 }
159 if (include_those_in_node_def && node_def_ != nullptr) {
160 for (AttrValueMap::const_iterator it = node_def_->attr().begin();
161 it != node_def_->attr().end(); ++it) {
162 m->insert(*it);
163 }
164 }
165 // For any attr-value pairs that exist in the op def (from op registry) but
166 // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
167 // specify all the default attr values (e.g. for matmul, the `transpose_a`
168 // attr defaults to false).
169 const OpDef* op_def = nullptr;
170 Status s = OpDefForOp(op_name_.c_str(), &op_def);
171 // This is expected, if this op is a custom function, and is therefore not
172 // present in the op registry.
173 if (!s.ok()) return;
174
175 DCHECK(op_def);
176 for (const auto& attr_def : op_def->attr()) {
177 if (attr_def.has_default_value() && !m->count(attr_def.name())) {
178 SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
179 }
180 }
181 }
182
BuildNodeDef()183 const NodeDef& AttrBuilder::BuildNodeDef() {
184 if (node_def_finalized_) return *node_def_;
185 MayBeInitializeNodeDef();
186 for (int i = 0; i < num_inputs_; ++i) {
187 node_def_->add_input("dummy_input");
188 }
189 FillAttrValueMap(node_def_->mutable_attr(), false);
190 node_def_finalized_ = true;
191 return *node_def_;
192 }
193
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)194 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
195 TF_AttrType* out, unsigned char* is_list) {
196 auto* t = gtl::FindOrNull(m, attr_name);
197 if (t == nullptr) {
198 return errors::InvalidArgument("Attribute '", attr_name,
199 "' does not exist for this operation");
200 }
201 *out = static_cast<TF_AttrType>(*t & ~kIsList);
202 if (*t & kIsList) {
203 *is_list = 1;
204 } else {
205 *is_list = 0;
206 }
207 return Status::OK();
208 }
209
210 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)211 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
212 const tensorflow::Fprint128& b) {
213 return {tensorflow::FingerprintCat64(a.low64, b.low64),
214 tensorflow::FingerprintCat64(a.high64, b.high64)};
215 }
216
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)217 void CombineUnordered(const tensorflow::Fprint128& a,
218 tensorflow::Fprint128* b) {
219 b->low64 += a.low64;
220 b->high64 += a.high64;
221 }
222
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)223 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
224 const tensorflow::Fprint128& b) {
225 tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
226 return FingerprintCat128(a, b);
227 }
228
CacheKeyHelper(StringPiece s,uint64 b)229 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
230 return CacheKeyHelper(s, {b, b});
231 }
232
233 } // namespace
234
CacheKey(const string & device)235 tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) {
236 if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
237 cached_cache_key_ = BuildCacheKeyForDevice(device);
238 device_for_cached_cache_key_ = device;
239 }
240
241 return *cached_cache_key_;
242 }
243
BuildCacheKeyForDevice(const string & device) const244 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
245 const string& device) const {
246 tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
247 f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
248 if (node_def_ != nullptr) {
249 // Some attributes are directly written to node_def_ instead of being
250 // stored explicitly.
251 string value;
252 for (const auto& attr : node_def_->attr()) {
253 attr.second.SerializeToString(&value);
254 CombineUnordered(
255 CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
256 }
257 // Note that node_def_ may be created but not finalized. This can happen
258 // when the creation was triggered by a call to Set, but BuildNodeDef has
259 // not been called.
260 if (node_def_finalized_) return f;
261 }
262 for (const auto& p : int_attrs_) {
263 CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
264 &f);
265 }
266 static std::hash<float> float_hasher;
267 for (const auto& p : float_attrs_) {
268 CombineUnordered(
269 CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
270 &f);
271 }
272 for (const auto& p : bool_attrs_) {
273 CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
274 }
275 for (const auto& p : type_attrs_) {
276 CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
277 &f);
278 }
279 return f;
280 }
281
MayBeInitializeNodeDef()282 void AttrBuilder::MayBeInitializeNodeDef() {
283 if (node_def_ == nullptr) {
284 node_def_.reset(new NodeDef());
285 node_def_->set_name(op_name_);
286 node_def_->set_op(op_name_);
287 }
288 }
289
290 } // namespace tensorflow
291