• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
17 #define TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
18 
19 #include <optional>
20 
21 #include "absl/types/optional.h"
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24 
25 namespace tensorflow {
26 namespace dtensor {
27 
28 // Gets the SparseToDenseOp that generates `value` if `value` is the result of
29 // a SparseToDenseOp. Returns empty otherwise. This is useful
30 // in SparseExpansion where we want to check whether some operand
31 // is a SparseTensor, by checking whether that operand is a result of a
32 // SparseToDenseOp. If this value is eventually an output of a SparseToDenseOp,
33 // there should only be DTensor related ops between the actual SparseToDenseOp,
34 // e.g. DTensorRelayout ops or DTensorLayout op.
35 StatusOr<mlir::TF::SparseToDenseOp> GetSparseToDenseOp(mlir::Value value);
36 
37 // Checks whether `value is an output of a SparseToDenseOp value.
38 bool IsSparseValue(mlir::Value value);
39 
40 // Checks if `op` has any sparse value operands.
41 bool HasAnySparseInput(mlir::Operation* op);
42 
43 // Checks if all operands of `op` is a sparse value.
44 bool AllSparseInput(mlir::Operation* op);
45 
46 // Returns the indices component dense tensor from `value`. `value` represents
47 // a SparseTensor value.
48 StatusOr<mlir::Value> GetIndicesFromSparseTensor(mlir::Value value);
49 
50 // Returns the values component dense tensor from `value`.`value` represents
51 // a SparseTensor value.
52 StatusOr<mlir::Value> GetValuesFromSparseTensor(mlir::Value value);
53 
54 // Returns the dense shape component dense tensor from `value`. `value`
55 // represents a SparseTensor value.
56 StatusOr<mlir::Value> GetDenseShapesFromSparseTensor(mlir::Value value);
57 
58 }  // namespace dtensor
59 }  // namespace tensorflow
60 
61 #endif  // TENSORFLOW_DTENSOR_MLIR_SPARSE_EXPANDER_COMMON_H_
62