1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/dataset.h" 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/framework/variant.h" 19 20 namespace tensorflow { 21 namespace data { 22 namespace experimental { 23 namespace { 24 25 class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { 26 public: DenseToSparseBatchDatasetOp(OpKernelConstruction * ctx)27 explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx) 28 : UnaryDatasetOpKernel(ctx) {} 29 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)30 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 31 DatasetBase** output) override { 32 // Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the 33 // step-local container, and return it as the output. 34 OP_REQUIRES( 35 ctx, input->output_dtypes().size() == 1, 36 errors::InvalidArgument("DenseToSparseBatchDataset only supports " 37 "inputs with a single component.")); 38 39 int64 batch_size; 40 OP_REQUIRES_OK(ctx, 41 ParseScalarArgument<int64>(ctx, "batch_size", &batch_size)); 42 OP_REQUIRES( 43 ctx, batch_size > 0, 44 errors::InvalidArgument("Batch size must be greater than zero.")); 45 46 const Tensor* row_shape_t; 47 OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t)); 48 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()), 49 errors::InvalidArgument("row_shape must be a vector")); 50 PartialTensorShape row_shape; 51 OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape( 52 row_shape_t->vec<int64>().data(), 53 row_shape_t->NumElements(), &row_shape)); 54 55 *output = nullptr; 56 57 #define HANDLE_TYPE(T) \ 58 case DataTypeToEnum<T>::value: { \ 59 *output = new Dataset<T>(ctx, batch_size, row_shape, input); \ 60 break; \ 61 } 62 63 switch (input->output_dtypes()[0]) { 64 TF_CALL_DATASET_TYPES(HANDLE_TYPE); 65 #undef HANDLE_TYPE 66 default: 67 OP_REQUIRES(ctx, false, 68 errors::Unimplemented( 69 "DenseToSparseBatchDataset unhandled data type: ", 70 input->output_dtypes()[0])); 71 } 72 } 73 74 private: 75 // TODO(mrry): Push the templated code down to the raw copying routine. 76 template <class T> 77 class Dataset : public DatasetBase { 78 public: Dataset(OpKernelContext * ctx,int64 batch_size,const PartialTensorShape & row_shape,const DatasetBase * input)79 Dataset(OpKernelContext* ctx, int64 batch_size, 80 const PartialTensorShape& row_shape, const DatasetBase* input) 81 : DatasetBase(DatasetContext(ctx)), 82 batch_size_(batch_size), 83 row_shape_(row_shape), 84 input_(input) { 85 input_->Ref(); 86 87 output_shapes_.reserve(1); 88 PartialTensorShape output_shape({-1}); 89 output_shape.AppendShape(row_shape_); 90 output_shapes_.push_back(output_shape); 91 } 92 ~Dataset()93 ~Dataset() override { input_->Unref(); } 94 MakeIteratorInternal(const string & prefix) const95 std::unique_ptr<IteratorBase> MakeIteratorInternal( 96 const string& prefix) const override { 97 return absl::make_unique<Iterator>(typename Iterator::Params{ 98 this, strings::StrCat(prefix, "::DenseToSparseBatch")}); 99 } 100 output_dtypes() const101 const DataTypeVector& output_dtypes() const override { 102 static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT}); 103 return *output_dtypes; 104 } 105 output_shapes() const106 const std::vector<PartialTensorShape>& output_shapes() const override { 107 return output_shapes_; 108 } 109 DebugString() const110 string DebugString() const override { 111 return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_, 112 ")::Dataset"); 113 } 114 Cardinality() const115 int64 Cardinality() const override { 116 int64 n = input_->Cardinality(); 117 if (n == kInfiniteCardinality || n == kUnknownCardinality) { 118 return n; 119 } 120 return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1); 121 } 122 InputDatasets(std::vector<const DatasetBase * > * inputs) const123 Status InputDatasets( 124 std::vector<const DatasetBase*>* inputs) const override { 125 inputs->push_back(input_); 126 return Status::OK(); 127 } 128 CheckExternalState() const129 Status CheckExternalState() const override { 130 return input_->CheckExternalState(); 131 } 132 133 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134 Status AsGraphDefInternal(SerializationContext* ctx, 135 DatasetGraphDefBuilder* b, 136 Node** output) const override { 137 Node* input_node; 138 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); 139 Node* batch_size_node; 140 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); 141 Node* row_shape_node; 142 std::vector<int64> row_shape; 143 row_shape.reserve( 144 row_shape_.dims()); // not an unknown rank PartialTensorShape 145 for (int i = 0; i < row_shape_.dims(); i++) 146 row_shape.emplace_back(row_shape_.dim_size(i)); 147 TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node)); 148 TF_RETURN_IF_ERROR(b->AddDataset( 149 this, {input_node, batch_size_node, row_shape_node}, output)); 150 return Status::OK(); 151 } 152 153 private: 154 class Iterator : public DatasetIterator<Dataset<T>> { 155 public: Iterator(const typename Iterator::Params & params)156 explicit Iterator(const typename Iterator::Params& params) 157 : DatasetIterator<Dataset<T>>(params) {} 158 Initialize(IteratorContext * ctx)159 Status Initialize(IteratorContext* ctx) override { 160 return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator( 161 ctx, this, DatasetIterator<Dataset<T>>::prefix(), &input_impl_); 162 } 163 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)164 Status GetNextInternal(IteratorContext* ctx, 165 std::vector<Tensor>* out_tensors, 166 bool* end_of_sequence) override { 167 // Each row of the output SparseTensor is an individual tensor 168 // from the input iterator. 169 std::vector<Tensor> batch_elements; 170 int64 total_elements = 0; 171 batch_elements.reserve( 172 DatasetIterator<Dataset<T>>::dataset()->batch_size_); 173 const PartialTensorShape& row_shape = 174 DatasetIterator<Dataset<T>>::dataset()->row_shape_; 175 const int row_ndims = row_shape.dims(); 176 177 // Determine the size of the output tensors: 178 // * dense_shape will be [`row_shape + 1`]. 179 Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1}); 180 auto dense_shape_vec = dense_shape.vec<int64>(); 181 for (size_t i = 0; i < row_ndims; ++i) { 182 if (row_shape.dim_size(i) == -1) { 183 dense_shape_vec(i + 1) = 0; 184 } else { 185 dense_shape_vec(i + 1) = row_shape.dim_size(i); 186 } 187 } 188 189 { 190 mutex_lock l(mu_); 191 *end_of_sequence = false; 192 for (int i = 0; 193 i < DatasetIterator<Dataset<T>>::dataset()->batch_size_ && 194 !*end_of_sequence; 195 ++i) { 196 std::vector<Tensor> batch_element_tuple; 197 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple, 198 end_of_sequence)); 199 if (!*end_of_sequence) { 200 DCHECK_EQ(1, batch_element_tuple.size()); 201 batch_elements.push_back(std::move(batch_element_tuple[0])); 202 total_elements += batch_element_tuple[0].NumElements(); 203 204 // TODO(mrry): Investigate how to hoist this check when we 205 // have static information that renders it unnecessary. 206 if (batch_element_tuple[0].shape().dims() != row_ndims) { 207 return errors::InvalidArgument( 208 "Input element had shape (", 209 batch_element_tuple[0].shape().DebugString(), 210 ") that is incompatible with the row shape (", 211 row_shape.DebugString(), ")."); 212 } 213 for (int j = 0; j < row_ndims; ++j) { 214 // Take the maximum in the dimension if -1 is given. 215 if (row_shape.dim_size(j) == -1) { 216 dense_shape_vec(j + 1) = 217 std::max(batch_element_tuple[0].dim_size(j), 218 dense_shape_vec(j + 1)); 219 } else if (batch_element_tuple[0].dim_size(j) > 220 row_shape.dim_size(j)) { 221 return errors::DataLoss( 222 "Input element had shape (", 223 batch_element_tuple[0].shape().DebugString(), 224 ") that is larger than the row shape (", 225 row_shape.DebugString(), ")."); 226 } 227 } 228 } 229 } 230 } 231 232 if (batch_elements.empty()) { 233 DCHECK(*end_of_sequence); 234 return Status::OK(); 235 } 236 237 // * indices will be [`total_elements`, `row_shape + 1`]. 238 // * values will be [`total_elements`]. 239 Tensor indices(ctx->allocator({}), DT_INT64, 240 {total_elements, row_ndims + 1}); 241 Tensor values( 242 ctx->allocator({}), 243 DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0], 244 {total_elements}); 245 auto indices_matrix = indices.matrix<int64>(); 246 auto values_flat = values.flat<T>(); 247 248 int64 current_position_in_values = 0; 249 for (int64 i = 0; i < batch_elements.size(); ++i) { 250 const Tensor& t = batch_elements[i]; 251 const auto& t_flat = t.flat<T>(); 252 // TODO(mrry): Replace with a memcpy or something more 253 // efficient. (Maybe an Eigen assign op?) 254 gtl::InlinedVector<int64, 4> strides(row_ndims); 255 if (!strides.empty()) { 256 strides[row_ndims - 1] = 1; 257 for (int64_t row_dim = strides.size() - 2; row_dim >= 0; 258 --row_dim) { 259 strides[row_dim] = 260 strides[row_dim + 1] * t.shape().dim_size(row_dim + 1); 261 } 262 } 263 264 for (int64 j = 0; j < t.NumElements(); ++j) { 265 values_flat(current_position_in_values) = t_flat(j); 266 indices_matrix(current_position_in_values, 0) = i; 267 int64 index = j; 268 for (size_t k = 0; k < strides.size(); ++k) { 269 indices_matrix(current_position_in_values, k + 1) = 270 index / strides[k]; 271 index %= strides[k]; 272 } 273 ++current_position_in_values; 274 } 275 } 276 277 dense_shape_vec(0) = batch_elements.size(); 278 279 Tensor serialized_sparse(DT_VARIANT, TensorShape({3})); 280 auto serialized_sparse_t = serialized_sparse.vec<Variant>(); 281 serialized_sparse_t(0) = std::move(indices); 282 serialized_sparse_t(1) = std::move(values); 283 serialized_sparse_t(2) = std::move(dense_shape); 284 out_tensors->push_back(std::move(serialized_sparse)); 285 286 *end_of_sequence = false; 287 return Status::OK(); 288 } 289 290 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const291 std::shared_ptr<model::Node> CreateNode( 292 IteratorContext* ctx, model::Node::Args args) const override { 293 return model::MakeKnownRatioNode( 294 std::move(args), 295 DatasetIterator<Dataset<T>>::dataset()->batch_size_); 296 } 297 SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)298 Status SaveInternal(SerializationContext* ctx, 299 IteratorStateWriter* writer) override { 300 mutex_lock l(mu_); 301 TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_)); 302 return Status::OK(); 303 } 304 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)305 Status RestoreInternal(IteratorContext* ctx, 306 IteratorStateReader* reader) override { 307 mutex_lock l(mu_); 308 TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_)); 309 return Status::OK(); 310 } 311 312 private: 313 mutex mu_; 314 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_); 315 }; 316 317 const int64 batch_size_; 318 const PartialTensorShape row_shape_; 319 const DatasetBase* const input_; 320 std::vector<PartialTensorShape> output_shapes_; 321 }; 322 }; 323 324 REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU), 325 DenseToSparseBatchDatasetOp); 326 REGISTER_KERNEL_BUILDER( 327 Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU), 328 DenseToSparseBatchDatasetOp); 329 330 } // namespace 331 } // namespace experimental 332 } // namespace data 333 } // namespace tensorflow 334