1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_test_base.h"
17
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/name_utils.h"
38 #include "tensorflow/core/data/serialization_utils.h"
39 #include "tensorflow/core/data/split_utils.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/cancellation.h"
42 #include "tensorflow/core/framework/control_flow.h"
43 #include "tensorflow/core/framework/dataset.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_handle_cache.h"
47 #include "tensorflow/core/framework/function_testlib.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/numeric_types.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/op_def.pb.h"
52 #include "tensorflow/core/framework/op_kernel.h"
53 #include "tensorflow/core/framework/register_types.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/framework/types.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/framework/variant_tensor_data.h"
61 #include "tensorflow/core/framework/versions.pb.h"
62 #include "tensorflow/core/graph/graph.h"
63 #include "tensorflow/core/lib/core/status_test_util.h"
64 #include "tensorflow/core/lib/gtl/inlined_vector.h"
65 #include "tensorflow/core/lib/io/record_writer.h"
66 #include "tensorflow/core/lib/io/zlib_compression_options.h"
67 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
68 #include "tensorflow/core/platform/bfloat16.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/errors.h"
71 #include "tensorflow/core/platform/file_system.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/platform/status.h"
74 #include "tensorflow/core/platform/test.h"
75 #include "tensorflow/core/platform/threadpool.h"
76 #include "tensorflow/core/platform/tstring.h"
77 #include "tensorflow/core/platform/types.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/public/session_options.h"
80 #include "tensorflow/core/public/version.h"
81 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
82
83 namespace tensorflow {
84 namespace data {
85
ToString(CompressionType compression_type)86 string ToString(CompressionType compression_type) {
87 switch (compression_type) {
88 case CompressionType::ZLIB:
89 return "ZLIB";
90 case CompressionType::GZIP:
91 return "GZIP";
92 case CompressionType::RAW:
93 return "RAW";
94 case CompressionType::UNCOMPRESSED:
95 return "";
96 }
97 }
98
GetZlibCompressionOptions(CompressionType compression_type)99 io::ZlibCompressionOptions GetZlibCompressionOptions(
100 CompressionType compression_type) {
101 switch (compression_type) {
102 case CompressionType::ZLIB:
103 return io::ZlibCompressionOptions::DEFAULT();
104 case CompressionType::GZIP:
105 return io::ZlibCompressionOptions::GZIP();
106 case CompressionType::RAW:
107 return io::ZlibCompressionOptions::RAW();
108 case CompressionType::UNCOMPRESSED:
109 LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
110 << ToString(compression_type);
111 return io::ZlibCompressionOptions::DEFAULT();
112 }
113 }
114
WriteDataToFile(const string & filename,const char * data)115 Status WriteDataToFile(const string& filename, const char* data) {
116 return WriteDataToFile(filename, data, CompressionParams());
117 }
118
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)119 Status WriteDataToFile(const string& filename, const char* data,
120 const CompressionParams& params) {
121 Env* env = Env::Default();
122 std::unique_ptr<WritableFile> file_writer;
123 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
124 if (params.compression_type == CompressionType::UNCOMPRESSED) {
125 TF_RETURN_IF_ERROR(file_writer->Append(data));
126 } else if (params.compression_type == CompressionType::ZLIB ||
127 params.compression_type == CompressionType::GZIP ||
128 params.compression_type == CompressionType::RAW) {
129 auto zlib_compression_options =
130 GetZlibCompressionOptions(params.compression_type);
131 io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
132 params.output_buffer_size,
133 zlib_compression_options);
134 TF_RETURN_IF_ERROR(out.Init());
135 TF_RETURN_IF_ERROR(out.Append(data));
136 TF_RETURN_IF_ERROR(out.Flush());
137 TF_RETURN_IF_ERROR(out.Close());
138 } else {
139 return tensorflow::errors::InvalidArgument(
140 "Unsupported compression_type: ", ToString(params.compression_type));
141 }
142
143 TF_RETURN_IF_ERROR(file_writer->Flush());
144 TF_RETURN_IF_ERROR(file_writer->Close());
145
146 return OkStatus();
147 }
148
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)149 Status WriteDataToTFRecordFile(const string& filename,
150 const std::vector<absl::string_view>& records,
151 const CompressionParams& params) {
152 Env* env = Env::Default();
153 std::unique_ptr<WritableFile> file_writer;
154 TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
155 auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
156 ToString(params.compression_type));
157 options.zlib_options.input_buffer_size = params.input_buffer_size;
158 io::RecordWriter record_writer(file_writer.get(), options);
159 for (const auto& record : records) {
160 TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
161 }
162 TF_RETURN_IF_ERROR(record_writer.Flush());
163 TF_RETURN_IF_ERROR(record_writer.Close());
164 TF_RETURN_IF_ERROR(file_writer->Flush());
165 TF_RETURN_IF_ERROR(file_writer->Close());
166 return OkStatus();
167 }
168
169 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)170 Status IsEqual(const Tensor& t1, const Tensor& t2) {
171 if (t1.dtype() != t2.dtype()) {
172 return tensorflow::errors::Internal(
173 "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
174 " vs. ", DataTypeString(t2.dtype()));
175 }
176 if (!t1.IsSameSize(t2)) {
177 return tensorflow::errors::Internal(
178 "Two tensors have different shapes: ", t1.shape().DebugString(),
179 " vs. ", t2.shape().DebugString());
180 }
181
182 auto flat_t1 = t1.flat<T>();
183 auto flat_t2 = t2.flat<T>();
184 auto length = flat_t1.size();
185
186 for (int i = 0; i < length; ++i) {
187 if (flat_t1(i) != flat_t2(i)) {
188 return tensorflow::errors::Internal(
189 "Two tensors have different values "
190 "at [",
191 i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
192 }
193 }
194 return OkStatus();
195 }
196
DatasetOpsTestBase()197 DatasetOpsTestBase::DatasetOpsTestBase()
198 : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
199 device_type_(DEVICE_CPU),
200 cpu_num_(kDefaultCPUNum),
201 thread_num_(kDefaultThreadNum) {
202 allocator_ = device_->GetAllocator(AllocatorAttributes());
203 }
204
~DatasetOpsTestBase()205 DatasetOpsTestBase::~DatasetOpsTestBase() {
206 if (dataset_) {
207 dataset_->Unref();
208 }
209 }
210
ExpectEqual(const Tensor & a,const Tensor & b)211 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
212 switch (a.dtype()) {
213 #define CASE(DT) \
214 case DataTypeToEnum<DT>::value: \
215 TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
216 break;
217 TF_CALL_NUMBER_TYPES(CASE);
218 TF_CALL_tstring(CASE);
219 // TODO(feihugis): figure out how to support variant tensors.
220 #undef CASE
221 default:
222 return errors::Internal("Unsupported dtype: ", a.dtype());
223 }
224 return OkStatus();
225 }
226
227 template <typename T>
compare(const Tensor & t1,const Tensor & t2)228 bool compare(const Tensor& t1, const Tensor& t2) {
229 auto flat_t1 = t1.flat<T>();
230 auto flat_t2 = t2.flat<T>();
231 auto length = std::min(flat_t1.size(), flat_t2.size());
232 for (int i = 0; i < length; ++i) {
233 if (flat_t1(i) < flat_t2(i)) return true;
234 if (flat_t1(i) > flat_t2(i)) return false;
235 }
236 return flat_t1.size() < length;
237 }
238
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)239 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
240 std::vector<Tensor> expected_tensors,
241 bool compare_order) {
242 if (produced_tensors.size() != expected_tensors.size()) {
243 return Status(tensorflow::errors::Internal(
244 "The two tensor vectors have different size (", produced_tensors.size(),
245 " v.s. ", expected_tensors.size(), ")"));
246 }
247
248 if (produced_tensors.empty()) return OkStatus();
249 if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
250 return Status(tensorflow::errors::Internal(
251 "The two tensor vectors have different dtypes (",
252 produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
253 ")"));
254 }
255
256 if (!compare_order) {
257 const DataType& dtype = produced_tensors[0].dtype();
258 switch (dtype) {
259 #define CASE(DT) \
260 case DT: \
261 std::sort(produced_tensors.begin(), produced_tensors.end(), \
262 compare<EnumToDataType<DT>::Type>); \
263 std::sort(expected_tensors.begin(), expected_tensors.end(), \
264 compare<EnumToDataType<DT>::Type>); \
265 break;
266 CASE(DT_FLOAT);
267 CASE(DT_DOUBLE);
268 CASE(DT_INT32);
269 CASE(DT_UINT8);
270 CASE(DT_INT16);
271 CASE(DT_INT8);
272 CASE(DT_STRING);
273 CASE(DT_INT64);
274 CASE(DT_BOOL);
275 CASE(DT_QINT8);
276 CASE(DT_QUINT8);
277 CASE(DT_QINT32);
278 CASE(DT_QINT16);
279 CASE(DT_QUINT16);
280 CASE(DT_UINT16);
281 CASE(DT_HALF);
282 CASE(DT_UINT32);
283 CASE(DT_UINT64);
284 // TODO(feihugis): support other dtypes.
285 #undef CASE
286 default:
287 return errors::Internal("Unsupported dtype: ", dtype);
288 }
289 }
290
291 for (int i = 0; i < produced_tensors.size(); ++i) {
292 TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
293 expected_tensors[i]));
294 }
295 return OkStatus();
296 }
297
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)298 Status DatasetOpsTestBase::CreateOpKernel(
299 const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
300 OpKernel* kernel;
301 Status s;
302
303 std::shared_ptr<const NodeProperties> props;
304 TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
305 node_def, flr_->GetFunctionLibraryDefinition(), &props));
306 // Apply attribute defaults.
307 auto props_with_defaults = std::make_shared<NodeProperties>(*props);
308 for (const auto& attr : props->op_def->attr()) {
309 if (attr.has_default_value() &&
310 !props->node_def.attr().contains(attr.name())) {
311 (*props_with_defaults->node_def.mutable_attr())[attr.name()] =
312 attr.default_value();
313 }
314 }
315 TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
316 device_type_, device_.get(), allocator_, flr_,
317 device_->resource_manager(), props_with_defaults, TF_GRAPH_DEF_VERSION,
318 &kernel));
319 op_kernel->reset(kernel);
320 return OkStatus();
321 }
322
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)323 Status DatasetOpsTestBase::CreateDatasetContext(
324 OpKernel* const dateset_kernel,
325 gtl::InlinedVector<TensorValue, 4>* const inputs,
326 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
327 std::unique_ptr<OpKernelContext>* dataset_context) {
328 Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
329 if (!status.ok()) {
330 VLOG(0) << "WARNING: " << status.ToString();
331 }
332 TF_RETURN_IF_ERROR(CreateOpKernelContext(
333 dateset_kernel, inputs, dataset_context_params, dataset_context));
334 return OkStatus();
335 }
336
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)337 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
338 OpKernelContext* context,
339 DatasetBase** const dataset) {
340 TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
341 // Assume that DatasetOp has only one output.
342 DCHECK_EQ(context->num_outputs(), 1);
343 TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
344 return OkStatus();
345 }
346
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)347 Status DatasetOpsTestBase::RestoreIterator(
348 IteratorContext* ctx, IteratorStateReader* reader,
349 const string& output_prefix, const DatasetBase& dataset,
350 std::unique_ptr<IteratorBase>* iterator) {
351 return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
352 iterator);
353 }
354
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)355 Status DatasetOpsTestBase::CreateIteratorContext(
356 OpKernelContext* const op_context,
357 std::unique_ptr<IteratorContext>* iterator_context) {
358 IteratorContext::Params params(op_context);
359 params.resource_mgr = op_context->resource_manager();
360 function_handle_cache_ = std::make_unique<FunctionHandleCache>(flr_);
361 params.function_handle_cache = function_handle_cache_.get();
362 params.cancellation_manager = cancellation_manager_.get();
363 *iterator_context = std::make_unique<IteratorContext>(params);
364 return OkStatus();
365 }
366
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)367 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
368 int output_index,
369 DatasetBase** const dataset) {
370 Tensor* output = context->mutable_output(output_index);
371 Status status = GetDatasetFromVariantTensor(*output, dataset);
372 (*dataset)->Ref();
373 return status;
374 }
375
InitThreadPool(int thread_num)376 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
377 if (thread_num < 1) {
378 return errors::InvalidArgument(
379 "The `thread_num` argument should be positive but got: ", thread_num);
380 }
381 thread_pool_ = std::make_unique<thread::ThreadPool>(
382 Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
383 return OkStatus();
384 }
385
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)386 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
387 const std::vector<FunctionDef>& flib, int cpu_num) {
388 if (cpu_num < 1) {
389 return errors::InvalidArgument(
390 "The `cpu_num` argument should be positive but got: ", cpu_num);
391 }
392 SessionOptions options;
393 auto* device_count = options.config.mutable_device_count();
394 device_count->insert({"CPU", cpu_num});
395 std::vector<std::unique_ptr<Device>> devices;
396 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
397 options, "/job:localhost/replica:0/task:0", &devices));
398 device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
399 resource_mgr_ = std::make_unique<ResourceMgr>("default_container");
400
401 FunctionDefLibrary proto;
402 for (const auto& fdef : flib) *(proto.add_function()) = fdef;
403 lib_def_ =
404 std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
405
406 OptimizerOptions opts;
407 pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
408 device_mgr_.get(), Env::Default(), /*config=*/nullptr,
409 TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
410 /*parent=*/nullptr,
411 /*session_metadata=*/nullptr,
412 Rendezvous::Factory{
413 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
414 *r = new IntraProcessRendezvous(device_mgr);
415 return OkStatus();
416 }});
417 flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
418 if (thread_pool_ == nullptr) {
419 runner_ = [](const std::function<void()>& fn) { fn(); };
420 } else {
421 runner_ = [this](std::function<void()> fn) {
422 thread_pool_->Schedule(std::move(fn));
423 };
424 }
425 return OkStatus();
426 }
427
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)428 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
429 OpKernelContext* context) {
430 device_->Compute(op_kernel, context);
431 return context->status();
432 }
433
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)434 Status DatasetOpsTestBase::RunFunction(
435 const FunctionDef& fdef, test::function::Attrs attrs,
436 const std::vector<Tensor>& args,
437 const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
438 std::unique_ptr<Executor> exec;
439 InstantiationResult result;
440 auto GetOpSig = [](const string& op, const OpDef** sig) {
441 return OpRegistry::Global()->LookUpOpDef(op, sig);
442 };
443 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
444
445 DataTypeVector arg_types = result.arg_types;
446 DataTypeVector ret_types = result.ret_types;
447
448 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
449 TF_RETURN_IF_ERROR(
450 ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
451
452 const int version = g->versions().producer();
453 LocalExecutorParams params;
454 params.function_library = flr_;
455 params.device = device_.get();
456 params.create_kernel = [this, version](
457 const std::shared_ptr<const NodeProperties>& props,
458 OpKernel** kernel) {
459 return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
460 kernel);
461 };
462 params.delete_kernel = [](OpKernel* kernel) {
463 DeleteNonCachedKernel(kernel);
464 };
465
466 Executor* cur_exec;
467 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
468 exec.reset(cur_exec);
469 FunctionCallFrame frame(arg_types, ret_types);
470 TF_RETURN_IF_ERROR(frame.SetArgs(args));
471 Executor::Args exec_args;
472 exec_args.call_frame = &frame;
473 exec_args.runner = runner_;
474 TF_RETURN_IF_ERROR(exec->Run(exec_args));
475 std::vector<Tensor> computed;
476 TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
477 if (computed.size() != rets.size()) {
478 return errors::InvalidArgument(
479 "The result does not match the expected number of return outpus",
480 ". Expected: ", rets.size(), ". Actual: ", computed.size());
481 }
482 for (int i = 0; i < rets.size(); ++i) {
483 *(rets[i]) = computed[i];
484 }
485 return OkStatus();
486 }
487
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)488 Status DatasetOpsTestBase::CreateOpKernelContext(
489 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
490 std::unique_ptr<OpKernelContext>* context) {
491 return CreateOpKernelContext(kernel, inputs, ¶ms_, context);
492 }
493
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)494 Status DatasetOpsTestBase::CreateOpKernelContext(
495 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
496 std::unique_ptr<OpKernelContext::Params>* context_params,
497 std::unique_ptr<OpKernelContext>* context) {
498 auto params = std::make_unique<OpKernelContext::Params>();
499 cancellation_manager_ = std::make_unique<CancellationManager>();
500 params->cancellation_manager = cancellation_manager_.get();
501 params->device = device_.get();
502 params->frame_iter = FrameAndIter(0, 0);
503 params->function_library = flr_;
504 params->inputs = *inputs;
505 params->op_kernel = kernel;
506 params->resource_manager = resource_mgr_.get();
507 params->runner = &runner_;
508 slice_reader_cache_ =
509 std::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
510 params->slice_reader_cache = slice_reader_cache_.get();
511 step_container_ =
512 std::make_unique<ScopedStepContainer>(0, [](const string&) {});
513 params->step_container = step_container_.get();
514
515 // Set the allocator attributes for the outputs.
516 allocator_attrs_.clear();
517 for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
518 AllocatorAttributes attr;
519 const bool on_host =
520 (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
521 attr.set_on_host(on_host);
522 allocator_attrs_.emplace_back(attr);
523 }
524 params->output_attr_array = allocator_attrs_.data();
525
526 *context = std::make_unique<OpKernelContext>(params.get());
527 *context_params = std::move(params);
528 return OkStatus();
529 }
530
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)531 Status DatasetOpsTestBase::CreateSerializationContext(
532 std::unique_ptr<SerializationContext>* context) {
533 *context =
534 std::make_unique<SerializationContext>(SerializationContext::Params{});
535 return OkStatus();
536 }
537
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)538 Status DatasetOpsTestBase::CheckOpKernelInput(
539 const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
540 if (kernel.num_inputs() != inputs.size()) {
541 return errors::InvalidArgument("The number of input elements should be ",
542 kernel.num_inputs(),
543 ", but got: ", inputs.size());
544 }
545 return OkStatus();
546 }
547
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)548 Status DatasetOpsTestBase::AddDatasetInput(
549 gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
550 DataType dtype, const TensorShape& shape) {
551 if (input_types.size() < inputs->size()) {
552 return errors::InvalidArgument("Adding more inputs than types: ",
553 inputs->size(), " vs. ", input_types.size());
554 }
555 bool is_ref = IsRefType(input_types[inputs->size()]);
556 auto input = std::make_unique<Tensor>(allocator_, dtype, shape);
557
558 if (is_ref) {
559 DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
560 if (expected_dtype != dtype) {
561 return errors::InvalidArgument("The input data type is ", dtype,
562 " , but expected: ", expected_dtype);
563 }
564 inputs->push_back({&lock_for_refs_, input.get()});
565 } else {
566 if (input_types[inputs->size()] != dtype) {
567 return errors::InvalidArgument(
568 "The input data type is ", dtype,
569 " , but expected: ", input_types[inputs->size()]);
570 }
571 inputs->push_back({nullptr, input.get()});
572 }
573
574 // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
575 // collect the inputs.
576 tensors_.push_back(std::move(input));
577
578 return OkStatus();
579 }
580
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)581 Status DatasetOpsTestBase::CheckIteratorGetNext(
582 const std::vector<Tensor>& expected_outputs, bool compare_order) {
583 return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
584 expected_outputs, compare_order);
585 }
586
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)587 Status DatasetOpsTestBase::CheckIteratorGetNext(
588 TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
589 bool compare_order) {
590 return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
591 expected_outputs, compare_order);
592 }
593
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)594 Status DatasetOpsTestBase::CheckIteratorGetNext(
595 IteratorBase* iterator, IteratorContext* ctx,
596 const std::vector<Tensor>& expected_outputs, bool compare_order) {
597 bool end_of_sequence = false;
598 std::vector<Tensor> out_tensors;
599 while (!end_of_sequence) {
600 std::vector<Tensor> next;
601 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
602 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
603 }
604 // Call GetNext one more time to make sure it still reports
605 // end_of_sequence = True.
606 std::vector<Tensor> unused;
607 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
608 EXPECT_TRUE(end_of_sequence);
609
610 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
611 /*compare_order=*/compare_order));
612 return OkStatus();
613 }
614
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)615 Status DatasetOpsTestBase::CheckIteratorSkip(
616 int num_to_skip, int expected_num_skipped, bool get_next,
617 const std::vector<Tensor>& expected_outputs, bool compare_order) {
618 IteratorBase* iterator = iterator_.get();
619 IteratorContext* ctx = iterator_ctx_.get();
620
621 bool end_of_sequence = false;
622 int num_skipped = 0;
623 TF_RETURN_IF_ERROR(
624 iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
625 EXPECT_TRUE(num_skipped == expected_num_skipped);
626 if (get_next) {
627 EXPECT_TRUE(!end_of_sequence);
628 std::vector<Tensor> out_tensors;
629 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
630 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
631 /*compare_order=*/compare_order));
632 }
633 return OkStatus();
634 }
635
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)636 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
637 const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
638 std::unique_ptr<TestDataset> dataset;
639 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
640 std::vector<std::unique_ptr<SplitProvider>> split_providers;
641 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
642 std::unique_ptr<TestIterator> iterator;
643 TF_RETURN_IF_ERROR(
644 MakeIterator(params, *dataset, std::move(split_providers), &iterator));
645 TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
646 /*compare_order=*/true));
647 return OkStatus();
648 }
649
CheckSplitProviderShardedIteration(const DatasetParams & params,int64_t num_shards,int64_t shard_index,const std::vector<Tensor> & expected_outputs)650 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
651 const DatasetParams& params, int64_t num_shards, int64_t shard_index,
652 const std::vector<Tensor>& expected_outputs) {
653 std::unique_ptr<TestDataset> dataset;
654 TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
655 std::vector<std::unique_ptr<SplitProvider>> split_providers;
656 TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
657 for (int i = 0; i < split_providers.size(); ++i) {
658 split_providers[i] = std::make_unique<ShardingSplitProvider>(
659 num_shards, shard_index, std::move(split_providers[i]));
660 }
661 std::unique_ptr<IteratorContext> iterator_ctx;
662 TF_RETURN_IF_ERROR(
663 CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
664 IteratorContext::Params iterator_params(iterator_ctx.get());
665 std::move(split_providers.begin(), split_providers.end(),
666 std::back_inserter(iterator_params.split_providers));
667 iterator_ctx = std::make_unique<IteratorContext>(iterator_params);
668 int mid_breakpoint = expected_outputs.size() / 2;
669 int near_end_breakpoint = expected_outputs.size() - 1;
670 int end_breakpoint = expected_outputs.size();
671 TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
672 dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
673 expected_outputs,
674 /*breakpoints=*/
675 {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
676 /*compare_order=*/true));
677 return OkStatus();
678 }
679
CheckDatasetNodeName(const string & expected_dataset_node_name)680 Status DatasetOpsTestBase::CheckDatasetNodeName(
681 const string& expected_dataset_node_name) {
682 EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
683 return OkStatus();
684 }
685
CheckDatasetTypeString(const string & expected_type_str)686 Status DatasetOpsTestBase::CheckDatasetTypeString(
687 const string& expected_type_str) {
688 EXPECT_EQ(dataset_->type_string(), expected_type_str);
689 return OkStatus();
690 }
691
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)692 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
693 const DataTypeVector& expected_output_dtypes) {
694 TF_EXPECT_OK(
695 VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
696 return OkStatus();
697 }
698
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)699 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
700 const std::vector<PartialTensorShape>& expected_output_shapes) {
701 TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
702 expected_output_shapes));
703 return OkStatus();
704 }
705
CheckDatasetCardinality(int expected_cardinality)706 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
707 EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
708 return OkStatus();
709 }
710
CheckDatasetOptions(const Options & expected_options)711 Status DatasetOpsTestBase::CheckDatasetOptions(
712 const Options& expected_options) {
713 EXPECT_EQ(dataset_->options().SerializeAsString(),
714 expected_options.SerializeAsString());
715 return OkStatus();
716 }
717
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)718 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
719 const DataTypeVector& expected_output_dtypes) {
720 TF_EXPECT_OK(
721 VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
722 return OkStatus();
723 }
724
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)725 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
726 const std::vector<PartialTensorShape>& expected_output_shapes) {
727 TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
728 expected_output_shapes));
729 return OkStatus();
730 }
731
CheckIteratorPrefix(const string & expected_iterator_prefix)732 Status DatasetOpsTestBase::CheckIteratorPrefix(
733 const string& expected_iterator_prefix) {
734 EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
735 return OkStatus();
736 }
737
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)738 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
739 DatasetBase* dataset, IteratorContext* iterator_ctx,
740 const std::string& iterator_prefix,
741 const std::vector<Tensor>& expected_outputs,
742 const std::vector<int>& breakpoints, bool compare_order) {
743 std::unique_ptr<IteratorBase> iterator;
744 TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
745 iterator_prefix, &iterator));
746 std::unique_ptr<SerializationContext> serialization_ctx;
747 TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
748 bool end_of_sequence = false;
749 int cur_iteration = 0;
750 std::vector<Tensor> out_tensors;
751 for (int breakpoint : breakpoints) {
752 VariantTensorDataWriter writer;
753 TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
754 std::vector<const VariantTensorData*> data;
755 writer.GetData(&data);
756 VariantTensorDataReader reader(data);
757 TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
758 *dataset, &iterator));
759
760 while (cur_iteration <= breakpoint) {
761 std::vector<Tensor> next;
762 TF_RETURN_IF_ERROR(
763 iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
764 out_tensors.insert(out_tensors.end(), next.begin(), next.end());
765 cur_iteration++;
766 }
767 }
768 TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
769 /*compare_order=*/compare_order));
770 return OkStatus();
771 }
772
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)773 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
774 const std::string& iterator_prefix,
775 const std::vector<Tensor>& expected_outputs,
776 const std::vector<int>& breakpoints, bool compare_order) {
777 return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
778 iterator_prefix, expected_outputs,
779 breakpoints, compare_order);
780 }
781
Initialize(const DatasetParams & dataset_params)782 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
783 if (initialized_) {
784 return errors::Internal(
785 "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
786 "iterator_ctx_, iterator_) have already been initialized.");
787 }
788 TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
789 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
790 &dataset_ctx_, &tensors_, &dataset_));
791 TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
792 TF_RETURN_IF_ERROR(
793 dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
794 dataset_params.iterator_prefix(), &iterator_));
795 initialized_ = true;
796 return OkStatus();
797 }
798
InitializeRuntime(const DatasetParams & dataset_params)799 Status DatasetOpsTestBase::InitializeRuntime(
800 const DatasetParams& dataset_params) {
801 TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
802 TF_RETURN_IF_ERROR(
803 InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
804 return OkStatus();
805 }
806
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)807 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
808 std::unique_ptr<TestDataset>* dataset) {
809 DatasetBase* dataset_base;
810 std::unique_ptr<OpKernel> dataset_kernel;
811 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
812 std::unique_ptr<OpKernelContext> dataset_ctx;
813 std::vector<std::unique_ptr<Tensor>> created_tensors;
814 TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
815 &dataset_ctx_params, &dataset_ctx,
816 &created_tensors, &dataset_base));
817 *dataset = std::make_unique<TestDataset>(
818 std::move(dataset_kernel), std::move(dataset_ctx_params),
819 std::move(dataset_ctx), std::move(created_tensors), dataset_base);
820 return OkStatus();
821 }
822
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)823 Status DatasetOpsTestBase::RunDatasetOp(
824 const DatasetParams& dataset_params,
825 std::unique_ptr<OpKernel>* dataset_kernel,
826 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
827 std::vector<std::unique_ptr<Tensor>>* created_tensors,
828 std::unique_ptr<OpKernelContext>* dataset_ctx) {
829 std::vector<Tensor*> input_datasets;
830 for (auto& input : dataset_params.input_dataset_params()) {
831 std::unique_ptr<Tensor> t;
832 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
833 input_datasets.push_back(t.get());
834 created_tensors->push_back(std::move(t));
835 }
836 gtl::InlinedVector<TensorValue, 4> inputs;
837 for (auto input_dataset : input_datasets) {
838 inputs.emplace_back(TensorValue(input_dataset));
839 }
840
841 // Copy the input tensors, storing them in the `inputs` vectors, and storing
842 // owned references to the copies in `created_tensors`.
843 for (auto& input : dataset_params.GetInputTensors()) {
844 auto copy = std::make_unique<Tensor>(input);
845 inputs.push_back(TensorValue(copy.get()));
846 created_tensors->push_back(std::move(copy));
847 }
848
849 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
850 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
851 dataset_ctx_params, dataset_ctx));
852 TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
853 return OkStatus();
854 }
855
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)856 Status DatasetOpsTestBase::MakeDataset(
857 const DatasetParams& dataset_params,
858 std::unique_ptr<OpKernel>* dataset_kernel,
859 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
860 std::unique_ptr<OpKernelContext>* dataset_ctx,
861 std::vector<std::unique_ptr<Tensor>>* created_tensors,
862 DatasetBase** dataset) {
863 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
864 dataset_ctx_params, created_tensors,
865 dataset_ctx));
866 // Assume that DatasetOp has only one output.
867 DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
868 TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
869 return OkStatus();
870 }
871
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<TestIterator> * iterator)872 Status DatasetOpsTestBase::MakeIterator(
873 const DatasetParams& dataset_params, const TestDataset& dataset,
874 std::vector<std::unique_ptr<SplitProvider>> split_providers,
875 std::unique_ptr<TestIterator>* iterator) {
876 std::unique_ptr<IteratorContext> iterator_ctx;
877 TF_RETURN_IF_ERROR(
878 CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
879 IteratorContext::Params iterator_params(iterator_ctx.get());
880 std::move(split_providers.begin(), split_providers.end(),
881 std::back_inserter(iterator_params.split_providers));
882
883 iterator_ctx = std::make_unique<IteratorContext>(iterator_params);
884 std::unique_ptr<IteratorBase> iterator_base;
885 TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
886 iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
887 &iterator_base));
888 *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
889 std::move(iterator_base));
890 return OkStatus();
891 }
892
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)893 Status DatasetOpsTestBase::MakeIterator(
894 const DatasetParams& dataset_params, const TestDataset& dataset,
895 std::unique_ptr<TestIterator>* iterator) {
896 return MakeIterator(dataset_params, dataset, /*split_providers=*/{},
897 iterator);
898 }
899
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)900 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
901 std::vector<Tensor>* outputs) {
902 TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, ¶ms_,
903 &tensors_, &dataset_ctx_));
904 for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
905 outputs->emplace_back(*dataset_ctx_->mutable_output(i));
906 }
907 return OkStatus();
908 }
909
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)910 Status DatasetOpsTestBase::MakeDatasetOpKernel(
911 const DatasetParams& dataset_params,
912 std::unique_ptr<OpKernel>* dataset_kernel) {
913 name_utils::OpNameParams params;
914 params.op_version = dataset_params.op_version();
915 std::vector<string> input_names;
916 TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
917 AttributeVector attributes;
918 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
919 NodeDef node_def =
920 test::function::NDef(dataset_params.node_name(), dataset_params.op_name(),
921 input_names, attributes);
922 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
923 return OkStatus();
924 }
925
MakeGetOptionsOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * op_kernel)926 Status DatasetOpsTestBase::MakeGetOptionsOpKernel(
927 const DatasetParams& dataset_params, std::unique_ptr<OpKernel>* op_kernel) {
928 name_utils::OpNameParams params;
929 params.op_version = dataset_params.op_version();
930 std::vector<string> input_names;
931 TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
932 AttributeVector attributes;
933 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
934 NodeDef node_def = test::function::NDef(dataset_params.node_name(),
935 dataset_params.dataset_type(),
936 input_names, attributes);
937 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
938 return OkStatus();
939 }
940
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)941 Status DatasetOpsTestBase::MakeDatasetTensor(
942 const DatasetParams& dataset_params,
943 std::vector<std::unique_ptr<Tensor>>* created_tensors,
944 std::unique_ptr<Tensor>* dataset) {
945 // Make sure all the input dataset tensors have been populated.
946 std::vector<Tensor*> input_datasets;
947 for (auto& input : dataset_params.input_dataset_params()) {
948 std::unique_ptr<Tensor> t;
949 TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
950 input_datasets.push_back(t.get());
951 created_tensors->push_back(std::move(t));
952 }
953
954 AttributeVector attributes;
955 TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
956
957 auto input_tensors = dataset_params.GetInputTensors();
958 gtl::InlinedVector<TensorValue, 4> inputs;
959 inputs.reserve(input_datasets.size() + input_tensors.size());
960 for (auto input_dataset : input_datasets) {
961 inputs.emplace_back(TensorValue(input_dataset));
962 }
963 for (auto& input_tensor : input_tensors) {
964 inputs.emplace_back(TensorValue(&input_tensor));
965 }
966
967 DatasetBase* dataset_base;
968 std::unique_ptr<OpKernel> dataset_kernel;
969 std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
970 std::unique_ptr<OpKernelContext> dataset_ctx;
971 TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
972 TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
973 &dataset_ctx_params, &dataset_ctx));
974 TF_RETURN_IF_ERROR(
975 CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
976 Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
977 TF_RETURN_IF_ERROR(
978 StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
979 *dataset = std::make_unique<Tensor>(dataset_tensor);
980 return OkStatus();
981 }
982
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)983 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
984 std::vector<PartialTensorShape> output_shapes,
985 string node_name)
986 : output_dtypes_(std::move(output_dtypes)),
987 output_shapes_(std::move(output_shapes)),
988 node_name_(std::move(node_name)) {}
989
IsDatasetTensor(const Tensor & tensor)990 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
991 return tensor.dtype() == DT_VARIANT &&
992 TensorShapeUtils::IsScalar(tensor.shape());
993 }
994
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)995 RangeDatasetParams::RangeDatasetParams(
996 int64_t start, int64_t stop, int64_t step, DataTypeVector output_dtypes,
997 std::vector<PartialTensorShape> output_shapes, string node_name)
998 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
999 std::move(node_name)),
1000 start_(start),
1001 stop_(stop),
1002 step_(step) {}
1003
RangeDatasetParams(int64_t start,int64_t stop,int64_t step)1004 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1005 int64_t step)
1006 : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
1007 start_(start),
1008 stop_(stop),
1009 step_(step) {}
1010
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes)1011 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1012 int64_t step,
1013 DataTypeVector output_dtypes)
1014 : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
1015 "range_dataset"),
1016 start_(start),
1017 stop_(stop),
1018 step_(step) {}
1019
GetInputTensors() const1020 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
1021 Tensor start_tensor = CreateTensor<int64_t>(TensorShape({}), {start_});
1022 Tensor stop_tensor = CreateTensor<int64_t>(TensorShape({}), {stop_});
1023 Tensor step_tensor = CreateTensor<int64_t>(TensorShape({}), {step_});
1024 return {start_tensor, stop_tensor, step_tensor};
1025 }
1026
GetInputNames(std::vector<string> * input_names) const1027 Status RangeDatasetParams::GetInputNames(
1028 std::vector<string>* input_names) const {
1029 *input_names = {"start", "stop", "step"};
1030 return OkStatus();
1031 }
1032
GetAttributes(AttributeVector * attr_vector) const1033 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1034 *attr_vector = {{"output_types", output_dtypes_},
1035 {"output_shapes", output_shapes_},
1036 {"replicate_on_split", false},
1037 {"metadata", ""}};
1038 return OkStatus();
1039 }
1040
dataset_type() const1041 string RangeDatasetParams::dataset_type() const { return "Range"; }
1042
GetInputTensors() const1043 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1044 Tensor batch_size = CreateTensor<int64_t>(TensorShape({}), {batch_size_});
1045 Tensor drop_remainder =
1046 CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1047 return {batch_size, drop_remainder};
1048 }
1049
GetInputNames(std::vector<string> * input_names) const1050 Status BatchDatasetParams::GetInputNames(
1051 std::vector<string>* input_names) const {
1052 *input_names = {"input_dataset", "batch_size", "drop_remainder"};
1053 return OkStatus();
1054 }
1055
GetAttributes(AttributeVector * attr_vector) const1056 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1057 *attr_vector = {{"parallel_copy", parallel_copy_},
1058 {"output_types", output_dtypes_},
1059 {"output_shapes", output_shapes_},
1060 {"metadata", ""}};
1061 return OkStatus();
1062 }
1063
dataset_type() const1064 string BatchDatasetParams::dataset_type() const { return "Batch"; }
1065
GetInputTensors() const1066 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1067 return other_arguments_;
1068 }
1069
GetInputNames(std::vector<string> * input_names) const1070 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1071 input_names->emplace_back("input_dataset");
1072 for (int i = 0; i < other_arguments_.size(); ++i) {
1073 input_names->emplace_back(absl::StrCat("other_arguments_", i));
1074 }
1075 return OkStatus();
1076 }
1077
GetAttributes(AttributeVector * attr_vector) const1078 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1079 *attr_vector = {{"f", func_},
1080 {"Targuments", type_arguments_},
1081 {"output_shapes", output_shapes_},
1082 {"output_types", output_dtypes_},
1083 {"use_inter_op_parallelism", use_inter_op_parallelism_},
1084 {"preserve_cardinality", preserve_cardinality_},
1085 {"metadata", ""}};
1086 return OkStatus();
1087 }
1088
dataset_type() const1089 string MapDatasetParams::dataset_type() const { return "Map"; }
1090
func_lib() const1091 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1092 return func_lib_;
1093 }
1094
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name,bool is_files)1095 TensorSliceDatasetParams::TensorSliceDatasetParams(
1096 std::vector<Tensor> components, string node_name, bool is_files)
1097 : DatasetParams(TensorSliceDtypes(components),
1098 TensorSliceShapes(components), std::move(node_name)),
1099 components_(std::move(components)),
1100 is_files_(is_files) {}
1101
GetInputTensors() const1102 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1103 return components_;
1104 }
1105
GetInputNames(std::vector<string> * input_names) const1106 Status TensorSliceDatasetParams::GetInputNames(
1107 std::vector<string>* input_names) const {
1108 input_names->reserve(components_.size());
1109 for (int i = 0; i < components_.size(); ++i) {
1110 input_names->emplace_back(absl::StrCat("components_", i));
1111 }
1112 return OkStatus();
1113 }
1114
GetAttributes(AttributeVector * attr_vector) const1115 Status TensorSliceDatasetParams::GetAttributes(
1116 AttributeVector* attr_vector) const {
1117 *attr_vector = {{"Toutput_types", output_dtypes_},
1118 {"output_shapes", output_shapes_},
1119 {"is_files", is_files_},
1120 {"replicate_on_split", false},
1121 {"metadata", ""}};
1122 return OkStatus();
1123 }
1124
TensorSliceDtypes(const std::vector<Tensor> & input_components)1125 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1126 const std::vector<Tensor>& input_components) {
1127 DataTypeVector dtypes;
1128 for (const auto& component : input_components) {
1129 dtypes.emplace_back(component.dtype());
1130 }
1131 return dtypes;
1132 }
1133
TensorSliceShapes(const std::vector<Tensor> & input_components)1134 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1135 const std::vector<Tensor>& input_components) {
1136 std::vector<PartialTensorShape> shapes;
1137 for (const auto& component : input_components) {
1138 gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
1139 for (int i = 1; i < component.dims(); ++i) {
1140 partial_dim_sizes.push_back(component.dim_size(i));
1141 }
1142 shapes.emplace_back(std::move(partial_dim_sizes));
1143 }
1144 return shapes;
1145 }
1146
dataset_type() const1147 string TensorSliceDatasetParams::dataset_type() const { return "TensorSlice"; }
1148
GetInputTensors() const1149 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1150 return {CreateTensor<int64_t>(TensorShape({}), {count_})};
1151 }
1152
GetInputNames(std::vector<string> * input_names) const1153 Status TakeDatasetParams::GetInputNames(
1154 std::vector<string>* input_names) const {
1155 *input_names = {"input_dataset", "count"};
1156 return OkStatus();
1157 }
1158
GetAttributes(AttributeVector * attr_vector) const1159 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1160 *attr_vector = {{"output_shapes", output_shapes_},
1161 {"output_types", output_dtypes_},
1162 {"metadata", ""}};
1163 return OkStatus();
1164 }
1165
dataset_type() const1166 string TakeDatasetParams::dataset_type() const { return "Take"; }
1167
GetInputTensors() const1168 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1169 return {};
1170 }
1171
GetInputNames(std::vector<string> * input_names) const1172 Status ConcatenateDatasetParams::GetInputNames(
1173 std::vector<string>* input_names) const {
1174 *input_names = {"input_dataset", "another_dataset"};
1175 return OkStatus();
1176 }
1177
GetAttributes(AttributeVector * attr_vector) const1178 Status ConcatenateDatasetParams::GetAttributes(
1179 AttributeVector* attr_vector) const {
1180 *attr_vector = {{"output_types", output_dtypes_},
1181 {"output_shapes", output_shapes_},
1182 {"metadata", ""}};
1183 return OkStatus();
1184 }
1185
dataset_type() const1186 string ConcatenateDatasetParams::dataset_type() const { return "Concatenate"; }
1187
GetInputTensors() const1188 std::vector<Tensor> OptionsDatasetParams::GetInputTensors() const { return {}; }
1189
GetInputNames(std::vector<string> * input_names) const1190 Status OptionsDatasetParams::GetInputNames(
1191 std::vector<string>* input_names) const {
1192 input_names->emplace_back("input_dataset");
1193 return OkStatus();
1194 }
1195
GetAttributes(AttributeVector * attr_vector) const1196 Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1197 *attr_vector = {{"serialized_options", serialized_options_},
1198 {"output_shapes", output_shapes_},
1199 {"output_types", output_dtypes_},
1200 {"metadata", ""}};
1201 return OkStatus();
1202 }
1203
dataset_type() const1204 string OptionsDatasetParams::dataset_type() const { return "Options"; }
1205
1206 } // namespace data
1207 } // namespace tensorflow
1208