• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/dataset_utils.h"
17 #include "tensorflow/core/common_runtime/device.h"
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/util/work_sharder.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
ComputeShortCircuitIndices(OpKernelConstruction * ctx,const NameAttrList & func,std::vector<int> * indices)26 Status ComputeShortCircuitIndices(OpKernelConstruction* ctx,
27                                   const NameAttrList& func,
28                                   std::vector<int>* indices) {
29   FunctionLibraryRuntime::Handle fn_handle;
30   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
31       func.name(), AttrSlice(&func.attr()), &fn_handle));
32   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
33     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
34     if (!s.ok()) {
35       LOG(WARNING) << "Failed to release handle: " << s.error_message();
36     }
37   });
38 
39   // If the function contains any stateful operations, we conservatively execute
40   // the entire function.
41   if (ctx->function_library()->IsStateful(func.name())) {
42     indices->clear();
43     return Status::OK();
44   }
45 
46   const FunctionBody* fn_body =
47       ctx->function_library()->GetFunctionBody(fn_handle);
48   indices->resize(fn_body->ret_nodes.size());
49 
50   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
51     Node* ret_node = fn_body->ret_nodes[i];
52     Node* ret_input_node;
53     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
54 
55     while (ret_input_node->def().op() == "Identity") {
56       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
57     }
58 
59     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
60       TF_RETURN_IF_ERROR(
61           GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
62     } else {
63       indices->clear();
64       break;
65     }
66   }
67   return Status::OK();
68 }
69 
ComputeMoveVector(const std::vector<int> & indices)70 std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
71   std::map<int, int> last_use;
72   for (size_t i = 0; i < indices.size(); ++i) {
73     last_use[indices[i]] = i;
74   }
75   std::vector<bool> can_move;
76   can_move.resize(indices.size());
77   for (size_t i = 0; i < indices.size(); ++i) {
78     can_move[i] = last_use[indices[i]] == i;
79   }
80   return can_move;
81 }
82 
MakeIteratorFromInputElement(IteratorContext * ctx,const std::vector<Tensor> & input_element,int64 thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)83 Status MakeIteratorFromInputElement(
84     IteratorContext* ctx, const std::vector<Tensor>& input_element,
85     int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
86     StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
87   std::vector<Tensor> return_values;
88 
89   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
90                                                             &return_values));
91 
92   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
93         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
94     return errors::InvalidArgument(
95         "Function must return a single scalar of dtype DT_VARIANT.");
96   }
97 
98   // Retrieve the dataset that was created in `f`.
99   DatasetBase* returned_dataset;
100   TF_RETURN_IF_ERROR(
101       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
102 
103   // Create an iterator for the dataset that was returned by `f`.
104   return returned_dataset->MakeIterator(
105       ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
106 }
107 
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)108 Status VerifyTypesMatch(const DataTypeVector& expected,
109                         const DataTypeVector& received) {
110   if (expected.size() != received.size()) {
111     return errors::InvalidArgument(
112         "Number of components does not match: expected ", expected.size(),
113         " types but got ", received.size(), ".");
114   }
115   for (size_t i = 0; i < expected.size(); ++i) {
116     if (expected[i] != received[i]) {
117       return errors::InvalidArgument("Data type mismatch at component ", i,
118                                      ": expected ", DataTypeString(expected[i]),
119                                      " but got ", DataTypeString(received[i]),
120                                      ".");
121     }
122   }
123   return Status::OK();
124 }
125 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)126 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
127                               const std::vector<PartialTensorShape>& received) {
128   if (expected.size() != received.size()) {
129     return errors::InvalidArgument(
130         "Number of components does not match: expected ", expected.size(),
131         " shapes but got ", received.size(), ".");
132   }
133   for (size_t i = 0; i < expected.size(); ++i) {
134     if (!expected[i].IsCompatibleWith(received[i])) {
135       return errors::InvalidArgument("Incompatible shapes at component ", i,
136                                      ": expected ", expected[i].DebugString(),
137                                      " but got ", received[i].DebugString(),
138                                      ".");
139     }
140   }
141 
142   return Status::OK();
143 }
144 
145 namespace {
146 
147 constexpr char kDelimiter[] = "@@";
148 
149 }  // namespace
150 
VariantTensorDataReader(const tensorflow::VariantTensorData * data)151 VariantTensorDataReader::VariantTensorDataReader(
152     const tensorflow::VariantTensorData* data)
153     : data_(data) {
154   string metadata;
155   data_->get_metadata(&metadata);
156   auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
157   for (size_t i = 0; i < keys.size(); ++i) {
158     map_[keys[i]] = i;
159   }
160 }
161 
ReadScalar(StringPiece key,int64 * val)162 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) {
163   return ReadScalarInternal(key, val);
164 }
165 
ReadScalar(StringPiece key,string * val)166 Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) {
167   return ReadScalarInternal(key, val);
168 }
169 
ReadTensor(StringPiece key,Tensor * val)170 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) {
171   return ReadTensorInternal(key, val);
172 }
173 
Contains(StringPiece key)174 bool VariantTensorDataReader::Contains(StringPiece key) {
175   return map_.find(string(key)) != map_.end();
176 }
177 
178 template <typename T>
ReadScalarInternal(StringPiece key,T * val)179 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
180   if (map_.find(string(key)) == map_.end()) {
181     return errors::NotFound(key);
182   }
183   *val = data_->tensors(map_[string(key)]).scalar<T>()();
184   return Status::OK();
185 }
186 
ReadTensorInternal(StringPiece key,Tensor * val)187 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
188                                                    Tensor* val) {
189   if (map_.find(string(key)) == map_.end()) {
190     return errors::NotFound(key);
191   }
192   *val = data_->tensors(map_[string(key)]);
193   return Status::OK();
194 }
195 
WriteScalar(StringPiece key,const int64 val)196 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
197   return WriteScalarInternal(key, val);
198 }
199 
WriteScalar(StringPiece key,const string & val)200 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
201                                             const string& val) {
202   return WriteScalarInternal(key, val);
203 }
204 
WriteTensor(StringPiece key,const Tensor & val)205 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
206                                             const Tensor& val) {
207   return WriteTensorInternal(key, val);
208 }
209 
Flush()210 Status VariantTensorDataWriter::Flush() {
211   string metadata;
212   for (size_t i = 0; i < keys_.size(); ++i) {
213     strings::StrAppend(&metadata, kDelimiter, keys_[i]);
214   }
215   data_->set_metadata(metadata);
216   return Status::OK();
217 }
218 
219 template <typename T>
WriteScalarInternal(StringPiece key,const T & val)220 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
221                                                     const T& val) {
222   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
223   val_t.scalar<T>()() = val;
224   return WriteTensorInternal(key, val_t);
225 }
226 
WriteTensorInternal(StringPiece key,const Tensor & val)227 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
228                                                     const Tensor& val) {
229   DCHECK_EQ(key.find(kDelimiter), string::npos);
230   keys_.push_back(string(key));
231   *(data_->add_tensors()) = val;
232   return Status::OK();
233 }
234 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)235 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
236                             const FunctionLibraryDefinition& to_add) {
237   for (const auto& fn : to_add.ListFunctionNames()) {
238     if (auto found = base->Find(fn)) {
239       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
240         return errors::InvalidArgument("Cannot add function '", fn,
241                                        "' because a different function with "
242                                        "the same signature already exists.");
243       }
244       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
245     }
246   }
247   return base->AddLibrary(to_add);
248 }
249 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)250 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
251                             const FunctionDefLibrary& to_add) {
252   for (const auto& fd : to_add.function()) {
253     if (auto found = base->Find(fd.signature().name())) {
254       if (!OpDefEqual(found->signature(), fd.signature())) {
255         return errors::InvalidArgument("Cannot add function '",
256                                        fd.signature().name(),
257                                        "' because a different function with "
258                                        "the same signature already exists.");
259       }
260       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
261     }
262   }
263   return base->AddLibrary(to_add);
264 }
265 
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)266 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
267     std::function<void(std::function<void()>)> runner, int max_parallelism) {
268   return std::bind(
269       [max_parallelism](
270           // Note: `runner` is a const reference to avoid copying it.
271           const std::function<void(std::function<void()>)>& runner,
272           std::function<void()> fn) {
273         std::function<void()> scoped_fn = std::bind(
274             [max_parallelism](const std::function<void()>& fn) {
275               ScopedPerThreadMaxParallelism scope(max_parallelism);
276               fn();
277             },
278             std::move(fn));
279         runner(std::move(scoped_fn));
280       },
281       std::move(runner), std::placeholders::_1);
282 }
283 }  // namespace data
284 }  // namespace tensorflow
285