1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
17
18 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
19 #include "tensorflow/compiler/xla/types.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/stream_executor/dnn.h"
23
24 namespace xla {
25
ConvertConvDimensionNumbers(mlir::mhlo::ConvDimensionNumbers input)26 ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
27 mlir::mhlo::ConvDimensionNumbers input) {
28 ConvolutionDimensionNumbers output;
29
30 output.set_input_batch_dimension(
31 input.input_batch_dimension().getValue().getSExtValue());
32 output.set_input_feature_dimension(
33 input.input_feature_dimension().getValue().getSExtValue());
34
35 for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
36 output.add_input_spatial_dimensions(v);
37 }
38
39 output.set_kernel_input_feature_dimension(
40 input.kernel_input_feature_dimension().getValue().getSExtValue());
41 output.set_kernel_output_feature_dimension(
42 input.kernel_output_feature_dimension().getValue().getSExtValue());
43
44 for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
45 output.add_kernel_spatial_dimensions(v);
46 }
47
48 output.set_output_batch_dimension(
49 input.output_batch_dimension().getValue().getSExtValue());
50 output.set_output_feature_dimension(
51 input.output_feature_dimension().getValue().getSExtValue());
52
53 for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
54 output.add_output_spatial_dimensions(v);
55 }
56
57 return output;
58 }
59
ConvertConvActivationMode(llvm::StringRef input)60 StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
61 llvm::StringRef input) {
62 llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
63 mlir::lmhlo_gpu::symbolizeActivation(input);
64 if (!activation) {
65 return InternalError("Unexpected activation");
66 }
67
68 switch (activation.getValue()) {
69 case mlir::lmhlo_gpu::Activation::None:
70 return stream_executor::dnn::kNone;
71 case mlir::lmhlo_gpu::Activation::Sigmoid:
72 return stream_executor::dnn::kSigmoid;
73 case mlir::lmhlo_gpu::Activation::Tanh:
74 return stream_executor::dnn::kTanh;
75 case mlir::lmhlo_gpu::Activation::Relu:
76 return stream_executor::dnn::kRelu;
77 case mlir::lmhlo_gpu::Activation::Relu6:
78 return stream_executor::dnn::kRelu6;
79 case mlir::lmhlo_gpu::Activation::ReluX:
80 return stream_executor::dnn::kReluX;
81 case mlir::lmhlo_gpu::Activation::BandPass:
82 return stream_executor::dnn::kBandPass;
83 default:
84 return InternalError("Unexpected activation");
85 }
86 }
87
88 // Convert replica group from MLIR encoding to HLO.
89 // See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding.
ConvertReplicaGroups(mlir::DenseIntElementsAttr input)90 StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
91 mlir::DenseIntElementsAttr input) {
92 mlir::RankedTensorType type =
93 input.getType().dyn_cast<mlir::RankedTensorType>();
94 if (!type || type.getRank() != 2 ||
95 !type.getElementType().isInteger(/*width=*/64)) {
96 return InternalError("Execpted replica group to be a rank 2 tensor of i64");
97 }
98 // rank 0 is num_groups, rank 1 is group size.
99 auto replica_group_values_it = input.getValues<uint64_t>().begin();
100 std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
101 for (ReplicaGroup& group : replica_groups) {
102 for (int64 element_idx = 0; element_idx < type.getDimSize(1);
103 ++element_idx, ++replica_group_values_it) {
104 // For replica group attribute, -1 indicates padding added by
105 // ConvertReplicaGroups. This show always be at the end and can be dropped
106 // when converting back to XLA HLO ReplicaGroups.
107 if (*replica_group_values_it != -1) {
108 group.add_replica_ids(*replica_group_values_it);
109 }
110 }
111 }
112 return replica_groups;
113 }
114
115 // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
116 // and source-target pairs are defined in HLO.
ConvertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optional_attr)117 StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
118 llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
119 if (!optional_attr.hasValue()) return std::vector<std::pair<int64, int64>>{};
120 mlir::DenseIntElementsAttr attr = *optional_attr;
121 auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
122 if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
123 return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
124 auto it = attr.getValues<int64>().begin();
125 std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
126 for (auto& item : out) {
127 int64 first = *it;
128 ++it;
129 int64 second = *it;
130 ++it;
131 item = {first, second};
132 }
133 return out;
134 }
135
ConvertFftType(llvm::StringRef type_string)136 StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
137 llvm::Optional<mlir::mhlo::FftType> type =
138 mlir::mhlo::symbolizeEnum<mlir::mhlo::FftType>(type_string);
139 if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str());
140
141 switch (*type) {
142 case mlir::mhlo::FftType::FFT:
143 return xla::FftType::FFT;
144 case mlir::mhlo::FftType::IFFT:
145 return xla::FftType::IFFT;
146 case mlir::mhlo::FftType::RFFT:
147 return xla::FftType::RFFT;
148 case mlir::mhlo::FftType::IRFFT:
149 return xla::FftType::IRFFT;
150 default:
151 return InvalidArgument("Unknown FFT type enum #%d", *type);
152 }
153 }
154
ConvertTranspose(llvm::StringRef transpose_string)155 StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
156 llvm::StringRef transpose_string) {
157 llvm::Optional<mlir::mhlo::Transpose> transpose =
158 mlir::mhlo::symbolizeTranspose(transpose_string);
159 if (!transpose)
160 return InvalidArgument("Unknown transpose type %s", transpose_string.str());
161
162 switch (*transpose) {
163 case mlir::mhlo::Transpose::NO_TRANSPOSE:
164 return TriangularSolveOptions::NO_TRANSPOSE;
165 case mlir::mhlo::Transpose::TRANSPOSE:
166 return TriangularSolveOptions::TRANSPOSE;
167 case mlir::mhlo::Transpose::ADJOINT:
168 return TriangularSolveOptions::ADJOINT;
169 case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
170 return TriangularSolveOptions::TRANSPOSE_INVALID;
171 default:
172 return InvalidArgument("Unknown transpose enum value #%d", *transpose);
173 }
174 }
175
176 } // namespace xla
177