• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/serialization_utils.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/data/dataset_test_base.h"
26 #include "tensorflow/core/data/dataset_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/node_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/framework/variant.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/str_util.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/work_sharder.h"
43 
44 namespace tensorflow {
45 namespace data {
46 namespace {
47 
48 class TestContext {
49  public:
Create(std::unique_ptr<TestContext> * result)50   static Status Create(std::unique_ptr<TestContext>* result) {
51     *result = absl::WrapUnique<TestContext>(new TestContext());
52 
53     SessionOptions options;
54     auto* device_count = options.config.mutable_device_count();
55     device_count->insert({"CPU", 1});
56     std::vector<std::unique_ptr<Device>> devices;
57     TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
58         options, "/job:localhost/replica:0/task:0", &devices));
59     (*result)->device_mgr_ =
60         std::make_unique<StaticDeviceMgr>(std::move(devices));
61 
62     FunctionDefLibrary proto;
63     (*result)->lib_def_ = std::make_unique<FunctionLibraryDefinition>(
64         OpRegistry::Global(), proto);
65 
66     OptimizerOptions opts;
67     (*result)->pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
68         (*result)->device_mgr_.get(), Env::Default(), /*config=*/nullptr,
69         TF_GRAPH_DEF_VERSION, (*result)->lib_def_.get(), opts);
70     (*result)->runner_ = [](const std::function<void()>& fn) { fn(); };
71     (*result)->params_.function_library =
72         (*result)->pflr_->GetFLR("/device:CPU:0");
73     (*result)->params_.device = (*result)->device_mgr_->ListDevices()[0];
74     (*result)->params_.runner = &(*result)->runner_;
75     (*result)->op_ctx_ =
76         std::make_unique<OpKernelContext>(&(*result)->params_, 0);
77     (*result)->iter_ctx_ =
78         std::make_unique<IteratorContext>((*result)->op_ctx_.get());
79     return OkStatus();
80   }
81 
iter_ctx() const82   IteratorContext* iter_ctx() const { return iter_ctx_.get(); }
83 
84  private:
85   TestContext() = default;
86 
87   std::unique_ptr<DeviceMgr> device_mgr_;
88   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
89   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
90   std::function<void(std::function<void()>)> runner_;
91   OpKernelContext::Params params_;
92   std::unique_ptr<OpKernelContext> op_ctx_;
93   std::unique_ptr<IteratorContext> iter_ctx_;
94 };
95 
full_name(string key)96 string full_name(string key) { return FullName("Iterator:", key); }
97 
TEST(SerializationUtilsTest,CheckpointElementsRoundTrip)98 TEST(SerializationUtilsTest, CheckpointElementsRoundTrip) {
99   std::vector<std::vector<Tensor>> elements;
100   elements.push_back(CreateTensors<int32>(TensorShape({3}), {{1, 2, 3}}));
101   elements.push_back(CreateTensors<int32>(TensorShape({2}), {{4, 5}}));
102   VariantTensorDataWriter writer;
103   tstring test_prefix = full_name("test_prefix");
104   TF_ASSERT_OK(WriteElementsToCheckpoint(&writer, test_prefix, elements));
105   std::vector<const VariantTensorData*> data;
106   writer.GetData(&data);
107 
108   VariantTensorDataReader reader(data);
109   std::vector<std::vector<Tensor>> read_elements;
110 
111   std::unique_ptr<TestContext> ctx;
112   TF_ASSERT_OK(TestContext::Create(&ctx));
113   TF_ASSERT_OK(ReadElementsFromCheckpoint(ctx->iter_ctx(), &reader, test_prefix,
114                                           &read_elements));
115   ASSERT_EQ(elements.size(), read_elements.size());
116   for (int i = 0; i < elements.size(); ++i) {
117     std::vector<Tensor>& original = elements[i];
118     std::vector<Tensor>& read = read_elements[i];
119 
120     ASSERT_EQ(original.size(), read.size());
121     for (int j = 0; j < original.size(); ++j) {
122       EXPECT_EQ(original[j].NumElements(), read[j].NumElements());
123       EXPECT_EQ(original[j].flat<int32>()(0), read[j].flat<int32>()(0));
124     }
125   }
126 }
127 
TEST(SerializationUtilsTest,VariantTensorDataRoundtrip)128 TEST(SerializationUtilsTest, VariantTensorDataRoundtrip) {
129   VariantTensorDataWriter writer;
130   TF_ASSERT_OK(writer.WriteScalar(full_name("Int64"), 24));
131   Tensor input_tensor(DT_FLOAT, {1});
132   input_tensor.flat<float>()(0) = 2.0f;
133   TF_ASSERT_OK(writer.WriteTensor(full_name("Tensor"), input_tensor));
134   std::vector<const VariantTensorData*> data;
135   writer.GetData(&data);
136 
137   VariantTensorDataReader reader(data);
138   int64_t val_int64;
139   TF_ASSERT_OK(reader.ReadScalar(full_name("Int64"), &val_int64));
140   EXPECT_EQ(val_int64, 24);
141   Tensor val_tensor;
142   TF_ASSERT_OK(reader.ReadTensor(full_name("Tensor"), &val_tensor));
143   EXPECT_EQ(input_tensor.NumElements(), val_tensor.NumElements());
144   EXPECT_EQ(input_tensor.flat<float>()(0), val_tensor.flat<float>()(0));
145 }
146 
TEST(SerializationUtilsTest,VariantTensorDataNonExistentKey)147 TEST(SerializationUtilsTest, VariantTensorDataNonExistentKey) {
148   VariantTensorData data;
149   strings::StrAppend(&data.metadata_, "key1", "@@");
150   data.tensors_.push_back(Tensor(DT_INT64, {1}));
151   std::vector<const VariantTensorData*> reader_data;
152   reader_data.push_back(&data);
153   VariantTensorDataReader reader(reader_data);
154   int64_t val_int64;
155   tstring val_string;
156   Tensor val_tensor;
157   EXPECT_EQ(error::NOT_FOUND,
158             reader.ReadScalar(full_name("NonExistentKey"), &val_int64).code());
159   EXPECT_EQ(error::NOT_FOUND,
160             reader.ReadScalar(full_name("NonExistentKey"), &val_string).code());
161   EXPECT_EQ(error::NOT_FOUND,
162             reader.ReadTensor(full_name("NonExistentKey"), &val_tensor).code());
163 }
164 
TEST(SerializationUtilsTest,VariantTensorDataRoundtripIteratorName)165 TEST(SerializationUtilsTest, VariantTensorDataRoundtripIteratorName) {
166   VariantTensorDataWriter writer;
167   TF_ASSERT_OK(writer.WriteScalar("Iterator", "Int64", 24));
168   Tensor input_tensor(DT_FLOAT, {1});
169   input_tensor.flat<float>()(0) = 2.0f;
170   TF_ASSERT_OK(writer.WriteTensor("Iterator", "Tensor", input_tensor));
171   std::vector<const VariantTensorData*> data;
172   writer.GetData(&data);
173 
174   VariantTensorDataReader reader(data);
175   int64_t val_int64;
176   TF_ASSERT_OK(reader.ReadScalar("Iterator", "Int64", &val_int64));
177   EXPECT_EQ(val_int64, 24);
178   Tensor val_tensor;
179   TF_ASSERT_OK(reader.ReadTensor("Iterator", "Tensor", &val_tensor));
180   EXPECT_EQ(input_tensor.NumElements(), val_tensor.NumElements());
181   EXPECT_EQ(input_tensor.flat<float>()(0), val_tensor.flat<float>()(0));
182 }
183 
TEST(SerializationUtilsTest,VariantTensorDataNonExistentKeyIteratorName)184 TEST(SerializationUtilsTest, VariantTensorDataNonExistentKeyIteratorName) {
185   VariantTensorData data;
186   strings::StrAppend(&data.metadata_, "key1", "@@");
187   data.tensors_.push_back(Tensor(DT_INT64, {1}));
188   std::vector<const VariantTensorData*> reader_data;
189   reader_data.push_back(&data);
190   VariantTensorDataReader reader(reader_data);
191   int64_t val_int64;
192   tstring val_string;
193   Tensor val_tensor;
194   EXPECT_EQ(error::NOT_FOUND,
195             reader.ReadScalar("Iterator", "NonExistentKey", &val_int64).code());
196   EXPECT_EQ(
197       error::NOT_FOUND,
198       reader.ReadScalar("Iterator", "NonExistentKey", &val_string).code());
199   EXPECT_EQ(
200       error::NOT_FOUND,
201       reader.ReadTensor("Iterator", "NonExistentKey", &val_tensor).code());
202 }
203 
TEST(SerializationUtilsTest,VariantTensorDataWriteAfterFlushing)204 TEST(SerializationUtilsTest, VariantTensorDataWriteAfterFlushing) {
205   VariantTensorDataWriter writer;
206   TF_ASSERT_OK(writer.WriteScalar(full_name("Int64"), 24));
207   std::vector<const VariantTensorData*> data;
208   writer.GetData(&data);
209   Tensor input_tensor(DT_FLOAT, {1});
210   input_tensor.flat<float>()(0) = 2.0f;
211   EXPECT_EQ(error::FAILED_PRECONDITION,
212             writer.WriteTensor(full_name("Tensor"), input_tensor).code());
213 }
214 
215 }  // namespace
216 }  // namespace data
217 }  // namespace tensorflow
218