1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <numeric> 16 17 #include "tensorflow/core/framework/dataset.h" 18 #include "tensorflow/core/framework/partial_tensor_shape.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/util/sparse/sparse_tensor.h" 22 23 namespace tensorflow { 24 namespace data { 25 namespace { 26 27 // See documentation in ../../ops/dataset_ops.cc for a high-level 28 // description of the following op. 29 30 template <typename T> 31 class Dataset : public DatasetBase { 32 public: Dataset(OpKernelContext * ctx,const sparse::SparseTensor & sparse_tensor)33 explicit Dataset(OpKernelContext* ctx, 34 const sparse::SparseTensor& sparse_tensor) 35 : DatasetBase(DatasetContext(ctx)), 36 sparse_tensor_(sparse_tensor), 37 dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}), 38 shapes_({{-1, sparse_tensor.dims() - 1}, 39 {-1}, 40 {sparse_tensor.dims() - 1}}) {} 41 MakeIteratorInternal(const string & prefix) const42 std::unique_ptr<IteratorBase> MakeIteratorInternal( 43 const string& prefix) const override { 44 return std::make_unique<Iterator>(typename Iterator::Params{ 45 this, strings::StrCat(prefix, "::SparseTensorSlice")}); 46 } 47 output_dtypes() const48 const DataTypeVector& output_dtypes() const override { return dtypes_; } output_shapes() const49 const std::vector<PartialTensorShape>& output_shapes() const override { 50 return shapes_; 51 } 52 DebugString() const53 string DebugString() const override { 54 return "SparseTensorSliceDatasetOp::Dataset"; 55 } 56 CardinalityInternal() const57 int64_t CardinalityInternal() const override { 58 return sparse_tensor_.shape()[0]; 59 } 60 InputDatasets(std::vector<const DatasetBase * > * inputs) const61 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override { 62 return OkStatus(); 63 } 64 CheckExternalState() const65 Status CheckExternalState() const override { return OkStatus(); } 66 67 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const68 Status AsGraphDefInternal(SerializationContext* ctx, 69 DatasetGraphDefBuilder* b, 70 Node** output) const override { 71 Node* indices_node; 72 TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node)); 73 Node* value_node; 74 TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.values(), &value_node)); 75 Node* dense_shape_node; 76 std::vector<int64_t> dense_shape; 77 dense_shape.reserve(sparse_tensor_.shape().size()); 78 for (int i = 0; i < sparse_tensor_.shape().size(); i++) 79 dense_shape.emplace_back(sparse_tensor_.shape()[i]); 80 TF_RETURN_IF_ERROR(b->AddVector(dense_shape, &dense_shape_node)); 81 AttrValue val_dtype; 82 b->BuildAttrValue(sparse_tensor_.dtype(), &val_dtype); 83 TF_RETURN_IF_ERROR( 84 b->AddDataset(this, {indices_node, value_node, dense_shape_node}, 85 {{"Tvalues", val_dtype}}, output)); 86 return OkStatus(); 87 } 88 89 private: 90 class Iterator : public DatasetIterator<Dataset<T>> { 91 public: Iterator(const typename Iterator::Params & params)92 explicit Iterator(const typename Iterator::Params& params) 93 : DatasetIterator<Dataset<T>>(params), 94 num_elements_(params.dataset->sparse_tensor_.shape()[0]), 95 dense_shape_(DT_INT64, {params.dataset->sparse_tensor_.dims() - 1}), 96 group_iterable_(params.dataset->sparse_tensor_.group({0})), 97 iter_(group_iterable_.begin()) { 98 for (size_t i = 0; i < dense_shape_.NumElements(); ++i) { 99 dense_shape_.vec<int64_t>()(i) = 100 params.dataset->sparse_tensor_.shape()[i + 1]; 101 } 102 } 103 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)104 Status GetNextInternal(IteratorContext* ctx, 105 std::vector<Tensor>* out_tensors, 106 bool* end_of_sequence) override { 107 mutex_lock l(mu_); 108 if (i_ == num_elements_) { 109 *end_of_sequence = true; 110 return OkStatus(); 111 } 112 113 out_tensors->clear(); 114 out_tensors->reserve(3); 115 const int rank = Iterator::dataset()->sparse_tensor_.dims(); 116 117 if (i_ > next_non_empty_i_ && iter_ != group_iterable_.end()) { 118 // We still have elements to consume from `group_iterable_` 119 // and we have emitted all elements up to and including the 120 // current position. 121 sparse::Group group = *iter_; 122 const auto indices = group.indices(); 123 const auto values = group.values<T>(); 124 const int64_t num_entries = values.size(); 125 next_non_empty_i_ = indices(0, 0); 126 127 next_indices_ = Tensor(DT_INT64, {num_entries, rank - 1}); 128 next_values_ = Tensor(DataTypeToEnum<T>::value, {num_entries}); 129 130 auto next_indices_t = next_indices_.matrix<int64_t>(); 131 auto next_values_t = next_values_.vec<T>(); 132 133 for (int64_t i = 0; i < num_entries; ++i) { 134 for (int d = 1; d < rank; ++d) { 135 next_indices_t(i, d - 1) = indices(i, d); 136 } 137 next_values_t(i) = values(i); 138 } 139 140 ++iter_; 141 } 142 if (i_ == next_non_empty_i_) { 143 // The current position is non-empty in the input 144 // `SparseTensor`, and we have already read the value from the 145 // `GroupIterable`. 146 out_tensors->push_back(std::move(next_indices_)); 147 out_tensors->push_back(std::move(next_values_)); 148 out_tensors->push_back(dense_shape_); 149 next_non_empty_i_ = kNextNonEmptyUnknown; 150 } else { 151 DCHECK(i_ < next_non_empty_i_ || iter_ == group_iterable_.end()); 152 // The current position is empty in the input `SparseTensor`, 153 // so emit empty indices and values. 154 out_tensors->push_back(Tensor(DT_INT64, TensorShape({0, rank - 1}))); 155 out_tensors->push_back(Tensor(DataTypeToEnum<T>::value, {0})); 156 out_tensors->push_back(dense_shape_); 157 } 158 159 ++i_; 160 *end_of_sequence = false; 161 return OkStatus(); 162 } 163 164 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const165 std::shared_ptr<model::Node> CreateNode( 166 IteratorContext* ctx, model::Node::Args args) const override { 167 return model::MakeSourceNode(std::move(args)); 168 } 169 SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)170 Status SaveInternal(SerializationContext* ctx, 171 IteratorStateWriter* writer) override { 172 mutex_lock l(mu_); 173 TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_)); 174 TF_RETURN_IF_ERROR( 175 writer->WriteScalar(Iterator::full_name("iter_loc"), iter_.loc())); 176 TF_RETURN_IF_ERROR(writer->WriteScalar( 177 Iterator::full_name("next_non_empty_i_"), next_non_empty_i_)); 178 if (i_ <= next_non_empty_i_) { 179 TF_RETURN_IF_ERROR(writer->WriteTensor( 180 Iterator::full_name("next_indices_"), next_indices_)); 181 TF_RETURN_IF_ERROR(writer->WriteTensor( 182 Iterator::full_name("next_values_"), next_values_)); 183 } 184 return OkStatus(); 185 } 186 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)187 Status RestoreInternal(IteratorContext* ctx, 188 IteratorStateReader* reader) override { 189 mutex_lock l(mu_); 190 TF_RETURN_IF_ERROR(reader->ReadScalar(Iterator::full_name("i"), &i_)); 191 int64_t iter_loc; 192 TF_RETURN_IF_ERROR( 193 reader->ReadScalar(Iterator::full_name("iter_loc"), &iter_loc)); 194 iter_ = group_iterable_.at(iter_loc); 195 TF_RETURN_IF_ERROR(reader->ReadScalar( 196 Iterator::full_name("next_non_empty_i_"), &next_non_empty_i_)); 197 if (i_ <= next_non_empty_i_) { 198 TF_RETURN_IF_ERROR(reader->ReadTensor( 199 Iterator::full_name("next_indices_"), &next_indices_)); 200 TF_RETURN_IF_ERROR(reader->ReadTensor( 201 Iterator::full_name("next_values_"), &next_values_)); 202 } 203 return OkStatus(); 204 } 205 206 private: 207 const int64_t num_elements_; 208 209 Tensor dense_shape_; 210 211 mutex mu_; 212 sparse::GroupIterable group_iterable_ TF_GUARDED_BY(mu_); 213 sparse::GroupIterable::IteratorStep iter_ TF_GUARDED_BY(mu_); 214 int64_t i_ TF_GUARDED_BY(mu_) = 0; 215 const int64_t kNextNonEmptyUnknown = -1; 216 int64_t next_non_empty_i_ TF_GUARDED_BY(mu_) = kNextNonEmptyUnknown; 217 Tensor next_indices_ TF_GUARDED_BY(mu_); 218 Tensor next_values_ TF_GUARDED_BY(mu_); 219 }; 220 221 const sparse::SparseTensor sparse_tensor_; 222 const DataTypeVector dtypes_; 223 const std::vector<PartialTensorShape> shapes_; 224 }; 225 226 template <typename T> 227 class SparseTensorSliceDatasetOp : public DatasetOpKernel { 228 public: SparseTensorSliceDatasetOp(OpKernelConstruction * ctx)229 explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx) 230 : DatasetOpKernel(ctx) {} 231 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)232 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 233 // Create a new SparseTensorSliceDatasetOp::Dataset, insert it in 234 // the step container, and return it as the output. 235 const Tensor* indices; 236 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices)); 237 const Tensor* values; 238 OP_REQUIRES_OK(ctx, ctx->input("values", &values)); 239 const Tensor* dense_shape; 240 OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape)); 241 242 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()), 243 errors::InvalidArgument("Input indices must be a matrix. Got: ", 244 indices->shape().DebugString())); 245 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()), 246 errors::InvalidArgument("Input values must be a vector. Got: ", 247 values->shape().DebugString())); 248 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()), 249 errors::InvalidArgument("Input shape must be a vector. Got: ", 250 dense_shape->shape().DebugString())); 251 OP_REQUIRES( 252 ctx, values->shape().dim_size(0) == indices->shape().dim_size(0), 253 errors::InvalidArgument( 254 "Number of values must match first dimension of indices. ", "Got ", 255 values->shape().dim_size(0), 256 " values, indices shape: ", indices->shape().DebugString())); 257 OP_REQUIRES( 258 ctx, dense_shape->shape().dim_size(0) == indices->shape().dim_size(1), 259 errors::InvalidArgument( 260 "Number of dimensions must match second dimension of indices. ", 261 "Got ", dense_shape->shape().dim_size(0), 262 " dimensions, indices shape: ", indices->shape().DebugString())); 263 OP_REQUIRES(ctx, dense_shape->NumElements() > 0, 264 errors::InvalidArgument( 265 "The shape argument requires at least one element.")); 266 267 // We currently ensure that `sparse_tensor` is ordered in the 268 // batch dimension. 269 // TODO(mrry): Investigate ways to avoid this unconditional check 270 // if we can be sure that the sparse tensor was produced in an 271 // appropriate order (e.g. by `tf.parse_example()` or a Dataset 272 // that batches elements into rows of a SparseTensor). 273 int64_t previous_batch_index = -1; 274 for (int64_t i = 0; i < indices->dim_size(0); ++i) { 275 int64_t next_batch_index = indices->matrix<int64_t>()(i, 0); 276 OP_REQUIRES( 277 ctx, next_batch_index >= previous_batch_index, 278 errors::Unimplemented("The SparseTensor must be ordered in the batch " 279 "dimension; handling arbitrarily ordered input " 280 "is not currently supported.")); 281 previous_batch_index = next_batch_index; 282 } 283 gtl::InlinedVector<int64_t, 8> std_order(dense_shape->NumElements(), 0); 284 TensorShape shape; 285 OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape( 286 dense_shape->vec<int64_t>(), &shape)); 287 sparse::SparseTensor tensor; 288 OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(*indices, *values, shape, 289 std_order, &tensor)); 290 *output = new Dataset<T>(ctx, std::move(tensor)); 291 } 292 293 private: 294 }; 295 296 #define REGISTER_DATASET_KERNEL(type) \ 297 REGISTER_KERNEL_BUILDER(Name("SparseTensorSliceDataset") \ 298 .Device(DEVICE_CPU) \ 299 .TypeConstraint<type>("Tvalues"), \ 300 SparseTensorSliceDatasetOp<type>); 301 302 TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL); 303 #undef REGISTER_DATASET_KERNEL 304 305 } // namespace 306 } // namespace data 307 } // namespace tensorflow 308