1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Diagnostics.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
43 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
49 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
60 #include "tensorflow/core/framework/allocator.h"
61 #include "tensorflow/core/framework/function.h"
62 #include "tensorflow/core/framework/function.pb.h"
63 #include "tensorflow/core/framework/node_properties.h"
64 #include "tensorflow/core/framework/op.h"
65 #include "tensorflow/core/framework/op_kernel.h"
66 #include "tensorflow/core/framework/resource_mgr.h"
67 #include "tensorflow/core/framework/tensor.h"
68 #include "tensorflow/core/framework/types.h"
69 #include "tensorflow/core/framework/types.pb.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/status.h"
72 #include "tensorflow/core/protobuf/config.pb.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/stream_executor/lib/statusor.h"
75 #include "tensorflow/stream_executor/stream_executor.h"
76
77 namespace mlir {
78 namespace mhlo {
79
80 // LINT.IfChange
IsOpAllowedTf2XlaFallback(Operation * op)81 bool IsOpAllowedTf2XlaFallback(Operation* op) {
82 // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
83 // building valid MLIR using MlirHloBuilder.
84 // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
85 // all tf2xla kernels.
86 // Use a pointer for the static set, so the set is not destructed upon thread
87 // end, which would not be thread safe.
88 // clang-format off
89
90 static auto* ops =
91 new llvm::SmallDenseSet<mlir::TypeID, 512>{
92 TypeID::get<TF::AcoshOp>(),
93 TypeID::get<TF::AcosOp>(),
94 TypeID::get<TF::AddNOp>(),
95 TypeID::get<TF::AddV2Op>(),
96 TypeID::get<TF::AngleOp>(),
97 TypeID::get<TF::AdjustContrastv2Op>(),
98 TypeID::get<TF::AdjustHueOp>(),
99 TypeID::get<TF::AdjustSaturationOp>(),
100 TypeID::get<TF::ApproximateEqualOp>(),
101 TypeID::get<TF::ArgMaxOp>(),
102 TypeID::get<TF::ArgMinOp>(),
103 TypeID::get<TF::AsinhOp>(),
104 TypeID::get<TF::AsinOp>(),
105 TypeID::get<TF::Atan2Op>(),
106 TypeID::get<TF::AtanhOp>(),
107 TypeID::get<TF::BatchMatMulV2Op>(),
108 TypeID::get<TF::BatchMatMulV3Op>(),
109 TypeID::get<TF::BatchToSpaceOp>(),
110 TypeID::get<TF::BesselI0eOp>(),
111 TypeID::get<TF::BesselI1eOp>(),
112 TypeID::get<TF::BetaincOp>(),
113 TypeID::get<TF::BiasAddOp>(),
114 TypeID::get<TF::BitwiseAndOp>(),
115 TypeID::get<TF::BitwiseOrOp>(),
116 TypeID::get<TF::BitwiseXorOp>(),
117 TypeID::get<TF::BucketizeOp>(),
118 TypeID::get<TF::CastOp>(),
119 TypeID::get<TF::ClipByValueOp>(),
120 TypeID::get<TF::CholeskyOp>(),
121 TypeID::get<TF::ComplexAbsOp>(),
122 TypeID::get<TF::ConjugateTransposeOp>(),
123 TypeID::get<TF::CoshOp>(),
124 TypeID::get<TF::CrossOp>(),
125 TypeID::get<TF::DataFormatDimMapOp>(),
126 TypeID::get<TF::DataFormatVecPermuteOp>(),
127 TypeID::get<TF::DepthToSpaceOp>(),
128 TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
129 TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
130 TypeID::get<TF::DiagOp>(),
131 TypeID::get<TF::DigammaOp>(),
132 TypeID::get<TF::DivNoNanOp>(),
133 TypeID::get<TF::EluGradOp>(),
134 TypeID::get<TF::EluOp>(),
135 TypeID::get<TF::EnsureShapeOp>(),
136 TypeID::get<TF::EqualOp>(),
137 TypeID::get<TF::ErfcOp>(),
138 TypeID::get<TF::ErfinvOp>(),
139 TypeID::get<TF::ErfOp>(),
140 TypeID::get<TF::ExtractImagePatchesOp>(),
141 TypeID::get<TF::FFT2DOp>(),
142 TypeID::get<TF::FFT3DOp>(),
143 TypeID::get<TF::FFTOp>(),
144 TypeID::get<TF::FakeParamOp>(),
145 TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
146 TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
147 TypeID::get<TF::FloorDivOp>(),
148 TypeID::get<TF::FloorModOp>(),
149 TypeID::get<TF::GreaterOp>(),
150 TypeID::get<TF::HSVToRGBOp>(),
151 TypeID::get<TF::IFFT2DOp>(),
152 TypeID::get<TF::IFFT3DOp>(),
153 TypeID::get<TF::IRFFT2DOp>(),
154 TypeID::get<TF::IRFFT3DOp>(),
155 TypeID::get<TF::IgammaOp>(),
156 TypeID::get<TF::IgammacOp>(),
157 TypeID::get<TF::IgammaGradAOp>(),
158 TypeID::get<TF::InplaceAddOp>(),
159 TypeID::get<TF::InTopKV2Op>(),
160 TypeID::get<TF::InvertOp>(),
161 TypeID::get<TF::InvOp>(),
162 TypeID::get<TF::KthOrderStatisticOp>(),
163 TypeID::get<TF::LRNOp>(),
164 TypeID::get<TF::LRNGradOp>(),
165 TypeID::get<TF::LeakyReluGradOp>(),
166 TypeID::get<TF::LeakyReluOp>(),
167 TypeID::get<TF::LeftShiftOp>(),
168 TypeID::get<TF::LessOp>(),
169 TypeID::get<TF::ListDiffOp>(),
170 TypeID::get<TF::LogicalAndOp>(),
171 TypeID::get<TF::LogicalNotOp>(),
172 TypeID::get<TF::LogOp>(),
173 TypeID::get<TF::LowerBoundOp>(),
174 TypeID::get<TF::MakeUniqueOp>(),
175 TypeID::get<TF::MatMulOp>(),
176 TypeID::get<TF::MatrixDiagV3Op>(),
177 TypeID::get<TF::MatrixInverseOp>(),
178 TypeID::get<TF::MatrixSetDiagV3Op>(),
179 TypeID::get<TF::MatrixSolveOp>(),
180 TypeID::get<TF::MatrixTriangularSolveOp>(),
181 TypeID::get<TF::MaxPool3DGradGradOp>(),
182 TypeID::get<TF::MaxPoolGradGradOp>(),
183 TypeID::get<TF::MirrorPadOp>(),
184 TypeID::get<TF::MirrorPadGradOp>(),
185 TypeID::get<TF::MulOp>(),
186 TypeID::get<TF::MultinomialOp>(),
187 TypeID::get<TF::NdtriOp>(),
188 TypeID::get<TF::NegOp>(),
189 TypeID::get<TF::NextAfterOp>(),
190 TypeID::get<TF::NonMaxSuppressionV4Op>(),
191 TypeID::get<TF::NotEqualOp>(),
192 TypeID::get<TF::PadOp>(),
193 TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
194 TypeID::get<TF::PlaceholderWithDefaultOp>(),
195 TypeID::get<TF::PolygammaOp>(),
196 TypeID::get<TF::PopulationCountOp>(),
197 TypeID::get<TF::PowOp>(),
198 // TODO(hinsu): Canonicalize QuantizeAndDequantize and
199 // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
200 // attributes to operands.
201 TypeID::get<TF::QuantizeAndDequantizeOp>(),
202 TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
203 TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
204 TypeID::get<TF::QuantizeAndDequantizeV4Op>(),
205 TypeID::get<TF::RFFT2DOp>(),
206 TypeID::get<TF::RFFT3DOp>(),
207 TypeID::get<TF::RGBToHSVOp>(),
208 TypeID::get<TF::RandomUniformIntOp>(),
209 TypeID::get<TF::RealDivOp>(),
210 TypeID::get<TF::ReciprocalGradOp>(),
211 TypeID::get<TF::Relu6GradOp>(),
212 TypeID::get<TF::ResizeBilinearOp>(),
213 TypeID::get<TF::ResizeBilinearGradOp>(),
214 TypeID::get<TF::ResizeNearestNeighborOp>(),
215 TypeID::get<TF::ResizeNearestNeighborGradOp>(),
216 TypeID::get<TF::ReverseSequenceOp>(),
217 TypeID::get<TF::RightShiftOp>(),
218 TypeID::get<TF::RintOp>(),
219 TypeID::get<TF::RollOp>(),
220 TypeID::get<TF::RoundOp>(),
221 TypeID::get<TF::SelectV2Op>(),
222 TypeID::get<TF::SelfAdjointEigV2Op>(),
223 TypeID::get<TF::SeluGradOp>(),
224 TypeID::get<TF::SeluOp>(),
225 TypeID::get<TF::SigmoidGradOp>(),
226 TypeID::get<TF::SinOp>(),
227 TypeID::get<TF::SoftplusGradOp>(),
228 TypeID::get<TF::SoftsignGradOp>(),
229 TypeID::get<TF::SoftsignOp>(),
230 TypeID::get<TF::SpaceToBatchNDOp>(),
231 TypeID::get<TF::SpaceToBatchOp>(),
232 TypeID::get<TF::SpaceToDepthOp>(),
233 TypeID::get<TF::SparseToDenseOp>(),
234 TypeID::get<TF::SquareOp>(),
235 TypeID::get<TF::StatelessMultinomialOp>(),
236 TypeID::get<TF::StatelessRandomGetAlgOp>(),
237 TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
238 TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
239 TypeID::get<TF::StatelessRandomNormalOp>(),
240 TypeID::get<TF::StatelessRandomNormalV2Op>(),
241 TypeID::get<TF::StatelessRandomUniformOp>(),
242 TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
243 TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
244 TypeID::get<TF::StatelessRandomUniformV2Op>(),
245 TypeID::get<TF::StatelessRandomUniformIntOp>(),
246 TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
247 TypeID::get<TF::StatelessTruncatedNormalOp>(),
248 TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
249 TypeID::get<TF::SubOp>(),
250 TypeID::get<TF::SvdOp>(),
251 TypeID::get<TF::TanOp>(),
252 TypeID::get<TF::TensorScatterAddOp>(),
253 TypeID::get<TF::TensorScatterSubOp>(),
254 TypeID::get<TF::TPUEmbeddingActivationsOp>(),
255 TypeID::get<TF::TopKUniqueOp>(),
256 TypeID::get<TF::TopKWithUniqueOp>(),
257 TypeID::get<TF::TransposeOp>(),
258 TypeID::get<TF::TridiagonalSolveOp>(),
259 TypeID::get<TF::TruncateDivOp>(),
260 TypeID::get<TF::TruncatedNormalOp>(),
261 TypeID::get<TF::TruncateModOp>(),
262 TypeID::get<TF::UnpackOp>(),
263 TypeID::get<TF::UpperBoundOp>(),
264 TypeID::get<TF::XlaBroadcastHelperOp>(),
265 TypeID::get<TF::XlaConvOp>(),
266 TypeID::get<TF::XlaConvV2Op>(),
267 TypeID::get<TF::XlaDotOp>(),
268 TypeID::get<TF::XlaDotV2Op>(),
269 TypeID::get<TF::XlaDynamicSliceOp>(),
270 TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
271 TypeID::get<TF::XlaEinsumOp>(),
272 TypeID::get<TF::XlaKeyValueSortOp>(),
273 TypeID::get<TF::XlaPadOp>(),
274 TypeID::get<TF::XlaSortOp>(),
275 TypeID::get<TF::XlaSvdOp>(),
276 };
277 // clang-format on
278
279 auto* abstractOp = op->getAbstractOperation();
280 if (!abstractOp) return false;
281 return ops->count(abstractOp->typeID);
282 }
283 // LINT.ThenChange(:Tf2XlaPreferred)
284
285 /// List of ops that should use XlaOpKernel legalization only in the case of
286 /// prefer_tf2xla. All other ops not in this list should use MLIR legalization
287 /// only or not be legalized by the new bridge.
288 // LINT.IfChange(Tf2XlaPreferred)
IsOpAllowedTf2XlaPreferred(Operation * op)289 bool IsOpAllowedTf2XlaPreferred(Operation* op) {
290 // Use a pointer for the static set, so the set is not destructed upon thread
291 // end, which would not be thread safe.
292 // clang-format off
293 static auto* ops =
294 new llvm::SmallDenseSet<mlir::TypeID, 512>{
295 TypeID::get<TF::AllOp>(),
296 TypeID::get<TF::AllToAllOp>(),
297 TypeID::get<TF::AnyOp>(),
298 TypeID::get<TF::AvgPoolOp>(),
299 TypeID::get<TF::AvgPool3DGradOp>(),
300 TypeID::get<TF::AvgPoolGradOp>(),
301 TypeID::get<TF::BatchToSpaceNDOp>(),
302 TypeID::get<TF::BitcastOp>(),
303 TypeID::get<TF::BroadcastToOp>(),
304 TypeID::get<TF::CollectivePermuteOp>(),
305 TypeID::get<TF::ConcatV2Op>(),
306 TypeID::get<TF::ConjOp>(),
307 TypeID::get<TF::Conv2DOp>(),
308 TypeID::get<TF::Conv2DBackpropFilterOp>(),
309 TypeID::get<TF::Conv2DBackpropInputOp>(),
310 TypeID::get<TF::Conv3DOp>(),
311 TypeID::get<TF::Conv3DBackpropFilterV2Op>(),
312 TypeID::get<TF::Conv3DBackpropInputV2Op>(),
313 TypeID::get<TF::CumprodOp>(),
314 TypeID::get<TF::CumsumOp>(),
315 TypeID::get<TF::DepthwiseConv2dNativeOp>(),
316 TypeID::get<TF::DynamicStitchOp>(),
317 TypeID::get<TF::_EagerConstOp>(),
318 TypeID::get<TF::EmptyOp>(),
319 TypeID::get<TF::ExpandDimsOp>(),
320 TypeID::get<TF::FillOp>(),
321 TypeID::get<TF::FusedBatchNormOp>(),
322 TypeID::get<TF::FusedBatchNormGradOp>(),
323 TypeID::get<TF::FusedBatchNormGradV2Op>(),
324 TypeID::get<TF::FusedBatchNormGradV3Op>(),
325 TypeID::get<TF::FusedBatchNormV2Op>(),
326 TypeID::get<TF::FusedBatchNormV3Op>(),
327 TypeID::get<TF::GatherNdOp>(),
328 TypeID::get<TF::GatherV2Op>(),
329 TypeID::get<TF::IdentityOp>(),
330 TypeID::get<TF::IdentityNOp>(),
331 TypeID::get<TF::InplaceUpdateOp>(),
332 TypeID::get<TF::InvertPermutationOp>(),
333 TypeID::get<TF::IRFFTOp>(),
334 TypeID::get<TF::L2LossOp>(),
335 TypeID::get<TF::LegacyCallOp>(),
336 TypeID::get<TF::LinSpaceOp>(),
337 TypeID::get<TF::MatrixDiagPartV3Op>(),
338 TypeID::get<TF::MaxOp>(),
339 TypeID::get<TF::MaximumOp>(),
340 TypeID::get<TF::MaxPoolOp>(),
341 TypeID::get<TF::MaxPool3DOp>(),
342 TypeID::get<TF::MaxPoolGradOp>(),
343 TypeID::get<TF::MeanOp>(),
344 TypeID::get<TF::MinOp>(),
345 TypeID::get<TF::MinimumOp>(),
346 TypeID::get<TF::MulNoNanOp>(),
347 TypeID::get<TF::OneHotOp>(),
348 TypeID::get<TF::OnesLikeOp>(),
349 TypeID::get<TF::PackOp>(),
350 TypeID::get<TF::PadV2Op>(),
351 TypeID::get<TF::ParallelDynamicStitchOp>(),
352 TypeID::get<TF::PartitionedCallOp>(),
353 TypeID::get<TF::ProdOp>(),
354 TypeID::get<TF::QrOp>(),
355 TypeID::get<TF::RandomStandardNormalOp>(),
356 TypeID::get<TF::RandomUniformOp>(),
357 TypeID::get<TF::RangeOp>(),
358 TypeID::get<TF::ReshapeOp>(),
359 TypeID::get<TF::ReverseV2Op>(),
360 TypeID::get<TF::RFFTOp>(),
361 TypeID::get<TF::RsqrtGradOp>(),
362 TypeID::get<TF::ScatterNdOp>(),
363 TypeID::get<TF::ShapeOp>(),
364 TypeID::get<TF::SinhOp>(),
365 TypeID::get<TF::SizeOp>(),
366 TypeID::get<TF::SliceOp>(),
367 TypeID::get<TF::SoftmaxCrossEntropyWithLogitsOp>(),
368 TypeID::get<TF::SoftplusOp>(),
369 TypeID::get<TF::SparseMatMulOp>(),
370 TypeID::get<TF::SparseSoftmaxCrossEntropyWithLogitsOp>(),
371 TypeID::get<TF::SplitOp>(),
372 TypeID::get<TF::SplitVOp>(),
373 TypeID::get<TF::SqueezeOp>(),
374 TypeID::get<TF::StatefulPartitionedCallOp>(),
375 TypeID::get<TF::StopGradientOp>(),
376 TypeID::get<TF::StridedSliceOp>(),
377 TypeID::get<TF::StridedSliceGradOp>(),
378 TypeID::get<TF::SumOp>(),
379 TypeID::get<TF::TensorScatterUpdateOp>(),
380 TypeID::get<TF::TileOp>(),
381 TypeID::get<TF::TopKV2Op>(),
382 TypeID::get<TF::_UnaryOpsCompositionOp>(),
383 TypeID::get<TF::UnsortedSegmentMaxOp>(),
384 TypeID::get<TF::UnsortedSegmentMinOp>(),
385 TypeID::get<TF::UnsortedSegmentProdOp>(),
386 TypeID::get<TF::UnsortedSegmentSumOp>(),
387 TypeID::get<TF::XdivyOp>(),
388 TypeID::get<TF::XlaAllReduceOp>(),
389 TypeID::get<TF::XlaGatherOp>(),
390 TypeID::get<TF::XlaReplicaIdOp>(),
391 TypeID::get<TF::Xlog1pyOp>(),
392 TypeID::get<TF::ZerosLikeOp>(),
393
394 // XlaOpKernel makes use of compiler options which we don't feed in the
395 // fallback.
396 // TypeID::get<TF::FakeQuantWithMinMaxVarsOp>(),
397 };
398 // clang-format on
399 auto* abstractOp = op->getAbstractOperation();
400 if (!abstractOp) return false;
401 return ops->count(abstractOp->typeID);
402 }
403 // LINT.ThenChange()
404
IsOpAllowedForTesting(Operation * op)405 bool IsOpAllowedForTesting(Operation* op) {
406 // clang-format off
407 static auto* ops =
408 new llvm::SmallDenseSet<mlir::TypeID, 16>{
409 // Op used to verify handling of XlaExpression of kind constant.
410 TypeID::get<TF::ConstOp>(),
411 };
412 // clang-format on
413 auto* abstractOp = op->getAbstractOperation();
414 if (!abstractOp) return false;
415 return ops->count(abstractOp->typeID);
416 }
417
418 namespace {
419
420 template <typename T, size_t N>
421 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
422
CreateDeviceMgr(const std::string & device_type)423 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
424 const std::string& device_type) {
425 // Register compilation kernels for all registered XLA backends.
426 tensorflow::XlaOpRegistry::RegisterCompilationKernels();
427
428 auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
429 tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
430 return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
431 }
432
433 class Tf2XlaRewriter {
434 public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type)435 static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
436 const std::string& device_type) {
437 Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
438 return tf2xla_rewriter.LegalizeOp();
439 }
440
441 private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type)442 Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
443 const std::string& device_type)
444 : op_(op),
445 device_type_(device_type),
446 rewriter_(rewriter),
447 hlo_builder_(op->getName().getStringRef().str(), rewriter_,
448 op->getLoc()),
449 context_(nullptr) {}
450
~Tf2XlaRewriter()451 ~Tf2XlaRewriter() {
452 if (context_) context_->Unref();
453 }
454
455 // Prepares OpKernelContext params common to all the ops.
456 // Emits an error on failure.
457 LogicalResult PrepareParams();
458
459 // Tries to legalize the specified TensorFlow op, if supported.
460 //
461 // Emits an error and returns failure if an error is encountered during
462 // conversion. Note that success return value doesn't mean successful
463 // legalization.
464 LogicalResult LegalizeOp();
465
466 // Converts the given operand to expression of kind kConstant or kXlaOp.
467 // Emits a remark and returns expression of kind kInvalid on failure.
468 tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
469
470 Operation* op_;
471 std::string device_type_;
472
473 PatternRewriter& rewriter_;
474 ::xla::MlirHloBuilder hlo_builder_;
475 tensorflow::OpOrArgLocNameMapper name_mapper_;
476
477 tensorflow::XlaContext* context_; // Ref-counted.
478
479 std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
480 tensorflow::Device* device_; // Owned by device_mgr_;
481 std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
482 std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
483 std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
484 tensorflow::OpKernelContext::Params params_;
485 };
486
PrepareParams()487 LogicalResult Tf2XlaRewriter::PrepareParams() {
488 // XlaCompiler within the context is only used by the functional ops to
489 // compile functions. We are not handling those at the moment so XlaCompiler
490 // is not required.
491 context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
492 /*graph=*/nullptr);
493 context_->Ref();
494
495 device_mgr_ = CreateDeviceMgr(device_type_);
496 if (!device_mgr_) return failure();
497
498 // Type of params_.device is DeviceBase* so store it as Device* to access
499 // derived class method.
500 device_ = device_mgr_->ListDevices().front();
501 params_.device = device_;
502 params_.resource_manager = device_->resource_manager();
503
504 // Resources are cleared at the time of device manager destruction so pass
505 // no-op cleanup function.
506 auto cleanup = [](const std::string& name) {};
507 // Use step_id zero as we only have a single context concurrently and
508 // concurrently running each of the MLIR functions create a new device.
509 step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
510 /*step_id=*/0, cleanup);
511 tensorflow::Status status = step_container_->Create(
512 device_->resource_manager(),
513 tensorflow::XlaContext::kXlaContextResourceName, context_);
514 if (!status.ok()) {
515 return emitRemark(op_->getLoc())
516 << "failed to create XlaContext resource: " << status.ToString();
517 }
518 params_.step_container = step_container_.get();
519
520 tensorflow::StatusOr<int64_t> version_or =
521 tensorflow::GetTfGraphProducerVersion(
522 op_->getParentOfType<mlir::ModuleOp>());
523 if (!version_or.ok()) {
524 return emitError(op_->getLoc()) << version_or.status().ToString();
525 }
526
527 flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
528 tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
529 pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
530 device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
531 version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
532 params_.function_library = pflr_->GetFLR(device_->name());
533 return success();
534 }
535
LegalizeOp()536 LogicalResult Tf2XlaRewriter::LegalizeOp() {
537 // Only static shaped operands are supported in XLA builders for now.
538 for (Type ty : op_->getOperandTypes()) {
539 auto ranked_ty = ty.dyn_cast<ShapedType>();
540 if (!ranked_ty || !ranked_ty.hasStaticShape()) {
541 return op_->emitRemark()
542 << "lowering requires static shaped tensor operands";
543 }
544 }
545
546 for (const auto& attr : op_->getAttrs()) {
547 if (attr.second.isa<SymbolRefAttr>()) {
548 return op_->emitRemark()
549 << "ops with symbol references are not supported";
550 }
551 }
552
553 auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
554 op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
555 if (!nodedef_or.ok()) {
556 return op_->emitRemark() << "failed to convert op to NodeDef: "
557 << nodedef_or.status().ToString();
558 }
559
560 if (failed(PrepareParams())) return failure();
561
562 std::shared_ptr<const tensorflow::NodeProperties> props;
563 tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
564 *nodedef_or.ValueOrDie(),
565 params_.function_library->GetFunctionLibraryDefinition(), &props);
566 if (!status.ok()) {
567 return op_->emitRemark()
568 << "failed to create NodeProperties: " << status.ToString();
569 }
570 tensorflow::OpKernel* op_kernel_raw;
571 status = params_.function_library->CreateKernel(props, &op_kernel_raw);
572 if (!status.ok()) {
573 return op_->emitRemark()
574 << "failed to create tf2xla kernel: " << status.ToString();
575 }
576 // Transfer ownership of the kernel to a local smart pointer.
577 auto op_kernel = absl::WrapUnique(op_kernel_raw);
578
579 std::vector<int> required_constants;
580 status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
581 *op_kernel, &required_constants);
582 if (!status.ok()) {
583 return op_->emitRemark()
584 << "failed to compute required constants: " << status.ToString();
585 }
586 llvm::SmallDenseSet<int, 4> required_consts;
587 required_consts.insert(required_constants.begin(), required_constants.end());
588
589 // TensorValue in inputs are backed by tensors which in turn depend on
590 // expressions. So, pre-allocate them to the required size.
591 InlinedVector<tensorflow::XlaExpression, 4> expressions;
592 InlinedVector<tensorflow::Tensor, 4> tensors;
593 InlinedVector<tensorflow::TensorValue, 4> inputs;
594 expressions.reserve(op_->getNumOperands());
595 tensors.reserve(op_->getNumOperands());
596 inputs.reserve(op_->getNumOperands());
597
598 // Prepare the list of Tensor inputs for the kernel.
599 for (auto it : llvm::enumerate(op_->getOperands())) {
600 Value operand = it.value();
601 size_t idx = it.index();
602
603 tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
604 tensorflow::XlaExpression::Kind kind = expr.kind();
605 if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
606 if (required_consts.count(idx) &&
607 kind != tensorflow::XlaExpression::Kind::kConstant) {
608 return op_->emitRemark()
609 << "lowering requires operand #" << idx << " to be a constant";
610 }
611 expressions.push_back(expr);
612
613 if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
614 return op_->emitRemark()
615 << "skipping legalization due to unsupported type "
616 << operand.getType();
617 }
618
619 auto shape_or = expr.GetShape();
620 if (!shape_or.ok()) {
621 return op_->emitRemark()
622 << "failed to get shape for expression. " << expr.HumanString();
623 }
624
625 tensors.emplace_back(
626 device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
627 shape_or.ValueOrDie());
628 tensorflow::Tensor& tensor = tensors.back();
629 tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
630 inputs.emplace_back(&tensor);
631 }
632
633 params_.inputs = &inputs;
634 params_.op_kernel = op_kernel.get();
635 llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
636 op_->getNumResults());
637 params_.output_attr_array = output_attr.data();
638
639 hlo_builder_.setInsertionPoint(op_);
640 hlo_builder_.SetLocation(op_->getLoc());
641
642 // Execute the kernel.
643 tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults());
644 device_->Compute(params_.op_kernel, &op_context);
645
646 status = op_context.status();
647 status.Update(hlo_builder_.GetCurrentStatus());
648 if (!status.ok()) {
649 return op_->emitRemark()
650 << "compilation to HLO failed: " << status.ToString();
651 }
652
653 // Replace uses of old results using the corresponding value after the
654 // lowering.
655 llvm::SmallVector<Value, 2> values;
656 values.reserve(op_->getNumResults());
657 for (int i = 0, e = op_->getNumResults(); i < e; i++) {
658 tensorflow::Tensor* output = op_context.mutable_output(i);
659 const tensorflow::XlaExpression* expr =
660 tensorflow::XlaExpression::CastExpressionFromTensor(*output);
661 if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp &&
662 expr->kind() != tensorflow::XlaExpression::Kind::kConstant) {
663 return op_->emitRemark(
664 "expects XlaExpression of kind kXlaOp or kConstant in compiled "
665 "output");
666 }
667 mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_));
668 mlir::OpResult old_result = op_->getResult(i);
669 if (value.getType() != old_result.getType()) {
670 value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
671 value);
672 }
673 values.push_back(value);
674 }
675 rewriter_.replaceOp(op_, values);
676 return success();
677 }
678
GetExprForOperand(Value operand,Operation * op)679 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
680 Operation* op) {
681 ElementsAttr const_attr;
682 auto defining_op = operand.getDefiningOp();
683 if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
684 tensorflow::Tensor tensor;
685 auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
686 if (!status.ok()) {
687 op->emitRemark() << "skipping legalization due to failed const conversion"
688 << status.ToString();
689 return tensorflow::XlaExpression::Invalid();
690 }
691 return tensorflow::XlaExpression::Constant(tensor);
692 }
693
694 // Skip this op if XLA doesn't support this operand type.
695 auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
696 if (!xla_op_or.ok()) {
697 op->emitRemark() << "skipping legalization due to "
698 << xla_op_or.status().ToString();
699 return tensorflow::XlaExpression::Invalid();
700 }
701 ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
702
703 tensorflow::DataType dtype;
704 auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
705 if (!status.ok()) {
706 op->emitRemark() << "skipping legalization due to " << status.ToString();
707 return tensorflow::XlaExpression::Invalid();
708 }
709 return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
710 }
711
712 class Tf2XlaRewritePattern : public RewritePattern {
713 public:
Tf2XlaRewritePattern(MLIRContext * ctx,const std::string & device_type,bool prefer_tf2xla,bool legalize_test_only_ops)714 explicit Tf2XlaRewritePattern(MLIRContext* ctx,
715 const std::string& device_type,
716 bool prefer_tf2xla, bool legalize_test_only_ops)
717 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
718 device_type_(device_type),
719 prefer_tf2xla_(prefer_tf2xla),
720 legalize_test_only_ops_(legalize_test_only_ops) {}
721
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const722 LogicalResult matchAndRewrite(Operation* op,
723 PatternRewriter& rewriter) const override {
724 if (!(IsOpAllowedTf2XlaFallback(op) ||
725 (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)) ||
726 (legalize_test_only_ops_ && IsOpAllowedForTesting(op))))
727 return failure();
728 return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
729 }
730
731 private:
732 std::string device_type_;
733 bool prefer_tf2xla_;
734 bool legalize_test_only_ops_;
735 };
736
737 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
738 public:
739 LegalizeTF() = default;
740
getDependentDialects(DialectRegistry & registry) const741 void getDependentDialects(DialectRegistry& registry) const override {
742 registry.insert<MhloDialect>();
743 }
744
LegalizeTF(llvm::StringRef device_type,bool prefer_tf2xla)745 explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
746 device_type_ = device_type.str();
747 prefer_tf2xla_ = prefer_tf2xla;
748 }
749
LegalizeTF(const LegalizeTF &)750 LegalizeTF(const LegalizeTF&) {}
751
getArgument() const752 StringRef getArgument() const final { return "xla-legalize-tf-with-tf2xla"; }
getDescription() const753 StringRef getDescription() const final {
754 return "Legalize from TensorFlow to the HLO dialect using tf2xla kernels";
755 }
756
runOnFunction()757 void runOnFunction() override {
758 OwningRewritePatternList patterns(&getContext());
759 patterns.insert<Tf2XlaRewritePattern>(
760 &getContext(), device_type_, prefer_tf2xla_, legalize_test_only_ops_);
761 if (failed(
762 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
763 signalPassFailure();
764 }
765
766 private:
767 // TODO(hinsu): Support finer grained device type assignment instead of a
768 // global device type for all TensorFlow ops.
769 Option<std::string> device_type_{
770 *this, "device-type",
771 llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
772 Option<bool> prefer_tf2xla_{
773 *this,
774 "prefer-tf2xla",
775 llvm::cl::desc("Enable legalization when it is not in the list of "
776 "MLIR-legalized ops."),
777 };
778 Option<bool> legalize_test_only_ops_{
779 *this,
780 "legalize-test-only-ops",
781 llvm::cl::desc("Enable tf2xla legalizations for some ops that are "
782 "enabled only for testing."),
783 };
784 };
785
786 static PassRegistration<LegalizeTF> pass;
787
788 } // end namespace
789
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,OwningRewritePatternList & patterns,MLIRContext * ctx,bool prefer_tf2xla)790 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
791 OwningRewritePatternList& patterns,
792 MLIRContext* ctx,
793 bool prefer_tf2xla) {
794 patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla,
795 /*legalize_test_only_ops=*/false);
796 }
797
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type,bool prefer_tf2xla)798 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
799 llvm::StringRef device_type, bool prefer_tf2xla) {
800 return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
801 }
802
803 } // end namespace mhlo
804 } // end namespace mlir
805