• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include "mlir-hlo-c/Attributes.h"
14 
15 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
16 #include "mlir/CAPI/IR.h"
17 #include "mlir/CAPI/Support.h"
18 
19 //
20 // ScatterDimensionNumbersAttr.
21 //
22 
mlirMhloScatterDimensionNumbersGet(MlirContext ctx,intptr_t nUpdateWindowDims,const int64_t * updateWindowDims,intptr_t nInsertedWindowDims,const int64_t * insertedWindowDims,intptr_t nScatteredDimsToOperandDims,const int64_t * scatteredDimsToOperandDims,int64_t indexVectorDim)23 MlirAttribute mlirMhloScatterDimensionNumbersGet(
24     MlirContext ctx, intptr_t nUpdateWindowDims,
25     const int64_t *updateWindowDims, intptr_t nInsertedWindowDims,
26     const int64_t *insertedWindowDims, intptr_t nScatteredDimsToOperandDims,
27     const int64_t *scatteredDimsToOperandDims, int64_t indexVectorDim) {
28   return wrap(mlir::mhlo::ScatterDimensionNumbersAttr::get(
29       unwrap(ctx), llvm::makeArrayRef(updateWindowDims, nUpdateWindowDims),
30       llvm::makeArrayRef(insertedWindowDims, nInsertedWindowDims),
31       llvm::makeArrayRef(scatteredDimsToOperandDims,
32                          nScatteredDimsToOperandDims),
33       indexVectorDim));
34 }
35 
mlirMhloAttributeIsAScatterDimensionNumbers(MlirAttribute attr)36 bool mlirMhloAttributeIsAScatterDimensionNumbers(MlirAttribute attr) {
37   return unwrap(attr).isa<mlir::mhlo::ScatterDimensionNumbersAttr>();
38 }
39 
mlirMhloScatterDimensionNumbersGetUpdateWindowDimsSize(MlirAttribute attr)40 intptr_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsSize(
41     MlirAttribute attr) {
42   return unwrap(attr)
43       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
44       .getUpdateWindowDims()
45       .size();
46 }
47 
mlirMhloScatterDimensionNumbersGetUpdateWindowDimsElem(MlirAttribute attr,intptr_t pos)48 int64_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsElem(
49     MlirAttribute attr, intptr_t pos) {
50   return unwrap(attr)
51       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
52       .getUpdateWindowDims()[pos];
53 }
54 
mlirMhloScatterDimensionNumbersGetInsertedWindowDimsSize(MlirAttribute attr)55 intptr_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsSize(
56     MlirAttribute attr) {
57   return unwrap(attr)
58       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
59       .getInsertedWindowDims()
60       .size();
61 }
62 
mlirMhloScatterDimensionNumbersGetInsertedWindowDimsElem(MlirAttribute attr,intptr_t pos)63 int64_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsElem(
64     MlirAttribute attr, intptr_t pos) {
65   return unwrap(attr)
66       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
67       .getInsertedWindowDims()[pos];
68 }
69 
mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(MlirAttribute attr)70 intptr_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(
71     MlirAttribute attr) {
72   return unwrap(attr)
73       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
74       .getScatterDimsToOperandDims()
75       .size();
76 }
77 
mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(MlirAttribute attr,intptr_t pos)78 int64_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(
79     MlirAttribute attr, intptr_t pos) {
80   return unwrap(attr)
81       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
82       .getScatterDimsToOperandDims()[pos];
83 }
84 
mlirMhloDimensionNumbersGetIndexVectorDim(MlirAttribute attr)85 int64_t mlirMhloDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
86   return unwrap(attr)
87       .cast<mlir::mhlo::ScatterDimensionNumbersAttr>()
88       .getIndexVectorDim();
89 }
90 
91 //
92 // GatherDimensionNumbersAttr.
93 //
94 
mlirMhloGatherDimensionNumbersGet(MlirContext ctx,intptr_t nOffsetDims,const int64_t * offsetDims,intptr_t nCollapsedSliceDims,const int64_t * collapsedSliceDims,intptr_t nStartIndexMap,const int64_t * startIndexMap,int64_t indexVectorDim)95 MlirAttribute mlirMhloGatherDimensionNumbersGet(
96     MlirContext ctx, intptr_t nOffsetDims, const int64_t *offsetDims,
97     intptr_t nCollapsedSliceDims, const int64_t *collapsedSliceDims,
98     intptr_t nStartIndexMap, const int64_t *startIndexMap,
99     int64_t indexVectorDim) {
100   return wrap(mlir::mhlo::GatherDimensionNumbersAttr::get(
101       unwrap(ctx), llvm::makeArrayRef(offsetDims, nOffsetDims),
102       llvm::makeArrayRef(collapsedSliceDims, nCollapsedSliceDims),
103       llvm::makeArrayRef(startIndexMap, nStartIndexMap), indexVectorDim));
104 }
105 
mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr)106 bool mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr) {
107   return unwrap(attr).isa<mlir::mhlo::GatherDimensionNumbersAttr>();
108 }
109 
mlirMhloGatherDimensionNumbersGetOffsetDimsSize(MlirAttribute attr)110 intptr_t mlirMhloGatherDimensionNumbersGetOffsetDimsSize(MlirAttribute attr) {
111   return unwrap(attr)
112       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
113       .getOffsetDims()
114       .size();
115 }
116 
mlirMhloGatherDimensionNumbersGetOffsetDimsElem(MlirAttribute attr,intptr_t pos)117 int64_t mlirMhloGatherDimensionNumbersGetOffsetDimsElem(MlirAttribute attr,
118                                                         intptr_t pos) {
119   return unwrap(attr)
120       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
121       .getOffsetDims()[pos];
122 }
123 
mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsSize(MlirAttribute attr)124 intptr_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsSize(
125     MlirAttribute attr) {
126   return unwrap(attr)
127       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
128       .getCollapsedSliceDims()
129       .size();
130 }
131 
mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsElem(MlirAttribute attr,intptr_t pos)132 int64_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsElem(
133     MlirAttribute attr, intptr_t pos) {
134   return unwrap(attr)
135       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
136       .getCollapsedSliceDims()[pos];
137 }
138 
mlirMhloGatherDimensionNumbersGetStartIndexMapSize(MlirAttribute attr)139 intptr_t mlirMhloGatherDimensionNumbersGetStartIndexMapSize(
140     MlirAttribute attr) {
141   return unwrap(attr)
142       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
143       .getStartIndexMap()
144       .size();
145 }
146 
mlirMhloGatherDimensionNumbersGetStartIndexMapElem(MlirAttribute attr,intptr_t pos)147 int64_t mlirMhloGatherDimensionNumbersGetStartIndexMapElem(MlirAttribute attr,
148                                                            intptr_t pos) {
149   return unwrap(attr)
150       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
151       .getStartIndexMap()[pos];
152 }
153 
mlirMhloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr)154 int64_t mlirMhloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
155   return unwrap(attr)
156       .cast<mlir::mhlo::GatherDimensionNumbersAttr>()
157       .getIndexVectorDim();
158 }
159 
160 //
161 // DotDimensionNumbersAttr.
162 //
163 
mlirMhloDotDimensionNumbersGet(MlirContext ctx,intptr_t nLhsBatchingDimensions,const int64_t * lhsBatchingDimensions,intptr_t nRhsBatchingDimensions,const int64_t * rhsBatchingDimensions,intptr_t nLhsContractingDimensions,const int64_t * lhsContractingDimensions,intptr_t nRhsContractingDimensions,const int64_t * rhsContractingDimensions)164 MlirAttribute mlirMhloDotDimensionNumbersGet(
165     MlirContext ctx, intptr_t nLhsBatchingDimensions,
166     const int64_t *lhsBatchingDimensions, intptr_t nRhsBatchingDimensions,
167     const int64_t *rhsBatchingDimensions, intptr_t nLhsContractingDimensions,
168     const int64_t *lhsContractingDimensions, intptr_t nRhsContractingDimensions,
169     const int64_t *rhsContractingDimensions) {
170   return wrap(mlir::mhlo::DotDimensionNumbersAttr::get(
171       unwrap(ctx),
172       llvm::makeArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions),
173       llvm::makeArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions),
174       llvm::makeArrayRef(lhsContractingDimensions, nLhsContractingDimensions),
175       llvm::makeArrayRef(rhsContractingDimensions, nRhsContractingDimensions)));
176 }
177 
mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr)178 bool mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr) {
179   return unwrap(attr).isa<mlir::mhlo::DotDimensionNumbersAttr>();
180 }
181 
mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsSize(MlirAttribute attr)182 intptr_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsSize(
183     MlirAttribute attr) {
184   return unwrap(attr)
185       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
186       .getLhsBatchingDimensions()
187       .size();
188 }
189 
mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsElem(MlirAttribute attr,intptr_t pos)190 int64_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsElem(
191     MlirAttribute attr, intptr_t pos) {
192   return unwrap(attr)
193       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
194       .getLhsBatchingDimensions()[pos];
195 }
196 
mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsSize(MlirAttribute attr)197 intptr_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsSize(
198     MlirAttribute attr) {
199   return unwrap(attr)
200       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
201       .getRhsBatchingDimensions()
202       .size();
203 }
204 
mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsElem(MlirAttribute attr,intptr_t pos)205 int64_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsElem(
206     MlirAttribute attr, intptr_t pos) {
207   return unwrap(attr)
208       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
209       .getRhsBatchingDimensions()[pos];
210 }
211 
mlirMhloDotDimensionNumbersGetLhsContractingDimensionsSize(MlirAttribute attr)212 intptr_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsSize(
213     MlirAttribute attr) {
214   return unwrap(attr)
215       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
216       .getLhsContractingDimensions()
217       .size();
218 }
219 
mlirMhloDotDimensionNumbersGetLhsContractingDimensionsElem(MlirAttribute attr,intptr_t pos)220 int64_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsElem(
221     MlirAttribute attr, intptr_t pos) {
222   return unwrap(attr)
223       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
224       .getLhsContractingDimensions()[pos];
225 }
226 
mlirMhloDotDimensionNumbersGetRhsContractingDimensionsSize(MlirAttribute attr)227 intptr_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsSize(
228     MlirAttribute attr) {
229   return unwrap(attr)
230       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
231       .getRhsContractingDimensions()
232       .size();
233 }
234 
mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem(MlirAttribute attr,intptr_t pos)235 int64_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem(
236     MlirAttribute attr, intptr_t pos) {
237   return unwrap(attr)
238       .cast<mlir::mhlo::DotDimensionNumbersAttr>()
239       .getRhsContractingDimensions()[pos];
240 }
241 
242 //
243 // ConvDimensionNumbersAttr.
244 //
245 
mlirMhloConvDimensionNumbersGet(MlirContext ctx,int64_t inputBatchDimension,int64_t inputFeatureDimension,intptr_t nInputSpatialDimensions,const int64_t * inputSpatialDimensions,int64_t kernelInputFeatureDimension,int64_t kernelOutputFeatureDimension,intptr_t nKernelSpatialDimensions,const int64_t * kernelSpatialDimensions,int64_t outputBatchDimension,int64_t outputFeatureDimension,intptr_t nOutputSpatialDimensions,const int64_t * outputSpatialDimensions)246 MlirAttribute mlirMhloConvDimensionNumbersGet(
247     MlirContext ctx, int64_t inputBatchDimension, int64_t inputFeatureDimension,
248     intptr_t nInputSpatialDimensions, const int64_t *inputSpatialDimensions,
249     int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
250     intptr_t nKernelSpatialDimensions, const int64_t *kernelSpatialDimensions,
251     int64_t outputBatchDimension, int64_t outputFeatureDimension,
252     intptr_t nOutputSpatialDimensions, const int64_t *outputSpatialDimensions) {
253   return wrap(mlir::mhlo::ConvDimensionNumbersAttr::get(
254       unwrap(ctx), inputBatchDimension, inputFeatureDimension,
255       llvm::makeArrayRef(inputSpatialDimensions, nInputSpatialDimensions),
256       kernelInputFeatureDimension, kernelOutputFeatureDimension,
257       llvm::makeArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions),
258       outputBatchDimension, outputFeatureDimension,
259       llvm::makeArrayRef(outputSpatialDimensions, nOutputSpatialDimensions)));
260 }
261 
mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr)262 bool mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr) {
263   return unwrap(attr).isa<mlir::mhlo::ConvDimensionNumbersAttr>();
264 }
265 
mlirMhloConvDimensionNumbersGetInputBatchDimension(MlirAttribute attr)266 int64_t mlirMhloConvDimensionNumbersGetInputBatchDimension(MlirAttribute attr) {
267   return unwrap(attr)
268       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
269       .getInputBatchDimension();
270 }
271 
mlirMhloConvDimensionNumbersGetInputFeatureDimension(MlirAttribute attr)272 int64_t mlirMhloConvDimensionNumbersGetInputFeatureDimension(
273     MlirAttribute attr) {
274   return unwrap(attr)
275       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
276       .getInputFeatureDimension();
277 }
278 
mlirMhloConvDimensionNumbersGetInputSpatialDimensionsSize(MlirAttribute attr)279 intptr_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsSize(
280     MlirAttribute attr) {
281   return unwrap(attr)
282       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
283       .getInputSpatialDimensions()
284       .size();
285 }
286 
mlirMhloConvDimensionNumbersGetInputSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)287 int64_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsElem(
288     MlirAttribute attr, intptr_t pos) {
289   return unwrap(attr)
290       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
291       .getInputSpatialDimensions()[pos];
292 }
293 
mlirMhloConvDimensionNumbersGetKernelInputFeatureDimension(MlirAttribute attr)294 int64_t mlirMhloConvDimensionNumbersGetKernelInputFeatureDimension(
295     MlirAttribute attr) {
296   return unwrap(attr)
297       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
298       .getKernelInputFeatureDimension();
299 }
300 
mlirMhloConvDimensionNumbersGetKernelOutputFeatureDimension(MlirAttribute attr)301 int64_t mlirMhloConvDimensionNumbersGetKernelOutputFeatureDimension(
302     MlirAttribute attr) {
303   return unwrap(attr)
304       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
305       .getKernelOutputFeatureDimension();
306 }
307 
mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsSize(MlirAttribute attr)308 intptr_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsSize(
309     MlirAttribute attr) {
310   return unwrap(attr)
311       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
312       .getKernelSpatialDimensions()
313       .size();
314 }
315 
mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)316 int64_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsElem(
317     MlirAttribute attr, intptr_t pos) {
318   return unwrap(attr)
319       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
320       .getKernelSpatialDimensions()[pos];
321 }
322 
mlirMhloConvDimensionNumbersGetOutputBatchDimension(MlirAttribute attr)323 int64_t mlirMhloConvDimensionNumbersGetOutputBatchDimension(
324     MlirAttribute attr) {
325   return unwrap(attr)
326       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
327       .getOutputBatchDimension();
328 }
329 
mlirMhloConvDimensionNumbersGetOutputFeatureDimension(MlirAttribute attr)330 int64_t mlirMhloConvDimensionNumbersGetOutputFeatureDimension(
331     MlirAttribute attr) {
332   return unwrap(attr)
333       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
334       .getOutputFeatureDimension();
335 }
336 
mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsSize(MlirAttribute attr)337 intptr_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsSize(
338     MlirAttribute attr) {
339   return unwrap(attr)
340       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
341       .getOutputSpatialDimensions()
342       .size();
343 }
344 
mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(MlirAttribute attr,intptr_t pos)345 int64_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(
346     MlirAttribute attr, intptr_t pos) {
347   return unwrap(attr)
348       .cast<mlir::mhlo::ConvDimensionNumbersAttr>()
349       .getOutputSpatialDimensions()[pos];
350 }
351 
352 //
353 // ComparisonDirectionAttr.
354 //
mlirMhloComparisonDirectionAttrGet(MlirContext ctx,MlirStringRef direction)355 MlirAttribute mlirMhloComparisonDirectionAttrGet(MlirContext ctx,
356                                                  MlirStringRef direction) {
357   llvm::Optional<mlir::mhlo::ComparisonDirection> compareDirection =
358       mlir::mhlo::symbolizeComparisonDirection(unwrap(direction));
359   if (!compareDirection)
360     llvm_unreachable("Invalid comparison-direction specified.");
361   return wrap(mlir::mhlo::ComparisonDirectionAttr::get(
362       unwrap(ctx), compareDirection.getValue()));
363 }
364 
mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr)365 bool mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr) {
366   return unwrap(attr).isa<mlir::mhlo::ComparisonDirectionAttr>();
367 }
368 
mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr)369 MlirStringRef mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) {
370   return wrap(mlir::mhlo::stringifyComparisonDirection(
371       unwrap(attr).cast<mlir::mhlo::ComparisonDirectionAttr>().getValue()));
372 }
373 
374 //
375 // ComparisonTypeAttr.
376 //
377 
mlirMhloComparisonTypeAttrGet(MlirContext ctx,MlirStringRef type)378 MlirAttribute mlirMhloComparisonTypeAttrGet(MlirContext ctx,
379                                             MlirStringRef type) {
380   llvm::Optional<mlir::mhlo::ComparisonType> compareType =
381       mlir::mhlo::symbolizeComparisonType(unwrap(type));
382   if (!compareType) llvm_unreachable("Invalid comparison-type specified.");
383   return wrap(
384       mlir::mhlo::ComparisonTypeAttr::get(unwrap(ctx), compareType.getValue()));
385 }
386 
mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr)387 bool mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr) {
388   return unwrap(attr).isa<mlir::mhlo::ComparisonTypeAttr>();
389 }
390 
mlirMhloComparisonTypeAttrGetType(MlirAttribute attr)391 MlirStringRef mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) {
392   return wrap(mlir::mhlo::stringifyComparisonType(
393       unwrap(attr).cast<mlir::mhlo::ComparisonTypeAttr>().getValue()));
394 }
395 
396 //
397 // DomainKindAttr.
398 //
399 
mlirMhloDomainKindAttrGet(MlirContext ctx,MlirStringRef kind)400 MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, MlirStringRef kind) {
401   llvm::Optional<mlir::mhlo::DomainKind> domainKind =
402       mlir::mhlo::symbolizeDomainKind(unwrap(kind));
403   if (!domainKind) llvm_unreachable("Invalid domain kind specified.");
404   return wrap(
405       mlir::mhlo::DomainKindAttr::get(unwrap(ctx), domainKind.getValue()));
406 }
407 
mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr)408 bool mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr) {
409   return unwrap(attr).isa<mlir::mhlo::DomainKindAttr>();
410 }
411 
mlirMhloDomainKindAttrGetType(MlirAttribute attr)412 MlirStringRef mlirMhloDomainKindAttrGetType(MlirAttribute attr) {
413   return wrap(mlir::mhlo::stringifyDomainKind(
414       unwrap(attr).cast<mlir::mhlo::DomainKindAttr>().getValue()));
415 }
416 
417 //
418 // PrecisionAttr.
419 //
420 
mlirMhloPrecisionAttrGet(MlirContext ctx,MlirStringRef type)421 MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef type) {
422   llvm::Optional<mlir::mhlo::Precision> precisionType =
423       mlir::mhlo::symbolizePrecision(unwrap(type));
424   if (!precisionType) llvm_unreachable("Invalid precision-type specified.");
425   return wrap(
426       mlir::mhlo::PrecisionAttr::get(unwrap(ctx), precisionType.getValue()));
427 }
428 
mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr)429 bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr) {
430   return unwrap(attr).isa<mlir::mhlo::PrecisionAttr>();
431 }
432 
mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr)433 MlirStringRef mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) {
434   return wrap(mlir::mhlo::stringifyPrecision(
435       unwrap(attr).cast<mlir::mhlo::PrecisionAttr>().getValue()));
436 }
437 
438 //
439 // FftTypeAttr.
440 //
441 
mlirMhloFftTypeAttrGet(MlirContext ctx,MlirStringRef type)442 MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef type) {
443   llvm::Optional<mlir::mhlo::FftType> fftType =
444       mlir::mhlo::symbolizeFftType(unwrap(type));
445   if (!fftType) llvm_unreachable("Invalid fft-type specified.");
446   return wrap(mlir::mhlo::FftTypeAttr::get(unwrap(ctx), fftType.getValue()));
447 }
448 
mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr)449 bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr) {
450   return unwrap(attr).isa<mlir::mhlo::FftTypeAttr>();
451 }
452 
mlirMhloFftTypeAttrGetFftType(MlirAttribute attr)453 MlirStringRef mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) {
454   return wrap(mlir::mhlo::stringifyFftType(
455       unwrap(attr).cast<mlir::mhlo::FftTypeAttr>().getValue()));
456 }
457 
458 //
459 // DequantizeModeAttr.
460 //
461 
mlirMhloDequantizeModeAttrGet(MlirContext ctx,MlirStringRef mode)462 MlirAttribute mlirMhloDequantizeModeAttrGet(MlirContext ctx,
463                                             MlirStringRef mode) {
464   llvm::Optional<mlir::mhlo::DequantizeMode> dequantizeMode =
465       mlir::mhlo::symbolizeDequantizeMode(unwrap(mode));
466   if (!dequantizeMode) llvm_unreachable("Invalid dequantize-mode specified.");
467   return wrap(mlir::mhlo::DequantizeModeAttr::get(unwrap(ctx),
468                                                   dequantizeMode.getValue()));
469 }
470 
mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr)471 bool mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr) {
472   return unwrap(attr).isa<mlir::mhlo::DequantizeModeAttr>();
473 }
474 
mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr)475 MlirStringRef mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) {
476   return wrap(mlir::mhlo::stringifyDequantizeMode(
477       unwrap(attr).cast<mlir::mhlo::DequantizeModeAttr>().getValue()));
478 }
479 
480 //
481 // TransposeAttr.
482 //
483 
mlirMhloTransposeAttrGet(MlirContext ctx,MlirStringRef type)484 MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef type) {
485   llvm::Optional<mlir::mhlo::Transpose> transposeType =
486       mlir::mhlo::symbolizeTranspose(unwrap(type));
487   if (!transposeType) llvm_unreachable("Invalid transpose-type specified.");
488   return wrap(
489       mlir::mhlo::TransposeAttr::get(unwrap(ctx), transposeType.getValue()));
490 }
491 
mlirMhloAttributeIsATransposeAttr(MlirAttribute attr)492 bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr) {
493   return unwrap(attr).isa<mlir::mhlo::TransposeAttr>();
494 }
495 
mlirMhloTransposeAttrGetTranspose(MlirAttribute attr)496 MlirStringRef mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) {
497   return wrap(mlir::mhlo::stringifyTranspose(
498       unwrap(attr).cast<mlir::mhlo::TransposeAttr>().getValue()));
499 }
500 
501 //
502 // FusionKindAttr.
503 //
504 
mlirMhloFusionKindAttrGet(MlirContext ctx,MlirStringRef kind)505 MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef kind) {
506   llvm::Optional<mlir::mhlo::FusionKind> fusionKind =
507       mlir::mhlo::symbolizeFusionKind(unwrap(kind));
508   if (!fusionKind) llvm_unreachable("Invalid fusion-kind specified.");
509   return wrap(
510       mlir::mhlo::FusionKindAttr::get(unwrap(ctx), fusionKind.getValue()));
511 }
512 
mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr)513 bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr) {
514   return unwrap(attr).isa<mlir::mhlo::FusionKindAttr>();
515 }
516 
mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr)517 MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) {
518   return wrap(mlir::mhlo::stringifyFusionKind(
519       unwrap(attr).cast<mlir::mhlo::FusionKindAttr>().getValue()));
520 }
521 
522 //
523 // RngDistributionAttr.
524 //
525 
mlirMhloRngDistributionAttrGet(MlirContext ctx,MlirStringRef distribution)526 MlirAttribute mlirMhloRngDistributionAttrGet(MlirContext ctx,
527                                              MlirStringRef distribution) {
528   llvm::Optional<mlir::mhlo::RngDistribution> rngDistribution =
529       mlir::mhlo::symbolizeRngDistribution(unwrap(distribution));
530   if (!rngDistribution) llvm_unreachable("Invalid rng-distribution specified.");
531   return wrap(mlir::mhlo::RngDistributionAttr::get(unwrap(ctx),
532                                                    rngDistribution.getValue()));
533 }
534 
mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr)535 bool mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr) {
536   return unwrap(attr).isa<mlir::mhlo::RngDistributionAttr>();
537 }
538 
mlirMhloRngDistributionAttrGetRngDistribution(MlirAttribute attr)539 MlirStringRef mlirMhloRngDistributionAttrGetRngDistribution(
540     MlirAttribute attr) {
541   return wrap(mlir::mhlo::stringifyRngDistribution(
542       unwrap(attr).cast<mlir::mhlo::RngDistributionAttr>().getValue()));
543 }
544 
545 //
546 // RngAlgorithmAttr.
547 //
548 
mlirMhloRngAlgorithmAttrGet(MlirContext ctx,MlirStringRef algorithm)549 MlirAttribute mlirMhloRngAlgorithmAttrGet(MlirContext ctx,
550                                           MlirStringRef algorithm) {
551   llvm::Optional<mlir::mhlo::RngAlgorithm> rngAlgorithm =
552       mlir::mhlo::symbolizeRngAlgorithm(unwrap(algorithm));
553   if (!rngAlgorithm) llvm_unreachable("Invalid rng-algorithm specified.");
554   return wrap(
555       mlir::mhlo::RngAlgorithmAttr::get(unwrap(ctx), rngAlgorithm.getValue()));
556 }
557 
mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr)558 bool mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr) {
559   return unwrap(attr).isa<mlir::mhlo::RngAlgorithmAttr>();
560 }
561 
mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr)562 MlirStringRef mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr) {
563   return wrap(mlir::mhlo::stringifyRngAlgorithm(
564       unwrap(attr).cast<mlir::mhlo::RngAlgorithmAttr>().getValue()));
565 }
566 
567 //
568 // ChannelHandle
569 //
570 
mlirMhloChannelHandleGet(MlirContext ctx,int64_t handle,int64_t type)571 MlirAttribute mlirMhloChannelHandleGet(MlirContext ctx, int64_t handle,
572                                        int64_t type) {
573   return wrap(mlir::mhlo::ChannelHandleAttr::get(unwrap(ctx), handle, type));
574 }
575 
mlirMhloAttributeIsChannelHandle(MlirAttribute attr)576 bool mlirMhloAttributeIsChannelHandle(MlirAttribute attr) {
577   return unwrap(attr).isa<mlir::mhlo::ChannelHandleAttr>();
578 }
579 
mlirMhloChannelHandleGetHandle(MlirAttribute attr)580 int64_t mlirMhloChannelHandleGetHandle(MlirAttribute attr) {
581   return unwrap(attr).cast<mlir::mhlo::ChannelHandleAttr>().getHandle();
582 }
583 
mlirMhloChannelHandleGetType(MlirAttribute attr)584 int64_t mlirMhloChannelHandleGetType(MlirAttribute attr) {
585   return unwrap(attr).cast<mlir::mhlo::ChannelHandleAttr>().getType();
586 }
587 
588 //
589 // TypeExtensions
590 //
591 
mlirMhloTypeExtensionsGet(MlirContext ctx,intptr_t nBounds,const int64_t * bounds)592 MlirAttribute mlirMhloTypeExtensionsGet(MlirContext ctx, intptr_t nBounds,
593                                         const int64_t *bounds) {
594   return wrap(mlir::mhlo::TypeExtensionsAttr::get(
595       unwrap(ctx), llvm::makeArrayRef(bounds, nBounds)));
596 }
597 
mlirMhloAttributeIsTypeExtensions(MlirAttribute attr)598 bool mlirMhloAttributeIsTypeExtensions(MlirAttribute attr) {
599   return unwrap(attr).isa<mlir::mhlo::TypeExtensionsAttr>();
600 }
601 
mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr)602 intptr_t mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr) {
603   return unwrap(attr).cast<mlir::mhlo::TypeExtensionsAttr>().getBounds().size();
604 }
605 
mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr,intptr_t pos)606 int64_t mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos) {
607   return unwrap(attr).cast<mlir::mhlo::TypeExtensionsAttr>().getBounds()[pos];
608 }
609