• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 #include "tensorflow/core/framework/partial_tensor_shape.h"
17 #include "tensorflow/core/framework/tensor.h"
18 #include "tensorflow/core/framework/variant.h"
19 
20 namespace tensorflow {
21 namespace data {
22 namespace experimental {
23 namespace {
24 
25 class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
26  public:
DenseToSparseBatchDatasetOp(OpKernelConstruction * ctx)27   explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx)
28       : UnaryDatasetOpKernel(ctx) {}
29 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)30   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
31                    DatasetBase** output) override {
32     // Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the
33     // step-local container, and return it as the output.
34     OP_REQUIRES(
35         ctx, input->output_dtypes().size() == 1,
36         errors::InvalidArgument("DenseToSparseBatchDataset only supports "
37                                 "inputs with a single component."));
38 
39     int64 batch_size;
40     OP_REQUIRES_OK(ctx,
41                    ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
42     OP_REQUIRES(
43         ctx, batch_size > 0,
44         errors::InvalidArgument("Batch size must be greater than zero."));
45 
46     const Tensor* row_shape_t;
47     OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t));
48     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()),
49                 errors::InvalidArgument("row_shape must be a vector"));
50     PartialTensorShape row_shape;
51     OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
52                             row_shape_t->vec<int64>().data(),
53                             row_shape_t->NumElements(), &row_shape));
54 
55     *output = nullptr;
56 
57 #define HANDLE_TYPE(T)                                           \
58   case DataTypeToEnum<T>::value: {                               \
59     *output = new Dataset<T>(ctx, batch_size, row_shape, input); \
60     break;                                                       \
61   }
62 
63     switch (input->output_dtypes()[0]) {
64       TF_CALL_DATASET_TYPES(HANDLE_TYPE);
65 #undef HANDLE_TYPE
66       default:
67         OP_REQUIRES(ctx, false,
68                     errors::Unimplemented(
69                         "DenseToSparseBatchDataset unhandled data type: ",
70                         input->output_dtypes()[0]));
71     }
72   }
73 
74  private:
75   // TODO(mrry): Push the templated code down to the raw copying routine.
76   template <class T>
77   class Dataset : public DatasetBase {
78    public:
Dataset(OpKernelContext * ctx,int64 batch_size,const PartialTensorShape & row_shape,const DatasetBase * input)79     Dataset(OpKernelContext* ctx, int64 batch_size,
80             const PartialTensorShape& row_shape, const DatasetBase* input)
81         : DatasetBase(DatasetContext(ctx)),
82           batch_size_(batch_size),
83           row_shape_(row_shape),
84           input_(input) {
85       input_->Ref();
86 
87       output_shapes_.reserve(1);
88       PartialTensorShape output_shape({-1});
89       output_shape.AppendShape(row_shape_);
90       output_shapes_.push_back(output_shape);
91     }
92 
~Dataset()93     ~Dataset() override { input_->Unref(); }
94 
MakeIteratorInternal(const string & prefix) const95     std::unique_ptr<IteratorBase> MakeIteratorInternal(
96         const string& prefix) const override {
97       return absl::make_unique<Iterator>(typename Iterator::Params{
98           this, strings::StrCat(prefix, "::DenseToSparseBatch")});
99     }
100 
output_dtypes() const101     const DataTypeVector& output_dtypes() const override {
102       static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT});
103       return *output_dtypes;
104     }
105 
output_shapes() const106     const std::vector<PartialTensorShape>& output_shapes() const override {
107       return output_shapes_;
108     }
109 
DebugString() const110     string DebugString() const override {
111       return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_,
112                              ")::Dataset");
113     }
114 
Cardinality() const115     int64 Cardinality() const override {
116       int64 n = input_->Cardinality();
117       if (n == kInfiniteCardinality || n == kUnknownCardinality) {
118         return n;
119       }
120       return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
121     }
122 
InputDatasets(std::vector<const DatasetBase * > * inputs) const123     Status InputDatasets(
124         std::vector<const DatasetBase*>* inputs) const override {
125       inputs->push_back(input_);
126       return Status::OK();
127     }
128 
CheckExternalState() const129     Status CheckExternalState() const override {
130       return input_->CheckExternalState();
131     }
132 
133    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134     Status AsGraphDefInternal(SerializationContext* ctx,
135                               DatasetGraphDefBuilder* b,
136                               Node** output) const override {
137       Node* input_node;
138       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
139       Node* batch_size_node;
140       TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
141       Node* row_shape_node;
142       std::vector<int64> row_shape;
143       row_shape.reserve(
144           row_shape_.dims());  // not an unknown rank PartialTensorShape
145       for (int i = 0; i < row_shape_.dims(); i++)
146         row_shape.emplace_back(row_shape_.dim_size(i));
147       TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node));
148       TF_RETURN_IF_ERROR(b->AddDataset(
149           this, {input_node, batch_size_node, row_shape_node}, output));
150       return Status::OK();
151     }
152 
153    private:
154     class Iterator : public DatasetIterator<Dataset<T>> {
155      public:
Iterator(const typename Iterator::Params & params)156       explicit Iterator(const typename Iterator::Params& params)
157           : DatasetIterator<Dataset<T>>(params) {}
158 
Initialize(IteratorContext * ctx)159       Status Initialize(IteratorContext* ctx) override {
160         return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator(
161             ctx, this, DatasetIterator<Dataset<T>>::prefix(), &input_impl_);
162       }
163 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)164       Status GetNextInternal(IteratorContext* ctx,
165                              std::vector<Tensor>* out_tensors,
166                              bool* end_of_sequence) override {
167         // Each row of the output SparseTensor is an individual tensor
168         // from the input iterator.
169         std::vector<Tensor> batch_elements;
170         int64 total_elements = 0;
171         batch_elements.reserve(
172             DatasetIterator<Dataset<T>>::dataset()->batch_size_);
173         const PartialTensorShape& row_shape =
174             DatasetIterator<Dataset<T>>::dataset()->row_shape_;
175         const int row_ndims = row_shape.dims();
176 
177         // Determine the size of the output tensors:
178         // * dense_shape will be [`row_shape + 1`].
179         Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1});
180         auto dense_shape_vec = dense_shape.vec<int64>();
181         for (size_t i = 0; i < row_ndims; ++i) {
182           if (row_shape.dim_size(i) == -1) {
183             dense_shape_vec(i + 1) = 0;
184           } else {
185             dense_shape_vec(i + 1) = row_shape.dim_size(i);
186           }
187         }
188 
189         {
190           mutex_lock l(mu_);
191           *end_of_sequence = false;
192           for (int i = 0;
193                i < DatasetIterator<Dataset<T>>::dataset()->batch_size_ &&
194                !*end_of_sequence;
195                ++i) {
196             std::vector<Tensor> batch_element_tuple;
197             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
198                                                     end_of_sequence));
199             if (!*end_of_sequence) {
200               DCHECK_EQ(1, batch_element_tuple.size());
201               batch_elements.push_back(std::move(batch_element_tuple[0]));
202               total_elements += batch_element_tuple[0].NumElements();
203 
204               // TODO(mrry): Investigate how to hoist this check when we
205               // have static information that renders it unnecessary.
206               if (batch_element_tuple[0].shape().dims() != row_ndims) {
207                 return errors::InvalidArgument(
208                     "Input element had shape (",
209                     batch_element_tuple[0].shape().DebugString(),
210                     ") that is incompatible with the row shape (",
211                     row_shape.DebugString(), ").");
212               }
213               for (int j = 0; j < row_ndims; ++j) {
214                 // Take the maximum in the dimension if -1 is given.
215                 if (row_shape.dim_size(j) == -1) {
216                   dense_shape_vec(j + 1) =
217                       std::max(batch_element_tuple[0].dim_size(j),
218                                dense_shape_vec(j + 1));
219                 } else if (batch_element_tuple[0].dim_size(j) >
220                            row_shape.dim_size(j)) {
221                   return errors::DataLoss(
222                       "Input element had shape (",
223                       batch_element_tuple[0].shape().DebugString(),
224                       ") that is larger than the row shape (",
225                       row_shape.DebugString(), ").");
226                 }
227               }
228             }
229           }
230         }
231 
232         if (batch_elements.empty()) {
233           DCHECK(*end_of_sequence);
234           return Status::OK();
235         }
236 
237         // * indices will be [`total_elements`, `row_shape + 1`].
238         // * values will be [`total_elements`].
239         Tensor indices(ctx->allocator({}), DT_INT64,
240                        {total_elements, row_ndims + 1});
241         Tensor values(
242             ctx->allocator({}),
243             DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0],
244             {total_elements});
245         auto indices_matrix = indices.matrix<int64>();
246         auto values_flat = values.flat<T>();
247 
248         int64 current_position_in_values = 0;
249         for (int64 i = 0; i < batch_elements.size(); ++i) {
250           const Tensor& t = batch_elements[i];
251           const auto& t_flat = t.flat<T>();
252           // TODO(mrry): Replace with a memcpy or something more
253           // efficient. (Maybe an Eigen assign op?)
254           gtl::InlinedVector<int64, 4> strides(row_ndims);
255           if (!strides.empty()) {
256             strides[row_ndims - 1] = 1;
257             for (int64_t row_dim = strides.size() - 2; row_dim >= 0;
258                  --row_dim) {
259               strides[row_dim] =
260                   strides[row_dim + 1] * t.shape().dim_size(row_dim + 1);
261             }
262           }
263 
264           for (int64 j = 0; j < t.NumElements(); ++j) {
265             values_flat(current_position_in_values) = t_flat(j);
266             indices_matrix(current_position_in_values, 0) = i;
267             int64 index = j;
268             for (size_t k = 0; k < strides.size(); ++k) {
269               indices_matrix(current_position_in_values, k + 1) =
270                   index / strides[k];
271               index %= strides[k];
272             }
273             ++current_position_in_values;
274           }
275         }
276 
277         dense_shape_vec(0) = batch_elements.size();
278 
279         Tensor serialized_sparse(DT_VARIANT, TensorShape({3}));
280         auto serialized_sparse_t = serialized_sparse.vec<Variant>();
281         serialized_sparse_t(0) = std::move(indices);
282         serialized_sparse_t(1) = std::move(values);
283         serialized_sparse_t(2) = std::move(dense_shape);
284         out_tensors->push_back(std::move(serialized_sparse));
285 
286         *end_of_sequence = false;
287         return Status::OK();
288       }
289 
290      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const291       std::shared_ptr<model::Node> CreateNode(
292           IteratorContext* ctx, model::Node::Args args) const override {
293         return model::MakeKnownRatioNode(
294             std::move(args),
295             DatasetIterator<Dataset<T>>::dataset()->batch_size_);
296       }
297 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)298       Status SaveInternal(SerializationContext* ctx,
299                           IteratorStateWriter* writer) override {
300         mutex_lock l(mu_);
301         TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_));
302         return Status::OK();
303       }
304 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)305       Status RestoreInternal(IteratorContext* ctx,
306                              IteratorStateReader* reader) override {
307         mutex_lock l(mu_);
308         TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_));
309         return Status::OK();
310       }
311 
312      private:
313       mutex mu_;
314       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
315     };
316 
317     const int64 batch_size_;
318     const PartialTensorShape row_shape_;
319     const DatasetBase* const input_;
320     std::vector<PartialTensorShape> output_shapes_;
321   };
322 };
323 
324 REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
325                         DenseToSparseBatchDatasetOp);
326 REGISTER_KERNEL_BUILDER(
327     Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU),
328     DenseToSparseBatchDatasetOp);
329 
330 }  // namespace
331 }  // namespace experimental
332 }  // namespace data
333 }  // namespace tensorflow
334