1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
17
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22
23 namespace xla {
24
Convert(llvm::ArrayRef<int64_t> elements,mlir::Builder * builder)25 static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements,
26 mlir::Builder* builder) {
27 return mlir::DenseIntElementsAttr::get(
28 mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)),
29 elements);
30 }
31
ConvertPrecisionConfig(const PrecisionConfig * config,mlir::Builder * builder)32 mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
33 mlir::Builder* builder) {
34 if (!config) return {};
35
36 // TODO(b/129709049) The HLO text format elides this in the all DEFAULT
37 // case and the parser sticks it in. Maybe we should too.
38 llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
39
40 for (auto prec : config->operand_precision()) {
41 operand_precision_attrs.push_back(
42 builder->getStringAttr(PrecisionConfig_Precision_Name(prec)));
43 }
44 return builder->getArrayAttr(operand_precision_attrs);
45 }
46
47 // Converts the gather dimensions to attributes.
ConvertGatherDimensionNumbers(const xla::GatherDimensionNumbers & dnums,mlir::Builder * builder)48 mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
49 const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
50 std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
51 dnums.offset_dims().end());
52 std::vector<int64_t> collapsed_slice_dims(
53 dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
54 std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
55 dnums.start_index_map().end());
56 return mlir::mhlo::GatherDimensionNumbers::get(
57 Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
58 Convert(start_index_map, builder),
59 builder->getI64IntegerAttr(dnums.index_vector_dim()),
60 builder->getContext());
61 }
62
ConvertScatterDimensionNumbers(const xla::ScatterDimensionNumbers & dnums,mlir::Builder * builder)63 mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
64 const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
65 std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
66 dnums.update_window_dims().end());
67 std::vector<int64_t> inserted_window_dims(
68 dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
69 std::vector<int64_t> scatter_dims_to_operand_dims(
70 dnums.scatter_dims_to_operand_dims().begin(),
71 dnums.scatter_dims_to_operand_dims().end());
72 return mlir::mhlo::ScatterDimensionNumbers::get(
73 Convert(update_window_dims, builder),
74 Convert(inserted_window_dims, builder),
75 Convert(scatter_dims_to_operand_dims, builder),
76 builder->getI64IntegerAttr(dnums.index_vector_dim()),
77 builder->getContext());
78 }
79
ConvertDotDimensionNumbers(const DotDimensionNumbers & dnums,mlir::Builder * builder)80 mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
81 const DotDimensionNumbers& dnums, mlir::Builder* builder) {
82 std::vector<int64_t> rhs_contracting_dimensions(
83 dnums.rhs_contracting_dimensions().begin(),
84 dnums.rhs_contracting_dimensions().end());
85 std::vector<int64_t> lhs_contracting_dimensions(
86 dnums.lhs_contracting_dimensions().begin(),
87 dnums.lhs_contracting_dimensions().end());
88 std::vector<int64_t> rhs_batch_dimensions(
89 dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
90 std::vector<int64_t> lhs_batch_dimensions(
91 dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
92
93 // Push the attributes into our new DictionaryAttr.
94 auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder);
95 auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder);
96 auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
97 auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
98
99 return mlir::mhlo::DotDimensionNumbers::get(
100 lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
101 rhs_contracting_dims_attr, builder->getContext());
102 }
103
ConvertConvDimensionNumbers(const xla::ConvolutionDimensionNumbers & dnums,mlir::Builder * builder)104 mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
105 const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
106 llvm::SmallVector<int64_t, 4> input_spatial_dims(
107 dnums.input_spatial_dimensions().begin(),
108 dnums.input_spatial_dimensions().end());
109 llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
110 dnums.kernel_spatial_dimensions().begin(),
111 dnums.kernel_spatial_dimensions().end());
112 llvm::SmallVector<int64_t, 4> output_spatial_dims(
113 dnums.output_spatial_dimensions().begin(),
114 dnums.output_spatial_dimensions().end());
115 return mlir::mhlo::ConvDimensionNumbers::get(
116 builder->getI64IntegerAttr(dnums.input_batch_dimension()),
117 builder->getI64IntegerAttr(dnums.input_feature_dimension()),
118 Convert(input_spatial_dims, builder),
119 builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
120 builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
121 Convert(kernel_spatial_dims, builder),
122 builder->getI64IntegerAttr(dnums.output_batch_dimension()),
123 builder->getI64IntegerAttr(dnums.output_feature_dimension()),
124 Convert(output_spatial_dims, builder), builder->getContext());
125 }
126
ConvertFftType(FftType type)127 StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
128 switch (type) {
129 case FftType::FFT:
130 return mlir::mhlo::FftType::FFT;
131 case FftType::IFFT:
132 return mlir::mhlo::FftType::IFFT;
133 case FftType::RFFT:
134 return mlir::mhlo::FftType::RFFT;
135 case FftType::IRFFT:
136 return mlir::mhlo::FftType::IRFFT;
137 default:
138 return InvalidArgument("Unknown FFT type enum value #%d", type);
139 }
140 }
141
ConvertTranspose(xla::TriangularSolveOptions_Transpose transpose)142 StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
143 xla::TriangularSolveOptions_Transpose transpose) {
144 switch (transpose) {
145 case TriangularSolveOptions::NO_TRANSPOSE:
146 return mlir::mhlo::Transpose::NO_TRANSPOSE;
147 case TriangularSolveOptions::TRANSPOSE:
148 return mlir::mhlo::Transpose::TRANSPOSE;
149 case TriangularSolveOptions::ADJOINT:
150 return mlir::mhlo::Transpose::ADJOINT;
151 case TriangularSolveOptions::TRANSPOSE_INVALID:
152 return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
153 default:
154 return InvalidArgument("Unknown transpose enum value #%d", transpose);
155 }
156 }
157
158 } // namespace xla
159