• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h"
17 
18 #include <algorithm>
19 
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/OperationSupport.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
30 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
31 
32 namespace mlir {
33 namespace tf_jitrt {
34 
35 //===----------------------------------------------------------------------===//
36 // JitRuntimeDialect Interfaces
37 //===----------------------------------------------------------------------===//
38 
39 namespace {
40 // Operations in the `tf_jitrt` dialect are always safe to inline because they
41 // are pure compute operations.
42 struct JitRuntimeInlinerInterface : public DialectInlinerInterface {
43   using DialectInlinerInterface::DialectInlinerInterface;
44 
isLegalToInlinemlir::tf_jitrt::__anonce345bd80111::JitRuntimeInlinerInterface45   bool isLegalToInline(Operation*, Operation*, bool) const final {
46     assert(false && "tf_jitrt doesn't have callable operations");
47     return true;
48   }
49 
isLegalToInlinemlir::tf_jitrt::__anonce345bd80111::JitRuntimeInlinerInterface50   bool isLegalToInline(Region*, Region*, bool,
51                        BlockAndValueMapping&) const final {
52     return true;
53   }
54 
isLegalToInlinemlir::tf_jitrt::__anonce345bd80111::JitRuntimeInlinerInterface55   bool isLegalToInline(Operation*, Region*, bool,
56                        BlockAndValueMapping&) const final {
57     return true;
58   }
59 };
60 }  // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // JitRuntimeDialect Dialect
64 //===----------------------------------------------------------------------===//
65 
JitRuntimeDialect(mlir::MLIRContext * context)66 JitRuntimeDialect::JitRuntimeDialect(mlir::MLIRContext* context)
67     : Dialect(/*name*/ "tf_jitrt", context,
68               mlir::TypeID::get<JitRuntimeDialect>()) {
69   addInterfaces<JitRuntimeInlinerInterface>();
70   addOperations<
71 #define GET_OP_LIST
72 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.cc.inc"
73       >();
74 }
75 
76 // Computes the number of elements in the tensor type. Optimistically use `1` as
77 // a size of all unknown dimensions. These heuristics match cost estimates of
78 // the fallback_async::ExecuteOp operations.
GetRankedTensorSize(TensorType tensor)79 static int64_t GetRankedTensorSize(TensorType tensor) {
80   assert(tensor.hasRank() && "shape must be ranked");
81   if (!tensor.hasRank()) return 0;
82 
83   int64_t size = 1;  // scalars (rank 0) have size 1
84   for (int64_t dim : tensor.getShape()) size *= std::max<int64_t>(1, dim);
85   return size;
86 }
87 
GetMaxArgSize(mlir::func::FuncOp func)88 int64_t GetMaxArgSize(mlir::func::FuncOp func) {
89   int64_t max_arg_size = 1;
90   for (BlockArgument& arg : func.getArguments()) {
91     auto type = arg.getType().cast<mlir::TensorType>();
92     if (type.hasRank())
93       max_arg_size = std::max(max_arg_size, GetRankedTensorSize(type));
94   }
95   return max_arg_size;
96 }
97 
cost()98 int64_t FallbackExecuteOp::cost() {
99   Operation* self = getOperation();
100 
101   // Find the referenced kernel function.
102   auto kernel_fn =
103       SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(self, kernel());
104   if (!kernel_fn) return 1;
105 
106   int64_t cost = 0;
107 
108   // Compute the max argument size, which we will assign to unranked inputs
109   // just like TFRT's cost model does.
110   int64_t max_arg_size = GetMaxArgSize(kernel_fn);
111 
112   // Maybe override max argument size with explicit value passed via attribute.
113   auto module = kernel_fn->getParentOfType<mlir::ModuleOp>();
114   if (auto attr = module->getAttrOfType<IntegerAttr>("tfrt.max-arg-size"))
115     max_arg_size = attr.getValue().getSExtValue();
116 
117   // Get the sum of sizes of all ranked inputs for all operations in the
118   // function body. This approach approximates the cost analysis in the
119   // tfrt_compiler::CostAnalysis, because initially we want to get identical
120   // stream assignments, however long term we want to use more precise cost
121   // estimation, together with a more precise stream assignment.
122   //
123   // TODO(ezhulenev): Once we have a proper cost model for MLIR operations,
124   // use it to compute a more precise cost estimation.
125   for (mlir::Operation& op : kernel_fn.getBody().getOps()) {
126     // Skip return operation.
127     if (mlir::isa<mlir::func::ReturnOp>(op)) continue;
128 
129     // These ops are cheap regardless of their input sizes.
130     if (mlir::isa<mlir::TF::ShapeOp, mlir::TF::StridedSliceOp,
131                   mlir::TF::ReshapeOp, mlir::TF::ExpandDimsOp>(op)) {
132       cost += 1;
133       continue;
134     }
135 
136     // Set initial op cost to 1, just like TFRT's cost model does.
137     cost += 1;
138     for (Type type : op.getOperandTypes()) {
139       if (auto tensor = type.dyn_cast<RankedTensorType>()) {
140         cost += GetRankedTensorSize(tensor);
141       } else {
142         cost += max_arg_size;
143       }
144     }
145   }
146 
147   return std::max<int64_t>(1, cost);
148 }
149 
150 }  // namespace tf_jitrt
151 }  // end namespace mlir
152 
153 #define GET_OP_CLASSES
154 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.cc.inc"
155