1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/framework/allocator.h"
37 #include "tensorflow/core/framework/cancellation.h"
38 #include "tensorflow/core/framework/control_flow.h"
39 #include "tensorflow/core/framework/dataset.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/function.pb.h"
42 #include "tensorflow/core/framework/function_handle_cache.h"
43 #include "tensorflow/core/framework/function_testlib.h"
44 #include "tensorflow/core/framework/node_def.pb.h"
45 #include "tensorflow/core/framework/numeric_types.h"
46 #include "tensorflow/core/framework/op.h"
47 #include "tensorflow/core/framework/op_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/register_types.h"
50 #include "tensorflow/core/framework/rendezvous.h"
51 #include "tensorflow/core/framework/resource_mgr.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/types.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/framework/variant_tensor_data.h"
57 #include "tensorflow/core/framework/versions.pb.h"
58 #include "tensorflow/core/graph/graph.h"
59 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
60 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
61 #include "tensorflow/core/kernels/data/dataset_utils.h"
62 #include "tensorflow/core/kernels/data/map_dataset_op.h"
63 #include "tensorflow/core/kernels/data/name_utils.h"
64 #include "tensorflow/core/kernels/data/range_dataset_op.h"
65 #include "tensorflow/core/kernels/data/split_utils.h"
66 #include "tensorflow/core/kernels/data/take_dataset_op.h"
67 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
68 #include "tensorflow/core/lib/core/status_test_util.h"
69 #include "tensorflow/core/lib/gtl/inlined_vector.h"
70 #include "tensorflow/core/lib/io/record_writer.h"
71 #include "tensorflow/core/lib/io/zlib_compression_options.h"
72 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
73 #include "tensorflow/core/platform/bfloat16.h"
74 #include "tensorflow/core/platform/env.h"
75 #include "tensorflow/core/platform/errors.h"
76 #include "tensorflow/core/platform/file_system.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/platform/status.h"
79 #include "tensorflow/core/platform/test.h"
80 #include "tensorflow/core/platform/threadpool.h"
81 #include "tensorflow/core/platform/tstring.h"
82 #include "tensorflow/core/platform/types.h"
83 #include "tensorflow/core/protobuf/config.pb.h"
84 #include "tensorflow/core/public/session_options.h"
85 #include "tensorflow/core/public/version.h"
86 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
87
88 namespace tensorflow {
89 namespace data {
90
ToString(CompressionType compression_type)91 string ToString(CompressionType compression_type) {
92 switch (compression_type) {
93 case CompressionType::ZLIB:
94 return "ZLIB";
95 case CompressionType::GZIP:
96 return "GZIP";
97 case CompressionType::RAW:
98 return "RAW";
99 case CompressionType::UNCOMPRESSED:
100 return "";
101 }
102 }
103
GetZlibCompressionOptions(CompressionType compression_type)104 io::ZlibCompressionOptions GetZlibCompressionOptions(
105 CompressionType compression_type) {
106 switch (compression_type) {
107 case CompressionType::ZLIB:
108 return io::ZlibCompressionOptions::DEFAULT();
109 case CompressionType::GZIP:
110 return io::ZlibCompressionOptions::GZIP();
111 case CompressionType::RAW:
112 return io::ZlibCompressionOptions::RAW();
113 case CompressionType::UNCOMPRESSED:
114 LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
115 << ToString(compression_type);
116 return io::ZlibCompressionOptions::DEFAULT();
117 }
118 }
119
WriteDataToFile(const string & filename,const char * data)120 Status WriteDataToFile(const string& filename, const char* data) {
121 return WriteDataToFile(filename, data, CompressionParams());
122 }
123
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)124 Status WriteDataToFile(const string& filename, const char* data,
125 const CompressionParams& params) {
126 Env* env = Env::Default();
127 std::unique_ptr<WritableFile> file_writer;
128 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
129 if (params.compression_type == CompressionType::UNCOMPRESSED) {
130 TF_RETURN_IF_ERROR(file_writer->Append(data));
131 } else if (params.compression_type == CompressionType::ZLIB ||
132 params.compression_type == CompressionType::GZIP ||
133 params.compression_type == CompressionType::RAW) {
134 auto zlib_compression_options =
135 GetZlibCompressionOptions(params.compression_type);
136 io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
137 params.output_buffer_size,
138 zlib_compression_options);
139 TF_RETURN_IF_ERROR(out.Init());
140 TF_RETURN_IF_ERROR(out.Append(data));
141 TF_RETURN_IF_ERROR(out.Flush());
142 TF_RETURN_IF_ERROR(out.Close());
143 } else {
144 return tensorflow::errors::InvalidArgument(
145 "Unsupported compression_type: ", ToString(params.compression_type));
146 }
147
148 TF_RETURN_IF_ERROR(file_writer->Flush());
149 TF_RETURN_IF_ERROR(file_writer->Close());
150
151 return Status::OK();
152 }
153
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)154 Status WriteDataToTFRecordFile(const string& filename,
155 const std::vector<absl::string_view>& records,
156 const CompressionParams& params) {
157 Env* env = Env::Default();
158 std::unique_ptr<WritableFile> file_writer;
159 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
160 auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
161 ToString(params.compression_type));
162 options.zlib_options.input_buffer_size = params.input_buffer_size;
163 io::RecordWriter record_writer(file_writer.get(), options);
164 for (const auto& record : records) {
165 TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
166 }
167 TF_RETURN_IF_ERROR(record_writer.Flush());
168 TF_RETURN_IF_ERROR(record_writer.Close());
169 TF_RETURN_IF_ERROR(file_writer->Flush());
170 TF_RETURN_IF_ERROR(file_writer->Close());
171 return Status::OK();
172 }
173
174 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)175 Status IsEqual(const Tensor& t1, const Tensor& t2) {
176 if (t1.dtype() != t2.dtype()) {
177 return tensorflow::errors::Internal(
178 "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
179 " vs. ", DataTypeString(t2.dtype()));
180 }
181 if (!t1.IsSameSize(t2)) {
182 return tensorflow::errors::Internal(
183 "Two tensors have different shapes: ", t1.shape().DebugString(),
184 " vs. ", t2.shape().DebugString());
185 }
186
187 auto flat_t1 = t1.flat<T>();
188 auto flat_t2 = t2.flat<T>();
189 auto length = flat_t1.size();
190
191 for (int i = 0; i < length; ++i) {
192 if (flat_t1(i) != flat_t2(i)) {
193 return tensorflow::errors::Internal(
194 "Two tensors have different values "
195 "at [",
196 i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
197 }
198 }
199 return Status::OK();
200 }
201
DatasetOpsTestBase()202 DatasetOpsTestBase::DatasetOpsTestBase()
203 : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
204 device_type_(DEVICE_CPU),
205 cpu_num_(kDefaultCPUNum),
206 thread_num_(kDefaultThreadNum) {
207 allocator_ = device_->GetAllocator(AllocatorAttributes());
208 }
209
~DatasetOpsTestBase()210 DatasetOpsTestBase::~DatasetOpsTestBase() {
211 if (dataset_) {
212 dataset_->Unref();
213 }
214 }
215
ExpectEqual(const Tensor & a,const Tensor & b)216 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
217 switch (a.dtype()) {
218 #define CASE(DT) \
219 case DataTypeToEnum<DT>::value: \
220 TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
221 break;
222 TF_CALL_NUMBER_TYPES(CASE);
223 TF_CALL_tstring(CASE);
224 // TODO(feihugis): figure out how to support variant tensors.
225 #undef CASE
226 default:
227 return errors::Internal("Unsupported dtype: ", a.dtype());
228 }
229 return Status::OK();
230 }
231
232 template <typename T>
compare(const Tensor & t1,const Tensor & t2)233 bool compare(const Tensor& t1, const Tensor& t2) {
234 auto flat_t1 = t1.flat<T>();
235 auto flat_t2 = t2.flat<T>();
236 auto length = std::min(flat_t1.size(), flat_t2.size());
237 for (int i = 0; i < length; ++i) {
238 if (flat_t1(i) < flat_t2(i)) return true;
239 if (flat_t1(i) > flat_t2(i)) return false;
240 }
241 return flat_t1.size() < length;
242 }
243
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)244 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
245 std::vector<Tensor> expected_tensors,
246 bool compare_order) {
247 if (produced_tensors.size() != expected_tensors.size()) {
248 return Status(tensorflow::errors::Internal(
249 "The two tensor vectors have different size (", produced_tensors.size(),
250 " v.s. ", expected_tensors.size(), ")"));
251 }
252
253 if (produced_tensors.empty()) return Status::OK();
254 if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
255 return Status(tensorflow::errors::Internal(
256 "The two tensor vectors have different dtypes (",
257 produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
258 ")"));
259 }
260
261 if (!compare_order) {
262 const DataType& dtype = produced_tensors[0].dtype();
263 switch (dtype) {
264 #define CASE(DT) \
265 case DT: \
266 std::sort(produced_tensors.begin(), produced_tensors.end(), \
267 compare<EnumToDataType<DT>::Type>); \
268 std::sort(expected_tensors.begin(), expected_tensors.end(), \
269 compare<EnumToDataType<DT>::Type>); \
270 break;
271 CASE(DT_FLOAT);
272 CASE(DT_DOUBLE);
273 CASE(DT_INT32);
274 CASE(DT_UINT8);
275 CASE(DT_INT16);
276 CASE(DT_INT8);
277 CASE(DT_STRING);
278 CASE(DT_INT64);
279 CASE(DT_BOOL);
280 CASE(DT_QINT8);
281 CASE(DT_QUINT8);
282 CASE(DT_QINT32);
283 CASE(DT_QINT16);
284 CASE(DT_QUINT16);
285 CASE(DT_UINT16);
286 CASE(DT_HALF);
287 CASE(DT_UINT32);
288 CASE(DT_UINT64);
289 // TODO(feihugis): support other dtypes.
290 #undef CASE
291 default:
292 return errors::Internal("Unsupported dtype: ", dtype);
293 }
294 }
295
296 for (int i = 0; i < produced_tensors.size(); ++i) {
297 TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
298 expected_tensors[i]));
299 }
300 return Status::OK();
301 }
302
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)303 Status DatasetOpsTestBase::CreateOpKernel(
304 const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
305 OpKernel* kernel;
306 Status s;
307
308 std::shared_ptr<const NodeProperties> props;
309 TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
310 node_def, flr_->GetFunctionLibraryDefinition(), &props));
311 TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
312 device_type_, device_.get(), allocator_, flr_,
313 device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
314 op_kernel->reset(kernel);
315 return Status::OK();
316 }
317
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)318 Status DatasetOpsTestBase::CreateDatasetContext(
319 OpKernel* const dateset_kernel,
320 gtl::InlinedVector<TensorValue, 4>* const inputs,
321 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
322 std::unique_ptr<OpKernelContext>* dataset_context) {
323 Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
324 if (!status.ok()) {
325 VLOG(0) << "WARNING: " << status.ToString();
326 }
327 TF_RETURN_IF_ERROR(CreateOpKernelContext(
328 dateset_kernel, inputs, dataset_context_params, dataset_context));
329 return Status::OK();
330 }
331
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)332 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
333 OpKernelContext* context,
334 DatasetBase** const dataset) {
335 TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
336 // Assume that DatasetOp has only one output.
337 DCHECK_EQ(context->num_outputs(), 1);
338 TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
339 return Status::OK();
340 }
341
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)342 Status DatasetOpsTestBase::RestoreIterator(
343 IteratorContext* ctx, IteratorStateReader* reader,
344 const string& output_prefix, const DatasetBase& dataset,
345 std::unique_ptr<IteratorBase>* iterator) {
346 return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
347 iterator);
348 }
349
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)350 Status DatasetOpsTestBase::CreateIteratorContext(
351 OpKernelContext* const op_context,
352 std::unique_ptr<IteratorContext>* iterator_context) {
353 IteratorContext::Params params(op_context);
354 params.resource_mgr = op_context->resource_manager();
355 function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
356 params.function_handle_cache = function_handle_cache_.get();
357 params.cancellation_manager = cancellation_manager_.get();
358 *iterator_context = absl::make_unique<IteratorContext>(params);
359 return Status::OK();
360 }
361
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)362 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
363 int output_index,
364 DatasetBase** const dataset) {
365 Tensor* output = context->mutable_output(output_index);
366 Status status = GetDatasetFromVariantTensor(*output, dataset);
367 (*dataset)->Ref();
368 return status;
369 }
370
InitThreadPool(int thread_num)371 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
372 if (thread_num < 1) {
373 return errors::InvalidArgument(
374 "The `thread_num` argument should be positive but got: ", thread_num);
375 }
376 thread_pool_ = absl::make_unique<thread::ThreadPool>(
377 Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
378 return Status::OK();
379 }
380
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)381 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
382 const std::vector<FunctionDef>& flib, int cpu_num) {
383 if (cpu_num < 1) {
384 return errors::InvalidArgument(
385 "The `cpu_num` argument should be positive but got: ", cpu_num);
386 }
387 SessionOptions options;
388 auto* device_count = options.config.mutable_device_count();
389 device_count->insert({"CPU", cpu_num});
390 std::vector<std::unique_ptr<Device>> devices;
391 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
392 options, "/job:localhost/replica:0/task:0", &devices));
393 device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
394 resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
395
396 FunctionDefLibrary proto;
397 for (const auto& fdef : flib) *(proto.add_function()) = fdef;
398 lib_def_ =
399 absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
400
401 OptimizerOptions opts;
402 pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
403 device_mgr_.get(), Env::Default(), /*config=*/nullptr,
404 TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
405 /*parent=*/nullptr,
406 /*session_metadata=*/nullptr,
407 Rendezvous::Factory{
408 [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
409 *r = new IntraProcessRendezvous(device_mgr);
410 return Status::OK();
411 }});
412 flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
413 if (thread_pool_ == nullptr) {
414 runner_ = [](const std::function<void()>& fn) { fn(); };
415 } else {
416 runner_ = [this](std::function<void()> fn) {
417 thread_pool_->Schedule(std::move(fn));
418 };
419 }
420 return Status::OK();
421 }
422
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)423 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
424 OpKernelContext* context) {
425 device_->Compute(op_kernel, context);
426 return context->status();
427 }
428
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)429 Status DatasetOpsTestBase::RunFunction(
430 const FunctionDef& fdef, test::function::Attrs attrs,
431 const std::vector<Tensor>& args,
432 const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
433 std::unique_ptr<Executor> exec;
434 InstantiationResult result;
435 auto GetOpSig = [](const string& op, const OpDef** sig) {
436 return OpRegistry::Global()->LookUpOpDef(op, sig);
437 };
438 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
439
440 DataTypeVector arg_types = result.arg_types;
441 DataTypeVector ret_types = result.ret_types;
442
443 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
444 TF_RETURN_IF_ERROR(
445 ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
446
447 const int version = g->versions().producer();
448 LocalExecutorParams params;
449 params.function_library = flr_;
450 params.device = device_.get();
451 params.create_kernel = [this, version](
452 const std::shared_ptr<const NodeProperties>& props,
453 OpKernel** kernel) {
454 return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
455 kernel);
456 };
457 params.delete_kernel = [](OpKernel* kernel) {
458 DeleteNonCachedKernel(kernel);
459 };
460
461 Executor* cur_exec;
462 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
463 exec.reset(cur_exec);
464 FunctionCallFrame frame(arg_types, ret_types);
465 TF_RETURN_IF_ERROR(frame.SetArgs(args));
466 Executor::Args exec_args;
467 exec_args.call_frame = &frame;
468 exec_args.runner = runner_;
469 TF_RETURN_IF_ERROR(exec->Run(exec_args));
470 std::vector<Tensor> computed;
471 TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
472 if (computed.size() != rets.size()) {
473 return errors::InvalidArgument(
474 "The result does not match the expected number of return outpus",
475 ". Expected: ", rets.size(), ". Actual: ", computed.size());
476 }
477 for (int i = 0; i < rets.size(); ++i) {
478 *(rets[i]) = computed[i];
479 }
480 return Status::OK();
481 }
482
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)483 Status DatasetOpsTestBase::CreateOpKernelContext(
484 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
485 std::unique_ptr<OpKernelContext>* context) {
486 return CreateOpKernelContext(kernel, inputs, ¶ms_, context);
487 }
488
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)489 Status DatasetOpsTestBase::CreateOpKernelContext(
490 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
491 std::unique_ptr<OpKernelContext::Params>* context_params,
492 std::unique_ptr<OpKernelContext>* context) {
493 auto params = absl::make_unique<OpKernelContext::Params>();
494 cancellation_manager_ = absl::make_unique<CancellationManager>();
495 params->cancellation_manager = cancellation_manager_.get();
496 params->device = device_.get();
497 params->frame_iter = FrameAndIter(0, 0);
498 params->function_library = flr_;
499 params->inputs = inputs;
500 params->op_kernel = kernel;
501 params->resource_manager = resource_mgr_.get();
502 params->runner = &runner_;
503 slice_reader_cache_ =
504 absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
505 params->slice_reader_cache = slice_reader_cache_.get();
506 step_container_ =
507 absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
508 params->step_container = step_container_.get();
509
510 // Set the allocator attributes for the outputs.
511 allocator_attrs_.clear();
512 for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
513 AllocatorAttributes attr;
514 const bool on_host =
515 (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
516 attr.set_on_host(on_host);
517 allocator_attrs_.emplace_back(attr);
518 }
519 params->output_attr_array = allocator_attrs_.data();
520
521 *context = absl::make_unique<OpKernelContext>(params.get());
522 *context_params = std::move(params);
523 return Status::OK();
524 }
525
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)526 Status DatasetOpsTestBase::CreateSerializationContext(
527 std::unique_ptr<SerializationContext>* context) {
528 *context =
529 absl::make_unique<SerializationContext>(SerializationContext::Params{});
530 return Status::OK();
531 }
532
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)533 Status DatasetOpsTestBase::CheckOpKernelInput(
534 const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
535 if (kernel.num_inputs() != inputs.size()) {
536 return errors::InvalidArgument("The number of input elements should be ",
537 kernel.num_inputs(),
538 ", but got: ", inputs.size());
539 }
540 return Status::OK();
541 }
542
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)543 Status DatasetOpsTestBase::AddDatasetInput(
544 gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
545 DataType dtype, const TensorShape& shape) {
546 if (input_types.size() < inputs->size()) {
547 return errors::InvalidArgument("Adding more inputs than types: ",
548 inputs->size(), " vs. ", input_types.size());
549 }
550 bool is_ref = IsRefType(input_types[inputs->size()]);
551 auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
552
553 if (is_ref) {
554 DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
555 if (expected_dtype != dtype) {
556 return errors::InvalidArgument("The input data type is ", dtype,
557 " , but expected: ", expected_dtype);
558 }
559 inputs->push_back({&lock_for_refs_, input.get()});
560 } else {
561 if (input_types[inputs->size()] != dtype) {
562 return errors::InvalidArgument(
563 "The input data type is ", dtype,
564 " , but expected: ", input_types[inputs->size()]);
565 }
566 inputs->push_back({nullptr, input.get()});
567 }
568
569 // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
570 // collect the inputs.
571 tensors_.push_back(std::move(input));
572
573 return Status::OK();
574 }
575
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)576 Status DatasetOpsTestBase::CheckIteratorGetNext(
577 const std::vector<Tensor>& expected_outputs, bool compare_order) {
578 return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
579 expected_outputs, compare_order);
580 }
581
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)582 Status DatasetOpsTestBase::CheckIteratorGetNext(
583 TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
584 bool compare_order) {
585 return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
586 expected_outputs, compare_order);
587 }
588
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)589 Status DatasetOpsTestBase::CheckIteratorGetNext(
590 IteratorBase* iterator, IteratorContext* ctx,
591 const std::vector<Tensor>& expected_outputs, bool compare_order) {
592 bool end_of_sequence = false;
593 std::vector<Tensor> out_tensors;
594 while (!end_of_sequence) {
595 std::vector<Tensor> next;
596 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
597 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
598 }
599 // Call GetNext one more time to make sure it still reports
600 // end_of_sequence = True.
601 std::vector<Tensor> unused;
602 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
603 EXPECT_TRUE(end_of_sequence);
604
605 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
606 /*compare_order=*/compare_order));
607 return Status::OK();
608 }
609
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)610 Status DatasetOpsTestBase::CheckIteratorSkip(
611 int num_to_skip, int expected_num_skipped, bool get_next,
612 const std::vector<Tensor>& expected_outputs, bool compare_order) {
613 IteratorBase* iterator = iterator_.get();
614 IteratorContext* ctx = iterator_ctx_.get();
615
616 bool end_of_sequence = false;
617 int num_skipped = 0;
618 TF_RETURN_IF_ERROR(
619 iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
620 EXPECT_TRUE(num_skipped == expected_num_skipped);
621 if (get_next) {
622 EXPECT_TRUE(!end_of_sequence);
623 std::vector<Tensor> out_tensors;
624 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
625 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
626 /*compare_order=*/compare_order));
627 }
628 return Status::OK();
629 }
630
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)631 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
632 const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
633 std::unique_ptr<TestDataset> dataset;
634 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
635 std::unique_ptr<SplitProvider> split_provider;
636 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
637 std::unique_ptr<TestIterator> iterator;
638 TF_RETURN_IF_ERROR(
639 MakeIterator(params, *dataset, std::move(split_provider), &iterator));
640 TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
641 /*compare_order=*/true));
642 return Status::OK();
643 }
644
CheckSplitProviderShardedIteration(const DatasetParams & params,int64 num_shards,int64 shard_index,const std::vector<Tensor> & expected_outputs)645 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
646 const DatasetParams& params, int64 num_shards, int64 shard_index,
647 const std::vector<Tensor>& expected_outputs) {
648 std::unique_ptr<TestDataset> dataset;
649 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
650 std::unique_ptr<SplitProvider> split_provider;
651 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
652 split_provider = absl::make_unique<ShardingSplitProvider>(
653 num_shards, shard_index, std::move(split_provider));
654 std::unique_ptr<IteratorContext> iterator_ctx;
655 TF_RETURN_IF_ERROR(
656 CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
657 IteratorContext::Params iterator_params(iterator_ctx.get());
658 iterator_params.split_provider = std::move(split_provider);
659 iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
660 int mid_breakpoint = expected_outputs.size() / 2;
661 int near_end_breakpoint = expected_outputs.size() - 1;
662 int end_breakpoint = expected_outputs.size();
663 TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
664 dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
665 expected_outputs,
666 /*breakpoints=*/
667 {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
668 /*compare_order=*/true));
669 return Status::OK();
670 }
671
CheckDatasetNodeName(const string & expected_dataset_node_name)672 Status DatasetOpsTestBase::CheckDatasetNodeName(
673 const string& expected_dataset_node_name) {
674 EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
675 return Status::OK();
676 }
677
CheckDatasetTypeString(const string & expected_type_str)678 Status DatasetOpsTestBase::CheckDatasetTypeString(
679 const string& expected_type_str) {
680 EXPECT_EQ(dataset_->type_string(), expected_type_str);
681 return Status::OK();
682 }
683
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)684 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
685 const DataTypeVector& expected_output_dtypes) {
686 TF_EXPECT_OK(
687 VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
688 return Status::OK();
689 }
690
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)691 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
692 const std::vector<PartialTensorShape>& expected_output_shapes) {
693 TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
694 expected_output_shapes));
695 return Status::OK();
696 }
697
CheckDatasetCardinality(int expected_cardinality)698 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
699 EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
700 return Status::OK();
701 }
702
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)703 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
704 const DataTypeVector& expected_output_dtypes) {
705 TF_EXPECT_OK(
706 VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
707 return Status::OK();
708 }
709
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)710 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
711 const std::vector<PartialTensorShape>& expected_output_shapes) {
712 TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
713 expected_output_shapes));
714 return Status::OK();
715 }
716
CheckIteratorPrefix(const string & expected_iterator_prefix)717 Status DatasetOpsTestBase::CheckIteratorPrefix(
718 const string& expected_iterator_prefix) {
719 EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
720 return Status::OK();
721 }
722
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)723 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
724 DatasetBase* dataset, IteratorContext* iterator_ctx,
725 const std::string& iterator_prefix,
726 const std::vector<Tensor>& expected_outputs,
727 const std::vector<int>& breakpoints, bool compare_order) {
728 std::unique_ptr<IteratorBase> iterator;
729 TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
730 iterator_prefix, &iterator));
731 std::unique_ptr<SerializationContext> serialization_ctx;
732 TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
733 bool end_of_sequence = false;
734 int cur_iteration = 0;
735 std::vector<Tensor> out_tensors;
736 for (int breakpoint : breakpoints) {
737 VariantTensorDataWriter writer;
738 TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
739 std::vector<const VariantTensorData*> data;
740 writer.GetData(&data);
741 VariantTensorDataReader reader(data);
742 TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
743 *dataset, &iterator));
744
745 while (cur_iteration <= breakpoint) {
746 std::vector<Tensor> next;
747 TF_RETURN_IF_ERROR(
748 iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
749 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
750 cur_iteration++;
751 }
752 }
753 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
754 /*compare_order=*/compare_order));
755 return Status::OK();
756 }
757
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)758 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
759 const std::string& iterator_prefix,
760 const std::vector<Tensor>& expected_outputs,
761 const std::vector<int>& breakpoints, bool compare_order) {
762 return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
763 iterator_prefix, expected_outputs,
764 breakpoints, compare_order);
765 }
766
Initialize(const DatasetParams & dataset_params)767 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
768 if (initialized_) {
769 return errors::Internal(
770 "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
771 "iterator_ctx_, iterator_) have already been initialized.");
772 }
773 TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
774 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
775 &dataset_ctx_, &tensors_, &dataset_));
776 TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
777 TF_RETURN_IF_ERROR(
778 dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
779 dataset_params.iterator_prefix(), &iterator_));
780 initialized_ = true;
781 return Status::OK();
782 }
783
InitializeRuntime(const DatasetParams & dataset_params)784 Status DatasetOpsTestBase::InitializeRuntime(
785 const DatasetParams& dataset_params) {
786 TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
787 TF_RETURN_IF_ERROR(
788 InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
789 return Status::OK();
790 }
791
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)792 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
793 std::unique_ptr<TestDataset>* dataset) {
794 DatasetBase* dataset_base;
795 std::unique_ptr<OpKernel> dataset_kernel;
796 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
797 std::unique_ptr<OpKernelContext> dataset_ctx;
798 std::vector<std::unique_ptr<Tensor>> created_tensors;
799 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
800 &dataset_ctx_params, &dataset_ctx,
801 &created_tensors, &dataset_base));
802 *dataset = std::make_unique<TestDataset>(
803 std::move(dataset_kernel), std::move(dataset_ctx_params),
804 std::move(dataset_ctx), std::move(created_tensors), dataset_base);
805 return Status::OK();
806 }
807
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)808 Status DatasetOpsTestBase::RunDatasetOp(
809 const DatasetParams& dataset_params,
810 std::unique_ptr<OpKernel>* dataset_kernel,
811 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
812 std::vector<std::unique_ptr<Tensor>>* created_tensors,
813 std::unique_ptr<OpKernelContext>* dataset_ctx) {
814 std::vector<Tensor*> input_datasets;
815 for (auto& input : dataset_params.input_dataset_params()) {
816 std::unique_ptr<Tensor> t;
817 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
818 input_datasets.push_back(t.get());
819 created_tensors->push_back(std::move(t));
820 }
821 gtl::InlinedVector<TensorValue, 4> inputs;
822 for (auto input_dataset : input_datasets) {
823 inputs.emplace_back(TensorValue(input_dataset));
824 }
825
826 // Copy the input tensors, storing them in the `inputs` vectors, and storing
827 // owned references to the copies in `created_tensors`.
828 for (auto& input : dataset_params.GetInputTensors()) {
829 auto copy = absl::make_unique<Tensor>(input);
830 inputs.push_back(TensorValue(copy.get()));
831 created_tensors->push_back(std::move(copy));
832 }
833
834 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
835 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
836 dataset_ctx_params, dataset_ctx));
837 TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
838 return Status::OK();
839 }
840
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)841 Status DatasetOpsTestBase::MakeDataset(
842 const DatasetParams& dataset_params,
843 std::unique_ptr<OpKernel>* dataset_kernel,
844 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
845 std::unique_ptr<OpKernelContext>* dataset_ctx,
846 std::vector<std::unique_ptr<Tensor>>* created_tensors,
847 DatasetBase** dataset) {
848 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
849 dataset_ctx_params, created_tensors,
850 dataset_ctx));
851 // Assume that DatasetOp has only one output.
852 DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
853 TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
854 return Status::OK();
855 }
856
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<SplitProvider> split_provider,std::unique_ptr<TestIterator> * iterator)857 Status DatasetOpsTestBase::MakeIterator(
858 const DatasetParams& dataset_params, const TestDataset& dataset,
859 std::unique_ptr<SplitProvider> split_provider,
860 std::unique_ptr<TestIterator>* iterator) {
861 std::unique_ptr<IteratorContext> iterator_ctx;
862 TF_RETURN_IF_ERROR(
863 CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
864 IteratorContext::Params iterator_params(iterator_ctx.get());
865 iterator_params.split_provider = std::move(split_provider);
866 iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
867 std::unique_ptr<IteratorBase> iterator_base;
868 TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
869 iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
870 &iterator_base));
871 *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
872 std::move(iterator_base));
873 return Status::OK();
874 }
875
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)876 Status DatasetOpsTestBase::MakeIterator(
877 const DatasetParams& dataset_params, const TestDataset& dataset,
878 std::unique_ptr<TestIterator>* iterator) {
879 return MakeIterator(dataset_params, dataset, /*split_provider=*/nullptr,
880 iterator);
881 }
882
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)883 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
884 std::vector<Tensor>* outputs) {
885 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, ¶ms_,
886 &tensors_, &dataset_ctx_));
887 for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
888 outputs->emplace_back(*dataset_ctx_->mutable_output(i));
889 }
890 return Status::OK();
891 }
892
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)893 Status DatasetOpsTestBase::MakeDatasetOpKernel(
894 const DatasetParams& dataset_params,
895 std::unique_ptr<OpKernel>* dataset_kernel) {
896 name_utils::OpNameParams params;
897 params.op_version = dataset_params.op_version();
898 std::vector<string> input_names;
899 TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
900 AttributeVector attributes;
901 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
902 NodeDef node_def = test::function::NDef(
903 dataset_params.node_name(),
904 name_utils::OpName(dataset_params.dataset_type(), params), input_names,
905 attributes);
906 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
907 return Status::OK();
908 }
909
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)910 Status DatasetOpsTestBase::MakeDatasetTensor(
911 const DatasetParams& dataset_params,
912 std::vector<std::unique_ptr<Tensor>>* created_tensors,
913 std::unique_ptr<Tensor>* dataset) {
914 // Make sure all the input dataset tensors have been populated.
915 std::vector<Tensor*> input_datasets;
916 for (auto& input : dataset_params.input_dataset_params()) {
917 std::unique_ptr<Tensor> t;
918 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
919 input_datasets.push_back(t.get());
920 created_tensors->push_back(std::move(t));
921 }
922
923 AttributeVector attributes;
924 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
925
926 gtl::InlinedVector<TensorValue, 4> inputs;
927 for (auto input_dataset : input_datasets) {
928 inputs.emplace_back(TensorValue(input_dataset));
929 }
930 auto input_tensors = dataset_params.GetInputTensors();
931 for (auto& input_tensor : input_tensors) {
932 inputs.emplace_back(TensorValue(&input_tensor));
933 }
934
935 DatasetBase* dataset_base;
936 std::unique_ptr<OpKernel> dataset_kernel;
937 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
938 std::unique_ptr<OpKernelContext> dataset_ctx;
939 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
940 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
941 &dataset_ctx_params, &dataset_ctx));
942 TF_RETURN_IF_ERROR(
943 CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
944 Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
945 TF_RETURN_IF_ERROR(
946 StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
947 *dataset = absl::make_unique<Tensor>(dataset_tensor);
948 return Status::OK();
949 }
950
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)951 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
952 std::vector<PartialTensorShape> output_shapes,
953 string node_name)
954 : output_dtypes_(std::move(output_dtypes)),
955 output_shapes_(std::move(output_shapes)),
956 node_name_(std::move(node_name)) {}
957
IsDatasetTensor(const Tensor & tensor)958 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
959 return tensor.dtype() == DT_VARIANT &&
960 TensorShapeUtils::IsScalar(tensor.shape());
961 }
962
RangeDatasetParams(int64 start,int64 stop,int64 step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)963 RangeDatasetParams::RangeDatasetParams(
964 int64 start, int64 stop, int64 step, DataTypeVector output_dtypes,
965 std::vector<PartialTensorShape> output_shapes, string node_name)
966 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
967 std::move(node_name)),
968 start_(start),
969 stop_(stop),
970 step_(step) {}
971
RangeDatasetParams(int64 start,int64 stop,int64 step)972 RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
973 : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
974 start_(start),
975 stop_(stop),
976 step_(step) {}
977
RangeDatasetParams(int64 start,int64 stop,int64 step,DataTypeVector output_dtypes)978 RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step,
979 DataTypeVector output_dtypes)
980 : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
981 "range_dataset"),
982 start_(start),
983 stop_(stop),
984 step_(step) {}
985
GetInputTensors() const986 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
987 Tensor start_tensor = CreateTensor<int64>(TensorShape({}), {start_});
988 Tensor stop_tensor = CreateTensor<int64>(TensorShape({}), {stop_});
989 Tensor step_tensor = CreateTensor<int64>(TensorShape({}), {step_});
990 return {start_tensor, stop_tensor, step_tensor};
991 }
992
GetInputNames(std::vector<string> * input_names) const993 Status RangeDatasetParams::GetInputNames(
994 std::vector<string>* input_names) const {
995 *input_names = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
996 RangeDatasetOp::kStep};
997 return Status::OK();
998 }
999
GetAttributes(AttributeVector * attr_vector) const1000 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1001 *attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
1002 {RangeDatasetOp::kOutputShapes, output_shapes_}};
1003 return Status::OK();
1004 }
1005
dataset_type() const1006 string RangeDatasetParams::dataset_type() const {
1007 return RangeDatasetOp::kDatasetType;
1008 }
1009
GetInputTensors() const1010 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1011 Tensor batch_size = CreateTensor<int64>(TensorShape({}), {batch_size_});
1012 Tensor drop_remainder =
1013 CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1014 return {batch_size, drop_remainder};
1015 }
1016
GetInputNames(std::vector<string> * input_names) const1017 Status BatchDatasetParams::GetInputNames(
1018 std::vector<string>* input_names) const {
1019 *input_names = {BatchDatasetOp::kInputDataset, BatchDatasetOp::kBatchSize,
1020 BatchDatasetOp::kDropRemainder};
1021 return Status::OK();
1022 }
1023
GetAttributes(AttributeVector * attr_vector) const1024 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1025 *attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
1026 {BatchDatasetOp::kOutputTypes, output_dtypes_},
1027 {BatchDatasetOp::kOutputShapes, output_shapes_}};
1028 return Status::OK();
1029 }
1030
dataset_type() const1031 string BatchDatasetParams::dataset_type() const {
1032 return BatchDatasetOp::kDatasetType;
1033 }
1034
GetInputTensors() const1035 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1036 return other_arguments_;
1037 }
1038
GetInputNames(std::vector<string> * input_names) const1039 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1040 input_names->emplace_back(MapDatasetOp::kInputDataset);
1041 for (int i = 0; i < other_arguments_.size(); ++i) {
1042 input_names->emplace_back(
1043 absl::StrCat(MapDatasetOp::kOtherArguments, "_", i));
1044 }
1045 return Status::OK();
1046 }
1047
GetAttributes(AttributeVector * attr_vector) const1048 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1049 *attr_vector = {
1050 {MapDatasetOp::kFunc, func_},
1051 {MapDatasetOp::kTarguments, type_arguments_},
1052 {MapDatasetOp::kOutputShapes, output_shapes_},
1053 {MapDatasetOp::kOutputTypes, output_dtypes_},
1054 {MapDatasetOp::kUseInterOpParallelism, use_inter_op_parallelism_},
1055 {MapDatasetOp::kPreserveCardinality, preserve_cardinality_}};
1056 return Status::OK();
1057 }
1058
dataset_type() const1059 string MapDatasetParams::dataset_type() const {
1060 return MapDatasetOp::kDatasetType;
1061 }
1062
func_lib() const1063 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1064 return func_lib_;
1065 }
1066
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name)1067 TensorSliceDatasetParams::TensorSliceDatasetParams(
1068 std::vector<Tensor> components, string node_name)
1069 : DatasetParams(TensorSliceDtypes(components),
1070 TensorSliceShapes(components), std::move(node_name)),
1071 components_(std::move(components)) {}
1072
GetInputTensors() const1073 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1074 return components_;
1075 }
1076
GetInputNames(std::vector<string> * input_names) const1077 Status TensorSliceDatasetParams::GetInputNames(
1078 std::vector<string>* input_names) const {
1079 input_names->reserve(components_.size());
1080 for (int i = 0; i < components_.size(); ++i) {
1081 input_names->emplace_back(
1082 absl::StrCat(TensorSliceDatasetOp::kComponents, "_", i));
1083 }
1084 return Status::OK();
1085 }
1086
GetAttributes(AttributeVector * attr_vector) const1087 Status TensorSliceDatasetParams::GetAttributes(
1088 AttributeVector* attr_vector) const {
1089 *attr_vector = {{TensorSliceDatasetOp::kToutputTypes, output_dtypes_},
1090 {TensorSliceDatasetOp::kOutputShapes, output_shapes_}};
1091 return Status::OK();
1092 }
1093
TensorSliceDtypes(const std::vector<Tensor> & input_components)1094 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1095 const std::vector<Tensor>& input_components) {
1096 DataTypeVector dtypes;
1097 for (const auto& component : input_components) {
1098 dtypes.emplace_back(component.dtype());
1099 }
1100 return dtypes;
1101 }
1102
TensorSliceShapes(const std::vector<Tensor> & input_components)1103 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1104 const std::vector<Tensor>& input_components) {
1105 std::vector<PartialTensorShape> shapes;
1106 for (const auto& component : input_components) {
1107 gtl::InlinedVector<int64, 4> partial_dim_sizes;
1108 for (int i = 1; i < component.dims(); ++i) {
1109 partial_dim_sizes.push_back(component.dim_size(i));
1110 }
1111 shapes.emplace_back(std::move(partial_dim_sizes));
1112 }
1113 return shapes;
1114 }
1115
dataset_type() const1116 string TensorSliceDatasetParams::dataset_type() const {
1117 return TensorSliceDatasetOp::kDatasetType;
1118 }
1119
GetInputTensors() const1120 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1121 return {CreateTensor<int64>(TensorShape({}), {count_})};
1122 }
1123
GetInputNames(std::vector<string> * input_names) const1124 Status TakeDatasetParams::GetInputNames(
1125 std::vector<string>* input_names) const {
1126 *input_names = {TakeDatasetOp::kInputDataset, TakeDatasetOp::kCount};
1127 return Status::OK();
1128 }
1129
GetAttributes(AttributeVector * attr_vector) const1130 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1131 *attr_vector = {{TakeDatasetOp::kOutputShapes, output_shapes_},
1132 {TakeDatasetOp::kOutputTypes, output_dtypes_}};
1133 return Status::OK();
1134 }
1135
dataset_type() const1136 string TakeDatasetParams::dataset_type() const {
1137 return TakeDatasetOp::kDatasetType;
1138 }
1139
GetInputTensors() const1140 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1141 return {};
1142 }
1143
GetInputNames(std::vector<string> * input_names) const1144 Status ConcatenateDatasetParams::GetInputNames(
1145 std::vector<string>* input_names) const {
1146 *input_names = {ConcatenateDatasetOp::kInputDataset,
1147 ConcatenateDatasetOp::kAnotherDataset};
1148 return Status::OK();
1149 }
1150
GetAttributes(AttributeVector * attr_vector) const1151 Status ConcatenateDatasetParams::GetAttributes(
1152 AttributeVector* attr_vector) const {
1153 *attr_vector = {{ConcatenateDatasetOp::kOutputTypes, output_dtypes_},
1154 {ConcatenateDatasetOp::kOutputShapes, output_shapes_}};
1155 return Status::OK();
1156 }
1157
dataset_type() const1158 string ConcatenateDatasetParams::dataset_type() const {
1159 return ConcatenateDatasetOp::kDatasetType;
1160 }
1161
1162 } // namespace data
1163 } // namespace tensorflow
1164