• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/utils/shape_inference_utils.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "llvm/ADT/None.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
31 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_def_builder.h"
37 #include "tensorflow/core/framework/shape_inference.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/ir/dialect.h"
42 #include "tensorflow/core/ir/importexport/convert_tensor.h"
43 #include "tensorflow/core/ir/importexport/convert_types.h"
44 #include "tensorflow/core/ir/importexport/graphdef_export.h"
45 #include "tensorflow/core/ir/types/dialect.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 #define DEBUG_TYPE "tfg-shape-inference-utils"
50 
51 using tensorflow::shape_inference::DimensionHandle;
52 using tensorflow::shape_inference::InferenceContext;
53 using tensorflow::shape_inference::ShapeHandle;
54 
55 namespace mlir {
56 namespace tfg {
57 
58 namespace {
59 
60 // Get the tensorflow op name.
GetTensorFlowOpName(Operation * inst)61 llvm::StringRef GetTensorFlowOpName(Operation* inst) {
62   llvm::StringRef op_name = inst->getName().stripDialect();
63   // Control dialect NextIteration sink ends with ".sink" and Executor dialect
64   // NextIteration sink ends with ".Sink".
65   if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
66   return op_name;
67 }
68 
69 // Extracts attributes from a MLIR operation, including derived attributes, into
70 // one NamedAttrList.
GetAllAttributesFromOperation(Operation * op)71 NamedAttrList GetAllAttributesFromOperation(Operation* op) {
72   NamedAttrList attr_list;
73   attr_list.append(op->getAttrDictionary().getValue());
74 
75   if (auto derived = dyn_cast<DerivedAttributeOpInterface>(op)) {
76     auto materialized = derived.materializeDerivedAttributes();
77     attr_list.append(materialized.getValue());
78   }
79 
80   return attr_list;
81 }
82 
83 // Extracts a PartialTensorShape from the MLIR type.
84 // Some MLIR shapes may fail to be represented as PartialTensorShape, e.g.
85 // those where num_elements overflows.
86 // TODO(tlongeri): Should num_elements overflow be handled by the MLIR
87 // verifier? Are there other cases?
GetShapeFromMlirType(Type t)88 Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
89   if (auto ranked_type = t.dyn_cast<RankedTensorType>()) {
90     tensorflow::PartialTensorShape shape;
91     const tensorflow::Status status =
92         tensorflow::PartialTensorShape::BuildPartialTensorShape(
93             ranked_type.getShape(), &shape);
94     if (status.ok()) return shape;
95   }
96   return None;
97 }
98 
99 // Extracts a PartialTensorShape from the MLIR attr.
GetShapeFromMlirAttr(Value v)100 Optional<tensorflow::PartialTensorShape> GetShapeFromMlirAttr(Value v) {
101   // Function arguments may have shape attr to describe its output shape.
102   if (auto arg = v.dyn_cast<BlockArgument>()) {
103     Operation* parent_op = arg.getOwner()->getParentOp();
104     if (auto func_op = llvm::dyn_cast<FunctionOpInterface>(parent_op)) {
105       int arg_idx = arg.getArgNumber();
106       auto attrs =
107           func_op.getArgAttrOfType<ArrayAttr>(arg_idx, "tf._output_shapes");
108       if (!attrs || attrs.size() != 1) return None;
109 
110       // "tf._output_shapes" in certain models may not store the shape as
111       // ShapeAttr, ignore them because we don't know how to interpret it.
112       auto shape_attr = attrs[0].dyn_cast<tf_type::ShapeAttr>();
113       if (shape_attr && shape_attr.hasRank())
114         return tensorflow::PartialTensorShape(shape_attr.getShape());
115     }
116   }
117   return None;
118 }
119 
120 // Gets the subtype's shape and data type for `type`. Templated to support both
121 // ResourceType and VariantType.
122 template <typename T>
123 std::unique_ptr<std::vector<
124     std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
GetSubtypesHelper(Type type)125 GetSubtypesHelper(Type type) {
126   auto type_with_subtypes =
127       type.cast<TensorType>().getElementType().dyn_cast<T>();
128   if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) {
129     return nullptr;
130   }
131   auto shapes_and_types = std::make_unique<std::vector<
132       std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
133   for (auto subtype : type_with_subtypes.getSubtypes()) {
134     auto shape = GetShapeFromMlirType(subtype);
135     // handle_shapes_and_types requires all shapes to be known. So if any
136     // subtype is unknown, clear the vector.
137     if (!shape) {
138       shapes_and_types = nullptr;
139       break;
140     }
141     tensorflow::DataType dtype;
142     auto status = ConvertToDataType(subtype.getElementType(), &dtype);
143     assert(status.ok() && "Unknown element type");
144     shapes_and_types->emplace_back(*shape, dtype);
145   }
146   return shapes_and_types;
147 }
148 
149 // Gets the subtype's shape and data type for `type`.
150 std::unique_ptr<std::vector<
151     std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
GetSubtypes(Type type)152 GetSubtypes(Type type) {
153   auto subclasses = GetSubtypesHelper<tf_type::ResourceType>(type);
154   if (subclasses) return subclasses;
155   return GetSubtypesHelper<tf_type::VariantType>(type);
156 }
157 
158 // Log a shape inference function call failure.
ReportErrorFromShapeFunction(Optional<Location> location,llvm::StringRef op_name,llvm::StringRef error_message)159 LogicalResult ReportErrorFromShapeFunction(Optional<Location> location,
160                                            llvm::StringRef op_name,
161                                            llvm::StringRef error_message) {
162   VLOG(3) << "TensorFlow shape inference function errored for op '"
163           << op_name.data() << "': " << error_message.data();
164   return failure();
165 }
166 
167 // Extracts shape from a shape handle and inference context.
GetShapeFromHandle(InferenceContext & context,const ShapeHandle & sh)168 Optional<SmallVector<int64_t, 8>> GetShapeFromHandle(InferenceContext& context,
169                                                      const ShapeHandle& sh) {
170   if (!context.RankKnown(sh)) return None;
171   SmallVector<int64_t, 8> shape;
172   for (int dim : llvm::seq<int>(0, context.Rank(sh)))
173     shape.push_back(context.Value(context.Dim(sh, dim)));
174   return shape;
175 }
176 
177 // Creates a tensor type from a shape handle and element type.
CreateTensorType(InferenceContext & context,const ShapeHandle & sh,Type element_type)178 TensorType CreateTensorType(InferenceContext& context, const ShapeHandle& sh,
179                             Type element_type) {
180   auto shape = GetShapeFromHandle(context, sh);
181   if (shape.has_value())
182     return RankedTensorType::get(shape.getValue(), element_type);
183   return UnrankedTensorType::get(element_type);
184 }
185 
186 // Creates a ShapedTypeComponent from a shape handle and element type.
CreateShapedTypeComponents(InferenceContext & context,const ShapeHandle & sh,Type element_type)187 ShapedTypeComponents CreateShapedTypeComponents(InferenceContext& context,
188                                                 const ShapeHandle& sh,
189                                                 Type element_type) {
190   auto shape = GetShapeFromHandle(context, sh);
191   if (shape.has_value())
192     return ShapedTypeComponents(shape.getValue(), element_type);
193   return ShapedTypeComponents(element_type);
194 }
195 
196 }  // namespace
197 
InferReturnTypeComponentsForTFOp(Optional<Location> location,Operation * op,ValueRange operands,int64_t graph_version,OperandAsConstantFn operand_as_constant_fn,OpResultAsShapeFn op_result_as_shape_fn,ResultElementTypeFn result_element_type_fn,GetAttrValuesFn get_attr_values_fn,SmallVectorImpl<ShapedTypeComponents> & inferred_return_shapes)198 LogicalResult InferReturnTypeComponentsForTFOp(
199     Optional<Location> location, Operation* op, ValueRange operands,
200     int64_t graph_version, OperandAsConstantFn operand_as_constant_fn,
201     OpResultAsShapeFn op_result_as_shape_fn,
202     ResultElementTypeFn result_element_type_fn,
203     GetAttrValuesFn get_attr_values_fn,
204     SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
205   llvm::StringRef op_name = GetTensorFlowOpName(op);
206 
207   // Get information from the registry and check if we have a shape function for
208   // this op.
209   const tensorflow::OpRegistrationData* op_reg_data =
210       tensorflow::OpRegistry::Global()->LookUp(op_name.str());
211   if (!op_reg_data) {
212     VLOG(3) << "Skipping inference for unregistered op '" << op_name.data()
213             << "'.\n";
214     return failure();
215   }
216   if (!op_reg_data->shape_inference_fn) {
217     VLOG(3) << "Skipping inference for op without shape function '"
218             << op_name.data() << "'.\n";
219     return failure();
220   }
221 
222   // Convert the operation to NodeDef to get the AttrValue to be able to use the
223   // InferenceContext and the TensorFlow shape function.
224   tensorflow::AttrValueMap attributes;
225 
226   if (get_attr_values_fn) {
227     tensorflow::Status status =
228         get_attr_values_fn(op, op_name, op_reg_data,
229                            /*ignore_unregistered_attrs=*/true, &attributes);
230     if (!status.ok()) {
231       VLOG(3) << op_name.data()
232               << " failed to get AttrValue: " << status.error_message();
233       return failure();
234     }
235   } else {
236     auto* dialect = cast<TFGraphDialect>(op->getDialect());
237     tensorflow::NodeDef node_def;
238     tensorflow::Status status = ConvertToNodeDef(
239         op, &node_def, dialect,
240         [&](Value value) { return GetValueName(value, dialect); });
241     if (!status.ok()) {
242       VLOG(3) << op_name.data() << " failed to be converted to NodeDef: "
243               << status.error_message();
244       return failure();
245     }
246     attributes = node_def.attr();
247   }
248 
249   // Collect an array with input values for constant operands and input shapes
250   // for all the operands.
251   const int num_operands = operands.size();
252   std::vector<tensorflow::PartialTensorShape> input_shapes(num_operands);
253   std::vector<std::unique_ptr<std::vector<
254       std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>>
255       handle_shapes_and_types(num_operands);
256   for (const auto& it : llvm::enumerate(operands)) {
257     Value operand = it.value();
258     size_t index = it.index();
259 
260     Type operand_type = operand.getType();
261     if (auto shape = GetShapeFromMlirType(operand_type)) {
262       input_shapes[index] = *shape;
263     } else if (auto shape = GetShapeFromMlirAttr(operand)) {
264       input_shapes[index] = *shape;
265     }
266     // Collect the handle shapes and types for a resource/variant.
267     handle_shapes_and_types[index] = GetSubtypes(operand_type);
268   }
269 
270   // Perform the shape inference using an InferenceContext with the input
271   // shapes. This object is abstracting the information that the ShapeInference
272   // function operates on.
273   InferenceContext c(graph_version, tensorflow::AttrSlice(&attributes),
274                      op_reg_data->op_def, input_shapes, /*input_tensors*/ {},
275                      /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
276   if (!c.construction_status().ok()) {
277     VLOG(3) << "InferenceContext construction failed on " << op_name.data()
278             << ": " << c.construction_status().error_message();
279     return failure();
280   }
281   auto status = c.Run(op_reg_data->shape_inference_fn);
282   if (!status.ok()) {
283     return ReportErrorFromShapeFunction(location, op_name,
284                                         status.error_message());
285   }
286 
287   std::vector<const tensorflow::Tensor*> input_tensors(num_operands);
288   std::vector<tensorflow::Tensor> tensors(num_operands);
289   std::vector<ShapeHandle> input_tensors_as_shapes(num_operands);
290 
291   // Determine if, during shape computation, the shape functions attempted to
292   // query the input or input as shape where the input wasn't available.
293   auto requires_inputs = [&]() {
294     return any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
295       return !input_tensors[input] &&
296              (c.requested_input_tensor(input) ||
297               c.requested_input_tensor_as_partial_shape(input));
298     });
299   };
300 
301   // Iterate until no new inputs are requested. Some shape functions may not
302   // request all inputs upfront and can return early so this requires multiple
303   // iterations.
304   while (requires_inputs()) {
305     VLOG(4) << "\tfeeding new inputs or input as partial shapes\n";
306 
307     bool has_new_inputs = false;
308     for (int input : llvm::seq<int>(0, c.num_inputs())) {
309       if (input_tensors[input]) continue;
310 
311       if (c.requested_input_tensor(input)) {
312         if (auto attr = operand_as_constant_fn(op->getOperand(input))
313                             .dyn_cast_or_null<ElementsAttr>()) {
314           VLOG(4) << "Requesting " << input << " as constant\n";
315           tensorflow::Tensor* input_tensor = &tensors.at(input);
316           auto status = ConvertToTensor(attr, input_tensor);
317           if (status.ok()) {
318             input_tensors.at(input) = input_tensor;
319             has_new_inputs = true;
320           } else {
321             VLOG(4) << "Error converting input " << input << " of op '"
322                     << op_name.data()
323                     << "' to Tensor: " << status.error_message() << "\n";
324           }
325         }
326       }
327 
328       if (c.requested_input_tensor_as_partial_shape(input) &&
329           !input_tensors[input] && !input_tensors_as_shapes[input].Handle()) {
330         VLOG(4) << "Requesting " << input << " as shape\n";
331         auto op_result = op->getOperand(input).dyn_cast<OpResult>();
332         if (!op_result) continue;
333         // Resize on first valid shape computed.
334         auto handle = op_result_as_shape_fn(c, op_result);
335         VLOG(4) << "Requested " << input << " as shape "
336                 << (handle.Handle() ? "found" : "not found");
337         if (handle.Handle()) {
338           input_tensors_as_shapes[input] = handle;
339           has_new_inputs = true;
340         }
341       }
342     }
343 
344     if (!has_new_inputs) break;
345 
346     c.set_input_tensors(input_tensors);
347     c.set_input_tensors_as_shapes(input_tensors_as_shapes);
348     auto status = c.Run(op_reg_data->shape_inference_fn);
349     if (!status.ok()) {
350       return ReportErrorFromShapeFunction(location, op_name,
351                                           status.error_message());
352     }
353   }
354 
355   // Update the shape for each of the operation result if the InferenceContext
356   // has more precise shapes recorded.
357   for (int output : llvm::seq<int>(0, c.num_outputs())) {
358     ShapeHandle shape_handle = c.output(output);
359     VLOG(4) << "Inferred output " << output << " : "
360             << c.DebugString(shape_handle) << "\n";
361 
362     Type new_element_type = result_element_type_fn(output);
363     // Populate the handle shapes for a resource/variant.
364     if (new_element_type &&
365         new_element_type.isa<tf_type::ResourceType, tf_type::VariantType>()) {
366       auto handle_shapes_types = c.output_handle_shapes_and_types(output);
367       if (handle_shapes_types) {
368         SmallVector<TensorType, 1> subtypes;
369         Builder b(op->getContext());
370         for (const auto& shape_n_type : *handle_shapes_types) {
371           Type element_type;
372           auto status = ConvertDataType(shape_n_type.dtype, b, &element_type);
373           assert(status.ok() && "Unknown element type");
374           subtypes.push_back(
375               CreateTensorType(c, shape_n_type.shape, element_type));
376         }
377         if (new_element_type.isa<tf_type::ResourceType>()) {
378           new_element_type =
379               tf_type::ResourceType::get(subtypes, op->getContext());
380         } else {
381           new_element_type =
382               tf_type::VariantType::get(subtypes, op->getContext());
383         }
384       }
385     }
386     inferred_return_shapes.push_back(
387         CreateShapedTypeComponents(c, shape_handle, new_element_type));
388   }
389 
390   return success();
391 }
392 
InferReturnTypeComponentsForTFOp(Optional<Location> location,Operation * op,ValueRange operands,int64_t graph_version,OperandAsConstantFn operand_as_constant_fn,OpResultAsShapeFn op_result_as_shape_fn,ResultElementTypeFn result_element_type_fn,SmallVectorImpl<ShapedTypeComponents> & inferred_return_shapes)393 LogicalResult InferReturnTypeComponentsForTFOp(
394     Optional<Location> location, Operation* op, ValueRange operands,
395     int64_t graph_version, OperandAsConstantFn operand_as_constant_fn,
396     OpResultAsShapeFn op_result_as_shape_fn,
397     ResultElementTypeFn result_element_type_fn,
398     SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
399   return InferReturnTypeComponentsForTFOp(
400       location, op, operands, graph_version, operand_as_constant_fn,
401       op_result_as_shape_fn, result_element_type_fn,
402       /*get_attr_values_fn=*/nullptr, inferred_return_shapes);
403 }
404 
405 }  // namespace tfg
406 }  // namespace mlir
407