1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h"
17
18 #include <algorithm>
19
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/OperationSupport.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
30 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
31
32 namespace mlir {
33 namespace tf_jitrt {
34
35 //===----------------------------------------------------------------------===//
36 // JitRuntimeDialect Interfaces
37 //===----------------------------------------------------------------------===//
38
39 namespace {
40 // Operations in the `tf_jitrt` dialect are always safe to inline because they
41 // are pure compute operations.
42 struct JitRuntimeInlinerInterface : public DialectInlinerInterface {
43 using DialectInlinerInterface::DialectInlinerInterface;
44
isLegalToInlinemlir::tf_jitrt::__anon60635f270111::JitRuntimeInlinerInterface45 bool isLegalToInline(Operation*, Operation*, bool) const final {
46 assert(false && "tf_jitrt doesn't have callable operations");
47 return true;
48 }
49
isLegalToInlinemlir::tf_jitrt::__anon60635f270111::JitRuntimeInlinerInterface50 bool isLegalToInline(Region*, Region*, bool,
51 BlockAndValueMapping&) const final {
52 return true;
53 }
54
isLegalToInlinemlir::tf_jitrt::__anon60635f270111::JitRuntimeInlinerInterface55 bool isLegalToInline(Operation*, Region*, bool,
56 BlockAndValueMapping&) const final {
57 return true;
58 }
59 };
60 } // namespace
61
62 //===----------------------------------------------------------------------===//
63 // JitRuntimeDialect Dialect
64 //===----------------------------------------------------------------------===//
65
JitRuntimeDialect(mlir::MLIRContext * context)66 JitRuntimeDialect::JitRuntimeDialect(mlir::MLIRContext* context)
67 : Dialect(/*name*/ "tf_jitrt", context,
68 mlir::TypeID::get<JitRuntimeDialect>()) {
69 addInterfaces<JitRuntimeInlinerInterface>();
70 addOperations<
71 #define GET_OP_LIST
72 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.cc.inc"
73 >();
74 }
75
76 // Computes the number of elements in the tensor type. Optimistically use `1` as
77 // a size of all unknown dimensions. These heuristics match cost estimates of
78 // the fallback_async::ExecuteOp operations.
GetRankedTensorSize(TensorType tensor)79 static int64_t GetRankedTensorSize(TensorType tensor) {
80 assert(tensor.hasRank() && "shape must be ranked");
81 if (!tensor.hasRank()) return 0;
82
83 int64_t size = 1; // scalars (rank 0) have size 1
84 for (int64_t dim : tensor.getShape()) size *= std::max<int64_t>(1, dim);
85 return size;
86 }
87
GetMaxArgSize(mlir::func::FuncOp func)88 int64_t GetMaxArgSize(mlir::func::FuncOp func) {
89 int64_t max_arg_size = 1;
90 for (BlockArgument& arg : func.getArguments()) {
91 auto type = arg.getType().cast<mlir::TensorType>();
92 if (type.hasRank())
93 max_arg_size = std::max(max_arg_size, GetRankedTensorSize(type));
94 }
95 return max_arg_size;
96 }
97
cost()98 int64_t FallbackExecuteOp::cost() {
99 Operation* self = getOperation();
100
101 // Find the referenced kernel function.
102 auto kernel_fn =
103 SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(self, kernel());
104 if (!kernel_fn) return 1;
105
106 int64_t cost = 0;
107
108 // Compute the max argument size, which we will assign to unranked inputs
109 // just like TFRT's cost model does.
110 int64_t max_arg_size = GetMaxArgSize(kernel_fn);
111
112 // Maybe override max argument size with explicit value passed via attribute.
113 auto module = kernel_fn->getParentOfType<mlir::ModuleOp>();
114 if (auto attr = module->getAttrOfType<IntegerAttr>("tfrt.max-arg-size"))
115 max_arg_size = attr.getValue().getSExtValue();
116
117 // Get the sum of sizes of all ranked inputs for all operations in the
118 // function body. This approach approximates the cost analysis in the
119 // tfrt_compiler::CostAnalysis, because initially we want to get identical
120 // stream assignments, however long term we want to use more precise cost
121 // estimation, together with a more precise stream assignment.
122 //
123 // TODO(ezhulenev): Once we have a proper cost model for MLIR operations,
124 // use it to compute a more precise cost estimation.
125 for (mlir::Operation& op : kernel_fn.getBody().getOps()) {
126 // Skip return operation.
127 if (mlir::isa<mlir::func::ReturnOp>(op)) continue;
128
129 // These ops are cheap regardless of their input sizes.
130 if (mlir::isa<mlir::TF::ShapeOp, mlir::TF::StridedSliceOp,
131 mlir::TF::ReshapeOp, mlir::TF::ExpandDimsOp>(op)) {
132 cost += 1;
133 continue;
134 }
135
136 // Set initial op cost to 1, just like TFRT's cost model does.
137 cost += 1;
138 for (Type type : op.getOperandTypes()) {
139 if (auto tensor = type.dyn_cast<RankedTensorType>()) {
140 cost += GetRankedTensorSize(tensor);
141 } else {
142 cost += max_arg_size;
143 }
144 }
145 }
146
147 return std::max<int64_t>(1, cost);
148 }
149
150 } // namespace tf_jitrt
151 } // end namespace mlir
152
153 #define GET_OP_CLASSES
154 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.cc.inc"
155