1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/dataset_utils.h"
17 #include "tensorflow/core/common_runtime/device.h"
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/util/work_sharder.h"
22
23 namespace tensorflow {
24 namespace data {
25
ComputeShortCircuitIndices(OpKernelConstruction * ctx,const NameAttrList & func,std::vector<int> * indices)26 Status ComputeShortCircuitIndices(OpKernelConstruction* ctx,
27 const NameAttrList& func,
28 std::vector<int>* indices) {
29 FunctionLibraryRuntime::Handle fn_handle;
30 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
31 func.name(), AttrSlice(&func.attr()), &fn_handle));
32 auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
33 Status s = ctx->function_library()->ReleaseHandle(fn_handle);
34 if (!s.ok()) {
35 LOG(WARNING) << "Failed to release handle: " << s.error_message();
36 }
37 });
38
39 // If the function contains any stateful operations, we conservatively execute
40 // the entire function.
41 if (ctx->function_library()->IsStateful(func.name())) {
42 indices->clear();
43 return Status::OK();
44 }
45
46 const FunctionBody* fn_body =
47 ctx->function_library()->GetFunctionBody(fn_handle);
48 indices->resize(fn_body->ret_nodes.size());
49
50 for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
51 Node* ret_node = fn_body->ret_nodes[i];
52 Node* ret_input_node;
53 TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
54
55 while (ret_input_node->def().op() == "Identity") {
56 TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
57 }
58
59 if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
60 TF_RETURN_IF_ERROR(
61 GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
62 } else {
63 indices->clear();
64 break;
65 }
66 }
67 return Status::OK();
68 }
69
ComputeMoveVector(const std::vector<int> & indices)70 std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
71 std::map<int, int> last_use;
72 for (size_t i = 0; i < indices.size(); ++i) {
73 last_use[indices[i]] = i;
74 }
75 std::vector<bool> can_move;
76 can_move.resize(indices.size());
77 for (size_t i = 0; i < indices.size(); ++i) {
78 can_move[i] = last_use[indices[i]] == i;
79 }
80 return can_move;
81 }
82
MakeIteratorFromInputElement(IteratorContext * ctx,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)83 Status MakeIteratorFromInputElement(
84 IteratorContext* ctx, const std::vector<Tensor>& input_element,
85 int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
86 StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
87 std::vector<Tensor> return_values;
88
89 TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
90 &return_values));
91
92 if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
93 TensorShapeUtils::IsScalar(return_values[0].shape()))) {
94 return errors::InvalidArgument(
95 "Function must return a single scalar of dtype DT_VARIANT.");
96 }
97
98 // Retrieve the dataset that was created in `f`.
99 DatasetBase* returned_dataset;
100 TF_RETURN_IF_ERROR(
101 GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
102
103 // Create an iterator for the dataset that was returned by `f`.
104 return returned_dataset->MakeIterator(
105 ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
106 }
107
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)108 Status VerifyTypesMatch(const DataTypeVector& expected,
109 const DataTypeVector& received) {
110 if (expected.size() != received.size()) {
111 return errors::InvalidArgument(
112 "Number of components does not match: expected ", expected.size(),
113 " types but got ", received.size(), ".");
114 }
115 for (size_t i = 0; i < expected.size(); ++i) {
116 if (expected[i] != received[i]) {
117 return errors::InvalidArgument("Data type mismatch at component ", i,
118 ": expected ", DataTypeString(expected[i]),
119 " but got ", DataTypeString(received[i]),
120 ".");
121 }
122 }
123 return Status::OK();
124 }
125
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)126 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
127 const std::vector<PartialTensorShape>& received) {
128 if (expected.size() != received.size()) {
129 return errors::InvalidArgument(
130 "Number of components does not match: expected ", expected.size(),
131 " shapes but got ", received.size(), ".");
132 }
133 for (size_t i = 0; i < expected.size(); ++i) {
134 if (!expected[i].IsCompatibleWith(received[i])) {
135 return errors::InvalidArgument("Incompatible shapes at component ", i,
136 ": expected ", expected[i].DebugString(),
137 " but got ", received[i].DebugString(),
138 ".");
139 }
140 }
141
142 return Status::OK();
143 }
144
145 namespace {
146
147 constexpr char kDelimiter[] = "@@";
148
149 } // namespace
150
VariantTensorDataReader(const tensorflow::VariantTensorData * data)151 VariantTensorDataReader::VariantTensorDataReader(
152 const tensorflow::VariantTensorData* data)
153 : data_(data) {
154 string metadata;
155 data_->get_metadata(&metadata);
156 auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
157 for (size_t i = 0; i < keys.size(); ++i) {
158 map_[keys[i]] = i;
159 }
160 }
161
ReadScalar(StringPiece key,int64 * val)162 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) {
163 return ReadScalarInternal(key, val);
164 }
165
ReadScalar(StringPiece key,string * val)166 Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) {
167 return ReadScalarInternal(key, val);
168 }
169
ReadTensor(StringPiece key,Tensor * val)170 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) {
171 return ReadTensorInternal(key, val);
172 }
173
Contains(StringPiece key)174 bool VariantTensorDataReader::Contains(StringPiece key) {
175 return map_.find(string(key)) != map_.end();
176 }
177
178 template <typename T>
ReadScalarInternal(StringPiece key,T * val)179 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
180 if (map_.find(string(key)) == map_.end()) {
181 return errors::NotFound(key);
182 }
183 *val = data_->tensors(map_[string(key)]).scalar<T>()();
184 return Status::OK();
185 }
186
ReadTensorInternal(StringPiece key,Tensor * val)187 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
188 Tensor* val) {
189 if (map_.find(string(key)) == map_.end()) {
190 return errors::NotFound(key);
191 }
192 *val = data_->tensors(map_[string(key)]);
193 return Status::OK();
194 }
195
WriteScalar(StringPiece key,const int64 val)196 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
197 return WriteScalarInternal(key, val);
198 }
199
WriteScalar(StringPiece key,const string & val)200 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
201 const string& val) {
202 return WriteScalarInternal(key, val);
203 }
204
WriteTensor(StringPiece key,const Tensor & val)205 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
206 const Tensor& val) {
207 return WriteTensorInternal(key, val);
208 }
209
Flush()210 Status VariantTensorDataWriter::Flush() {
211 string metadata;
212 for (size_t i = 0; i < keys_.size(); ++i) {
213 strings::StrAppend(&metadata, kDelimiter, keys_[i]);
214 }
215 data_->set_metadata(metadata);
216 return Status::OK();
217 }
218
219 template <typename T>
WriteScalarInternal(StringPiece key,const T & val)220 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
221 const T& val) {
222 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
223 val_t.scalar<T>()() = val;
224 return WriteTensorInternal(key, val_t);
225 }
226
WriteTensorInternal(StringPiece key,const Tensor & val)227 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
228 const Tensor& val) {
229 DCHECK_EQ(key.find(kDelimiter), string::npos);
230 keys_.push_back(string(key));
231 *(data_->add_tensors()) = val;
232 return Status::OK();
233 }
234
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)235 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
236 const FunctionLibraryDefinition& to_add) {
237 for (const auto& fn : to_add.ListFunctionNames()) {
238 if (auto found = base->Find(fn)) {
239 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
240 return errors::InvalidArgument("Cannot add function '", fn,
241 "' because a different function with "
242 "the same signature already exists.");
243 }
244 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
245 }
246 }
247 return base->AddLibrary(to_add);
248 }
249
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)250 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
251 const FunctionDefLibrary& to_add) {
252 for (const auto& fd : to_add.function()) {
253 if (auto found = base->Find(fd.signature().name())) {
254 if (!OpDefEqual(found->signature(), fd.signature())) {
255 return errors::InvalidArgument("Cannot add function '",
256 fd.signature().name(),
257 "' because a different function with "
258 "the same signature already exists.");
259 }
260 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
261 }
262 }
263 return base->AddLibrary(to_add);
264 }
265
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)266 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
267 std::function<void(std::function<void()>)> runner, int max_parallelism) {
268 return std::bind(
269 [max_parallelism](
270 // Note: `runner` is a const reference to avoid copying it.
271 const std::function<void(std::function<void()>)>& runner,
272 std::function<void()> fn) {
273 std::function<void()> scoped_fn = std::bind(
274 [max_parallelism](const std::function<void()>& fn) {
275 ScopedPerThreadMaxParallelism scope(max_parallelism);
276 fn();
277 },
278 std::move(fn));
279 runner(std::move(scoped_fn));
280 },
281 std::move(runner), std::placeholders::_1);
282 }
283 } // namespace data
284 } // namespace tensorflow
285