1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Diagnostics.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
49 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
60 #include "tensorflow/core/framework/allocator.h"
61 #include "tensorflow/core/framework/function.h"
62 #include "tensorflow/core/framework/function.pb.h"
63 #include "tensorflow/core/framework/node_properties.h"
64 #include "tensorflow/core/framework/op.h"
65 #include "tensorflow/core/framework/op_kernel.h"
66 #include "tensorflow/core/framework/resource_mgr.h"
67 #include "tensorflow/core/framework/tensor.h"
68 #include "tensorflow/core/framework/types.h"
69 #include "tensorflow/core/framework/types.pb.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/status.h"
72 #include "tensorflow/core/protobuf/config.pb.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/stream_executor/lib/statusor.h"
75 #include "tensorflow/stream_executor/stream_executor.h"
76
77 namespace mlir {
78 namespace mhlo {
79
IsOpAllowedTf2XlaFallback(Operation * op)80 bool IsOpAllowedTf2XlaFallback(Operation* op) {
81 // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
82 // building valid MLIR using MlirHloBuilder.
83 // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
84 // all tf2xla kernels.
85 // clang-format off
86
87 static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
88 TypeID::get<TF::AbsOp>(),
89 TypeID::get<TF::AcoshOp>(),
90 TypeID::get<TF::AcosOp>(),
91 TypeID::get<TF::AddNOp>(),
92 TypeID::get<TF::AddV2Op>(),
93 TypeID::get<TF::AngleOp>(),
94 TypeID::get<TF::AdjustContrastv2Op>(),
95 TypeID::get<TF::AdjustHueOp>(),
96 TypeID::get<TF::AdjustSaturationOp>(),
97 TypeID::get<TF::ApproximateEqualOp>(),
98 TypeID::get<TF::ArgMaxOp>(),
99 TypeID::get<TF::ArgMinOp>(),
100 TypeID::get<TF::AsinhOp>(),
101 TypeID::get<TF::AsinOp>(),
102 TypeID::get<TF::Atan2Op>(),
103 TypeID::get<TF::AtanhOp>(),
104 TypeID::get<TF::AtanOp>(),
105 TypeID::get<TF::BatchMatMulV2Op>(),
106 TypeID::get<TF::BatchToSpaceOp>(),
107 TypeID::get<TF::BesselI0eOp>(),
108 TypeID::get<TF::BesselI1eOp>(),
109 TypeID::get<TF::BetaincOp>(),
110 TypeID::get<TF::BiasAddGradOp>(),
111 TypeID::get<TF::BiasAddOp>(),
112 TypeID::get<TF::BitwiseAndOp>(),
113 TypeID::get<TF::BitwiseOrOp>(),
114 TypeID::get<TF::BitwiseXorOp>(),
115 TypeID::get<TF::BucketizeOp>(),
116 TypeID::get<TF::CastOp>(),
117 TypeID::get<TF::ClipByValueOp>(),
118 TypeID::get<TF::CholeskyOp>(),
119 TypeID::get<TF::ComplexAbsOp>(),
120 TypeID::get<TF::ConjugateTransposeOp>(),
121 TypeID::get<TF::CoshOp>(),
122 TypeID::get<TF::CrossOp>(),
123 TypeID::get<TF::DataFormatDimMapOp>(),
124 TypeID::get<TF::DataFormatVecPermuteOp>(),
125 TypeID::get<TF::DepthToSpaceOp>(),
126 TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
127 TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
128 TypeID::get<TF::DiagOp>(),
129 TypeID::get<TF::DigammaOp>(),
130 TypeID::get<TF::DivNoNanOp>(),
131 TypeID::get<TF::EluGradOp>(),
132 TypeID::get<TF::EluOp>(),
133 TypeID::get<TF::EqualOp>(),
134 TypeID::get<TF::ErfcOp>(),
135 TypeID::get<TF::ErfinvOp>(),
136 TypeID::get<TF::ErfOp>(),
137 TypeID::get<TF::ExtractImagePatchesOp>(),
138 TypeID::get<TF::FFT2DOp>(),
139 TypeID::get<TF::FFT3DOp>(),
140 TypeID::get<TF::FFTOp>(),
141 TypeID::get<TF::FakeParamOp>(),
142 TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
143 TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
144 TypeID::get<TF::FloorDivOp>(),
145 TypeID::get<TF::FloorModOp>(),
146 TypeID::get<TF::GatherNdOp>(),
147 TypeID::get<TF::GreaterEqualOp>(),
148 TypeID::get<TF::GreaterOp>(),
149 TypeID::get<TF::HSVToRGBOp>(),
150 TypeID::get<TF::IFFT2DOp>(),
151 TypeID::get<TF::IFFT3DOp>(),
152 TypeID::get<TF::IRFFT2DOp>(),
153 TypeID::get<TF::IRFFT3DOp>(),
154 TypeID::get<TF::IgammaOp>(),
155 TypeID::get<TF::IgammacOp>(),
156 TypeID::get<TF::IgammaGradAOp>(),
157 TypeID::get<TF::InplaceAddOp>(),
158 TypeID::get<TF::InTopKV2Op>(),
159 TypeID::get<TF::InvertOp>(),
160 TypeID::get<TF::InvOp>(),
161 TypeID::get<TF::KthOrderStatisticOp>(),
162 TypeID::get<TF::LRNOp>(),
163 TypeID::get<TF::LRNGradOp>(),
164 TypeID::get<TF::LeakyReluGradOp>(),
165 TypeID::get<TF::LeakyReluOp>(),
166 TypeID::get<TF::LeftShiftOp>(),
167 TypeID::get<TF::LessEqualOp>(),
168 TypeID::get<TF::LessOp>(),
169 TypeID::get<TF::ListDiffOp>(),
170 TypeID::get<TF::LogicalAndOp>(),
171 TypeID::get<TF::LogicalNotOp>(),
172 TypeID::get<TF::LogicalOrOp>(),
173 TypeID::get<TF::LogOp>(),
174 TypeID::get<TF::LowerBoundOp>(),
175 TypeID::get<TF::MakeUniqueOp>(),
176 TypeID::get<TF::MatMulOp>(),
177 TypeID::get<TF::MatrixDiagV3Op>(),
178 TypeID::get<TF::MatrixInverseOp>(),
179 TypeID::get<TF::MatrixSetDiagV3Op>(),
180 TypeID::get<TF::MatrixSolveOp>(),
181 TypeID::get<TF::MatrixTriangularSolveOp>(),
182 TypeID::get<TF::MaxPool3DGradGradOp>(),
183 TypeID::get<TF::MaxPoolGradGradOp>(),
184 TypeID::get<TF::MirrorPadOp>(),
185 TypeID::get<TF::MirrorPadGradOp>(),
186 TypeID::get<TF::MulOp>(),
187 TypeID::get<TF::MultinomialOp>(),
188 TypeID::get<TF::NdtriOp>(),
189 TypeID::get<TF::NegOp>(),
190 TypeID::get<TF::NextAfterOp>(),
191 TypeID::get<TF::NonMaxSuppressionV4Op>(),
192 TypeID::get<TF::NotEqualOp>(),
193 TypeID::get<TF::PadOp>(),
194 TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
195 TypeID::get<TF::PlaceholderWithDefaultOp>(),
196 TypeID::get<TF::PolygammaOp>(),
197 TypeID::get<TF::PopulationCountOp>(),
198 TypeID::get<TF::PowOp>(),
199 // TODO(hinsu): Canonicalize QuantizeAndDequantize and
200 // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
201 // attributes to operands.
202 TypeID::get<TF::QuantizeAndDequantizeOp>(),
203 TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
204 TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
205 TypeID::get<TF::RFFT2DOp>(),
206 TypeID::get<TF::RFFT3DOp>(),
207 TypeID::get<TF::RGBToHSVOp>(),
208 TypeID::get<TF::RandomUniformIntOp>(),
209 TypeID::get<TF::RealDivOp>(),
210 TypeID::get<TF::ReciprocalGradOp>(),
211 TypeID::get<TF::Relu6GradOp>(),
212 TypeID::get<TF::ResizeBilinearOp>(),
213 TypeID::get<TF::ResizeBilinearGradOp>(),
214 TypeID::get<TF::ResizeNearestNeighborOp>(),
215 TypeID::get<TF::ResizeNearestNeighborGradOp>(),
216 TypeID::get<TF::ReverseSequenceOp>(),
217 TypeID::get<TF::RightShiftOp>(),
218 TypeID::get<TF::RintOp>(),
219 TypeID::get<TF::RollOp>(),
220 TypeID::get<TF::RoundOp>(),
221 TypeID::get<TF::SelectV2Op>(),
222 TypeID::get<TF::SelfAdjointEigV2Op>(),
223 TypeID::get<TF::SeluGradOp>(),
224 TypeID::get<TF::SeluOp>(),
225 TypeID::get<TF::SigmoidGradOp>(),
226 TypeID::get<TF::SinhOp>(),
227 TypeID::get<TF::SinOp>(),
228 TypeID::get<TF::SoftplusGradOp>(),
229 TypeID::get<TF::SoftsignGradOp>(),
230 TypeID::get<TF::SoftsignOp>(),
231 TypeID::get<TF::SpaceToBatchNDOp>(),
232 TypeID::get<TF::SpaceToBatchOp>(),
233 TypeID::get<TF::SpaceToDepthOp>(),
234 TypeID::get<TF::SparseToDenseOp>(),
235 TypeID::get<TF::SquareOp>(),
236 TypeID::get<TF::StatelessMultinomialOp>(),
237 TypeID::get<TF::StatelessRandomGetAlgOp>(),
238 TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
239 TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
240 TypeID::get<TF::StatelessRandomNormalOp>(),
241 TypeID::get<TF::StatelessRandomNormalV2Op>(),
242 TypeID::get<TF::StatelessRandomUniformOp>(),
243 TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
244 TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
245 TypeID::get<TF::StatelessRandomUniformV2Op>(),
246 TypeID::get<TF::StatelessRandomUniformIntOp>(),
247 TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
248 TypeID::get<TF::StatelessTruncatedNormalOp>(),
249 TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
250 TypeID::get<TF::SubOp>(),
251 TypeID::get<TF::SvdOp>(),
252 TypeID::get<TF::TanOp>(),
253 TypeID::get<TF::TensorScatterAddOp>(),
254 TypeID::get<TF::TensorScatterSubOp>(),
255 TypeID::get<TF::TPUEmbeddingActivationsOp>(),
256 TypeID::get<TF::TopKUniqueOp>(),
257 TypeID::get<TF::TopKWithUniqueOp>(),
258 TypeID::get<TF::TransposeOp>(),
259 TypeID::get<TF::TridiagonalSolveOp>(),
260 TypeID::get<TF::TruncateDivOp>(),
261 TypeID::get<TF::TruncatedNormalOp>(),
262 TypeID::get<TF::TruncateModOp>(),
263 TypeID::get<TF::UnpackOp>(),
264 TypeID::get<TF::UpperBoundOp>(),
265 TypeID::get<TF::XlaBroadcastHelperOp>(),
266 TypeID::get<TF::XlaConvOp>(),
267 TypeID::get<TF::XlaDotOp>(),
268 TypeID::get<TF::XlaDynamicSliceOp>(),
269 TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
270 TypeID::get<TF::XlaEinsumOp>(),
271 TypeID::get<TF::XlaKeyValueSortOp>(),
272 TypeID::get<TF::XlaPadOp>(),
273 TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
274 TypeID::get<TF::XlaSortOp>(),
275 TypeID::get<TF::XlaSvdOp>(),
276 TypeID::get<TF::ZetaOp>()
277 };
278 // clang-format on
279
280 auto* abstractOp = op->getAbstractOperation();
281 if (!abstractOp) return false;
282 return ops.count(abstractOp->typeID);
283 }
284
285 namespace {
286
287 template <typename T, size_t N>
288 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
289
CreateDeviceMgr(const std::string & device_type)290 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
291 const std::string& device_type) {
292 // Register compilation kernels for all registered XLA backends.
293 tensorflow::XlaOpRegistry::RegisterCompilationKernels();
294
295 auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
296 tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
297 return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
298 }
299
300 class Tf2XlaRewriter {
301 public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type)302 static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
303 const std::string& device_type) {
304 Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
305 return tf2xla_rewriter.LegalizeOp();
306 }
307
308 private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type)309 Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
310 const std::string& device_type)
311 : op_(op),
312 device_type_(device_type),
313 rewriter_(rewriter),
314 hlo_builder_(op->getName().getStringRef().str(), rewriter_,
315 op->getLoc()),
316 context_(nullptr) {}
317
~Tf2XlaRewriter()318 ~Tf2XlaRewriter() {
319 if (context_) context_->Unref();
320 }
321
322 // Prepares OpKernelContext params common to all the ops.
323 // Emits an error on failure.
324 LogicalResult PrepareParams();
325
326 // Tries to legalize the specified TensorFlow op, if supported.
327 //
328 // Emits an error and returns failure if an error is encountered during
329 // conversion. Note that success return value doesn't mean successful
330 // legalization.
331 LogicalResult LegalizeOp();
332
333 // Converts the given operand to expression of kind kConstant or kXlaOp.
334 // Emits a remark and returns expression of kind kInvalid on failure.
335 tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
336
337 Operation* op_;
338 std::string device_type_;
339
340 PatternRewriter& rewriter_;
341 ::xla::MlirHloBuilder hlo_builder_;
342 tensorflow::OpOrArgLocNameMapper name_mapper_;
343
344 tensorflow::XlaContext* context_; // Ref-counted.
345
346 std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
347 tensorflow::Device* device_; // Owned by device_mgr_;
348 std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
349 std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
350 std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
351 tensorflow::OpKernelContext::Params params_;
352 };
353
PrepareParams()354 LogicalResult Tf2XlaRewriter::PrepareParams() {
355 // XlaCompiler within the context is only used by the functional ops to
356 // compile functions. We are not handling those at the moment so XlaCompiler
357 // is not required.
358 context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
359 /*graph=*/nullptr);
360 context_->Ref();
361
362 device_mgr_ = CreateDeviceMgr(device_type_);
363 if (!device_mgr_) return failure();
364
365 // Type of params_.device is DeviceBase* so store it as Device* to access
366 // derived class method.
367 device_ = device_mgr_->ListDevices().front();
368 params_.device = device_;
369 params_.resource_manager = device_->resource_manager();
370
371 // Resources are cleared at the time of device manager destruction so pass
372 // no-op cleanup function.
373 auto cleanup = [](const std::string& name) {};
374 // Use step_id zero as we only have a single context concurrently and
375 // concurrently running each of the MLIR functions create a new device.
376 step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
377 /*step_id=*/0, cleanup);
378 tensorflow::Status status = step_container_->Create(
379 device_->resource_manager(),
380 tensorflow::XlaContext::kXlaContextResourceName, context_);
381 if (!status.ok()) {
382 return emitError(op_->getLoc())
383 << "failed to create XlaContext resource: " << status.ToString();
384 }
385 params_.step_container = step_container_.get();
386
387 tensorflow::StatusOr<int64_t> version_or =
388 tensorflow::GetTfGraphProducerVersion(
389 op_->getParentOfType<mlir::ModuleOp>());
390 if (!version_or.ok()) {
391 return emitError(op_->getLoc()) << version_or.status().ToString();
392 }
393
394 flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
395 tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
396 pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
397 device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
398 version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
399 params_.function_library = pflr_->GetFLR(device_->name());
400 return success();
401 }
402
LegalizeOp()403 LogicalResult Tf2XlaRewriter::LegalizeOp() {
404 // Only static shaped operands are supported in XLA builders for now.
405 for (Type ty : op_->getOperandTypes()) {
406 auto ranked_ty = ty.dyn_cast<ShapedType>();
407 if (!ranked_ty || !ranked_ty.hasStaticShape()) {
408 return op_->emitRemark()
409 << "lowering requires static shaped tensor operands";
410 }
411 }
412
413 for (const auto& attr : op_->getAttrs()) {
414 if (attr.second.isa<SymbolRefAttr>()) {
415 return op_->emitRemark()
416 << "ops with symbol references are not supported";
417 }
418 }
419
420 auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
421 op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
422 if (!nodedef_or.ok()) {
423 return op_->emitRemark() << "failed to convert op to NodeDef: "
424 << nodedef_or.status().ToString();
425 }
426
427 if (failed(PrepareParams())) return failure();
428
429 std::shared_ptr<const tensorflow::NodeProperties> props;
430 tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
431 *nodedef_or.ValueOrDie(),
432 params_.function_library->GetFunctionLibraryDefinition(), &props);
433 if (!status.ok()) {
434 return op_->emitRemark()
435 << "failed to create NodeProperties: " << status.ToString();
436 }
437 tensorflow::OpKernel* op_kernel_raw;
438 status = params_.function_library->CreateKernel(props, &op_kernel_raw);
439 if (!status.ok()) {
440 return op_->emitRemark()
441 << "failed to create tf2xla kernel: " << status.ToString();
442 }
443 // Transfer ownership of the kernel to a local smart pointer.
444 auto op_kernel = absl::WrapUnique(op_kernel_raw);
445
446 std::vector<int> required_constants;
447 status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
448 *op_kernel, &required_constants);
449 if (!status.ok()) {
450 return op_->emitRemark()
451 << "failed to compute required constants: " << status.ToString();
452 }
453 llvm::SmallDenseSet<int, 4> required_consts;
454 required_consts.insert(required_constants.begin(), required_constants.end());
455
456 // TensorValue in inputs are backed by tensors which in turn depend on
457 // expressions. So, pre-allocate them to the required size.
458 InlinedVector<tensorflow::XlaExpression, 4> expressions;
459 InlinedVector<tensorflow::Tensor, 4> tensors;
460 InlinedVector<tensorflow::TensorValue, 4> inputs;
461 expressions.reserve(op_->getNumOperands());
462 tensors.reserve(op_->getNumOperands());
463 inputs.reserve(op_->getNumOperands());
464
465 // Prepare the list of Tensor inputs for the kernel.
466 for (auto it : llvm::enumerate(op_->getOperands())) {
467 Value operand = it.value();
468 size_t idx = it.index();
469
470 tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
471 tensorflow::XlaExpression::Kind kind = expr.kind();
472 if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
473 if (required_consts.count(idx) &&
474 kind != tensorflow::XlaExpression::Kind::kConstant) {
475 return op_->emitRemark()
476 << "lowering requires operand #" << idx << " to be a constant";
477 }
478 expressions.push_back(expr);
479
480 if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
481 return op_->emitRemark()
482 << "skipping legalization due to unsupported type "
483 << operand.getType();
484 }
485
486 auto shape_or = expr.GetShape();
487 if (!shape_or.ok()) {
488 return op_->emitRemark()
489 << "failed to get shape for expression. " << expr.HumanString();
490 }
491
492 tensors.emplace_back(
493 device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
494 shape_or.ValueOrDie());
495 tensorflow::Tensor& tensor = tensors.back();
496 tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
497 inputs.emplace_back(&tensor);
498 }
499
500 params_.inputs = &inputs;
501 params_.op_kernel = op_kernel.get();
502 llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
503 op_->getNumResults());
504 params_.output_attr_array = output_attr.data();
505
506 hlo_builder_.setInsertionPoint(op_);
507 hlo_builder_.SetLocation(op_->getLoc());
508
509 // Execute the kernel.
510 tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults());
511 device_->Compute(params_.op_kernel, &op_context);
512 if (!op_context.status().ok()) {
513 return op_->emitRemark()
514 << "compilation to HLO failed: " << op_context.status().ToString();
515 }
516
517 // Replace uses of old results using the corresponding value after the
518 // lowering.
519 llvm::SmallVector<Value, 2> values;
520 values.reserve(op_->getNumResults());
521 for (int i = 0, e = op_->getNumResults(); i < e; i++) {
522 tensorflow::Tensor* output = op_context.mutable_output(i);
523 const tensorflow::XlaExpression* expr =
524 tensorflow::XlaExpression::CastExpressionFromTensor(*output);
525 if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
526 return op_->emitError(
527 "expects XlaExpression of kind kXlaOp in compiled output");
528 auto value = hlo_builder_.GetValue(expr->handle());
529 mlir::OpResult old_result = op_->getResult(i);
530 if (value.getType() != old_result.getType()) {
531 value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
532 value);
533 }
534 values.push_back(value);
535 }
536 rewriter_.replaceOp(op_, values);
537 return success();
538 }
539
GetExprForOperand(Value operand,Operation * op)540 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
541 Operation* op) {
542 ElementsAttr const_attr;
543 auto defining_op = operand.getDefiningOp();
544 if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
545 tensorflow::Tensor tensor;
546 auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
547 if (!status.ok()) {
548 op->emitRemark() << "skipping legalization due to failed const conversion"
549 << status.ToString();
550 return tensorflow::XlaExpression::Invalid();
551 }
552 return tensorflow::XlaExpression::Constant(tensor);
553 }
554
555 // Skip this op if XLA doesn't support this operand type.
556 auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
557 if (!xla_op_or.ok()) {
558 op->emitRemark() << "skipping legalization due to "
559 << xla_op_or.status().ToString();
560 return tensorflow::XlaExpression::Invalid();
561 }
562 ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
563
564 tensorflow::DataType dtype;
565 auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
566 if (!status.ok()) {
567 op->emitRemark() << "skipping legalization due to " << status.ToString();
568 return tensorflow::XlaExpression::Invalid();
569 }
570 return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
571 }
572
573 class Tf2XlaRewritePattern : public RewritePattern {
574 public:
Tf2XlaRewritePattern(const std::string & device_type)575 explicit Tf2XlaRewritePattern(const std::string& device_type)
576 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
577 device_type_(device_type) {}
578
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const579 LogicalResult matchAndRewrite(Operation* op,
580 PatternRewriter& rewriter) const override {
581 if (!IsOpAllowedTf2XlaFallback(op)) return failure();
582 return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
583 }
584
585 private:
586 std::string device_type_;
587 };
588
589 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
590 public:
591 LegalizeTF() = default;
592
LegalizeTF(llvm::StringRef device_type)593 explicit LegalizeTF(llvm::StringRef device_type) {
594 device_type_ = device_type.str();
595 }
596
LegalizeTF(const LegalizeTF &)597 LegalizeTF(const LegalizeTF&) {}
598
runOnFunction()599 void runOnFunction() override {
600 OwningRewritePatternList patterns;
601 patterns.insert<Tf2XlaRewritePattern>(device_type_);
602 if (failed(
603 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
604 signalPassFailure();
605 }
606
607 private:
608 // TODO(hinsu): Support finer grained device type assignment instead of a
609 // global device type for all TensorFlow ops.
610 Option<std::string> device_type_{
611 *this, "device-type",
612 llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
613 };
614
615 static PassRegistration<LegalizeTF> pass(
616 "xla-legalize-tf-with-tf2xla",
617 "Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
618
619 } // end namespace
620
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,OwningRewritePatternList & patterns)621 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
622 OwningRewritePatternList& patterns) {
623 patterns.insert<Tf2XlaRewritePattern>(device_type.str());
624 }
625
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type)626 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
627 llvm::StringRef device_type) {
628 return std::make_unique<LegalizeTF>(device_type);
629 }
630
631 } // end namespace mhlo
632 } // end namespace mlir
633