• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
48 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
49 #include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes_detail.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
57 #include "tensorflow/core/common_runtime/device.h"
58 #include "tensorflow/core/common_runtime/device_factory.h"
59 #include "tensorflow/core/common_runtime/device_mgr.h"
60 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
61 #include "tensorflow/core/framework/allocator.h"
62 #include "tensorflow/core/framework/function.h"
63 #include "tensorflow/core/framework/function.pb.h"
64 #include "tensorflow/core/framework/node_properties.h"
65 #include "tensorflow/core/framework/op.h"
66 #include "tensorflow/core/framework/op_kernel.h"
67 #include "tensorflow/core/framework/resource_mgr.h"
68 #include "tensorflow/core/framework/tensor.h"
69 #include "tensorflow/core/framework/types.h"
70 #include "tensorflow/core/framework/types.pb.h"
71 #include "tensorflow/core/platform/env.h"
72 #include "tensorflow/core/platform/status.h"
73 #include "tensorflow/core/protobuf/config.pb.h"
74 #include "tensorflow/core/public/session_options.h"
75 #include "tensorflow/stream_executor/lib/statusor.h"
76 #include "tensorflow/stream_executor/stream_executor.h"
77 
78 namespace mlir {
79 namespace mhlo {
80 
81 // LINT.IfChange
IsOpAllowedTf2XlaFallback(Operation * op)82 bool IsOpAllowedTf2XlaFallback(Operation* op) {
83   // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
84   // building valid MLIR using MlirHloBuilder.
85   // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
86   // all tf2xla kernels.
87   // Use a pointer for the static set, so the set is not destructed upon thread
88   // end, which would not be thread safe.
89   // clang-format off
90 
91   static auto* ops =
92       new llvm::SmallDenseSet<mlir::TypeID, 512>{
93     TypeID::get<TF::AcoshOp>(),
94     TypeID::get<TF::AcosOp>(),
95     TypeID::get<TF::AddNOp>(),
96     TypeID::get<TF::AddV2Op>(),
97     TypeID::get<TF::AngleOp>(),
98     TypeID::get<TF::AdjustContrastv2Op>(),
99     TypeID::get<TF::AdjustHueOp>(),
100     TypeID::get<TF::AdjustSaturationOp>(),
101     TypeID::get<TF::ApproximateEqualOp>(),
102     TypeID::get<TF::ArgMaxOp>(),
103     TypeID::get<TF::ArgMinOp>(),
104     TypeID::get<TF::AsinhOp>(),
105     TypeID::get<TF::AsinOp>(),
106     TypeID::get<TF::Atan2Op>(),
107     TypeID::get<TF::AtanhOp>(),
108     TypeID::get<TF::BatchMatMulV2Op>(),
109     TypeID::get<TF::BatchMatMulV3Op>(),
110     TypeID::get<TF::BatchToSpaceOp>(),
111     TypeID::get<TF::BesselI0eOp>(),
112     TypeID::get<TF::BesselI1eOp>(),
113     TypeID::get<TF::BetaincOp>(),
114     TypeID::get<TF::BiasAddOp>(),
115     TypeID::get<TF::BitwiseAndOp>(),
116     TypeID::get<TF::BitwiseOrOp>(),
117     TypeID::get<TF::BitwiseXorOp>(),
118     TypeID::get<TF::BucketizeOp>(),
119     TypeID::get<TF::CastOp>(),
120     TypeID::get<TF::ClipByValueOp>(),
121     TypeID::get<TF::CholeskyOp>(),
122     TypeID::get<TF::ComplexAbsOp>(),
123     TypeID::get<TF::ConjugateTransposeOp>(),
124     TypeID::get<TF::CoshOp>(),
125     TypeID::get<TF::CrossOp>(),
126     TypeID::get<TF::DataFormatDimMapOp>(),
127     TypeID::get<TF::DataFormatVecPermuteOp>(),
128     TypeID::get<TF::DepthToSpaceOp>(),
129     TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
130     TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
131     TypeID::get<TF::DiagOp>(),
132     TypeID::get<TF::DigammaOp>(),
133     TypeID::get<TF::DivNoNanOp>(),
134     TypeID::get<TF::EluGradOp>(),
135     TypeID::get<TF::EluOp>(),
136     TypeID::get<TF::EnsureShapeOp>(),
137     TypeID::get<TF::EqualOp>(),
138     TypeID::get<TF::ErfcOp>(),
139     TypeID::get<TF::ErfinvOp>(),
140     TypeID::get<TF::ErfOp>(),
141     TypeID::get<TF::ExtractImagePatchesOp>(),
142     TypeID::get<TF::FFT2DOp>(),
143     TypeID::get<TF::FFT3DOp>(),
144     TypeID::get<TF::FFTOp>(),
145     TypeID::get<TF::FakeParamOp>(),
146     TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
147     TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
148     TypeID::get<TF::FakeQuantWithMinMaxVarsPerChannelOp>(),
149     TypeID::get<TF::FakeQuantWithMinMaxVarsPerChannelGradientOp>(),
150     TypeID::get<TF::FloorDivOp>(),
151     TypeID::get<TF::FloorModOp>(),
152     TypeID::get<TF::GreaterOp>(),
153     TypeID::get<TF::HSVToRGBOp>(),
154     TypeID::get<TF::IFFT2DOp>(),
155     TypeID::get<TF::IFFT3DOp>(),
156     TypeID::get<TF::IRFFT2DOp>(),
157     TypeID::get<TF::IRFFT3DOp>(),
158     TypeID::get<TF::IgammaOp>(),
159     TypeID::get<TF::IgammacOp>(),
160     TypeID::get<TF::IgammaGradAOp>(),
161     TypeID::get<TF::InplaceAddOp>(),
162     TypeID::get<TF::InTopKV2Op>(),
163     TypeID::get<TF::InvertOp>(),
164     TypeID::get<TF::InvOp>(),
165     TypeID::get<TF::KthOrderStatisticOp>(),
166     TypeID::get<TF::LRNOp>(),
167     TypeID::get<TF::LRNGradOp>(),
168     TypeID::get<TF::LeakyReluGradOp>(),
169     TypeID::get<TF::LeakyReluOp>(),
170     TypeID::get<TF::LeftShiftOp>(),
171     TypeID::get<TF::LessOp>(),
172     TypeID::get<TF::ListDiffOp>(),
173     TypeID::get<TF::LogicalAndOp>(),
174     TypeID::get<TF::LogicalNotOp>(),
175     TypeID::get<TF::LogOp>(),
176     TypeID::get<TF::LowerBoundOp>(),
177     TypeID::get<TF::MakeUniqueOp>(),
178     TypeID::get<TF::MatMulOp>(),
179     TypeID::get<TF::MatrixDiagV3Op>(),
180     TypeID::get<TF::MatrixInverseOp>(),
181     TypeID::get<TF::MatrixSetDiagV3Op>(),
182     TypeID::get<TF::MatrixSolveOp>(),
183     TypeID::get<TF::MatrixTriangularSolveOp>(),
184     TypeID::get<TF::MaxPool3DGradGradOp>(),
185     TypeID::get<TF::MaxPoolGradGradOp>(),
186     TypeID::get<TF::MirrorPadOp>(),
187     TypeID::get<TF::MirrorPadGradOp>(),
188     TypeID::get<TF::MulOp>(),
189     TypeID::get<TF::MultinomialOp>(),
190     TypeID::get<TF::NdtriOp>(),
191     TypeID::get<TF::NegOp>(),
192     TypeID::get<TF::NextAfterOp>(),
193     TypeID::get<TF::NonMaxSuppressionV4Op>(),
194     TypeID::get<TF::NotEqualOp>(),
195     TypeID::get<TF::PadOp>(),
196     TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
197     TypeID::get<TF::PlaceholderWithDefaultOp>(),
198     TypeID::get<TF::PolygammaOp>(),
199     TypeID::get<TF::PopulationCountOp>(),
200     TypeID::get<TF::PowOp>(),
201     // TODO(hinsu): Canonicalize QuantizeAndDequantize and
202     // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
203     // attributes to operands.
204     TypeID::get<TF::QuantizeAndDequantizeOp>(),
205     TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
206     TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
207     TypeID::get<TF::QuantizeAndDequantizeV4Op>(),
208     TypeID::get<TF::RFFT2DOp>(),
209     TypeID::get<TF::RFFT3DOp>(),
210     TypeID::get<TF::RGBToHSVOp>(),
211     TypeID::get<TF::RandomUniformIntOp>(),
212     TypeID::get<TF::RealDivOp>(),
213     TypeID::get<TF::ReciprocalGradOp>(),
214     TypeID::get<TF::Relu6GradOp>(),
215     TypeID::get<TF::ResizeBilinearOp>(),
216     TypeID::get<TF::ResizeBilinearGradOp>(),
217     TypeID::get<TF::ResizeNearestNeighborOp>(),
218     TypeID::get<TF::ResizeNearestNeighborGradOp>(),
219     TypeID::get<TF::ReverseSequenceOp>(),
220     TypeID::get<TF::RightShiftOp>(),
221     TypeID::get<TF::RintOp>(),
222     TypeID::get<TF::RollOp>(),
223     TypeID::get<TF::RoundOp>(),
224     TypeID::get<TF::SelectV2Op>(),
225     TypeID::get<TF::SelfAdjointEigV2Op>(),
226     TypeID::get<TF::SeluGradOp>(),
227     TypeID::get<TF::SeluOp>(),
228     TypeID::get<TF::SigmoidGradOp>(),
229     TypeID::get<TF::SinOp>(),
230     TypeID::get<TF::SoftplusGradOp>(),
231     TypeID::get<TF::SoftsignGradOp>(),
232     TypeID::get<TF::SoftsignOp>(),
233     TypeID::get<TF::SpaceToBatchNDOp>(),
234     TypeID::get<TF::SpaceToBatchOp>(),
235     TypeID::get<TF::SpaceToDepthOp>(),
236     TypeID::get<TF::SparseToDenseOp>(),
237     TypeID::get<TF::SquareOp>(),
238     TypeID::get<TF::StatelessMultinomialOp>(),
239     TypeID::get<TF::StatelessParameterizedTruncatedNormalOp>(),
240     TypeID::get<TF::StatelessRandomGetAlgOp>(),
241     TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
242     TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
243     TypeID::get<TF::StatelessRandomNormalOp>(),
244     TypeID::get<TF::StatelessRandomNormalV2Op>(),
245     TypeID::get<TF::StatelessRandomUniformOp>(),
246     TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
247     TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
248     TypeID::get<TF::StatelessRandomUniformV2Op>(),
249     TypeID::get<TF::StatelessRandomUniformIntOp>(),
250     TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
251     TypeID::get<TF::StatelessTruncatedNormalOp>(),
252     TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
253     TypeID::get<TF::SubOp>(),
254     TypeID::get<TF::SvdOp>(),
255     TypeID::get<TF::TanOp>(),
256     TypeID::get<TF::TensorScatterAddOp>(),
257     TypeID::get<TF::TensorScatterSubOp>(),
258     TypeID::get<TF::TPUEmbeddingActivationsOp>(),
259     TypeID::get<TF::TopKUniqueOp>(),
260     TypeID::get<TF::TopKWithUniqueOp>(),
261     TypeID::get<TF::TransposeOp>(),
262     TypeID::get<TF::TridiagonalSolveOp>(),
263     TypeID::get<TF::TridiagonalMatMulOp>(),
264     TypeID::get<TF::TruncateDivOp>(),
265     TypeID::get<TF::TruncatedNormalOp>(),
266     TypeID::get<TF::TruncateModOp>(),
267     TypeID::get<TF::UniqueOp>(),
268     TypeID::get<TF::UnpackOp>(),
269     TypeID::get<TF::UpperBoundOp>(),
270     TypeID::get<TF::XlaBroadcastHelperOp>(),
271     TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
272     TypeID::get<TF::XlaKeyValueSortOp>(),
273     TypeID::get<TF::XlaPadOp>(),
274     TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
275     TypeID::get<TF::XlaSvdOp>(),
276   };
277   // clang-format on
278 
279   auto abstractOp = op->getRegisteredInfo();
280   if (!abstractOp) return false;
281   return ops->count(abstractOp->getTypeID());
282 }
283 // LINT.ThenChange(:Tf2XlaPreferred)
284 
285 /// List of ops that should use XlaOpKernel legalization only in the case of
286 /// prefer_tf2xla. All other ops not in this list should use MLIR legalization
287 /// only or not be legalized by the new bridge.
288 // LINT.IfChange(Tf2XlaPreferred)
IsOpAllowedTf2XlaPreferred(Operation * op)289 bool IsOpAllowedTf2XlaPreferred(Operation* op) {
290   // Use a pointer for the static set, so the set is not destructed upon thread
291   // end, which would not be thread safe.
292   // clang-format off
293   static auto* ops =
294       new llvm::SmallDenseSet<mlir::TypeID, 512>{
295     TypeID::get<TF::AllOp>(),
296     TypeID::get<TF::AllToAllOp>(),
297     TypeID::get<TF::AnyOp>(),
298     TypeID::get<TF::AvgPoolOp>(),
299     TypeID::get<TF::AvgPool3DGradOp>(),
300     TypeID::get<TF::AvgPoolGradOp>(),
301     TypeID::get<TF::BatchToSpaceNDOp>(),
302     TypeID::get<TF::BitcastOp>(),
303     TypeID::get<TF::BroadcastToOp>(),
304     TypeID::get<TF::CollectivePermuteOp>(),
305     TypeID::get<TF::ConcatV2Op>(),
306     TypeID::get<TF::ConjOp>(),
307     TypeID::get<TF::Conv2DOp>(),
308     TypeID::get<TF::Conv2DBackpropFilterOp>(),
309     TypeID::get<TF::Conv2DBackpropInputOp>(),
310     TypeID::get<TF::Conv3DOp>(),
311     TypeID::get<TF::Conv3DBackpropFilterV2Op>(),
312     TypeID::get<TF::Conv3DBackpropInputV2Op>(),
313     TypeID::get<TF::CumprodOp>(),
314     TypeID::get<TF::CumsumOp>(),
315     TypeID::get<TF::DepthwiseConv2dNativeOp>(),
316     TypeID::get<TF::DynamicStitchOp>(),
317     TypeID::get<TF::_EagerConstOp>(),
318     TypeID::get<TF::EmptyOp>(),
319     TypeID::get<TF::ExpandDimsOp>(),
320     TypeID::get<TF::FakeQuantWithMinMaxVarsOp>(),
321     TypeID::get<TF::FillOp>(),
322     TypeID::get<TF::FusedBatchNormOp>(),
323     TypeID::get<TF::FusedBatchNormGradOp>(),
324     TypeID::get<TF::FusedBatchNormGradV2Op>(),
325     TypeID::get<TF::FusedBatchNormGradV3Op>(),
326     TypeID::get<TF::FusedBatchNormV2Op>(),
327     TypeID::get<TF::FusedBatchNormV3Op>(),
328     TypeID::get<TF::GatherNdOp>(),
329     TypeID::get<TF::GatherV2Op>(),
330     TypeID::get<TF::IdentityOp>(),
331     TypeID::get<TF::IdentityNOp>(),
332     TypeID::get<TF::InplaceUpdateOp>(),
333     TypeID::get<TF::InvertPermutationOp>(),
334     TypeID::get<TF::IRFFTOp>(),
335     TypeID::get<TF::L2LossOp>(),
336     TypeID::get<TF::LegacyCallOp>(),
337     TypeID::get<TF::LinSpaceOp>(),
338     TypeID::get<TF::MatrixDiagPartV3Op>(),
339     TypeID::get<TF::MaxOp>(),
340     TypeID::get<TF::MaximumOp>(),
341     TypeID::get<TF::MaxPoolOp>(),
342     TypeID::get<TF::MaxPool3DOp>(),
343     TypeID::get<TF::MaxPoolGradOp>(),
344     TypeID::get<TF::MeanOp>(),
345     TypeID::get<TF::MinOp>(),
346     TypeID::get<TF::MinimumOp>(),
347     TypeID::get<TF::MulNoNanOp>(),
348     TypeID::get<TF::OneHotOp>(),
349     TypeID::get<TF::OnesLikeOp>(),
350     TypeID::get<TF::PackOp>(),
351     TypeID::get<TF::PadV2Op>(),
352     TypeID::get<TF::ParallelDynamicStitchOp>(),
353     TypeID::get<TF::PartitionedCallOp>(),
354     TypeID::get<TF::ProdOp>(),
355     TypeID::get<TF::QrOp>(),
356     TypeID::get<TF::RandomStandardNormalOp>(),
357     TypeID::get<TF::RangeOp>(),
358     TypeID::get<TF::ReshapeOp>(),
359     TypeID::get<TF::ReverseV2Op>(),
360     TypeID::get<TF::RFFTOp>(),
361     TypeID::get<TF::RsqrtGradOp>(),
362     TypeID::get<TF::ScatterNdOp>(),
363     TypeID::get<TF::ShapeOp>(),
364     TypeID::get<TF::SinhOp>(),
365     TypeID::get<TF::SizeOp>(),
366     TypeID::get<TF::SliceOp>(),
367     TypeID::get<TF::SoftmaxCrossEntropyWithLogitsOp>(),
368     TypeID::get<TF::SoftplusOp>(),
369     TypeID::get<TF::SparseMatMulOp>(),
370     TypeID::get<TF::SparseSoftmaxCrossEntropyWithLogitsOp>(),
371     TypeID::get<TF::SplitOp>(),
372     TypeID::get<TF::SplitVOp>(),
373     TypeID::get<TF::SqueezeOp>(),
374     TypeID::get<TF::StatelessParameterizedTruncatedNormalOp>(),
375     TypeID::get<TF::StatefulPartitionedCallOp>(),
376     TypeID::get<TF::StopGradientOp>(),
377     TypeID::get<TF::StridedSliceGradOp>(),
378     TypeID::get<TF::SumOp>(),
379     TypeID::get<TF::TensorScatterUpdateOp>(),
380     TypeID::get<TF::TileOp>(),
381     TypeID::get<TF::TopKV2Op>(),
382     TypeID::get<TF::_UnaryOpsCompositionOp>(),
383     TypeID::get<TF::UnsortedSegmentMaxOp>(),
384     TypeID::get<TF::UnsortedSegmentMinOp>(),
385     TypeID::get<TF::UnsortedSegmentProdOp>(),
386     TypeID::get<TF::UnsortedSegmentSumOp>(),
387     TypeID::get<TF::XdivyOp>(),
388     TypeID::get<TF::XlaAllReduceOp>(),
389     TypeID::get<TF::XlaGatherOp>(),
390     TypeID::get<TF::Xlog1pyOp>(),
391     TypeID::get<TF::ZerosLikeOp>(),
392   };
393   // clang-format on
394   auto abstractOp = op->getRegisteredInfo();
395   if (!abstractOp) return false;
396   return ops->count(abstractOp->getTypeID());
397 }
398 // LINT.ThenChange()
399 
IsOpAllowedForTesting(Operation * op)400 bool IsOpAllowedForTesting(Operation* op) {
401   // clang-format off
402   static auto* ops =
403       new llvm::SmallDenseSet<mlir::TypeID, 16>{
404     // Op used to verify handling of XlaExpression of kind constant.
405     TypeID::get<TF::ConstOp>(),
406   };
407   // clang-format on
408   auto abstractOp = op->getRegisteredInfo();
409   if (!abstractOp) return false;
410   return ops->count(abstractOp->getTypeID());
411 }
412 
413 // List of ops that require falling back to XlaOpKernel legalizations and also
414 // require the ability to create functions.
IsOpAllowedTf2XlaFallbackAndCreateFunctions(Operation * op)415 bool IsOpAllowedTf2XlaFallbackAndCreateFunctions(Operation* op) {
416   static auto* ops = new llvm::SmallDenseSet<mlir::TypeID, 16>{
417       TypeID::get<TF::ApproxTopKOp>(),
418   };
419   auto abstractOp = op->getRegisteredInfo();
420   if (!abstractOp) return false;
421   return ops->count(abstractOp->getTypeID());
422 }
423 
424 namespace {
425 
426 template <typename T, size_t N>
427 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>;  // non-absl ok
428 
CreateDeviceMgr(const std::string & device_type)429 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
430     const std::string& device_type) {
431   // Register compilation kernels for all registered XLA backends.
432   tensorflow::XlaOpRegistry::RegisterCompilationKernels();
433 
434   auto device = std::make_unique<tensorflow::XlaCompilationDevice>(
435       tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
436   return std::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
437 }
438 
439 class Tf2XlaRewriter {
440  public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type,bool is_module_pass)441   static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
442                                  const std::string& device_type,
443                                  bool is_module_pass) {
444     Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type, is_module_pass);
445     return tf2xla_rewriter.LegalizeOp();
446   }
447 
448  private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type,bool is_module_pass)449   Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
450                  const std::string& device_type, bool is_module_pass)
451       : op_(op),
452         device_type_(device_type),
453         rewriter_(rewriter),
454         hlo_builder_(op->getName().getStringRef().str(), rewriter_,
455                      op->getLoc(), /*build_functions=*/is_module_pass),
456         context_(nullptr) {}
457 
~Tf2XlaRewriter()458   ~Tf2XlaRewriter() {
459     if (context_) context_->Unref();
460   }
461 
462   // Prepares OpKernelContext params common to all the ops.
463   // Emits an error on failure.
464   LogicalResult PrepareParams();
465 
466   // Tries to legalize the specified TensorFlow op, if supported.
467   //
468   // Emits an error and returns failure if an error is encountered during
469   // conversion. Note that success return value doesn't mean successful
470   // legalization.
471   LogicalResult LegalizeOp();
472 
473   // Converts the given operand to expression of kind kConstant or kXlaOp.
474   // Emits a remark and returns expression of kind kInvalid on failure.
475   tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
476 
477   Operation* op_;
478   std::string device_type_;
479 
480   PatternRewriter& rewriter_;
481   ::xla::MlirHloBuilder hlo_builder_;
482   tensorflow::OpOrArgLocNameMapper name_mapper_;
483 
484   tensorflow::XlaContext* context_;  // Ref-counted.
485 
486   std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
487   tensorflow::Device* device_;  // Owned by device_mgr_;
488   std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
489   std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
490   std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
491   tensorflow::OpKernelContext::Params params_;
492 };
493 
PrepareParams()494 LogicalResult Tf2XlaRewriter::PrepareParams() {
495   // XlaCompiler within the context is only used by the functional ops to
496   // compile functions. We are not handling those at the moment so XlaCompiler
497   // is not required.
498   context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
499                                         /*graph=*/nullptr);
500   context_->Ref();
501 
502   device_mgr_ = CreateDeviceMgr(device_type_);
503   if (!device_mgr_) return failure();
504 
505   // Type of params_.device is DeviceBase* so store it as Device* to access
506   // derived class method.
507   device_ = device_mgr_->ListDevices().front();
508   params_.device = device_;
509   params_.resource_manager = device_->resource_manager();
510 
511   // Resources are cleared at the time of device manager destruction so pass
512   // no-op cleanup function.
513   auto cleanup = [](const std::string& name) {};
514   // Use step_id zero as we only have a single context concurrently and
515   // concurrently running each of the MLIR functions create a new device.
516   step_container_ = std::make_unique<tensorflow::ScopedStepContainer>(
517       /*step_id=*/0, cleanup);
518   tensorflow::Status status = step_container_->Create(
519       device_->resource_manager(),
520       tensorflow::XlaContext::kXlaContextResourceName, context_);
521   if (!status.ok()) {
522     return emitRemark(op_->getLoc())
523            << "failed to create XlaContext resource: " << status.ToString();
524   }
525   params_.step_container = step_container_.get();
526 
527   tensorflow::StatusOr<int64_t> version_or =
528       tensorflow::GetTfGraphProducerVersion(
529           op_->getParentOfType<mlir::ModuleOp>());
530   if (!version_or.ok()) {
531     return emitError(op_->getLoc()) << version_or.status().ToString();
532   }
533 
534   flib_def_ = std::make_unique<tensorflow::FunctionLibraryDefinition>(
535       tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
536   pflr_ = std::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
537       device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
538       version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
539   params_.function_library = pflr_->GetFLR(device_->name());
540   return success();
541 }
542 
LegalizeOp()543 LogicalResult Tf2XlaRewriter::LegalizeOp() {
544   // Only static shaped operands are supported in XLA builders for now.
545   for (Type ty : op_->getOperandTypes()) {
546     auto ranked_ty = ty.dyn_cast<ShapedType>();
547     if (!ranked_ty || !ranked_ty.hasStaticShape()) {
548       return op_->emitRemark()
549              << "lowering requires static shaped tensor operands";
550     }
551   }
552 
553   for (const auto& attr : op_->getAttrs()) {
554     if (attr.getValue().isa<SymbolRefAttr>()) {
555       return op_->emitRemark()
556              << "ops with symbol references are not supported";
557     }
558   }
559 
560   auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
561       op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
562   if (!nodedef_or.ok()) {
563     return op_->emitRemark() << "failed to convert op to NodeDef: "
564                              << nodedef_or.status().ToString();
565   }
566 
567   if (failed(PrepareParams())) return failure();
568 
569   std::shared_ptr<const tensorflow::NodeProperties> props;
570   tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
571       *nodedef_or.ValueOrDie(),
572       params_.function_library->GetFunctionLibraryDefinition(), &props);
573   if (!status.ok()) {
574     return op_->emitRemark()
575            << "failed to create NodeProperties: " << status.ToString();
576   }
577   tensorflow::OpKernel* op_kernel_raw;
578   status = params_.function_library->CreateKernel(props, &op_kernel_raw);
579   if (!status.ok()) {
580     return op_->emitRemark()
581            << "failed to create tf2xla kernel: " << status.ToString();
582   }
583   // Transfer ownership of the kernel to a local smart pointer.
584   auto op_kernel = absl::WrapUnique(op_kernel_raw);
585 
586   std::vector<int> required_constants;
587   status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
588       *op_kernel, &required_constants);
589   if (!status.ok()) {
590     return op_->emitRemark()
591            << "failed to compute required constants: " << status.ToString();
592   }
593   llvm::SmallDenseSet<int, 4> required_consts;
594   required_consts.insert(required_constants.begin(), required_constants.end());
595 
596   // TensorValue in inputs are backed by tensors which in turn depend on
597   // expressions. So, pre-allocate them to the required size.
598   InlinedVector<tensorflow::XlaExpression, 4> expressions;
599   InlinedVector<tensorflow::Tensor, 4> tensors;
600   InlinedVector<tensorflow::TensorValue, 4> inputs;
601   expressions.reserve(op_->getNumOperands());
602   tensors.reserve(op_->getNumOperands());
603   inputs.reserve(op_->getNumOperands());
604 
605   // Prepare the list of Tensor inputs for the kernel.
606   for (auto it : llvm::enumerate(op_->getOperands())) {
607     Value operand = it.value();
608     size_t idx = it.index();
609 
610     tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
611     tensorflow::XlaExpression::Kind kind = expr.kind();
612     if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
613     if (required_consts.count(idx) &&
614         kind != tensorflow::XlaExpression::Kind::kConstant) {
615       return op_->emitRemark()
616              << "lowering requires operand #" << idx << " to be a constant";
617     }
618     expressions.push_back(expr);
619 
620     if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
621       return op_->emitRemark()
622              << "skipping legalization due to unsupported type "
623              << operand.getType();
624     }
625 
626     auto shape_or = expr.GetShape();
627     if (!shape_or.ok()) {
628       return op_->emitRemark()
629              << "failed to get shape for expression. " << expr.HumanString();
630     }
631 
632     tensors.emplace_back(
633         device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
634         shape_or.ValueOrDie());
635     tensorflow::Tensor& tensor = tensors.back();
636     tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
637     inputs.emplace_back(&tensor);
638   }
639 
640   params_.inputs = inputs;
641   params_.op_kernel = op_kernel.get();
642   llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
643       op_->getNumResults());
644   params_.output_attr_array = output_attr.data();
645 
646   hlo_builder_.setInsertionPoint(op_);
647   hlo_builder_.SetLocation(op_->getLoc());
648 
649   // Execute the kernel.
650   tensorflow::OpKernelContext op_context(&params_, op_->getNumResults());
651   device_->Compute(params_.op_kernel, &op_context);
652 
653   status = op_context.status();
654   status.Update(hlo_builder_.GetCurrentStatus());
655   if (!status.ok()) {
656     return op_->emitRemark()
657            << "compilation to HLO failed: " << status.ToString();
658   }
659 
660   // Replace uses of old results using the corresponding value after the
661   // lowering.
662   llvm::SmallVector<Value, 2> values;
663   values.reserve(op_->getNumResults());
664   for (int i = 0, e = op_->getNumResults(); i < e; i++) {
665     tensorflow::Tensor* output = op_context.mutable_output(i);
666     const tensorflow::XlaExpression* expr =
667         tensorflow::XlaExpression::CastExpressionFromTensor(*output);
668     if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp &&
669         expr->kind() != tensorflow::XlaExpression::Kind::kConstant) {
670       return op_->emitRemark(
671           "expects XlaExpression of kind kXlaOp or kConstant in compiled "
672           "output");
673     }
674     mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_));
675     mlir::OpResult old_result = op_->getResult(i);
676     if (value.getType() != old_result.getType()) {
677       value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
678                                                         value);
679     }
680     values.push_back(value);
681   }
682   rewriter_.replaceOp(op_, values);
683   return success();
684 }
685 
GetExprForOperand(Value operand,Operation * op)686 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
687                                                             Operation* op) {
688   ElementsAttr const_attr;
689   auto defining_op = operand.getDefiningOp();
690   if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
691     tensorflow::Tensor tensor;
692     auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
693     if (!status.ok()) {
694       op->emitRemark() << "skipping legalization due to failed const conversion"
695                        << status.ToString();
696       return tensorflow::XlaExpression::Invalid();
697     }
698     return tensorflow::XlaExpression::Constant(tensor);
699   }
700 
701   // Skip this op if XLA doesn't support this operand type.
702   auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
703   if (!xla_op_or.ok()) {
704     op->emitRemark() << "skipping legalization due to "
705                      << xla_op_or.status().ToString();
706     return tensorflow::XlaExpression::Invalid();
707   }
708   ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
709 
710   tensorflow::DataType dtype;
711   auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
712   if (!status.ok()) {
713     op->emitRemark() << "skipping legalization due to " << status.ToString();
714     return tensorflow::XlaExpression::Invalid();
715   }
716   return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
717 }
718 
719 class Tf2XlaRewritePattern : public RewritePattern {
720  public:
Tf2XlaRewritePattern(MLIRContext * ctx,const std::string & device_type,bool prefer_tf2xla,bool legalize_test_only_ops,bool is_module_pass)721   explicit Tf2XlaRewritePattern(MLIRContext* ctx,
722                                 const std::string& device_type,
723                                 bool prefer_tf2xla, bool legalize_test_only_ops,
724                                 bool is_module_pass)
725       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
726         device_type_(device_type),
727         prefer_tf2xla_(prefer_tf2xla),
728         legalize_test_only_ops_(legalize_test_only_ops),
729         is_module_pass_(is_module_pass) {}
730 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const731   LogicalResult matchAndRewrite(Operation* op,
732                                 PatternRewriter& rewriter) const override {
733     if (is_module_pass_) {
734       // Module passes should only ever legalize ops that have been specifically
735       // whitelisted for legalization within a module pass. They will never
736       // legalize any ops whitelisted for legalization within a func pass.
737       if (!IsOpAllowedTf2XlaFallbackAndCreateFunctions(op)) {
738         return failure();
739       }
740     } else if (!(IsOpAllowedTf2XlaFallback(op) ||
741                  (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)) ||
742                  (legalize_test_only_ops_ && IsOpAllowedForTesting(op)))) {
743       return failure();
744     }
745     return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_,
746                                      is_module_pass_);
747   }
748 
749  private:
750   std::string device_type_;
751   bool prefer_tf2xla_;
752   bool legalize_test_only_ops_;
753   bool is_module_pass_;
754 };
755 
756 class LegalizeTF : public LegalizeTFPassBase<LegalizeTF> {
757  public:
758   LegalizeTF() = default;
LegalizeTF(llvm::StringRef device_type,bool prefer_tf2xla)759   explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
760     device_type_ = device_type.str();
761     prefer_tf2xla_ = prefer_tf2xla;
762   }
763 
LegalizeTF(const LegalizeTF &)764   LegalizeTF(const LegalizeTF&) {}
765 
runOnOperation()766   void runOnOperation() override {
767     RewritePatternSet patterns(&getContext());
768     patterns.add<Tf2XlaRewritePattern>(&getContext(), device_type_,
769                                        prefer_tf2xla_, legalize_test_only_ops_,
770                                        /*is_module_pass=*/false);
771     if (failed(
772             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
773       signalPassFailure();
774   }
775 
776  private:
777 };
778 
779 }  // end namespace
780 
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,RewritePatternSet & patterns,MLIRContext * ctx,bool prefer_tf2xla,bool is_module_pass)781 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
782                                           RewritePatternSet& patterns,
783                                           MLIRContext* ctx, bool prefer_tf2xla,
784                                           bool is_module_pass) {
785   patterns.add<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla,
786                                      /*legalize_test_only_ops=*/false,
787                                      is_module_pass);
788 }
789 
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type,bool prefer_tf2xla)790 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTfWithTf2XlaPass(
791     llvm::StringRef device_type, bool prefer_tf2xla) {
792   return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
793 }
794 
795 }  // end namespace mhlo
796 }  // end namespace mlir
797