1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Diagnostics.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
48 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
49 #include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes_detail.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
57 #include "tensorflow/core/common_runtime/device.h"
58 #include "tensorflow/core/common_runtime/device_factory.h"
59 #include "tensorflow/core/common_runtime/device_mgr.h"
60 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
61 #include "tensorflow/core/framework/allocator.h"
62 #include "tensorflow/core/framework/function.h"
63 #include "tensorflow/core/framework/function.pb.h"
64 #include "tensorflow/core/framework/node_properties.h"
65 #include "tensorflow/core/framework/op.h"
66 #include "tensorflow/core/framework/op_kernel.h"
67 #include "tensorflow/core/framework/resource_mgr.h"
68 #include "tensorflow/core/framework/tensor.h"
69 #include "tensorflow/core/framework/types.h"
70 #include "tensorflow/core/framework/types.pb.h"
71 #include "tensorflow/core/platform/env.h"
72 #include "tensorflow/core/platform/status.h"
73 #include "tensorflow/core/protobuf/config.pb.h"
74 #include "tensorflow/core/public/session_options.h"
75 #include "tensorflow/stream_executor/lib/statusor.h"
76 #include "tensorflow/stream_executor/stream_executor.h"
77
78 namespace mlir {
79 namespace mhlo {
80
81 // LINT.IfChange
IsOpAllowedTf2XlaFallback(Operation * op)82 bool IsOpAllowedTf2XlaFallback(Operation* op) {
83 // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
84 // building valid MLIR using MlirHloBuilder.
85 // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
86 // all tf2xla kernels.
87 // Use a pointer for the static set, so the set is not destructed upon thread
88 // end, which would not be thread safe.
89 // clang-format off
90
91 static auto* ops =
92 new llvm::SmallDenseSet<mlir::TypeID, 512>{
93 TypeID::get<TF::AcoshOp>(),
94 TypeID::get<TF::AcosOp>(),
95 TypeID::get<TF::AddNOp>(),
96 TypeID::get<TF::AddV2Op>(),
97 TypeID::get<TF::AngleOp>(),
98 TypeID::get<TF::AdjustContrastv2Op>(),
99 TypeID::get<TF::AdjustHueOp>(),
100 TypeID::get<TF::AdjustSaturationOp>(),
101 TypeID::get<TF::ApproximateEqualOp>(),
102 TypeID::get<TF::ArgMaxOp>(),
103 TypeID::get<TF::ArgMinOp>(),
104 TypeID::get<TF::AsinhOp>(),
105 TypeID::get<TF::AsinOp>(),
106 TypeID::get<TF::Atan2Op>(),
107 TypeID::get<TF::AtanhOp>(),
108 TypeID::get<TF::BatchMatMulV2Op>(),
109 TypeID::get<TF::BatchMatMulV3Op>(),
110 TypeID::get<TF::BatchToSpaceOp>(),
111 TypeID::get<TF::BesselI0eOp>(),
112 TypeID::get<TF::BesselI1eOp>(),
113 TypeID::get<TF::BetaincOp>(),
114 TypeID::get<TF::BiasAddOp>(),
115 TypeID::get<TF::BitwiseAndOp>(),
116 TypeID::get<TF::BitwiseOrOp>(),
117 TypeID::get<TF::BitwiseXorOp>(),
118 TypeID::get<TF::BucketizeOp>(),
119 TypeID::get<TF::CastOp>(),
120 TypeID::get<TF::ClipByValueOp>(),
121 TypeID::get<TF::CholeskyOp>(),
122 TypeID::get<TF::ComplexAbsOp>(),
123 TypeID::get<TF::ConjugateTransposeOp>(),
124 TypeID::get<TF::CoshOp>(),
125 TypeID::get<TF::CrossOp>(),
126 TypeID::get<TF::DataFormatDimMapOp>(),
127 TypeID::get<TF::DataFormatVecPermuteOp>(),
128 TypeID::get<TF::DepthToSpaceOp>(),
129 TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
130 TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
131 TypeID::get<TF::DiagOp>(),
132 TypeID::get<TF::DigammaOp>(),
133 TypeID::get<TF::DivNoNanOp>(),
134 TypeID::get<TF::EluGradOp>(),
135 TypeID::get<TF::EluOp>(),
136 TypeID::get<TF::EnsureShapeOp>(),
137 TypeID::get<TF::EqualOp>(),
138 TypeID::get<TF::ErfcOp>(),
139 TypeID::get<TF::ErfinvOp>(),
140 TypeID::get<TF::ErfOp>(),
141 TypeID::get<TF::ExtractImagePatchesOp>(),
142 TypeID::get<TF::FFT2DOp>(),
143 TypeID::get<TF::FFT3DOp>(),
144 TypeID::get<TF::FFTOp>(),
145 TypeID::get<TF::FakeParamOp>(),
146 TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
147 TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
148 TypeID::get<TF::FakeQuantWithMinMaxVarsPerChannelOp>(),
149 TypeID::get<TF::FakeQuantWithMinMaxVarsPerChannelGradientOp>(),
150 TypeID::get<TF::FloorDivOp>(),
151 TypeID::get<TF::FloorModOp>(),
152 TypeID::get<TF::GreaterOp>(),
153 TypeID::get<TF::HSVToRGBOp>(),
154 TypeID::get<TF::IFFT2DOp>(),
155 TypeID::get<TF::IFFT3DOp>(),
156 TypeID::get<TF::IRFFT2DOp>(),
157 TypeID::get<TF::IRFFT3DOp>(),
158 TypeID::get<TF::IgammaOp>(),
159 TypeID::get<TF::IgammacOp>(),
160 TypeID::get<TF::IgammaGradAOp>(),
161 TypeID::get<TF::InplaceAddOp>(),
162 TypeID::get<TF::InTopKV2Op>(),
163 TypeID::get<TF::InvertOp>(),
164 TypeID::get<TF::InvOp>(),
165 TypeID::get<TF::KthOrderStatisticOp>(),
166 TypeID::get<TF::LRNOp>(),
167 TypeID::get<TF::LRNGradOp>(),
168 TypeID::get<TF::LeakyReluGradOp>(),
169 TypeID::get<TF::LeakyReluOp>(),
170 TypeID::get<TF::LeftShiftOp>(),
171 TypeID::get<TF::LessOp>(),
172 TypeID::get<TF::ListDiffOp>(),
173 TypeID::get<TF::LogicalAndOp>(),
174 TypeID::get<TF::LogicalNotOp>(),
175 TypeID::get<TF::LogOp>(),
176 TypeID::get<TF::LowerBoundOp>(),
177 TypeID::get<TF::MakeUniqueOp>(),
178 TypeID::get<TF::MatMulOp>(),
179 TypeID::get<TF::MatrixDiagV3Op>(),
180 TypeID::get<TF::MatrixInverseOp>(),
181 TypeID::get<TF::MatrixSetDiagV3Op>(),
182 TypeID::get<TF::MatrixSolveOp>(),
183 TypeID::get<TF::MatrixTriangularSolveOp>(),
184 TypeID::get<TF::MaxPool3DGradGradOp>(),
185 TypeID::get<TF::MaxPoolGradGradOp>(),
186 TypeID::get<TF::MirrorPadOp>(),
187 TypeID::get<TF::MirrorPadGradOp>(),
188 TypeID::get<TF::MulOp>(),
189 TypeID::get<TF::MultinomialOp>(),
190 TypeID::get<TF::NdtriOp>(),
191 TypeID::get<TF::NegOp>(),
192 TypeID::get<TF::NextAfterOp>(),
193 TypeID::get<TF::NonMaxSuppressionV4Op>(),
194 TypeID::get<TF::NotEqualOp>(),
195 TypeID::get<TF::PadOp>(),
196 TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
197 TypeID::get<TF::PlaceholderWithDefaultOp>(),
198 TypeID::get<TF::PolygammaOp>(),
199 TypeID::get<TF::PopulationCountOp>(),
200 TypeID::get<TF::PowOp>(),
201 // TODO(hinsu): Canonicalize QuantizeAndDequantize and
202 // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
203 // attributes to operands.
204 TypeID::get<TF::QuantizeAndDequantizeOp>(),
205 TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
206 TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
207 TypeID::get<TF::QuantizeAndDequantizeV4Op>(),
208 TypeID::get<TF::RFFT2DOp>(),
209 TypeID::get<TF::RFFT3DOp>(),
210 TypeID::get<TF::RGBToHSVOp>(),
211 TypeID::get<TF::RandomUniformIntOp>(),
212 TypeID::get<TF::RealDivOp>(),
213 TypeID::get<TF::ReciprocalGradOp>(),
214 TypeID::get<TF::Relu6GradOp>(),
215 TypeID::get<TF::ResizeBilinearOp>(),
216 TypeID::get<TF::ResizeBilinearGradOp>(),
217 TypeID::get<TF::ResizeNearestNeighborOp>(),
218 TypeID::get<TF::ResizeNearestNeighborGradOp>(),
219 TypeID::get<TF::ReverseSequenceOp>(),
220 TypeID::get<TF::RightShiftOp>(),
221 TypeID::get<TF::RintOp>(),
222 TypeID::get<TF::RollOp>(),
223 TypeID::get<TF::RoundOp>(),
224 TypeID::get<TF::SelectV2Op>(),
225 TypeID::get<TF::SelfAdjointEigV2Op>(),
226 TypeID::get<TF::SeluGradOp>(),
227 TypeID::get<TF::SeluOp>(),
228 TypeID::get<TF::SigmoidGradOp>(),
229 TypeID::get<TF::SinOp>(),
230 TypeID::get<TF::SoftplusGradOp>(),
231 TypeID::get<TF::SoftsignGradOp>(),
232 TypeID::get<TF::SoftsignOp>(),
233 TypeID::get<TF::SpaceToBatchNDOp>(),
234 TypeID::get<TF::SpaceToBatchOp>(),
235 TypeID::get<TF::SpaceToDepthOp>(),
236 TypeID::get<TF::SparseToDenseOp>(),
237 TypeID::get<TF::SquareOp>(),
238 TypeID::get<TF::StatelessMultinomialOp>(),
239 TypeID::get<TF::StatelessParameterizedTruncatedNormalOp>(),
240 TypeID::get<TF::StatelessRandomGetAlgOp>(),
241 TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
242 TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
243 TypeID::get<TF::StatelessRandomNormalOp>(),
244 TypeID::get<TF::StatelessRandomNormalV2Op>(),
245 TypeID::get<TF::StatelessRandomUniformOp>(),
246 TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
247 TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
248 TypeID::get<TF::StatelessRandomUniformV2Op>(),
249 TypeID::get<TF::StatelessRandomUniformIntOp>(),
250 TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
251 TypeID::get<TF::StatelessTruncatedNormalOp>(),
252 TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
253 TypeID::get<TF::SubOp>(),
254 TypeID::get<TF::SvdOp>(),
255 TypeID::get<TF::TanOp>(),
256 TypeID::get<TF::TensorScatterAddOp>(),
257 TypeID::get<TF::TensorScatterSubOp>(),
258 TypeID::get<TF::TPUEmbeddingActivationsOp>(),
259 TypeID::get<TF::TopKUniqueOp>(),
260 TypeID::get<TF::TopKWithUniqueOp>(),
261 TypeID::get<TF::TransposeOp>(),
262 TypeID::get<TF::TridiagonalSolveOp>(),
263 TypeID::get<TF::TridiagonalMatMulOp>(),
264 TypeID::get<TF::TruncateDivOp>(),
265 TypeID::get<TF::TruncatedNormalOp>(),
266 TypeID::get<TF::TruncateModOp>(),
267 TypeID::get<TF::UniqueOp>(),
268 TypeID::get<TF::UnpackOp>(),
269 TypeID::get<TF::UpperBoundOp>(),
270 TypeID::get<TF::XlaBroadcastHelperOp>(),
271 TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
272 TypeID::get<TF::XlaKeyValueSortOp>(),
273 TypeID::get<TF::XlaPadOp>(),
274 TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
275 TypeID::get<TF::XlaSvdOp>(),
276 };
277 // clang-format on
278
279 auto abstractOp = op->getRegisteredInfo();
280 if (!abstractOp) return false;
281 return ops->count(abstractOp->getTypeID());
282 }
283 // LINT.ThenChange(:Tf2XlaPreferred)
284
285 /// List of ops that should use XlaOpKernel legalization only in the case of
286 /// prefer_tf2xla. All other ops not in this list should use MLIR legalization
287 /// only or not be legalized by the new bridge.
288 // LINT.IfChange(Tf2XlaPreferred)
IsOpAllowedTf2XlaPreferred(Operation * op)289 bool IsOpAllowedTf2XlaPreferred(Operation* op) {
290 // Use a pointer for the static set, so the set is not destructed upon thread
291 // end, which would not be thread safe.
292 // clang-format off
293 static auto* ops =
294 new llvm::SmallDenseSet<mlir::TypeID, 512>{
295 TypeID::get<TF::AllOp>(),
296 TypeID::get<TF::AllToAllOp>(),
297 TypeID::get<TF::AnyOp>(),
298 TypeID::get<TF::AvgPoolOp>(),
299 TypeID::get<TF::AvgPool3DGradOp>(),
300 TypeID::get<TF::AvgPoolGradOp>(),
301 TypeID::get<TF::BatchToSpaceNDOp>(),
302 TypeID::get<TF::BitcastOp>(),
303 TypeID::get<TF::BroadcastToOp>(),
304 TypeID::get<TF::CollectivePermuteOp>(),
305 TypeID::get<TF::ConcatV2Op>(),
306 TypeID::get<TF::ConjOp>(),
307 TypeID::get<TF::Conv2DOp>(),
308 TypeID::get<TF::Conv2DBackpropFilterOp>(),
309 TypeID::get<TF::Conv2DBackpropInputOp>(),
310 TypeID::get<TF::Conv3DOp>(),
311 TypeID::get<TF::Conv3DBackpropFilterV2Op>(),
312 TypeID::get<TF::Conv3DBackpropInputV2Op>(),
313 TypeID::get<TF::CumprodOp>(),
314 TypeID::get<TF::CumsumOp>(),
315 TypeID::get<TF::DepthwiseConv2dNativeOp>(),
316 TypeID::get<TF::DynamicStitchOp>(),
317 TypeID::get<TF::_EagerConstOp>(),
318 TypeID::get<TF::EmptyOp>(),
319 TypeID::get<TF::ExpandDimsOp>(),
320 TypeID::get<TF::FakeQuantWithMinMaxVarsOp>(),
321 TypeID::get<TF::FillOp>(),
322 TypeID::get<TF::FusedBatchNormOp>(),
323 TypeID::get<TF::FusedBatchNormGradOp>(),
324 TypeID::get<TF::FusedBatchNormGradV2Op>(),
325 TypeID::get<TF::FusedBatchNormGradV3Op>(),
326 TypeID::get<TF::FusedBatchNormV2Op>(),
327 TypeID::get<TF::FusedBatchNormV3Op>(),
328 TypeID::get<TF::GatherNdOp>(),
329 TypeID::get<TF::GatherV2Op>(),
330 TypeID::get<TF::IdentityOp>(),
331 TypeID::get<TF::IdentityNOp>(),
332 TypeID::get<TF::InplaceUpdateOp>(),
333 TypeID::get<TF::InvertPermutationOp>(),
334 TypeID::get<TF::IRFFTOp>(),
335 TypeID::get<TF::L2LossOp>(),
336 TypeID::get<TF::LegacyCallOp>(),
337 TypeID::get<TF::LinSpaceOp>(),
338 TypeID::get<TF::MatrixDiagPartV3Op>(),
339 TypeID::get<TF::MaxOp>(),
340 TypeID::get<TF::MaximumOp>(),
341 TypeID::get<TF::MaxPoolOp>(),
342 TypeID::get<TF::MaxPool3DOp>(),
343 TypeID::get<TF::MaxPoolGradOp>(),
344 TypeID::get<TF::MeanOp>(),
345 TypeID::get<TF::MinOp>(),
346 TypeID::get<TF::MinimumOp>(),
347 TypeID::get<TF::MulNoNanOp>(),
348 TypeID::get<TF::OneHotOp>(),
349 TypeID::get<TF::OnesLikeOp>(),
350 TypeID::get<TF::PackOp>(),
351 TypeID::get<TF::PadV2Op>(),
352 TypeID::get<TF::ParallelDynamicStitchOp>(),
353 TypeID::get<TF::PartitionedCallOp>(),
354 TypeID::get<TF::ProdOp>(),
355 TypeID::get<TF::QrOp>(),
356 TypeID::get<TF::RandomStandardNormalOp>(),
357 TypeID::get<TF::RangeOp>(),
358 TypeID::get<TF::ReshapeOp>(),
359 TypeID::get<TF::ReverseV2Op>(),
360 TypeID::get<TF::RFFTOp>(),
361 TypeID::get<TF::RsqrtGradOp>(),
362 TypeID::get<TF::ScatterNdOp>(),
363 TypeID::get<TF::ShapeOp>(),
364 TypeID::get<TF::SinhOp>(),
365 TypeID::get<TF::SizeOp>(),
366 TypeID::get<TF::SliceOp>(),
367 TypeID::get<TF::SoftmaxCrossEntropyWithLogitsOp>(),
368 TypeID::get<TF::SoftplusOp>(),
369 TypeID::get<TF::SparseMatMulOp>(),
370 TypeID::get<TF::SparseSoftmaxCrossEntropyWithLogitsOp>(),
371 TypeID::get<TF::SplitOp>(),
372 TypeID::get<TF::SplitVOp>(),
373 TypeID::get<TF::SqueezeOp>(),
374 TypeID::get<TF::StatelessParameterizedTruncatedNormalOp>(),
375 TypeID::get<TF::StatefulPartitionedCallOp>(),
376 TypeID::get<TF::StopGradientOp>(),
377 TypeID::get<TF::StridedSliceGradOp>(),
378 TypeID::get<TF::SumOp>(),
379 TypeID::get<TF::TensorScatterUpdateOp>(),
380 TypeID::get<TF::TileOp>(),
381 TypeID::get<TF::TopKV2Op>(),
382 TypeID::get<TF::_UnaryOpsCompositionOp>(),
383 TypeID::get<TF::UnsortedSegmentMaxOp>(),
384 TypeID::get<TF::UnsortedSegmentMinOp>(),
385 TypeID::get<TF::UnsortedSegmentProdOp>(),
386 TypeID::get<TF::UnsortedSegmentSumOp>(),
387 TypeID::get<TF::XdivyOp>(),
388 TypeID::get<TF::XlaAllReduceOp>(),
389 TypeID::get<TF::XlaGatherOp>(),
390 TypeID::get<TF::Xlog1pyOp>(),
391 TypeID::get<TF::ZerosLikeOp>(),
392 };
393 // clang-format on
394 auto abstractOp = op->getRegisteredInfo();
395 if (!abstractOp) return false;
396 return ops->count(abstractOp->getTypeID());
397 }
398 // LINT.ThenChange()
399
IsOpAllowedForTesting(Operation * op)400 bool IsOpAllowedForTesting(Operation* op) {
401 // clang-format off
402 static auto* ops =
403 new llvm::SmallDenseSet<mlir::TypeID, 16>{
404 // Op used to verify handling of XlaExpression of kind constant.
405 TypeID::get<TF::ConstOp>(),
406 };
407 // clang-format on
408 auto abstractOp = op->getRegisteredInfo();
409 if (!abstractOp) return false;
410 return ops->count(abstractOp->getTypeID());
411 }
412
413 // List of ops that require falling back to XlaOpKernel legalizations and also
414 // require the ability to create functions.
IsOpAllowedTf2XlaFallbackAndCreateFunctions(Operation * op)415 bool IsOpAllowedTf2XlaFallbackAndCreateFunctions(Operation* op) {
416 static auto* ops = new llvm::SmallDenseSet<mlir::TypeID, 16>{
417 TypeID::get<TF::ApproxTopKOp>(),
418 };
419 auto abstractOp = op->getRegisteredInfo();
420 if (!abstractOp) return false;
421 return ops->count(abstractOp->getTypeID());
422 }
423
424 namespace {
425
426 template <typename T, size_t N>
427 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
428
CreateDeviceMgr(const std::string & device_type)429 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
430 const std::string& device_type) {
431 // Register compilation kernels for all registered XLA backends.
432 tensorflow::XlaOpRegistry::RegisterCompilationKernels();
433
434 auto device = std::make_unique<tensorflow::XlaCompilationDevice>(
435 tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
436 return std::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
437 }
438
439 class Tf2XlaRewriter {
440 public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type,bool is_module_pass)441 static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
442 const std::string& device_type,
443 bool is_module_pass) {
444 Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type, is_module_pass);
445 return tf2xla_rewriter.LegalizeOp();
446 }
447
448 private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type,bool is_module_pass)449 Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
450 const std::string& device_type, bool is_module_pass)
451 : op_(op),
452 device_type_(device_type),
453 rewriter_(rewriter),
454 hlo_builder_(op->getName().getStringRef().str(), rewriter_,
455 op->getLoc(), /*build_functions=*/is_module_pass),
456 context_(nullptr) {}
457
~Tf2XlaRewriter()458 ~Tf2XlaRewriter() {
459 if (context_) context_->Unref();
460 }
461
462 // Prepares OpKernelContext params common to all the ops.
463 // Emits an error on failure.
464 LogicalResult PrepareParams();
465
466 // Tries to legalize the specified TensorFlow op, if supported.
467 //
468 // Emits an error and returns failure if an error is encountered during
469 // conversion. Note that success return value doesn't mean successful
470 // legalization.
471 LogicalResult LegalizeOp();
472
473 // Converts the given operand to expression of kind kConstant or kXlaOp.
474 // Emits a remark and returns expression of kind kInvalid on failure.
475 tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
476
477 Operation* op_;
478 std::string device_type_;
479
480 PatternRewriter& rewriter_;
481 ::xla::MlirHloBuilder hlo_builder_;
482 tensorflow::OpOrArgLocNameMapper name_mapper_;
483
484 tensorflow::XlaContext* context_; // Ref-counted.
485
486 std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
487 tensorflow::Device* device_; // Owned by device_mgr_;
488 std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
489 std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
490 std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
491 tensorflow::OpKernelContext::Params params_;
492 };
493
PrepareParams()494 LogicalResult Tf2XlaRewriter::PrepareParams() {
495 // XlaCompiler within the context is only used by the functional ops to
496 // compile functions. We are not handling those at the moment so XlaCompiler
497 // is not required.
498 context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
499 /*graph=*/nullptr);
500 context_->Ref();
501
502 device_mgr_ = CreateDeviceMgr(device_type_);
503 if (!device_mgr_) return failure();
504
505 // Type of params_.device is DeviceBase* so store it as Device* to access
506 // derived class method.
507 device_ = device_mgr_->ListDevices().front();
508 params_.device = device_;
509 params_.resource_manager = device_->resource_manager();
510
511 // Resources are cleared at the time of device manager destruction so pass
512 // no-op cleanup function.
513 auto cleanup = [](const std::string& name) {};
514 // Use step_id zero as we only have a single context concurrently and
515 // concurrently running each of the MLIR functions create a new device.
516 step_container_ = std::make_unique<tensorflow::ScopedStepContainer>(
517 /*step_id=*/0, cleanup);
518 tensorflow::Status status = step_container_->Create(
519 device_->resource_manager(),
520 tensorflow::XlaContext::kXlaContextResourceName, context_);
521 if (!status.ok()) {
522 return emitRemark(op_->getLoc())
523 << "failed to create XlaContext resource: " << status.ToString();
524 }
525 params_.step_container = step_container_.get();
526
527 tensorflow::StatusOr<int64_t> version_or =
528 tensorflow::GetTfGraphProducerVersion(
529 op_->getParentOfType<mlir::ModuleOp>());
530 if (!version_or.ok()) {
531 return emitError(op_->getLoc()) << version_or.status().ToString();
532 }
533
534 flib_def_ = std::make_unique<tensorflow::FunctionLibraryDefinition>(
535 tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
536 pflr_ = std::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
537 device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
538 version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
539 params_.function_library = pflr_->GetFLR(device_->name());
540 return success();
541 }
542
LegalizeOp()543 LogicalResult Tf2XlaRewriter::LegalizeOp() {
544 // Only static shaped operands are supported in XLA builders for now.
545 for (Type ty : op_->getOperandTypes()) {
546 auto ranked_ty = ty.dyn_cast<ShapedType>();
547 if (!ranked_ty || !ranked_ty.hasStaticShape()) {
548 return op_->emitRemark()
549 << "lowering requires static shaped tensor operands";
550 }
551 }
552
553 for (const auto& attr : op_->getAttrs()) {
554 if (attr.getValue().isa<SymbolRefAttr>()) {
555 return op_->emitRemark()
556 << "ops with symbol references are not supported";
557 }
558 }
559
560 auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
561 op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
562 if (!nodedef_or.ok()) {
563 return op_->emitRemark() << "failed to convert op to NodeDef: "
564 << nodedef_or.status().ToString();
565 }
566
567 if (failed(PrepareParams())) return failure();
568
569 std::shared_ptr<const tensorflow::NodeProperties> props;
570 tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
571 *nodedef_or.ValueOrDie(),
572 params_.function_library->GetFunctionLibraryDefinition(), &props);
573 if (!status.ok()) {
574 return op_->emitRemark()
575 << "failed to create NodeProperties: " << status.ToString();
576 }
577 tensorflow::OpKernel* op_kernel_raw;
578 status = params_.function_library->CreateKernel(props, &op_kernel_raw);
579 if (!status.ok()) {
580 return op_->emitRemark()
581 << "failed to create tf2xla kernel: " << status.ToString();
582 }
583 // Transfer ownership of the kernel to a local smart pointer.
584 auto op_kernel = absl::WrapUnique(op_kernel_raw);
585
586 std::vector<int> required_constants;
587 status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
588 *op_kernel, &required_constants);
589 if (!status.ok()) {
590 return op_->emitRemark()
591 << "failed to compute required constants: " << status.ToString();
592 }
593 llvm::SmallDenseSet<int, 4> required_consts;
594 required_consts.insert(required_constants.begin(), required_constants.end());
595
596 // TensorValue in inputs are backed by tensors which in turn depend on
597 // expressions. So, pre-allocate them to the required size.
598 InlinedVector<tensorflow::XlaExpression, 4> expressions;
599 InlinedVector<tensorflow::Tensor, 4> tensors;
600 InlinedVector<tensorflow::TensorValue, 4> inputs;
601 expressions.reserve(op_->getNumOperands());
602 tensors.reserve(op_->getNumOperands());
603 inputs.reserve(op_->getNumOperands());
604
605 // Prepare the list of Tensor inputs for the kernel.
606 for (auto it : llvm::enumerate(op_->getOperands())) {
607 Value operand = it.value();
608 size_t idx = it.index();
609
610 tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
611 tensorflow::XlaExpression::Kind kind = expr.kind();
612 if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
613 if (required_consts.count(idx) &&
614 kind != tensorflow::XlaExpression::Kind::kConstant) {
615 return op_->emitRemark()
616 << "lowering requires operand #" << idx << " to be a constant";
617 }
618 expressions.push_back(expr);
619
620 if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
621 return op_->emitRemark()
622 << "skipping legalization due to unsupported type "
623 << operand.getType();
624 }
625
626 auto shape_or = expr.GetShape();
627 if (!shape_or.ok()) {
628 return op_->emitRemark()
629 << "failed to get shape for expression. " << expr.HumanString();
630 }
631
632 tensors.emplace_back(
633 device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
634 shape_or.ValueOrDie());
635 tensorflow::Tensor& tensor = tensors.back();
636 tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
637 inputs.emplace_back(&tensor);
638 }
639
640 params_.inputs = inputs;
641 params_.op_kernel = op_kernel.get();
642 llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
643 op_->getNumResults());
644 params_.output_attr_array = output_attr.data();
645
646 hlo_builder_.setInsertionPoint(op_);
647 hlo_builder_.SetLocation(op_->getLoc());
648
649 // Execute the kernel.
650 tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults());
651 device_->Compute(params_.op_kernel, &op_context);
652
653 status = op_context.status();
654 status.Update(hlo_builder_.GetCurrentStatus());
655 if (!status.ok()) {
656 return op_->emitRemark()
657 << "compilation to HLO failed: " << status.ToString();
658 }
659
660 // Replace uses of old results using the corresponding value after the
661 // lowering.
662 llvm::SmallVector<Value, 2> values;
663 values.reserve(op_->getNumResults());
664 for (int i = 0, e = op_->getNumResults(); i < e; i++) {
665 tensorflow::Tensor* output = op_context.mutable_output(i);
666 const tensorflow::XlaExpression* expr =
667 tensorflow::XlaExpression::CastExpressionFromTensor(*output);
668 if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp &&
669 expr->kind() != tensorflow::XlaExpression::Kind::kConstant) {
670 return op_->emitRemark(
671 "expects XlaExpression of kind kXlaOp or kConstant in compiled "
672 "output");
673 }
674 mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_));
675 mlir::OpResult old_result = op_->getResult(i);
676 if (value.getType() != old_result.getType()) {
677 value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
678 value);
679 }
680 values.push_back(value);
681 }
682 rewriter_.replaceOp(op_, values);
683 return success();
684 }
685
GetExprForOperand(Value operand,Operation * op)686 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
687 Operation* op) {
688 ElementsAttr const_attr;
689 auto defining_op = operand.getDefiningOp();
690 if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
691 tensorflow::Tensor tensor;
692 auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
693 if (!status.ok()) {
694 op->emitRemark() << "skipping legalization due to failed const conversion"
695 << status.ToString();
696 return tensorflow::XlaExpression::Invalid();
697 }
698 return tensorflow::XlaExpression::Constant(tensor);
699 }
700
701 // Skip this op if XLA doesn't support this operand type.
702 auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
703 if (!xla_op_or.ok()) {
704 op->emitRemark() << "skipping legalization due to "
705 << xla_op_or.status().ToString();
706 return tensorflow::XlaExpression::Invalid();
707 }
708 ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
709
710 tensorflow::DataType dtype;
711 auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
712 if (!status.ok()) {
713 op->emitRemark() << "skipping legalization due to " << status.ToString();
714 return tensorflow::XlaExpression::Invalid();
715 }
716 return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
717 }
718
719 class Tf2XlaRewritePattern : public RewritePattern {
720 public:
Tf2XlaRewritePattern(MLIRContext * ctx,const std::string & device_type,bool prefer_tf2xla,bool legalize_test_only_ops,bool is_module_pass)721 explicit Tf2XlaRewritePattern(MLIRContext* ctx,
722 const std::string& device_type,
723 bool prefer_tf2xla, bool legalize_test_only_ops,
724 bool is_module_pass)
725 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
726 device_type_(device_type),
727 prefer_tf2xla_(prefer_tf2xla),
728 legalize_test_only_ops_(legalize_test_only_ops),
729 is_module_pass_(is_module_pass) {}
730
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const731 LogicalResult matchAndRewrite(Operation* op,
732 PatternRewriter& rewriter) const override {
733 if (is_module_pass_) {
734 // Module passes should only ever legalize ops that have been specifically
735 // whitelisted for legalization within a module pass. They will never
736 // legalize any ops whitelisted for legalization within a func pass.
737 if (!IsOpAllowedTf2XlaFallbackAndCreateFunctions(op)) {
738 return failure();
739 }
740 } else if (!(IsOpAllowedTf2XlaFallback(op) ||
741 (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)) ||
742 (legalize_test_only_ops_ && IsOpAllowedForTesting(op)))) {
743 return failure();
744 }
745 return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_,
746 is_module_pass_);
747 }
748
749 private:
750 std::string device_type_;
751 bool prefer_tf2xla_;
752 bool legalize_test_only_ops_;
753 bool is_module_pass_;
754 };
755
756 class LegalizeTF : public LegalizeTFPassBase<LegalizeTF> {
757 public:
758 LegalizeTF() = default;
LegalizeTF(llvm::StringRef device_type,bool prefer_tf2xla)759 explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
760 device_type_ = device_type.str();
761 prefer_tf2xla_ = prefer_tf2xla;
762 }
763
LegalizeTF(const LegalizeTF &)764 LegalizeTF(const LegalizeTF&) {}
765
runOnOperation()766 void runOnOperation() override {
767 RewritePatternSet patterns(&getContext());
768 patterns.add<Tf2XlaRewritePattern>(&getContext(), device_type_,
769 prefer_tf2xla_, legalize_test_only_ops_,
770 /*is_module_pass=*/false);
771 if (failed(
772 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
773 signalPassFailure();
774 }
775
776 private:
777 };
778
779 } // end namespace
780
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,RewritePatternSet & patterns,MLIRContext * ctx,bool prefer_tf2xla,bool is_module_pass)781 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
782 RewritePatternSet& patterns,
783 MLIRContext* ctx, bool prefer_tf2xla,
784 bool is_module_pass) {
785 patterns.add<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla,
786 /*legalize_test_only_ops=*/false,
787 is_module_pass);
788 }
789
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type,bool prefer_tf2xla)790 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTfWithTf2XlaPass(
791 llvm::StringRef device_type, bool prefer_tf2xla) {
792 return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
793 }
794
795 } // end namespace mhlo
796 } // end namespace mlir
797