• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <numeric>
16 
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/partial_tensor_shape.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/util/sparse/sparse_tensor.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 
27 // See documentation in ../../ops/dataset_ops.cc for a high-level
28 // description of the following op.
29 
30 template <typename T>
31 class Dataset : public DatasetBase {
32  public:
Dataset(OpKernelContext * ctx,const sparse::SparseTensor & sparse_tensor)33   explicit Dataset(OpKernelContext* ctx,
34                    const sparse::SparseTensor& sparse_tensor)
35       : DatasetBase(DatasetContext(ctx)),
36         sparse_tensor_(sparse_tensor),
37         dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}),
38         shapes_({{-1, sparse_tensor.dims() - 1},
39                  {-1},
40                  {sparse_tensor.dims() - 1}}) {}
41 
MakeIteratorInternal(const string & prefix) const42   std::unique_ptr<IteratorBase> MakeIteratorInternal(
43       const string& prefix) const override {
44     return std::make_unique<Iterator>(typename Iterator::Params{
45         this, strings::StrCat(prefix, "::SparseTensorSlice")});
46   }
47 
output_dtypes() const48   const DataTypeVector& output_dtypes() const override { return dtypes_; }
output_shapes() const49   const std::vector<PartialTensorShape>& output_shapes() const override {
50     return shapes_;
51   }
52 
DebugString() const53   string DebugString() const override {
54     return "SparseTensorSliceDatasetOp::Dataset";
55   }
56 
CardinalityInternal() const57   int64_t CardinalityInternal() const override {
58     return sparse_tensor_.shape()[0];
59   }
60 
InputDatasets(std::vector<const DatasetBase * > * inputs) const61   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
62     return OkStatus();
63   }
64 
CheckExternalState() const65   Status CheckExternalState() const override { return OkStatus(); }
66 
67  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const68   Status AsGraphDefInternal(SerializationContext* ctx,
69                             DatasetGraphDefBuilder* b,
70                             Node** output) const override {
71     Node* indices_node;
72     TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node));
73     Node* value_node;
74     TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.values(), &value_node));
75     Node* dense_shape_node;
76     std::vector<int64_t> dense_shape;
77     dense_shape.reserve(sparse_tensor_.shape().size());
78     for (int i = 0; i < sparse_tensor_.shape().size(); i++)
79       dense_shape.emplace_back(sparse_tensor_.shape()[i]);
80     TF_RETURN_IF_ERROR(b->AddVector(dense_shape, &dense_shape_node));
81     AttrValue val_dtype;
82     b->BuildAttrValue(sparse_tensor_.dtype(), &val_dtype);
83     TF_RETURN_IF_ERROR(
84         b->AddDataset(this, {indices_node, value_node, dense_shape_node},
85                       {{"Tvalues", val_dtype}}, output));
86     return OkStatus();
87   }
88 
89  private:
90   class Iterator : public DatasetIterator<Dataset<T>> {
91    public:
Iterator(const typename Iterator::Params & params)92     explicit Iterator(const typename Iterator::Params& params)
93         : DatasetIterator<Dataset<T>>(params),
94           num_elements_(params.dataset->sparse_tensor_.shape()[0]),
95           dense_shape_(DT_INT64, {params.dataset->sparse_tensor_.dims() - 1}),
96           group_iterable_(params.dataset->sparse_tensor_.group({0})),
97           iter_(group_iterable_.begin()) {
98       for (size_t i = 0; i < dense_shape_.NumElements(); ++i) {
99         dense_shape_.vec<int64_t>()(i) =
100             params.dataset->sparse_tensor_.shape()[i + 1];
101       }
102     }
103 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)104     Status GetNextInternal(IteratorContext* ctx,
105                            std::vector<Tensor>* out_tensors,
106                            bool* end_of_sequence) override {
107       mutex_lock l(mu_);
108       if (i_ == num_elements_) {
109         *end_of_sequence = true;
110         return OkStatus();
111       }
112 
113       out_tensors->clear();
114       out_tensors->reserve(3);
115       const int rank = Iterator::dataset()->sparse_tensor_.dims();
116 
117       if (i_ > next_non_empty_i_ && iter_ != group_iterable_.end()) {
118         // We still have elements to consume from `group_iterable_`
119         // and we have emitted all elements up to and including the
120         // current position.
121         sparse::Group group = *iter_;
122         const auto indices = group.indices();
123         const auto values = group.values<T>();
124         const int64_t num_entries = values.size();
125         next_non_empty_i_ = indices(0, 0);
126 
127         next_indices_ = Tensor(DT_INT64, {num_entries, rank - 1});
128         next_values_ = Tensor(DataTypeToEnum<T>::value, {num_entries});
129 
130         auto next_indices_t = next_indices_.matrix<int64_t>();
131         auto next_values_t = next_values_.vec<T>();
132 
133         for (int64_t i = 0; i < num_entries; ++i) {
134           for (int d = 1; d < rank; ++d) {
135             next_indices_t(i, d - 1) = indices(i, d);
136           }
137           next_values_t(i) = values(i);
138         }
139 
140         ++iter_;
141       }
142       if (i_ == next_non_empty_i_) {
143         // The current position is non-empty in the input
144         // `SparseTensor`, and we have already read the value from the
145         // `GroupIterable`.
146         out_tensors->push_back(std::move(next_indices_));
147         out_tensors->push_back(std::move(next_values_));
148         out_tensors->push_back(dense_shape_);
149         next_non_empty_i_ = kNextNonEmptyUnknown;
150       } else {
151         DCHECK(i_ < next_non_empty_i_ || iter_ == group_iterable_.end());
152         // The current position is empty in the input `SparseTensor`,
153         // so emit empty indices and values.
154         out_tensors->push_back(Tensor(DT_INT64, TensorShape({0, rank - 1})));
155         out_tensors->push_back(Tensor(DataTypeToEnum<T>::value, {0}));
156         out_tensors->push_back(dense_shape_);
157       }
158 
159       ++i_;
160       *end_of_sequence = false;
161       return OkStatus();
162     }
163 
164    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const165     std::shared_ptr<model::Node> CreateNode(
166         IteratorContext* ctx, model::Node::Args args) const override {
167       return model::MakeSourceNode(std::move(args));
168     }
169 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)170     Status SaveInternal(SerializationContext* ctx,
171                         IteratorStateWriter* writer) override {
172       mutex_lock l(mu_);
173       TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_));
174       TF_RETURN_IF_ERROR(
175           writer->WriteScalar(Iterator::full_name("iter_loc"), iter_.loc()));
176       TF_RETURN_IF_ERROR(writer->WriteScalar(
177           Iterator::full_name("next_non_empty_i_"), next_non_empty_i_));
178       if (i_ <= next_non_empty_i_) {
179         TF_RETURN_IF_ERROR(writer->WriteTensor(
180             Iterator::full_name("next_indices_"), next_indices_));
181         TF_RETURN_IF_ERROR(writer->WriteTensor(
182             Iterator::full_name("next_values_"), next_values_));
183       }
184       return OkStatus();
185     }
186 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)187     Status RestoreInternal(IteratorContext* ctx,
188                            IteratorStateReader* reader) override {
189       mutex_lock l(mu_);
190       TF_RETURN_IF_ERROR(reader->ReadScalar(Iterator::full_name("i"), &i_));
191       int64_t iter_loc;
192       TF_RETURN_IF_ERROR(
193           reader->ReadScalar(Iterator::full_name("iter_loc"), &iter_loc));
194       iter_ = group_iterable_.at(iter_loc);
195       TF_RETURN_IF_ERROR(reader->ReadScalar(
196           Iterator::full_name("next_non_empty_i_"), &next_non_empty_i_));
197       if (i_ <= next_non_empty_i_) {
198         TF_RETURN_IF_ERROR(reader->ReadTensor(
199             Iterator::full_name("next_indices_"), &next_indices_));
200         TF_RETURN_IF_ERROR(reader->ReadTensor(
201             Iterator::full_name("next_values_"), &next_values_));
202       }
203       return OkStatus();
204     }
205 
206    private:
207     const int64_t num_elements_;
208 
209     Tensor dense_shape_;
210 
211     mutex mu_;
212     sparse::GroupIterable group_iterable_ TF_GUARDED_BY(mu_);
213     sparse::GroupIterable::IteratorStep iter_ TF_GUARDED_BY(mu_);
214     int64_t i_ TF_GUARDED_BY(mu_) = 0;
215     const int64_t kNextNonEmptyUnknown = -1;
216     int64_t next_non_empty_i_ TF_GUARDED_BY(mu_) = kNextNonEmptyUnknown;
217     Tensor next_indices_ TF_GUARDED_BY(mu_);
218     Tensor next_values_ TF_GUARDED_BY(mu_);
219   };
220 
221   const sparse::SparseTensor sparse_tensor_;
222   const DataTypeVector dtypes_;
223   const std::vector<PartialTensorShape> shapes_;
224 };
225 
226 template <typename T>
227 class SparseTensorSliceDatasetOp : public DatasetOpKernel {
228  public:
SparseTensorSliceDatasetOp(OpKernelConstruction * ctx)229   explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx)
230       : DatasetOpKernel(ctx) {}
231 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)232   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
233     // Create a new SparseTensorSliceDatasetOp::Dataset, insert it in
234     // the step container, and return it as the output.
235     const Tensor* indices;
236     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices));
237     const Tensor* values;
238     OP_REQUIRES_OK(ctx, ctx->input("values", &values));
239     const Tensor* dense_shape;
240     OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape));
241 
242     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()),
243                 errors::InvalidArgument("Input indices must be a matrix. Got: ",
244                                         indices->shape().DebugString()));
245     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
246                 errors::InvalidArgument("Input values must be a vector. Got: ",
247                                         values->shape().DebugString()));
248     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()),
249                 errors::InvalidArgument("Input shape must be a vector. Got: ",
250                                         dense_shape->shape().DebugString()));
251     OP_REQUIRES(
252         ctx, values->shape().dim_size(0) == indices->shape().dim_size(0),
253         errors::InvalidArgument(
254             "Number of values must match first dimension of indices. ", "Got ",
255             values->shape().dim_size(0),
256             " values, indices shape: ", indices->shape().DebugString()));
257     OP_REQUIRES(
258         ctx, dense_shape->shape().dim_size(0) == indices->shape().dim_size(1),
259         errors::InvalidArgument(
260             "Number of dimensions must match second dimension of indices. ",
261             "Got ", dense_shape->shape().dim_size(0),
262             " dimensions, indices shape: ", indices->shape().DebugString()));
263     OP_REQUIRES(ctx, dense_shape->NumElements() > 0,
264                 errors::InvalidArgument(
265                     "The shape argument requires at least one element."));
266 
267     // We currently ensure that `sparse_tensor` is ordered in the
268     // batch dimension.
269     // TODO(mrry): Investigate ways to avoid this unconditional check
270     // if we can be sure that the sparse tensor was produced in an
271     // appropriate order (e.g. by `tf.parse_example()` or a Dataset
272     // that batches elements into rows of a SparseTensor).
273     int64_t previous_batch_index = -1;
274     for (int64_t i = 0; i < indices->dim_size(0); ++i) {
275       int64_t next_batch_index = indices->matrix<int64_t>()(i, 0);
276       OP_REQUIRES(
277           ctx, next_batch_index >= previous_batch_index,
278           errors::Unimplemented("The SparseTensor must be ordered in the batch "
279                                 "dimension; handling arbitrarily ordered input "
280                                 "is not currently supported."));
281       previous_batch_index = next_batch_index;
282     }
283     gtl::InlinedVector<int64_t, 8> std_order(dense_shape->NumElements(), 0);
284     TensorShape shape;
285     OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(
286                             dense_shape->vec<int64_t>(), &shape));
287     sparse::SparseTensor tensor;
288     OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(*indices, *values, shape,
289                                                      std_order, &tensor));
290     *output = new Dataset<T>(ctx, std::move(tensor));
291   }
292 
293  private:
294 };
295 
296 #define REGISTER_DATASET_KERNEL(type)                           \
297   REGISTER_KERNEL_BUILDER(Name("SparseTensorSliceDataset")      \
298                               .Device(DEVICE_CPU)               \
299                               .TypeConstraint<type>("Tvalues"), \
300                           SparseTensorSliceDatasetOp<type>);
301 
302 TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL);
303 #undef REGISTER_DATASET_KERNEL
304 
305 }  // namespace
306 }  // namespace data
307 }  // namespace tensorflow
308