• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
17 
18 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
19 #include "tensorflow/compiler/xla/types.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/stream_executor/dnn.h"
23 
24 namespace xla {
25 
ConvertConvDimensionNumbers(mlir::mhlo::ConvDimensionNumbers input)26 ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
27     mlir::mhlo::ConvDimensionNumbers input) {
28   ConvolutionDimensionNumbers output;
29 
30   output.set_input_batch_dimension(
31       input.input_batch_dimension().getValue().getSExtValue());
32   output.set_input_feature_dimension(
33       input.input_feature_dimension().getValue().getSExtValue());
34 
35   for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
36     output.add_input_spatial_dimensions(v);
37   }
38 
39   output.set_kernel_input_feature_dimension(
40       input.kernel_input_feature_dimension().getValue().getSExtValue());
41   output.set_kernel_output_feature_dimension(
42       input.kernel_output_feature_dimension().getValue().getSExtValue());
43 
44   for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
45     output.add_kernel_spatial_dimensions(v);
46   }
47 
48   output.set_output_batch_dimension(
49       input.output_batch_dimension().getValue().getSExtValue());
50   output.set_output_feature_dimension(
51       input.output_feature_dimension().getValue().getSExtValue());
52 
53   for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
54     output.add_output_spatial_dimensions(v);
55   }
56 
57   return output;
58 }
59 
ConvertConvActivationMode(llvm::StringRef input)60 StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
61     llvm::StringRef input) {
62   llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
63       mlir::lmhlo_gpu::symbolizeActivation(input);
64   if (!activation) {
65     return InternalError("Unexpected activation");
66   }
67 
68   switch (activation.getValue()) {
69     case mlir::lmhlo_gpu::Activation::None:
70       return stream_executor::dnn::kNone;
71     case mlir::lmhlo_gpu::Activation::Sigmoid:
72       return stream_executor::dnn::kSigmoid;
73     case mlir::lmhlo_gpu::Activation::Tanh:
74       return stream_executor::dnn::kTanh;
75     case mlir::lmhlo_gpu::Activation::Relu:
76       return stream_executor::dnn::kRelu;
77     case mlir::lmhlo_gpu::Activation::Relu6:
78       return stream_executor::dnn::kRelu6;
79     case mlir::lmhlo_gpu::Activation::ReluX:
80       return stream_executor::dnn::kReluX;
81     case mlir::lmhlo_gpu::Activation::BandPass:
82       return stream_executor::dnn::kBandPass;
83     default:
84       return InternalError("Unexpected activation");
85   }
86 }
87 
88 // Convert replica group from MLIR encoding to HLO.
89 // See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding.
ConvertReplicaGroups(mlir::DenseIntElementsAttr input)90 StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
91     mlir::DenseIntElementsAttr input) {
92   mlir::RankedTensorType type =
93       input.getType().dyn_cast<mlir::RankedTensorType>();
94   if (!type || type.getRank() != 2 ||
95       !type.getElementType().isInteger(/*width=*/64)) {
96     return InternalError("Execpted replica group to be a rank 2 tensor of i64");
97   }
98   // rank 0 is num_groups, rank 1 is group size.
99   auto replica_group_values_it = input.getValues<uint64_t>().begin();
100   std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
101   for (ReplicaGroup& group : replica_groups) {
102     for (int64 element_idx = 0; element_idx < type.getDimSize(1);
103          ++element_idx, ++replica_group_values_it) {
104       // For replica group attribute, -1 indicates padding added by
105       // ConvertReplicaGroups. This show always be at the end and can be dropped
106       // when converting back to XLA HLO ReplicaGroups.
107       if (*replica_group_values_it != -1) {
108         group.add_replica_ids(*replica_group_values_it);
109       }
110     }
111   }
112   return replica_groups;
113 }
114 
115 // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
116 // and source-target pairs are defined in HLO.
ConvertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optional_attr)117 StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
118     llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
119   if (!optional_attr.hasValue()) return std::vector<std::pair<int64, int64>>{};
120   mlir::DenseIntElementsAttr attr = *optional_attr;
121   auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
122   if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
123     return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
124   auto it = attr.getValues<int64>().begin();
125   std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
126   for (auto& item : out) {
127     int64 first = *it;
128     ++it;
129     int64 second = *it;
130     ++it;
131     item = {first, second};
132   }
133   return out;
134 }
135 
ConvertFftType(llvm::StringRef type_string)136 StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
137   llvm::Optional<mlir::mhlo::FftType> type =
138       mlir::mhlo::symbolizeEnum<mlir::mhlo::FftType>(type_string);
139   if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str());
140 
141   switch (*type) {
142     case mlir::mhlo::FftType::FFT:
143       return xla::FftType::FFT;
144     case mlir::mhlo::FftType::IFFT:
145       return xla::FftType::IFFT;
146     case mlir::mhlo::FftType::RFFT:
147       return xla::FftType::RFFT;
148     case mlir::mhlo::FftType::IRFFT:
149       return xla::FftType::IRFFT;
150     default:
151       return InvalidArgument("Unknown FFT type enum #%d", *type);
152   }
153 }
154 
ConvertTranspose(llvm::StringRef transpose_string)155 StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
156     llvm::StringRef transpose_string) {
157   llvm::Optional<mlir::mhlo::Transpose> transpose =
158       mlir::mhlo::symbolizeTranspose(transpose_string);
159   if (!transpose)
160     return InvalidArgument("Unknown transpose type %s", transpose_string.str());
161 
162   switch (*transpose) {
163     case mlir::mhlo::Transpose::NO_TRANSPOSE:
164       return TriangularSolveOptions::NO_TRANSPOSE;
165     case mlir::mhlo::Transpose::TRANSPOSE:
166       return TriangularSolveOptions::TRANSPOSE;
167     case mlir::mhlo::Transpose::ADJOINT:
168       return TriangularSolveOptions::ADJOINT;
169     case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
170       return TriangularSolveOptions::TRANSPOSE_INVALID;
171     default:
172       return InvalidArgument("Unknown transpose enum value #%d", *transpose);
173   }
174 }
175 
176 }  // namespace xla
177