1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17
18 #include <algorithm>
19 #include <string>
20
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/ADT/Twine.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
30 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
36 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
37 #include "mlir/IR/MLIRContext.h" // from @llvm-project
38 #include "mlir/IR/Matchers.h" // from @llvm-project
39 #include "mlir/IR/OpDefinition.h" // from @llvm-project
40 #include "mlir/IR/OpImplementation.h" // from @llvm-project
41 #include "mlir/IR/PatternMatch.h" // from @llvm-project
42 #include "mlir/IR/Types.h" // from @llvm-project
43 #include "mlir/Support/LogicalResult.h" // from @llvm-project
44 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
48 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
49
50 namespace mlir {
51
52 namespace TFR {
53
54 //===----------------------------------------------------------------------===//
55 // InlinerInterface
56 //===----------------------------------------------------------------------===//
57
58 namespace {
59 /// This class defines the interface for inlining within the TFR dialect.
60 struct TFRInlinerInterface : public DialectInlinerInterface {
61 using DialectInlinerInterface::DialectInlinerInterface;
62
63 // Allow all call operations to be inlined.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface64 bool isLegalToInline(Operation *call, Operation *callable,
65 bool wouldBeCloned) const final {
66 return true;
67 }
68 // Returns true if the given region 'src' can be inlined into the region
69 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface70 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71 BlockAndValueMapping &) const final {
72 return true;
73 }
74
75 // Returns true if the given operation 'op', that is registered to this
76 // dialect, can be inlined into the region 'dest' that is attached to an
77 // operation registered to the current dialect.
isLegalToInlinemlir::TFR::__anon23dd953f0111::TFRInlinerInterface78 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79 BlockAndValueMapping &) const final {
80 return true;
81 }
82
83 // Handle the given inlined terminator by replacing it with a new operation
84 // as necessary. Required when the region has only one block.
handleTerminatormlir::TFR::__anon23dd953f0111::TFRInlinerInterface85 void handleTerminator(Operation *op,
86 ArrayRef<Value> valuesToRepl) const final {
87 auto retValOp = dyn_cast<TFRReturnOp>(op);
88 if (!retValOp) return;
89
90 for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
91 std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
92 }
93 }
94
95 // Attempts to materialize a conversion for a type mismatch between a call
96 // from this dialect, and a callable region. This method should generate an
97 // operation that takes 'input' as the only operand, and produces a single
98 // result of 'resultType'. If a conversion can not be generated, nullptr
99 // should be returned.
materializeCallConversionmlir::TFR::__anon23dd953f0111::TFRInlinerInterface100 Operation *materializeCallConversion(OpBuilder &builder, Value input,
101 Type result_type,
102 Location conversion_loc) const final {
103 if (!result_type.isa<IntegerType>()) return nullptr;
104 return builder.create<TruncateIOp>(conversion_loc, result_type, input);
105 }
106 };
107 } // namespace
108
109 //===----------------------------------------------------------------------===//
110 // TFR Dialect
111 //===----------------------------------------------------------------------===//
112
TFRDialect(MLIRContext * context)113 TFRDialect::TFRDialect(MLIRContext *context)
114 : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
115 addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
116 addOperations<
117 #define GET_OP_LIST
118 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
119 >();
120
121 addInterfaces<TFRInlinerInterface>();
122 }
123
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)124 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
125 Type type, Location loc) {
126 if (ConstantOp::isBuildableWith(value, type))
127 return builder.create<ConstantOp>(loc, type, value);
128 return nullptr;
129 }
130
classof(Type type)131 bool TFRType::classof(Type type) {
132 return llvm::isa<TFRDialect>(type.getDialect());
133 }
134
135 //===----------------------------------------------------------------------===//
136 // Custom op methods
137 //===----------------------------------------------------------------------===//
138
Verify(ConstantTensorOp op)139 static LogicalResult Verify(ConstantTensorOp op) {
140 auto input_type = op.arg().getType();
141 auto output_type = op.out().getType();
142
143 if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
144 return success();
145 }
146
147 auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
148 if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
149 op.emitError("output type should be static and ranked.");
150 return failure();
151 }
152
153 if (output_tensor_type.getRank() == 0) {
154 bool same_scalar = output_tensor_type.getElementType() == input_type;
155 if (!same_scalar) {
156 op.emitError("input and output should have the same scalar types.");
157 }
158 return success(same_scalar);
159 }
160
161 if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
162 bool same_element_type = output_tensor_type.getElementType() ==
163 input_vector_type.getElementType();
164 bool same_shape =
165 output_tensor_type.getShape() == input_vector_type.getShape();
166 if (!same_element_type || !same_shape) {
167 op.emitError("input and output should have same shape and element type.");
168 }
169 return success(same_element_type && same_shape);
170 }
171
172 op.emitError("input can not be converted to an output tensor.");
173 return failure();
174 }
175
Verify(TFRFuncOp func)176 static LogicalResult Verify(TFRFuncOp func) {
177 // Collect all attribute names used by the tensor and tensor list arguments
178 // and returns. Also, collect the names of all the attribute arguments as the
179 // defined list. Later on, the used attribute names will be verified to be in
180 // the defined list.
181 llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
182
183 // While scanning the arguments, record the start/end indices of each argument
184 // type, so the order can be verified as well.
185 // TODO(fengliuai): the attribute arguments with default values need to be
186 // at the end?
187 int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
188 last_tensor_list = -1, first_attr = -1;
189 for (auto arg : llvm::enumerate(func.getType().getInputs())) {
190 Type arg_type = arg.value();
191
192 if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
193 if (first_tensor == -1) {
194 first_tensor = arg.index();
195 }
196 last_tensor = arg.index();
197 auto used = tensor.getAttrKeys();
198 input_used_attrs.append(used.begin(), used.end());
199 continue;
200 }
201
202 if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
203 if (first_tensor_list == -1) {
204 first_tensor_list = arg.index();
205 }
206 last_tensor_list = arg.index();
207 auto used = tensor_list.getAttrKeys();
208 input_used_attrs.append(used.begin(), used.end());
209 continue;
210 }
211
212 if (!arg_type.isa<TensorType>()) {
213 if (first_attr == -1) {
214 first_attr = arg.index();
215 }
216 auto name =
217 func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
218 if (!name) {
219 func.emitError(
220 llvm::Twine(arg.index()) +
221 " attribute argument doesn't have a tfr.name attribute.");
222 return failure();
223 }
224 continue;
225 }
226
227 func.emitError("Builtin TensorType isn't allowed as the argument.");
228 return failure();
229 }
230
231 // Collect all the undefined attributes used in the inputs.
232 llvm::SmallVector<StringAttr, 4> undefined_attrs;
233 for (auto attr : input_used_attrs) {
234 if (!func->getAttr(attr.getValue())) {
235 undefined_attrs.push_back(attr);
236 }
237 }
238
239 // Verify the argument order: tensors, tensor list, attributes; and also
240 // verify there is at most one tensor list argument.
241 if (first_attr != -1 &&
242 (first_attr < last_tensor_list || first_attr < last_tensor)) {
243 func.emitError(
244 "tfr.tensor/tfr.tensor_list argument should be before non tensor "
245 "arguments.");
246 return failure();
247 }
248 // The order between tensor arguments and tensor list arguments and the number
249 // of tensor list arguments are verified only when they couldn't be determined
250 // by the attributes.
251 if (!undefined_attrs.empty()) {
252 if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
253 func.emitError(
254 "tfr.tensor argument should be before tfr.tensor_list argument.");
255 return failure();
256 }
257 if (first_tensor_list != last_tensor_list) {
258 func.emitError("More than one tfr.tensor_list argument isn't allowed.");
259 return failure();
260 }
261 }
262
263 // Verify the result order: tensor, tensor list, and also verify at most one
264 // tensor list result.
265 int undefined_input_attrs_number = undefined_attrs.size();
266 bool seen_tensor_list = false, has_tensor_list_order_error = false,
267 has_multiple_tensor_lists_error = false;
268 for (auto result_type : func.getType().getResults()) {
269 if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
270 if (seen_tensor_list) {
271 has_tensor_list_order_error = true;
272 } else {
273 auto used = tensor.getAttrKeys();
274 output_used_attrs.append(used.begin(), used.end());
275 }
276 continue;
277 }
278
279 if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
280 if (seen_tensor_list) {
281 has_multiple_tensor_lists_error = true;
282 } else {
283 seen_tensor_list = true;
284 auto used = tensor_list.getAttrKeys();
285 output_used_attrs.append(used.begin(), used.end());
286 }
287 continue;
288 }
289
290 func.emitError(
291 "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
292 "result.");
293 return failure();
294 }
295
296 // Collect all the undefined attributes used in the outputs.
297 for (auto attr : output_used_attrs) {
298 if (!func->getAttr(attr.getValue())) {
299 undefined_attrs.push_back(attr);
300 }
301 }
302
303 // Verify there are no tensor/tensor list order error and multiple tensor
304 // list arguments error.
305 if (undefined_input_attrs_number != undefined_attrs.size()) {
306 if (has_tensor_list_order_error) {
307 func.emitError(
308 "tfr.tensor result should be before tfr.tensor_list result.");
309 return failure();
310 } else if (has_multiple_tensor_lists_error) {
311 func.emitError("More than one tfr.tensor_list result isn't allowed.");
312 return failure();
313 }
314 }
315
316 // TODO(fengliuai): We might want to refine this constraint because the
317 // tensor element type can be derived.
318 if (!undefined_attrs.empty()) {
319 llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
320 std::transform(undefined_attrs.begin(), undefined_attrs.end(),
321 attr_names.begin(),
322 [](StringAttr attr) { return attr.getValue().str(); });
323 func.emitError(llvm::Twine("Undefined attributes are used: ",
324 llvm::join(attr_names, ",")));
325 return failure();
326 }
327
328 return success();
329 }
330
ParseFuncOp(OpAsmParser & parser,OperationState * result)331 static ParseResult ParseFuncOp(OpAsmParser &parser, OperationState *result) {
332 auto build_func_type = [](Builder &builder, ArrayRef<Type> arg_types,
333 ArrayRef<Type> results, impl::VariadicFlag,
334 std::string &) {
335 return builder.getFunctionType(arg_types, results);
336 };
337 return impl::parseFunctionLikeOp(parser, *result, /*allowVariadic=*/false,
338 build_func_type);
339 }
340
PrintFuncOp(OpAsmPrinter & p,TFRFuncOp op)341 static void PrintFuncOp(OpAsmPrinter &p, TFRFuncOp op) {
342 FunctionType fn_type = op.getType();
343 impl::printFunctionLikeOp(p, op, fn_type.getInputs(), /*isVariadic=*/false,
344 fn_type.getResults());
345 }
346
347 } // namespace TFR
348 } // namespace mlir
349
350 //===----------------------------------------------------------------------===//
351 // TableGen'd op method definitions
352 //===----------------------------------------------------------------------===//
353
354 #define GET_OP_CLASSES
355 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
356
357 namespace mlir {
358 namespace TFR {
359 struct ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
360 using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
361
matchAndRewritemlir::TFR::ConvertConstToTensorConst362 LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
363 PatternRewriter &rewriter) const override {
364 Location loc = cst_tensor_op.getLoc();
365 Type out_type = cst_tensor_op.getType();
366 Operation *new_cst = nullptr;
367
368 ArrayAttr array;
369 if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
370 llvm::DenseSet<Type> all_types;
371 for (auto it : array) {
372 all_types.insert(it.getType());
373 }
374 if (all_types.size() != 1) return failure();
375 ShapedType new_out_type = RankedTensorType::get(
376 {static_cast<int64_t>(array.size())}, *all_types.begin());
377 DenseElementsAttr attr =
378 DenseElementsAttr::get(new_out_type, array.getValue());
379 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
380 if (out_type.isa<TFRTensorType>()) {
381 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
382 }
383 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
384 return success();
385 }
386
387 Attribute scalar;
388 if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
389 Type new_out_type = RankedTensorType::get({}, scalar.getType());
390 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
391 if (out_type.isa<TFRTensorType>()) {
392 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
393 }
394 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
395 return success();
396 }
397 return failure();
398 }
399 };
400
401 struct RemoveRedundantCast : public OpRewritePattern<CastOp> {
402 using OpRewritePattern<CastOp>::OpRewritePattern;
403
matchAndRewritemlir::TFR::RemoveRedundantCast404 LogicalResult matchAndRewrite(CastOp cast_op,
405 PatternRewriter &rewriter) const override {
406 auto preceding_cast =
407 llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
408 if (!preceding_cast) {
409 return failure();
410 }
411 Value input = preceding_cast.arg();
412 Type input_type = input.getType();
413 Type output_type = cast_op.getType();
414
415 // If the two types are the same, the back-to-back tfr.cast ops can be
416 // removed.
417 if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
418 rewriter.replaceOp(cast_op, {input});
419 return success();
420 }
421
422 // If the rank of the input tensor isn't ranked, we replace the pair
423 // with tf.EnsureShape op so it can be removed after shape inference or
424 // confirmed at runtime.
425 if (input_type.isa<UnrankedTensorType>() && output_type.isa<ShapedType>()) {
426 auto shape = output_type.cast<ShapedType>().getShape();
427 auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
428 rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
429 input, shape_attr);
430 }
431
432 return success();
433 }
434 };
435
436 struct GetTensorShape : public OpRewritePattern<GetShapeOp> {
437 using OpRewritePattern<GetShapeOp>::OpRewritePattern;
438
matchAndRewritemlir::TFR::GetTensorShape439 LogicalResult matchAndRewrite(GetShapeOp shape_op,
440 PatternRewriter &rewriter) const override {
441 Operation *preceding_op = shape_op.arg().getDefiningOp();
442 if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
443 // replace this pair by shape.shape_of, so the folding works.
444 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
445 return success();
446 }
447 return failure();
448 }
449 };
450
451 struct RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
452 using OpRewritePattern<GetElementOp>::OpRewritePattern;
453
matchAndRewritemlir::TFR::RemoveRedundantGetElement454 LogicalResult matchAndRewrite(GetElementOp ge_op,
455 PatternRewriter &rewriter) const override {
456 IntegerAttr index;
457 if (!matchPattern(ge_op.index(), m_Constant(&index))) {
458 return failure();
459 }
460 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
461 ge_op.tensor_list().getDefiningOp());
462 if (!preceding_build_list ||
463 preceding_build_list.getNumOperands() <= index.getInt()) {
464 return failure();
465 }
466 Value input = preceding_build_list.getOperand(index.getInt());
467 Type output_type = ge_op.getType();
468 if (input.getType() != output_type &&
469 !output_type.isa<UnrankedTensorType>()) {
470 return failure();
471 }
472 rewriter.replaceOp(ge_op, {input});
473 return success();
474 }
475 };
476
477 struct RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
478 using OpRewritePattern<GetLengthOp>::OpRewritePattern;
479
matchAndRewritemlir::TFR::RemoveRedundantGetLength480 LogicalResult matchAndRewrite(GetLengthOp gl_op,
481 PatternRewriter &rewriter) const override {
482 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
483 gl_op.tensor_list().getDefiningOp());
484 if (!preceding_build_list) {
485 return failure();
486 }
487 int64_t num_tensors = preceding_build_list.getNumOperands();
488 rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
489 rewriter.getIndexAttr(num_tensors));
490 return success();
491 }
492 };
493
494 struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
495 using OpRewritePattern<BuildListOp>::OpRewritePattern;
496
matchAndRewritemlir::TFR::BuildConstantListAsAttr497 LogicalResult matchAndRewrite(BuildListOp bl_op,
498 PatternRewriter &rewriter) const override {
499 SmallVector<Attribute, 4> array_list;
500 array_list.reserve(bl_op.getNumOperands());
501 for (const auto &operand : bl_op.getOperands()) {
502 Attribute array_elt;
503 if (!matchPattern(operand, m_Constant(&array_elt))) {
504 return failure();
505 }
506 array_list.push_back(array_elt);
507 }
508 auto array_attr = rewriter.getArrayAttr(array_list);
509 rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
510 return success();
511 }
512 };
513
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)514 void ConstantTensorOp::getCanonicalizationPatterns(
515 OwningRewritePatternList &results, MLIRContext *context) {
516 results.insert<ConvertConstToTensorConst>(context);
517 }
518
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)519 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
520 MLIRContext *context) {
521 results.insert<RemoveRedundantCast>(context);
522 }
523
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)524 void GetShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
525 MLIRContext *context) {
526 results.insert<GetTensorShape>(context);
527 }
528
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)529 void GetElementOp::getCanonicalizationPatterns(
530 OwningRewritePatternList &results, MLIRContext *context) {
531 results.insert<RemoveRedundantGetElement>(context);
532 }
533
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)534 void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
535 MLIRContext *context) {
536 results.insert<RemoveRedundantGetLength>(context);
537 }
538
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)539 void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
540 MLIRContext *context) {
541 results.insert<BuildConstantListAsAttr>(context);
542 }
543
fold(ArrayRef<Attribute> operands)544 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
545 assert(operands.size() == 2 && "equal op has two operands");
546 auto ctx = getContext();
547 if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
548 return BoolAttr::get(ctx, false);
549 }
550
fold(ArrayRef<Attribute> operands)551 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
552 assert(operands.empty() && "constant has no operands");
553
554 // Return the held attribute value.
555 return value();
556 }
557
558 // CallableOpInterface
getCallableRegion()559 Region *TFRFuncOp::getCallableRegion() {
560 return isExternal() ? nullptr : &body().front();
561 }
562
563 // CallableOpInterface
getCallableResults()564 ArrayRef<Type> TFRFuncOp::getCallableResults() {
565 return getType().getResults();
566 }
567
568 //===----------------------------------------------------------------------===//
569 // Dialect type definitions
570 //===----------------------------------------------------------------------===//
571
572 // Parses a TFR type.
573 // tfr_type ::= tensor_type | tensor_list_type | attr_type
574 // string_list ::= `[` string-literal (, string-literal)+ `]`
575 // tensor_type ::= `tensor`
576 // | `tensor<` (string-literal | string_list) '>'
577 // tensor_list_type ::= `tensor_list`
578 // | `tensor_list<` (string-literal | string_list) '>'
579 // attr_type ::= `attr`
parseType(DialectAsmParser & parser) const580 Type TFRDialect::parseType(DialectAsmParser &parser) const {
581 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
582 MLIRContext *ctx = loc.getContext();
583
584 StringRef typeNameSpelling;
585 if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
586 llvm::SmallVector<StringAttr, 4> attrs;
587 if (succeeded(parser.parseOptionalLess())) {
588 bool l_square_parsed = false;
589 if (succeeded(parser.parseOptionalLSquare())) {
590 l_square_parsed = true;
591 }
592
593 do {
594 StringRef attr;
595 if (failed(parser.parseKeyword(&attr))) return {};
596 attrs.push_back(StringAttr::get(ctx, attr));
597 } while (succeeded(parser.parseOptionalComma()));
598
599 if (l_square_parsed && failed(parser.parseRSquare())) {
600 parser.emitError(parser.getNameLoc(), "expected ']'");
601 }
602
603 if (failed(parser.parseGreater())) {
604 parser.emitError(parser.getNameLoc(), "expected '>'");
605 }
606 }
607
608 if (typeNameSpelling == "tensor") {
609 return TFRTensorType::getChecked(attrs, loc);
610 } else if (typeNameSpelling == "tensor_list") {
611 return TFRTensorListType::getChecked(attrs, loc);
612 } else if (typeNameSpelling == "attr") {
613 return TFRAttrType::getChecked(loc);
614 } else {
615 parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
616 return {};
617 }
618 }
619
printType(Type type,DialectAsmPrinter & os) const620 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
621 llvm::ArrayRef<StringAttr> attrs;
622
623 if (type.isa<TFRAttrType>()) {
624 os << "attr";
625 return;
626 }
627 if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
628 attrs = tensor_ty.getAttrKeys();
629 os << "tensor";
630 } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
631 attrs = tensor_list_ty.getAttrKeys();
632 os << "tensor_list";
633 } else {
634 llvm_unreachable("Unhandled tfr type");
635 }
636
637 if (attrs.empty()) return;
638 os << "<";
639
640 if (attrs.size() > 1) {
641 os << "[";
642 }
643
644 llvm::interleaveComma(attrs, os,
645 [&](StringAttr attr) { os << attr.getValue(); });
646
647 if (attrs.size() > 1) {
648 os << "]";
649 }
650 os << ">";
651 }
652
653 } // namespace TFR
654 } // namespace mlir
655