1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
16
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/kernels/data/name_utils.h"
20
21 namespace tensorflow {
22 namespace data {
23
24 // See documentation in ../../ops/dataset_ops.cc for a high-level
25 // description of the following op.
26
27 /* static */ constexpr const char* const ConcatenateDatasetOp::kDatasetType;
28 /* static */ constexpr const char* const ConcatenateDatasetOp::kInputDataset;
29 /* static */ constexpr const char* const ConcatenateDatasetOp::kAnotherDataset;
30 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputTypes;
31 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputShapes;
32
33 constexpr char kIndex[] = "i";
34 constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
35
36 class ConcatenateDatasetOp::Dataset : public DatasetBase {
37 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const DatasetBase * to_concatenate)38 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
39 const DatasetBase* to_concatenate)
40 : DatasetBase(DatasetContext(ctx)),
41 input_(input),
42 to_concatenate_(to_concatenate) {
43 input_->Ref();
44 to_concatenate_->Ref();
45
46 auto os_input = input->output_shapes();
47 auto os_concatenate = to_concatenate->output_shapes();
48 for (int i = 0; i < os_input.size(); i++) {
49 output_shapes_.push_back(
50 MostSpecificCompatibleShape(os_input[i], os_concatenate[i]));
51 }
52 }
~Dataset()53 ~Dataset() override {
54 input_->Unref();
55 to_concatenate_->Unref();
56 }
57
MakeIteratorInternal(const string & prefix) const58 std::unique_ptr<IteratorBase> MakeIteratorInternal(
59 const string& prefix) const override {
60 return absl::make_unique<Iterator>(Iterator::Params{
61 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
62 }
63
output_dtypes() const64 const DataTypeVector& output_dtypes() const override {
65 return input_->output_dtypes();
66 }
67
output_shapes() const68 const std::vector<PartialTensorShape>& output_shapes() const override {
69 return output_shapes_;
70 }
71
DebugString() const72 string DebugString() const override {
73 return name_utils::DatasetDebugString(kDatasetType);
74 }
75
Cardinality() const76 int64 Cardinality() const override {
77 int64 n1 = input_->Cardinality();
78 int64 n2 = to_concatenate_->Cardinality();
79 if (n1 == kInfiniteCardinality || n2 == kInfiniteCardinality) {
80 return kInfiniteCardinality;
81 }
82 if (n1 == kUnknownCardinality || n2 == kUnknownCardinality) {
83 return kUnknownCardinality;
84 }
85 return n1 + n2;
86 }
87
InputDatasets(std::vector<const DatasetBase * > * inputs) const88 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
89 inputs->push_back(input_);
90 inputs->push_back(to_concatenate_);
91 return Status::OK();
92 }
93
CheckExternalState() const94 Status CheckExternalState() const override {
95 TF_RETURN_IF_ERROR(input_->CheckExternalState());
96 return to_concatenate_->CheckExternalState();
97 }
98
99 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const100 Status AsGraphDefInternal(SerializationContext* ctx,
101 DatasetGraphDefBuilder* b,
102 Node** output) const override {
103 Node* input_graph = nullptr;
104 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
105 Node* to_concatenate_graph = nullptr;
106 TF_RETURN_IF_ERROR(
107 b->AddInputDataset(ctx, to_concatenate_, &to_concatenate_graph));
108 TF_RETURN_IF_ERROR(
109 b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
110 return Status::OK();
111 }
112
113 private:
114 class Iterator : public DatasetIterator<Dataset> {
115 public:
Iterator(const Params & params)116 explicit Iterator(const Params& params)
117 : DatasetIterator<Dataset>(params), i_(0) {}
118
Initialize(IteratorContext * ctx)119 Status Initialize(IteratorContext* ctx) override {
120 return dataset()->input_->MakeIterator(
121 ctx, this, strings::StrCat(prefix(), "[0]"), &input_impl_);
122 }
123
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124 Status GetNextInternal(IteratorContext* ctx,
125 std::vector<Tensor>* out_tensors,
126 bool* end_of_sequence) override {
127 mutex_lock l(mu_);
128 if (!input_impl_) {
129 *end_of_sequence = true;
130 return Status::OK();
131 }
132 while (i_ < 2) {
133 TF_RETURN_IF_ERROR(
134 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
135 if (!*end_of_sequence) {
136 return Status::OK();
137 }
138 if (++i_ < 2) {
139 TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
140 ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
141 }
142 }
143 *end_of_sequence = true;
144 input_impl_.reset();
145 return Status::OK();
146 }
147
148 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const149 std::shared_ptr<model::Node> CreateNode(
150 IteratorContext* ctx, model::Node::Args args) const override {
151 return model::MakeKnownRatioNode(std::move(args),
152 /*ratio=*/1);
153 }
154
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)155 Status SaveInternal(SerializationContext* ctx,
156 IteratorStateWriter* writer) override {
157 mutex_lock l(mu_);
158 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
159 if (input_impl_) {
160 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
161 } else {
162 TF_RETURN_IF_ERROR(
163 writer->WriteScalar(full_name(kInputImplUninitialized), ""));
164 }
165 return Status::OK();
166 }
167
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)168 Status RestoreInternal(IteratorContext* ctx,
169 IteratorStateReader* reader) override {
170 mutex_lock l(mu_);
171 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &i_));
172 if (reader->Contains(full_name(kInputImplUninitialized))) {
173 input_impl_.reset();
174 return Status::OK();
175 }
176 if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
177 return errors::InvalidArgument("i_ must be in range [0, 2].");
178 if (i_ == 1) {
179 TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
180 ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
181 } else if (i_ == 2) {
182 input_impl_.reset();
183 }
184 if (input_impl_) {
185 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
186 }
187 return Status::OK();
188 }
189
190 private:
191 mutex mu_;
192 int64 i_ TF_GUARDED_BY(mu_);
193 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
194 };
195
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)196 static PartialTensorShape MostSpecificCompatibleShape(
197 const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
198 if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
199 return PartialTensorShape();
200 PartialTensorShape output_tensorshape({});
201 auto dims1 = ts1.dim_sizes();
202 auto dims2 = ts2.dim_sizes();
203 for (int d = 0; d < ts1.dims(); d++) {
204 if (dims1[d] == dims2[d])
205 output_tensorshape.AddDim(dims1[d]);
206 else
207 output_tensorshape.AddDim(-1);
208 }
209 return output_tensorshape;
210 }
211
212 const DatasetBase* input_;
213 const DatasetBase* to_concatenate_;
214 std::vector<PartialTensorShape> output_shapes_;
215 };
216
ConcatenateDatasetOp(OpKernelConstruction * ctx)217 ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
218 : BinaryDatasetOpKernel(ctx) {}
219
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase * to_concatenate,DatasetBase ** output)220 void ConcatenateDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
221 DatasetBase* to_concatenate,
222 DatasetBase** output) {
223 OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(),
224 errors::InvalidArgument(
225 "input dataset and dataset to concatenate"
226 " have different output_types %s and %s",
227 (DataTypeVectorString(input->output_dtypes()),
228 DataTypeVectorString(to_concatenate->output_dtypes()))));
229 *output = new Dataset(ctx, input, to_concatenate);
230 }
231
232 namespace {
233 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
234 ConcatenateDatasetOp);
235 } // namespace
236 } // namespace data
237 } // namespace tensorflow
238