1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" 17 18 #include <algorithm> 19 #include <cstdint> 20 #include <functional> 21 #include <limits> 22 #include <numeric> 23 #include <string> 24 #include <tuple> 25 #include <type_traits> 26 27 #include "llvm/ADT/APFloat.h" 28 #include "llvm/ADT/APInt.h" 29 #include "llvm/ADT/ArrayRef.h" 30 #include "llvm/ADT/Optional.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/Sequence.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/StringSwitch.h" 37 #include "llvm/ADT/iterator_range.h" 38 #include "llvm/Support/Casting.h" 39 #include "llvm/Support/FormatVariadic.h" 40 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 41 #include "mlir/Dialect/Traits.h" // from @llvm-project 42 #include "mlir/IR/Attributes.h" // from @llvm-project 43 #include "mlir/IR/Builders.h" // from @llvm-project 44 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 45 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 46 #include "mlir/IR/Diagnostics.h" // from @llvm-project 47 #include "mlir/IR/DialectImplementation.h" // from @llvm-project 48 #include "mlir/IR/Identifier.h" // from @llvm-project 49 #include "mlir/IR/Location.h" // from @llvm-project 50 #include "mlir/IR/MLIRContext.h" // from @llvm-project 51 #include "mlir/IR/Matchers.h" // from @llvm-project 52 #include "mlir/IR/OpDefinition.h" // from @llvm-project 53 #include "mlir/IR/OpImplementation.h" // from @llvm-project 54 #include "mlir/IR/PatternMatch.h" // from @llvm-project 55 #include "mlir/IR/TypeUtilities.h" // from @llvm-project 56 #include "mlir/IR/Types.h" // from @llvm-project 57 #include "mlir/IR/Value.h" // from @llvm-project 58 #include "mlir/Parser.h" // from @llvm-project 59 #include "mlir/Support/LLVM.h" // from @llvm-project 60 #include "mlir/Support/LogicalResult.h" // from @llvm-project 61 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project 62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" 63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" 64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" 66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" 67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" 68 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" 69 #include "tensorflow/core/platform/logging.h" 70 #include "tensorflow/core/util/tensor_format.h" 71 72 namespace mlir { 73 namespace TF { 74 namespace { 75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" 76 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" 77 } // namespace 78 79 //===----------------------------------------------------------------------===// 80 // _XlaHostComputeOp 81 //===----------------------------------------------------------------------===// 82 83 // This verifies that `_XlaHostComputeMlirOp` has a well-formed 84 // `host_mlir_module` attribute. 85 // For other attributes, there is no additional verification beyond the default. Verify(_XlaHostComputeMlirOp op)86static LogicalResult Verify(_XlaHostComputeMlirOp op) { 87 // Extract the module and function. 88 StringRef host_module = op.host_mlir_module(); 89 90 if (host_module.empty()) return success(); 91 92 mlir::OwningModuleRef module_for_func; 93 tensorflow::Status status = tensorflow::DeserializeMlirModule( 94 host_module.str(), op->getContext(), &module_for_func); 95 if (!status.ok()) { 96 return op.emitError() 97 << "attribute 'host_mlir_module' can not be deserialized. " 98 << status.error_message(); 99 } 100 101 FuncOp func = module_for_func->lookupSymbol<FuncOp>("host_func"); 102 if (!func) 103 return op.emitError() 104 << "serialized module in attribute 'host_mlir_module' does not " 105 "contain 'host_func' function."; 106 107 if (op->getNumOperands() != func.getType().getNumInputs()) 108 return op.emitError() 109 << "'host_func' has " << func.getType().getNumInputs() 110 << " inputs and '_XlaHostComputeMlir' has " << op->getNumOperands() 111 << " operands. Number of operands/inputs should be the same."; 112 113 if (op->getNumResults() != func.getType().getNumResults()) 114 return op.emitError() << "'host_func' has " 115 << func.getType().getNumResults() 116 << " results and '_XlaHostComputeMlir' has " 117 << op->getNumResults() 118 << " results. Number of results should be the same."; 119 120 return success(); 121 } 122 GetHostFunc(mlir::OwningModuleRef * mlir_module)123FuncOp _XlaHostComputeMlirOp::GetHostFunc(mlir::OwningModuleRef* mlir_module) { 124 if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(), 125 this->getContext(), mlir_module) 126 .ok()) 127 return nullptr; 128 return (*mlir_module)->lookupSymbol<FuncOp>("host_func"); 129 } 130 131 } // namespace TF 132 } // namespace mlir 133 134 //===----------------------------------------------------------------------===// 135 // TableGen'd op method definitions 136 //===----------------------------------------------------------------------===// 137 138 #define GET_OP_CLASSES 139 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" 140