• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_test_base.h"
17 
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/name_utils.h"
38 #include "tensorflow/core/data/serialization_utils.h"
39 #include "tensorflow/core/data/split_utils.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/cancellation.h"
42 #include "tensorflow/core/framework/control_flow.h"
43 #include "tensorflow/core/framework/dataset.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_handle_cache.h"
47 #include "tensorflow/core/framework/function_testlib.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/numeric_types.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/op_def.pb.h"
52 #include "tensorflow/core/framework/op_kernel.h"
53 #include "tensorflow/core/framework/register_types.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/framework/types.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/framework/variant_tensor_data.h"
61 #include "tensorflow/core/framework/versions.pb.h"
62 #include "tensorflow/core/graph/graph.h"
63 #include "tensorflow/core/lib/core/status_test_util.h"
64 #include "tensorflow/core/lib/gtl/inlined_vector.h"
65 #include "tensorflow/core/lib/io/record_writer.h"
66 #include "tensorflow/core/lib/io/zlib_compression_options.h"
67 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
68 #include "tensorflow/core/platform/bfloat16.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/errors.h"
71 #include "tensorflow/core/platform/file_system.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/platform/status.h"
74 #include "tensorflow/core/platform/test.h"
75 #include "tensorflow/core/platform/threadpool.h"
76 #include "tensorflow/core/platform/tstring.h"
77 #include "tensorflow/core/platform/types.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/public/session_options.h"
80 #include "tensorflow/core/public/version.h"
81 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
82 
83 namespace tensorflow {
84 namespace data {
85 
ToString(CompressionType compression_type)86 string ToString(CompressionType compression_type) {
87   switch (compression_type) {
88     case CompressionType::ZLIB:
89       return "ZLIB";
90     case CompressionType::GZIP:
91       return "GZIP";
92     case CompressionType::RAW:
93       return "RAW";
94     case CompressionType::UNCOMPRESSED:
95       return "";
96   }
97 }
98 
GetZlibCompressionOptions(CompressionType compression_type)99 io::ZlibCompressionOptions GetZlibCompressionOptions(
100     CompressionType compression_type) {
101   switch (compression_type) {
102     case CompressionType::ZLIB:
103       return io::ZlibCompressionOptions::DEFAULT();
104     case CompressionType::GZIP:
105       return io::ZlibCompressionOptions::GZIP();
106     case CompressionType::RAW:
107       return io::ZlibCompressionOptions::RAW();
108     case CompressionType::UNCOMPRESSED:
109       LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
110                    << ToString(compression_type);
111       return io::ZlibCompressionOptions::DEFAULT();
112   }
113 }
114 
WriteDataToFile(const string & filename,const char * data)115 Status WriteDataToFile(const string& filename, const char* data) {
116   return WriteDataToFile(filename, data, CompressionParams());
117 }
118 
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)119 Status WriteDataToFile(const string& filename, const char* data,
120                        const CompressionParams& params) {
121   Env* env = Env::Default();
122   std::unique_ptr<WritableFile> file_writer;
123   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
124   if (params.compression_type == CompressionType::UNCOMPRESSED) {
125     TF_RETURN_IF_ERROR(file_writer->Append(data));
126   } else if (params.compression_type == CompressionType::ZLIB ||
127              params.compression_type == CompressionType::GZIP ||
128              params.compression_type == CompressionType::RAW) {
129     auto zlib_compression_options =
130         GetZlibCompressionOptions(params.compression_type);
131     io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
132                              params.output_buffer_size,
133                              zlib_compression_options);
134     TF_RETURN_IF_ERROR(out.Init());
135     TF_RETURN_IF_ERROR(out.Append(data));
136     TF_RETURN_IF_ERROR(out.Flush());
137     TF_RETURN_IF_ERROR(out.Close());
138   } else {
139     return tensorflow::errors::InvalidArgument(
140         "Unsupported compression_type: ", ToString(params.compression_type));
141   }
142 
143   TF_RETURN_IF_ERROR(file_writer->Flush());
144   TF_RETURN_IF_ERROR(file_writer->Close());
145 
146   return Status::OK();
147 }
148 
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)149 Status WriteDataToTFRecordFile(const string& filename,
150                                const std::vector<absl::string_view>& records,
151                                const CompressionParams& params) {
152   Env* env = Env::Default();
153   std::unique_ptr<WritableFile> file_writer;
154   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
155   auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
156       ToString(params.compression_type));
157   options.zlib_options.input_buffer_size = params.input_buffer_size;
158   io::RecordWriter record_writer(file_writer.get(), options);
159   for (const auto& record : records) {
160     TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
161   }
162   TF_RETURN_IF_ERROR(record_writer.Flush());
163   TF_RETURN_IF_ERROR(record_writer.Close());
164   TF_RETURN_IF_ERROR(file_writer->Flush());
165   TF_RETURN_IF_ERROR(file_writer->Close());
166   return Status::OK();
167 }
168 
169 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)170 Status IsEqual(const Tensor& t1, const Tensor& t2) {
171   if (t1.dtype() != t2.dtype()) {
172     return tensorflow::errors::Internal(
173         "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
174         " vs. ", DataTypeString(t2.dtype()));
175   }
176   if (!t1.IsSameSize(t2)) {
177     return tensorflow::errors::Internal(
178         "Two tensors have different shapes: ", t1.shape().DebugString(),
179         " vs. ", t2.shape().DebugString());
180   }
181 
182   auto flat_t1 = t1.flat<T>();
183   auto flat_t2 = t2.flat<T>();
184   auto length = flat_t1.size();
185 
186   for (int i = 0; i < length; ++i) {
187     if (flat_t1(i) != flat_t2(i)) {
188       return tensorflow::errors::Internal(
189           "Two tensors have different values "
190           "at [",
191           i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
192     }
193   }
194   return Status::OK();
195 }
196 
DatasetOpsTestBase()197 DatasetOpsTestBase::DatasetOpsTestBase()
198     : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
199       device_type_(DEVICE_CPU),
200       cpu_num_(kDefaultCPUNum),
201       thread_num_(kDefaultThreadNum) {
202   allocator_ = device_->GetAllocator(AllocatorAttributes());
203 }
204 
~DatasetOpsTestBase()205 DatasetOpsTestBase::~DatasetOpsTestBase() {
206   if (dataset_) {
207     dataset_->Unref();
208   }
209 }
210 
ExpectEqual(const Tensor & a,const Tensor & b)211 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
212   switch (a.dtype()) {
213 #define CASE(DT)                           \
214   case DataTypeToEnum<DT>::value:          \
215     TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
216     break;
217     TF_CALL_NUMBER_TYPES(CASE);
218     TF_CALL_tstring(CASE);
219     // TODO(feihugis): figure out how to support variant tensors.
220 #undef CASE
221     default:
222       return errors::Internal("Unsupported dtype: ", a.dtype());
223   }
224   return Status::OK();
225 }
226 
227 template <typename T>
compare(const Tensor & t1,const Tensor & t2)228 bool compare(const Tensor& t1, const Tensor& t2) {
229   auto flat_t1 = t1.flat<T>();
230   auto flat_t2 = t2.flat<T>();
231   auto length = std::min(flat_t1.size(), flat_t2.size());
232   for (int i = 0; i < length; ++i) {
233     if (flat_t1(i) < flat_t2(i)) return true;
234     if (flat_t1(i) > flat_t2(i)) return false;
235   }
236   return flat_t1.size() < length;
237 }
238 
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)239 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
240                                        std::vector<Tensor> expected_tensors,
241                                        bool compare_order) {
242   if (produced_tensors.size() != expected_tensors.size()) {
243     return Status(tensorflow::errors::Internal(
244         "The two tensor vectors have different size (", produced_tensors.size(),
245         " v.s. ", expected_tensors.size(), ")"));
246   }
247 
248   if (produced_tensors.empty()) return Status::OK();
249   if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
250     return Status(tensorflow::errors::Internal(
251         "The two tensor vectors have different dtypes (",
252         produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
253         ")"));
254   }
255 
256   if (!compare_order) {
257     const DataType& dtype = produced_tensors[0].dtype();
258     switch (dtype) {
259 #define CASE(DT)                                                \
260   case DT:                                                      \
261     std::sort(produced_tensors.begin(), produced_tensors.end(), \
262               compare<EnumToDataType<DT>::Type>);               \
263     std::sort(expected_tensors.begin(), expected_tensors.end(), \
264               compare<EnumToDataType<DT>::Type>);               \
265     break;
266       CASE(DT_FLOAT);
267       CASE(DT_DOUBLE);
268       CASE(DT_INT32);
269       CASE(DT_UINT8);
270       CASE(DT_INT16);
271       CASE(DT_INT8);
272       CASE(DT_STRING);
273       CASE(DT_INT64);
274       CASE(DT_BOOL);
275       CASE(DT_QINT8);
276       CASE(DT_QUINT8);
277       CASE(DT_QINT32);
278       CASE(DT_QINT16);
279       CASE(DT_QUINT16);
280       CASE(DT_UINT16);
281       CASE(DT_HALF);
282       CASE(DT_UINT32);
283       CASE(DT_UINT64);
284       // TODO(feihugis): support other dtypes.
285 #undef CASE
286       default:
287         return errors::Internal("Unsupported dtype: ", dtype);
288     }
289   }
290 
291   for (int i = 0; i < produced_tensors.size(); ++i) {
292     TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
293                                                        expected_tensors[i]));
294   }
295   return Status::OK();
296 }
297 
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)298 Status DatasetOpsTestBase::CreateOpKernel(
299     const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
300   OpKernel* kernel;
301   Status s;
302 
303   std::shared_ptr<const NodeProperties> props;
304   TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
305       node_def, flr_->GetFunctionLibraryDefinition(), &props));
306   TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
307       device_type_, device_.get(), allocator_, flr_,
308       device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
309   op_kernel->reset(kernel);
310   return Status::OK();
311 }
312 
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)313 Status DatasetOpsTestBase::CreateDatasetContext(
314     OpKernel* const dateset_kernel,
315     gtl::InlinedVector<TensorValue, 4>* const inputs,
316     std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
317     std::unique_ptr<OpKernelContext>* dataset_context) {
318   Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
319   if (!status.ok()) {
320     VLOG(0) << "WARNING: " << status.ToString();
321   }
322   TF_RETURN_IF_ERROR(CreateOpKernelContext(
323       dateset_kernel, inputs, dataset_context_params, dataset_context));
324   return Status::OK();
325 }
326 
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)327 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
328                                          OpKernelContext* context,
329                                          DatasetBase** const dataset) {
330   TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
331   // Assume that DatasetOp has only one output.
332   DCHECK_EQ(context->num_outputs(), 1);
333   TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
334   return Status::OK();
335 }
336 
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)337 Status DatasetOpsTestBase::RestoreIterator(
338     IteratorContext* ctx, IteratorStateReader* reader,
339     const string& output_prefix, const DatasetBase& dataset,
340     std::unique_ptr<IteratorBase>* iterator) {
341   return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
342                                             iterator);
343 }
344 
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)345 Status DatasetOpsTestBase::CreateIteratorContext(
346     OpKernelContext* const op_context,
347     std::unique_ptr<IteratorContext>* iterator_context) {
348   IteratorContext::Params params(op_context);
349   params.resource_mgr = op_context->resource_manager();
350   function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
351   params.function_handle_cache = function_handle_cache_.get();
352   params.cancellation_manager = cancellation_manager_.get();
353   *iterator_context = absl::make_unique<IteratorContext>(params);
354   return Status::OK();
355 }
356 
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)357 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
358                                                  int output_index,
359                                                  DatasetBase** const dataset) {
360   Tensor* output = context->mutable_output(output_index);
361   Status status = GetDatasetFromVariantTensor(*output, dataset);
362   (*dataset)->Ref();
363   return status;
364 }
365 
InitThreadPool(int thread_num)366 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
367   if (thread_num < 1) {
368     return errors::InvalidArgument(
369         "The `thread_num` argument should be positive but got: ", thread_num);
370   }
371   thread_pool_ = absl::make_unique<thread::ThreadPool>(
372       Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
373   return Status::OK();
374 }
375 
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)376 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
377     const std::vector<FunctionDef>& flib, int cpu_num) {
378   if (cpu_num < 1) {
379     return errors::InvalidArgument(
380         "The `cpu_num` argument should be positive but got: ", cpu_num);
381   }
382   SessionOptions options;
383   auto* device_count = options.config.mutable_device_count();
384   device_count->insert({"CPU", cpu_num});
385   std::vector<std::unique_ptr<Device>> devices;
386   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
387       options, "/job:localhost/replica:0/task:0", &devices));
388   device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
389   resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
390 
391   FunctionDefLibrary proto;
392   for (const auto& fdef : flib) *(proto.add_function()) = fdef;
393   lib_def_ =
394       absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
395 
396   OptimizerOptions opts;
397   pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
398       device_mgr_.get(), Env::Default(), /*config=*/nullptr,
399       TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
400       /*parent=*/nullptr,
401       /*session_metadata=*/nullptr,
402       Rendezvous::Factory{
403           [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
404             *r = new IntraProcessRendezvous(device_mgr);
405             return Status::OK();
406           }});
407   flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
408   if (thread_pool_ == nullptr) {
409     runner_ = [](const std::function<void()>& fn) { fn(); };
410   } else {
411     runner_ = [this](std::function<void()> fn) {
412       thread_pool_->Schedule(std::move(fn));
413     };
414   }
415   return Status::OK();
416 }
417 
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)418 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
419                                        OpKernelContext* context) {
420   device_->Compute(op_kernel, context);
421   return context->status();
422 }
423 
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)424 Status DatasetOpsTestBase::RunFunction(
425     const FunctionDef& fdef, test::function::Attrs attrs,
426     const std::vector<Tensor>& args,
427     const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
428   std::unique_ptr<Executor> exec;
429   InstantiationResult result;
430   auto GetOpSig = [](const string& op, const OpDef** sig) {
431     return OpRegistry::Global()->LookUpOpDef(op, sig);
432   };
433   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
434 
435   DataTypeVector arg_types = result.arg_types;
436   DataTypeVector ret_types = result.ret_types;
437 
438   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
439   TF_RETURN_IF_ERROR(
440       ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
441 
442   const int version = g->versions().producer();
443   LocalExecutorParams params;
444   params.function_library = flr_;
445   params.device = device_.get();
446   params.create_kernel = [this, version](
447                              const std::shared_ptr<const NodeProperties>& props,
448                              OpKernel** kernel) {
449     return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
450                                  kernel);
451   };
452   params.delete_kernel = [](OpKernel* kernel) {
453     DeleteNonCachedKernel(kernel);
454   };
455 
456   Executor* cur_exec;
457   TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
458   exec.reset(cur_exec);
459   FunctionCallFrame frame(arg_types, ret_types);
460   TF_RETURN_IF_ERROR(frame.SetArgs(args));
461   Executor::Args exec_args;
462   exec_args.call_frame = &frame;
463   exec_args.runner = runner_;
464   TF_RETURN_IF_ERROR(exec->Run(exec_args));
465   std::vector<Tensor> computed;
466   TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
467   if (computed.size() != rets.size()) {
468     return errors::InvalidArgument(
469         "The result does not match the expected number of return outpus",
470         ". Expected: ", rets.size(), ". Actual: ", computed.size());
471   }
472   for (int i = 0; i < rets.size(); ++i) {
473     *(rets[i]) = computed[i];
474   }
475   return Status::OK();
476 }
477 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)478 Status DatasetOpsTestBase::CreateOpKernelContext(
479     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
480     std::unique_ptr<OpKernelContext>* context) {
481   return CreateOpKernelContext(kernel, inputs, &params_, context);
482 }
483 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)484 Status DatasetOpsTestBase::CreateOpKernelContext(
485     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
486     std::unique_ptr<OpKernelContext::Params>* context_params,
487     std::unique_ptr<OpKernelContext>* context) {
488   auto params = absl::make_unique<OpKernelContext::Params>();
489   cancellation_manager_ = absl::make_unique<CancellationManager>();
490   params->cancellation_manager = cancellation_manager_.get();
491   params->device = device_.get();
492   params->frame_iter = FrameAndIter(0, 0);
493   params->function_library = flr_;
494   params->inputs = inputs;
495   params->op_kernel = kernel;
496   params->resource_manager = resource_mgr_.get();
497   params->runner = &runner_;
498   slice_reader_cache_ =
499       absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
500   params->slice_reader_cache = slice_reader_cache_.get();
501   step_container_ =
502       absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
503   params->step_container = step_container_.get();
504 
505   // Set the allocator attributes for the outputs.
506   allocator_attrs_.clear();
507   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
508     AllocatorAttributes attr;
509     const bool on_host =
510         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
511     attr.set_on_host(on_host);
512     allocator_attrs_.emplace_back(attr);
513   }
514   params->output_attr_array = allocator_attrs_.data();
515 
516   *context = absl::make_unique<OpKernelContext>(params.get());
517   *context_params = std::move(params);
518   return Status::OK();
519 }
520 
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)521 Status DatasetOpsTestBase::CreateSerializationContext(
522     std::unique_ptr<SerializationContext>* context) {
523   *context =
524       absl::make_unique<SerializationContext>(SerializationContext::Params{});
525   return Status::OK();
526 }
527 
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)528 Status DatasetOpsTestBase::CheckOpKernelInput(
529     const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
530   if (kernel.num_inputs() != inputs.size()) {
531     return errors::InvalidArgument("The number of input elements should be ",
532                                    kernel.num_inputs(),
533                                    ", but got: ", inputs.size());
534   }
535   return Status::OK();
536 }
537 
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)538 Status DatasetOpsTestBase::AddDatasetInput(
539     gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
540     DataType dtype, const TensorShape& shape) {
541   if (input_types.size() < inputs->size()) {
542     return errors::InvalidArgument("Adding more inputs than types: ",
543                                    inputs->size(), " vs. ", input_types.size());
544   }
545   bool is_ref = IsRefType(input_types[inputs->size()]);
546   auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
547 
548   if (is_ref) {
549     DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
550     if (expected_dtype != dtype) {
551       return errors::InvalidArgument("The input data type is ", dtype,
552                                      " , but expected: ", expected_dtype);
553     }
554     inputs->push_back({&lock_for_refs_, input.get()});
555   } else {
556     if (input_types[inputs->size()] != dtype) {
557       return errors::InvalidArgument(
558           "The input data type is ", dtype,
559           " , but expected: ", input_types[inputs->size()]);
560     }
561     inputs->push_back({nullptr, input.get()});
562   }
563 
564   // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
565   // collect the inputs.
566   tensors_.push_back(std::move(input));
567 
568   return Status::OK();
569 }
570 
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)571 Status DatasetOpsTestBase::CheckIteratorGetNext(
572     const std::vector<Tensor>& expected_outputs, bool compare_order) {
573   return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
574                               expected_outputs, compare_order);
575 }
576 
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)577 Status DatasetOpsTestBase::CheckIteratorGetNext(
578     TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
579     bool compare_order) {
580   return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
581                               expected_outputs, compare_order);
582 }
583 
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)584 Status DatasetOpsTestBase::CheckIteratorGetNext(
585     IteratorBase* iterator, IteratorContext* ctx,
586     const std::vector<Tensor>& expected_outputs, bool compare_order) {
587   bool end_of_sequence = false;
588   std::vector<Tensor> out_tensors;
589   while (!end_of_sequence) {
590     std::vector<Tensor> next;
591     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
592     out_tensors.insert(out_tensors.end(), next.begin(), next.end());
593   }
594   // Call GetNext one more time to make sure it still reports
595   // end_of_sequence = True.
596   std::vector<Tensor> unused;
597   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
598   EXPECT_TRUE(end_of_sequence);
599 
600   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
601                            /*compare_order=*/compare_order));
602   return Status::OK();
603 }
604 
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)605 Status DatasetOpsTestBase::CheckIteratorSkip(
606     int num_to_skip, int expected_num_skipped, bool get_next,
607     const std::vector<Tensor>& expected_outputs, bool compare_order) {
608   IteratorBase* iterator = iterator_.get();
609   IteratorContext* ctx = iterator_ctx_.get();
610 
611   bool end_of_sequence = false;
612   int num_skipped = 0;
613   TF_RETURN_IF_ERROR(
614       iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
615   EXPECT_TRUE(num_skipped == expected_num_skipped);
616   if (get_next) {
617     EXPECT_TRUE(!end_of_sequence);
618     std::vector<Tensor> out_tensors;
619     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
620     TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
621                              /*compare_order=*/compare_order));
622   }
623   return Status::OK();
624 }
625 
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)626 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
627     const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
628   std::unique_ptr<TestDataset> dataset;
629   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
630   std::vector<std::unique_ptr<SplitProvider>> split_providers;
631   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
632   std::unique_ptr<TestIterator> iterator;
633   TF_RETURN_IF_ERROR(
634       MakeIterator(params, *dataset, std::move(split_providers), &iterator));
635   TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
636                                           /*compare_order=*/true));
637   return Status::OK();
638 }
639 
CheckSplitProviderShardedIteration(const DatasetParams & params,int64_t num_shards,int64_t shard_index,const std::vector<Tensor> & expected_outputs)640 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
641     const DatasetParams& params, int64_t num_shards, int64_t shard_index,
642     const std::vector<Tensor>& expected_outputs) {
643   std::unique_ptr<TestDataset> dataset;
644   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
645   std::vector<std::unique_ptr<SplitProvider>> split_providers;
646   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
647   for (int i = 0; i < split_providers.size(); ++i) {
648     split_providers[i] = std::make_unique<ShardingSplitProvider>(
649         num_shards, shard_index, std::move(split_providers[i]));
650   }
651   std::unique_ptr<IteratorContext> iterator_ctx;
652   TF_RETURN_IF_ERROR(
653       CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
654   IteratorContext::Params iterator_params(iterator_ctx.get());
655   std::move(split_providers.begin(), split_providers.end(),
656             std::back_inserter(iterator_params.split_providers));
657   iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
658   int mid_breakpoint = expected_outputs.size() / 2;
659   int near_end_breakpoint = expected_outputs.size() - 1;
660   int end_breakpoint = expected_outputs.size();
661   TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
662       dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
663       expected_outputs,
664       /*breakpoints=*/
665       {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
666       /*compare_order=*/true));
667   return Status::OK();
668 }
669 
CheckDatasetNodeName(const string & expected_dataset_node_name)670 Status DatasetOpsTestBase::CheckDatasetNodeName(
671     const string& expected_dataset_node_name) {
672   EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
673   return Status::OK();
674 }
675 
CheckDatasetTypeString(const string & expected_type_str)676 Status DatasetOpsTestBase::CheckDatasetTypeString(
677     const string& expected_type_str) {
678   EXPECT_EQ(dataset_->type_string(), expected_type_str);
679   return Status::OK();
680 }
681 
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)682 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
683     const DataTypeVector& expected_output_dtypes) {
684   TF_EXPECT_OK(
685       VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
686   return Status::OK();
687 }
688 
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)689 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
690     const std::vector<PartialTensorShape>& expected_output_shapes) {
691   TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
692                                       expected_output_shapes));
693   return Status::OK();
694 }
695 
CheckDatasetCardinality(int expected_cardinality)696 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
697   EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
698   return Status::OK();
699 }
700 
CheckDatasetOptions(const Options & expected_options)701 Status DatasetOpsTestBase::CheckDatasetOptions(
702     const Options& expected_options) {
703   EXPECT_EQ(dataset_->options().SerializeAsString(),
704             expected_options.SerializeAsString());
705   return Status::OK();
706 }
707 
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)708 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
709     const DataTypeVector& expected_output_dtypes) {
710   TF_EXPECT_OK(
711       VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
712   return Status::OK();
713 }
714 
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)715 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
716     const std::vector<PartialTensorShape>& expected_output_shapes) {
717   TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
718                                       expected_output_shapes));
719   return Status::OK();
720 }
721 
CheckIteratorPrefix(const string & expected_iterator_prefix)722 Status DatasetOpsTestBase::CheckIteratorPrefix(
723     const string& expected_iterator_prefix) {
724   EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
725   return Status::OK();
726 }
727 
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)728 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
729     DatasetBase* dataset, IteratorContext* iterator_ctx,
730     const std::string& iterator_prefix,
731     const std::vector<Tensor>& expected_outputs,
732     const std::vector<int>& breakpoints, bool compare_order) {
733   std::unique_ptr<IteratorBase> iterator;
734   TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
735                                            iterator_prefix, &iterator));
736   std::unique_ptr<SerializationContext> serialization_ctx;
737   TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
738   bool end_of_sequence = false;
739   int cur_iteration = 0;
740   std::vector<Tensor> out_tensors;
741   for (int breakpoint : breakpoints) {
742     VariantTensorDataWriter writer;
743     TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
744     std::vector<const VariantTensorData*> data;
745     writer.GetData(&data);
746     VariantTensorDataReader reader(data);
747     TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
748                                  *dataset, &iterator));
749 
750     while (cur_iteration <= breakpoint) {
751       std::vector<Tensor> next;
752       TF_RETURN_IF_ERROR(
753           iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
754       out_tensors.insert(out_tensors.end(), next.begin(), next.end());
755       cur_iteration++;
756     }
757   }
758   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
759                            /*compare_order=*/compare_order));
760   return Status::OK();
761 }
762 
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)763 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
764     const std::string& iterator_prefix,
765     const std::vector<Tensor>& expected_outputs,
766     const std::vector<int>& breakpoints, bool compare_order) {
767   return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
768                                      iterator_prefix, expected_outputs,
769                                      breakpoints, compare_order);
770 }
771 
Initialize(const DatasetParams & dataset_params)772 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
773   if (initialized_) {
774     return errors::Internal(
775         "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
776         "iterator_ctx_, iterator_) have already been initialized.");
777   }
778   TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
779   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, &params_,
780                                  &dataset_ctx_, &tensors_, &dataset_));
781   TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
782   TF_RETURN_IF_ERROR(
783       dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
784                              dataset_params.iterator_prefix(), &iterator_));
785   initialized_ = true;
786   return Status::OK();
787 }
788 
InitializeRuntime(const DatasetParams & dataset_params)789 Status DatasetOpsTestBase::InitializeRuntime(
790     const DatasetParams& dataset_params) {
791   TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
792   TF_RETURN_IF_ERROR(
793       InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
794   return Status::OK();
795 }
796 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)797 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
798                                        std::unique_ptr<TestDataset>* dataset) {
799   DatasetBase* dataset_base;
800   std::unique_ptr<OpKernel> dataset_kernel;
801   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
802   std::unique_ptr<OpKernelContext> dataset_ctx;
803   std::vector<std::unique_ptr<Tensor>> created_tensors;
804   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
805                                  &dataset_ctx_params, &dataset_ctx,
806                                  &created_tensors, &dataset_base));
807   *dataset = std::make_unique<TestDataset>(
808       std::move(dataset_kernel), std::move(dataset_ctx_params),
809       std::move(dataset_ctx), std::move(created_tensors), dataset_base);
810   return Status::OK();
811 }
812 
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)813 Status DatasetOpsTestBase::RunDatasetOp(
814     const DatasetParams& dataset_params,
815     std::unique_ptr<OpKernel>* dataset_kernel,
816     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
817     std::vector<std::unique_ptr<Tensor>>* created_tensors,
818     std::unique_ptr<OpKernelContext>* dataset_ctx) {
819   std::vector<Tensor*> input_datasets;
820   for (auto& input : dataset_params.input_dataset_params()) {
821     std::unique_ptr<Tensor> t;
822     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
823     input_datasets.push_back(t.get());
824     created_tensors->push_back(std::move(t));
825   }
826   gtl::InlinedVector<TensorValue, 4> inputs;
827   for (auto input_dataset : input_datasets) {
828     inputs.emplace_back(TensorValue(input_dataset));
829   }
830 
831   // Copy the input tensors, storing them in the `inputs` vectors, and storing
832   // owned references to the copies in `created_tensors`.
833   for (auto& input : dataset_params.GetInputTensors()) {
834     auto copy = absl::make_unique<Tensor>(input);
835     inputs.push_back(TensorValue(copy.get()));
836     created_tensors->push_back(std::move(copy));
837   }
838 
839   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
840   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
841                                           dataset_ctx_params, dataset_ctx));
842   TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
843   return Status::OK();
844 }
845 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)846 Status DatasetOpsTestBase::MakeDataset(
847     const DatasetParams& dataset_params,
848     std::unique_ptr<OpKernel>* dataset_kernel,
849     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
850     std::unique_ptr<OpKernelContext>* dataset_ctx,
851     std::vector<std::unique_ptr<Tensor>>* created_tensors,
852     DatasetBase** dataset) {
853   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
854                                   dataset_ctx_params, created_tensors,
855                                   dataset_ctx));
856   // Assume that DatasetOp has only one output.
857   DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
858   TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
859   return Status::OK();
860 }
861 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<TestIterator> * iterator)862 Status DatasetOpsTestBase::MakeIterator(
863     const DatasetParams& dataset_params, const TestDataset& dataset,
864     std::vector<std::unique_ptr<SplitProvider>> split_providers,
865     std::unique_ptr<TestIterator>* iterator) {
866   std::unique_ptr<IteratorContext> iterator_ctx;
867   TF_RETURN_IF_ERROR(
868       CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
869   IteratorContext::Params iterator_params(iterator_ctx.get());
870   std::move(split_providers.begin(), split_providers.end(),
871             std::back_inserter(iterator_params.split_providers));
872 
873   iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
874   std::unique_ptr<IteratorBase> iterator_base;
875   TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
876       iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
877       &iterator_base));
878   *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
879                                              std::move(iterator_base));
880   return Status::OK();
881 }
882 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)883 Status DatasetOpsTestBase::MakeIterator(
884     const DatasetParams& dataset_params, const TestDataset& dataset,
885     std::unique_ptr<TestIterator>* iterator) {
886   return MakeIterator(dataset_params, dataset, /*split_providers=*/{},
887                       iterator);
888 }
889 
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)890 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
891                                         std::vector<Tensor>* outputs) {
892   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, &params_,
893                                   &tensors_, &dataset_ctx_));
894   for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
895     outputs->emplace_back(*dataset_ctx_->mutable_output(i));
896   }
897   return Status::OK();
898 }
899 
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)900 Status DatasetOpsTestBase::MakeDatasetOpKernel(
901     const DatasetParams& dataset_params,
902     std::unique_ptr<OpKernel>* dataset_kernel) {
903   name_utils::OpNameParams params;
904   params.op_version = dataset_params.op_version();
905   std::vector<string> input_names;
906   TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
907   AttributeVector attributes;
908   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
909   NodeDef node_def =
910       test::function::NDef(dataset_params.node_name(), dataset_params.op_name(),
911                            input_names, attributes);
912   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
913   return Status::OK();
914 }
915 
MakeGetOptionsOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * op_kernel)916 Status DatasetOpsTestBase::MakeGetOptionsOpKernel(
917     const DatasetParams& dataset_params, std::unique_ptr<OpKernel>* op_kernel) {
918   name_utils::OpNameParams params;
919   params.op_version = dataset_params.op_version();
920   std::vector<string> input_names;
921   TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
922   AttributeVector attributes;
923   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
924   NodeDef node_def = test::function::NDef(dataset_params.node_name(),
925                                           dataset_params.dataset_type(),
926                                           input_names, attributes);
927   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
928   return Status::OK();
929 }
930 
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)931 Status DatasetOpsTestBase::MakeDatasetTensor(
932     const DatasetParams& dataset_params,
933     std::vector<std::unique_ptr<Tensor>>* created_tensors,
934     std::unique_ptr<Tensor>* dataset) {
935   // Make sure all the input dataset tensors have been populated.
936   std::vector<Tensor*> input_datasets;
937   for (auto& input : dataset_params.input_dataset_params()) {
938     std::unique_ptr<Tensor> t;
939     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
940     input_datasets.push_back(t.get());
941     created_tensors->push_back(std::move(t));
942   }
943 
944   AttributeVector attributes;
945   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
946 
947   auto input_tensors = dataset_params.GetInputTensors();
948   gtl::InlinedVector<TensorValue, 4> inputs;
949   inputs.reserve(input_datasets.size() + input_tensors.size());
950   for (auto input_dataset : input_datasets) {
951     inputs.emplace_back(TensorValue(input_dataset));
952   }
953   for (auto& input_tensor : input_tensors) {
954     inputs.emplace_back(TensorValue(&input_tensor));
955   }
956 
957   DatasetBase* dataset_base;
958   std::unique_ptr<OpKernel> dataset_kernel;
959   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
960   std::unique_ptr<OpKernelContext> dataset_ctx;
961   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
962   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
963                                           &dataset_ctx_params, &dataset_ctx));
964   TF_RETURN_IF_ERROR(
965       CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
966   Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
967   TF_RETURN_IF_ERROR(
968       StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
969   *dataset = absl::make_unique<Tensor>(dataset_tensor);
970   return Status::OK();
971 }
972 
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)973 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
974                              std::vector<PartialTensorShape> output_shapes,
975                              string node_name)
976     : output_dtypes_(std::move(output_dtypes)),
977       output_shapes_(std::move(output_shapes)),
978       node_name_(std::move(node_name)) {}
979 
IsDatasetTensor(const Tensor & tensor)980 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
981   return tensor.dtype() == DT_VARIANT &&
982          TensorShapeUtils::IsScalar(tensor.shape());
983 }
984 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)985 RangeDatasetParams::RangeDatasetParams(
986     int64_t start, int64_t stop, int64_t step, DataTypeVector output_dtypes,
987     std::vector<PartialTensorShape> output_shapes, string node_name)
988     : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
989                     std::move(node_name)),
990       start_(start),
991       stop_(stop),
992       step_(step) {}
993 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step)994 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
995                                        int64_t step)
996     : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
997       start_(start),
998       stop_(stop),
999       step_(step) {}
1000 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes)1001 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1002                                        int64_t step,
1003                                        DataTypeVector output_dtypes)
1004     : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
1005                     "range_dataset"),
1006       start_(start),
1007       stop_(stop),
1008       step_(step) {}
1009 
GetInputTensors() const1010 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
1011   Tensor start_tensor = CreateTensor<int64>(TensorShape({}), {start_});
1012   Tensor stop_tensor = CreateTensor<int64>(TensorShape({}), {stop_});
1013   Tensor step_tensor = CreateTensor<int64>(TensorShape({}), {step_});
1014   return {start_tensor, stop_tensor, step_tensor};
1015 }
1016 
GetInputNames(std::vector<string> * input_names) const1017 Status RangeDatasetParams::GetInputNames(
1018     std::vector<string>* input_names) const {
1019   *input_names = {"start", "stop", "step"};
1020   return Status::OK();
1021 }
1022 
GetAttributes(AttributeVector * attr_vector) const1023 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1024   *attr_vector = {{"output_types", output_dtypes_},
1025                   {"output_shapes", output_shapes_}};
1026   return Status::OK();
1027 }
1028 
dataset_type() const1029 string RangeDatasetParams::dataset_type() const { return "Range"; }
1030 
GetInputTensors() const1031 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1032   Tensor batch_size = CreateTensor<int64>(TensorShape({}), {batch_size_});
1033   Tensor drop_remainder =
1034       CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1035   return {batch_size, drop_remainder};
1036 }
1037 
GetInputNames(std::vector<string> * input_names) const1038 Status BatchDatasetParams::GetInputNames(
1039     std::vector<string>* input_names) const {
1040   *input_names = {"input_dataset", "batch_size", "drop_remainder"};
1041   return Status::OK();
1042 }
1043 
GetAttributes(AttributeVector * attr_vector) const1044 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1045   *attr_vector = {{"parallel_copy", parallel_copy_},
1046                   {"output_types", output_dtypes_},
1047                   {"output_shapes", output_shapes_}};
1048   return Status::OK();
1049 }
1050 
dataset_type() const1051 string BatchDatasetParams::dataset_type() const { return "Batch"; }
1052 
GetInputTensors() const1053 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1054   return other_arguments_;
1055 }
1056 
GetInputNames(std::vector<string> * input_names) const1057 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1058   input_names->emplace_back("input_dataset");
1059   for (int i = 0; i < other_arguments_.size(); ++i) {
1060     input_names->emplace_back(absl::StrCat("other_arguments_", i));
1061   }
1062   return Status::OK();
1063 }
1064 
GetAttributes(AttributeVector * attr_vector) const1065 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1066   *attr_vector = {{"f", func_},
1067                   {"Targuments", type_arguments_},
1068                   {"output_shapes", output_shapes_},
1069                   {"output_types", output_dtypes_},
1070                   {"use_inter_op_parallelism", use_inter_op_parallelism_},
1071                   {"preserve_cardinality", preserve_cardinality_}};
1072   return Status::OK();
1073 }
1074 
dataset_type() const1075 string MapDatasetParams::dataset_type() const { return "Map"; }
1076 
func_lib() const1077 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1078   return func_lib_;
1079 }
1080 
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name)1081 TensorSliceDatasetParams::TensorSliceDatasetParams(
1082     std::vector<Tensor> components, string node_name)
1083     : DatasetParams(TensorSliceDtypes(components),
1084                     TensorSliceShapes(components), std::move(node_name)),
1085       components_(std::move(components)) {}
1086 
GetInputTensors() const1087 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1088   return components_;
1089 }
1090 
GetInputNames(std::vector<string> * input_names) const1091 Status TensorSliceDatasetParams::GetInputNames(
1092     std::vector<string>* input_names) const {
1093   input_names->reserve(components_.size());
1094   for (int i = 0; i < components_.size(); ++i) {
1095     input_names->emplace_back(absl::StrCat("components_", i));
1096   }
1097   return Status::OK();
1098 }
1099 
GetAttributes(AttributeVector * attr_vector) const1100 Status TensorSliceDatasetParams::GetAttributes(
1101     AttributeVector* attr_vector) const {
1102   *attr_vector = {{"Toutput_types", output_dtypes_},
1103                   {"output_shapes", output_shapes_}};
1104   return Status::OK();
1105 }
1106 
TensorSliceDtypes(const std::vector<Tensor> & input_components)1107 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1108     const std::vector<Tensor>& input_components) {
1109   DataTypeVector dtypes;
1110   for (const auto& component : input_components) {
1111     dtypes.emplace_back(component.dtype());
1112   }
1113   return dtypes;
1114 }
1115 
TensorSliceShapes(const std::vector<Tensor> & input_components)1116 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1117     const std::vector<Tensor>& input_components) {
1118   std::vector<PartialTensorShape> shapes;
1119   for (const auto& component : input_components) {
1120     gtl::InlinedVector<int64, 4> partial_dim_sizes;
1121     for (int i = 1; i < component.dims(); ++i) {
1122       partial_dim_sizes.push_back(component.dim_size(i));
1123     }
1124     shapes.emplace_back(std::move(partial_dim_sizes));
1125   }
1126   return shapes;
1127 }
1128 
dataset_type() const1129 string TensorSliceDatasetParams::dataset_type() const { return "TensorSlice"; }
1130 
GetInputTensors() const1131 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1132   return {CreateTensor<int64>(TensorShape({}), {count_})};
1133 }
1134 
GetInputNames(std::vector<string> * input_names) const1135 Status TakeDatasetParams::GetInputNames(
1136     std::vector<string>* input_names) const {
1137   *input_names = {"input_dataset", "count"};
1138   return Status::OK();
1139 }
1140 
GetAttributes(AttributeVector * attr_vector) const1141 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1142   *attr_vector = {{"output_shapes", output_shapes_},
1143                   {"output_types", output_dtypes_}};
1144   return Status::OK();
1145 }
1146 
dataset_type() const1147 string TakeDatasetParams::dataset_type() const { return "Take"; }
1148 
GetInputTensors() const1149 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1150   return {};
1151 }
1152 
GetInputNames(std::vector<string> * input_names) const1153 Status ConcatenateDatasetParams::GetInputNames(
1154     std::vector<string>* input_names) const {
1155   *input_names = {"input_dataset", "another_dataset"};
1156   return Status::OK();
1157 }
1158 
GetAttributes(AttributeVector * attr_vector) const1159 Status ConcatenateDatasetParams::GetAttributes(
1160     AttributeVector* attr_vector) const {
1161   *attr_vector = {{"output_types", output_dtypes_},
1162                   {"output_shapes", output_shapes_}};
1163   return Status::OK();
1164 }
1165 
dataset_type() const1166 string ConcatenateDatasetParams::dataset_type() const { return "Concatenate"; }
1167 
GetInputTensors() const1168 std::vector<Tensor> OptionsDatasetParams::GetInputTensors() const { return {}; }
1169 
GetInputNames(std::vector<string> * input_names) const1170 Status OptionsDatasetParams::GetInputNames(
1171     std::vector<string>* input_names) const {
1172   input_names->emplace_back("input_dataset");
1173   return Status::OK();
1174 }
1175 
GetAttributes(AttributeVector * attr_vector) const1176 Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1177   *attr_vector = {{"serialized_options", serialized_options_},
1178                   {"output_shapes", output_shapes_},
1179                   {"output_types", output_dtypes_}};
1180   return Status::OK();
1181 }
1182 
dataset_type() const1183 string OptionsDatasetParams::dataset_type() const { return "Options"; }
1184 
1185 }  // namespace data
1186 }  // namespace tensorflow
1187