• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the definition file for the structured interface for Linalg ops.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE
14#define LINALG_IR_STRUCTURED_OPS_INTERFACE
15
16include "mlir/Dialect/Linalg/IR/LinalgBase.td"
17
18// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
19// interface.
20def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
21  let cppNamespace = "::mlir::linalg";
22  let methods = [
23    //===------------------------------------------------------------------===//
24    // Loop types handling.
25    //===------------------------------------------------------------------===//
26    InterfaceMethod<
27      /*desc=*/[{
28        Return the number of parallel loops.
29      }],
30      /*retTy=*/"unsigned",
31      /*methodName=*/"getNumParallelLoops",
32      /*args=*/(ins),
33      /*methodBody=*/"",
34      /*defaultImplementation=*/[{
35        return getNumIterators(getParallelIteratorTypeName(),
36                               $_op.iterator_types());
37      }]
38    >,
39    InterfaceMethod<
40      /*desc=*/[{
41        Return the dims that are parallel loops.
42      }],
43      /*retTy=*/"void",
44      /*methodName=*/"getParallelDims",
45      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
46      /*methodBody=*/"",
47      /*defaultImplementation=*/[{
48        return getDimsOfType($_op, getParallelIteratorTypeName(), res);
49      }]
50    >,
51    InterfaceMethod<
52      /*desc=*/[{
53        Return the number of reduction loops.
54      }],
55      /*retTy=*/"unsigned",
56      /*methodName=*/"getNumReductionLoops",
57      /*args=*/(ins),
58      /*methodBody=*/"",
59      /*defaultImplementation=*/[{
60        return getNumIterators(getReductionIteratorTypeName(),
61                               $_op.iterator_types());
62      }]
63    >,
64    InterfaceMethod<
65      /*desc=*/[{
66        Return the dims that are reduction loops.
67      }],
68      /*retTy=*/"void",
69      /*methodName=*/"getReductionDims",
70      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
71      /*methodBody=*/"",
72      /*defaultImplementation=*/[{
73        return getDimsOfType($_op, getReductionIteratorTypeName(), res);
74      }]
75    >,
76    InterfaceMethod<
77      /*desc=*/[{
78        Return the number of window loops.
79      }],
80      /*retTy=*/"unsigned",
81      /*methodName=*/"getNumWindowLoops",
82      /*args=*/(ins),
83      /*methodBody=*/"",
84      /*defaultImplementation=*/[{
85        return getNumIterators(getWindowIteratorTypeName(),
86                               $_op.iterator_types());
87      }]
88    >,
89    InterfaceMethod<
90      /*desc=*/[{
91        Return the dims that are window loops.
92      }],
93      /*retTy=*/"void",
94      /*methodName=*/"getWindowDims",
95      /*args=*/(ins "SmallVectorImpl<AffineExpr> &":$res),
96      /*methodBody=*/"",
97      /*defaultImplementation=*/[{
98        return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res);
99      }]
100    >,
101    InterfaceMethod<
102      /*desc=*/[{
103        Return the total number of loops within the current operation.
104      }],
105      /*retTy=*/"unsigned",
106      /*methodName=*/"getNumLoops",
107      /*args=*/(ins),
108      /*methodBody=*/"",
109      /*defaultImplementation=*/[{
110        return getNumIterators($_op.iterator_types());
111      }]
112    >,
113    InterfaceMethod<
114      /*desc=*/[{
115        Returns true if the current operation has only one loop and it's a
116        reduction loop.
117      }],
118      /*retTy=*/"bool",
119      /*methodName=*/"hasSingleReductionLoop",
120      /*args=*/(ins),
121      /*methodBody=*/"",
122      /*defaultImplementation=*/[{
123        auto iters = $_op.iterator_types();
124        return iters.size() == 1 &&
125               getNumIterators(getReductionIteratorTypeName(), iters) == 1;
126      }]>,
127    //===------------------------------------------------------------------===//
128    // Num input/output/initTensors arguments handling.
129    //===------------------------------------------------------------------===//
130    // These special methods must be defined by each op that wants to implement
131    // the LinalgStructuredInterface. For now, this is either:
132    // - Explicitly specified in the op definition.
133    // - Derived from variadic attributes (for "named" ops, linalg.generic and
134    //   linalg.indexed_generic ops).
135    InterfaceMethod<
136      /*desc=*/[{
137        Return the number of inputs.
138      }],
139      /*retTy=*/"unsigned",
140      /*methodName=*/"getNumInputs"
141    >,
142    InterfaceMethod<
143      /*desc=*/[{
144        Return the number of init tensors.
145      }],
146      /*retTy=*/"unsigned",
147      /*methodName=*/"getNumInitTensors"
148    >,
149    InterfaceMethod<
150      /*desc=*/[{
151        Return the number of outputs.
152      }],
153      /*retTy=*/"unsigned",
154      /*methodName=*/"getNumOutputs"
155    >,
156    //===------------------------------------------------------------------===//
157    // Input arguments handling.
158    //===------------------------------------------------------------------===//
159    InterfaceMethod<
160      /*desc=*/[{
161        Return the `i`-th input value.
162        The `i^th` input argument is always the `i^th` operand regardless of
163        whether we have tensors or buffers.
164      }],
165      /*retTy=*/"Value",
166      /*methodName=*/"getInput",
167      /*args=*/(ins "unsigned":$i),
168      /*methodBody=*/"",
169      /*defaultImplementation=*/[{
170        assert(i < $_op.getNumInputs());
171        return this->getOperation()->getOperand(i);
172      }]
173    >,
174    InterfaceMethod<
175      /*desc=*/[{
176        Return the index of the given input value `v`, or `None` if the value is
177        not an input.
178      }],
179      /*retTy=*/"llvm::Optional<unsigned>",
180      /*methodName=*/"getIndexOfInput",
181      /*args=*/(ins "Value":$value),
182      /*methodBody=*/"",
183      /*defaultImplementation=*/[{
184        auto it = llvm::find(getInputs(), value);
185        if (it != getInputs().end())
186          return it - getInputs().begin();
187        return llvm::None;
188      }]
189    >,
190    InterfaceMethod<
191      /*desc=*/[{
192        Return the `i`-th input shaped type, irrespective of buffer or tensor
193        type.
194      }],
195      /*retTy=*/"ShapedType",
196      /*methodName=*/"getInputShapedType",
197      /*args=*/(ins "unsigned":$i),
198      /*methodBody=*/"",
199      /*defaultImplementation=*/[{
200        return getInput(i).getType().template cast<ShapedType>();
201      }]
202    >,
203    InterfaceMethod<
204      /*desc=*/[{
205        Return the input operands.
206      }],
207      /*retTy=*/"Operation::operand_range",
208      /*methodName=*/"getInputs",
209      /*args=*/(ins),
210      /*methodBody=*/"",
211      /*defaultImplementation=*/[{
212        auto range = this->getOperation()->getOperands();
213        return {range.begin(), range.begin() + $_op.getNumInputs()};
214      }]
215    >,
216    InterfaceMethod<
217      /*desc=*/[{
218        Return the range over the input operands that are of buffer type.
219      }],
220      /*retTy=*/"SmallVector<Value, 4>",
221      /*methodName=*/"getInputBuffers",
222      /*args=*/(ins),
223      /*methodBody=*/"",
224      /*defaultImplementation=*/[{
225        return llvm::to_vector<4>(llvm::make_filter_range(
226          getInputs(), [](Value in){ return in.getType().isa<MemRefType>(); }));
227      }]
228    >,
229    InterfaceMethod<
230      /*desc=*/[{
231        Return the subset of input operands that are of ranked tensor type.
232      }],
233      /*retTy=*/"SmallVector<RankedTensorType, 4>",
234      /*methodName=*/"getInputTensorTypes" ,
235      /*args=*/(ins),
236      /*methodBody=*/"",
237      /*defaultImplementation=*/[{
238        SmallVector<RankedTensorType, 4> res;
239        for (Type type : getInputs().getTypes())
240          if (auto t = type.template dyn_cast<RankedTensorType>())
241            res.push_back(t);
242        return res;
243      }]
244    >,
245    //===------------------------------------------------------------------===//
246    // Output arguments handling.
247    //===------------------------------------------------------------------===//
248    InterfaceMethod<
249      /*desc=*/[{
250        Return the output buffer at the given index, asserts that this is a
251        buffer operand and not a tensor result.
252        The `i^th` output argument is an operand (resp. a return value) iff it
253        is a value of buffer type (resp. a return value of tensor type).
254      }],
255      /*retTy=*/"Value",
256      /*methodName=*/"getOutputBuffer",
257      /*args=*/(ins "unsigned":$i),
258      /*methodBody=*/"",
259      /*defaultImplementation=*/[{
260        // Output buffers are passed as output buffer operands (side-effecting).
261        // Output tensors are results.
262        // The union of the 2 are all the outputs and we want to ensure i does
263        // not overflow the buffer operands.
264        assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs()
265               && "overflowing output buffer index");
266        return this->getOperation()->getOperand($_op.getNumInputs() + i);
267      }]
268    >,
269    InterfaceMethod<
270      /*desc=*/[{
271        Return the index of the given buffer value, or `None` if the value is
272        not part of the output buffers.
273      }],
274      /*retTy=*/"llvm::Optional<unsigned>",
275      /*methodName=*/"getIndexOfOutputBuffer",
276      /*args=*/(ins "Value":$value),
277      /*methodBody=*/"",
278      /*defaultImplementation=*/[{
279        auto it = llvm::find(getOutputBuffers(), value);
280        if (it != getOutputBuffers().end())
281          return it - getOutputBuffers().begin();
282        return llvm::None;
283      }]
284    >,
285    InterfaceMethod<
286      /*desc=*/[{
287        Return the type of the output buffer at the given index.
288      }],
289      /*retTy=*/"MemRefType",
290      /*methodName=*/"getOutputBufferType",
291      /*args=*/(ins "unsigned":$i),
292      /*methodBody=*/"",
293      /*defaultImplementation=*/[{
294        return getOutputBuffer(i).getType().template cast<MemRefType>();
295      }]>,
296    InterfaceMethod<
297      /*desc=*/[{
298        Return the `i`-th output shaped type, irrespective of buffer or tensor
299        type.
300      }],
301      /*retTy=*/"ShapedType",
302      /*methodName=*/"getOutputShapedType",
303      /*args=*/(ins "unsigned":$i),
304      /*methodBody=*/"",
305      /*defaultImplementation=*/[{
306        return getShapedType(i + $_op.getNumInputs());
307      }]>,
308    InterfaceMethod<
309      /*desc=*/[{
310        Return the results that are of ranked tensor type.
311      }],
312      /*retTy=*/"SmallVector<RankedTensorType, 4>",
313      /*methodName=*/"getOutputTensorTypes",
314      /*args=*/(ins),
315      /*methodBody=*/"",
316      /*defaultImplementation=*/[{
317        SmallVector<RankedTensorType, 4> res;
318        for (Type type : this->getOperation()->getResults().getTypes())
319          res.push_back(type.template cast<RankedTensorType>());
320        return res;
321      }]>,
322    InterfaceMethod<
323      /*desc=*/[{
324        Return the output buffers (operands).
325      }],
326      /*retTy=*/"Operation::operand_range",
327      /*methodName=*/"getOutputBuffers",
328      /*args=*/(ins),
329      /*methodBody=*/"",
330      /*defaultImplementation=*/[{
331        auto range = this->getOperation()->getOperands();
332        return {range.begin() + $_op.getNumInputs(),
333                range.begin() + getNumInputsAndOutputBuffers()};
334      }]
335    >,
336
337    //===------------------------------------------------------------------===//
338    // Input and Output arguments handling.
339    //===------------------------------------------------------------------===//
340    InterfaceMethod<
341      /*desc=*/[{
342        Return one single buffer at position `$i`.
343      }],
344      /*retTy=*/"Value",
345      /*methodName=*/"getBuffer",
346      /*args=*/(ins "unsigned":$i),
347      /*methodBody=*/"",
348      /*defaultImplementation=*/[{
349        assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
350        return this->getOperation()->getOperand(i);
351      }]
352    >,
353    InterfaceMethod<
354      /*desc=*/[{
355        Return the number of output buffers
356      }],
357      /*retTy=*/"unsigned",
358      /*methodName=*/"getNumOutputBuffers",
359      /*args=*/(ins),
360      /*methodBody=*/"",
361      /*defaultImplementation=*/[{
362        return $_op.getNumOutputs() - this->getOperation()->getNumResults();
363      }]
364    >,
365    InterfaceMethod<
366      /*desc=*/[{
367        Return the number of inputs and outputs, irrespective of their buffer or
368        tensor type.
369      }],
370      /*retTy=*/"unsigned",
371      /*methodName=*/"getNumInputsAndOutputs",
372      /*args=*/(ins),
373      /*methodBody=*/"",
374      /*defaultImplementation=*/[{
375        return $_op.getNumInputs() + $_op.getNumOutputs();
376      }]
377    >,
378    InterfaceMethod<
379      /*desc=*/[{
380        Return the number of inputs, irrespective of their buffer or tensor type
381        and output buffers
382      }],
383      /*retTy=*/"unsigned",
384      /*methodName=*/"getNumInputsAndOutputBuffers",
385      /*args=*/(ins),
386      /*methodBody=*/"",
387      /*defaultImplementation=*/[{
388        return $_op.getNumInputs() + $_op.getNumOutputs() -
389          this->getOperation()->getNumResults();
390      }]
391    >,
392    InterfaceMethod<
393      /*desc=*/[{
394        Return the range over inputs (irrespective of type) and output buffers.
395      }],
396      /*retTy=*/"Operation::operand_range",
397      /*methodName=*/"getInputsAndOutputBuffers",
398      /*args=*/(ins),
399      /*methodBody=*/"",
400      /*defaultImplementation=*/[{
401        auto range = this->getOperation()->getOperands();
402        return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
403      }]
404    >,
405    InterfaceMethod<
406      /*desc=*/[{
407        Return the range over init tensors.
408      }],
409      /*retTy=*/"Operation::operand_range",
410      /*methodName=*/"getInitTensors",
411      /*args=*/(ins),
412      /*methodBody=*/"",
413      /*defaultImplementation=*/[{
414        auto range = this->getOperation()->getOperands();
415        auto base = range.begin() + getNumInputsAndOutputBuffers();
416        return {base, base + $_op.getNumInitTensors()};
417      }]
418    >,
419    InterfaceMethod<
420      /*desc=*/[{
421        Return one single init tensor at position `$i`.
422      }],
423      /*retTy=*/"Value",
424      /*methodName=*/"getInitTensor",
425      /*args=*/(ins "unsigned":$i),
426      /*methodBody=*/"",
427      /*defaultImplementation=*/[{
428        assert(i < $_op.getNumInitTensors() && "overflowing init tensor index");
429        return getInitTensors()[i];
430      }]
431    >,
432    InterfaceMethod<
433      /*desc=*/[{
434        Return true if the shaped operand index `i` is the index of an init
435        tensor.
436      }],
437      /*retTy=*/"bool",
438      /*methodName=*/"isIndexOfAnInitTensor",
439      /*args=*/(ins "unsigned":$i),
440      /*methodBody=*/"",
441      /*defaultImplementation=*/[{
442        assert(i < $_op.getNumShapedOperands() && "overflowing shaped operand index");
443        return i >= $_op.getNumInputs() + getNumOutputBuffers();
444      }]
445    >,
446    InterfaceMethod<
447      /*desc=*/[{
448        Return the relative init tensor index of the shaped operand index.
449      }],
450      /*retTy=*/"unsigned",
451      /*methodName=*/"getInitTensorIndexFromShapedIndex",
452      /*args=*/(ins "unsigned":$i),
453      /*methodBody=*/"",
454      /*defaultImplementation=*/[{
455        assert(isIndexOfAnInitTensor(i) && "expected an init tensor index");
456        return i - $_op.getNumInputs() - getNumOutputBuffers();
457      }]
458    >,
459    InterfaceMethod<
460      /*desc=*/[{
461        Return the index of the given init tensor value, or `None` if the value
462        is not part of the init tensors.
463      }],
464      /*retTy=*/"llvm::Optional<unsigned>",
465      /*methodName=*/"getIndexOfInitTensor",
466      /*args=*/(ins "Value":$value),
467      /*methodBody=*/"",
468      /*defaultImplementation=*/[{
469        auto it = llvm::find(getInitTensors(), value);
470        if (it != getInitTensors().end())
471          return it - getInitTensors().begin();
472        return llvm::None;
473      }]
474    >,
475    InterfaceMethod<
476      /*desc=*/[{
477        Return the number of inputs, output buffers and init tensors operands.
478      }],
479      /*retTy=*/"unsigned",
480      /*methodName=*/"getNumShapedOperands",
481      /*args=*/(ins),
482      /*methodBody=*/"",
483      /*defaultImplementation=*/[{
484        return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
485      }]
486    >,
487    InterfaceMethod<
488      /*desc=*/[{
489        Return the `i`-th shaped operand value, which can be an arbitrary input
490        tensor/buffer, init tensor or output buffer.
491      }],
492      /*retTy=*/"Value",
493      /*methodName=*/"getShapedOperand",
494      /*args=*/(ins "unsigned":$i),
495      /*methodBody=*/"",
496      /*defaultImplementation=*/[{
497        assert(i < $_op.getNumShapedOperands());
498        return this->getOperation()->getOperand(i);
499      }]
500    >,
501    InterfaceMethod<
502      /*desc=*/[{
503        Return the range over inputs, output buffers and init tensors.
504      }],
505      /*retTy=*/"Operation::operand_range",
506      /*methodName=*/"getShapedOperands",
507      /*args=*/(ins),
508      /*methodBody=*/"",
509      /*defaultImplementation=*/[{
510        auto range = this->getOperation()->getOperands();
511        return {range.begin(), range.begin() + getNumShapedOperands()};
512      }]
513    >,
514    InterfaceMethod<
515      /*desc=*/[{
516        Return the `i`-th shaped type, there are 3 cases:
517          1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`;
518             otherwise
519          2. if `i < getNumInputsAndOutputBuffers()` then return the
520             `getOutputBufferType(i - $_op.getNumInputs())`; otherwise
521          3. return the `i - getNumInputsAndOutputBuffers()` result type.
522      }],
523      /*retTy=*/"ShapedType",
524      /*methodName=*/"getShapedType",
525      /*args=*/(ins "unsigned":$i),
526      /*methodBody=*/"",
527      /*defaultImplementation=*/[{
528        if (i < $_op.getNumInputs())
529          return getInputShapedType(i);
530        if (i < getNumInputsAndOutputBuffers())
531          return getOutputBufferType(i - $_op.getNumInputs());
532        return this->getOperation()->getResult(
533          i - getNumInputsAndOutputBuffers()).
534          getType().template cast<ShapedType>();
535      }]>,
536    InterfaceMethod<
537      /*desc=*/[{
538        Return the shaped types for all the inputs and outputs
539      }],
540      /*retTy=*/"SmallVector<ShapedType, 4>",
541      /*methodName=*/"getInputOutputShapedTypes",
542      /*args=*/(ins),
543      /*methodBody=*/"",
544      /*defaultImplementation=*/[{
545        SmallVector<Type, 4> inputOutputTypes(
546            this->getOperation()->operand_type_begin(),
547            this->getOperation()->operand_type_end());
548        inputOutputTypes.append(this->getOperation()->result_type_begin(),
549                                this->getOperation()->result_type_end());
550        return llvm::to_vector<4>(
551            llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
552              return type.cast<ShapedType>();
553            }));
554      }]
555    >,
556    InterfaceMethod<
557      /*desc=*/[{
558        Return the first position of the shaped operand in the operand list.
559      }],
560      /*retTy=*/"Optional<unsigned>",
561      /*methodName=*/"getIndexOfShapedOperand",
562      /*args=*/(ins "Value":$value),
563      /*methodBody=*/"",
564      /*defaultImplementation=*/[{
565        Optional<unsigned> inputIndex = getIndexOfInput(value);
566        if (inputIndex.hasValue()) return inputIndex.getValue();
567        Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
568        if (outputIndex.hasValue())
569          return $_op.getNumInputs() + outputIndex.getValue();
570        Optional<unsigned> initTensorIndex = getIndexOfInitTensor(value);
571        if (initTensorIndex.hasValue())
572          return $_op.getNumInputs() + $_op.getNumOutputBuffers() + initTensorIndex.getValue();
573        return llvm::None;
574      }]
575    >,
576    InterfaceMethod<
577      /*desc=*/[{
578        Returns the operand index given the input index. Returns None
579        of the input index is invalid.
580      }],
581      /*retTy=*/"Optional<unsigned>",
582      /*methodName=*/"getOperandIndexForInputIndex",
583      /*args=*/(ins "unsigned":$input_index),
584      /*methodBody=*/"",
585      /*defaultImplementation=*/[{
586        if (input_index >= $_op.getNumInputs())
587          return llvm::None;
588        return input_index;
589      }]
590    >,
591    InterfaceMethod<
592      /*desc=*/[{
593        Returns the operand index given the output index. Returns None
594        of the output index is invalid.
595      }],
596      /*retTy=*/"Optional<unsigned>",
597      /*methodName=*/"getOperandIndexForOutputIndex",
598      /*args=*/(ins "unsigned":$output_index),
599      /*methodBody=*/"",
600      /*defaultImplementation=*/[{
601        if (output_index >= $_op.getNumOutputs())
602          return llvm::None;
603        return output_index + $_op.getNumInputs();
604      }]
605    >,
606    InterfaceMethod<
607      /*desc=*/[{
608        Returns the input index given the operand index. Return None
609        if the operand index doesnt corresponding to an input.
610      }],
611      /*retTy=*/"Optional<unsigned>",
612      /*methodName=*/"getInputIndex",
613      /*args=*/(ins "unsigned":$operand_index),
614      /*methodBody=*/"",
615      /*defaultImplementation=*/[{
616         if (operand_index >= $_op.getNumInputs())
617           return llvm::None;
618         return operand_index;
619      }]
620    >,
621    InterfaceMethod<
622      /*desc=*/[{
623        Returns the output index given the operand index. Return None
624        if the operand index doesnt corresponding to an output.
625      }],
626      /*retTy=*/"Optional<unsigned>",
627      /*methodName=*/"getOutputIndex",
628      /*args=*/(ins "unsigned":$operand_index),
629      /*methodBody=*/"",
630      /*defaultImplementation=*/[{
631         if (operand_index < $_op.getNumInputs() ||
632             operand_index >= $_op.getNumInputs() + $_op.getNumOutputs())
633           return llvm::None;
634         return operand_index - $_op.getNumInputs();
635      }]
636    >,
637
638    //===------------------------------------------------------------------===//
639    // Other interface methods.
640    //===------------------------------------------------------------------===//
641    InterfaceMethod<
642      /*desc=*/[{
643        Return the iterator types attribute within the current operation.
644      }],
645      /*retTy=*/"ArrayAttr",
646      /*methodName=*/"iterator_types",
647      /*args=*/(ins),
648      /*methodBody=*/"",
649      /*defaultImplementation=*/[{
650        return $_op.iterator_types();
651      }]
652    >,
653    InterfaceMethod<
654      /*desc=*/[{
655        Return the indexing maps attribute within the current operation.
656      }],
657      /*retTy=*/"ArrayAttr",
658      /*methodName=*/"indexing_maps"
659    >,
660    InterfaceMethod<
661      /*desc=*/[{
662        Return the indexing maps within the current operation.
663      }],
664      /*retTy=*/"SmallVector<AffineMap, 4>",
665      /*methodName=*/"getIndexingMaps",
666      /*args=*/(ins),
667      /*methodBody=*/"",
668      /*defaultImplementation=*/[{
669        return llvm::to_vector<4>(
670          $_op.indexing_maps().template getAsValueRange<AffineMapAttr>());
671      }]
672    >,
673    InterfaceMethod<
674      /*desc=*/[{
675        Return the input or output indexing map at index `i`.
676      }],
677      /*retTy=*/"AffineMap",
678      /*methodName=*/"getIndexingMap",
679      /*args=*/(ins "unsigned":$i),
680      /*methodBody=*/"",
681      /*defaultImplementation=*/[{
682        assert(i < getNumInputsAndOutputs());
683        return getIndexingMaps()[i];
684      }]
685    >,
686    InterfaceMethod<
687      /*desc=*/[{
688        Return the input indexing map at index `i`.
689      }],
690      /*retTy=*/"AffineMap",
691      /*methodName=*/"getInputIndexingMap",
692      /*args=*/(ins "unsigned":$i),
693      /*methodBody=*/"",
694      /*defaultImplementation=*/[{
695        assert(i < $_op.getNumInputs());
696        return getIndexingMaps()[i];
697      }]
698    >,
699    InterfaceMethod<
700      /*desc=*/[{
701        Return the output indexing map at index `i`.
702      }],
703      /*retTy=*/"AffineMap",
704      /*methodName=*/"getOutputIndexingMap",
705      /*args=*/(ins "unsigned":$i),
706      /*methodBody=*/"",
707      /*defaultImplementation=*/[{
708        assert(i < $_op.getNumOutputs());
709        return getIndexingMaps()[i + $_op.getNumInputs()];
710      }]
711    >,
712    InterfaceMethod<
713      /*desc=*/[{
714        Return whether the op has only MemRef input and outputs.
715      }],
716      /*retTy=*/"bool",
717      /*methodName=*/"hasBufferSemantics",
718      /*args=*/(ins),
719      /*methodBody=*/"",
720      /*defaultImplementation=*/[{
721        return this->getOperation()->getNumResults() == 0 &&
722          llvm::all_of(getInputs(),
723                       [](Value v) { return v.getType().isa<MemRefType>(); });
724      }]
725    >,
726    InterfaceMethod<
727      /*desc=*/[{
728        Return whether the op has only RankedTensor input and outputs.
729      }],
730      /*retTy=*/"bool",
731      /*methodName=*/"hasTensorSemantics",
732      /*args=*/(ins),
733      /*methodBody=*/"",
734      /*defaultImplementation=*/[{
735        auto isTensorType = [](Value v) {
736          return v.getType().isa<RankedTensorType>();
737        };
738        return llvm::all_of(getInputs(), isTensorType) &&
739               llvm::all_of(this->getOperation()->getResults(), isTensorType);
740      }]
741    >,
742    InterfaceMethod<
743      /*desc=*/[{
744        Return whether the op has sparse tensor semantics.
745      }],
746      /*retTy=*/"bool",
747      /*methodName=*/"hasSparseSemantics",
748      /*args=*/(ins),
749      /*methodBody=*/"",
750      /*defaultImplementation=*/[{
751        return $_op.getAttr(getSparseAttrName()).template dyn_cast_or_null<ArrayAttr>() != nullptr;
752      }]
753    >,
754    InterfaceMethod<
755      /*desc=*/[{
756        Return the name registered for this op when lowering to an external
757        library call.
758      }],
759      /*retTy=*/"std::string",
760      /*methodName=*/"getLibraryCallName",
761      /*args=*/(ins),
762      /*methodBody=*/"",
763      /*defaultImplementation=*/[{
764        return $_op.getLibraryCallName();
765      }]
766    >,
767
768    //===------------------------------------------------------------------===//
769    // Linalg generalization hooks.
770    //===------------------------------------------------------------------===//
771    InterfaceMethod<
772      /*desc=*/[{
773        Hook to provide a custom AffineMap used to compute all the operand
774        subshapes given loop bounds. This is used to answer the question: "given
775        an iteration space over the codomain, what are the subshapes of the
776        operands involved in the computation".
777        The default behavior is to just concatenate all the indexing maps.
778        A custom AffineMap allows providing a map that can be used to
779        compute subshapes even in cases where the concatenation of indexing maps
780        (i.e. the data traversal order) is not a simple permutation of the loop
781        traversal order. It is then possible to define ops with skewed data
782        traversal order for which we can still easily compute hyperrectangular
783        loop bounds and subviews.
784      }],
785      /*retTy=*/"AffineMap",
786      /*methodName=*/"getLoopsToShapesMap",
787      /*args=*/(ins),
788      /*methodBody=*/"",
789      /*defaultImplementation=*/[{
790        auto r = $_op.indexing_maps().template getAsRange<AffineMapAttr>();
791        auto maps = llvm::to_vector<8>(
792            llvm::map_range(r, [](AffineMapAttr a) { return a.getValue(); }));
793        return concatAffineMaps(maps);
794      }]
795    >,
796    InterfaceMethod<
797      /*desc=*/[{
798        Hook to provide a custom AffineMap used to construct the
799        hyperrectangular loop iteration space given all the operand subshapes.
800        This is used to answer the question:
801        "Given a list of operand ranges, what is the subportion of the iteration
802        space involved in the computation".
803        This is the inverse problem of `getLoopsToShapesMap`.
804        Return the empty AffineMap when such an AffineMap cannot be constructed.
805        The default behavior is based on a very simple inference procedure that
806        only works with permutation affine maps.
807        A more advanced Tensor-Comprehension like inference is possible but has
808        proven to be ambiguous in unfavorable case.
809        A safer and more robust alternative is to allow each each op to define
810        its own AffineMap.
811      }],
812      /*retTy=*/"AffineMap",
813      /*methodName=*/"getShapesToLoopsMap",
814      /*args=*/(ins),
815      /*methodBody=*/"",
816      /*defaultImplementation=*/[{
817        return inversePermutation(getLoopsToShapesMap());
818      }]
819    >,
820
821    //===------------------------------------------------------------------===//
822    // Other static interface methods.
823    //===------------------------------------------------------------------===//
824    StaticInterfaceMethod<
825      /*desc=*/[{
826        Create an operation of the current type with the given location,
827        operands, and attributes.
828      }],
829      /*retTy=*/"Operation *",
830      /*methodName=*/"create",
831      (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
832           "ValueRange":$operands,
833           "ArrayRef<NamedAttribute>":$attributes), [{
834        return builder.create<ConcreteOp>(
835          loc, resultTypes, operands, attributes);
836      }]
837    >,
838    InterfaceMethod<
839      /*desc=*/[{
840        Clone the current operation with the given location and operands. This
841        is used to abstract away the optional underlying region creation. This
842        does not change the balance between input, output_buffer and
843        init_tensors operands.
844      }],
845      /*retTy=*/"Operation *",
846      /*methodName=*/"clone",
847      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
848           "ValueRange":$operands),
849      [{
850        BlockAndValueMapping map;
851        unsigned numRegions = $_op->getNumRegions();
852        Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs());
853        assert(res->getNumRegions() == numRegions && "inconsistent # regions");
854        for (unsigned ridx = 0; ridx < numRegions; ++ridx)
855          $_op->getRegion(ridx).cloneInto(
856            &res->getRegion(ridx), map);
857        return res;
858      }]
859    >,
860    StaticInterfaceMethod<
861      /*desc=*/[{
862        Returns the region builder for constructing the body for linalg.generic.
863        Returns a null function if this named op does not define a region
864        builder.
865      }],
866      /*retTy=*/"std::function<void(Block &)>",
867      /*methodName=*/"getRegionBuilder",
868      (ins),
869      [{ return ConcreteOp::getRegionBuilder(); }]
870    >
871  ];
872
873  let extraClassDeclaration = [{
874    /// Return the flat list of all operand dimension sizes in the order they
875    /// appear in the operands.
876    SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
877
878    /// Create the loop ranges to materialize the computation over the current
879    /// operands. This is done by applying `getShapesToLoopsMap` to
880    /// `createFlatListOfOperandDims`.
881    SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
882
883    /// Returns all the operands past the inputs, output_buffers and
884    /// init_tensors operands. Asserts that these operands are value types to
885    /// allow transformations like tiling to just use the values when cloning
886    /// `linalgOp`.
887    SmallVector<Value, 4> getAssumedNonShapedOperands() {
888      unsigned numShapedOperands = getNumShapedOperands();
889      unsigned nExtraOperands =
890        getOperation()->getNumOperands() - numShapedOperands;
891      SmallVector<Value, 4> res;
892      res.reserve(nExtraOperands);
893      for (unsigned i = 0; i < nExtraOperands; ++i) {
894        res.push_back(getOperation()->getOperand(numShapedOperands + i));
895        assert((res.back().getType().isSignlessIntOrIndexOrFloat()
896                || res.back().getType().isa<VectorType>()) &&
897               "expected scalar or vector type");
898      }
899      return res;
900    }
901    //========================================================================//
902    // Helper functions to mutate the `operand_segment_sizes` attribute.
903    // These are useful when cloning and changing operand types.
904    //========================================================================//
905    void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
906    void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
907    void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); }
908
909    private:
910    void setOperandSegmentAt(unsigned idx, unsigned val) {
911      auto attr = getOperation()->getAttr("operand_segment_sizes")
912        .cast<DenseIntElementsAttr>();
913      unsigned i = 0;
914      auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
915        [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
916      getOperation()->setAttr("operand_segment_sizes", newAttr);
917    }
918  }];
919}
920
921#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE
922