1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/serialization_utils.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/data/dataset_test_base.h"
26 #include "tensorflow/core/data/dataset_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/node_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/framework/variant.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/str_util.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/work_sharder.h"
43
44 namespace tensorflow {
45 namespace data {
46 namespace {
47
48 class TestContext {
49 public:
Create(std::unique_ptr<TestContext> * result)50 static Status Create(std::unique_ptr<TestContext>* result) {
51 *result = absl::WrapUnique<TestContext>(new TestContext());
52
53 SessionOptions options;
54 auto* device_count = options.config.mutable_device_count();
55 device_count->insert({"CPU", 1});
56 std::vector<std::unique_ptr<Device>> devices;
57 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
58 options, "/job:localhost/replica:0/task:0", &devices));
59 (*result)->device_mgr_ =
60 std::make_unique<StaticDeviceMgr>(std::move(devices));
61
62 FunctionDefLibrary proto;
63 (*result)->lib_def_ = std::make_unique<FunctionLibraryDefinition>(
64 OpRegistry::Global(), proto);
65
66 OptimizerOptions opts;
67 (*result)->pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
68 (*result)->device_mgr_.get(), Env::Default(), /*config=*/nullptr,
69 TF_GRAPH_DEF_VERSION, (*result)->lib_def_.get(), opts);
70 (*result)->runner_ = [](const std::function<void()>& fn) { fn(); };
71 (*result)->params_.function_library =
72 (*result)->pflr_->GetFLR("/device:CPU:0");
73 (*result)->params_.device = (*result)->device_mgr_->ListDevices()[0];
74 (*result)->params_.runner = &(*result)->runner_;
75 (*result)->op_ctx_ =
76 std::make_unique<OpKernelContext>(&(*result)->params_, 0);
77 (*result)->iter_ctx_ =
78 std::make_unique<IteratorContext>((*result)->op_ctx_.get());
79 return OkStatus();
80 }
81
iter_ctx() const82 IteratorContext* iter_ctx() const { return iter_ctx_.get(); }
83
84 private:
85 TestContext() = default;
86
87 std::unique_ptr<DeviceMgr> device_mgr_;
88 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
89 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
90 std::function<void(std::function<void()>)> runner_;
91 OpKernelContext::Params params_;
92 std::unique_ptr<OpKernelContext> op_ctx_;
93 std::unique_ptr<IteratorContext> iter_ctx_;
94 };
95
full_name(string key)96 string full_name(string key) { return FullName("Iterator:", key); }
97
TEST(SerializationUtilsTest,CheckpointElementsRoundTrip)98 TEST(SerializationUtilsTest, CheckpointElementsRoundTrip) {
99 std::vector<std::vector<Tensor>> elements;
100 elements.push_back(CreateTensors<int32>(TensorShape({3}), {{1, 2, 3}}));
101 elements.push_back(CreateTensors<int32>(TensorShape({2}), {{4, 5}}));
102 VariantTensorDataWriter writer;
103 tstring test_prefix = full_name("test_prefix");
104 TF_ASSERT_OK(WriteElementsToCheckpoint(&writer, test_prefix, elements));
105 std::vector<const VariantTensorData*> data;
106 writer.GetData(&data);
107
108 VariantTensorDataReader reader(data);
109 std::vector<std::vector<Tensor>> read_elements;
110
111 std::unique_ptr<TestContext> ctx;
112 TF_ASSERT_OK(TestContext::Create(&ctx));
113 TF_ASSERT_OK(ReadElementsFromCheckpoint(ctx->iter_ctx(), &reader, test_prefix,
114 &read_elements));
115 ASSERT_EQ(elements.size(), read_elements.size());
116 for (int i = 0; i < elements.size(); ++i) {
117 std::vector<Tensor>& original = elements[i];
118 std::vector<Tensor>& read = read_elements[i];
119
120 ASSERT_EQ(original.size(), read.size());
121 for (int j = 0; j < original.size(); ++j) {
122 EXPECT_EQ(original[j].NumElements(), read[j].NumElements());
123 EXPECT_EQ(original[j].flat<int32>()(0), read[j].flat<int32>()(0));
124 }
125 }
126 }
127
TEST(SerializationUtilsTest,VariantTensorDataRoundtrip)128 TEST(SerializationUtilsTest, VariantTensorDataRoundtrip) {
129 VariantTensorDataWriter writer;
130 TF_ASSERT_OK(writer.WriteScalar(full_name("Int64"), 24));
131 Tensor input_tensor(DT_FLOAT, {1});
132 input_tensor.flat<float>()(0) = 2.0f;
133 TF_ASSERT_OK(writer.WriteTensor(full_name("Tensor"), input_tensor));
134 std::vector<const VariantTensorData*> data;
135 writer.GetData(&data);
136
137 VariantTensorDataReader reader(data);
138 int64_t val_int64;
139 TF_ASSERT_OK(reader.ReadScalar(full_name("Int64"), &val_int64));
140 EXPECT_EQ(val_int64, 24);
141 Tensor val_tensor;
142 TF_ASSERT_OK(reader.ReadTensor(full_name("Tensor"), &val_tensor));
143 EXPECT_EQ(input_tensor.NumElements(), val_tensor.NumElements());
144 EXPECT_EQ(input_tensor.flat<float>()(0), val_tensor.flat<float>()(0));
145 }
146
TEST(SerializationUtilsTest,VariantTensorDataNonExistentKey)147 TEST(SerializationUtilsTest, VariantTensorDataNonExistentKey) {
148 VariantTensorData data;
149 strings::StrAppend(&data.metadata_, "key1", "@@");
150 data.tensors_.push_back(Tensor(DT_INT64, {1}));
151 std::vector<const VariantTensorData*> reader_data;
152 reader_data.push_back(&data);
153 VariantTensorDataReader reader(reader_data);
154 int64_t val_int64;
155 tstring val_string;
156 Tensor val_tensor;
157 EXPECT_EQ(error::NOT_FOUND,
158 reader.ReadScalar(full_name("NonExistentKey"), &val_int64).code());
159 EXPECT_EQ(error::NOT_FOUND,
160 reader.ReadScalar(full_name("NonExistentKey"), &val_string).code());
161 EXPECT_EQ(error::NOT_FOUND,
162 reader.ReadTensor(full_name("NonExistentKey"), &val_tensor).code());
163 }
164
TEST(SerializationUtilsTest,VariantTensorDataRoundtripIteratorName)165 TEST(SerializationUtilsTest, VariantTensorDataRoundtripIteratorName) {
166 VariantTensorDataWriter writer;
167 TF_ASSERT_OK(writer.WriteScalar("Iterator", "Int64", 24));
168 Tensor input_tensor(DT_FLOAT, {1});
169 input_tensor.flat<float>()(0) = 2.0f;
170 TF_ASSERT_OK(writer.WriteTensor("Iterator", "Tensor", input_tensor));
171 std::vector<const VariantTensorData*> data;
172 writer.GetData(&data);
173
174 VariantTensorDataReader reader(data);
175 int64_t val_int64;
176 TF_ASSERT_OK(reader.ReadScalar("Iterator", "Int64", &val_int64));
177 EXPECT_EQ(val_int64, 24);
178 Tensor val_tensor;
179 TF_ASSERT_OK(reader.ReadTensor("Iterator", "Tensor", &val_tensor));
180 EXPECT_EQ(input_tensor.NumElements(), val_tensor.NumElements());
181 EXPECT_EQ(input_tensor.flat<float>()(0), val_tensor.flat<float>()(0));
182 }
183
TEST(SerializationUtilsTest,VariantTensorDataNonExistentKeyIteratorName)184 TEST(SerializationUtilsTest, VariantTensorDataNonExistentKeyIteratorName) {
185 VariantTensorData data;
186 strings::StrAppend(&data.metadata_, "key1", "@@");
187 data.tensors_.push_back(Tensor(DT_INT64, {1}));
188 std::vector<const VariantTensorData*> reader_data;
189 reader_data.push_back(&data);
190 VariantTensorDataReader reader(reader_data);
191 int64_t val_int64;
192 tstring val_string;
193 Tensor val_tensor;
194 EXPECT_EQ(error::NOT_FOUND,
195 reader.ReadScalar("Iterator", "NonExistentKey", &val_int64).code());
196 EXPECT_EQ(
197 error::NOT_FOUND,
198 reader.ReadScalar("Iterator", "NonExistentKey", &val_string).code());
199 EXPECT_EQ(
200 error::NOT_FOUND,
201 reader.ReadTensor("Iterator", "NonExistentKey", &val_tensor).code());
202 }
203
TEST(SerializationUtilsTest,VariantTensorDataWriteAfterFlushing)204 TEST(SerializationUtilsTest, VariantTensorDataWriteAfterFlushing) {
205 VariantTensorDataWriter writer;
206 TF_ASSERT_OK(writer.WriteScalar(full_name("Int64"), 24));
207 std::vector<const VariantTensorData*> data;
208 writer.GetData(&data);
209 Tensor input_tensor(DT_FLOAT, {1});
210 input_tensor.flat<float>()(0) = 2.0f;
211 EXPECT_EQ(error::FAILED_PRECONDITION,
212 writer.WriteTensor(full_name("Tensor"), input_tensor).code());
213 }
214
215 } // namespace
216 } // namespace data
217 } // namespace tensorflow
218