• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/ADT/Twine.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
36 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
37 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
38 #include "mlir/IR/Matchers.h"  // from @llvm-project
39 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
40 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
41 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
42 #include "mlir/IR/Types.h"  // from @llvm-project
43 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
44 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
48 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
49 
50 namespace mlir {
51 
52 namespace TFR {
53 
54 //===----------------------------------------------------------------------===//
55 // InlinerInterface
56 //===----------------------------------------------------------------------===//
57 
58 namespace {
59 /// This class defines the interface for inlining within the TFR dialect.
60 struct TFRInlinerInterface : public DialectInlinerInterface {
61   using DialectInlinerInterface::DialectInlinerInterface;
62 
63   // Allow all call operations to be inlined.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface64   bool isLegalToInline(Operation *call, Operation *callable,
65                        bool wouldBeCloned) const final {
66     return true;
67   }
68   // Returns true if the given region 'src' can be inlined into the region
69   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface70   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71                        BlockAndValueMapping &) const final {
72     return true;
73   }
74 
75   // Returns true if the given operation 'op', that is registered to this
76   // dialect, can be inlined into the region 'dest' that is attached to an
77   // operation registered to the current dialect.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface78   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79                        BlockAndValueMapping &) const final {
80     return true;
81   }
82 
83   // Handle the given inlined terminator by replacing it with a new operation
84   // as necessary. Required when the region has only one block.
handleTerminatormlir::TFR::__anon23dd953f0111::TFRInlinerInterface85   void handleTerminator(Operation *op,
86                         ArrayRef<Value> valuesToRepl) const final {
87     auto retValOp = dyn_cast<TFRReturnOp>(op);
88     if (!retValOp) return;
89 
90     for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
91       std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
92     }
93   }
94 
95   // Attempts to materialize a conversion for a type mismatch between a call
96   // from this dialect, and a callable region. This method should generate an
97   // operation that takes 'input' as the only operand, and produces a single
98   // result of 'resultType'. If a conversion can not be generated, nullptr
99   // should be returned.
materializeCallConversionmlir::TFR::__anon23dd953f0111::TFRInlinerInterface100   Operation *materializeCallConversion(OpBuilder &builder, Value input,
101                                        Type result_type,
102                                        Location conversion_loc) const final {
103     if (!result_type.isa<IntegerType>()) return nullptr;
104     return builder.create<TruncateIOp>(conversion_loc, result_type, input);
105   }
106 };
107 }  // namespace
108 
109 //===----------------------------------------------------------------------===//
110 // TFR Dialect
111 //===----------------------------------------------------------------------===//
112 
TFRDialect(MLIRContext * context)113 TFRDialect::TFRDialect(MLIRContext *context)
114     : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
115   addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
116   addOperations<
117 #define GET_OP_LIST
118 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
119       >();
120 
121   addInterfaces<TFRInlinerInterface>();
122 }
123 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)124 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
125                                            Type type, Location loc) {
126   if (ConstantOp::isBuildableWith(value, type))
127     return builder.create<ConstantOp>(loc, type, value);
128   return nullptr;
129 }
130 
classof(Type type)131 bool TFRType::classof(Type type) {
132   return llvm::isa<TFRDialect>(type.getDialect());
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Custom op methods
137 //===----------------------------------------------------------------------===//
138 
Verify(ConstantTensorOp op)139 static LogicalResult Verify(ConstantTensorOp op) {
140   auto input_type = op.arg().getType();
141   auto output_type = op.out().getType();
142 
143   if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
144     return success();
145   }
146 
147   auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
148   if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
149     op.emitError("output type should be static and ranked.");
150     return failure();
151   }
152 
153   if (output_tensor_type.getRank() == 0) {
154     bool same_scalar = output_tensor_type.getElementType() == input_type;
155     if (!same_scalar) {
156       op.emitError("input and output should have the same scalar types.");
157     }
158     return success(same_scalar);
159   }
160 
161   if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
162     bool same_element_type = output_tensor_type.getElementType() ==
163                              input_vector_type.getElementType();
164     bool same_shape =
165         output_tensor_type.getShape() == input_vector_type.getShape();
166     if (!same_element_type || !same_shape) {
167       op.emitError("input and output should have same shape and element type.");
168     }
169     return success(same_element_type && same_shape);
170   }
171 
172   op.emitError("input can not be converted to an output tensor.");
173   return failure();
174 }
175 
Verify(TFRFuncOp func)176 static LogicalResult Verify(TFRFuncOp func) {
177   // Collect all attribute names used by the tensor and tensor list arguments
178   // and returns. Also, collect the names of all the attribute arguments as the
179   // defined list. Later on, the used attribute names will be verified to be in
180   // the defined list.
181   llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
182 
183   // While scanning the arguments, record the start/end indices of each argument
184   // type, so the order can be verified as well.
185   // TODO(fengliuai): the attribute arguments with default values need to be
186   // at the end?
187   int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
188       last_tensor_list = -1, first_attr = -1;
189   for (auto arg : llvm::enumerate(func.getType().getInputs())) {
190     Type arg_type = arg.value();
191 
192     if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
193       if (first_tensor == -1) {
194         first_tensor = arg.index();
195       }
196       last_tensor = arg.index();
197       auto used = tensor.getAttrKeys();
198       input_used_attrs.append(used.begin(), used.end());
199       continue;
200     }
201 
202     if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
203       if (first_tensor_list == -1) {
204         first_tensor_list = arg.index();
205       }
206       last_tensor_list = arg.index();
207       auto used = tensor_list.getAttrKeys();
208       input_used_attrs.append(used.begin(), used.end());
209       continue;
210     }
211 
212     if (!arg_type.isa<TensorType>()) {
213       if (first_attr == -1) {
214         first_attr = arg.index();
215       }
216       auto name =
217           func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
218       if (!name) {
219         func.emitError(
220             llvm::Twine(arg.index()) +
221             " attribute argument doesn't have a tfr.name attribute.");
222         return failure();
223       }
224       continue;
225     }
226 
227     func.emitError("Builtin TensorType isn't allowed as the argument.");
228     return failure();
229   }
230 
231   // Collect all the undefined attributes used in the inputs.
232   llvm::SmallVector<StringAttr, 4> undefined_attrs;
233   for (auto attr : input_used_attrs) {
234     if (!func->getAttr(attr.getValue())) {
235       undefined_attrs.push_back(attr);
236     }
237   }
238 
239   // Verify the argument order: tensors, tensor list, attributes; and also
240   // verify there is at most one tensor list argument.
241   if (first_attr != -1 &&
242       (first_attr < last_tensor_list || first_attr < last_tensor)) {
243     func.emitError(
244         "tfr.tensor/tfr.tensor_list argument should be before non tensor "
245         "arguments.");
246     return failure();
247   }
248   // The order between tensor arguments and tensor list arguments and the number
249   // of tensor list arguments are verified only when they couldn't be determined
250   // by the attributes.
251   if (!undefined_attrs.empty()) {
252     if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
253       func.emitError(
254           "tfr.tensor argument should be before tfr.tensor_list argument.");
255       return failure();
256     }
257     if (first_tensor_list != last_tensor_list) {
258       func.emitError("More than one tfr.tensor_list argument isn't allowed.");
259       return failure();
260     }
261   }
262 
263   // Verify the result order: tensor, tensor list, and also verify at most one
264   // tensor list result.
265   int undefined_input_attrs_number = undefined_attrs.size();
266   bool seen_tensor_list = false, has_tensor_list_order_error = false,
267        has_multiple_tensor_lists_error = false;
268   for (auto result_type : func.getType().getResults()) {
269     if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
270       if (seen_tensor_list) {
271         has_tensor_list_order_error = true;
272       } else {
273         auto used = tensor.getAttrKeys();
274         output_used_attrs.append(used.begin(), used.end());
275       }
276       continue;
277     }
278 
279     if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
280       if (seen_tensor_list) {
281         has_multiple_tensor_lists_error = true;
282       } else {
283         seen_tensor_list = true;
284         auto used = tensor_list.getAttrKeys();
285         output_used_attrs.append(used.begin(), used.end());
286       }
287       continue;
288     }
289 
290     func.emitError(
291         "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
292         "result.");
293     return failure();
294   }
295 
296   // Collect all the undefined attributes used in the outputs.
297   for (auto attr : output_used_attrs) {
298     if (!func->getAttr(attr.getValue())) {
299       undefined_attrs.push_back(attr);
300     }
301   }
302 
303   // Verify there are no tensor/tensor list order error and multiple tensor
304   // list arguments error.
305   if (undefined_input_attrs_number != undefined_attrs.size()) {
306     if (has_tensor_list_order_error) {
307       func.emitError(
308           "tfr.tensor result should be before tfr.tensor_list result.");
309       return failure();
310     } else if (has_multiple_tensor_lists_error) {
311       func.emitError("More than one tfr.tensor_list result isn't allowed.");
312       return failure();
313     }
314   }
315 
316   // TODO(fengliuai): We might want to refine this constraint because the
317   // tensor element type can be derived.
318   if (!undefined_attrs.empty()) {
319     llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
320     std::transform(undefined_attrs.begin(), undefined_attrs.end(),
321                    attr_names.begin(),
322                    [](StringAttr attr) { return attr.getValue().str(); });
323     func.emitError(llvm::Twine("Undefined attributes are used: ",
324                                llvm::join(attr_names, ",")));
325     return failure();
326   }
327 
328   return success();
329 }
330 
ParseFuncOp(OpAsmParser & parser,OperationState * result)331 static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
332   auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
333                             ArrayRef<Type> results, impl::VariadicFlag,
334                             std::string &) {
335     return builder.getFunctionType(arg_types, results);
336   };
337   return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false,
338                                    build_func_type);
339 }
340 
PrintFuncOp(OpAsmPrinter & p,TFRFuncOp op)341 static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
342   FunctionType fn_type = op.getType();
343   impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false,
344                             fn_type.getResults());
345 }
346 
347 }  // namespace TFR
348 }  // namespace mlir
349 
350 //===----------------------------------------------------------------------===//
351 // TableGen'd op method definitions
352 //===----------------------------------------------------------------------===//
353 
354 #define GET_OP_CLASSES
355 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
356 
357 namespace mlir {
358 namespace TFR {
359 struct ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
360   using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
361 
matchAndRewritemlir::TFR::ConvertConstToTensorConst362   LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
363                                 PatternRewriter &rewriter) const override {
364     Location loc = cst_tensor_op.getLoc();
365     Type out_type = cst_tensor_op.getType();
366     Operation *new_cst = nullptr;
367 
368     ArrayAttr array;
369     if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
370       llvm::DenseSet<Type> all_types;
371       for (auto it : array) {
372         all_types.insert(it.getType());
373       }
374       if (all_types.size() != 1) return failure();
375       ShapedType new_out_type = RankedTensorType::get(
376           {static_cast<int64_t>(array.size())}, *all_types.begin());
377       DenseElementsAttr attr =
378           DenseElementsAttr::get(new_out_type, array.getValue());
379       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
380       if (out_type.isa<TFRTensorType>()) {
381         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
382       }
383       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
384       return success();
385     }
386 
387     Attribute scalar;
388     if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
389       Type new_out_type = RankedTensorType::get({}, scalar.getType());
390       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
391       if (out_type.isa<TFRTensorType>()) {
392         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
393       }
394       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
395       return success();
396     }
397     return failure();
398   }
399 };
400 
401 struct RemoveRedundantCast : public OpRewritePattern<CastOp> {
402   using OpRewritePattern<CastOp>::OpRewritePattern;
403 
matchAndRewritemlir::TFR::RemoveRedundantCast404   LogicalResult matchAndRewrite(CastOp cast_op,
405                                 PatternRewriter &rewriter) const override {
406     auto preceding_cast =
407         llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
408     if (!preceding_cast) {
409       return failure();
410     }
411     Value input = preceding_cast.arg();
412     Type input_type = input.getType();
413     Type output_type = cast_op.getType();
414 
415     // If the two types are the same, the back-to-back tfr.cast ops can be
416     // removed.
417     if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
418       rewriter.replaceOp(cast_op, {input});
419       return success();
420     }
421 
422     // If the rank of the input tensor isn't ranked, we replace the pair
423     // with tf.EnsureShape op so it can be removed after shape inference or
424     // confirmed at runtime.
425     if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
426       auto shape = output_type.cast<ShapedType>().getShape();
427       auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
428       rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
429                                                      input, shape_attr);
430     }
431 
432     return success();
433   }
434 };
435 
436 struct GetTensorShape : public OpRewritePattern<GetShapeOp> {
437   using OpRewritePattern<GetShapeOp>::OpRewritePattern;
438 
matchAndRewritemlir::TFR::GetTensorShape439   LogicalResult matchAndRewrite(GetShapeOp shape_op,
440                                 PatternRewriter &rewriter) const override {
441     Operation *preceding_op = shape_op.arg().getDefiningOp();
442     if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
443       // replace this pair by shape.shape_of, so the folding works.
444       rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
445       return success();
446     }
447     return failure();
448   }
449 };
450 
451 struct RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
452   using OpRewritePattern<GetElementOp>::OpRewritePattern;
453 
matchAndRewritemlir::TFR::RemoveRedundantGetElement454   LogicalResult matchAndRewrite(GetElementOp ge_op,
455                                 PatternRewriter &rewriter) const override {
456     IntegerAttr index;
457     if (!matchPattern(ge_op.index(), m_Constant(&index))) {
458       return failure();
459     }
460     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
461         ge_op.tensor_list().getDefiningOp());
462     if (!preceding_build_list ||
463         preceding_build_list.getNumOperands() <= index.getInt()) {
464       return failure();
465     }
466     Value input = preceding_build_list.getOperand(index.getInt());
467     Type output_type = ge_op.getType();
468     if (input.getType() != output_type &&
469         !output_type.isa<UnrankedTensorType>()) {
470       return failure();
471     }
472     rewriter.replaceOp(ge_op, {input});
473     return success();
474   }
475 };
476 
477 struct RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
478   using OpRewritePattern<GetLengthOp>::OpRewritePattern;
479 
matchAndRewritemlir::TFR::RemoveRedundantGetLength480   LogicalResult matchAndRewrite(GetLengthOp gl_op,
481                                 PatternRewriter &rewriter) const override {
482     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
483         gl_op.tensor_list().getDefiningOp());
484     if (!preceding_build_list) {
485       return failure();
486     }
487     int64_t num_tensors = preceding_build_list.getNumOperands();
488     rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
489                                             rewriter.getIndexAttr(num_tensors));
490     return success();
491   }
492 };
493 
494 struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
495   using OpRewritePattern<BuildListOp>::OpRewritePattern;
496 
matchAndRewritemlir::TFR::BuildConstantListAsAttr497   LogicalResult matchAndRewrite(BuildListOp bl_op,
498                                 PatternRewriter &rewriter) const override {
499     SmallVector<Attribute, 4> array_list;
500     array_list.reserve(bl_op.getNumOperands());
501     for (const auto &operand : bl_op.getOperands()) {
502       Attribute array_elt;
503       if (!matchPattern(operand, m_Constant(&array_elt))) {
504         return failure();
505       }
506       array_list.push_back(array_elt);
507     }
508     auto array_attr = rewriter.getArrayAttr(array_list);
509     rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
510     return success();
511   }
512 };
513 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)514 void ConstantTensorOp::getCanonicalizationPatterns(
515     OwningRewritePatternList &results, MLIRContext *context) {
516   results.insert<ConvertConstToTensorConst>(context);
517 }
518 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)519 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
520                                          MLIRContext *context) {
521   results.insert<RemoveRedundantCast>(context);
522 }
523 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)524 void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
525                                              MLIRContext *context) {
526   results.insert<GetTensorShape>(context);
527 }
528 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)529 void GetElementOp::getCanonicalizationPatterns(
530     OwningRewritePatternList &results, MLIRContext *context) {
531   results.insert<RemoveRedundantGetElement>(context);
532 }
533 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)534 void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
535                                               MLIRContext *context) {
536   results.insert<RemoveRedundantGetLength>(context);
537 }
538 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)539 void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
540                                               MLIRContext *context) {
541   results.insert<BuildConstantListAsAttr>(context);
542 }
543 
fold(ArrayRef<Attribute> operands)544 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
545   assert(operands.size() == 2 && "equal op has two operands");
546   auto ctx = getContext();
547   if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
548   return BoolAttr::get(ctx, false);
549 }
550 
fold(ArrayRef<Attribute> operands)551 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
552   assert(operands.empty() && "constant has no operands");
553 
554   // Return the held attribute value.
555   return value();
556 }
557 
558 // CallableOpInterface
getCallableRegion()559 Region *TFRFuncOp::getCallableRegion() {
560   return isExternal() ? nullptr : &body().front();
561 }
562 
563 // CallableOpInterface
getCallableResults()564 ArrayRef<Type> TFRFuncOp::getCallableResults() {
565   return getType().getResults();
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Dialect type definitions
570 //===----------------------------------------------------------------------===//
571 
572 // Parses a TFR type.
573 //   tfr_type ::= tensor_type | tensor_list_type | attr_type
574 //   string_list ::= `[` string-literal (, string-literal)+ `]`
575 //   tensor_type ::= `tensor`
576 //                 | `tensor<` (string-literal | string_list)  '>'
577 //   tensor_list_type ::= `tensor_list`
578 //                      | `tensor_list<` (string-literal | string_list)  '>'
579 //   attr_type ::= `attr`
parseType(DialectAsmParser & parser) const580 Type TFRDialect::parseType(DialectAsmParser &parser) const {
581   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
582   MLIRContext *ctx = loc.getContext();
583 
584   StringRef typeNameSpelling;
585   if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
586   llvm::SmallVector<StringAttr, 4> attrs;
587   if (succeeded(parser.parseOptionalLess())) {
588     bool l_square_parsed = false;
589     if (succeeded(parser.parseOptionalLSquare())) {
590       l_square_parsed = true;
591     }
592 
593     do {
594       StringRef attr;
595       if (failed(parser.parseKeyword(&attr))) return {};
596       attrs.push_back(StringAttr::get(ctx, attr));
597     } while (succeeded(parser.parseOptionalComma()));
598 
599     if (l_square_parsed && failed(parser.parseRSquare())) {
600       parser.emitError(parser.getNameLoc(), "expected ']'");
601     }
602 
603     if (failed(parser.parseGreater())) {
604       parser.emitError(parser.getNameLoc(), "expected '>'");
605     }
606   }
607 
608   if (typeNameSpelling == "tensor") {
609     return TFRTensorType::getChecked(attrs, loc);
610   } else if (typeNameSpelling == "tensor_list") {
611     return TFRTensorListType::getChecked(attrs, loc);
612   } else if (typeNameSpelling == "attr") {
613     return TFRAttrType::getChecked(loc);
614   } else {
615     parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
616     return {};
617   }
618 }
619 
printType(Type type,DialectAsmPrinter & os) const620 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
621   llvm::ArrayRef<StringAttr> attrs;
622 
623   if (type.isa<TFRAttrType>()) {
624     os << "attr";
625     return;
626   }
627   if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
628     attrs = tensor_ty.getAttrKeys();
629     os << "tensor";
630   } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
631     attrs = tensor_list_ty.getAttrKeys();
632     os << "tensor_list";
633   } else {
634     llvm_unreachable("Unhandled tfr type");
635   }
636 
637   if (attrs.empty()) return;
638   os << "<";
639 
640   if (attrs.size() > 1) {
641     os << "[";
642   }
643 
644   llvm::interleaveComma(attrs, os,
645                         [&](StringAttr attr) { os << attr.getValue(); });
646 
647   if (attrs.size() > 1) {
648     os << "]";
649   }
650   os << ">";
651 }
652 
653 }  // namespace TFR
654 }  // namespace mlir
655