1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/core/kernels/data/dataset_utils.h"
21 #include "tensorflow/core/kernels/data/name_utils.h"
22 #include "tensorflow/core/kernels/data/split_utils.h"
23 #include "tensorflow/core/util/batch_util.h"
24
25 namespace tensorflow {
26 namespace data {
27
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30
31 /* static */ constexpr const char* const TensorSliceDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const TensorSliceDatasetOp::kComponents;
33 /* static */ constexpr const char* const TensorSliceDatasetOp::kToutputTypes;
34 /* static */ constexpr const char* const TensorSliceDatasetOp::kOutputShapes;
35
36 constexpr char kCurIndex[] = "i";
37
38 class TensorSliceDatasetOp::Dataset : public DatasetBase {
39 public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors)40 explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
41 : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
42 for (const Tensor& t : tensors_) {
43 dtypes_.push_back(t.dtype());
44 gtl::InlinedVector<int64, 4> element_dim_sizes;
45 // Handle scalar here. Check that everyone matches here? Or fail
46 // at runtime?
47 for (int i = 1; i < t.dims(); ++i) {
48 element_dim_sizes.push_back(t.dim_size(i));
49 }
50 partial_shapes_.emplace_back(element_dim_sizes);
51 shapes_.emplace_back(std::move(element_dim_sizes));
52 }
53 }
54
MakeIteratorInternal(const string & prefix) const55 std::unique_ptr<IteratorBase> MakeIteratorInternal(
56 const string& prefix) const override {
57 return absl::make_unique<Iterator>(Iterator::Params{
58 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
59 }
60
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const61 Status MakeSplitProvider(
62 std::unique_ptr<SplitProvider>* split_provider) const override {
63 *split_provider =
64 absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0));
65 return Status::OK();
66 }
67
output_dtypes() const68 const DataTypeVector& output_dtypes() const override { return dtypes_; }
69
output_shapes() const70 const std::vector<PartialTensorShape>& output_shapes() const override {
71 return partial_shapes_;
72 }
73
DebugString() const74 string DebugString() const override {
75 return name_utils::DatasetDebugString(kDatasetType);
76 }
77
Cardinality() const78 int64 Cardinality() const override { return tensors_[0].dim_size(0); }
79
InputDatasets(std::vector<const DatasetBase * > * inputs) const80 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
81 return Status::OK();
82 }
83
CheckExternalState() const84 Status CheckExternalState() const override { return Status::OK(); }
85
86 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const87 Status AsGraphDefInternal(SerializationContext* ctx,
88 DatasetGraphDefBuilder* b,
89 Node** output) const override {
90 std::vector<Node*> components;
91 components.reserve(tensors_.size());
92 for (const Tensor& t : tensors_) {
93 Node* node;
94 if (ctx->serialize_data_tensors()) {
95 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
96 } else {
97 TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
98 DCHECK_NE(ctx->input_list(), nullptr);
99 ctx->input_list()->emplace_back(node->name(), t);
100 }
101 components.emplace_back(node);
102 }
103 AttrValue dtypes;
104 b->BuildAttrValue(dtypes_, &dtypes);
105 TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}},
106 {{kToutputTypes, dtypes}}, output));
107 return Status::OK();
108 }
109
110 private:
111 class Iterator : public DatasetIterator<Dataset> {
112 public:
Iterator(const Params & params)113 explicit Iterator(const Params& params)
114 : DatasetIterator<Dataset>(params) {}
115
Initialize(IteratorContext * ctx)116 Status Initialize(IteratorContext* ctx) override {
117 split_provider_ = ctx->split_provider();
118 if (split_provider_ == nullptr) {
119 split_provider_ = std::make_shared<IndexSplitProvider>(
120 dataset()->tensors_[0].dim_size(0));
121 }
122 return Status::OK();
123 }
124
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)125 Status GetNextInternal(IteratorContext* ctx,
126 std::vector<Tensor>* out_tensors,
127 bool* end_of_sequence) override {
128 Tensor split;
129 TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
130 if (*end_of_sequence) {
131 return Status::OK();
132 }
133 int64 index = split.scalar<int64>()();
134 out_tensors->clear();
135 out_tensors->reserve(dataset()->tensors_.size());
136 for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
137 const Tensor& t = dataset()->tensors_[i];
138 out_tensors->emplace_back(ctx->allocator({}), t.dtype(),
139 dataset()->shapes_[i]);
140 TF_RETURN_IF_ERROR(
141 batch_util::CopySliceToElement(t, &out_tensors->back(), index));
142 }
143 *end_of_sequence = false;
144 return Status::OK();
145 }
146
147 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const148 std::shared_ptr<model::Node> CreateNode(
149 IteratorContext* ctx, model::Node::Args args) const override {
150 return model::MakeSourceNode(std::move(args));
151 }
152
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)153 Status SaveInternal(SerializationContext* ctx,
154 IteratorStateWriter* writer) override {
155 return split_provider_->Save(
156 [this](const std::string& key) { return full_name(key); }, writer);
157 }
158
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)159 Status RestoreInternal(IteratorContext* ctx,
160 IteratorStateReader* reader) override {
161 return split_provider_->Restore(
162 [this](const std::string& key) { return full_name(key); }, reader);
163 }
164
165 private:
166 std::shared_ptr<SplitProvider> split_provider_;
167 };
168
169 const std::vector<Tensor> tensors_;
170 DataTypeVector dtypes_;
171 std::vector<TensorShape> shapes_;
172 std::vector<PartialTensorShape> partial_shapes_;
173 };
174
TensorSliceDatasetOp(OpKernelConstruction * ctx)175 TensorSliceDatasetOp::TensorSliceDatasetOp(OpKernelConstruction* ctx)
176 : DatasetOpKernel(ctx) {
177 OP_REQUIRES_OK(ctx, ctx->GetAttr(kToutputTypes, &output_types_));
178 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
179 }
180
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)181 void TensorSliceDatasetOp::MakeDataset(OpKernelContext* ctx,
182 DatasetBase** output) {
183 OpInputList inputs;
184 OP_REQUIRES_OK(ctx, ctx->input_list(kComponents, &inputs));
185 std::vector<Tensor> components;
186 components.reserve(inputs.size());
187 OP_REQUIRES(
188 ctx, inputs[0].dims() > 0,
189 errors::InvalidArgument("All components must be at least 1-dimensional"));
190 const int64 num_slices = inputs[0].dim_size(0);
191 for (const Tensor& t : inputs) {
192 components.push_back(t);
193 OP_REQUIRES(ctx, t.dims() > 0,
194 errors::InvalidArgument(
195 "All components must be at least 1-dimensional"));
196 OP_REQUIRES(
197 ctx, t.dim_size(0) == num_slices,
198 errors::InvalidArgument(
199 "All components must have the same size in the 0th dimension"));
200 }
201 *output = new Dataset(ctx, std::move(components));
202 OP_REQUIRES_OK(ctx,
203 VerifyTypesMatch((*output)->output_dtypes(), output_types_));
204 OP_REQUIRES_OK(
205 ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
206 }
207
208 namespace {
209 REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
210 TensorSliceDatasetOp);
211 } // namespace
212 } // namespace data
213 } // namespace tensorflow
214