1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
17 
18 #include "mlir/Dialect/Traits.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/Matchers.h"  // from @llvm-project
21 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
22 
23 namespace mlir {
24 namespace TF {
25 
26 class IdentityOp;
27 class IdentityNOp;
28 
29 // Returns the RankedTensorType for the given operand. TensorFlow constant ops
30 // may have non-static shape because the shape is not propagated during constant
31 // folding. If the defining op for the given operand is a constant op, this
32 // routine uses the constant op's attribute to get the actual shape.
GetRankedTensorTypeForOperand(Value operand)33 RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
34   DenseElementsAttr attr;
35   if (matchPattern(operand, m_Constant(&attr))) {
36     return attr.getType().dyn_cast<RankedTensorType>();
37   }
38   return operand.getType().dyn_cast<RankedTensorType>();
39 }
40 
41 // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
42 // `incompatible_shape_error` is true, reports error if `x` and `y` has
43 // incompatible shapes. Otherwise, returns a tensor type with unknown rank.
DeduceEqualCmpOpType(Builder * builder,Location loc,Value x,Value y,BoolAttr incompatible_shape_error)44 Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y,
45                           BoolAttr incompatible_shape_error) {
46   auto result_type =
47       OpTrait::util::getBroadcastedType(x.getType(), y.getType());
48   if (!result_type) {
49     if (incompatible_shape_error.getValue()) {
50       mlir::emitError(loc, "non-broadcastable operands");
51     } else {
52       return UnrankedTensorType::get(builder->getI1Type());
53     }
54   }
55 
56   auto ranked_type = result_type.dyn_cast<RankedTensorType>();
57   if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type());
58 
59   return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type());
60 }
61 
InferReductionOpType(Value input,Value reduction_indices,BoolAttr keep_dims)62 Type InferReductionOpType(Value input, Value reduction_indices,
63                           BoolAttr keep_dims) {
64   Type input_ty = input.getType();
65   Type element_ty = getElementTypeOrSelf(input_ty);
66 
67   // Output type is unranked if input type is not ranked.
68   auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
69   if (!ranked_ty) return UnrankedTensorType::get(element_ty);
70   int64_t rank = ranked_ty.getRank();
71 
72   DenseIntElementsAttr indices;
73   if (!matchPattern(reduction_indices, m_Constant(&indices))) {
74     // Output type is unranked if reduction indices are not constant and reduced
75     // dimensions are not kept.
76     if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
77 
78     // Otherwise, output type has same rank as the input.
79     return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
80   }
81 
82   int64_t num_reduce_dim = 0;
83   llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
84   for (const APInt &index : indices.getValues<APInt>()) {
85     int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
86     // Invalid input.
87     if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
88 
89     if (!is_reduce_dim[dim]) {
90       is_reduce_dim[dim] = true;
91       num_reduce_dim++;
92     }
93   }
94 
95   ArrayRef<int64_t> shape = ranked_ty.getShape();
96   SmallVector<int64_t, 4> out_shape;
97   out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim));
98   for (int64_t i = 0; i < rank; ++i) {
99     if (!is_reduce_dim[i])
100       out_shape.push_back(shape[i]);
101     else if (keep_dims.getValue())
102       out_shape.push_back(1);
103   }
104   return RankedTensorType::get(out_shape, element_ty);
105 }
106 
107 // Verifies that the given types are cast compatible. If not, emits appropriate
108 // error for the given op. If mask_one_dim is set to true, then the types are
109 // allowed to have one mismatching dimension. Masking one of the dimensions is
110 // useful for ops like Concat that requires all ranked inputs to have the same
111 // rank and match dimension sizes for all but one of the dimensions.
VerifyTypesCompatibility(Operation::operand_type_range types,bool mask_one_dim,Operation * op)112 LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types,
113                                        bool mask_one_dim, Operation *op) {
114   constexpr int64_t kUninitialized = -1;
115   int64_t common_rank = kUninitialized;
116   llvm::SmallVector<int64_t, 4> common_dims;
117   int64_t dim_to_mask = kUninitialized;
118 
119   // Initialize common_rank with rank of the first ranked type and verify that
120   // following ranked types have the same rank.
121   // Similarly, initialize each of the dimensions with the first type that has
122   // the dimension size available and verify that all following types have the
123   // same size for the dimension. However, if mask_one_dim is true, note down
124   // the dimension index on the first mismatch and ignore dimension at that
125   // index in following types.
126   for (Type ty : types) {
127     RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
128     if (!ranked_ty) continue;
129 
130     int64_t rank = ranked_ty.getRank();
131     if (common_rank == kUninitialized) {
132       common_rank = rank;
133       common_dims.resize(common_rank, kUninitialized);
134     } else if (common_rank != rank) {
135       return op->emitError()
136              << "operand type " << ranked_ty
137              << " is not compatible with preceding operands; expected rank: "
138              << common_rank;
139     }
140 
141     for (int64_t i = 0, e = common_rank; i != e; i++) {
142       if (i == dim_to_mask) continue;
143 
144       int64_t dim = ranked_ty.getDimSize(i);
145       if (dim == kUninitialized) continue;
146 
147       int64_t &common_dim = common_dims[i];
148       if (common_dim == kUninitialized) {
149         common_dim = dim;
150       } else if (common_dim != dim) {
151         // If mask_one_dim is true, do not emit an error if this is the only
152         // dimension with mismatches. Note down the dimension to mask it from
153         // the following types.
154         if (mask_one_dim && dim_to_mask == kUninitialized) {
155           dim_to_mask = i;
156           continue;
157         }
158 
159         return op->emitError() << "operand type " << ranked_ty
160                                << " is not compatible with preceding operands; "
161                                   "expected dimension at index "
162                                << i << ": " << common_dim;
163       }
164     }
165   }
166   return success();
167 }
168 
169 }  // namespace TF
170 }  // namespace mlir
171