1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/utils/shape_inference_utils.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "llvm/ADT/None.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
31 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_def_builder.h"
37 #include "tensorflow/core/framework/shape_inference.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/ir/dialect.h"
42 #include "tensorflow/core/ir/importexport/convert_tensor.h"
43 #include "tensorflow/core/ir/importexport/convert_types.h"
44 #include "tensorflow/core/ir/importexport/graphdef_export.h"
45 #include "tensorflow/core/ir/types/dialect.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48
49 #define DEBUG_TYPE "tfg-shape-inference-utils"
50
51 using tensorflow::shape_inference::DimensionHandle;
52 using tensorflow::shape_inference::InferenceContext;
53 using tensorflow::shape_inference::ShapeHandle;
54
55 namespace mlir {
56 namespace tfg {
57
58 namespace {
59
60 // Get the tensorflow op name.
GetTensorFlowOpName(Operation * inst)61 llvm::StringRef GetTensorFlowOpName(Operation* inst) {
62 llvm::StringRef op_name = inst->getName().stripDialect();
63 // Control dialect NextIteration sink ends with ".sink" and Executor dialect
64 // NextIteration sink ends with ".Sink".
65 if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
66 return op_name;
67 }
68
69 // Extracts attributes from a MLIR operation, including derived attributes, into
70 // one NamedAttrList.
GetAllAttributesFromOperation(Operation * op)71 NamedAttrList GetAllAttributesFromOperation(Operation* op) {
72 NamedAttrList attr_list;
73 attr_list.append(op->getAttrDictionary().getValue());
74
75 if (auto derived = dyn_cast<DerivedAttributeOpInterface>(op)) {
76 auto materialized = derived.materializeDerivedAttributes();
77 attr_list.append(materialized.getValue());
78 }
79
80 return attr_list;
81 }
82
83 // Extracts a PartialTensorShape from the MLIR type.
84 // Some MLIR shapes may fail to be represented as PartialTensorShape, e.g.
85 // those where num_elements overflows.
86 // TODO(tlongeri): Should num_elements overflow be handled by the MLIR
87 // verifier? Are there other cases?
GetShapeFromMlirType(Type t)88 Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
89 if (auto ranked_type = t.dyn_cast<RankedTensorType>()) {
90 tensorflow::PartialTensorShape shape;
91 const tensorflow::Status status =
92 tensorflow::PartialTensorShape::BuildPartialTensorShape(
93 ranked_type.getShape(), &shape);
94 if (status.ok()) return shape;
95 }
96 return None;
97 }
98
99 // Extracts a PartialTensorShape from the MLIR attr.
GetShapeFromMlirAttr(Value v)100 Optional<tensorflow::PartialTensorShape> GetShapeFromMlirAttr(Value v) {
101 // Function arguments may have shape attr to describe its output shape.
102 if (auto arg = v.dyn_cast<BlockArgument>()) {
103 Operation* parent_op = arg.getOwner()->getParentOp();
104 if (auto func_op = llvm::dyn_cast<FunctionOpInterface>(parent_op)) {
105 int arg_idx = arg.getArgNumber();
106 auto attrs =
107 func_op.getArgAttrOfType<ArrayAttr>(arg_idx, "tf._output_shapes");
108 if (!attrs || attrs.size() != 1) return None;
109
110 // "tf._output_shapes" in certain models may not store the shape as
111 // ShapeAttr, ignore them because we don't know how to interpret it.
112 auto shape_attr = attrs[0].dyn_cast<tf_type::ShapeAttr>();
113 if (shape_attr && shape_attr.hasRank())
114 return tensorflow::PartialTensorShape(shape_attr.getShape());
115 }
116 }
117 return None;
118 }
119
120 // Gets the subtype's shape and data type for `type`. Templated to support both
121 // ResourceType and VariantType.
122 template <typename T>
123 std::unique_ptr<std::vector<
124 std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
GetSubtypesHelper(Type type)125 GetSubtypesHelper(Type type) {
126 auto type_with_subtypes =
127 type.cast<TensorType>().getElementType().dyn_cast<T>();
128 if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) {
129 return nullptr;
130 }
131 auto shapes_and_types = std::make_unique<std::vector<
132 std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
133 for (auto subtype : type_with_subtypes.getSubtypes()) {
134 auto shape = GetShapeFromMlirType(subtype);
135 // handle_shapes_and_types requires all shapes to be known. So if any
136 // subtype is unknown, clear the vector.
137 if (!shape) {
138 shapes_and_types = nullptr;
139 break;
140 }
141 tensorflow::DataType dtype;
142 auto status = ConvertToDataType(subtype.getElementType(), &dtype);
143 assert(status.ok() && "Unknown element type");
144 shapes_and_types->emplace_back(*shape, dtype);
145 }
146 return shapes_and_types;
147 }
148
149 // Gets the subtype's shape and data type for `type`.
150 std::unique_ptr<std::vector<
151 std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
GetSubtypes(Type type)152 GetSubtypes(Type type) {
153 auto subclasses = GetSubtypesHelper<tf_type::ResourceType>(type);
154 if (subclasses) return subclasses;
155 return GetSubtypesHelper<tf_type::VariantType>(type);
156 }
157
158 // Log a shape inference function call failure.
ReportErrorFromShapeFunction(Optional<Location> location,llvm::StringRef op_name,llvm::StringRef error_message)159 LogicalResult ReportErrorFromShapeFunction(Optional<Location> location,
160 llvm::StringRef op_name,
161 llvm::StringRef error_message) {
162 VLOG(3) << "TensorFlow shape inference function errored for op '"
163 << op_name.data() << "': " << error_message.data();
164 return failure();
165 }
166
167 // Extracts shape from a shape handle and inference context.
GetShapeFromHandle(InferenceContext & context,const ShapeHandle & sh)168 Optional<SmallVector<int64_t, 8>> GetShapeFromHandle(InferenceContext& context,
169 const ShapeHandle& sh) {
170 if (!context.RankKnown(sh)) return None;
171 SmallVector<int64_t, 8> shape;
172 for (int dim : llvm::seq<int>(0, context.Rank(sh)))
173 shape.push_back(context.Value(context.Dim(sh, dim)));
174 return shape;
175 }
176
177 // Creates a tensor type from a shape handle and element type.
CreateTensorType(InferenceContext & context,const ShapeHandle & sh,Type element_type)178 TensorType CreateTensorType(InferenceContext& context, const ShapeHandle& sh,
179 Type element_type) {
180 auto shape = GetShapeFromHandle(context, sh);
181 if (shape.has_value())
182 return RankedTensorType::get(shape.getValue(), element_type);
183 return UnrankedTensorType::get(element_type);
184 }
185
186 // Creates a ShapedTypeComponent from a shape handle and element type.
CreateShapedTypeComponents(InferenceContext & context,const ShapeHandle & sh,Type element_type)187 ShapedTypeComponents CreateShapedTypeComponents(InferenceContext& context,
188 const ShapeHandle& sh,
189 Type element_type) {
190 auto shape = GetShapeFromHandle(context, sh);
191 if (shape.has_value())
192 return ShapedTypeComponents(shape.getValue(), element_type);
193 return ShapedTypeComponents(element_type);
194 }
195
196 } // namespace
197
InferReturnTypeComponentsForTFOp(Optional<Location> location,Operation * op,ValueRange operands,int64_t graph_version,OperandAsConstantFn operand_as_constant_fn,OpResultAsShapeFn op_result_as_shape_fn,ResultElementTypeFn result_element_type_fn,GetAttrValuesFn get_attr_values_fn,SmallVectorImpl<ShapedTypeComponents> & inferred_return_shapes)198 LogicalResult InferReturnTypeComponentsForTFOp(
199 Optional<Location> location, Operation* op, ValueRange operands,
200 int64_t graph_version, OperandAsConstantFn operand_as_constant_fn,
201 OpResultAsShapeFn op_result_as_shape_fn,
202 ResultElementTypeFn result_element_type_fn,
203 GetAttrValuesFn get_attr_values_fn,
204 SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
205 llvm::StringRef op_name = GetTensorFlowOpName(op);
206
207 // Get information from the registry and check if we have a shape function for
208 // this op.
209 const tensorflow::OpRegistrationData* op_reg_data =
210 tensorflow::OpRegistry::Global()->LookUp(op_name.str());
211 if (!op_reg_data) {
212 VLOG(3) << "Skipping inference for unregistered op '" << op_name.data()
213 << "'.\n";
214 return failure();
215 }
216 if (!op_reg_data->shape_inference_fn) {
217 VLOG(3) << "Skipping inference for op without shape function '"
218 << op_name.data() << "'.\n";
219 return failure();
220 }
221
222 // Convert the operation to NodeDef to get the AttrValue to be able to use the
223 // InferenceContext and the TensorFlow shape function.
224 tensorflow::AttrValueMap attributes;
225
226 if (get_attr_values_fn) {
227 tensorflow::Status status =
228 get_attr_values_fn(op, op_name, op_reg_data,
229 /*ignore_unregistered_attrs=*/true, &attributes);
230 if (!status.ok()) {
231 VLOG(3) << op_name.data()
232 << " failed to get AttrValue: " << status.error_message();
233 return failure();
234 }
235 } else {
236 auto* dialect = cast<TFGraphDialect>(op->getDialect());
237 tensorflow::NodeDef node_def;
238 tensorflow::Status status = ConvertToNodeDef(
239 op, &node_def, dialect,
240 [&](Value value) { return GetValueName(value, dialect); });
241 if (!status.ok()) {
242 VLOG(3) << op_name.data() << " failed to be converted to NodeDef: "
243 << status.error_message();
244 return failure();
245 }
246 attributes = node_def.attr();
247 }
248
249 // Collect an array with input values for constant operands and input shapes
250 // for all the operands.
251 const int num_operands = operands.size();
252 std::vector<tensorflow::PartialTensorShape> input_shapes(num_operands);
253 std::vector<std::unique_ptr<std::vector<
254 std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>>
255 handle_shapes_and_types(num_operands);
256 for (const auto& it : llvm::enumerate(operands)) {
257 Value operand = it.value();
258 size_t index = it.index();
259
260 Type operand_type = operand.getType();
261 if (auto shape = GetShapeFromMlirType(operand_type)) {
262 input_shapes[index] = *shape;
263 } else if (auto shape = GetShapeFromMlirAttr(operand)) {
264 input_shapes[index] = *shape;
265 }
266 // Collect the handle shapes and types for a resource/variant.
267 handle_shapes_and_types[index] = GetSubtypes(operand_type);
268 }
269
270 // Perform the shape inference using an InferenceContext with the input
271 // shapes. This object is abstracting the information that the ShapeInference
272 // function operates on.
273 InferenceContext c(graph_version, tensorflow::AttrSlice(&attributes),
274 op_reg_data->op_def, input_shapes, /*input_tensors*/ {},
275 /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
276 if (!c.construction_status().ok()) {
277 VLOG(3) << "InferenceContext construction failed on " << op_name.data()
278 << ": " << c.construction_status().error_message();
279 return failure();
280 }
281 auto status = c.Run(op_reg_data->shape_inference_fn);
282 if (!status.ok()) {
283 return ReportErrorFromShapeFunction(location, op_name,
284 status.error_message());
285 }
286
287 std::vector<const tensorflow::Tensor*> input_tensors(num_operands);
288 std::vector<tensorflow::Tensor> tensors(num_operands);
289 std::vector<ShapeHandle> input_tensors_as_shapes(num_operands);
290
291 // Determine if, during shape computation, the shape functions attempted to
292 // query the input or input as shape where the input wasn't available.
293 auto requires_inputs = [&]() {
294 return any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
295 return !input_tensors[input] &&
296 (c.requested_input_tensor(input) ||
297 c.requested_input_tensor_as_partial_shape(input));
298 });
299 };
300
301 // Iterate until no new inputs are requested. Some shape functions may not
302 // request all inputs upfront and can return early so this requires multiple
303 // iterations.
304 while (requires_inputs()) {
305 VLOG(4) << "\tfeeding new inputs or input as partial shapes\n";
306
307 bool has_new_inputs = false;
308 for (int input : llvm::seq<int>(0, c.num_inputs())) {
309 if (input_tensors[input]) continue;
310
311 if (c.requested_input_tensor(input)) {
312 if (auto attr = operand_as_constant_fn(op->getOperand(input))
313 .dyn_cast_or_null<ElementsAttr>()) {
314 VLOG(4) << "Requesting " << input << " as constant\n";
315 tensorflow::Tensor* input_tensor = &tensors.at(input);
316 auto status = ConvertToTensor(attr, input_tensor);
317 if (status.ok()) {
318 input_tensors.at(input) = input_tensor;
319 has_new_inputs = true;
320 } else {
321 VLOG(4) << "Error converting input " << input << " of op '"
322 << op_name.data()
323 << "' to Tensor: " << status.error_message() << "\n";
324 }
325 }
326 }
327
328 if (c.requested_input_tensor_as_partial_shape(input) &&
329 !input_tensors[input] && !input_tensors_as_shapes[input].Handle()) {
330 VLOG(4) << "Requesting " << input << " as shape\n";
331 auto op_result = op->getOperand(input).dyn_cast<OpResult>();
332 if (!op_result) continue;
333 // Resize on first valid shape computed.
334 auto handle = op_result_as_shape_fn(c, op_result);
335 VLOG(4) << "Requested " << input << " as shape "
336 << (handle.Handle() ? "found" : "not found");
337 if (handle.Handle()) {
338 input_tensors_as_shapes[input] = handle;
339 has_new_inputs = true;
340 }
341 }
342 }
343
344 if (!has_new_inputs) break;
345
346 c.set_input_tensors(input_tensors);
347 c.set_input_tensors_as_shapes(input_tensors_as_shapes);
348 auto status = c.Run(op_reg_data->shape_inference_fn);
349 if (!status.ok()) {
350 return ReportErrorFromShapeFunction(location, op_name,
351 status.error_message());
352 }
353 }
354
355 // Update the shape for each of the operation result if the InferenceContext
356 // has more precise shapes recorded.
357 for (int output : llvm::seq<int>(0, c.num_outputs())) {
358 ShapeHandle shape_handle = c.output(output);
359 VLOG(4) << "Inferred output " << output << " : "
360 << c.DebugString(shape_handle) << "\n";
361
362 Type new_element_type = result_element_type_fn(output);
363 // Populate the handle shapes for a resource/variant.
364 if (new_element_type &&
365 new_element_type.isa<tf_type::ResourceType, tf_type::VariantType>()) {
366 auto handle_shapes_types = c.output_handle_shapes_and_types(output);
367 if (handle_shapes_types) {
368 SmallVector<TensorType, 1> subtypes;
369 Builder b(op->getContext());
370 for (const auto& shape_n_type : *handle_shapes_types) {
371 Type element_type;
372 auto status = ConvertDataType(shape_n_type.dtype, b, &element_type);
373 assert(status.ok() && "Unknown element type");
374 subtypes.push_back(
375 CreateTensorType(c, shape_n_type.shape, element_type));
376 }
377 if (new_element_type.isa<tf_type::ResourceType>()) {
378 new_element_type =
379 tf_type::ResourceType::get(subtypes, op->getContext());
380 } else {
381 new_element_type =
382 tf_type::VariantType::get(subtypes, op->getContext());
383 }
384 }
385 }
386 inferred_return_shapes.push_back(
387 CreateShapedTypeComponents(c, shape_handle, new_element_type));
388 }
389
390 return success();
391 }
392
InferReturnTypeComponentsForTFOp(Optional<Location> location,Operation * op,ValueRange operands,int64_t graph_version,OperandAsConstantFn operand_as_constant_fn,OpResultAsShapeFn op_result_as_shape_fn,ResultElementTypeFn result_element_type_fn,SmallVectorImpl<ShapedTypeComponents> & inferred_return_shapes)393 LogicalResult InferReturnTypeComponentsForTFOp(
394 Optional<Location> location, Operation* op, ValueRange operands,
395 int64_t graph_version, OperandAsConstantFn operand_as_constant_fn,
396 OpResultAsShapeFn op_result_as_shape_fn,
397 ResultElementTypeFn result_element_type_fn,
398 SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
399 return InferReturnTypeComponentsForTFOp(
400 location, op, operands, graph_version, operand_as_constant_fn,
401 op_result_as_shape_fn, result_element_type_fn,
402 /*get_attr_values_fn=*/nullptr, inferred_return_shapes);
403 }
404
405 } // namespace tfg
406 } // namespace mlir
407