• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See core/ops/sparse_ops.cc for documentation.
17 //
18 // NOTE: the operations in this file only are suitable for execution
19 // on CPUs.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 
25 #include <numeric>
26 #include <sstream>
27 #include <string>
28 #include <unordered_map>
29 #include <utility>
30 
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/lib/strings/stringprintf.h"
38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
39 
40 namespace tensorflow {
41 
42 // Operator to convert sparse representations to dense.
43 template <typename T, typename Index>
44 class SparseToDense : public OpKernel {
45  public:
SparseToDense(OpKernelConstruction * context)46   explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {
47     OP_REQUIRES_OK(context,
48                    context->GetAttr("validate_indices", &validate_indices_));
49   }
50 
Compute(OpKernelContext * c)51   void Compute(OpKernelContext* c) override {
52     // sparse_indices
53     const Tensor& indices = c->input(0);
54     OP_REQUIRES(c, indices.dims() <= 2,
55                 errors::InvalidArgument(
56                     "sparse_indices should be a scalar, vector, or matrix, "
57                     "got shape ",
58                     indices.shape().DebugString()));
59     const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
60     const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;
61 
62     // output_shape
63     const Tensor& output_shape = c->input(1);
64     OP_REQUIRES(
65         c, IsLegacyVector(output_shape.shape()),
66         errors::InvalidArgument("output_shape should be a vector, got shape ",
67                                 output_shape.shape().DebugString()));
68     OP_REQUIRES(c, output_shape.NumElements() == num_dims,
69                 errors::InvalidArgument(
70                     "output_shape has incorrect number of elements: ",
71                     output_shape.NumElements(), " should be: ", num_dims));
72 
73     // sparse_values
74     const Tensor& sparse_values = c->input(2);
75     const int64 num_values = sparse_values.NumElements();
76     OP_REQUIRES(c,
77                 sparse_values.dims() == 0 ||
78                     (sparse_values.dims() == 1 && num_values == num_elems),
79                 errors::InvalidArgument("sparse_values has incorrect shape ",
80                                         sparse_values.shape().DebugString(),
81                                         ", should be [] or [", num_elems, "]"));
82 
83     // default_value
84     const Tensor& default_value = c->input(3);
85     OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()),
86                 errors::InvalidArgument("default_value should be a scalar."));
87 
88     auto output_shape_vec = output_shape.flat<Index>();
89     TensorShape output_tensor_shape;
90     OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(),
91                                                   output_shape_vec.size(),
92                                                   &output_tensor_shape));
93     Tensor* output = nullptr;
94     OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output));
95 
96     TensorShape ix_shape({num_elems, num_dims});
97     Tensor indices_shaped(DT_INT64, ix_shape);
98     if (indices.dtype() == DT_INT64) {
99       CHECK(indices_shaped.CopyFrom(indices, ix_shape));
100     } else {
101       indices_shaped.matrix<int64>() =
102           indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>();
103     }
104 
105     // If we received a scalar, we'll need to create a new
106     // tensor with copies of the values as a vec.
107     // TODO(ebrevdo): find a way to avoid this temp allocation.
108     Tensor sparse_values_b;
109 
110     if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
111       OP_REQUIRES_OK(
112           c, c->allocate_temp(DataTypeToEnum<T>::value,
113                               TensorShape({num_elems}), &sparse_values_b));
114       sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()());
115     } else {
116       sparse_values_b = sparse_values;
117     }
118 
119     // Assume SparseTensor is lexicographically sorted.
120     gtl::InlinedVector<int64, 8> order(output->shape().dims());
121     std::iota(order.begin(), order.end(), 0);
122     sparse::SparseTensor st;
123     OP_REQUIRES_OK(c,
124                    sparse::SparseTensor::Create(indices_shaped, sparse_values_b,
125                                                 output->shape(), order, &st));
126 
127     if (validate_indices_) {
128       OP_REQUIRES_OK(c, st.IndicesValid());
129     }
130 
131     output->flat<T>().setConstant(default_value.scalar<T>()());
132     OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
133                 errors::InvalidArgument(
134                     "Indices are not valid (out of bounds).  Shape: ",
135                     output->shape().DebugString()));
136   }
137 
138  private:
139   bool validate_indices_;
140 };
141 
142 #define REGISTER_KERNELS(type, index_type)                             \
143   REGISTER_KERNEL_BUILDER(Name("SparseToDense")                        \
144                               .Device(DEVICE_CPU)                      \
145                               .TypeConstraint<type>("T")               \
146                               .TypeConstraint<index_type>("Tindices"), \
147                           SparseToDense<type, index_type>);
148 
149 #define REGISTER_KERNELS_ALL(type) \
150   REGISTER_KERNELS(type, int32);   \
151   REGISTER_KERNELS(type, int64);
152 
153 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL);
154 REGISTER_KERNELS_ALL(bool);
155 REGISTER_KERNELS_ALL(string);
156 
157 #undef REGISTER_KERNELS_ALL
158 #undef REGISTER_KERNELS
159 
160 }  // namespace tensorflow
161