• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 #include "tensorflow/core/framework/partial_tensor_shape.h"
17 #include "tensorflow/core/framework/tensor.h"
18 #include "tensorflow/core/framework/variant.h"
19 
20 namespace tensorflow {
21 namespace data {
22 namespace {
23 
24 // See documentation in ../../ops/dataset_ops.cc for a high-level
25 // description of the following op.
26 
27 class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
28  public:
DenseToSparseBatchDatasetOp(OpKernelConstruction * ctx)29   explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx)
30       : UnaryDatasetOpKernel(ctx) {}
31 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)32   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
33                    DatasetBase** output) override {
34     // Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the
35     // step-local container, and return it as the output.
36     OP_REQUIRES(
37         ctx, input->output_dtypes().size() == 1,
38         errors::InvalidArgument("DenseToSparseBatchDataset only supports "
39                                 "inputs with a single component."));
40 
41     int64 batch_size;
42     OP_REQUIRES_OK(ctx,
43                    ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
44     OP_REQUIRES(
45         ctx, batch_size > 0,
46         errors::InvalidArgument("Batch size must be greater than zero."));
47 
48     const Tensor* row_shape_t;
49     OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t));
50     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()),
51                 errors::InvalidArgument("row_shape must be a vector"));
52     PartialTensorShape row_shape;
53     OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
54                             row_shape_t->vec<int64>().data(),
55                             row_shape_t->NumElements(), &row_shape));
56 
57     *output = nullptr;
58 
59 #define HANDLE_TYPE(T)                                           \
60   case DataTypeToEnum<T>::value: {                               \
61     *output = new Dataset<T>(ctx, batch_size, row_shape, input); \
62     break;                                                       \
63   }
64 
65     switch (input->output_dtypes()[0]) {
66       TF_CALL_DATASET_TYPES(HANDLE_TYPE);
67 #undef HANDLE_TYPE
68       default:
69         OP_REQUIRES(ctx, false,
70                     errors::Unimplemented(
71                         "DenseToSparseBatchDataset unhandled data type: ",
72                         input->output_dtypes()[0]));
73     }
74   }
75 
76  private:
77   // TODO(mrry): Push the templated code down to the raw copying routine.
78   template <class T>
79   class Dataset : public DatasetBase {
80    public:
Dataset(OpKernelContext * ctx,int64 batch_size,const PartialTensorShape & row_shape,const DatasetBase * input)81     Dataset(OpKernelContext* ctx, int64 batch_size,
82             const PartialTensorShape& row_shape, const DatasetBase* input)
83         : DatasetBase(DatasetContext(ctx)),
84           batch_size_(batch_size),
85           row_shape_(row_shape),
86           input_(input) {
87       input_->Ref();
88 
89       output_shapes_.reserve(1);
90       PartialTensorShape output_shape({-1});
91       output_shape.AppendShape(row_shape_);
92       output_shapes_.push_back(output_shape);
93     }
94 
~Dataset()95     ~Dataset() override { input_->Unref(); }
96 
MakeIteratorInternal(const string & prefix) const97     std::unique_ptr<IteratorBase> MakeIteratorInternal(
98         const string& prefix) const override {
99       return absl::make_unique<Iterator>(typename Iterator::Params{
100           this, strings::StrCat(prefix, "::DenseToSparseBatch")});
101     }
102 
output_dtypes() const103     const DataTypeVector& output_dtypes() const override {
104       static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT});
105       return *output_dtypes;
106     }
107 
output_shapes() const108     const std::vector<PartialTensorShape>& output_shapes() const override {
109       return output_shapes_;
110     }
111 
DebugString() const112     string DebugString() const override {
113       return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_,
114                              ")::Dataset");
115     }
116 
Cardinality() const117     int64 Cardinality() const override {
118       int64 n = input_->Cardinality();
119       if (n == kInfiniteCardinality || n == kUnknownCardinality) {
120         return n;
121       }
122       return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
123     }
124 
125    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const126     Status AsGraphDefInternal(SerializationContext* ctx,
127                               DatasetGraphDefBuilder* b,
128                               Node** output) const override {
129       Node* input_node;
130       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
131       Node* batch_size_node;
132       TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
133       Node* row_shape_node;
134       std::vector<int64> row_shape;
135       row_shape.reserve(
136           row_shape_.dims());  // not an unknown rank PartialTensorShape
137       for (int i = 0; i < row_shape_.dims(); i++)
138         row_shape.emplace_back(row_shape_.dim_size(i));
139       TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node));
140       TF_RETURN_IF_ERROR(b->AddDataset(
141           this, {input_node, batch_size_node, row_shape_node}, output));
142       return Status::OK();
143     }
144 
145    private:
146     class Iterator : public DatasetIterator<Dataset<T>> {
147      public:
Iterator(const typename Iterator::Params & params)148       explicit Iterator(const typename Iterator::Params& params)
149           : DatasetIterator<Dataset<T>>(params) {}
150 
Initialize(IteratorContext * ctx)151       Status Initialize(IteratorContext* ctx) override {
152         return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator(
153             ctx, DatasetIterator<Dataset<T>>::prefix(), &input_impl_);
154       }
155 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)156       Status GetNextInternal(IteratorContext* ctx,
157                              std::vector<Tensor>* out_tensors,
158                              bool* end_of_sequence) override {
159         // Each row of the output SparseTensor is an individual tensor
160         // from the input iterator.
161         std::vector<Tensor> batch_elements;
162         int64 total_elements = 0;
163         batch_elements.reserve(
164             DatasetIterator<Dataset<T>>::dataset()->batch_size_);
165         const PartialTensorShape& row_shape =
166             DatasetIterator<Dataset<T>>::dataset()->row_shape_;
167         const int row_ndims = row_shape.dims();
168 
169         // Determine the size of the output tensors:
170         // * dense_shape will be [`row_shape + 1`].
171         Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1});
172         auto dense_shape_vec = dense_shape.vec<int64>();
173         for (size_t i = 0; i < row_ndims; ++i) {
174           if (row_shape.dim_size(i) == -1) {
175             dense_shape_vec(i + 1) = 0;
176           } else {
177             dense_shape_vec(i + 1) = row_shape.dim_size(i);
178           }
179         }
180 
181         {
182           mutex_lock l(mu_);
183           *end_of_sequence = false;
184           for (int i = 0;
185                i < DatasetIterator<Dataset<T>>::dataset()->batch_size_ &&
186                !*end_of_sequence;
187                ++i) {
188             std::vector<Tensor> batch_element_tuple;
189             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
190                                                     end_of_sequence));
191             if (!*end_of_sequence) {
192               DCHECK_EQ(1, batch_element_tuple.size());
193               batch_elements.push_back(std::move(batch_element_tuple[0]));
194               total_elements += batch_element_tuple[0].NumElements();
195 
196               // TODO(mrry): Investigate how to hoist this check when we
197               // have static information that renders it unnecessary.
198               if (batch_element_tuple[0].shape().dims() != row_ndims) {
199                 return errors::InvalidArgument(
200                     "Input element had shape (",
201                     batch_element_tuple[0].shape().DebugString(),
202                     ") that is incompatible with the row shape (",
203                     row_shape.DebugString(), ").");
204               }
205               for (int j = 0; j < row_ndims; ++j) {
206                 // Take the maximum in the dimension if -1 is given.
207                 if (row_shape.dim_size(j) == -1) {
208                   dense_shape_vec(j + 1) =
209                       std::max(batch_element_tuple[0].dim_size(j),
210                                dense_shape_vec(j + 1));
211                 } else if (batch_element_tuple[0].dim_size(j) >
212                            row_shape.dim_size(j)) {
213                   return errors::DataLoss(
214                       "Input element had shape (",
215                       batch_element_tuple[0].shape().DebugString(),
216                       ") that is larger than the row shape (",
217                       row_shape.DebugString(), ").");
218                 }
219               }
220             }
221           }
222         }
223 
224         if (batch_elements.empty()) {
225           DCHECK(*end_of_sequence);
226           return Status::OK();
227         }
228 
229         // * indices will be [`total_elements`, `row_shape + 1`].
230         // * values will be [`total_elements`].
231         Tensor indices(ctx->allocator({}), DT_INT64,
232                        {total_elements, row_ndims + 1});
233         Tensor values(
234             ctx->allocator({}),
235             DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0],
236             {total_elements});
237         auto indices_matrix = indices.matrix<int64>();
238         auto values_flat = values.flat<T>();
239 
240         int64 current_position_in_values = 0;
241         for (int64 i = 0; i < batch_elements.size(); ++i) {
242           const Tensor& t = batch_elements[i];
243           const auto& t_flat = t.flat<T>();
244           // TODO(mrry): Replace with a memcpy or something more
245           // efficient. (Maybe an Eigen assign op?)
246           gtl::InlinedVector<int64, 4> strides(row_ndims);
247           if (!strides.empty()) {
248             strides[row_ndims - 1] = 1;
249             for (int64_t row_dim = strides.size() - 2; row_dim >= 0;
250                  --row_dim) {
251               strides[row_dim] =
252                   strides[row_dim + 1] * t.shape().dim_size(row_dim + 1);
253             }
254           }
255 
256           for (int64 j = 0; j < t.NumElements(); ++j) {
257             values_flat(current_position_in_values) = t_flat(j);
258             indices_matrix(current_position_in_values, 0) = i;
259             int64 index = j;
260             for (size_t k = 0; k < strides.size(); ++k) {
261               indices_matrix(current_position_in_values, k + 1) =
262                   index / strides[k];
263               index %= strides[k];
264             }
265             ++current_position_in_values;
266           }
267         }
268 
269         dense_shape_vec(0) = batch_elements.size();
270 
271         Tensor serialized_sparse(DT_VARIANT, TensorShape({3}));
272         auto serialized_sparse_t = serialized_sparse.vec<Variant>();
273         serialized_sparse_t(0) = std::move(indices);
274         serialized_sparse_t(1) = std::move(values);
275         serialized_sparse_t(2) = std::move(dense_shape);
276         out_tensors->push_back(std::move(serialized_sparse));
277 
278         *end_of_sequence = false;
279         return Status::OK();
280       }
281 
282      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const283       std::shared_ptr<model::Node> CreateNode(
284           IteratorContext* ctx, model::Node::Args args) const override {
285         return model::MakeKnownRatioNode(
286             std::move(args),
287             DatasetIterator<Dataset<T>>::dataset()->batch_size_);
288       }
289 
SaveInternal(IteratorStateWriter * writer)290       Status SaveInternal(IteratorStateWriter* writer) override {
291         mutex_lock l(mu_);
292         TF_RETURN_IF_ERROR(Iterator::SaveInput(writer, input_impl_));
293         return Status::OK();
294       }
295 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)296       Status RestoreInternal(IteratorContext* ctx,
297                              IteratorStateReader* reader) override {
298         mutex_lock l(mu_);
299         TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_));
300         return Status::OK();
301       }
302 
303      private:
304       mutex mu_;
305       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
306     };
307 
308     const int64 batch_size_;
309     const PartialTensorShape row_shape_;
310     const DatasetBase* const input_;
311     std::vector<PartialTensorShape> output_shapes_;
312   };
313 };
314 
315 REGISTER_KERNEL_BUILDER(
316     Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU),
317     DenseToSparseBatchDatasetOp);
318 
319 }  // namespace
320 }  // namespace data
321 }  // namespace tensorflow
322