1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/split_utils.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_util.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/util/batch_util.h"
28
29 namespace tensorflow {
30 namespace data {
31
32 // See documentation in ../../ops/dataset_ops.cc for a high-level
33 // description of the following op.
34
35 /* static */ constexpr const char* const TensorSliceDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const TensorSliceDatasetOp::kComponents;
37 /* static */ constexpr const char* const TensorSliceDatasetOp::kToutputTypes;
38 /* static */ constexpr const char* const TensorSliceDatasetOp::kOutputShapes;
39
40 class TensorSliceDatasetOp::Dataset : public DatasetBase {
41 public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors)42 explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
43 : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
44 for (const Tensor& t : tensors_) {
45 dtypes_.push_back(t.dtype());
46 gtl::InlinedVector<int64, 4> element_dim_sizes;
47 // Handle scalar here. Check that everyone matches here? Or fail
48 // at runtime?
49 for (int i = 1; i < t.dims(); ++i) {
50 element_dim_sizes.push_back(t.dim_size(i));
51 }
52 partial_shapes_.emplace_back(element_dim_sizes);
53 shapes_.emplace_back(std::move(element_dim_sizes));
54 }
55 }
56
MakeIteratorInternal(const string & prefix) const57 std::unique_ptr<IteratorBase> MakeIteratorInternal(
58 const string& prefix) const override {
59 return absl::make_unique<Iterator>(Iterator::Params{
60 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
61 }
62
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const63 Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
64 split_providers) const override {
65 split_providers->push_back(
66 absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0)));
67 return Status::OK();
68 }
69
output_dtypes() const70 const DataTypeVector& output_dtypes() const override { return dtypes_; }
71
output_shapes() const72 const std::vector<PartialTensorShape>& output_shapes() const override {
73 return partial_shapes_;
74 }
75
DebugString() const76 string DebugString() const override {
77 return name_utils::DatasetDebugString(kDatasetType);
78 }
79
Cardinality() const80 int64 Cardinality() const override { return tensors_[0].dim_size(0); }
81
InputDatasets(std::vector<const DatasetBase * > * inputs) const82 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
83 return Status::OK();
84 }
85
CheckExternalState() const86 Status CheckExternalState() const override { return Status::OK(); }
87
88 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const89 Status AsGraphDefInternal(SerializationContext* ctx,
90 DatasetGraphDefBuilder* b,
91 Node** output) const override {
92 std::vector<Node*> components;
93 components.reserve(tensors_.size());
94 for (const Tensor& t : tensors_) {
95 Node* node;
96 if (ctx->serialize_data_tensors()) {
97 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
98 } else {
99 TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
100 DCHECK_NE(ctx->input_list(), nullptr);
101 ctx->input_list()->emplace_back(node->name(), t);
102 }
103 components.emplace_back(node);
104 }
105 AttrValue dtypes;
106 b->BuildAttrValue(dtypes_, &dtypes);
107 TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}},
108 {{kToutputTypes, dtypes}}, output));
109 return Status::OK();
110 }
111
112 private:
113 class Iterator : public DatasetIterator<Dataset> {
114 public:
Iterator(const Params & params)115 explicit Iterator(const Params& params)
116 : DatasetIterator<Dataset>(params) {}
117
Initialize(IteratorContext * ctx)118 Status Initialize(IteratorContext* ctx) override {
119 if (ctx->split_providers().empty()) {
120 split_provider_ = std::make_shared<IndexSplitProvider>(
121 dataset()->tensors_[0].dim_size(0));
122 } else {
123 TF_ASSIGN_OR_RETURN(split_provider_,
124 GetSingleSplitProvider(ctx, dataset()));
125 }
126 return Status::OK();
127 }
128
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)129 Status GetNextInternal(IteratorContext* ctx,
130 std::vector<Tensor>* out_tensors,
131 bool* end_of_sequence) override {
132 Tensor split;
133 TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
134 if (*end_of_sequence) {
135 return Status::OK();
136 }
137 int64_t index = split.scalar<int64>()();
138 out_tensors->reserve(dataset()->tensors_.size());
139 for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
140 Tensor slice = dataset()->tensors_[i].SubSlice(index);
141 if (slice.IsAligned()) {
142 out_tensors->push_back(std::move(slice));
143 } else {
144 out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
145 }
146 }
147 *end_of_sequence = false;
148 return Status::OK();
149 }
150
151 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const152 std::shared_ptr<model::Node> CreateNode(
153 IteratorContext* ctx, model::Node::Args args) const override {
154 return model::MakeSourceNode(std::move(args));
155 }
156
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)157 Status SaveInternal(SerializationContext* ctx,
158 IteratorStateWriter* writer) override {
159 return split_provider_->Save(
160 [this](const std::string& key) { return full_name(key); }, writer);
161 }
162
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)163 Status RestoreInternal(IteratorContext* ctx,
164 IteratorStateReader* reader) override {
165 return split_provider_->Restore(
166 [this](const std::string& key) { return full_name(key); }, reader);
167 }
168
169 private:
170 std::shared_ptr<SplitProvider> split_provider_;
171 };
172
173 const std::vector<Tensor> tensors_;
174 DataTypeVector dtypes_;
175 std::vector<TensorShape> shapes_;
176 std::vector<PartialTensorShape> partial_shapes_;
177 };
178
TensorSliceDatasetOp(OpKernelConstruction * ctx)179 TensorSliceDatasetOp::TensorSliceDatasetOp(OpKernelConstruction* ctx)
180 : DatasetOpKernel(ctx) {
181 OP_REQUIRES_OK(ctx, ctx->GetAttr(kToutputTypes, &output_types_));
182 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
183 }
184
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)185 void TensorSliceDatasetOp::MakeDataset(OpKernelContext* ctx,
186 DatasetBase** output) {
187 OpInputList inputs;
188 OP_REQUIRES_OK(ctx, ctx->input_list(kComponents, &inputs));
189 std::vector<Tensor> components;
190 components.reserve(inputs.size());
191 OP_REQUIRES(
192 ctx, inputs[0].dims() > 0,
193 errors::InvalidArgument("All components must be at least 1-dimensional"));
194 const int64_t num_slices = inputs[0].dim_size(0);
195 for (const Tensor& t : inputs) {
196 components.push_back(t);
197 OP_REQUIRES(ctx, t.dims() > 0,
198 errors::InvalidArgument(
199 "All components must be at least 1-dimensional"));
200 OP_REQUIRES(
201 ctx, t.dim_size(0) == num_slices,
202 errors::InvalidArgument(
203 "All components must have the same size in the 0th dimension"));
204 }
205 *output = new Dataset(ctx, std::move(components));
206 OP_REQUIRES_OK(ctx,
207 VerifyTypesMatch((*output)->output_dtypes(), output_types_));
208 OP_REQUIRES_OK(
209 ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
210 }
211
212 namespace {
213 REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
214 TensorSliceDatasetOp);
215 } // namespace
216 } // namespace data
217 } // namespace tensorflow
218