• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_
18 
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 #include "mlir/IR/Value.h"  // from @llvm-project
22 
23 namespace mlir {
24 
25 class Builder;
26 
27 namespace TF {
28 
29 // Returns the RankedTensorType for the given operand. TensorFlow constant ops
30 // may have non-static shape because the shape is not propagated during constant
31 // folding. If the defining op for the given operand is a constant op, this
32 // routine uses the constant op's attribute to get the actual shape.
33 RankedTensorType GetRankedTensorTypeForOperand(Value operand);
34 
35 // Returns true if the given `value` is of ranked float tensor type with the
36 // given `rank`.
IsOfRankedFloatTensorType(RankedTensorType type,int rank)37 inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) {
38   return type && type.getRank() == rank &&
39          type.getElementType().isa<FloatType>();
40 }
41 
42 // Returns true if the given `value` has the specified rank or has unranked
43 // type.
IsOfRankOrUnranked(Value value,int64_t rank)44 inline bool IsOfRankOrUnranked(Value value, int64_t rank) {
45   RankedTensorType type = GetRankedTensorTypeForOperand(value);
46   return !type || type.getRank() == rank;
47 }
48 
49 // Returns true if the given `value` has at least the specified rank or has
50 // unranked type.
HasRankAtLeast(Value value,int64_t rank)51 inline bool HasRankAtLeast(Value value, int64_t rank) {
52   RankedTensorType type = GetRankedTensorTypeForOperand(value);
53   return !type || type.getRank() >= rank;
54 }
55 
56 // Returns true if the given `value` has at most the specified rank or has
57 // unranked type.
HasRankAtMost(Value value,int64_t rank)58 inline bool HasRankAtMost(Value value, int64_t rank) {
59   RankedTensorType type = GetRankedTensorTypeForOperand(value);
60   return !type || type.getRank() <= rank;
61 }
62 
IsUnknownDimOrRank(int64_t dim_or_rank)63 inline bool IsUnknownDimOrRank(int64_t dim_or_rank) {
64   return dim_or_rank == -1;
65 }
66 
67 // Returns dimension index for the given TensorFlow axis that supports negative
68 // indexing.
GetDimForAxis(int64_t axis,int64_t rank)69 inline int64_t GetDimForAxis(int64_t axis, int64_t rank) {
70   return axis >= 0 ? axis : axis + rank;
71 }
72 
73 // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
74 // `incompatible_shape_error` is true, reports error if `x` and `y` has
75 // incompatible shapes. Otherwise, returns a tensor type with unknown rank.
76 Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y,
77                           BoolAttr incompatible_shape_error);
78 
79 Type InferReductionOpType(Value input, Value reduction_indices,
80                           BoolAttr keep_dims);
81 
82 // Verifies that the given types are cast compatible. If not, emits appropriate
83 // error for the given op. If mask_one_dim is set to true, then the types are
84 // allowed to have one mismatching dimension. Masking one of the dimensions is
85 // useful for ops like Concat that requires all ranked inputs to have the same
86 // rank and match dimension sizes for all but one of the dimensions.
87 LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types,
88                                        bool mask_one_dim, Operation *op);
89 
90 }  // namespace TF
91 }  // namespace mlir
92 
93 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_
94