1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include "mlir-hlo-c/Attributes.h"
14
15 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
16 #include "mlir/CAPI/IR.h"
17 #include "mlir/CAPI/Support.h"
18
19 //
20 // ScatterDimensionNumbersAttr.
21 //
22
mlirMhloScatterDimensionNumbersGet(MlirContext ctx,intptr_t nUpdateWindowDims,const int64_t * updateWindowDims,intptr_t nInsertedWindowDims,const int64_t * insertedWindowDims,intptr_t nScatteredDimsToOperandDims,const int64_t * scatteredDimsToOperandDims,int64_t indexVectorDim)23 MlirAttribute mlirMhloScatterDimensionNumbersGet(
24 MlirContext ctx, intptr_t nUpdateWindowDims,
25 const int64_t *updateWindowDims, intptr_t nInsertedWindowDims,
26 const int64_t *insertedWindowDims, intptr_t nScatteredDimsToOperandDims,
27 const int64_t *scatteredDimsToOperandDims, int64_t indexVectorDim) {
28 return wrap(mlir::mhlo::ScatterDimensionNumbersAttr::get(
29 unwrap(ctx), llvm::makeArrayRef(updateWindowDims, nUpdateWindowDims),
30 llvm::makeArrayRef(insertedWindowDims, nInsertedWindowDims),
31 llvm::makeArrayRef(scatteredDimsToOperandDims,
32 nScatteredDimsToOperandDims),
33 indexVectorDim));
34 }
35
mlirMhloAttributeIsAScatterDimensionNumbers(MlirAttribute attr)36 bool mlirMhloAttributeIsAScatterDimensionNumbers(MlirAttribute attr) {
37 return unwrap(attr).isa<mlir::mhlo::ScatterDimensionNumbersAttr>();
38 }
39
mlirMhloScatterDimensionNumbersGetUpdateWindowDimsSize(MlirAttribute attr)40 intptr_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsSize(
41 MlirAttribute attr) {
42 return unwrap(attr)
43 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
44 .getUpdateWindowDims()
45 .size();
46 }
47
mlirMhloScatterDimensionNumbersGetUpdateWindowDimsElem(MlirAttribute attr,intptr_t pos)48 int64_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsElem(
49 MlirAttribute attr, intptr_t pos) {
50 return unwrap(attr)
51 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
52 .getUpdateWindowDims()[pos];
53 }
54
mlirMhloScatterDimensionNumbersGetInsertedWindowDimsSize(MlirAttribute attr)55 intptr_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsSize(
56 MlirAttribute attr) {
57 return unwrap(attr)
58 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
59 .getInsertedWindowDims()
60 .size();
61 }
62
mlirMhloScatterDimensionNumbersGetInsertedWindowDimsElem(MlirAttribute attr,intptr_t pos)63 int64_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsElem(
64 MlirAttribute attr, intptr_t pos) {
65 return unwrap(attr)
66 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
67 .getInsertedWindowDims()[pos];
68 }
69
mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(MlirAttribute attr)70 intptr_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(
71 MlirAttribute attr) {
72 return unwrap(attr)
73 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
74 .getScatterDimsToOperandDims()
75 .size();
76 }
77
mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(MlirAttribute attr,intptr_t pos)78 int64_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(
79 MlirAttribute attr, intptr_t pos) {
80 return unwrap(attr)
81 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
82 .getScatterDimsToOperandDims()[pos];
83 }
84
mlirMhloDimensionNumbersGetIndexVectorDim(MlirAttribute attr)85 int64_t mlirMhloDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
86 return unwrap(attr)
87 .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
88 .getIndexVectorDim();
89 }
90
91 //
92 // GatherDimensionNumbersAttr.
93 //
94
mlirMhloGatherDimensionNumbersGet(MlirContext ctx,intptr_t nOffsetDims,const int64_t * offsetDims,intptr_t nCollapsedSliceDims,const int64_t * collapsedSliceDims,intptr_t nStartIndexMap,const int64_t * startIndexMap,int64_t indexVectorDim)95 MlirAttribute mlirMhloGatherDimensionNumbersGet(
96 MlirContext ctx, intptr_t nOffsetDims, const int64_t *offsetDims,
97 intptr_t nCollapsedSliceDims, const int64_t *collapsedSliceDims,
98 intptr_t nStartIndexMap, const int64_t *startIndexMap,
99 int64_t indexVectorDim) {
100 return wrap(mlir::mhlo::GatherDimensionNumbersAttr::get(
101 unwrap(ctx), llvm::makeArrayRef(offsetDims, nOffsetDims),
102 llvm::makeArrayRef(collapsedSliceDims, nCollapsedSliceDims),
103 llvm::makeArrayRef(startIndexMap, nStartIndexMap), indexVectorDim));
104 }
105
mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr)106 bool mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr) {
107 return unwrap(attr).isa<mlir::mhlo::GatherDimensionNumbersAttr>();
108 }
109
mlirMhloGatherDimensionNumbersGetOffsetDimsSize(MlirAttribute attr)110 intptr_t mlirMhloGatherDimensionNumbersGetOffsetDimsSize(MlirAttribute attr) {
111 return unwrap(attr)
112 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
113 .getOffsetDims()
114 .size();
115 }
116
mlirMhloGatherDimensionNumbersGetOffsetDimsElem(MlirAttribute attr,intptr_t pos)117 int64_t mlirMhloGatherDimensionNumbersGetOffsetDimsElem(MlirAttribute attr,
118 intptr_t pos) {
119 return unwrap(attr)
120 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
121 .getOffsetDims()[pos];
122 }
123
mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsSize(MlirAttribute attr)124 intptr_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsSize(
125 MlirAttribute attr) {
126 return unwrap(attr)
127 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
128 .getCollapsedSliceDims()
129 .size();
130 }
131
mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsElem(MlirAttribute attr,intptr_t pos)132 int64_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsElem(
133 MlirAttribute attr, intptr_t pos) {
134 return unwrap(attr)
135 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
136 .getCollapsedSliceDims()[pos];
137 }
138
mlirMhloGatherDimensionNumbersGetStartIndexMapSize(MlirAttribute attr)139 intptr_t mlirMhloGatherDimensionNumbersGetStartIndexMapSize(
140 MlirAttribute attr) {
141 return unwrap(attr)
142 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
143 .getStartIndexMap()
144 .size();
145 }
146
mlirMhloGatherDimensionNumbersGetStartIndexMapElem(MlirAttribute attr,intptr_t pos)147 int64_t mlirMhloGatherDimensionNumbersGetStartIndexMapElem(MlirAttribute attr,
148 intptr_t pos) {
149 return unwrap(attr)
150 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
151 .getStartIndexMap()[pos];
152 }
153
mlirMhloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr)154 int64_t mlirMhloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
155 return unwrap(attr)
156 .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
157 .getIndexVectorDim();
158 }
159
160 //
161 // DotDimensionNumbersAttr.
162 //
163
mlirMhloDotDimensionNumbersGet(MlirContext ctx,intptr_t nLhsBatchingDimensions,const int64_t * lhsBatchingDimensions,intptr_t nRhsBatchingDimensions,const int64_t * rhsBatchingDimensions,intptr_t nLhsContractingDimensions,const int64_t * lhsContractingDimensions,intptr_t nRhsContractingDimensions,const int64_t * rhsContractingDimensions)164 MlirAttribute mlirMhloDotDimensionNumbersGet(
165 MlirContext ctx, intptr_t nLhsBatchingDimensions,
166 const int64_t *lhsBatchingDimensions, intptr_t nRhsBatchingDimensions,
167 const int64_t *rhsBatchingDimensions, intptr_t nLhsContractingDimensions,
168 const int64_t *lhsContractingDimensions, intptr_t nRhsContractingDimensions,
169 const int64_t *rhsContractingDimensions) {
170 return wrap(mlir::mhlo::DotDimensionNumbersAttr::get(
171 unwrap(ctx),
172 llvm::makeArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions),
173 llvm::makeArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions),
174 llvm::makeArrayRef(lhsContractingDimensions, nLhsContractingDimensions),
175 llvm::makeArrayRef(rhsContractingDimensions, nRhsContractingDimensions)));
176 }
177
mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr)178 bool mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr) {
179 return unwrap(attr).isa<mlir::mhlo::DotDimensionNumbersAttr>();
180 }
181
mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsSize(MlirAttribute attr)182 intptr_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsSize(
183 MlirAttribute attr) {
184 return unwrap(attr)
185 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
186 .getLhsBatchingDimensions()
187 .size();
188 }
189
mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsElem(MlirAttribute attr,intptr_t pos)190 int64_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsElem(
191 MlirAttribute attr, intptr_t pos) {
192 return unwrap(attr)
193 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
194 .getLhsBatchingDimensions()[pos];
195 }
196
mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsSize(MlirAttribute attr)197 intptr_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsSize(
198 MlirAttribute attr) {
199 return unwrap(attr)
200 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
201 .getRhsBatchingDimensions()
202 .size();
203 }
204
mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsElem(MlirAttribute attr,intptr_t pos)205 int64_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsElem(
206 MlirAttribute attr, intptr_t pos) {
207 return unwrap(attr)
208 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
209 .getRhsBatchingDimensions()[pos];
210 }
211
mlirMhloDotDimensionNumbersGetLhsContractingDimensionsSize(MlirAttribute attr)212 intptr_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsSize(
213 MlirAttribute attr) {
214 return unwrap(attr)
215 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
216 .getLhsContractingDimensions()
217 .size();
218 }
219
mlirMhloDotDimensionNumbersGetLhsContractingDimensionsElem(MlirAttribute attr,intptr_t pos)220 int64_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsElem(
221 MlirAttribute attr, intptr_t pos) {
222 return unwrap(attr)
223 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
224 .getLhsContractingDimensions()[pos];
225 }
226
mlirMhloDotDimensionNumbersGetRhsContractingDimensionsSize(MlirAttribute attr)227 intptr_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsSize(
228 MlirAttribute attr) {
229 return unwrap(attr)
230 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
231 .getRhsContractingDimensions()
232 .size();
233 }
234
mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem(MlirAttribute attr,intptr_t pos)235 int64_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem(
236 MlirAttribute attr, intptr_t pos) {
237 return unwrap(attr)
238 .cast<mlir::mhlo::DotDimensionNumbersAttr>()
239 .getRhsContractingDimensions()[pos];
240 }
241
242 //
243 // ConvDimensionNumbersAttr.
244 //
245
mlirMhloConvDimensionNumbersGet(MlirContext ctx,int64_t inputBatchDimension,int64_t inputFeatureDimension,intptr_t nInputSpatialDimensions,const int64_t * inputSpatialDimensions,int64_t kernelInputFeatureDimension,int64_t kernelOutputFeatureDimension,intptr_t nKernelSpatialDimensions,const int64_t * kernelSpatialDimensions,int64_t outputBatchDimension,int64_t outputFeatureDimension,intptr_t nOutputSpatialDimensions,const int64_t * outputSpatialDimensions)246 MlirAttribute mlirMhloConvDimensionNumbersGet(
247 MlirContext ctx, int64_t inputBatchDimension, int64_t inputFeatureDimension,
248 intptr_t nInputSpatialDimensions, const int64_t *inputSpatialDimensions,
249 int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
250 intptr_t nKernelSpatialDimensions, const int64_t *kernelSpatialDimensions,
251 int64_t outputBatchDimension, int64_t outputFeatureDimension,
252 intptr_t nOutputSpatialDimensions, const int64_t *outputSpatialDimensions) {
253 return wrap(mlir::mhlo::ConvDimensionNumbersAttr::get(
254 unwrap(ctx), inputBatchDimension, inputFeatureDimension,
255 llvm::makeArrayRef(inputSpatialDimensions, nInputSpatialDimensions),
256 kernelInputFeatureDimension, kernelOutputFeatureDimension,
257 llvm::makeArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions),
258 outputBatchDimension, outputFeatureDimension,
259 llvm::makeArrayRef(outputSpatialDimensions, nOutputSpatialDimensions)));
260 }
261
mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr)262 bool mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr) {
263 return unwrap(attr).isa<mlir::mhlo::ConvDimensionNumbersAttr>();
264 }
265
mlirMhloConvDimensionNumbersGetInputBatchDimension(MlirAttribute attr)266 int64_t mlirMhloConvDimensionNumbersGetInputBatchDimension(MlirAttribute attr) {
267 return unwrap(attr)
268 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
269 .getInputBatchDimension();
270 }
271
mlirMhloConvDimensionNumbersGetInputFeatureDimension(MlirAttribute attr)272 int64_t mlirMhloConvDimensionNumbersGetInputFeatureDimension(
273 MlirAttribute attr) {
274 return unwrap(attr)
275 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
276 .getInputFeatureDimension();
277 }
278
mlirMhloConvDimensionNumbersGetInputSpatialDimensionsSize(MlirAttribute attr)279 intptr_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsSize(
280 MlirAttribute attr) {
281 return unwrap(attr)
282 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
283 .getInputSpatialDimensions()
284 .size();
285 }
286
mlirMhloConvDimensionNumbersGetInputSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)287 int64_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsElem(
288 MlirAttribute attr, intptr_t pos) {
289 return unwrap(attr)
290 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
291 .getInputSpatialDimensions()[pos];
292 }
293
mlirMhloConvDimensionNumbersGetKernelInputFeatureDimension(MlirAttribute attr)294 int64_t mlirMhloConvDimensionNumbersGetKernelInputFeatureDimension(
295 MlirAttribute attr) {
296 return unwrap(attr)
297 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
298 .getKernelInputFeatureDimension();
299 }
300
mlirMhloConvDimensionNumbersGetKernelOutputFeatureDimension(MlirAttribute attr)301 int64_t mlirMhloConvDimensionNumbersGetKernelOutputFeatureDimension(
302 MlirAttribute attr) {
303 return unwrap(attr)
304 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
305 .getKernelOutputFeatureDimension();
306 }
307
mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsSize(MlirAttribute attr)308 intptr_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsSize(
309 MlirAttribute attr) {
310 return unwrap(attr)
311 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
312 .getKernelSpatialDimensions()
313 .size();
314 }
315
mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)316 int64_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsElem(
317 MlirAttribute attr, intptr_t pos) {
318 return unwrap(attr)
319 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
320 .getKernelSpatialDimensions()[pos];
321 }
322
mlirMhloConvDimensionNumbersGetOutputBatchDimension(MlirAttribute attr)323 int64_t mlirMhloConvDimensionNumbersGetOutputBatchDimension(
324 MlirAttribute attr) {
325 return unwrap(attr)
326 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
327 .getOutputBatchDimension();
328 }
329
mlirMhloConvDimensionNumbersGetOutputFeatureDimension(MlirAttribute attr)330 int64_t mlirMhloConvDimensionNumbersGetOutputFeatureDimension(
331 MlirAttribute attr) {
332 return unwrap(attr)
333 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
334 .getOutputFeatureDimension();
335 }
336
mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsSize(MlirAttribute attr)337 intptr_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsSize(
338 MlirAttribute attr) {
339 return unwrap(attr)
340 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
341 .getOutputSpatialDimensions()
342 .size();
343 }
344
mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)345 int64_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(
346 MlirAttribute attr, intptr_t pos) {
347 return unwrap(attr)
348 .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
349 .getOutputSpatialDimensions()[pos];
350 }
351
352 //
353 // ComparisonDirectionAttr.
354 //
mlirMhloComparisonDirectionAttrGet(MlirContext ctx,MlirStringRef direction)355 MlirAttribute mlirMhloComparisonDirectionAttrGet(MlirContext ctx,
356 MlirStringRef direction) {
357 llvm::Optional<mlir::mhlo::ComparisonDirection> compareDirection =
358 mlir::mhlo::symbolizeComparisonDirection(unwrap(direction));
359 if (!compareDirection)
360 llvm_unreachable("Invalid comparison-direction specified.");
361 return wrap(mlir::mhlo::ComparisonDirectionAttr::get(
362 unwrap(ctx), compareDirection.getValue()));
363 }
364
mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr)365 bool mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr) {
366 return unwrap(attr).isa<mlir::mhlo::ComparisonDirectionAttr>();
367 }
368
mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr)369 MlirStringRef mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) {
370 return wrap(mlir::mhlo::stringifyComparisonDirection(
371 unwrap(attr).cast<mlir::mhlo::ComparisonDirectionAttr>().getValue()));
372 }
373
374 //
375 // ComparisonTypeAttr.
376 //
377
mlirMhloComparisonTypeAttrGet(MlirContext ctx,MlirStringRef type)378 MlirAttribute mlirMhloComparisonTypeAttrGet(MlirContext ctx,
379 MlirStringRef type) {
380 llvm::Optional<mlir::mhlo::ComparisonType> compareType =
381 mlir::mhlo::symbolizeComparisonType(unwrap(type));
382 if (!compareType) llvm_unreachable("Invalid comparison-type specified.");
383 return wrap(
384 mlir::mhlo::ComparisonTypeAttr::get(unwrap(ctx), compareType.getValue()));
385 }
386
mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr)387 bool mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr) {
388 return unwrap(attr).isa<mlir::mhlo::ComparisonTypeAttr>();
389 }
390
mlirMhloComparisonTypeAttrGetType(MlirAttribute attr)391 MlirStringRef mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) {
392 return wrap(mlir::mhlo::stringifyComparisonType(
393 unwrap(attr).cast<mlir::mhlo::ComparisonTypeAttr>().getValue()));
394 }
395
396 //
397 // DomainKindAttr.
398 //
399
mlirMhloDomainKindAttrGet(MlirContext ctx,MlirStringRef kind)400 MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, MlirStringRef kind) {
401 llvm::Optional<mlir::mhlo::DomainKind> domainKind =
402 mlir::mhlo::symbolizeDomainKind(unwrap(kind));
403 if (!domainKind) llvm_unreachable("Invalid domain kind specified.");
404 return wrap(
405 mlir::mhlo::DomainKindAttr::get(unwrap(ctx), domainKind.getValue()));
406 }
407
mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr)408 bool mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr) {
409 return unwrap(attr).isa<mlir::mhlo::DomainKindAttr>();
410 }
411
mlirMhloDomainKindAttrGetType(MlirAttribute attr)412 MlirStringRef mlirMhloDomainKindAttrGetType(MlirAttribute attr) {
413 return wrap(mlir::mhlo::stringifyDomainKind(
414 unwrap(attr).cast<mlir::mhlo::DomainKindAttr>().getValue()));
415 }
416
417 //
418 // PrecisionAttr.
419 //
420
mlirMhloPrecisionAttrGet(MlirContext ctx,MlirStringRef type)421 MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef type) {
422 llvm::Optional<mlir::mhlo::Precision> precisionType =
423 mlir::mhlo::symbolizePrecision(unwrap(type));
424 if (!precisionType) llvm_unreachable("Invalid precision-type specified.");
425 return wrap(
426 mlir::mhlo::PrecisionAttr::get(unwrap(ctx), precisionType.getValue()));
427 }
428
mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr)429 bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr) {
430 return unwrap(attr).isa<mlir::mhlo::PrecisionAttr>();
431 }
432
mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr)433 MlirStringRef mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) {
434 return wrap(mlir::mhlo::stringifyPrecision(
435 unwrap(attr).cast<mlir::mhlo::PrecisionAttr>().getValue()));
436 }
437
438 //
439 // FftTypeAttr.
440 //
441
mlirMhloFftTypeAttrGet(MlirContext ctx,MlirStringRef type)442 MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef type) {
443 llvm::Optional<mlir::mhlo::FftType> fftType =
444 mlir::mhlo::symbolizeFftType(unwrap(type));
445 if (!fftType) llvm_unreachable("Invalid fft-type specified.");
446 return wrap(mlir::mhlo::FftTypeAttr::get(unwrap(ctx), fftType.getValue()));
447 }
448
mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr)449 bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr) {
450 return unwrap(attr).isa<mlir::mhlo::FftTypeAttr>();
451 }
452
mlirMhloFftTypeAttrGetFftType(MlirAttribute attr)453 MlirStringRef mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) {
454 return wrap(mlir::mhlo::stringifyFftType(
455 unwrap(attr).cast<mlir::mhlo::FftTypeAttr>().getValue()));
456 }
457
458 //
459 // DequantizeModeAttr.
460 //
461
mlirMhloDequantizeModeAttrGet(MlirContext ctx,MlirStringRef mode)462 MlirAttribute mlirMhloDequantizeModeAttrGet(MlirContext ctx,
463 MlirStringRef mode) {
464 llvm::Optional<mlir::mhlo::DequantizeMode> dequantizeMode =
465 mlir::mhlo::symbolizeDequantizeMode(unwrap(mode));
466 if (!dequantizeMode) llvm_unreachable("Invalid dequantize-mode specified.");
467 return wrap(mlir::mhlo::DequantizeModeAttr::get(unwrap(ctx),
468 dequantizeMode.getValue()));
469 }
470
mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr)471 bool mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr) {
472 return unwrap(attr).isa<mlir::mhlo::DequantizeModeAttr>();
473 }
474
mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr)475 MlirStringRef mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) {
476 return wrap(mlir::mhlo::stringifyDequantizeMode(
477 unwrap(attr).cast<mlir::mhlo::DequantizeModeAttr>().getValue()));
478 }
479
480 //
481 // TransposeAttr.
482 //
483
mlirMhloTransposeAttrGet(MlirContext ctx,MlirStringRef type)484 MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef type) {
485 llvm::Optional<mlir::mhlo::Transpose> transposeType =
486 mlir::mhlo::symbolizeTranspose(unwrap(type));
487 if (!transposeType) llvm_unreachable("Invalid transpose-type specified.");
488 return wrap(
489 mlir::mhlo::TransposeAttr::get(unwrap(ctx), transposeType.getValue()));
490 }
491
mlirMhloAttributeIsATransposeAttr(MlirAttribute attr)492 bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr) {
493 return unwrap(attr).isa<mlir::mhlo::TransposeAttr>();
494 }
495
mlirMhloTransposeAttrGetTranspose(MlirAttribute attr)496 MlirStringRef mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) {
497 return wrap(mlir::mhlo::stringifyTranspose(
498 unwrap(attr).cast<mlir::mhlo::TransposeAttr>().getValue()));
499 }
500
501 //
502 // FusionKindAttr.
503 //
504
mlirMhloFusionKindAttrGet(MlirContext ctx,MlirStringRef kind)505 MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef kind) {
506 llvm::Optional<mlir::mhlo::FusionKind> fusionKind =
507 mlir::mhlo::symbolizeFusionKind(unwrap(kind));
508 if (!fusionKind) llvm_unreachable("Invalid fusion-kind specified.");
509 return wrap(
510 mlir::mhlo::FusionKindAttr::get(unwrap(ctx), fusionKind.getValue()));
511 }
512
mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr)513 bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr) {
514 return unwrap(attr).isa<mlir::mhlo::FusionKindAttr>();
515 }
516
mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr)517 MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) {
518 return wrap(mlir::mhlo::stringifyFusionKind(
519 unwrap(attr).cast<mlir::mhlo::FusionKindAttr>().getValue()));
520 }
521
522 //
523 // RngDistributionAttr.
524 //
525
mlirMhloRngDistributionAttrGet(MlirContext ctx,MlirStringRef distribution)526 MlirAttribute mlirMhloRngDistributionAttrGet(MlirContext ctx,
527 MlirStringRef distribution) {
528 llvm::Optional<mlir::mhlo::RngDistribution> rngDistribution =
529 mlir::mhlo::symbolizeRngDistribution(unwrap(distribution));
530 if (!rngDistribution) llvm_unreachable("Invalid rng-distribution specified.");
531 return wrap(mlir::mhlo::RngDistributionAttr::get(unwrap(ctx),
532 rngDistribution.getValue()));
533 }
534
mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr)535 bool mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr) {
536 return unwrap(attr).isa<mlir::mhlo::RngDistributionAttr>();
537 }
538
mlirMhloRngDistributionAttrGetRngDistribution(MlirAttribute attr)539 MlirStringRef mlirMhloRngDistributionAttrGetRngDistribution(
540 MlirAttribute attr) {
541 return wrap(mlir::mhlo::stringifyRngDistribution(
542 unwrap(attr).cast<mlir::mhlo::RngDistributionAttr>().getValue()));
543 }
544
545 //
546 // RngAlgorithmAttr.
547 //
548
mlirMhloRngAlgorithmAttrGet(MlirContext ctx,MlirStringRef algorithm)549 MlirAttribute mlirMhloRngAlgorithmAttrGet(MlirContext ctx,
550 MlirStringRef algorithm) {
551 llvm::Optional<mlir::mhlo::RngAlgorithm> rngAlgorithm =
552 mlir::mhlo::symbolizeRngAlgorithm(unwrap(algorithm));
553 if (!rngAlgorithm) llvm_unreachable("Invalid rng-algorithm specified.");
554 return wrap(
555 mlir::mhlo::RngAlgorithmAttr::get(unwrap(ctx), rngAlgorithm.getValue()));
556 }
557
mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr)558 bool mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr) {
559 return unwrap(attr).isa<mlir::mhlo::RngAlgorithmAttr>();
560 }
561
mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr)562 MlirStringRef mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr) {
563 return wrap(mlir::mhlo::stringifyRngAlgorithm(
564 unwrap(attr).cast<mlir::mhlo::RngAlgorithmAttr>().getValue()));
565 }
566
567 //
568 // ChannelHandle
569 //
570
mlirMhloChannelHandleGet(MlirContext ctx,int64_t handle,int64_t type)571 MlirAttribute mlirMhloChannelHandleGet(MlirContext ctx, int64_t handle,
572 int64_t type) {
573 return wrap(mlir::mhlo::ChannelHandleAttr::get(unwrap(ctx), handle, type));
574 }
575
mlirMhloAttributeIsChannelHandle(MlirAttribute attr)576 bool mlirMhloAttributeIsChannelHandle(MlirAttribute attr) {
577 return unwrap(attr).isa<mlir::mhlo::ChannelHandleAttr>();
578 }
579
mlirMhloChannelHandleGetHandle(MlirAttribute attr)580 int64_t mlirMhloChannelHandleGetHandle(MlirAttribute attr) {
581 return unwrap(attr).cast<mlir::mhlo::ChannelHandleAttr>().getHandle();
582 }
583
mlirMhloChannelHandleGetType(MlirAttribute attr)584 int64_t mlirMhloChannelHandleGetType(MlirAttribute attr) {
585 return unwrap(attr).cast<mlir::mhlo::ChannelHandleAttr>().getType();
586 }
587
588 //
589 // TypeExtensions
590 //
591
mlirMhloTypeExtensionsGet(MlirContext ctx,intptr_t nBounds,const int64_t * bounds)592 MlirAttribute mlirMhloTypeExtensionsGet(MlirContext ctx, intptr_t nBounds,
593 const int64_t *bounds) {
594 return wrap(mlir::mhlo::TypeExtensionsAttr::get(
595 unwrap(ctx), llvm::makeArrayRef(bounds, nBounds)));
596 }
597
mlirMhloAttributeIsTypeExtensions(MlirAttribute attr)598 bool mlirMhloAttributeIsTypeExtensions(MlirAttribute attr) {
599 return unwrap(attr).isa<mlir::mhlo::TypeExtensionsAttr>();
600 }
601
mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr)602 intptr_t mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr) {
603 return unwrap(attr).cast<mlir::mhlo::TypeExtensionsAttr>().getBounds().size();
604 }
605
mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr,intptr_t pos)606 int64_t mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos) {
607 return unwrap(attr).cast<mlir::mhlo::TypeExtensionsAttr>().getBounds()[pos];
608 }
609