1//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the definition file for the structured interface for Linalg ops. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE 14#define LINALG_IR_STRUCTURED_OPS_INTERFACE 15 16include "mlir/Dialect/Linalg/IR/LinalgBase.td" 17 18// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' 19// interface. 20def LinalgStructuredInterface : OpInterface<"LinalgOp"> { 21 let cppNamespace = "::mlir::linalg"; 22 let methods = [ 23 //===------------------------------------------------------------------===// 24 // Loop types handling. 25 //===------------------------------------------------------------------===// 26 InterfaceMethod< 27 /*desc=*/[{ 28 Return the number of parallel loops. 29 }], 30 /*retTy=*/"unsigned", 31 /*methodName=*/"getNumParallelLoops", 32 /*args=*/(ins), 33 /*methodBody=*/"", 34 /*defaultImplementation=*/[{ 35 return getNumIterators(getParallelIteratorTypeName(), 36 $_op.iterator_types()); 37 }] 38 >, 39 InterfaceMethod< 40 /*desc=*/[{ 41 Return the dims that are parallel loops. 42 }], 43 /*retTy=*/"void", 44 /*methodName=*/"getParallelDims", 45 /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res), 46 /*methodBody=*/"", 47 /*defaultImplementation=*/[{ 48 return getDimsOfType($_op, getParallelIteratorTypeName(), res); 49 }] 50 >, 51 InterfaceMethod< 52 /*desc=*/[{ 53 Return the number of reduction loops. 54 }], 55 /*retTy=*/"unsigned", 56 /*methodName=*/"getNumReductionLoops", 57 /*args=*/(ins), 58 /*methodBody=*/"", 59 /*defaultImplementation=*/[{ 60 return getNumIterators(getReductionIteratorTypeName(), 61 $_op.iterator_types()); 62 }] 63 >, 64 InterfaceMethod< 65 /*desc=*/[{ 66 Return the dims that are reduction loops. 67 }], 68 /*retTy=*/"void", 69 /*methodName=*/"getReductionDims", 70 /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res), 71 /*methodBody=*/"", 72 /*defaultImplementation=*/[{ 73 return getDimsOfType($_op, getReductionIteratorTypeName(), res); 74 }] 75 >, 76 InterfaceMethod< 77 /*desc=*/[{ 78 Return the number of window loops. 79 }], 80 /*retTy=*/"unsigned", 81 /*methodName=*/"getNumWindowLoops", 82 /*args=*/(ins), 83 /*methodBody=*/"", 84 /*defaultImplementation=*/[{ 85 return getNumIterators(getWindowIteratorTypeName(), 86 $_op.iterator_types()); 87 }] 88 >, 89 InterfaceMethod< 90 /*desc=*/[{ 91 Return the dims that are window loops. 92 }], 93 /*retTy=*/"void", 94 /*methodName=*/"getWindowDims", 95 /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res), 96 /*methodBody=*/"", 97 /*defaultImplementation=*/[{ 98 return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res); 99 }] 100 >, 101 InterfaceMethod< 102 /*desc=*/[{ 103 Return the total number of loops within the current operation. 104 }], 105 /*retTy=*/"unsigned", 106 /*methodName=*/"getNumLoops", 107 /*args=*/(ins), 108 /*methodBody=*/"", 109 /*defaultImplementation=*/[{ 110 return getNumIterators($_op.iterator_types()); 111 }] 112 >, 113 InterfaceMethod< 114 /*desc=*/[{ 115 Returns true if the current operation has only one loop and it's a 116 reduction loop. 117 }], 118 /*retTy=*/"bool", 119 /*methodName=*/"hasSingleReductionLoop", 120 /*args=*/(ins), 121 /*methodBody=*/"", 122 /*defaultImplementation=*/[{ 123 auto iters = $_op.iterator_types(); 124 return iters.size() == 1 && 125 getNumIterators(getReductionIteratorTypeName(), iters) == 1; 126 }]>, 127 //===------------------------------------------------------------------===// 128 // Num input/output/initTensors arguments handling. 129 //===------------------------------------------------------------------===// 130 // These special methods must be defined by each op that wants to implement 131 // the LinalgStructuredInterface. For now, this is either: 132 // - Explicitly specified in the op definition. 133 // - Derived from variadic attributes (for "named" ops, linalg.generic and 134 // linalg.indexed_generic ops). 135 InterfaceMethod< 136 /*desc=*/[{ 137 Return the number of inputs. 138 }], 139 /*retTy=*/"unsigned", 140 /*methodName=*/"getNumInputs" 141 >, 142 InterfaceMethod< 143 /*desc=*/[{ 144 Return the number of init tensors. 145 }], 146 /*retTy=*/"unsigned", 147 /*methodName=*/"getNumInitTensors" 148 >, 149 InterfaceMethod< 150 /*desc=*/[{ 151 Return the number of outputs. 152 }], 153 /*retTy=*/"unsigned", 154 /*methodName=*/"getNumOutputs" 155 >, 156 //===------------------------------------------------------------------===// 157 // Input arguments handling. 158 //===------------------------------------------------------------------===// 159 InterfaceMethod< 160 /*desc=*/[{ 161 Return the `i`-th input value. 162 The `i^th` input argument is always the `i^th` operand regardless of 163 whether we have tensors or buffers. 164 }], 165 /*retTy=*/"Value", 166 /*methodName=*/"getInput", 167 /*args=*/(ins "unsigned":$i), 168 /*methodBody=*/"", 169 /*defaultImplementation=*/[{ 170 assert(i < $_op.getNumInputs()); 171 return this->getOperation()->getOperand(i); 172 }] 173 >, 174 InterfaceMethod< 175 /*desc=*/[{ 176 Return the index of the given input value `v`, or `None` if the value is 177 not an input. 178 }], 179 /*retTy=*/"llvm::Optional<unsigned>", 180 /*methodName=*/"getIndexOfInput", 181 /*args=*/(ins "Value":$value), 182 /*methodBody=*/"", 183 /*defaultImplementation=*/[{ 184 auto it = llvm::find(getInputs(), value); 185 if (it != getInputs().end()) 186 return it - getInputs().begin(); 187 return llvm::None; 188 }] 189 >, 190 InterfaceMethod< 191 /*desc=*/[{ 192 Return the `i`-th input shaped type, irrespective of buffer or tensor 193 type. 194 }], 195 /*retTy=*/"ShapedType", 196 /*methodName=*/"getInputShapedType", 197 /*args=*/(ins "unsigned":$i), 198 /*methodBody=*/"", 199 /*defaultImplementation=*/[{ 200 return getInput(i).getType().template cast<ShapedType>(); 201 }] 202 >, 203 InterfaceMethod< 204 /*desc=*/[{ 205 Return the input operands. 206 }], 207 /*retTy=*/"Operation::operand_range", 208 /*methodName=*/"getInputs", 209 /*args=*/(ins), 210 /*methodBody=*/"", 211 /*defaultImplementation=*/[{ 212 auto range = this->getOperation()->getOperands(); 213 return {range.begin(), range.begin() + $_op.getNumInputs()}; 214 }] 215 >, 216 InterfaceMethod< 217 /*desc=*/[{ 218 Return the range over the input operands that are of buffer type. 219 }], 220 /*retTy=*/"SmallVector<Value, 4>", 221 /*methodName=*/"getInputBuffers", 222 /*args=*/(ins), 223 /*methodBody=*/"", 224 /*defaultImplementation=*/[{ 225 return llvm::to_vector<4>(llvm::make_filter_range( 226 getInputs(), [](Value in){ return in.getType().isa<MemRefType>(); })); 227 }] 228 >, 229 InterfaceMethod< 230 /*desc=*/[{ 231 Return the subset of input operands that are of ranked tensor type. 232 }], 233 /*retTy=*/"SmallVector<RankedTensorType, 4>", 234 /*methodName=*/"getInputTensorTypes" , 235 /*args=*/(ins), 236 /*methodBody=*/"", 237 /*defaultImplementation=*/[{ 238 SmallVector<RankedTensorType, 4> res; 239 for (Type type : getInputs().getTypes()) 240 if (auto t = type.template dyn_cast<RankedTensorType>()) 241 res.push_back(t); 242 return res; 243 }] 244 >, 245 //===------------------------------------------------------------------===// 246 // Output arguments handling. 247 //===------------------------------------------------------------------===// 248 InterfaceMethod< 249 /*desc=*/[{ 250 Return the output buffer at the given index, asserts that this is a 251 buffer operand and not a tensor result. 252 The `i^th` output argument is an operand (resp. a return value) iff it 253 is a value of buffer type (resp. a return value of tensor type). 254 }], 255 /*retTy=*/"Value", 256 /*methodName=*/"getOutputBuffer", 257 /*args=*/(ins "unsigned":$i), 258 /*methodBody=*/"", 259 /*defaultImplementation=*/[{ 260 // Output buffers are passed as output buffer operands (side-effecting). 261 // Output tensors are results. 262 // The union of the 2 are all the outputs and we want to ensure i does 263 // not overflow the buffer operands. 264 assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs() 265 && "overflowing output buffer index"); 266 return this->getOperation()->getOperand($_op.getNumInputs() + i); 267 }] 268 >, 269 InterfaceMethod< 270 /*desc=*/[{ 271 Return the index of the given buffer value, or `None` if the value is 272 not part of the output buffers. 273 }], 274 /*retTy=*/"llvm::Optional<unsigned>", 275 /*methodName=*/"getIndexOfOutputBuffer", 276 /*args=*/(ins "Value":$value), 277 /*methodBody=*/"", 278 /*defaultImplementation=*/[{ 279 auto it = llvm::find(getOutputBuffers(), value); 280 if (it != getOutputBuffers().end()) 281 return it - getOutputBuffers().begin(); 282 return llvm::None; 283 }] 284 >, 285 InterfaceMethod< 286 /*desc=*/[{ 287 Return the type of the output buffer at the given index. 288 }], 289 /*retTy=*/"MemRefType", 290 /*methodName=*/"getOutputBufferType", 291 /*args=*/(ins "unsigned":$i), 292 /*methodBody=*/"", 293 /*defaultImplementation=*/[{ 294 return getOutputBuffer(i).getType().template cast<MemRefType>(); 295 }]>, 296 InterfaceMethod< 297 /*desc=*/[{ 298 Return the `i`-th output shaped type, irrespective of buffer or tensor 299 type. 300 }], 301 /*retTy=*/"ShapedType", 302 /*methodName=*/"getOutputShapedType", 303 /*args=*/(ins "unsigned":$i), 304 /*methodBody=*/"", 305 /*defaultImplementation=*/[{ 306 return getShapedType(i + $_op.getNumInputs()); 307 }]>, 308 InterfaceMethod< 309 /*desc=*/[{ 310 Return the results that are of ranked tensor type. 311 }], 312 /*retTy=*/"SmallVector<RankedTensorType, 4>", 313 /*methodName=*/"getOutputTensorTypes", 314 /*args=*/(ins), 315 /*methodBody=*/"", 316 /*defaultImplementation=*/[{ 317 SmallVector<RankedTensorType, 4> res; 318 for (Type type : this->getOperation()->getResults().getTypes()) 319 res.push_back(type.template cast<RankedTensorType>()); 320 return res; 321 }]>, 322 InterfaceMethod< 323 /*desc=*/[{ 324 Return the output buffers (operands). 325 }], 326 /*retTy=*/"Operation::operand_range", 327 /*methodName=*/"getOutputBuffers", 328 /*args=*/(ins), 329 /*methodBody=*/"", 330 /*defaultImplementation=*/[{ 331 auto range = this->getOperation()->getOperands(); 332 return {range.begin() + $_op.getNumInputs(), 333 range.begin() + getNumInputsAndOutputBuffers()}; 334 }] 335 >, 336 337 //===------------------------------------------------------------------===// 338 // Input and Output arguments handling. 339 //===------------------------------------------------------------------===// 340 InterfaceMethod< 341 /*desc=*/[{ 342 Return one single buffer at position `$i`. 343 }], 344 /*retTy=*/"Value", 345 /*methodName=*/"getBuffer", 346 /*args=*/(ins "unsigned":$i), 347 /*methodBody=*/"", 348 /*defaultImplementation=*/[{ 349 assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index"); 350 return this->getOperation()->getOperand(i); 351 }] 352 >, 353 InterfaceMethod< 354 /*desc=*/[{ 355 Return the number of output buffers 356 }], 357 /*retTy=*/"unsigned", 358 /*methodName=*/"getNumOutputBuffers", 359 /*args=*/(ins), 360 /*methodBody=*/"", 361 /*defaultImplementation=*/[{ 362 return $_op.getNumOutputs() - this->getOperation()->getNumResults(); 363 }] 364 >, 365 InterfaceMethod< 366 /*desc=*/[{ 367 Return the number of inputs and outputs, irrespective of their buffer or 368 tensor type. 369 }], 370 /*retTy=*/"unsigned", 371 /*methodName=*/"getNumInputsAndOutputs", 372 /*args=*/(ins), 373 /*methodBody=*/"", 374 /*defaultImplementation=*/[{ 375 return $_op.getNumInputs() + $_op.getNumOutputs(); 376 }] 377 >, 378 InterfaceMethod< 379 /*desc=*/[{ 380 Return the number of inputs, irrespective of their buffer or tensor type 381 and output buffers 382 }], 383 /*retTy=*/"unsigned", 384 /*methodName=*/"getNumInputsAndOutputBuffers", 385 /*args=*/(ins), 386 /*methodBody=*/"", 387 /*defaultImplementation=*/[{ 388 return $_op.getNumInputs() + $_op.getNumOutputs() - 389 this->getOperation()->getNumResults(); 390 }] 391 >, 392 InterfaceMethod< 393 /*desc=*/[{ 394 Return the range over inputs (irrespective of type) and output buffers. 395 }], 396 /*retTy=*/"Operation::operand_range", 397 /*methodName=*/"getInputsAndOutputBuffers", 398 /*args=*/(ins), 399 /*methodBody=*/"", 400 /*defaultImplementation=*/[{ 401 auto range = this->getOperation()->getOperands(); 402 return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; 403 }] 404 >, 405 InterfaceMethod< 406 /*desc=*/[{ 407 Return the range over init tensors. 408 }], 409 /*retTy=*/"Operation::operand_range", 410 /*methodName=*/"getInitTensors", 411 /*args=*/(ins), 412 /*methodBody=*/"", 413 /*defaultImplementation=*/[{ 414 auto range = this->getOperation()->getOperands(); 415 auto base = range.begin() + getNumInputsAndOutputBuffers(); 416 return {base, base + $_op.getNumInitTensors()}; 417 }] 418 >, 419 InterfaceMethod< 420 /*desc=*/[{ 421 Return one single init tensor at position `$i`. 422 }], 423 /*retTy=*/"Value", 424 /*methodName=*/"getInitTensor", 425 /*args=*/(ins "unsigned":$i), 426 /*methodBody=*/"", 427 /*defaultImplementation=*/[{ 428 assert(i < $_op.getNumInitTensors() && "overflowing init tensor index"); 429 return getInitTensors()[i]; 430 }] 431 >, 432 InterfaceMethod< 433 /*desc=*/[{ 434 Return true if the shaped operand index `i` is the index of an init 435 tensor. 436 }], 437 /*retTy=*/"bool", 438 /*methodName=*/"isIndexOfAnInitTensor", 439 /*args=*/(ins "unsigned":$i), 440 /*methodBody=*/"", 441 /*defaultImplementation=*/[{ 442 assert(i < $_op.getNumShapedOperands() && "overflowing shaped operand index"); 443 return i >= $_op.getNumInputs() + getNumOutputBuffers(); 444 }] 445 >, 446 InterfaceMethod< 447 /*desc=*/[{ 448 Return the relative init tensor index of the shaped operand index. 449 }], 450 /*retTy=*/"unsigned", 451 /*methodName=*/"getInitTensorIndexFromShapedIndex", 452 /*args=*/(ins "unsigned":$i), 453 /*methodBody=*/"", 454 /*defaultImplementation=*/[{ 455 assert(isIndexOfAnInitTensor(i) && "expected an init tensor index"); 456 return i - $_op.getNumInputs() - getNumOutputBuffers(); 457 }] 458 >, 459 InterfaceMethod< 460 /*desc=*/[{ 461 Return the index of the given init tensor value, or `None` if the value 462 is not part of the init tensors. 463 }], 464 /*retTy=*/"llvm::Optional<unsigned>", 465 /*methodName=*/"getIndexOfInitTensor", 466 /*args=*/(ins "Value":$value), 467 /*methodBody=*/"", 468 /*defaultImplementation=*/[{ 469 auto it = llvm::find(getInitTensors(), value); 470 if (it != getInitTensors().end()) 471 return it - getInitTensors().begin(); 472 return llvm::None; 473 }] 474 >, 475 InterfaceMethod< 476 /*desc=*/[{ 477 Return the number of inputs, output buffers and init tensors operands. 478 }], 479 /*retTy=*/"unsigned", 480 /*methodName=*/"getNumShapedOperands", 481 /*args=*/(ins), 482 /*methodBody=*/"", 483 /*defaultImplementation=*/[{ 484 return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors(); 485 }] 486 >, 487 InterfaceMethod< 488 /*desc=*/[{ 489 Return the `i`-th shaped operand value, which can be an arbitrary input 490 tensor/buffer, init tensor or output buffer. 491 }], 492 /*retTy=*/"Value", 493 /*methodName=*/"getShapedOperand", 494 /*args=*/(ins "unsigned":$i), 495 /*methodBody=*/"", 496 /*defaultImplementation=*/[{ 497 assert(i < $_op.getNumShapedOperands()); 498 return this->getOperation()->getOperand(i); 499 }] 500 >, 501 InterfaceMethod< 502 /*desc=*/[{ 503 Return the range over inputs, output buffers and init tensors. 504 }], 505 /*retTy=*/"Operation::operand_range", 506 /*methodName=*/"getShapedOperands", 507 /*args=*/(ins), 508 /*methodBody=*/"", 509 /*defaultImplementation=*/[{ 510 auto range = this->getOperation()->getOperands(); 511 return {range.begin(), range.begin() + getNumShapedOperands()}; 512 }] 513 >, 514 InterfaceMethod< 515 /*desc=*/[{ 516 Return the `i`-th shaped type, there are 3 cases: 517 1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`; 518 otherwise 519 2. if `i < getNumInputsAndOutputBuffers()` then return the 520 `getOutputBufferType(i - $_op.getNumInputs())`; otherwise 521 3. return the `i - getNumInputsAndOutputBuffers()` result type. 522 }], 523 /*retTy=*/"ShapedType", 524 /*methodName=*/"getShapedType", 525 /*args=*/(ins "unsigned":$i), 526 /*methodBody=*/"", 527 /*defaultImplementation=*/[{ 528 if (i < $_op.getNumInputs()) 529 return getInputShapedType(i); 530 if (i < getNumInputsAndOutputBuffers()) 531 return getOutputBufferType(i - $_op.getNumInputs()); 532 return this->getOperation()->getResult( 533 i - getNumInputsAndOutputBuffers()). 534 getType().template cast<ShapedType>(); 535 }]>, 536 InterfaceMethod< 537 /*desc=*/[{ 538 Return the shaped types for all the inputs and outputs 539 }], 540 /*retTy=*/"SmallVector<ShapedType, 4>", 541 /*methodName=*/"getInputOutputShapedTypes", 542 /*args=*/(ins), 543 /*methodBody=*/"", 544 /*defaultImplementation=*/[{ 545 SmallVector<Type, 4> inputOutputTypes( 546 this->getOperation()->operand_type_begin(), 547 this->getOperation()->operand_type_end()); 548 inputOutputTypes.append(this->getOperation()->result_type_begin(), 549 this->getOperation()->result_type_end()); 550 return llvm::to_vector<4>( 551 llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { 552 return type.cast<ShapedType>(); 553 })); 554 }] 555 >, 556 InterfaceMethod< 557 /*desc=*/[{ 558 Return the first position of the shaped operand in the operand list. 559 }], 560 /*retTy=*/"Optional<unsigned>", 561 /*methodName=*/"getIndexOfShapedOperand", 562 /*args=*/(ins "Value":$value), 563 /*methodBody=*/"", 564 /*defaultImplementation=*/[{ 565 Optional<unsigned> inputIndex = getIndexOfInput(value); 566 if (inputIndex.hasValue()) return inputIndex.getValue(); 567 Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value); 568 if (outputIndex.hasValue()) 569 return $_op.getNumInputs() + outputIndex.getValue(); 570 Optional<unsigned> initTensorIndex = getIndexOfInitTensor(value); 571 if (initTensorIndex.hasValue()) 572 return $_op.getNumInputs() + $_op.getNumOutputBuffers() + initTensorIndex.getValue(); 573 return llvm::None; 574 }] 575 >, 576 InterfaceMethod< 577 /*desc=*/[{ 578 Returns the operand index given the input index. Returns None 579 of the input index is invalid. 580 }], 581 /*retTy=*/"Optional<unsigned>", 582 /*methodName=*/"getOperandIndexForInputIndex", 583 /*args=*/(ins "unsigned":$input_index), 584 /*methodBody=*/"", 585 /*defaultImplementation=*/[{ 586 if (input_index >= $_op.getNumInputs()) 587 return llvm::None; 588 return input_index; 589 }] 590 >, 591 InterfaceMethod< 592 /*desc=*/[{ 593 Returns the operand index given the output index. Returns None 594 of the output index is invalid. 595 }], 596 /*retTy=*/"Optional<unsigned>", 597 /*methodName=*/"getOperandIndexForOutputIndex", 598 /*args=*/(ins "unsigned":$output_index), 599 /*methodBody=*/"", 600 /*defaultImplementation=*/[{ 601 if (output_index >= $_op.getNumOutputs()) 602 return llvm::None; 603 return output_index + $_op.getNumInputs(); 604 }] 605 >, 606 InterfaceMethod< 607 /*desc=*/[{ 608 Returns the input index given the operand index. Return None 609 if the operand index doesnt corresponding to an input. 610 }], 611 /*retTy=*/"Optional<unsigned>", 612 /*methodName=*/"getInputIndex", 613 /*args=*/(ins "unsigned":$operand_index), 614 /*methodBody=*/"", 615 /*defaultImplementation=*/[{ 616 if (operand_index >= $_op.getNumInputs()) 617 return llvm::None; 618 return operand_index; 619 }] 620 >, 621 InterfaceMethod< 622 /*desc=*/[{ 623 Returns the output index given the operand index. Return None 624 if the operand index doesnt corresponding to an output. 625 }], 626 /*retTy=*/"Optional<unsigned>", 627 /*methodName=*/"getOutputIndex", 628 /*args=*/(ins "unsigned":$operand_index), 629 /*methodBody=*/"", 630 /*defaultImplementation=*/[{ 631 if (operand_index < $_op.getNumInputs() || 632 operand_index >= $_op.getNumInputs() + $_op.getNumOutputs()) 633 return llvm::None; 634 return operand_index - $_op.getNumInputs(); 635 }] 636 >, 637 638 //===------------------------------------------------------------------===// 639 // Other interface methods. 640 //===------------------------------------------------------------------===// 641 InterfaceMethod< 642 /*desc=*/[{ 643 Return the iterator types attribute within the current operation. 644 }], 645 /*retTy=*/"ArrayAttr", 646 /*methodName=*/"iterator_types", 647 /*args=*/(ins), 648 /*methodBody=*/"", 649 /*defaultImplementation=*/[{ 650 return $_op.iterator_types(); 651 }] 652 >, 653 InterfaceMethod< 654 /*desc=*/[{ 655 Return the indexing maps attribute within the current operation. 656 }], 657 /*retTy=*/"ArrayAttr", 658 /*methodName=*/"indexing_maps" 659 >, 660 InterfaceMethod< 661 /*desc=*/[{ 662 Return the indexing maps within the current operation. 663 }], 664 /*retTy=*/"SmallVector<AffineMap, 4>", 665 /*methodName=*/"getIndexingMaps", 666 /*args=*/(ins), 667 /*methodBody=*/"", 668 /*defaultImplementation=*/[{ 669 return llvm::to_vector<4>( 670 $_op.indexing_maps().template getAsValueRange<AffineMapAttr>()); 671 }] 672 >, 673 InterfaceMethod< 674 /*desc=*/[{ 675 Return the input or output indexing map at index `i`. 676 }], 677 /*retTy=*/"AffineMap", 678 /*methodName=*/"getIndexingMap", 679 /*args=*/(ins "unsigned":$i), 680 /*methodBody=*/"", 681 /*defaultImplementation=*/[{ 682 assert(i < getNumInputsAndOutputs()); 683 return getIndexingMaps()[i]; 684 }] 685 >, 686 InterfaceMethod< 687 /*desc=*/[{ 688 Return the input indexing map at index `i`. 689 }], 690 /*retTy=*/"AffineMap", 691 /*methodName=*/"getInputIndexingMap", 692 /*args=*/(ins "unsigned":$i), 693 /*methodBody=*/"", 694 /*defaultImplementation=*/[{ 695 assert(i < $_op.getNumInputs()); 696 return getIndexingMaps()[i]; 697 }] 698 >, 699 InterfaceMethod< 700 /*desc=*/[{ 701 Return the output indexing map at index `i`. 702 }], 703 /*retTy=*/"AffineMap", 704 /*methodName=*/"getOutputIndexingMap", 705 /*args=*/(ins "unsigned":$i), 706 /*methodBody=*/"", 707 /*defaultImplementation=*/[{ 708 assert(i < $_op.getNumOutputs()); 709 return getIndexingMaps()[i + $_op.getNumInputs()]; 710 }] 711 >, 712 InterfaceMethod< 713 /*desc=*/[{ 714 Return whether the op has only MemRef input and outputs. 715 }], 716 /*retTy=*/"bool", 717 /*methodName=*/"hasBufferSemantics", 718 /*args=*/(ins), 719 /*methodBody=*/"", 720 /*defaultImplementation=*/[{ 721 return this->getOperation()->getNumResults() == 0 && 722 llvm::all_of(getInputs(), 723 [](Value v) { return v.getType().isa<MemRefType>(); }); 724 }] 725 >, 726 InterfaceMethod< 727 /*desc=*/[{ 728 Return whether the op has only RankedTensor input and outputs. 729 }], 730 /*retTy=*/"bool", 731 /*methodName=*/"hasTensorSemantics", 732 /*args=*/(ins), 733 /*methodBody=*/"", 734 /*defaultImplementation=*/[{ 735 auto isTensorType = [](Value v) { 736 return v.getType().isa<RankedTensorType>(); 737 }; 738 return llvm::all_of(getInputs(), isTensorType) && 739 llvm::all_of(this->getOperation()->getResults(), isTensorType); 740 }] 741 >, 742 InterfaceMethod< 743 /*desc=*/[{ 744 Return whether the op has sparse tensor semantics. 745 }], 746 /*retTy=*/"bool", 747 /*methodName=*/"hasSparseSemantics", 748 /*args=*/(ins), 749 /*methodBody=*/"", 750 /*defaultImplementation=*/[{ 751 return $_op.getAttr(getSparseAttrName()).template dyn_cast_or_null<ArrayAttr>() != nullptr; 752 }] 753 >, 754 InterfaceMethod< 755 /*desc=*/[{ 756 Return the name registered for this op when lowering to an external 757 library call. 758 }], 759 /*retTy=*/"std::string", 760 /*methodName=*/"getLibraryCallName", 761 /*args=*/(ins), 762 /*methodBody=*/"", 763 /*defaultImplementation=*/[{ 764 return $_op.getLibraryCallName(); 765 }] 766 >, 767 768 //===------------------------------------------------------------------===// 769 // Linalg generalization hooks. 770 //===------------------------------------------------------------------===// 771 InterfaceMethod< 772 /*desc=*/[{ 773 Hook to provide a custom AffineMap used to compute all the operand 774 subshapes given loop bounds. This is used to answer the question: "given 775 an iteration space over the codomain, what are the subshapes of the 776 operands involved in the computation". 777 The default behavior is to just concatenate all the indexing maps. 778 A custom AffineMap allows providing a map that can be used to 779 compute subshapes even in cases where the concatenation of indexing maps 780 (i.e. the data traversal order) is not a simple permutation of the loop 781 traversal order. It is then possible to define ops with skewed data 782 traversal order for which we can still easily compute hyperrectangular 783 loop bounds and subviews. 784 }], 785 /*retTy=*/"AffineMap", 786 /*methodName=*/"getLoopsToShapesMap", 787 /*args=*/(ins), 788 /*methodBody=*/"", 789 /*defaultImplementation=*/[{ 790 auto r = $_op.indexing_maps().template getAsRange<AffineMapAttr>(); 791 auto maps = llvm::to_vector<8>( 792 llvm::map_range(r, [](AffineMapAttr a) { return a.getValue(); })); 793 return concatAffineMaps(maps); 794 }] 795 >, 796 InterfaceMethod< 797 /*desc=*/[{ 798 Hook to provide a custom AffineMap used to construct the 799 hyperrectangular loop iteration space given all the operand subshapes. 800 This is used to answer the question: 801 "Given a list of operand ranges, what is the subportion of the iteration 802 space involved in the computation". 803 This is the inverse problem of `getLoopsToShapesMap`. 804 Return the empty AffineMap when such an AffineMap cannot be constructed. 805 The default behavior is based on a very simple inference procedure that 806 only works with permutation affine maps. 807 A more advanced Tensor-Comprehension like inference is possible but has 808 proven to be ambiguous in unfavorable case. 809 A safer and more robust alternative is to allow each each op to define 810 its own AffineMap. 811 }], 812 /*retTy=*/"AffineMap", 813 /*methodName=*/"getShapesToLoopsMap", 814 /*args=*/(ins), 815 /*methodBody=*/"", 816 /*defaultImplementation=*/[{ 817 return inversePermutation(getLoopsToShapesMap()); 818 }] 819 >, 820 821 //===------------------------------------------------------------------===// 822 // Other static interface methods. 823 //===------------------------------------------------------------------===// 824 StaticInterfaceMethod< 825 /*desc=*/[{ 826 Create an operation of the current type with the given location, 827 operands, and attributes. 828 }], 829 /*retTy=*/"Operation *", 830 /*methodName=*/"create", 831 (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, 832 "ValueRange":$operands, 833 "ArrayRef<NamedAttribute>":$attributes), [{ 834 return builder.create<ConcreteOp>( 835 loc, resultTypes, operands, attributes); 836 }] 837 >, 838 InterfaceMethod< 839 /*desc=*/[{ 840 Clone the current operation with the given location and operands. This 841 is used to abstract away the optional underlying region creation. This 842 does not change the balance between input, output_buffer and 843 init_tensors operands. 844 }], 845 /*retTy=*/"Operation *", 846 /*methodName=*/"clone", 847 (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, 848 "ValueRange":$operands), 849 [{ 850 BlockAndValueMapping map; 851 unsigned numRegions = $_op->getNumRegions(); 852 Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs()); 853 assert(res->getNumRegions() == numRegions && "inconsistent # regions"); 854 for (unsigned ridx = 0; ridx < numRegions; ++ridx) 855 $_op->getRegion(ridx).cloneInto( 856 &res->getRegion(ridx), map); 857 return res; 858 }] 859 >, 860 StaticInterfaceMethod< 861 /*desc=*/[{ 862 Returns the region builder for constructing the body for linalg.generic. 863 Returns a null function if this named op does not define a region 864 builder. 865 }], 866 /*retTy=*/"std::function<void(Block &)>", 867 /*methodName=*/"getRegionBuilder", 868 (ins), 869 [{ return ConcreteOp::getRegionBuilder(); }] 870 > 871 ]; 872 873 let extraClassDeclaration = [{ 874 /// Return the flat list of all operand dimension sizes in the order they 875 /// appear in the operands. 876 SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location); 877 878 /// Create the loop ranges to materialize the computation over the current 879 /// operands. This is done by applying `getShapesToLoopsMap` to 880 /// `createFlatListOfOperandDims`. 881 SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc); 882 883 /// Returns all the operands past the inputs, output_buffers and 884 /// init_tensors operands. Asserts that these operands are value types to 885 /// allow transformations like tiling to just use the values when cloning 886 /// `linalgOp`. 887 SmallVector<Value, 4> getAssumedNonShapedOperands() { 888 unsigned numShapedOperands = getNumShapedOperands(); 889 unsigned nExtraOperands = 890 getOperation()->getNumOperands() - numShapedOperands; 891 SmallVector<Value, 4> res; 892 res.reserve(nExtraOperands); 893 for (unsigned i = 0; i < nExtraOperands; ++i) { 894 res.push_back(getOperation()->getOperand(numShapedOperands + i)); 895 assert((res.back().getType().isSignlessIntOrIndexOrFloat() 896 || res.back().getType().isa<VectorType>()) && 897 "expected scalar or vector type"); 898 } 899 return res; 900 } 901 //========================================================================// 902 // Helper functions to mutate the `operand_segment_sizes` attribute. 903 // These are useful when cloning and changing operand types. 904 //========================================================================// 905 void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } 906 void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } 907 void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); } 908 909 private: 910 void setOperandSegmentAt(unsigned idx, unsigned val) { 911 auto attr = getOperation()->getAttr("operand_segment_sizes") 912 .cast<DenseIntElementsAttr>(); 913 unsigned i = 0; 914 auto newAttr = attr.mapValues(IntegerType::get(32, getContext()), 915 [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); 916 getOperation()->setAttr("operand_segment_sizes", newAttr); 917 } 918 }]; 919} 920 921#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE 922