1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_test_base.h"
17
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/name_utils.h"
38 #include "tensorflow/core/data/serialization_utils.h"
39 #include "tensorflow/core/data/split_utils.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/cancellation.h"
42 #include "tensorflow/core/framework/control_flow.h"
43 #include "tensorflow/core/framework/dataset.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_handle_cache.h"
47 #include "tensorflow/core/framework/function_testlib.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/numeric_types.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/op_def.pb.h"
52 #include "tensorflow/core/framework/op_kernel.h"
53 #include "tensorflow/core/framework/register_types.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/framework/types.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/framework/variant_tensor_data.h"
61 #include "tensorflow/core/framework/versions.pb.h"
62 #include "tensorflow/core/graph/graph.h"
63 #include "tensorflow/core/lib/core/status_test_util.h"
64 #include "tensorflow/core/lib/gtl/inlined_vector.h"
65 #include "tensorflow/core/lib/io/record_writer.h"
66 #include "tensorflow/core/lib/io/zlib_compression_options.h"
67 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
68 #include "tensorflow/core/platform/bfloat16.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/errors.h"
71 #include "tensorflow/core/platform/file_system.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/platform/status.h"
74 #include "tensorflow/core/platform/test.h"
75 #include "tensorflow/core/platform/threadpool.h"
76 #include "tensorflow/core/platform/tstring.h"
77 #include "tensorflow/core/platform/types.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/public/session_options.h"
80 #include "tensorflow/core/public/version.h"
81 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
82
83 namespace tensorflow {
84 namespace data {
85
ToString(CompressionType compression_type)86 string ToString(CompressionType compression_type) {
87 switch (compression_type) {
88 case CompressionType::ZLIB:
89 return "ZLIB";
90 case CompressionType::GZIP:
91 return "GZIP";
92 case CompressionType::RAW:
93 return "RAW";
94 case CompressionType::UNCOMPRESSED:
95 return "";
96 }
97 }
98
GetZlibCompressionOptions(CompressionType compression_type)99 io::ZlibCompressionOptions GetZlibCompressionOptions(
100 CompressionType compression_type) {
101 switch (compression_type) {
102 case CompressionType::ZLIB:
103 return io::ZlibCompressionOptions::DEFAULT();
104 case CompressionType::GZIP:
105 return io::ZlibCompressionOptions::GZIP();
106 case CompressionType::RAW:
107 return io::ZlibCompressionOptions::RAW();
108 case CompressionType::UNCOMPRESSED:
109 LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
110 << ToString(compression_type);
111 return io::ZlibCompressionOptions::DEFAULT();
112 }
113 }
114
WriteDataToFile(const string & filename,const char * data)115 Status WriteDataToFile(const string& filename, const char* data) {
116 return WriteDataToFile(filename, data, CompressionParams());
117 }
118
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)119 Status WriteDataToFile(const string& filename, const char* data,
120 const CompressionParams& params) {
121 Env* env = Env::Default();
122 std::unique_ptr<WritableFile> file_writer;
123 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
124 if (params.compression_type == CompressionType::UNCOMPRESSED) {
125 TF_RETURN_IF_ERROR(file_writer->Append(data));
126 } else if (params.compression_type == CompressionType::ZLIB ||
127 params.compression_type == CompressionType::GZIP ||
128 params.compression_type == CompressionType::RAW) {
129 auto zlib_compression_options =
130 GetZlibCompressionOptions(params.compression_type);
131 io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
132 params.output_buffer_size,
133 zlib_compression_options);
134 TF_RETURN_IF_ERROR(out.Init());
135 TF_RETURN_IF_ERROR(out.Append(data));
136 TF_RETURN_IF_ERROR(out.Flush());
137 TF_RETURN_IF_ERROR(out.Close());
138 } else {
139 return tensorflow::errors::InvalidArgument(
140 "Unsupported compression_type: ", ToString(params.compression_type));
141 }
142
143 TF_RETURN_IF_ERROR(file_writer->Flush());
144 TF_RETURN_IF_ERROR(file_writer->Close());
145
146 return Status::OK();
147 }
148
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)149 Status WriteDataToTFRecordFile(const string& filename,
150 const std::vector<absl::string_view>& records,
151 const CompressionParams& params) {
152 Env* env = Env::Default();
153 std::unique_ptr<WritableFile> file_writer;
154 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
155 auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
156 ToString(params.compression_type));
157 options.zlib_options.input_buffer_size = params.input_buffer_size;
158 io::RecordWriter record_writer(file_writer.get(), options);
159 for (const auto& record : records) {
160 TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
161 }
162 TF_RETURN_IF_ERROR(record_writer.Flush());
163 TF_RETURN_IF_ERROR(record_writer.Close());
164 TF_RETURN_IF_ERROR(file_writer->Flush());
165 TF_RETURN_IF_ERROR(file_writer->Close());
166 return Status::OK();
167 }
168
169 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)170 Status IsEqual(const Tensor& t1, const Tensor& t2) {
171 if (t1.dtype() != t2.dtype()) {
172 return tensorflow::errors::Internal(
173 "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
174 " vs. ", DataTypeString(t2.dtype()));
175 }
176 if (!t1.IsSameSize(t2)) {
177 return tensorflow::errors::Internal(
178 "Two tensors have different shapes: ", t1.shape().DebugString(),
179 " vs. ", t2.shape().DebugString());
180 }
181
182 auto flat_t1 = t1.flat<T>();
183 auto flat_t2 = t2.flat<T>();
184 auto length = flat_t1.size();
185
186 for (int i = 0; i < length; ++i) {
187 if (flat_t1(i) != flat_t2(i)) {
188 return tensorflow::errors::Internal(
189 "Two tensors have different values "
190 "at [",
191 i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
192 }
193 }
194 return Status::OK();
195 }
196
DatasetOpsTestBase()197 DatasetOpsTestBase::DatasetOpsTestBase()
198 : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
199 device_type_(DEVICE_CPU),
200 cpu_num_(kDefaultCPUNum),
201 thread_num_(kDefaultThreadNum) {
202 allocator_ = device_->GetAllocator(AllocatorAttributes());
203 }
204
~DatasetOpsTestBase()205 DatasetOpsTestBase::~DatasetOpsTestBase() {
206 if (dataset_) {
207 dataset_->Unref();
208 }
209 }
210
ExpectEqual(const Tensor & a,const Tensor & b)211 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
212 switch (a.dtype()) {
213 #define CASE(DT) \
214 case DataTypeToEnum<DT>::value: \
215 TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
216 break;
217 TF_CALL_NUMBER_TYPES(CASE);
218 TF_CALL_tstring(CASE);
219 // TODO(feihugis): figure out how to support variant tensors.
220 #undef CASE
221 default:
222 return errors::Internal("Unsupported dtype: ", a.dtype());
223 }
224 return Status::OK();
225 }
226
227 template <typename T>
compare(const Tensor & t1,const Tensor & t2)228 bool compare(const Tensor& t1, const Tensor& t2) {
229 auto flat_t1 = t1.flat<T>();
230 auto flat_t2 = t2.flat<T>();
231 auto length = std::min(flat_t1.size(), flat_t2.size());
232 for (int i = 0; i < length; ++i) {
233 if (flat_t1(i) < flat_t2(i)) return true;
234 if (flat_t1(i) > flat_t2(i)) return false;
235 }
236 return flat_t1.size() < length;
237 }
238
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)239 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
240 std::vector<Tensor> expected_tensors,
241 bool compare_order) {
242 if (produced_tensors.size() != expected_tensors.size()) {
243 return Status(tensorflow::errors::Internal(
244 "The two tensor vectors have different size (", produced_tensors.size(),
245 " v.s. ", expected_tensors.size(), ")"));
246 }
247
248 if (produced_tensors.empty()) return Status::OK();
249 if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
250 return Status(tensorflow::errors::Internal(
251 "The two tensor vectors have different dtypes (",
252 produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
253 ")"));
254 }
255
256 if (!compare_order) {
257 const DataType& dtype = produced_tensors[0].dtype();
258 switch (dtype) {
259 #define CASE(DT) \
260 case DT: \
261 std::sort(produced_tensors.begin(), produced_tensors.end(), \
262 compare<EnumToDataType<DT>::Type>); \
263 std::sort(expected_tensors.begin(), expected_tensors.end(), \
264 compare<EnumToDataType<DT>::Type>); \
265 break;
266 CASE(DT_FLOAT);
267 CASE(DT_DOUBLE);
268 CASE(DT_INT32);
269 CASE(DT_UINT8);
270 CASE(DT_INT16);
271 CASE(DT_INT8);
272 CASE(DT_STRING);
273 CASE(DT_INT64);
274 CASE(DT_BOOL);
275 CASE(DT_QINT8);
276 CASE(DT_QUINT8);
277 CASE(DT_QINT32);
278 CASE(DT_QINT16);
279 CASE(DT_QUINT16);
280 CASE(DT_UINT16);
281 CASE(DT_HALF);
282 CASE(DT_UINT32);
283 CASE(DT_UINT64);
284 // TODO(feihugis): support other dtypes.
285 #undef CASE
286 default:
287 return errors::Internal("Unsupported dtype: ", dtype);
288 }
289 }
290
291 for (int i = 0; i < produced_tensors.size(); ++i) {
292 TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
293 expected_tensors[i]));
294 }
295 return Status::OK();
296 }
297
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)298 Status DatasetOpsTestBase::CreateOpKernel(
299 const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
300 OpKernel* kernel;
301 Status s;
302
303 std::shared_ptr<const NodeProperties> props;
304 TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
305 node_def, flr_->GetFunctionLibraryDefinition(), &props));
306 TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
307 device_type_, device_.get(), allocator_, flr_,
308 device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
309 op_kernel->reset(kernel);
310 return Status::OK();
311 }
312
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)313 Status DatasetOpsTestBase::CreateDatasetContext(
314 OpKernel* const dateset_kernel,
315 gtl::InlinedVector<TensorValue, 4>* const inputs,
316 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
317 std::unique_ptr<OpKernelContext>* dataset_context) {
318 Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
319 if (!status.ok()) {
320 VLOG(0) << "WARNING: " << status.ToString();
321 }
322 TF_RETURN_IF_ERROR(CreateOpKernelContext(
323 dateset_kernel, inputs, dataset_context_params, dataset_context));
324 return Status::OK();
325 }
326
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)327 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
328 OpKernelContext* context,
329 DatasetBase** const dataset) {
330 TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
331 // Assume that DatasetOp has only one output.
332 DCHECK_EQ(context->num_outputs(), 1);
333 TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
334 return Status::OK();
335 }
336
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)337 Status DatasetOpsTestBase::RestoreIterator(
338 IteratorContext* ctx, IteratorStateReader* reader,
339 const string& output_prefix, const DatasetBase& dataset,
340 std::unique_ptr<IteratorBase>* iterator) {
341 return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
342 iterator);
343 }
344
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)345 Status DatasetOpsTestBase::CreateIteratorContext(
346 OpKernelContext* const op_context,
347 std::unique_ptr<IteratorContext>* iterator_context) {
348 IteratorContext::Params params(op_context);
349 params.resource_mgr = op_context->resource_manager();
350 function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
351 params.function_handle_cache = function_handle_cache_.get();
352 params.cancellation_manager = cancellation_manager_.get();
353 *iterator_context = absl::make_unique<IteratorContext>(params);
354 return Status::OK();
355 }
356
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)357 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
358 int output_index,
359 DatasetBase** const dataset) {
360 Tensor* output = context->mutable_output(output_index);
361 Status status = GetDatasetFromVariantTensor(*output, dataset);
362 (*dataset)->Ref();
363 return status;
364 }
365
InitThreadPool(int thread_num)366 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
367 if (thread_num < 1) {
368 return errors::InvalidArgument(
369 "The `thread_num` argument should be positive but got: ", thread_num);
370 }
371 thread_pool_ = absl::make_unique<thread::ThreadPool>(
372 Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
373 return Status::OK();
374 }
375
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)376 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
377 const std::vector<FunctionDef>& flib, int cpu_num) {
378 if (cpu_num < 1) {
379 return errors::InvalidArgument(
380 "The `cpu_num` argument should be positive but got: ", cpu_num);
381 }
382 SessionOptions options;
383 auto* device_count = options.config.mutable_device_count();
384 device_count->insert({"CPU", cpu_num});
385 std::vector<std::unique_ptr<Device>> devices;
386 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
387 options, "/job:localhost/replica:0/task:0", &devices));
388 device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
389 resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
390
391 FunctionDefLibrary proto;
392 for (const auto& fdef : flib) *(proto.add_function()) = fdef;
393 lib_def_ =
394 absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
395
396 OptimizerOptions opts;
397 pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
398 device_mgr_.get(), Env::Default(), /*config=*/nullptr,
399 TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
400 /*parent=*/nullptr,
401 /*session_metadata=*/nullptr,
402 Rendezvous::Factory{
403 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
404 *r = new IntraProcessRendezvous(device_mgr);
405 return Status::OK();
406 }});
407 flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
408 if (thread_pool_ == nullptr) {
409 runner_ = [](const std::function<void()>& fn) { fn(); };
410 } else {
411 runner_ = [this](std::function<void()> fn) {
412 thread_pool_->Schedule(std::move(fn));
413 };
414 }
415 return Status::OK();
416 }
417
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)418 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
419 OpKernelContext* context) {
420 device_->Compute(op_kernel, context);
421 return context->status();
422 }
423
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)424 Status DatasetOpsTestBase::RunFunction(
425 const FunctionDef& fdef, test::function::Attrs attrs,
426 const std::vector<Tensor>& args,
427 const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
428 std::unique_ptr<Executor> exec;
429 InstantiationResult result;
430 auto GetOpSig = [](const string& op, const OpDef** sig) {
431 return OpRegistry::Global()->LookUpOpDef(op, sig);
432 };
433 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
434
435 DataTypeVector arg_types = result.arg_types;
436 DataTypeVector ret_types = result.ret_types;
437
438 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
439 TF_RETURN_IF_ERROR(
440 ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
441
442 const int version = g->versions().producer();
443 LocalExecutorParams params;
444 params.function_library = flr_;
445 params.device = device_.get();
446 params.create_kernel = [this, version](
447 const std::shared_ptr<const NodeProperties>& props,
448 OpKernel** kernel) {
449 return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
450 kernel);
451 };
452 params.delete_kernel = [](OpKernel* kernel) {
453 DeleteNonCachedKernel(kernel);
454 };
455
456 Executor* cur_exec;
457 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
458 exec.reset(cur_exec);
459 FunctionCallFrame frame(arg_types, ret_types);
460 TF_RETURN_IF_ERROR(frame.SetArgs(args));
461 Executor::Args exec_args;
462 exec_args.call_frame = &frame;
463 exec_args.runner = runner_;
464 TF_RETURN_IF_ERROR(exec->Run(exec_args));
465 std::vector<Tensor> computed;
466 TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
467 if (computed.size() != rets.size()) {
468 return errors::InvalidArgument(
469 "The result does not match the expected number of return outpus",
470 ". Expected: ", rets.size(), ". Actual: ", computed.size());
471 }
472 for (int i = 0; i < rets.size(); ++i) {
473 *(rets[i]) = computed[i];
474 }
475 return Status::OK();
476 }
477
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)478 Status DatasetOpsTestBase::CreateOpKernelContext(
479 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
480 std::unique_ptr<OpKernelContext>* context) {
481 return CreateOpKernelContext(kernel, inputs, ¶ms_, context);
482 }
483
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)484 Status DatasetOpsTestBase::CreateOpKernelContext(
485 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
486 std::unique_ptr<OpKernelContext::Params>* context_params,
487 std::unique_ptr<OpKernelContext>* context) {
488 auto params = absl::make_unique<OpKernelContext::Params>();
489 cancellation_manager_ = absl::make_unique<CancellationManager>();
490 params->cancellation_manager = cancellation_manager_.get();
491 params->device = device_.get();
492 params->frame_iter = FrameAndIter(0, 0);
493 params->function_library = flr_;
494 params->inputs = inputs;
495 params->op_kernel = kernel;
496 params->resource_manager = resource_mgr_.get();
497 params->runner = &runner_;
498 slice_reader_cache_ =
499 absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
500 params->slice_reader_cache = slice_reader_cache_.get();
501 step_container_ =
502 absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
503 params->step_container = step_container_.get();
504
505 // Set the allocator attributes for the outputs.
506 allocator_attrs_.clear();
507 for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
508 AllocatorAttributes attr;
509 const bool on_host =
510 (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
511 attr.set_on_host(on_host);
512 allocator_attrs_.emplace_back(attr);
513 }
514 params->output_attr_array = allocator_attrs_.data();
515
516 *context = absl::make_unique<OpKernelContext>(params.get());
517 *context_params = std::move(params);
518 return Status::OK();
519 }
520
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)521 Status DatasetOpsTestBase::CreateSerializationContext(
522 std::unique_ptr<SerializationContext>* context) {
523 *context =
524 absl::make_unique<SerializationContext>(SerializationContext::Params{});
525 return Status::OK();
526 }
527
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)528 Status DatasetOpsTestBase::CheckOpKernelInput(
529 const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
530 if (kernel.num_inputs() != inputs.size()) {
531 return errors::InvalidArgument("The number of input elements should be ",
532 kernel.num_inputs(),
533 ", but got: ", inputs.size());
534 }
535 return Status::OK();
536 }
537
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)538 Status DatasetOpsTestBase::AddDatasetInput(
539 gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
540 DataType dtype, const TensorShape& shape) {
541 if (input_types.size() < inputs->size()) {
542 return errors::InvalidArgument("Adding more inputs than types: ",
543 inputs->size(), " vs. ", input_types.size());
544 }
545 bool is_ref = IsRefType(input_types[inputs->size()]);
546 auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
547
548 if (is_ref) {
549 DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
550 if (expected_dtype != dtype) {
551 return errors::InvalidArgument("The input data type is ", dtype,
552 " , but expected: ", expected_dtype);
553 }
554 inputs->push_back({&lock_for_refs_, input.get()});
555 } else {
556 if (input_types[inputs->size()] != dtype) {
557 return errors::InvalidArgument(
558 "The input data type is ", dtype,
559 " , but expected: ", input_types[inputs->size()]);
560 }
561 inputs->push_back({nullptr, input.get()});
562 }
563
564 // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
565 // collect the inputs.
566 tensors_.push_back(std::move(input));
567
568 return Status::OK();
569 }
570
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)571 Status DatasetOpsTestBase::CheckIteratorGetNext(
572 const std::vector<Tensor>& expected_outputs, bool compare_order) {
573 return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
574 expected_outputs, compare_order);
575 }
576
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)577 Status DatasetOpsTestBase::CheckIteratorGetNext(
578 TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
579 bool compare_order) {
580 return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
581 expected_outputs, compare_order);
582 }
583
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)584 Status DatasetOpsTestBase::CheckIteratorGetNext(
585 IteratorBase* iterator, IteratorContext* ctx,
586 const std::vector<Tensor>& expected_outputs, bool compare_order) {
587 bool end_of_sequence = false;
588 std::vector<Tensor> out_tensors;
589 while (!end_of_sequence) {
590 std::vector<Tensor> next;
591 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
592 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
593 }
594 // Call GetNext one more time to make sure it still reports
595 // end_of_sequence = True.
596 std::vector<Tensor> unused;
597 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
598 EXPECT_TRUE(end_of_sequence);
599
600 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
601 /*compare_order=*/compare_order));
602 return Status::OK();
603 }
604
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)605 Status DatasetOpsTestBase::CheckIteratorSkip(
606 int num_to_skip, int expected_num_skipped, bool get_next,
607 const std::vector<Tensor>& expected_outputs, bool compare_order) {
608 IteratorBase* iterator = iterator_.get();
609 IteratorContext* ctx = iterator_ctx_.get();
610
611 bool end_of_sequence = false;
612 int num_skipped = 0;
613 TF_RETURN_IF_ERROR(
614 iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
615 EXPECT_TRUE(num_skipped == expected_num_skipped);
616 if (get_next) {
617 EXPECT_TRUE(!end_of_sequence);
618 std::vector<Tensor> out_tensors;
619 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
620 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
621 /*compare_order=*/compare_order));
622 }
623 return Status::OK();
624 }
625
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)626 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
627 const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
628 std::unique_ptr<TestDataset> dataset;
629 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
630 std::vector<std::unique_ptr<SplitProvider>> split_providers;
631 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
632 std::unique_ptr<TestIterator> iterator;
633 TF_RETURN_IF_ERROR(
634 MakeIterator(params, *dataset, std::move(split_providers), &iterator));
635 TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
636 /*compare_order=*/true));
637 return Status::OK();
638 }
639
CheckSplitProviderShardedIteration(const DatasetParams & params,int64_t num_shards,int64_t shard_index,const std::vector<Tensor> & expected_outputs)640 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
641 const DatasetParams& params, int64_t num_shards, int64_t shard_index,
642 const std::vector<Tensor>& expected_outputs) {
643 std::unique_ptr<TestDataset> dataset;
644 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
645 std::vector<std::unique_ptr<SplitProvider>> split_providers;
646 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
647 for (int i = 0; i < split_providers.size(); ++i) {
648 split_providers[i] = std::make_unique<ShardingSplitProvider>(
649 num_shards, shard_index, std::move(split_providers[i]));
650 }
651 std::unique_ptr<IteratorContext> iterator_ctx;
652 TF_RETURN_IF_ERROR(
653 CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
654 IteratorContext::Params iterator_params(iterator_ctx.get());
655 std::move(split_providers.begin(), split_providers.end(),
656 std::back_inserter(iterator_params.split_providers));
657 iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
658 int mid_breakpoint = expected_outputs.size() / 2;
659 int near_end_breakpoint = expected_outputs.size() - 1;
660 int end_breakpoint = expected_outputs.size();
661 TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
662 dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
663 expected_outputs,
664 /*breakpoints=*/
665 {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
666 /*compare_order=*/true));
667 return Status::OK();
668 }
669
CheckDatasetNodeName(const string & expected_dataset_node_name)670 Status DatasetOpsTestBase::CheckDatasetNodeName(
671 const string& expected_dataset_node_name) {
672 EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
673 return Status::OK();
674 }
675
CheckDatasetTypeString(const string & expected_type_str)676 Status DatasetOpsTestBase::CheckDatasetTypeString(
677 const string& expected_type_str) {
678 EXPECT_EQ(dataset_->type_string(), expected_type_str);
679 return Status::OK();
680 }
681
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)682 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
683 const DataTypeVector& expected_output_dtypes) {
684 TF_EXPECT_OK(
685 VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
686 return Status::OK();
687 }
688
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)689 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
690 const std::vector<PartialTensorShape>& expected_output_shapes) {
691 TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
692 expected_output_shapes));
693 return Status::OK();
694 }
695
CheckDatasetCardinality(int expected_cardinality)696 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
697 EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
698 return Status::OK();
699 }
700
CheckDatasetOptions(const Options & expected_options)701 Status DatasetOpsTestBase::CheckDatasetOptions(
702 const Options& expected_options) {
703 EXPECT_EQ(dataset_->options().SerializeAsString(),
704 expected_options.SerializeAsString());
705 return Status::OK();
706 }
707
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)708 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
709 const DataTypeVector& expected_output_dtypes) {
710 TF_EXPECT_OK(
711 VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
712 return Status::OK();
713 }
714
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)715 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
716 const std::vector<PartialTensorShape>& expected_output_shapes) {
717 TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
718 expected_output_shapes));
719 return Status::OK();
720 }
721
CheckIteratorPrefix(const string & expected_iterator_prefix)722 Status DatasetOpsTestBase::CheckIteratorPrefix(
723 const string& expected_iterator_prefix) {
724 EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
725 return Status::OK();
726 }
727
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)728 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
729 DatasetBase* dataset, IteratorContext* iterator_ctx,
730 const std::string& iterator_prefix,
731 const std::vector<Tensor>& expected_outputs,
732 const std::vector<int>& breakpoints, bool compare_order) {
733 std::unique_ptr<IteratorBase> iterator;
734 TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
735 iterator_prefix, &iterator));
736 std::unique_ptr<SerializationContext> serialization_ctx;
737 TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
738 bool end_of_sequence = false;
739 int cur_iteration = 0;
740 std::vector<Tensor> out_tensors;
741 for (int breakpoint : breakpoints) {
742 VariantTensorDataWriter writer;
743 TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
744 std::vector<const VariantTensorData*> data;
745 writer.GetData(&data);
746 VariantTensorDataReader reader(data);
747 TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
748 *dataset, &iterator));
749
750 while (cur_iteration <= breakpoint) {
751 std::vector<Tensor> next;
752 TF_RETURN_IF_ERROR(
753 iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
754 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
755 cur_iteration++;
756 }
757 }
758 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
759 /*compare_order=*/compare_order));
760 return Status::OK();
761 }
762
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)763 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
764 const std::string& iterator_prefix,
765 const std::vector<Tensor>& expected_outputs,
766 const std::vector<int>& breakpoints, bool compare_order) {
767 return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
768 iterator_prefix, expected_outputs,
769 breakpoints, compare_order);
770 }
771
Initialize(const DatasetParams & dataset_params)772 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
773 if (initialized_) {
774 return errors::Internal(
775 "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
776 "iterator_ctx_, iterator_) have already been initialized.");
777 }
778 TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
779 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
780 &dataset_ctx_, &tensors_, &dataset_));
781 TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
782 TF_RETURN_IF_ERROR(
783 dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
784 dataset_params.iterator_prefix(), &iterator_));
785 initialized_ = true;
786 return Status::OK();
787 }
788
InitializeRuntime(const DatasetParams & dataset_params)789 Status DatasetOpsTestBase::InitializeRuntime(
790 const DatasetParams& dataset_params) {
791 TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
792 TF_RETURN_IF_ERROR(
793 InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
794 return Status::OK();
795 }
796
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)797 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
798 std::unique_ptr<TestDataset>* dataset) {
799 DatasetBase* dataset_base;
800 std::unique_ptr<OpKernel> dataset_kernel;
801 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
802 std::unique_ptr<OpKernelContext> dataset_ctx;
803 std::vector<std::unique_ptr<Tensor>> created_tensors;
804 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
805 &dataset_ctx_params, &dataset_ctx,
806 &created_tensors, &dataset_base));
807 *dataset = std::make_unique<TestDataset>(
808 std::move(dataset_kernel), std::move(dataset_ctx_params),
809 std::move(dataset_ctx), std::move(created_tensors), dataset_base);
810 return Status::OK();
811 }
812
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)813 Status DatasetOpsTestBase::RunDatasetOp(
814 const DatasetParams& dataset_params,
815 std::unique_ptr<OpKernel>* dataset_kernel,
816 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
817 std::vector<std::unique_ptr<Tensor>>* created_tensors,
818 std::unique_ptr<OpKernelContext>* dataset_ctx) {
819 std::vector<Tensor*> input_datasets;
820 for (auto& input : dataset_params.input_dataset_params()) {
821 std::unique_ptr<Tensor> t;
822 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
823 input_datasets.push_back(t.get());
824 created_tensors->push_back(std::move(t));
825 }
826 gtl::InlinedVector<TensorValue, 4> inputs;
827 for (auto input_dataset : input_datasets) {
828 inputs.emplace_back(TensorValue(input_dataset));
829 }
830
831 // Copy the input tensors, storing them in the `inputs` vectors, and storing
832 // owned references to the copies in `created_tensors`.
833 for (auto& input : dataset_params.GetInputTensors()) {
834 auto copy = absl::make_unique<Tensor>(input);
835 inputs.push_back(TensorValue(copy.get()));
836 created_tensors->push_back(std::move(copy));
837 }
838
839 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
840 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
841 dataset_ctx_params, dataset_ctx));
842 TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
843 return Status::OK();
844 }
845
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)846 Status DatasetOpsTestBase::MakeDataset(
847 const DatasetParams& dataset_params,
848 std::unique_ptr<OpKernel>* dataset_kernel,
849 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
850 std::unique_ptr<OpKernelContext>* dataset_ctx,
851 std::vector<std::unique_ptr<Tensor>>* created_tensors,
852 DatasetBase** dataset) {
853 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
854 dataset_ctx_params, created_tensors,
855 dataset_ctx));
856 // Assume that DatasetOp has only one output.
857 DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
858 TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
859 return Status::OK();
860 }
861
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<TestIterator> * iterator)862 Status DatasetOpsTestBase::MakeIterator(
863 const DatasetParams& dataset_params, const TestDataset& dataset,
864 std::vector<std::unique_ptr<SplitProvider>> split_providers,
865 std::unique_ptr<TestIterator>* iterator) {
866 std::unique_ptr<IteratorContext> iterator_ctx;
867 TF_RETURN_IF_ERROR(
868 CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
869 IteratorContext::Params iterator_params(iterator_ctx.get());
870 std::move(split_providers.begin(), split_providers.end(),
871 std::back_inserter(iterator_params.split_providers));
872
873 iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
874 std::unique_ptr<IteratorBase> iterator_base;
875 TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
876 iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
877 &iterator_base));
878 *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
879 std::move(iterator_base));
880 return Status::OK();
881 }
882
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)883 Status DatasetOpsTestBase::MakeIterator(
884 const DatasetParams& dataset_params, const TestDataset& dataset,
885 std::unique_ptr<TestIterator>* iterator) {
886 return MakeIterator(dataset_params, dataset, /*split_providers=*/{},
887 iterator);
888 }
889
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)890 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
891 std::vector<Tensor>* outputs) {
892 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, ¶ms_,
893 &tensors_, &dataset_ctx_));
894 for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
895 outputs->emplace_back(*dataset_ctx_->mutable_output(i));
896 }
897 return Status::OK();
898 }
899
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)900 Status DatasetOpsTestBase::MakeDatasetOpKernel(
901 const DatasetParams& dataset_params,
902 std::unique_ptr<OpKernel>* dataset_kernel) {
903 name_utils::OpNameParams params;
904 params.op_version = dataset_params.op_version();
905 std::vector<string> input_names;
906 TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
907 AttributeVector attributes;
908 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
909 NodeDef node_def =
910 test::function::NDef(dataset_params.node_name(), dataset_params.op_name(),
911 input_names, attributes);
912 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
913 return Status::OK();
914 }
915
MakeGetOptionsOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * op_kernel)916 Status DatasetOpsTestBase::MakeGetOptionsOpKernel(
917 const DatasetParams& dataset_params, std::unique_ptr<OpKernel>* op_kernel) {
918 name_utils::OpNameParams params;
919 params.op_version = dataset_params.op_version();
920 std::vector<string> input_names;
921 TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
922 AttributeVector attributes;
923 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
924 NodeDef node_def = test::function::NDef(dataset_params.node_name(),
925 dataset_params.dataset_type(),
926 input_names, attributes);
927 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
928 return Status::OK();
929 }
930
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)931 Status DatasetOpsTestBase::MakeDatasetTensor(
932 const DatasetParams& dataset_params,
933 std::vector<std::unique_ptr<Tensor>>* created_tensors,
934 std::unique_ptr<Tensor>* dataset) {
935 // Make sure all the input dataset tensors have been populated.
936 std::vector<Tensor*> input_datasets;
937 for (auto& input : dataset_params.input_dataset_params()) {
938 std::unique_ptr<Tensor> t;
939 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
940 input_datasets.push_back(t.get());
941 created_tensors->push_back(std::move(t));
942 }
943
944 AttributeVector attributes;
945 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
946
947 auto input_tensors = dataset_params.GetInputTensors();
948 gtl::InlinedVector<TensorValue, 4> inputs;
949 inputs.reserve(input_datasets.size() + input_tensors.size());
950 for (auto input_dataset : input_datasets) {
951 inputs.emplace_back(TensorValue(input_dataset));
952 }
953 for (auto& input_tensor : input_tensors) {
954 inputs.emplace_back(TensorValue(&input_tensor));
955 }
956
957 DatasetBase* dataset_base;
958 std::unique_ptr<OpKernel> dataset_kernel;
959 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
960 std::unique_ptr<OpKernelContext> dataset_ctx;
961 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
962 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
963 &dataset_ctx_params, &dataset_ctx));
964 TF_RETURN_IF_ERROR(
965 CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
966 Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
967 TF_RETURN_IF_ERROR(
968 StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
969 *dataset = absl::make_unique<Tensor>(dataset_tensor);
970 return Status::OK();
971 }
972
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)973 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
974 std::vector<PartialTensorShape> output_shapes,
975 string node_name)
976 : output_dtypes_(std::move(output_dtypes)),
977 output_shapes_(std::move(output_shapes)),
978 node_name_(std::move(node_name)) {}
979
IsDatasetTensor(const Tensor & tensor)980 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
981 return tensor.dtype() == DT_VARIANT &&
982 TensorShapeUtils::IsScalar(tensor.shape());
983 }
984
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)985 RangeDatasetParams::RangeDatasetParams(
986 int64_t start, int64_t stop, int64_t step, DataTypeVector output_dtypes,
987 std::vector<PartialTensorShape> output_shapes, string node_name)
988 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
989 std::move(node_name)),
990 start_(start),
991 stop_(stop),
992 step_(step) {}
993
RangeDatasetParams(int64_t start,int64_t stop,int64_t step)994 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
995 int64_t step)
996 : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
997 start_(start),
998 stop_(stop),
999 step_(step) {}
1000
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes)1001 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1002 int64_t step,
1003 DataTypeVector output_dtypes)
1004 : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
1005 "range_dataset"),
1006 start_(start),
1007 stop_(stop),
1008 step_(step) {}
1009
GetInputTensors() const1010 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
1011 Tensor start_tensor = CreateTensor<int64>(TensorShape({}), {start_});
1012 Tensor stop_tensor = CreateTensor<int64>(TensorShape({}), {stop_});
1013 Tensor step_tensor = CreateTensor<int64>(TensorShape({}), {step_});
1014 return {start_tensor, stop_tensor, step_tensor};
1015 }
1016
GetInputNames(std::vector<string> * input_names) const1017 Status RangeDatasetParams::GetInputNames(
1018 std::vector<string>* input_names) const {
1019 *input_names = {"start", "stop", "step"};
1020 return Status::OK();
1021 }
1022
GetAttributes(AttributeVector * attr_vector) const1023 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1024 *attr_vector = {{"output_types", output_dtypes_},
1025 {"output_shapes", output_shapes_}};
1026 return Status::OK();
1027 }
1028
dataset_type() const1029 string RangeDatasetParams::dataset_type() const { return "Range"; }
1030
GetInputTensors() const1031 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1032 Tensor batch_size = CreateTensor<int64>(TensorShape({}), {batch_size_});
1033 Tensor drop_remainder =
1034 CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1035 return {batch_size, drop_remainder};
1036 }
1037
GetInputNames(std::vector<string> * input_names) const1038 Status BatchDatasetParams::GetInputNames(
1039 std::vector<string>* input_names) const {
1040 *input_names = {"input_dataset", "batch_size", "drop_remainder"};
1041 return Status::OK();
1042 }
1043
GetAttributes(AttributeVector * attr_vector) const1044 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1045 *attr_vector = {{"parallel_copy", parallel_copy_},
1046 {"output_types", output_dtypes_},
1047 {"output_shapes", output_shapes_}};
1048 return Status::OK();
1049 }
1050
dataset_type() const1051 string BatchDatasetParams::dataset_type() const { return "Batch"; }
1052
GetInputTensors() const1053 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1054 return other_arguments_;
1055 }
1056
GetInputNames(std::vector<string> * input_names) const1057 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1058 input_names->emplace_back("input_dataset");
1059 for (int i = 0; i < other_arguments_.size(); ++i) {
1060 input_names->emplace_back(absl::StrCat("other_arguments_", i));
1061 }
1062 return Status::OK();
1063 }
1064
GetAttributes(AttributeVector * attr_vector) const1065 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1066 *attr_vector = {{"f", func_},
1067 {"Targuments", type_arguments_},
1068 {"output_shapes", output_shapes_},
1069 {"output_types", output_dtypes_},
1070 {"use_inter_op_parallelism", use_inter_op_parallelism_},
1071 {"preserve_cardinality", preserve_cardinality_}};
1072 return Status::OK();
1073 }
1074
dataset_type() const1075 string MapDatasetParams::dataset_type() const { return "Map"; }
1076
func_lib() const1077 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1078 return func_lib_;
1079 }
1080
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name)1081 TensorSliceDatasetParams::TensorSliceDatasetParams(
1082 std::vector<Tensor> components, string node_name)
1083 : DatasetParams(TensorSliceDtypes(components),
1084 TensorSliceShapes(components), std::move(node_name)),
1085 components_(std::move(components)) {}
1086
GetInputTensors() const1087 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1088 return components_;
1089 }
1090
GetInputNames(std::vector<string> * input_names) const1091 Status TensorSliceDatasetParams::GetInputNames(
1092 std::vector<string>* input_names) const {
1093 input_names->reserve(components_.size());
1094 for (int i = 0; i < components_.size(); ++i) {
1095 input_names->emplace_back(absl::StrCat("components_", i));
1096 }
1097 return Status::OK();
1098 }
1099
GetAttributes(AttributeVector * attr_vector) const1100 Status TensorSliceDatasetParams::GetAttributes(
1101 AttributeVector* attr_vector) const {
1102 *attr_vector = {{"Toutput_types", output_dtypes_},
1103 {"output_shapes", output_shapes_}};
1104 return Status::OK();
1105 }
1106
TensorSliceDtypes(const std::vector<Tensor> & input_components)1107 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1108 const std::vector<Tensor>& input_components) {
1109 DataTypeVector dtypes;
1110 for (const auto& component : input_components) {
1111 dtypes.emplace_back(component.dtype());
1112 }
1113 return dtypes;
1114 }
1115
TensorSliceShapes(const std::vector<Tensor> & input_components)1116 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1117 const std::vector<Tensor>& input_components) {
1118 std::vector<PartialTensorShape> shapes;
1119 for (const auto& component : input_components) {
1120 gtl::InlinedVector<int64, 4> partial_dim_sizes;
1121 for (int i = 1; i < component.dims(); ++i) {
1122 partial_dim_sizes.push_back(component.dim_size(i));
1123 }
1124 shapes.emplace_back(std::move(partial_dim_sizes));
1125 }
1126 return shapes;
1127 }
1128
dataset_type() const1129 string TensorSliceDatasetParams::dataset_type() const { return "TensorSlice"; }
1130
GetInputTensors() const1131 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1132 return {CreateTensor<int64>(TensorShape({}), {count_})};
1133 }
1134
GetInputNames(std::vector<string> * input_names) const1135 Status TakeDatasetParams::GetInputNames(
1136 std::vector<string>* input_names) const {
1137 *input_names = {"input_dataset", "count"};
1138 return Status::OK();
1139 }
1140
GetAttributes(AttributeVector * attr_vector) const1141 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1142 *attr_vector = {{"output_shapes", output_shapes_},
1143 {"output_types", output_dtypes_}};
1144 return Status::OK();
1145 }
1146
dataset_type() const1147 string TakeDatasetParams::dataset_type() const { return "Take"; }
1148
GetInputTensors() const1149 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1150 return {};
1151 }
1152
GetInputNames(std::vector<string> * input_names) const1153 Status ConcatenateDatasetParams::GetInputNames(
1154 std::vector<string>* input_names) const {
1155 *input_names = {"input_dataset", "another_dataset"};
1156 return Status::OK();
1157 }
1158
GetAttributes(AttributeVector * attr_vector) const1159 Status ConcatenateDatasetParams::GetAttributes(
1160 AttributeVector* attr_vector) const {
1161 *attr_vector = {{"output_types", output_dtypes_},
1162 {"output_shapes", output_shapes_}};
1163 return Status::OK();
1164 }
1165
dataset_type() const1166 string ConcatenateDatasetParams::dataset_type() const { return "Concatenate"; }
1167
GetInputTensors() const1168 std::vector<Tensor> OptionsDatasetParams::GetInputTensors() const { return {}; }
1169
GetInputNames(std::vector<string> * input_names) const1170 Status OptionsDatasetParams::GetInputNames(
1171 std::vector<string>* input_names) const {
1172 input_names->emplace_back("input_dataset");
1173 return Status::OK();
1174 }
1175
GetAttributes(AttributeVector * attr_vector) const1176 Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1177 *attr_vector = {{"serialized_options", serialized_options_},
1178 {"output_shapes", output_shapes_},
1179 {"output_types", output_dtypes_}};
1180 return Status::OK();
1181 }
1182
dataset_type() const1183 string OptionsDatasetParams::dataset_type() const { return "Options"; }
1184
1185 } // namespace data
1186 } // namespace tensorflow
1187