• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_test_base.h"
17 
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/name_utils.h"
38 #include "tensorflow/core/data/serialization_utils.h"
39 #include "tensorflow/core/data/split_utils.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/cancellation.h"
42 #include "tensorflow/core/framework/control_flow.h"
43 #include "tensorflow/core/framework/dataset.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_handle_cache.h"
47 #include "tensorflow/core/framework/function_testlib.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/numeric_types.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/op_def.pb.h"
52 #include "tensorflow/core/framework/op_kernel.h"
53 #include "tensorflow/core/framework/register_types.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/framework/tensor.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/framework/types.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/framework/variant_tensor_data.h"
61 #include "tensorflow/core/framework/versions.pb.h"
62 #include "tensorflow/core/graph/graph.h"
63 #include "tensorflow/core/lib/core/status_test_util.h"
64 #include "tensorflow/core/lib/gtl/inlined_vector.h"
65 #include "tensorflow/core/lib/io/record_writer.h"
66 #include "tensorflow/core/lib/io/zlib_compression_options.h"
67 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
68 #include "tensorflow/core/platform/bfloat16.h"
69 #include "tensorflow/core/platform/env.h"
70 #include "tensorflow/core/platform/errors.h"
71 #include "tensorflow/core/platform/file_system.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/platform/status.h"
74 #include "tensorflow/core/platform/test.h"
75 #include "tensorflow/core/platform/threadpool.h"
76 #include "tensorflow/core/platform/tstring.h"
77 #include "tensorflow/core/platform/types.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/public/session_options.h"
80 #include "tensorflow/core/public/version.h"
81 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
82 
83 namespace tensorflow {
84 namespace data {
85 
ToString(CompressionType compression_type)86 string ToString(CompressionType compression_type) {
87   switch (compression_type) {
88     case CompressionType::ZLIB:
89       return "ZLIB";
90     case CompressionType::GZIP:
91       return "GZIP";
92     case CompressionType::RAW:
93       return "RAW";
94     case CompressionType::UNCOMPRESSED:
95       return "";
96   }
97 }
98 
GetZlibCompressionOptions(CompressionType compression_type)99 io::ZlibCompressionOptions GetZlibCompressionOptions(
100     CompressionType compression_type) {
101   switch (compression_type) {
102     case CompressionType::ZLIB:
103       return io::ZlibCompressionOptions::DEFAULT();
104     case CompressionType::GZIP:
105       return io::ZlibCompressionOptions::GZIP();
106     case CompressionType::RAW:
107       return io::ZlibCompressionOptions::RAW();
108     case CompressionType::UNCOMPRESSED:
109       LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
110                    << ToString(compression_type);
111       return io::ZlibCompressionOptions::DEFAULT();
112   }
113 }
114 
WriteDataToFile(const string & filename,const char * data)115 Status WriteDataToFile(const string& filename, const char* data) {
116   return WriteDataToFile(filename, data, CompressionParams());
117 }
118 
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)119 Status WriteDataToFile(const string& filename, const char* data,
120                        const CompressionParams& params) {
121   Env* env = Env::Default();
122   std::unique_ptr<WritableFile> file_writer;
123   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
124   if (params.compression_type == CompressionType::UNCOMPRESSED) {
125     TF_RETURN_IF_ERROR(file_writer->Append(data));
126   } else if (params.compression_type == CompressionType::ZLIB ||
127              params.compression_type == CompressionType::GZIP ||
128              params.compression_type == CompressionType::RAW) {
129     auto zlib_compression_options =
130         GetZlibCompressionOptions(params.compression_type);
131     io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
132                              params.output_buffer_size,
133                              zlib_compression_options);
134     TF_RETURN_IF_ERROR(out.Init());
135     TF_RETURN_IF_ERROR(out.Append(data));
136     TF_RETURN_IF_ERROR(out.Flush());
137     TF_RETURN_IF_ERROR(out.Close());
138   } else {
139     return tensorflow::errors::InvalidArgument(
140         "Unsupported compression_type: ", ToString(params.compression_type));
141   }
142 
143   TF_RETURN_IF_ERROR(file_writer->Flush());
144   TF_RETURN_IF_ERROR(file_writer->Close());
145 
146   return OkStatus();
147 }
148 
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)149 Status WriteDataToTFRecordFile(const string& filename,
150                                const std::vector<absl::string_view>& records,
151                                const CompressionParams& params) {
152   Env* env = Env::Default();
153   std::unique_ptr<WritableFile> file_writer;
154   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
155   auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
156       ToString(params.compression_type));
157   options.zlib_options.input_buffer_size = params.input_buffer_size;
158   io::RecordWriter record_writer(file_writer.get(), options);
159   for (const auto& record : records) {
160     TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
161   }
162   TF_RETURN_IF_ERROR(record_writer.Flush());
163   TF_RETURN_IF_ERROR(record_writer.Close());
164   TF_RETURN_IF_ERROR(file_writer->Flush());
165   TF_RETURN_IF_ERROR(file_writer->Close());
166   return OkStatus();
167 }
168 
169 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)170 Status IsEqual(const Tensor& t1, const Tensor& t2) {
171   if (t1.dtype() != t2.dtype()) {
172     return tensorflow::errors::Internal(
173         "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
174         " vs. ", DataTypeString(t2.dtype()));
175   }
176   if (!t1.IsSameSize(t2)) {
177     return tensorflow::errors::Internal(
178         "Two tensors have different shapes: ", t1.shape().DebugString(),
179         " vs. ", t2.shape().DebugString());
180   }
181 
182   auto flat_t1 = t1.flat<T>();
183   auto flat_t2 = t2.flat<T>();
184   auto length = flat_t1.size();
185 
186   for (int i = 0; i < length; ++i) {
187     if (flat_t1(i) != flat_t2(i)) {
188       return tensorflow::errors::Internal(
189           "Two tensors have different values "
190           "at [",
191           i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
192     }
193   }
194   return OkStatus();
195 }
196 
DatasetOpsTestBase()197 DatasetOpsTestBase::DatasetOpsTestBase()
198     : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
199       device_type_(DEVICE_CPU),
200       cpu_num_(kDefaultCPUNum),
201       thread_num_(kDefaultThreadNum) {
202   allocator_ = device_->GetAllocator(AllocatorAttributes());
203 }
204 
~DatasetOpsTestBase()205 DatasetOpsTestBase::~DatasetOpsTestBase() {
206   if (dataset_) {
207     dataset_->Unref();
208   }
209 }
210 
ExpectEqual(const Tensor & a,const Tensor & b)211 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
212   switch (a.dtype()) {
213 #define CASE(DT)                           \
214   case DataTypeToEnum<DT>::value:          \
215     TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
216     break;
217     TF_CALL_NUMBER_TYPES(CASE);
218     TF_CALL_tstring(CASE);
219     // TODO(feihugis): figure out how to support variant tensors.
220 #undef CASE
221     default:
222       return errors::Internal("Unsupported dtype: ", a.dtype());
223   }
224   return OkStatus();
225 }
226 
227 template <typename T>
compare(const Tensor & t1,const Tensor & t2)228 bool compare(const Tensor& t1, const Tensor& t2) {
229   auto flat_t1 = t1.flat<T>();
230   auto flat_t2 = t2.flat<T>();
231   auto length = std::min(flat_t1.size(), flat_t2.size());
232   for (int i = 0; i < length; ++i) {
233     if (flat_t1(i) < flat_t2(i)) return true;
234     if (flat_t1(i) > flat_t2(i)) return false;
235   }
236   return flat_t1.size() < length;
237 }
238 
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)239 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
240                                        std::vector<Tensor> expected_tensors,
241                                        bool compare_order) {
242   if (produced_tensors.size() != expected_tensors.size()) {
243     return Status(tensorflow::errors::Internal(
244         "The two tensor vectors have different size (", produced_tensors.size(),
245         " v.s. ", expected_tensors.size(), ")"));
246   }
247 
248   if (produced_tensors.empty()) return OkStatus();
249   if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
250     return Status(tensorflow::errors::Internal(
251         "The two tensor vectors have different dtypes (",
252         produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
253         ")"));
254   }
255 
256   if (!compare_order) {
257     const DataType& dtype = produced_tensors[0].dtype();
258     switch (dtype) {
259 #define CASE(DT)                                                \
260   case DT:                                                      \
261     std::sort(produced_tensors.begin(), produced_tensors.end(), \
262               compare<EnumToDataType<DT>::Type>);               \
263     std::sort(expected_tensors.begin(), expected_tensors.end(), \
264               compare<EnumToDataType<DT>::Type>);               \
265     break;
266       CASE(DT_FLOAT);
267       CASE(DT_DOUBLE);
268       CASE(DT_INT32);
269       CASE(DT_UINT8);
270       CASE(DT_INT16);
271       CASE(DT_INT8);
272       CASE(DT_STRING);
273       CASE(DT_INT64);
274       CASE(DT_BOOL);
275       CASE(DT_QINT8);
276       CASE(DT_QUINT8);
277       CASE(DT_QINT32);
278       CASE(DT_QINT16);
279       CASE(DT_QUINT16);
280       CASE(DT_UINT16);
281       CASE(DT_HALF);
282       CASE(DT_UINT32);
283       CASE(DT_UINT64);
284       // TODO(feihugis): support other dtypes.
285 #undef CASE
286       default:
287         return errors::Internal("Unsupported dtype: ", dtype);
288     }
289   }
290 
291   for (int i = 0; i < produced_tensors.size(); ++i) {
292     TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
293                                                        expected_tensors[i]));
294   }
295   return OkStatus();
296 }
297 
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)298 Status DatasetOpsTestBase::CreateOpKernel(
299     const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
300   OpKernel* kernel;
301   Status s;
302 
303   std::shared_ptr<const NodeProperties> props;
304   TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
305       node_def, flr_->GetFunctionLibraryDefinition(), &props));
306   // Apply attribute defaults.
307   auto props_with_defaults = std::make_shared<NodeProperties>(*props);
308   for (const auto& attr : props->op_def->attr()) {
309     if (attr.has_default_value() &&
310         !props->node_def.attr().contains(attr.name())) {
311       (*props_with_defaults->node_def.mutable_attr())[attr.name()] =
312           attr.default_value();
313     }
314   }
315   TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
316       device_type_, device_.get(), allocator_, flr_,
317       device_->resource_manager(), props_with_defaults, TF_GRAPH_DEF_VERSION,
318       &kernel));
319   op_kernel->reset(kernel);
320   return OkStatus();
321 }
322 
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)323 Status DatasetOpsTestBase::CreateDatasetContext(
324     OpKernel* const dateset_kernel,
325     gtl::InlinedVector<TensorValue, 4>* const inputs,
326     std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
327     std::unique_ptr<OpKernelContext>* dataset_context) {
328   Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
329   if (!status.ok()) {
330     VLOG(0) << "WARNING: " << status.ToString();
331   }
332   TF_RETURN_IF_ERROR(CreateOpKernelContext(
333       dateset_kernel, inputs, dataset_context_params, dataset_context));
334   return OkStatus();
335 }
336 
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)337 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
338                                          OpKernelContext* context,
339                                          DatasetBase** const dataset) {
340   TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
341   // Assume that DatasetOp has only one output.
342   DCHECK_EQ(context->num_outputs(), 1);
343   TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
344   return OkStatus();
345 }
346 
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)347 Status DatasetOpsTestBase::RestoreIterator(
348     IteratorContext* ctx, IteratorStateReader* reader,
349     const string& output_prefix, const DatasetBase& dataset,
350     std::unique_ptr<IteratorBase>* iterator) {
351   return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
352                                             iterator);
353 }
354 
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)355 Status DatasetOpsTestBase::CreateIteratorContext(
356     OpKernelContext* const op_context,
357     std::unique_ptr<IteratorContext>* iterator_context) {
358   IteratorContext::Params params(op_context);
359   params.resource_mgr = op_context->resource_manager();
360   function_handle_cache_ = std::make_unique<FunctionHandleCache>(flr_);
361   params.function_handle_cache = function_handle_cache_.get();
362   params.cancellation_manager = cancellation_manager_.get();
363   *iterator_context = std::make_unique<IteratorContext>(params);
364   return OkStatus();
365 }
366 
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)367 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
368                                                  int output_index,
369                                                  DatasetBase** const dataset) {
370   Tensor* output = context->mutable_output(output_index);
371   Status status = GetDatasetFromVariantTensor(*output, dataset);
372   (*dataset)->Ref();
373   return status;
374 }
375 
InitThreadPool(int thread_num)376 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
377   if (thread_num < 1) {
378     return errors::InvalidArgument(
379         "The `thread_num` argument should be positive but got: ", thread_num);
380   }
381   thread_pool_ = std::make_unique<thread::ThreadPool>(
382       Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
383   return OkStatus();
384 }
385 
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)386 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
387     const std::vector<FunctionDef>& flib, int cpu_num) {
388   if (cpu_num < 1) {
389     return errors::InvalidArgument(
390         "The `cpu_num` argument should be positive but got: ", cpu_num);
391   }
392   SessionOptions options;
393   auto* device_count = options.config.mutable_device_count();
394   device_count->insert({"CPU", cpu_num});
395   std::vector<std::unique_ptr<Device>> devices;
396   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
397       options, "/job:localhost/replica:0/task:0", &devices));
398   device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
399   resource_mgr_ = std::make_unique<ResourceMgr>("default_container");
400 
401   FunctionDefLibrary proto;
402   for (const auto& fdef : flib) *(proto.add_function()) = fdef;
403   lib_def_ =
404       std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
405 
406   OptimizerOptions opts;
407   pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
408       device_mgr_.get(), Env::Default(), /*config=*/nullptr,
409       TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
410       /*parent=*/nullptr,
411       /*session_metadata=*/nullptr,
412       Rendezvous::Factory{
413           [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
414             *r = new IntraProcessRendezvous(device_mgr);
415             return OkStatus();
416           }});
417   flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
418   if (thread_pool_ == nullptr) {
419     runner_ = [](const std::function<void()>& fn) { fn(); };
420   } else {
421     runner_ = [this](std::function<void()> fn) {
422       thread_pool_->Schedule(std::move(fn));
423     };
424   }
425   return OkStatus();
426 }
427 
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)428 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
429                                        OpKernelContext* context) {
430   device_->Compute(op_kernel, context);
431   return context->status();
432 }
433 
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)434 Status DatasetOpsTestBase::RunFunction(
435     const FunctionDef& fdef, test::function::Attrs attrs,
436     const std::vector<Tensor>& args,
437     const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
438   std::unique_ptr<Executor> exec;
439   InstantiationResult result;
440   auto GetOpSig = [](const string& op, const OpDef** sig) {
441     return OpRegistry::Global()->LookUpOpDef(op, sig);
442   };
443   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
444 
445   DataTypeVector arg_types = result.arg_types;
446   DataTypeVector ret_types = result.ret_types;
447 
448   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
449   TF_RETURN_IF_ERROR(
450       ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
451 
452   const int version = g->versions().producer();
453   LocalExecutorParams params;
454   params.function_library = flr_;
455   params.device = device_.get();
456   params.create_kernel = [this, version](
457                              const std::shared_ptr<const NodeProperties>& props,
458                              OpKernel** kernel) {
459     return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
460                                  kernel);
461   };
462   params.delete_kernel = [](OpKernel* kernel) {
463     DeleteNonCachedKernel(kernel);
464   };
465 
466   Executor* cur_exec;
467   TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
468   exec.reset(cur_exec);
469   FunctionCallFrame frame(arg_types, ret_types);
470   TF_RETURN_IF_ERROR(frame.SetArgs(args));
471   Executor::Args exec_args;
472   exec_args.call_frame = &frame;
473   exec_args.runner = runner_;
474   TF_RETURN_IF_ERROR(exec->Run(exec_args));
475   std::vector<Tensor> computed;
476   TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
477   if (computed.size() != rets.size()) {
478     return errors::InvalidArgument(
479         "The result does not match the expected number of return outpus",
480         ". Expected: ", rets.size(), ". Actual: ", computed.size());
481   }
482   for (int i = 0; i < rets.size(); ++i) {
483     *(rets[i]) = computed[i];
484   }
485   return OkStatus();
486 }
487 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)488 Status DatasetOpsTestBase::CreateOpKernelContext(
489     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
490     std::unique_ptr<OpKernelContext>* context) {
491   return CreateOpKernelContext(kernel, inputs, &params_, context);
492 }
493 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)494 Status DatasetOpsTestBase::CreateOpKernelContext(
495     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
496     std::unique_ptr<OpKernelContext::Params>* context_params,
497     std::unique_ptr<OpKernelContext>* context) {
498   auto params = std::make_unique<OpKernelContext::Params>();
499   cancellation_manager_ = std::make_unique<CancellationManager>();
500   params->cancellation_manager = cancellation_manager_.get();
501   params->device = device_.get();
502   params->frame_iter = FrameAndIter(0, 0);
503   params->function_library = flr_;
504   params->inputs = *inputs;
505   params->op_kernel = kernel;
506   params->resource_manager = resource_mgr_.get();
507   params->runner = &runner_;
508   slice_reader_cache_ =
509       std::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
510   params->slice_reader_cache = slice_reader_cache_.get();
511   step_container_ =
512       std::make_unique<ScopedStepContainer>(0, [](const string&) {});
513   params->step_container = step_container_.get();
514 
515   // Set the allocator attributes for the outputs.
516   allocator_attrs_.clear();
517   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
518     AllocatorAttributes attr;
519     const bool on_host =
520         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
521     attr.set_on_host(on_host);
522     allocator_attrs_.emplace_back(attr);
523   }
524   params->output_attr_array = allocator_attrs_.data();
525 
526   *context = std::make_unique<OpKernelContext>(params.get());
527   *context_params = std::move(params);
528   return OkStatus();
529 }
530 
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)531 Status DatasetOpsTestBase::CreateSerializationContext(
532     std::unique_ptr<SerializationContext>* context) {
533   *context =
534       std::make_unique<SerializationContext>(SerializationContext::Params{});
535   return OkStatus();
536 }
537 
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)538 Status DatasetOpsTestBase::CheckOpKernelInput(
539     const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
540   if (kernel.num_inputs() != inputs.size()) {
541     return errors::InvalidArgument("The number of input elements should be ",
542                                    kernel.num_inputs(),
543                                    ", but got: ", inputs.size());
544   }
545   return OkStatus();
546 }
547 
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)548 Status DatasetOpsTestBase::AddDatasetInput(
549     gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
550     DataType dtype, const TensorShape& shape) {
551   if (input_types.size() < inputs->size()) {
552     return errors::InvalidArgument("Adding more inputs than types: ",
553                                    inputs->size(), " vs. ", input_types.size());
554   }
555   bool is_ref = IsRefType(input_types[inputs->size()]);
556   auto input = std::make_unique<Tensor>(allocator_, dtype, shape);
557 
558   if (is_ref) {
559     DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
560     if (expected_dtype != dtype) {
561       return errors::InvalidArgument("The input data type is ", dtype,
562                                      " , but expected: ", expected_dtype);
563     }
564     inputs->push_back({&lock_for_refs_, input.get()});
565   } else {
566     if (input_types[inputs->size()] != dtype) {
567       return errors::InvalidArgument(
568           "The input data type is ", dtype,
569           " , but expected: ", input_types[inputs->size()]);
570     }
571     inputs->push_back({nullptr, input.get()});
572   }
573 
574   // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
575   // collect the inputs.
576   tensors_.push_back(std::move(input));
577 
578   return OkStatus();
579 }
580 
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)581 Status DatasetOpsTestBase::CheckIteratorGetNext(
582     const std::vector<Tensor>& expected_outputs, bool compare_order) {
583   return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
584                               expected_outputs, compare_order);
585 }
586 
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)587 Status DatasetOpsTestBase::CheckIteratorGetNext(
588     TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
589     bool compare_order) {
590   return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
591                               expected_outputs, compare_order);
592 }
593 
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)594 Status DatasetOpsTestBase::CheckIteratorGetNext(
595     IteratorBase* iterator, IteratorContext* ctx,
596     const std::vector<Tensor>& expected_outputs, bool compare_order) {
597   bool end_of_sequence = false;
598   std::vector<Tensor> out_tensors;
599   while (!end_of_sequence) {
600     std::vector<Tensor> next;
601     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
602     out_tensors.insert(out_tensors.end(), next.begin(), next.end());
603   }
604   // Call GetNext one more time to make sure it still reports
605   // end_of_sequence = True.
606   std::vector<Tensor> unused;
607   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
608   EXPECT_TRUE(end_of_sequence);
609 
610   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
611                            /*compare_order=*/compare_order));
612   return OkStatus();
613 }
614 
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)615 Status DatasetOpsTestBase::CheckIteratorSkip(
616     int num_to_skip, int expected_num_skipped, bool get_next,
617     const std::vector<Tensor>& expected_outputs, bool compare_order) {
618   IteratorBase* iterator = iterator_.get();
619   IteratorContext* ctx = iterator_ctx_.get();
620 
621   bool end_of_sequence = false;
622   int num_skipped = 0;
623   TF_RETURN_IF_ERROR(
624       iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
625   EXPECT_TRUE(num_skipped == expected_num_skipped);
626   if (get_next) {
627     EXPECT_TRUE(!end_of_sequence);
628     std::vector<Tensor> out_tensors;
629     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
630     TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
631                              /*compare_order=*/compare_order));
632   }
633   return OkStatus();
634 }
635 
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)636 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
637     const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
638   std::unique_ptr<TestDataset> dataset;
639   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
640   std::vector<std::unique_ptr<SplitProvider>> split_providers;
641   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
642   std::unique_ptr<TestIterator> iterator;
643   TF_RETURN_IF_ERROR(
644       MakeIterator(params, *dataset, std::move(split_providers), &iterator));
645   TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
646                                           /*compare_order=*/true));
647   return OkStatus();
648 }
649 
CheckSplitProviderShardedIteration(const DatasetParams & params,int64_t num_shards,int64_t shard_index,const std::vector<Tensor> & expected_outputs)650 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
651     const DatasetParams& params, int64_t num_shards, int64_t shard_index,
652     const std::vector<Tensor>& expected_outputs) {
653   std::unique_ptr<TestDataset> dataset;
654   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
655   std::vector<std::unique_ptr<SplitProvider>> split_providers;
656   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProviders(&split_providers));
657   for (int i = 0; i < split_providers.size(); ++i) {
658     split_providers[i] = std::make_unique<ShardingSplitProvider>(
659         num_shards, shard_index, std::move(split_providers[i]));
660   }
661   std::unique_ptr<IteratorContext> iterator_ctx;
662   TF_RETURN_IF_ERROR(
663       CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
664   IteratorContext::Params iterator_params(iterator_ctx.get());
665   std::move(split_providers.begin(), split_providers.end(),
666             std::back_inserter(iterator_params.split_providers));
667   iterator_ctx = std::make_unique<IteratorContext>(iterator_params);
668   int mid_breakpoint = expected_outputs.size() / 2;
669   int near_end_breakpoint = expected_outputs.size() - 1;
670   int end_breakpoint = expected_outputs.size();
671   TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
672       dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
673       expected_outputs,
674       /*breakpoints=*/
675       {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
676       /*compare_order=*/true));
677   return OkStatus();
678 }
679 
CheckDatasetNodeName(const string & expected_dataset_node_name)680 Status DatasetOpsTestBase::CheckDatasetNodeName(
681     const string& expected_dataset_node_name) {
682   EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
683   return OkStatus();
684 }
685 
CheckDatasetTypeString(const string & expected_type_str)686 Status DatasetOpsTestBase::CheckDatasetTypeString(
687     const string& expected_type_str) {
688   EXPECT_EQ(dataset_->type_string(), expected_type_str);
689   return OkStatus();
690 }
691 
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)692 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
693     const DataTypeVector& expected_output_dtypes) {
694   TF_EXPECT_OK(
695       VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
696   return OkStatus();
697 }
698 
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)699 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
700     const std::vector<PartialTensorShape>& expected_output_shapes) {
701   TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
702                                       expected_output_shapes));
703   return OkStatus();
704 }
705 
CheckDatasetCardinality(int expected_cardinality)706 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
707   EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
708   return OkStatus();
709 }
710 
CheckDatasetOptions(const Options & expected_options)711 Status DatasetOpsTestBase::CheckDatasetOptions(
712     const Options& expected_options) {
713   EXPECT_EQ(dataset_->options().SerializeAsString(),
714             expected_options.SerializeAsString());
715   return OkStatus();
716 }
717 
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)718 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
719     const DataTypeVector& expected_output_dtypes) {
720   TF_EXPECT_OK(
721       VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
722   return OkStatus();
723 }
724 
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)725 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
726     const std::vector<PartialTensorShape>& expected_output_shapes) {
727   TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
728                                       expected_output_shapes));
729   return OkStatus();
730 }
731 
CheckIteratorPrefix(const string & expected_iterator_prefix)732 Status DatasetOpsTestBase::CheckIteratorPrefix(
733     const string& expected_iterator_prefix) {
734   EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
735   return OkStatus();
736 }
737 
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)738 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
739     DatasetBase* dataset, IteratorContext* iterator_ctx,
740     const std::string& iterator_prefix,
741     const std::vector<Tensor>& expected_outputs,
742     const std::vector<int>& breakpoints, bool compare_order) {
743   std::unique_ptr<IteratorBase> iterator;
744   TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
745                                            iterator_prefix, &iterator));
746   std::unique_ptr<SerializationContext> serialization_ctx;
747   TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
748   bool end_of_sequence = false;
749   int cur_iteration = 0;
750   std::vector<Tensor> out_tensors;
751   for (int breakpoint : breakpoints) {
752     VariantTensorDataWriter writer;
753     TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
754     std::vector<const VariantTensorData*> data;
755     writer.GetData(&data);
756     VariantTensorDataReader reader(data);
757     TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
758                                  *dataset, &iterator));
759 
760     while (cur_iteration <= breakpoint) {
761       std::vector<Tensor> next;
762       TF_RETURN_IF_ERROR(
763           iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
764       out_tensors.insert(out_tensors.end(), next.begin(), next.end());
765       cur_iteration++;
766     }
767   }
768   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
769                            /*compare_order=*/compare_order));
770   return OkStatus();
771 }
772 
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)773 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
774     const std::string& iterator_prefix,
775     const std::vector<Tensor>& expected_outputs,
776     const std::vector<int>& breakpoints, bool compare_order) {
777   return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
778                                      iterator_prefix, expected_outputs,
779                                      breakpoints, compare_order);
780 }
781 
Initialize(const DatasetParams & dataset_params)782 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
783   if (initialized_) {
784     return errors::Internal(
785         "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
786         "iterator_ctx_, iterator_) have already been initialized.");
787   }
788   TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
789   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, &params_,
790                                  &dataset_ctx_, &tensors_, &dataset_));
791   TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
792   TF_RETURN_IF_ERROR(
793       dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
794                              dataset_params.iterator_prefix(), &iterator_));
795   initialized_ = true;
796   return OkStatus();
797 }
798 
InitializeRuntime(const DatasetParams & dataset_params)799 Status DatasetOpsTestBase::InitializeRuntime(
800     const DatasetParams& dataset_params) {
801   TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
802   TF_RETURN_IF_ERROR(
803       InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
804   return OkStatus();
805 }
806 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)807 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
808                                        std::unique_ptr<TestDataset>* dataset) {
809   DatasetBase* dataset_base;
810   std::unique_ptr<OpKernel> dataset_kernel;
811   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
812   std::unique_ptr<OpKernelContext> dataset_ctx;
813   std::vector<std::unique_ptr<Tensor>> created_tensors;
814   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
815                                  &dataset_ctx_params, &dataset_ctx,
816                                  &created_tensors, &dataset_base));
817   *dataset = std::make_unique<TestDataset>(
818       std::move(dataset_kernel), std::move(dataset_ctx_params),
819       std::move(dataset_ctx), std::move(created_tensors), dataset_base);
820   return OkStatus();
821 }
822 
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)823 Status DatasetOpsTestBase::RunDatasetOp(
824     const DatasetParams& dataset_params,
825     std::unique_ptr<OpKernel>* dataset_kernel,
826     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
827     std::vector<std::unique_ptr<Tensor>>* created_tensors,
828     std::unique_ptr<OpKernelContext>* dataset_ctx) {
829   std::vector<Tensor*> input_datasets;
830   for (auto& input : dataset_params.input_dataset_params()) {
831     std::unique_ptr<Tensor> t;
832     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
833     input_datasets.push_back(t.get());
834     created_tensors->push_back(std::move(t));
835   }
836   gtl::InlinedVector<TensorValue, 4> inputs;
837   for (auto input_dataset : input_datasets) {
838     inputs.emplace_back(TensorValue(input_dataset));
839   }
840 
841   // Copy the input tensors, storing them in the `inputs` vectors, and storing
842   // owned references to the copies in `created_tensors`.
843   for (auto& input : dataset_params.GetInputTensors()) {
844     auto copy = std::make_unique<Tensor>(input);
845     inputs.push_back(TensorValue(copy.get()));
846     created_tensors->push_back(std::move(copy));
847   }
848 
849   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
850   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
851                                           dataset_ctx_params, dataset_ctx));
852   TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
853   return OkStatus();
854 }
855 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)856 Status DatasetOpsTestBase::MakeDataset(
857     const DatasetParams& dataset_params,
858     std::unique_ptr<OpKernel>* dataset_kernel,
859     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
860     std::unique_ptr<OpKernelContext>* dataset_ctx,
861     std::vector<std::unique_ptr<Tensor>>* created_tensors,
862     DatasetBase** dataset) {
863   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
864                                   dataset_ctx_params, created_tensors,
865                                   dataset_ctx));
866   // Assume that DatasetOp has only one output.
867   DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
868   TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
869   return OkStatus();
870 }
871 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<TestIterator> * iterator)872 Status DatasetOpsTestBase::MakeIterator(
873     const DatasetParams& dataset_params, const TestDataset& dataset,
874     std::vector<std::unique_ptr<SplitProvider>> split_providers,
875     std::unique_ptr<TestIterator>* iterator) {
876   std::unique_ptr<IteratorContext> iterator_ctx;
877   TF_RETURN_IF_ERROR(
878       CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
879   IteratorContext::Params iterator_params(iterator_ctx.get());
880   std::move(split_providers.begin(), split_providers.end(),
881             std::back_inserter(iterator_params.split_providers));
882 
883   iterator_ctx = std::make_unique<IteratorContext>(iterator_params);
884   std::unique_ptr<IteratorBase> iterator_base;
885   TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
886       iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
887       &iterator_base));
888   *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
889                                              std::move(iterator_base));
890   return OkStatus();
891 }
892 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)893 Status DatasetOpsTestBase::MakeIterator(
894     const DatasetParams& dataset_params, const TestDataset& dataset,
895     std::unique_ptr<TestIterator>* iterator) {
896   return MakeIterator(dataset_params, dataset, /*split_providers=*/{},
897                       iterator);
898 }
899 
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)900 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
901                                         std::vector<Tensor>* outputs) {
902   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, &params_,
903                                   &tensors_, &dataset_ctx_));
904   for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
905     outputs->emplace_back(*dataset_ctx_->mutable_output(i));
906   }
907   return OkStatus();
908 }
909 
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)910 Status DatasetOpsTestBase::MakeDatasetOpKernel(
911     const DatasetParams& dataset_params,
912     std::unique_ptr<OpKernel>* dataset_kernel) {
913   name_utils::OpNameParams params;
914   params.op_version = dataset_params.op_version();
915   std::vector<string> input_names;
916   TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
917   AttributeVector attributes;
918   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
919   NodeDef node_def =
920       test::function::NDef(dataset_params.node_name(), dataset_params.op_name(),
921                            input_names, attributes);
922   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
923   return OkStatus();
924 }
925 
MakeGetOptionsOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * op_kernel)926 Status DatasetOpsTestBase::MakeGetOptionsOpKernel(
927     const DatasetParams& dataset_params, std::unique_ptr<OpKernel>* op_kernel) {
928   name_utils::OpNameParams params;
929   params.op_version = dataset_params.op_version();
930   std::vector<string> input_names;
931   TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
932   AttributeVector attributes;
933   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
934   NodeDef node_def = test::function::NDef(dataset_params.node_name(),
935                                           dataset_params.dataset_type(),
936                                           input_names, attributes);
937   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
938   return OkStatus();
939 }
940 
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)941 Status DatasetOpsTestBase::MakeDatasetTensor(
942     const DatasetParams& dataset_params,
943     std::vector<std::unique_ptr<Tensor>>* created_tensors,
944     std::unique_ptr<Tensor>* dataset) {
945   // Make sure all the input dataset tensors have been populated.
946   std::vector<Tensor*> input_datasets;
947   for (auto& input : dataset_params.input_dataset_params()) {
948     std::unique_ptr<Tensor> t;
949     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
950     input_datasets.push_back(t.get());
951     created_tensors->push_back(std::move(t));
952   }
953 
954   AttributeVector attributes;
955   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
956 
957   auto input_tensors = dataset_params.GetInputTensors();
958   gtl::InlinedVector<TensorValue, 4> inputs;
959   inputs.reserve(input_datasets.size() + input_tensors.size());
960   for (auto input_dataset : input_datasets) {
961     inputs.emplace_back(TensorValue(input_dataset));
962   }
963   for (auto& input_tensor : input_tensors) {
964     inputs.emplace_back(TensorValue(&input_tensor));
965   }
966 
967   DatasetBase* dataset_base;
968   std::unique_ptr<OpKernel> dataset_kernel;
969   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
970   std::unique_ptr<OpKernelContext> dataset_ctx;
971   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
972   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
973                                           &dataset_ctx_params, &dataset_ctx));
974   TF_RETURN_IF_ERROR(
975       CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
976   Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
977   TF_RETURN_IF_ERROR(
978       StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
979   *dataset = std::make_unique<Tensor>(dataset_tensor);
980   return OkStatus();
981 }
982 
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)983 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
984                              std::vector<PartialTensorShape> output_shapes,
985                              string node_name)
986     : output_dtypes_(std::move(output_dtypes)),
987       output_shapes_(std::move(output_shapes)),
988       node_name_(std::move(node_name)) {}
989 
IsDatasetTensor(const Tensor & tensor)990 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
991   return tensor.dtype() == DT_VARIANT &&
992          TensorShapeUtils::IsScalar(tensor.shape());
993 }
994 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)995 RangeDatasetParams::RangeDatasetParams(
996     int64_t start, int64_t stop, int64_t step, DataTypeVector output_dtypes,
997     std::vector<PartialTensorShape> output_shapes, string node_name)
998     : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
999                     std::move(node_name)),
1000       start_(start),
1001       stop_(stop),
1002       step_(step) {}
1003 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step)1004 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1005                                        int64_t step)
1006     : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
1007       start_(start),
1008       stop_(stop),
1009       step_(step) {}
1010 
RangeDatasetParams(int64_t start,int64_t stop,int64_t step,DataTypeVector output_dtypes)1011 RangeDatasetParams::RangeDatasetParams(int64_t start, int64_t stop,
1012                                        int64_t step,
1013                                        DataTypeVector output_dtypes)
1014     : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
1015                     "range_dataset"),
1016       start_(start),
1017       stop_(stop),
1018       step_(step) {}
1019 
GetInputTensors() const1020 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
1021   Tensor start_tensor = CreateTensor<int64_t>(TensorShape({}), {start_});
1022   Tensor stop_tensor = CreateTensor<int64_t>(TensorShape({}), {stop_});
1023   Tensor step_tensor = CreateTensor<int64_t>(TensorShape({}), {step_});
1024   return {start_tensor, stop_tensor, step_tensor};
1025 }
1026 
GetInputNames(std::vector<string> * input_names) const1027 Status RangeDatasetParams::GetInputNames(
1028     std::vector<string>* input_names) const {
1029   *input_names = {"start", "stop", "step"};
1030   return OkStatus();
1031 }
1032 
GetAttributes(AttributeVector * attr_vector) const1033 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1034   *attr_vector = {{"output_types", output_dtypes_},
1035                   {"output_shapes", output_shapes_},
1036                   {"replicate_on_split", false},
1037                   {"metadata", ""}};
1038   return OkStatus();
1039 }
1040 
dataset_type() const1041 string RangeDatasetParams::dataset_type() const { return "Range"; }
1042 
GetInputTensors() const1043 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1044   Tensor batch_size = CreateTensor<int64_t>(TensorShape({}), {batch_size_});
1045   Tensor drop_remainder =
1046       CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1047   return {batch_size, drop_remainder};
1048 }
1049 
GetInputNames(std::vector<string> * input_names) const1050 Status BatchDatasetParams::GetInputNames(
1051     std::vector<string>* input_names) const {
1052   *input_names = {"input_dataset", "batch_size", "drop_remainder"};
1053   return OkStatus();
1054 }
1055 
GetAttributes(AttributeVector * attr_vector) const1056 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1057   *attr_vector = {{"parallel_copy", parallel_copy_},
1058                   {"output_types", output_dtypes_},
1059                   {"output_shapes", output_shapes_},
1060                   {"metadata", ""}};
1061   return OkStatus();
1062 }
1063 
dataset_type() const1064 string BatchDatasetParams::dataset_type() const { return "Batch"; }
1065 
GetInputTensors() const1066 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1067   return other_arguments_;
1068 }
1069 
GetInputNames(std::vector<string> * input_names) const1070 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1071   input_names->emplace_back("input_dataset");
1072   for (int i = 0; i < other_arguments_.size(); ++i) {
1073     input_names->emplace_back(absl::StrCat("other_arguments_", i));
1074   }
1075   return OkStatus();
1076 }
1077 
GetAttributes(AttributeVector * attr_vector) const1078 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1079   *attr_vector = {{"f", func_},
1080                   {"Targuments", type_arguments_},
1081                   {"output_shapes", output_shapes_},
1082                   {"output_types", output_dtypes_},
1083                   {"use_inter_op_parallelism", use_inter_op_parallelism_},
1084                   {"preserve_cardinality", preserve_cardinality_},
1085                   {"metadata", ""}};
1086   return OkStatus();
1087 }
1088 
dataset_type() const1089 string MapDatasetParams::dataset_type() const { return "Map"; }
1090 
func_lib() const1091 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1092   return func_lib_;
1093 }
1094 
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name,bool is_files)1095 TensorSliceDatasetParams::TensorSliceDatasetParams(
1096     std::vector<Tensor> components, string node_name, bool is_files)
1097     : DatasetParams(TensorSliceDtypes(components),
1098                     TensorSliceShapes(components), std::move(node_name)),
1099       components_(std::move(components)),
1100       is_files_(is_files) {}
1101 
GetInputTensors() const1102 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1103   return components_;
1104 }
1105 
GetInputNames(std::vector<string> * input_names) const1106 Status TensorSliceDatasetParams::GetInputNames(
1107     std::vector<string>* input_names) const {
1108   input_names->reserve(components_.size());
1109   for (int i = 0; i < components_.size(); ++i) {
1110     input_names->emplace_back(absl::StrCat("components_", i));
1111   }
1112   return OkStatus();
1113 }
1114 
GetAttributes(AttributeVector * attr_vector) const1115 Status TensorSliceDatasetParams::GetAttributes(
1116     AttributeVector* attr_vector) const {
1117   *attr_vector = {{"Toutput_types", output_dtypes_},
1118                   {"output_shapes", output_shapes_},
1119                   {"is_files", is_files_},
1120                   {"replicate_on_split", false},
1121                   {"metadata", ""}};
1122   return OkStatus();
1123 }
1124 
TensorSliceDtypes(const std::vector<Tensor> & input_components)1125 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1126     const std::vector<Tensor>& input_components) {
1127   DataTypeVector dtypes;
1128   for (const auto& component : input_components) {
1129     dtypes.emplace_back(component.dtype());
1130   }
1131   return dtypes;
1132 }
1133 
TensorSliceShapes(const std::vector<Tensor> & input_components)1134 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1135     const std::vector<Tensor>& input_components) {
1136   std::vector<PartialTensorShape> shapes;
1137   for (const auto& component : input_components) {
1138     gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
1139     for (int i = 1; i < component.dims(); ++i) {
1140       partial_dim_sizes.push_back(component.dim_size(i));
1141     }
1142     shapes.emplace_back(std::move(partial_dim_sizes));
1143   }
1144   return shapes;
1145 }
1146 
dataset_type() const1147 string TensorSliceDatasetParams::dataset_type() const { return "TensorSlice"; }
1148 
GetInputTensors() const1149 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1150   return {CreateTensor<int64_t>(TensorShape({}), {count_})};
1151 }
1152 
GetInputNames(std::vector<string> * input_names) const1153 Status TakeDatasetParams::GetInputNames(
1154     std::vector<string>* input_names) const {
1155   *input_names = {"input_dataset", "count"};
1156   return OkStatus();
1157 }
1158 
GetAttributes(AttributeVector * attr_vector) const1159 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1160   *attr_vector = {{"output_shapes", output_shapes_},
1161                   {"output_types", output_dtypes_},
1162                   {"metadata", ""}};
1163   return OkStatus();
1164 }
1165 
dataset_type() const1166 string TakeDatasetParams::dataset_type() const { return "Take"; }
1167 
GetInputTensors() const1168 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1169   return {};
1170 }
1171 
GetInputNames(std::vector<string> * input_names) const1172 Status ConcatenateDatasetParams::GetInputNames(
1173     std::vector<string>* input_names) const {
1174   *input_names = {"input_dataset", "another_dataset"};
1175   return OkStatus();
1176 }
1177 
GetAttributes(AttributeVector * attr_vector) const1178 Status ConcatenateDatasetParams::GetAttributes(
1179     AttributeVector* attr_vector) const {
1180   *attr_vector = {{"output_types", output_dtypes_},
1181                   {"output_shapes", output_shapes_},
1182                   {"metadata", ""}};
1183   return OkStatus();
1184 }
1185 
dataset_type() const1186 string ConcatenateDatasetParams::dataset_type() const { return "Concatenate"; }
1187 
GetInputTensors() const1188 std::vector<Tensor> OptionsDatasetParams::GetInputTensors() const { return {}; }
1189 
GetInputNames(std::vector<string> * input_names) const1190 Status OptionsDatasetParams::GetInputNames(
1191     std::vector<string>* input_names) const {
1192   input_names->emplace_back("input_dataset");
1193   return OkStatus();
1194 }
1195 
GetAttributes(AttributeVector * attr_vector) const1196 Status OptionsDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1197   *attr_vector = {{"serialized_options", serialized_options_},
1198                   {"output_shapes", output_shapes_},
1199                   {"output_types", output_dtypes_},
1200                   {"metadata", ""}};
1201   return OkStatus();
1202 }
1203 
dataset_type() const1204 string OptionsDatasetParams::dataset_type() const { return "Options"; }
1205 
1206 }  // namespace data
1207 }  // namespace tensorflow
1208