• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
41 #include "mlir/Dialect/Traits.h"  // from @llvm-project
42 #include "mlir/IR/Attributes.h"  // from @llvm-project
43 #include "mlir/IR/Builders.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
45 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
46 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
47 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
48 #include "mlir/IR/Identifier.h"  // from @llvm-project
49 #include "mlir/IR/Location.h"  // from @llvm-project
50 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
51 #include "mlir/IR/Matchers.h"  // from @llvm-project
52 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
53 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
54 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
55 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
56 #include "mlir/IR/Types.h"  // from @llvm-project
57 #include "mlir/IR/Value.h"  // from @llvm-project
58 #include "mlir/Parser.h"  // from @llvm-project
59 #include "mlir/Support/LLVM.h"  // from @llvm-project
60 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
61 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
69 #include "tensorflow/core/platform/logging.h"
70 #include "tensorflow/core/util/tensor_format.h"
71 
72 namespace mlir {
73 namespace TF {
74 namespace {
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
76 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
77 }  // namespace
78 
79 //===----------------------------------------------------------------------===//
80 // _XlaHostComputeOp
81 //===----------------------------------------------------------------------===//
82 
83 // This verifies that `_XlaHostComputeMlirOp` has a well-formed
84 // `host_mlir_module` attribute.
85 // For other attributes, there is no additional verification beyond the default.
Verify(_XlaHostComputeMlirOp op)86 static LogicalResult Verify(_XlaHostComputeMlirOp op) {
87   // Extract the module and function.
88   StringRef host_module = op.host_mlir_module();
89 
90   if (host_module.empty()) return success();
91 
92   mlir::OwningModuleRef module_for_func;
93   tensorflow::Status status = tensorflow::DeserializeMlirModule(
94       host_module.str(), op->getContext(), &module_for_func);
95   if (!status.ok()) {
96     return op.emitError()
97            << "attribute 'host_mlir_module' can not be deserialized. "
98            << status.error_message();
99   }
100 
101   FuncOp func = module_for_func->lookupSymbol<FuncOp>("host_func");
102   if (!func)
103     return op.emitError()
104            << "serialized module in attribute 'host_mlir_module' does not "
105               "contain 'host_func' function.";
106 
107   if (op->getNumOperands() != func.getType().getNumInputs())
108     return op.emitError()
109            << "'host_func' has " << func.getType().getNumInputs()
110            << " inputs and '_XlaHostComputeMlir' has " << op->getNumOperands()
111            << " operands.  Number of operands/inputs should be the same.";
112 
113   if (op->getNumResults() != func.getType().getNumResults())
114     return op.emitError() << "'host_func' has "
115                           << func.getType().getNumResults()
116                           << " results and '_XlaHostComputeMlir' has "
117                           << op->getNumResults()
118                           << " results.  Number of results should be the same.";
119 
120   return success();
121 }
122 
GetHostFunc(mlir::OwningModuleRef * mlir_module)123 FuncOp _XlaHostComputeMlirOp::GetHostFunc(mlir::OwningModuleRef* mlir_module) {
124   if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(),
125                                          this->getContext(), mlir_module)
126            .ok())
127     return nullptr;
128   return (*mlir_module)->lookupSymbol<FuncOp>("host_func");
129 }
130 
131 }  // namespace TF
132 }  // namespace mlir
133 
134 //===----------------------------------------------------------------------===//
135 // TableGen'd op method definitions
136 //===----------------------------------------------------------------------===//
137 
138 #define GET_OP_CLASSES
139 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc"
140