• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/split_utils.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_util.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/util/batch_util.h"
28 
29 namespace tensorflow {
30 namespace data {
31 
32 // See documentation in ../../ops/dataset_ops.cc for a high-level
33 // description of the following op.
34 
35 /* static */ constexpr const char* const TensorSliceDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const TensorSliceDatasetOp::kComponents;
37 /* static */ constexpr const char* const TensorSliceDatasetOp::kToutputTypes;
38 /* static */ constexpr const char* const TensorSliceDatasetOp::kOutputShapes;
39 
40 class TensorSliceDatasetOp::Dataset : public DatasetBase {
41  public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors)42   explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
43       : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
44     for (const Tensor& t : tensors_) {
45       dtypes_.push_back(t.dtype());
46       gtl::InlinedVector<int64, 4> element_dim_sizes;
47       // Handle scalar here. Check that everyone matches here? Or fail
48       // at runtime?
49       for (int i = 1; i < t.dims(); ++i) {
50         element_dim_sizes.push_back(t.dim_size(i));
51       }
52       partial_shapes_.emplace_back(element_dim_sizes);
53       shapes_.emplace_back(std::move(element_dim_sizes));
54     }
55   }
56 
MakeIteratorInternal(const string & prefix) const57   std::unique_ptr<IteratorBase> MakeIteratorInternal(
58       const string& prefix) const override {
59     return absl::make_unique<Iterator>(Iterator::Params{
60         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
61   }
62 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const63   Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
64                                 split_providers) const override {
65     split_providers->push_back(
66         absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0)));
67     return Status::OK();
68   }
69 
output_dtypes() const70   const DataTypeVector& output_dtypes() const override { return dtypes_; }
71 
output_shapes() const72   const std::vector<PartialTensorShape>& output_shapes() const override {
73     return partial_shapes_;
74   }
75 
DebugString() const76   string DebugString() const override {
77     return name_utils::DatasetDebugString(kDatasetType);
78   }
79 
Cardinality() const80   int64 Cardinality() const override { return tensors_[0].dim_size(0); }
81 
InputDatasets(std::vector<const DatasetBase * > * inputs) const82   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
83     return Status::OK();
84   }
85 
CheckExternalState() const86   Status CheckExternalState() const override { return Status::OK(); }
87 
88  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const89   Status AsGraphDefInternal(SerializationContext* ctx,
90                             DatasetGraphDefBuilder* b,
91                             Node** output) const override {
92     std::vector<Node*> components;
93     components.reserve(tensors_.size());
94     for (const Tensor& t : tensors_) {
95       Node* node;
96       if (ctx->serialize_data_tensors()) {
97         TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
98       } else {
99         TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
100         DCHECK_NE(ctx->input_list(), nullptr);
101         ctx->input_list()->emplace_back(node->name(), t);
102       }
103       components.emplace_back(node);
104     }
105     AttrValue dtypes;
106     b->BuildAttrValue(dtypes_, &dtypes);
107     TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}},
108                                      {{kToutputTypes, dtypes}}, output));
109     return Status::OK();
110   }
111 
112  private:
113   class Iterator : public DatasetIterator<Dataset> {
114    public:
Iterator(const Params & params)115     explicit Iterator(const Params& params)
116         : DatasetIterator<Dataset>(params) {}
117 
Initialize(IteratorContext * ctx)118     Status Initialize(IteratorContext* ctx) override {
119       if (ctx->split_providers().empty()) {
120         split_provider_ = std::make_shared<IndexSplitProvider>(
121             dataset()->tensors_[0].dim_size(0));
122       } else {
123         TF_ASSIGN_OR_RETURN(split_provider_,
124                             GetSingleSplitProvider(ctx, dataset()));
125       }
126       return Status::OK();
127     }
128 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)129     Status GetNextInternal(IteratorContext* ctx,
130                            std::vector<Tensor>* out_tensors,
131                            bool* end_of_sequence) override {
132       Tensor split;
133       TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
134       if (*end_of_sequence) {
135         return Status::OK();
136       }
137       int64_t index = split.scalar<int64>()();
138       out_tensors->reserve(dataset()->tensors_.size());
139       for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
140         Tensor slice = dataset()->tensors_[i].SubSlice(index);
141         if (slice.IsAligned()) {
142           out_tensors->push_back(std::move(slice));
143         } else {
144           out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
145         }
146       }
147       *end_of_sequence = false;
148       return Status::OK();
149     }
150 
151    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const152     std::shared_ptr<model::Node> CreateNode(
153         IteratorContext* ctx, model::Node::Args args) const override {
154       return model::MakeSourceNode(std::move(args));
155     }
156 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)157     Status SaveInternal(SerializationContext* ctx,
158                         IteratorStateWriter* writer) override {
159       return split_provider_->Save(
160           [this](const std::string& key) { return full_name(key); }, writer);
161     }
162 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)163     Status RestoreInternal(IteratorContext* ctx,
164                            IteratorStateReader* reader) override {
165       return split_provider_->Restore(
166           [this](const std::string& key) { return full_name(key); }, reader);
167     }
168 
169    private:
170     std::shared_ptr<SplitProvider> split_provider_;
171   };
172 
173   const std::vector<Tensor> tensors_;
174   DataTypeVector dtypes_;
175   std::vector<TensorShape> shapes_;
176   std::vector<PartialTensorShape> partial_shapes_;
177 };
178 
TensorSliceDatasetOp(OpKernelConstruction * ctx)179 TensorSliceDatasetOp::TensorSliceDatasetOp(OpKernelConstruction* ctx)
180     : DatasetOpKernel(ctx) {
181   OP_REQUIRES_OK(ctx, ctx->GetAttr(kToutputTypes, &output_types_));
182   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
183 }
184 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)185 void TensorSliceDatasetOp::MakeDataset(OpKernelContext* ctx,
186                                        DatasetBase** output) {
187   OpInputList inputs;
188   OP_REQUIRES_OK(ctx, ctx->input_list(kComponents, &inputs));
189   std::vector<Tensor> components;
190   components.reserve(inputs.size());
191   OP_REQUIRES(
192       ctx, inputs[0].dims() > 0,
193       errors::InvalidArgument("All components must be at least 1-dimensional"));
194   const int64_t num_slices = inputs[0].dim_size(0);
195   for (const Tensor& t : inputs) {
196     components.push_back(t);
197     OP_REQUIRES(ctx, t.dims() > 0,
198                 errors::InvalidArgument(
199                     "All components must be at least 1-dimensional"));
200     OP_REQUIRES(
201         ctx, t.dim_size(0) == num_slices,
202         errors::InvalidArgument(
203             "All components must have the same size in the 0th dimension"));
204   }
205   *output = new Dataset(ctx, std::move(components));
206   OP_REQUIRES_OK(ctx,
207                  VerifyTypesMatch((*output)->output_dtypes(), output_types_));
208   OP_REQUIRES_OK(
209       ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
210 }
211 
212 namespace {
213 REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
214                         TensorSliceDatasetOp);
215 }  // namespace
216 }  // namespace data
217 }  // namespace tensorflow
218