• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
43 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
49 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
60 #include "tensorflow/core/framework/allocator.h"
61 #include "tensorflow/core/framework/function.h"
62 #include "tensorflow/core/framework/function.pb.h"
63 #include "tensorflow/core/framework/node_properties.h"
64 #include "tensorflow/core/framework/op.h"
65 #include "tensorflow/core/framework/op_kernel.h"
66 #include "tensorflow/core/framework/resource_mgr.h"
67 #include "tensorflow/core/framework/tensor.h"
68 #include "tensorflow/core/framework/types.h"
69 #include "tensorflow/core/framework/types.pb.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/status.h"
72 #include "tensorflow/core/protobuf/config.pb.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/stream_executor/lib/statusor.h"
75 #include "tensorflow/stream_executor/stream_executor.h"
76 
77 namespace mlir {
78 namespace mhlo {
79 
80 // LINT.IfChange
IsOpAllowedTf2XlaFallback(Operation * op)81 bool IsOpAllowedTf2XlaFallback(Operation* op) {
82   // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
83   // building valid MLIR using MlirHloBuilder.
84   // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
85   // all tf2xla kernels.
86   // Use a pointer for the static set, so the set is not destructed upon thread
87   // end, which would not be thread safe.
88   // clang-format off
89 
90   static auto* ops =
91       new llvm::SmallDenseSet<mlir::TypeID, 512>{
92     TypeID::get<TF::AcoshOp>(),
93     TypeID::get<TF::AcosOp>(),
94     TypeID::get<TF::AddNOp>(),
95     TypeID::get<TF::AddV2Op>(),
96     TypeID::get<TF::AngleOp>(),
97     TypeID::get<TF::AdjustContrastv2Op>(),
98     TypeID::get<TF::AdjustHueOp>(),
99     TypeID::get<TF::AdjustSaturationOp>(),
100     TypeID::get<TF::ApproximateEqualOp>(),
101     TypeID::get<TF::ArgMaxOp>(),
102     TypeID::get<TF::ArgMinOp>(),
103     TypeID::get<TF::AsinhOp>(),
104     TypeID::get<TF::AsinOp>(),
105     TypeID::get<TF::Atan2Op>(),
106     TypeID::get<TF::AtanhOp>(),
107     TypeID::get<TF::BatchMatMulV2Op>(),
108     TypeID::get<TF::BatchMatMulV3Op>(),
109     TypeID::get<TF::BatchToSpaceOp>(),
110     TypeID::get<TF::BesselI0eOp>(),
111     TypeID::get<TF::BesselI1eOp>(),
112     TypeID::get<TF::BetaincOp>(),
113     TypeID::get<TF::BiasAddOp>(),
114     TypeID::get<TF::BitwiseAndOp>(),
115     TypeID::get<TF::BitwiseOrOp>(),
116     TypeID::get<TF::BitwiseXorOp>(),
117     TypeID::get<TF::BucketizeOp>(),
118     TypeID::get<TF::CastOp>(),
119     TypeID::get<TF::ClipByValueOp>(),
120     TypeID::get<TF::CholeskyOp>(),
121     TypeID::get<TF::ComplexAbsOp>(),
122     TypeID::get<TF::ConjugateTransposeOp>(),
123     TypeID::get<TF::CoshOp>(),
124     TypeID::get<TF::CrossOp>(),
125     TypeID::get<TF::DataFormatDimMapOp>(),
126     TypeID::get<TF::DataFormatVecPermuteOp>(),
127     TypeID::get<TF::DepthToSpaceOp>(),
128     TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
129     TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
130     TypeID::get<TF::DiagOp>(),
131     TypeID::get<TF::DigammaOp>(),
132     TypeID::get<TF::DivNoNanOp>(),
133     TypeID::get<TF::EluGradOp>(),
134     TypeID::get<TF::EluOp>(),
135     TypeID::get<TF::EnsureShapeOp>(),
136     TypeID::get<TF::EqualOp>(),
137     TypeID::get<TF::ErfcOp>(),
138     TypeID::get<TF::ErfinvOp>(),
139     TypeID::get<TF::ErfOp>(),
140     TypeID::get<TF::ExtractImagePatchesOp>(),
141     TypeID::get<TF::FFT2DOp>(),
142     TypeID::get<TF::FFT3DOp>(),
143     TypeID::get<TF::FFTOp>(),
144     TypeID::get<TF::FakeParamOp>(),
145     TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
146     TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
147     TypeID::get<TF::FloorDivOp>(),
148     TypeID::get<TF::FloorModOp>(),
149     TypeID::get<TF::GreaterOp>(),
150     TypeID::get<TF::HSVToRGBOp>(),
151     TypeID::get<TF::IFFT2DOp>(),
152     TypeID::get<TF::IFFT3DOp>(),
153     TypeID::get<TF::IRFFT2DOp>(),
154     TypeID::get<TF::IRFFT3DOp>(),
155     TypeID::get<TF::IgammaOp>(),
156     TypeID::get<TF::IgammacOp>(),
157     TypeID::get<TF::IgammaGradAOp>(),
158     TypeID::get<TF::InplaceAddOp>(),
159     TypeID::get<TF::InTopKV2Op>(),
160     TypeID::get<TF::InvertOp>(),
161     TypeID::get<TF::InvOp>(),
162     TypeID::get<TF::KthOrderStatisticOp>(),
163     TypeID::get<TF::LRNOp>(),
164     TypeID::get<TF::LRNGradOp>(),
165     TypeID::get<TF::LeakyReluGradOp>(),
166     TypeID::get<TF::LeakyReluOp>(),
167     TypeID::get<TF::LeftShiftOp>(),
168     TypeID::get<TF::LessOp>(),
169     TypeID::get<TF::ListDiffOp>(),
170     TypeID::get<TF::LogicalAndOp>(),
171     TypeID::get<TF::LogicalNotOp>(),
172     TypeID::get<TF::LogOp>(),
173     TypeID::get<TF::LowerBoundOp>(),
174     TypeID::get<TF::MakeUniqueOp>(),
175     TypeID::get<TF::MatMulOp>(),
176     TypeID::get<TF::MatrixDiagV3Op>(),
177     TypeID::get<TF::MatrixInverseOp>(),
178     TypeID::get<TF::MatrixSetDiagV3Op>(),
179     TypeID::get<TF::MatrixSolveOp>(),
180     TypeID::get<TF::MatrixTriangularSolveOp>(),
181     TypeID::get<TF::MaxPool3DGradGradOp>(),
182     TypeID::get<TF::MaxPoolGradGradOp>(),
183     TypeID::get<TF::MirrorPadOp>(),
184     TypeID::get<TF::MirrorPadGradOp>(),
185     TypeID::get<TF::MulOp>(),
186     TypeID::get<TF::MultinomialOp>(),
187     TypeID::get<TF::NdtriOp>(),
188     TypeID::get<TF::NegOp>(),
189     TypeID::get<TF::NextAfterOp>(),
190     TypeID::get<TF::NonMaxSuppressionV4Op>(),
191     TypeID::get<TF::NotEqualOp>(),
192     TypeID::get<TF::PadOp>(),
193     TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
194     TypeID::get<TF::PlaceholderWithDefaultOp>(),
195     TypeID::get<TF::PolygammaOp>(),
196     TypeID::get<TF::PopulationCountOp>(),
197     TypeID::get<TF::PowOp>(),
198     // TODO(hinsu): Canonicalize QuantizeAndDequantize and
199     // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
200     // attributes to operands.
201     TypeID::get<TF::QuantizeAndDequantizeOp>(),
202     TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
203     TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
204     TypeID::get<TF::QuantizeAndDequantizeV4Op>(),
205     TypeID::get<TF::RFFT2DOp>(),
206     TypeID::get<TF::RFFT3DOp>(),
207     TypeID::get<TF::RGBToHSVOp>(),
208     TypeID::get<TF::RandomUniformIntOp>(),
209     TypeID::get<TF::RealDivOp>(),
210     TypeID::get<TF::ReciprocalGradOp>(),
211     TypeID::get<TF::Relu6GradOp>(),
212     TypeID::get<TF::ResizeBilinearOp>(),
213     TypeID::get<TF::ResizeBilinearGradOp>(),
214     TypeID::get<TF::ResizeNearestNeighborOp>(),
215     TypeID::get<TF::ResizeNearestNeighborGradOp>(),
216     TypeID::get<TF::ReverseSequenceOp>(),
217     TypeID::get<TF::RightShiftOp>(),
218     TypeID::get<TF::RintOp>(),
219     TypeID::get<TF::RollOp>(),
220     TypeID::get<TF::RoundOp>(),
221     TypeID::get<TF::SelectV2Op>(),
222     TypeID::get<TF::SelfAdjointEigV2Op>(),
223     TypeID::get<TF::SeluGradOp>(),
224     TypeID::get<TF::SeluOp>(),
225     TypeID::get<TF::SigmoidGradOp>(),
226     TypeID::get<TF::SinOp>(),
227     TypeID::get<TF::SoftplusGradOp>(),
228     TypeID::get<TF::SoftsignGradOp>(),
229     TypeID::get<TF::SoftsignOp>(),
230     TypeID::get<TF::SpaceToBatchNDOp>(),
231     TypeID::get<TF::SpaceToBatchOp>(),
232     TypeID::get<TF::SpaceToDepthOp>(),
233     TypeID::get<TF::SparseToDenseOp>(),
234     TypeID::get<TF::SquareOp>(),
235     TypeID::get<TF::StatelessMultinomialOp>(),
236     TypeID::get<TF::StatelessRandomGetAlgOp>(),
237     TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
238     TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
239     TypeID::get<TF::StatelessRandomNormalOp>(),
240     TypeID::get<TF::StatelessRandomNormalV2Op>(),
241     TypeID::get<TF::StatelessRandomUniformOp>(),
242     TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
243     TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
244     TypeID::get<TF::StatelessRandomUniformV2Op>(),
245     TypeID::get<TF::StatelessRandomUniformIntOp>(),
246     TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
247     TypeID::get<TF::StatelessTruncatedNormalOp>(),
248     TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
249     TypeID::get<TF::SubOp>(),
250     TypeID::get<TF::SvdOp>(),
251     TypeID::get<TF::TanOp>(),
252     TypeID::get<TF::TensorScatterAddOp>(),
253     TypeID::get<TF::TensorScatterSubOp>(),
254     TypeID::get<TF::TPUEmbeddingActivationsOp>(),
255     TypeID::get<TF::TopKUniqueOp>(),
256     TypeID::get<TF::TopKWithUniqueOp>(),
257     TypeID::get<TF::TransposeOp>(),
258     TypeID::get<TF::TridiagonalSolveOp>(),
259     TypeID::get<TF::TruncateDivOp>(),
260     TypeID::get<TF::TruncatedNormalOp>(),
261     TypeID::get<TF::TruncateModOp>(),
262     TypeID::get<TF::UnpackOp>(),
263     TypeID::get<TF::UpperBoundOp>(),
264     TypeID::get<TF::XlaBroadcastHelperOp>(),
265     TypeID::get<TF::XlaConvOp>(),
266     TypeID::get<TF::XlaConvV2Op>(),
267     TypeID::get<TF::XlaDotOp>(),
268     TypeID::get<TF::XlaDotV2Op>(),
269     TypeID::get<TF::XlaDynamicSliceOp>(),
270     TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
271     TypeID::get<TF::XlaEinsumOp>(),
272     TypeID::get<TF::XlaKeyValueSortOp>(),
273     TypeID::get<TF::XlaPadOp>(),
274     TypeID::get<TF::XlaSortOp>(),
275     TypeID::get<TF::XlaSvdOp>(),
276   };
277   // clang-format on
278 
279   auto* abstractOp = op->getAbstractOperation();
280   if (!abstractOp) return false;
281   return ops->count(abstractOp->typeID);
282 }
283 // LINT.ThenChange(:Tf2XlaPreferred)
284 
285 /// List of ops that should use XlaOpKernel legalization only in the case of
286 /// prefer_tf2xla. All other ops not in this list should use MLIR legalization
287 /// only or not be legalized by the new bridge.
288 // LINT.IfChange(Tf2XlaPreferred)
IsOpAllowedTf2XlaPreferred(Operation * op)289 bool IsOpAllowedTf2XlaPreferred(Operation* op) {
290   // Use a pointer for the static set, so the set is not destructed upon thread
291   // end, which would not be thread safe.
292   // clang-format off
293   static auto* ops =
294       new llvm::SmallDenseSet<mlir::TypeID, 512>{
295     TypeID::get<TF::AllOp>(),
296     TypeID::get<TF::AllToAllOp>(),
297     TypeID::get<TF::AnyOp>(),
298     TypeID::get<TF::AvgPoolOp>(),
299     TypeID::get<TF::AvgPool3DGradOp>(),
300     TypeID::get<TF::AvgPoolGradOp>(),
301     TypeID::get<TF::BatchToSpaceNDOp>(),
302     TypeID::get<TF::BitcastOp>(),
303     TypeID::get<TF::BroadcastToOp>(),
304     TypeID::get<TF::CollectivePermuteOp>(),
305     TypeID::get<TF::ConcatV2Op>(),
306     TypeID::get<TF::ConjOp>(),
307     TypeID::get<TF::Conv2DOp>(),
308     TypeID::get<TF::Conv2DBackpropFilterOp>(),
309     TypeID::get<TF::Conv2DBackpropInputOp>(),
310     TypeID::get<TF::Conv3DOp>(),
311     TypeID::get<TF::Conv3DBackpropFilterV2Op>(),
312     TypeID::get<TF::Conv3DBackpropInputV2Op>(),
313     TypeID::get<TF::CumprodOp>(),
314     TypeID::get<TF::CumsumOp>(),
315     TypeID::get<TF::DepthwiseConv2dNativeOp>(),
316     TypeID::get<TF::DynamicStitchOp>(),
317     TypeID::get<TF::_EagerConstOp>(),
318     TypeID::get<TF::EmptyOp>(),
319     TypeID::get<TF::ExpandDimsOp>(),
320     TypeID::get<TF::FillOp>(),
321     TypeID::get<TF::FusedBatchNormOp>(),
322     TypeID::get<TF::FusedBatchNormGradOp>(),
323     TypeID::get<TF::FusedBatchNormGradV2Op>(),
324     TypeID::get<TF::FusedBatchNormGradV3Op>(),
325     TypeID::get<TF::FusedBatchNormV2Op>(),
326     TypeID::get<TF::FusedBatchNormV3Op>(),
327     TypeID::get<TF::GatherNdOp>(),
328     TypeID::get<TF::GatherV2Op>(),
329     TypeID::get<TF::IdentityOp>(),
330     TypeID::get<TF::IdentityNOp>(),
331     TypeID::get<TF::InplaceUpdateOp>(),
332     TypeID::get<TF::InvertPermutationOp>(),
333     TypeID::get<TF::IRFFTOp>(),
334     TypeID::get<TF::L2LossOp>(),
335     TypeID::get<TF::LegacyCallOp>(),
336     TypeID::get<TF::LinSpaceOp>(),
337     TypeID::get<TF::MatrixDiagPartV3Op>(),
338     TypeID::get<TF::MaxOp>(),
339     TypeID::get<TF::MaximumOp>(),
340     TypeID::get<TF::MaxPoolOp>(),
341     TypeID::get<TF::MaxPool3DOp>(),
342     TypeID::get<TF::MaxPoolGradOp>(),
343     TypeID::get<TF::MeanOp>(),
344     TypeID::get<TF::MinOp>(),
345     TypeID::get<TF::MinimumOp>(),
346     TypeID::get<TF::MulNoNanOp>(),
347     TypeID::get<TF::OneHotOp>(),
348     TypeID::get<TF::OnesLikeOp>(),
349     TypeID::get<TF::PackOp>(),
350     TypeID::get<TF::PadV2Op>(),
351     TypeID::get<TF::ParallelDynamicStitchOp>(),
352     TypeID::get<TF::PartitionedCallOp>(),
353     TypeID::get<TF::ProdOp>(),
354     TypeID::get<TF::QrOp>(),
355     TypeID::get<TF::RandomStandardNormalOp>(),
356     TypeID::get<TF::RandomUniformOp>(),
357     TypeID::get<TF::RangeOp>(),
358     TypeID::get<TF::ReshapeOp>(),
359     TypeID::get<TF::ReverseV2Op>(),
360     TypeID::get<TF::RFFTOp>(),
361     TypeID::get<TF::RsqrtGradOp>(),
362     TypeID::get<TF::ScatterNdOp>(),
363     TypeID::get<TF::ShapeOp>(),
364     TypeID::get<TF::SinhOp>(),
365     TypeID::get<TF::SizeOp>(),
366     TypeID::get<TF::SliceOp>(),
367     TypeID::get<TF::SoftmaxCrossEntropyWithLogitsOp>(),
368     TypeID::get<TF::SoftplusOp>(),
369     TypeID::get<TF::SparseMatMulOp>(),
370     TypeID::get<TF::SparseSoftmaxCrossEntropyWithLogitsOp>(),
371     TypeID::get<TF::SplitOp>(),
372     TypeID::get<TF::SplitVOp>(),
373     TypeID::get<TF::SqueezeOp>(),
374     TypeID::get<TF::StatefulPartitionedCallOp>(),
375     TypeID::get<TF::StopGradientOp>(),
376     TypeID::get<TF::StridedSliceOp>(),
377     TypeID::get<TF::StridedSliceGradOp>(),
378     TypeID::get<TF::SumOp>(),
379     TypeID::get<TF::TensorScatterUpdateOp>(),
380     TypeID::get<TF::TileOp>(),
381     TypeID::get<TF::TopKV2Op>(),
382     TypeID::get<TF::_UnaryOpsCompositionOp>(),
383     TypeID::get<TF::UnsortedSegmentMaxOp>(),
384     TypeID::get<TF::UnsortedSegmentMinOp>(),
385     TypeID::get<TF::UnsortedSegmentProdOp>(),
386     TypeID::get<TF::UnsortedSegmentSumOp>(),
387     TypeID::get<TF::XdivyOp>(),
388     TypeID::get<TF::XlaAllReduceOp>(),
389     TypeID::get<TF::XlaGatherOp>(),
390     TypeID::get<TF::XlaReplicaIdOp>(),
391     TypeID::get<TF::Xlog1pyOp>(),
392     TypeID::get<TF::ZerosLikeOp>(),
393 
394     // XlaOpKernel makes use of compiler options which we don't feed in the
395     // fallback.
396     // TypeID::get<TF::FakeQuantWithMinMaxVarsOp>(),
397   };
398   // clang-format on
399   auto* abstractOp = op->getAbstractOperation();
400   if (!abstractOp) return false;
401   return ops->count(abstractOp->typeID);
402 }
403 // LINT.ThenChange()
404 
IsOpAllowedForTesting(Operation * op)405 bool IsOpAllowedForTesting(Operation* op) {
406   // clang-format off
407   static auto* ops =
408       new llvm::SmallDenseSet<mlir::TypeID, 16>{
409     // Op used to verify handling of XlaExpression of kind constant.
410     TypeID::get<TF::ConstOp>(),
411   };
412   // clang-format on
413   auto* abstractOp = op->getAbstractOperation();
414   if (!abstractOp) return false;
415   return ops->count(abstractOp->typeID);
416 }
417 
418 namespace {
419 
420 template <typename T, size_t N>
421 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>;  // non-absl ok
422 
CreateDeviceMgr(const std::string & device_type)423 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
424     const std::string& device_type) {
425   // Register compilation kernels for all registered XLA backends.
426   tensorflow::XlaOpRegistry::RegisterCompilationKernels();
427 
428   auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
429       tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
430   return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
431 }
432 
433 class Tf2XlaRewriter {
434  public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type)435   static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
436                                  const std::string& device_type) {
437     Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
438     return tf2xla_rewriter.LegalizeOp();
439   }
440 
441  private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type)442   Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
443                  const std::string& device_type)
444       : op_(op),
445         device_type_(device_type),
446         rewriter_(rewriter),
447         hlo_builder_(op->getName().getStringRef().str(), rewriter_,
448                      op->getLoc()),
449         context_(nullptr) {}
450 
~Tf2XlaRewriter()451   ~Tf2XlaRewriter() {
452     if (context_) context_->Unref();
453   }
454 
455   // Prepares OpKernelContext params common to all the ops.
456   // Emits an error on failure.
457   LogicalResult PrepareParams();
458 
459   // Tries to legalize the specified TensorFlow op, if supported.
460   //
461   // Emits an error and returns failure if an error is encountered during
462   // conversion. Note that success return value doesn't mean successful
463   // legalization.
464   LogicalResult LegalizeOp();
465 
466   // Converts the given operand to expression of kind kConstant or kXlaOp.
467   // Emits a remark and returns expression of kind kInvalid on failure.
468   tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
469 
470   Operation* op_;
471   std::string device_type_;
472 
473   PatternRewriter& rewriter_;
474   ::xla::MlirHloBuilder hlo_builder_;
475   tensorflow::OpOrArgLocNameMapper name_mapper_;
476 
477   tensorflow::XlaContext* context_;  // Ref-counted.
478 
479   std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
480   tensorflow::Device* device_;  // Owned by device_mgr_;
481   std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
482   std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
483   std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
484   tensorflow::OpKernelContext::Params params_;
485 };
486 
PrepareParams()487 LogicalResult Tf2XlaRewriter::PrepareParams() {
488   // XlaCompiler within the context is only used by the functional ops to
489   // compile functions. We are not handling those at the moment so XlaCompiler
490   // is not required.
491   context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
492                                         /*graph=*/nullptr);
493   context_->Ref();
494 
495   device_mgr_ = CreateDeviceMgr(device_type_);
496   if (!device_mgr_) return failure();
497 
498   // Type of params_.device is DeviceBase* so store it as Device* to access
499   // derived class method.
500   device_ = device_mgr_->ListDevices().front();
501   params_.device = device_;
502   params_.resource_manager = device_->resource_manager();
503 
504   // Resources are cleared at the time of device manager destruction so pass
505   // no-op cleanup function.
506   auto cleanup = [](const std::string& name) {};
507   // Use step_id zero as we only have a single context concurrently and
508   // concurrently running each of the MLIR functions create a new device.
509   step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
510       /*step_id=*/0, cleanup);
511   tensorflow::Status status = step_container_->Create(
512       device_->resource_manager(),
513       tensorflow::XlaContext::kXlaContextResourceName, context_);
514   if (!status.ok()) {
515     return emitRemark(op_->getLoc())
516            << "failed to create XlaContext resource: " << status.ToString();
517   }
518   params_.step_container = step_container_.get();
519 
520   tensorflow::StatusOr<int64_t> version_or =
521       tensorflow::GetTfGraphProducerVersion(
522           op_->getParentOfType<mlir::ModuleOp>());
523   if (!version_or.ok()) {
524     return emitError(op_->getLoc()) << version_or.status().ToString();
525   }
526 
527   flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
528       tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
529   pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
530       device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
531       version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
532   params_.function_library = pflr_->GetFLR(device_->name());
533   return success();
534 }
535 
LegalizeOp()536 LogicalResult Tf2XlaRewriter::LegalizeOp() {
537   // Only static shaped operands are supported in XLA builders for now.
538   for (Type ty : op_->getOperandTypes()) {
539     auto ranked_ty = ty.dyn_cast<ShapedType>();
540     if (!ranked_ty || !ranked_ty.hasStaticShape()) {
541       return op_->emitRemark()
542              << "lowering requires static shaped tensor operands";
543     }
544   }
545 
546   for (const auto& attr : op_->getAttrs()) {
547     if (attr.second.isa<SymbolRefAttr>()) {
548       return op_->emitRemark()
549              << "ops with symbol references are not supported";
550     }
551   }
552 
553   auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
554       op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
555   if (!nodedef_or.ok()) {
556     return op_->emitRemark() << "failed to convert op to NodeDef: "
557                              << nodedef_or.status().ToString();
558   }
559 
560   if (failed(PrepareParams())) return failure();
561 
562   std::shared_ptr<const tensorflow::NodeProperties> props;
563   tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
564       *nodedef_or.ValueOrDie(),
565       params_.function_library->GetFunctionLibraryDefinition(), &props);
566   if (!status.ok()) {
567     return op_->emitRemark()
568            << "failed to create NodeProperties: " << status.ToString();
569   }
570   tensorflow::OpKernel* op_kernel_raw;
571   status = params_.function_library->CreateKernel(props, &op_kernel_raw);
572   if (!status.ok()) {
573     return op_->emitRemark()
574            << "failed to create tf2xla kernel: " << status.ToString();
575   }
576   // Transfer ownership of the kernel to a local smart pointer.
577   auto op_kernel = absl::WrapUnique(op_kernel_raw);
578 
579   std::vector<int> required_constants;
580   status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
581       *op_kernel, &required_constants);
582   if (!status.ok()) {
583     return op_->emitRemark()
584            << "failed to compute required constants: " << status.ToString();
585   }
586   llvm::SmallDenseSet<int, 4> required_consts;
587   required_consts.insert(required_constants.begin(), required_constants.end());
588 
589   // TensorValue in inputs are backed by tensors which in turn depend on
590   // expressions. So, pre-allocate them to the required size.
591   InlinedVector<tensorflow::XlaExpression, 4> expressions;
592   InlinedVector<tensorflow::Tensor, 4> tensors;
593   InlinedVector<tensorflow::TensorValue, 4> inputs;
594   expressions.reserve(op_->getNumOperands());
595   tensors.reserve(op_->getNumOperands());
596   inputs.reserve(op_->getNumOperands());
597 
598   // Prepare the list of Tensor inputs for the kernel.
599   for (auto it : llvm::enumerate(op_->getOperands())) {
600     Value operand = it.value();
601     size_t idx = it.index();
602 
603     tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
604     tensorflow::XlaExpression::Kind kind = expr.kind();
605     if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
606     if (required_consts.count(idx) &&
607         kind != tensorflow::XlaExpression::Kind::kConstant) {
608       return op_->emitRemark()
609              << "lowering requires operand #" << idx << " to be a constant";
610     }
611     expressions.push_back(expr);
612 
613     if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
614       return op_->emitRemark()
615              << "skipping legalization due to unsupported type "
616              << operand.getType();
617     }
618 
619     auto shape_or = expr.GetShape();
620     if (!shape_or.ok()) {
621       return op_->emitRemark()
622              << "failed to get shape for expression. " << expr.HumanString();
623     }
624 
625     tensors.emplace_back(
626         device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
627         shape_or.ValueOrDie());
628     tensorflow::Tensor& tensor = tensors.back();
629     tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
630     inputs.emplace_back(&tensor);
631   }
632 
633   params_.inputs = &inputs;
634   params_.op_kernel = op_kernel.get();
635   llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
636       op_->getNumResults());
637   params_.output_attr_array = output_attr.data();
638 
639   hlo_builder_.setInsertionPoint(op_);
640   hlo_builder_.SetLocation(op_->getLoc());
641 
642   // Execute the kernel.
643   tensorflow::OpKernelContext op_context(&params_, op_->getNumResults());
644   device_->Compute(params_.op_kernel, &op_context);
645 
646   status = op_context.status();
647   status.Update(hlo_builder_.GetCurrentStatus());
648   if (!status.ok()) {
649     return op_->emitRemark()
650            << "compilation to HLO failed: " << status.ToString();
651   }
652 
653   // Replace uses of old results using the corresponding value after the
654   // lowering.
655   llvm::SmallVector<Value, 2> values;
656   values.reserve(op_->getNumResults());
657   for (int i = 0, e = op_->getNumResults(); i < e; i++) {
658     tensorflow::Tensor* output = op_context.mutable_output(i);
659     const tensorflow::XlaExpression* expr =
660         tensorflow::XlaExpression::CastExpressionFromTensor(*output);
661     if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp &&
662         expr->kind() != tensorflow::XlaExpression::Kind::kConstant) {
663       return op_->emitRemark(
664           "expects XlaExpression of kind kXlaOp or kConstant in compiled "
665           "output");
666     }
667     mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_));
668     mlir::OpResult old_result = op_->getResult(i);
669     if (value.getType() != old_result.getType()) {
670       value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
671                                                         value);
672     }
673     values.push_back(value);
674   }
675   rewriter_.replaceOp(op_, values);
676   return success();
677 }
678 
GetExprForOperand(Value operand,Operation * op)679 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
680                                                             Operation* op) {
681   ElementsAttr const_attr;
682   auto defining_op = operand.getDefiningOp();
683   if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
684     tensorflow::Tensor tensor;
685     auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
686     if (!status.ok()) {
687       op->emitRemark() << "skipping legalization due to failed const conversion"
688                        << status.ToString();
689       return tensorflow::XlaExpression::Invalid();
690     }
691     return tensorflow::XlaExpression::Constant(tensor);
692   }
693 
694   // Skip this op if XLA doesn't support this operand type.
695   auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
696   if (!xla_op_or.ok()) {
697     op->emitRemark() << "skipping legalization due to "
698                      << xla_op_or.status().ToString();
699     return tensorflow::XlaExpression::Invalid();
700   }
701   ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
702 
703   tensorflow::DataType dtype;
704   auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
705   if (!status.ok()) {
706     op->emitRemark() << "skipping legalization due to " << status.ToString();
707     return tensorflow::XlaExpression::Invalid();
708   }
709   return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
710 }
711 
712 class Tf2XlaRewritePattern : public RewritePattern {
713  public:
Tf2XlaRewritePattern(MLIRContext * ctx,const std::string & device_type,bool prefer_tf2xla,bool legalize_test_only_ops)714   explicit Tf2XlaRewritePattern(MLIRContext* ctx,
715                                 const std::string& device_type,
716                                 bool prefer_tf2xla, bool legalize_test_only_ops)
717       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
718         device_type_(device_type),
719         prefer_tf2xla_(prefer_tf2xla),
720         legalize_test_only_ops_(legalize_test_only_ops) {}
721 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const722   LogicalResult matchAndRewrite(Operation* op,
723                                 PatternRewriter& rewriter) const override {
724     if (!(IsOpAllowedTf2XlaFallback(op) ||
725           (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)) ||
726           (legalize_test_only_ops_ && IsOpAllowedForTesting(op))))
727       return failure();
728     return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
729   }
730 
731  private:
732   std::string device_type_;
733   bool prefer_tf2xla_;
734   bool legalize_test_only_ops_;
735 };
736 
737 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
738  public:
739   LegalizeTF() = default;
740 
getDependentDialects(DialectRegistry & registry) const741   void getDependentDialects(DialectRegistry& registry) const override {
742     registry.insert<MhloDialect>();
743   }
744 
LegalizeTF(llvm::StringRef device_type,bool prefer_tf2xla)745   explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
746     device_type_ = device_type.str();
747     prefer_tf2xla_ = prefer_tf2xla;
748   }
749 
LegalizeTF(const LegalizeTF &)750   LegalizeTF(const LegalizeTF&) {}
751 
getArgument() const752   StringRef getArgument() const final { return "xla-legalize-tf-with-tf2xla"; }
getDescription() const753   StringRef getDescription() const final {
754     return "Legalize from TensorFlow to the HLO dialect using tf2xla kernels";
755   }
756 
runOnFunction()757   void runOnFunction() override {
758     OwningRewritePatternList patterns(&getContext());
759     patterns.insert<Tf2XlaRewritePattern>(
760         &getContext(), device_type_, prefer_tf2xla_, legalize_test_only_ops_);
761     if (failed(
762             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
763       signalPassFailure();
764   }
765 
766  private:
767   // TODO(hinsu): Support finer grained device type assignment instead of a
768   // global device type for all TensorFlow ops.
769   Option<std::string> device_type_{
770       *this, "device-type",
771       llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
772   Option<bool> prefer_tf2xla_{
773       *this,
774       "prefer-tf2xla",
775       llvm::cl::desc("Enable legalization when it is not in the list of "
776                      "MLIR-legalized ops."),
777   };
778   Option<bool> legalize_test_only_ops_{
779       *this,
780       "legalize-test-only-ops",
781       llvm::cl::desc("Enable tf2xla legalizations for some ops that are "
782                      "enabled only for testing."),
783   };
784 };
785 
786 static PassRegistration<LegalizeTF> pass;
787 
788 }  // end namespace
789 
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,OwningRewritePatternList & patterns,MLIRContext * ctx,bool prefer_tf2xla)790 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
791                                           OwningRewritePatternList& patterns,
792                                           MLIRContext* ctx,
793                                           bool prefer_tf2xla) {
794   patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla,
795                                         /*legalize_test_only_ops=*/false);
796 }
797 
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type,bool prefer_tf2xla)798 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
799     llvm::StringRef device_type, bool prefer_tf2xla) {
800   return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
801 }
802 
803 }  // end namespace mhlo
804 }  // end namespace mlir
805