1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
18
19 #define EIGEN_USE_THREADS
20
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_encode_decode.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32
33 namespace tensorflow {
34
35 class CSRSparseMatrix {
36 // CreateCSRSparseMatrix is the main method used to construct a
37 // CSRSparseMatrix. The representations for both 2D and 3D
38 // (batched) CSR Sparse Matrices are the same:
39 //
40 // dtype: The datatype of the values.
41 // dense_shape: The dense shape of the matrix.
42 // * Host int64 vector, size 2 or 3.
43 // * Takes on values: (rows, cols) or (batch_size, rows, cols).
44 // batch_pointers: Batch offset pointers into col_indices and values.
45 // * Host int32 vector, size (batch_size + 1).
46 // * Takes on values: (0, nnz[0], nnz[0] + nnz[1], ..., total_nnz).
47 // row_pointers: Row offset pointers into col_indices and values.
48 // * Device int32 vector, size ((rows + 1) * batch_size).
49 // * Each block of size (rows + 1) takes on values:
50 // (0, num_rows{b}[0], num_rows{b}[0] + num_rows{b}[1], ..., nnz[b]).
51 // for b = 0 .. batch_size - 1.
52 // col_indices: Column values for the given row and column index.
53 // * Device int32 vector, size total_nnz.
54 // values: Actual values for the given row and column index.
55 // * Device dtype vector, size total_nnz.
56 //
57 // The storage agreement is such that for a given (batch, row, ix):
58 // offset = batch_pointers(batch) + row_pointers(batch * (rows + 1) + row)
59 // col = col_indices(offset + ix)
60 // val = values(offset + ix)
61 // where ix < #nnz columns in (batch, row).
62 // Then:
63 // matrix(batch, row, col) = val.
64 //
65 // All other elements in the dense representation are treated as 0 / empty.
66 //
67 // For example, for a 2D sparse matrix m shaped (3, 4) such that:
68 //
69 // m[0, 0] = 1.0
70 // m[0, 1] = 2.0
71 // m[0, 2] = 3.0
72 // m[2, 2] = 4.0
73 // m[2, 3] = 5.0
74 //
75 // The corresponding representation is:
76 //
77 // dtype: DT_FLOAT
78 // dense_shape: (3, 4)
79 // batch_pointers: (0, 5)
80 // row_pointers: (0, 3, 3, 5)
81 // col_indices: concat((0, 1, 2), (), (2, 3))
82 // values: concat((1.0, 2.0, 3.0), (), (4.0, 5.0))
83 //
84 // For a 3D sparse matrix m shaped (2, 3, 4) such that:
85 //
86 // m[0, 0, 0] = 1.0
87 // m[0, 0, 2] = 2.0
88 // m[0, 2, 3] = 3.0
89 // m[1, 0, 3] = 4.0
90 // m[1, 1, 0] = 5.0
91 //
92 // The corresponding representation is:
93 // dtype: DT_FLOAT
94 // dense_shape: (2, 3, 4)
95 // batch_pointers: (0, 3, 5)
96 // row_pointers: concat((0, 2, 2, 3), (0, 1, 2, 2))
97 // col_indices: concat(concat((0, 2), (), (3,)),
98 // concat((3,), (), (0,)))
99 // values: concat(concat((1.0, 2.0), (3.0,), ()),
100 /// concat((4.0,), (5.0,), ()))
101 //
102 public:
103 static constexpr const char kTypeName[] = "tensorflow::CSRSparseMatrix";
104
CSRSparseMatrix()105 CSRSparseMatrix() : metadata_{false, DT_INVALID} {}
106
CSRSparseMatrix(const CSRSparseMatrix & rhs)107 CSRSparseMatrix(const CSRSparseMatrix& rhs)
108 : metadata_(rhs.metadata_),
109 dense_shape_(rhs.dense_shape_),
110 batch_pointers_(rhs.batch_pointers_),
111 row_pointers_(rhs.row_pointers_),
112 col_indices_(rhs.col_indices_),
113 values_(rhs.values_) {
114 SetupVecs();
115 }
116
CSRSparseMatrix(CSRSparseMatrix && rhs)117 CSRSparseMatrix(CSRSparseMatrix&& rhs)
118 : metadata_(rhs.metadata_),
119 dense_shape_(std::move(rhs.dense_shape_)),
120 batch_pointers_(std::move(rhs.batch_pointers_)),
121 row_pointers_(std::move(rhs.row_pointers_)),
122 col_indices_(std::move(rhs.col_indices_)),
123 values_(std::move(rhs.values_)) {
124 SetupVecs();
125 rhs.metadata_.validated = false;
126 rhs.metadata_.dtype = DT_INVALID;
127 rhs.ClearVecs();
128 }
129
130 CSRSparseMatrix& operator=(CSRSparseMatrix&& rhs) {
131 if (this == &rhs) return *this;
132 metadata_ = rhs.metadata_;
133 metadata_.validated = rhs.metadata_.validated;
134 dense_shape_ = std::move(rhs.dense_shape_);
135 batch_pointers_ = std::move(rhs.batch_pointers_);
136 row_pointers_ = std::move(rhs.row_pointers_);
137 col_indices_ = std::move(rhs.col_indices_);
138 values_ = std::move(rhs.values_);
139 SetupVecs();
140 rhs.metadata_ = {false, DT_INVALID};
141 rhs.ClearVecs();
142 return *this;
143 }
144
CreateCSRSparseMatrix(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values,CSRSparseMatrix * matrix)145 static Status CreateCSRSparseMatrix(DataType dtype,
146 const Tensor& dense_shape, // on host
147 const Tensor& batch_pointers, // on host
148 const Tensor& row_pointers,
149 const Tensor& col_indices,
150 const Tensor& values,
151 CSRSparseMatrix* matrix) {
152 *matrix = CSRSparseMatrix(dtype, dense_shape, batch_pointers, row_pointers,
153 col_indices, values);
154 Status s = matrix->Validate();
155 matrix->metadata_.validated = s.ok();
156 matrix->SetupVecs();
157 return s;
158 }
159
Validate()160 Status Validate() const {
161 return ValidateTypesAndShapes(metadata_.dtype, dense_shape_,
162 batch_pointers_, row_pointers_, col_indices_,
163 values_);
164 }
165
Clear()166 void Clear() {
167 metadata_ = {false, DT_INVALID};
168 dense_shape_ = Tensor();
169 batch_pointers_ = Tensor();
170 row_pointers_ = Tensor();
171 col_indices_ = Tensor();
172 values_ = Tensor();
173 ClearVecs();
174 }
175
valid()176 bool valid() const {
177 return metadata_.validated && dense_shape_.IsInitialized() &&
178 batch_pointers_.IsInitialized() && row_pointers_.IsInitialized() &&
179 col_indices_.IsInitialized() && values_.IsInitialized() &&
180 dense_shape_.NumElements() > 1 &&
181 batch_pointers_.NumElements() > 0 && row_pointers_.NumElements() > 0;
182 }
183
dtype()184 DataType dtype() const {
185 DCHECK(valid());
186 return metadata_.dtype;
187 }
188
dims()189 inline int dims() const {
190 DCHECK(valid());
191 return dense_shape_.NumElements();
192 }
193
nnz(int batch)194 inline int nnz(int batch) const {
195 DCHECK_LT(batch, batch_size());
196 return (*batch_pointers_vec_)(batch + 1) - (*batch_pointers_vec_)(batch);
197 }
198
batch_offset(int batch)199 inline int batch_offset(int batch) const {
200 DCHECK_LT(batch, batch_size());
201 return (*batch_pointers_vec_)(batch);
202 }
203
total_nnz()204 inline int total_nnz() const {
205 DCHECK(valid());
206 return (*batch_pointers_vec_)(batch_size());
207 }
208
dense_shape()209 inline Tensor& dense_shape() {
210 DCHECK(valid());
211 return dense_shape_;
212 }
213
dense_shape()214 inline const Tensor& dense_shape() const {
215 DCHECK(valid());
216 return dense_shape_;
217 }
218
row_pointers_vec(int batch)219 inline TTypes<int32>::UnalignedVec row_pointers_vec(int batch) {
220 DCHECK(valid());
221 DCHECK_LT(batch, batch_size());
222 const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
223 const int offset = batch * (rows + 1);
224 return TTypes<int32>::UnalignedVec(row_pointers_vec_->data() + offset,
225 rows + 1);
226 }
227
row_pointers_vec(int batch)228 inline TTypes<int32>::UnalignedConstVec row_pointers_vec(int batch) const {
229 DCHECK(valid());
230 DCHECK_LT(batch, batch_size());
231 const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
232 const int offset = batch * (rows + 1);
233 return TTypes<int32>::UnalignedConstVec(row_pointers_vec_->data() + offset,
234 rows + 1);
235 }
236
col_indices_vec(int batch)237 inline TTypes<int32>::UnalignedVec col_indices_vec(int batch) {
238 DCHECK(valid());
239 DCHECK_LT(batch, batch_size());
240 const int offset = (*batch_pointers_vec_)(batch);
241 const int nnz_in_batch = nnz(batch);
242 return TTypes<int32>::UnalignedVec(col_indices_vec_->data() + offset,
243 nnz_in_batch);
244 }
245
col_indices_vec(int batch)246 inline TTypes<int32>::UnalignedConstVec col_indices_vec(int batch) const {
247 DCHECK(valid());
248 DCHECK_LT(batch, batch_size());
249 const int offset = (*batch_pointers_vec_)(batch);
250 const int nnz_in_batch = nnz(batch);
251 return TTypes<int32>::UnalignedConstVec(col_indices_vec_->data() + offset,
252 nnz_in_batch);
253 }
254
255 template <typename T>
values_vec(int batch)256 inline typename TTypes<T>::UnalignedVec values_vec(int batch) {
257 DCHECK(valid());
258 DCHECK_LT(batch, batch_size());
259 const int offset = (*batch_pointers_vec_)(batch);
260 const int nnz_in_batch = nnz(batch);
261 return typename TTypes<T>::UnalignedVec(&(values().vec<T>()(offset)),
262 nnz_in_batch);
263 }
264
265 template <typename T>
values_vec(int batch)266 inline typename TTypes<T>::UnalignedConstVec values_vec(int batch) const {
267 DCHECK(valid());
268 DCHECK_LT(batch, batch_size());
269 const int offset = (*batch_pointers_vec_)(batch);
270 const int nnz_in_batch = nnz(batch);
271 return typename TTypes<T>::UnalignedConstVec(&(values().vec<T>()(offset)),
272 nnz_in_batch);
273 }
274
row_pointers()275 inline Tensor& row_pointers() {
276 DCHECK(valid());
277 return row_pointers_;
278 }
279
row_pointers()280 inline const Tensor& row_pointers() const {
281 DCHECK(valid());
282 return row_pointers_;
283 }
284
col_indices()285 inline Tensor& col_indices() {
286 DCHECK(valid());
287 return col_indices_;
288 }
289
col_indices()290 inline const Tensor& col_indices() const {
291 DCHECK(valid());
292 return col_indices_;
293 }
294
values()295 inline Tensor& values() {
296 DCHECK(valid());
297 return values_;
298 }
299
values()300 inline const Tensor& values() const {
301 DCHECK(valid());
302 return values_;
303 }
304
batch_pointers()305 inline Tensor& batch_pointers() {
306 DCHECK(valid());
307 return batch_pointers_;
308 }
309
batch_pointers()310 inline const Tensor& batch_pointers() const {
311 DCHECK(valid());
312 return batch_pointers_;
313 }
314
TypeName()315 std::string TypeName() const { return kTypeName; }
316
317 // TODO(ebrevdo): A better debug string.
DebugString()318 std::string DebugString() const { return dense_shape_.DebugString(); }
319
320 // Returns the number of elements. This is equal to 1 if the
321 // CSRSparseMatrix is a singleton matrix (dense_shape is length 2).
batch_size()322 int batch_size() const {
323 DCHECK(valid());
324 return batch_pointers_.NumElements() - 1;
325 }
326
Decode(const VariantTensorData & p)327 bool Decode(const VariantTensorData& p) {
328 if (p.tensors_.empty()) return false;
329 Metadata metadata;
330 if (!p.get_metadata(&metadata)) return false;
331 const bool validated = metadata.validated;
332 const DataType dtype = metadata.dtype;
333
334 // p.tensors_ should contain tensors {dense_shape, batch_pointers,
335 // row_pointers, col_indices, values}.
336 if (p.tensors_.size() != 5) return false;
337
338 Tensor dense_shape = p.tensors_[0];
339 if (dense_shape.dtype() != DT_INT64) return false;
340 if (dense_shape.dims() != 1) return false;
341 int rank = dense_shape.dim_size(0);
342 if (rank < 2 || rank > 3) return false;
343
344 Tensor batch_pointers(p.tensors_[1]);
345 Tensor row_pointers(p.tensors_[2]);
346 Tensor col_indices(p.tensors_[3]);
347 Tensor values(p.tensors_[4]);
348
349 // Check that the validated bool is consistent with the data.
350 Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers,
351 row_pointers, col_indices, values);
352 if (s.ok() != validated) return false;
353
354 // Save to this object.
355 metadata_ = metadata;
356 dense_shape_ = std::move(dense_shape);
357 batch_pointers_ = std::move(batch_pointers);
358 row_pointers_ = std::move(row_pointers);
359 col_indices_ = std::move(col_indices);
360 values_ = std::move(values);
361 SetupVecs();
362 return true;
363 }
364
Encode(VariantTensorData * p)365 void Encode(VariantTensorData* p) const {
366 DCHECK(valid());
367
368 // Store metadata_ to p's metadata
369 p->set_metadata(metadata_);
370
371 // Store dense_shape, row_pointers, col_indices, and values to p->tensors_.
372 p->tensors_.reserve(5);
373 p->tensors_.push_back(dense_shape_);
374 p->tensors_.push_back(batch_pointers_);
375 p->tensors_.push_back(row_pointers_);
376 p->tensors_.push_back(col_indices_);
377 p->tensors_.push_back(values_);
378 }
379
380 // This static method copies CSRSparseMatrices in all directions:
381 // Host->Device, Device->Host, and Device->Device.
DeviceCopy(const CSRSparseMatrix & from,CSRSparseMatrix * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)382 static Status DeviceCopy(
383 const CSRSparseMatrix& from, CSRSparseMatrix* to,
384 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
385 VLOG(2) << "DeviceCopy from type: " << DataTypeString(from.dtype())
386 << " and shape: " << from.dense_shape().DebugString();
387 Tensor to_row_ptr(DT_INT32);
388 Tensor to_col_ind(DT_INT32);
389 Tensor to_values(from.dtype());
390 TF_RETURN_IF_ERROR(copy(from.row_pointers(), &to_row_ptr));
391 TF_RETURN_IF_ERROR(copy(from.col_indices(), &to_col_ind));
392 TF_RETURN_IF_ERROR(copy(from.values(), &to_values));
393 return CreateCSRSparseMatrix(from.dtype(),
394 from.dense_shape(), // Always on host.
395 from.batch_pointers(), // Always on host.
396 to_row_ptr, to_col_ind, to_values, to);
397 }
398
399 private:
CSRSparseMatrix(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values)400 CSRSparseMatrix(DataType dtype, const Tensor& dense_shape,
401 const Tensor& batch_pointers, const Tensor& row_pointers,
402 const Tensor& col_indices, const Tensor& values)
403 : metadata_{false, dtype},
404 dense_shape_(dense_shape),
405 batch_pointers_(batch_pointers),
406 row_pointers_(row_pointers),
407 col_indices_(col_indices),
408 values_(values) {}
409
SetupVecs()410 void SetupVecs() {
411 if (!metadata_.validated) return;
412 batch_pointers_vec_.reset(
413 new TTypes<int32>::Vec(batch_pointers_.vec<int32>()));
414 row_pointers_vec_.reset(new TTypes<int32>::Vec(row_pointers_.vec<int32>()));
415 col_indices_vec_.reset(new TTypes<int32>::Vec(col_indices_.vec<int32>()));
416 }
417
ClearVecs()418 void ClearVecs() {
419 batch_pointers_vec_.reset();
420 row_pointers_vec_.reset();
421 col_indices_vec_.reset();
422 }
423
ValidateTypesAndShapes(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values)424 static Status ValidateTypesAndShapes(DataType dtype,
425 const Tensor& dense_shape,
426 const Tensor& batch_pointers,
427 const Tensor& row_pointers,
428 const Tensor& col_indices,
429 const Tensor& values) {
430 // TODO(ebrevdo): Consider adding support for other floating point types
431 // (namely, float16).
432 if (dtype != DT_FLOAT && dtype != DT_DOUBLE && dtype != DT_COMPLEX64 &&
433 dtype != DT_COMPLEX128) {
434 return errors::InvalidArgument(
435 "CSRSparseMatrix::Validate: dtype = ", DataTypeString(dtype),
436 " not in {float32, float64, complex64, complex128}");
437 }
438 // dense_shape checks
439 if (dense_shape.dtype() != DT_INT64) {
440 return errors::InvalidArgument(
441 "CSRSparseMatrix::Validate: dense_shape.dtype() = ",
442 DataTypeString(dense_shape.dtype()), " != int64");
443 }
444 if (dense_shape.dims() != 1) {
445 return errors::InvalidArgument(
446 "CSRSparseMatrix::Validate: dense_shape should be a vector, but saw "
447 "tensor: ",
448 dense_shape.DebugString());
449 }
450 int rank = dense_shape.dim_size(0);
451 if (rank < 2 || rank > 3) {
452 return errors::InvalidArgument(
453 "CSRSparseMatrix::Validate: dense_shape should be a 2- or 3- vector, "
454 "but saw: ",
455 dense_shape.SummarizeValue(5));
456 }
457 auto dense_shape_t = dense_shape.vec<int64>();
458 const int64 batch_size = (rank == 2) ? 1 : dense_shape_t(0);
459 const int64 num_rows = (rank == 2) ? dense_shape_t(0) : dense_shape_t(1);
460
461 if (batch_pointers.dtype() != DT_INT32) {
462 return errors::InvalidArgument(
463 "CSRSparseMatrix::Validate: batch_pointers.dtype() = ",
464 DataTypeString(batch_pointers.dtype()), " != int32");
465 }
466 if (batch_pointers.dims() != 1) {
467 return errors::InvalidArgument(
468 "CSRSparseMatrix::Validate: batch_indices is not a vector, saw "
469 "shape: ",
470 batch_pointers.shape().DebugString());
471 }
472
473 // batch size checks
474 if (batch_size != batch_pointers.NumElements() - 1) {
475 return errors::InvalidArgument(
476 "CSRSparseMatrix::Validate: dense_shape is ",
477 dense_shape.SummarizeValue(5),
478 " but batch pointers implies batch size is ",
479 batch_pointers.NumElements() - 1);
480 }
481
482 if (row_pointers.dtype() != DT_INT32) {
483 return errors::InvalidArgument(
484 "CSRSparseMatrix::Validate: row_pointers.dtype() = ",
485 DataTypeString(row_pointers.dtype()), " != int32");
486 }
487 if (row_pointers.dims() != 1) {
488 return errors::InvalidArgument(
489 "CSRSparseMatrix::Validate: row_pointers is not a vector, saw "
490 "shape: ",
491 row_pointers.shape().DebugString());
492 }
493 if (row_pointers.dim_size(0) != batch_size * (num_rows + 1)) {
494 return errors::InvalidArgument(
495 "CSRSparseMatrix::Validate: row_pointers should have size batch_size "
496 "* (num_rows + 1), saw shapes: ",
497 dense_shape.DebugString(), " vs. ",
498 row_pointers.shape().DebugString());
499 }
500 if (col_indices.dtype() != DT_INT32) {
501 return errors::InvalidArgument(
502 "CSRSparseMatrix::Validate: col_indices.dtype() = ",
503 DataTypeString(col_indices.dtype()), " != int32");
504 }
505 if (col_indices.dims() != 1) {
506 return errors::InvalidArgument(
507 "CSRSparseMatrix::Validate: col_indices is not a vector, saw shape: ",
508 col_indices.shape().DebugString());
509 }
510 if (values.dtype() != dtype) {
511 return errors::InvalidArgument(
512 "CSRSparseMatrix::Validate: values.dtype() = ",
513 DataTypeString(values.dtype()),
514 " != dtype = ", DataTypeString(dtype));
515 }
516 if (values.dims() != 1) {
517 return errors::InvalidArgument(
518 "CSRSparseMatrix::Validate: values is not a vector, saw shape: ",
519 values.shape().DebugString());
520 }
521 if (col_indices.dim_size(0) != values.dim_size(0)) {
522 return errors::InvalidArgument(
523 "CSRSparseMatrix::Validate: size(col_indices) = ",
524 col_indices.dim_size(0), " != size(values) = ", values.dim_size(0));
525 }
526 return Status::OK();
527 }
528
529 struct Metadata {
530 bool validated;
531 DataType dtype;
532 };
533 Metadata metadata_;
534 Tensor dense_shape_;
535 Tensor batch_pointers_;
536 Tensor row_pointers_;
537 Tensor col_indices_;
538 Tensor values_;
539 std::unique_ptr<TTypes<int32>::Vec> batch_pointers_vec_;
540 std::unique_ptr<TTypes<int32>::Vec> row_pointers_vec_;
541 std::unique_ptr<TTypes<int32>::Vec> col_indices_vec_;
542 };
543
544 // Call BinaryFunctor<Device, T>()(ctx, a, b, c)
545 // where T depends on a.dtype(). T will be one of: float, double,
546 // complex64, complex128.
547 template <typename Device, template <typename, typename> class BinaryFunctor>
CSRSparseMatrixBinaryHelper(OpKernelContext * ctx,const CSRSparseMatrix & a,const CSRSparseMatrix & b,CSRSparseMatrix * c)548 Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx,
549 const CSRSparseMatrix& a,
550 const CSRSparseMatrix& b,
551 CSRSparseMatrix* c) {
552 DataType dt = a.dtype();
553 if (dt != b.dtype()) {
554 return errors::InvalidArgument(
555 "CSRSparseMatrixBinaryHelper: Inconsistent dtypes for input matrices, "
556 "a "
557 "dtype: ",
558 DataTypeString(dt), ", b dtype: ", DataTypeString(b.dtype()));
559 }
560 switch (dt) {
561 case DT_FLOAT: {
562 BinaryFunctor<Device, float> functor(ctx);
563 return functor(a, b, c);
564 }
565 case DT_DOUBLE: {
566 BinaryFunctor<Device, double> functor(ctx);
567 return functor(a, b, c);
568 }
569 case DT_COMPLEX64: {
570 BinaryFunctor<Device, complex64> functor(ctx);
571 return functor(a, b, c);
572 }
573 case DT_COMPLEX128: {
574 BinaryFunctor<Device, complex128> functor(ctx);
575 return functor(a, b, c);
576 }
577 default:
578 return errors::InvalidArgument(
579 "CSRSparseMatrixBinaryHelper: a.dtype (", DataTypeString(dt),
580 ") is not one of: float, double, complex64, complex128");
581 }
582 }
583
584 // Call UnaryFunctor<Device, T>()(ctx, a, b)
585 // where T depends on a.dtype(). T will be one of: float, double,
586 // complex64, complex128.
587 template <typename Device, template <typename, typename> class UnaryFunctor>
CSRSparseMatrixUnaryHelper(OpKernelContext * ctx,const CSRSparseMatrix & a,CSRSparseMatrix * b)588 Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx,
589 const CSRSparseMatrix& a,
590 CSRSparseMatrix* b) {
591 DataType dt = a.dtype();
592 switch (dt) {
593 case DT_FLOAT: {
594 UnaryFunctor<Device, float> functor(ctx);
595 return functor(a, b);
596 }
597 case DT_DOUBLE: {
598 UnaryFunctor<Device, double> functor(ctx);
599 return functor(a, b);
600 }
601 case DT_COMPLEX64: {
602 UnaryFunctor<Device, complex64> functor(ctx);
603 return functor(a, b);
604 }
605 case DT_COMPLEX128: {
606 UnaryFunctor<Device, complex128> functor(ctx);
607 return functor(a, b);
608 }
609 default:
610 return errors::InvalidArgument(
611 "CSRSparseMatrixUnaryHelper: a.dtype (", DataTypeString(dt),
612 ") is not one of: float, double, complex64, complex128");
613 }
614 }
615
616 template <typename T>
617 struct ConstCSRComponent {
618 TTypes<int32>::UnalignedConstVec row_ptr;
619 TTypes<int32>::UnalignedConstVec col_ind;
620 typename TTypes<T>::UnalignedConstVec values;
621 TTypes<int64>::ConstVec dense_shape_host;
622 };
623
624 template <typename T>
625 struct CSRComponent {
626 TTypes<int32>::UnalignedVec row_ptr;
627 TTypes<int32>::UnalignedVec col_ind;
628 typename TTypes<T>::UnalignedVec values;
629 TTypes<int64>::Vec dense_shape_host;
630 };
631
632 template <typename T>
ExtractVariantFromInput(OpKernelContext * ctx,int index,const T ** value)633 Status ExtractVariantFromInput(OpKernelContext* ctx, int index,
634 const T** value) {
635 const Tensor& input_t = ctx->input(index);
636 const Variant& input_variant = input_t.scalar<Variant>()();
637 *value = input_variant.get<T>();
638 if (*value == nullptr) {
639 return errors::InvalidArgument("Could not retrieve Variant input ", index);
640 }
641 if (!(*value)->valid()) {
642 return errors::InvalidArgument("Variant input ", index, " is not valid.");
643 }
644 return Status::OK();
645 }
646
647 } // namespace tensorflow
648
649 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
650