• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_
18 
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
24 #include "tfrt/compiler/opdefs/tfrt_op_interfaces.h"  // from @tf_runtime
25 #include "tfrt/compiler/opdefs/tfrt_traits.h"  // from @tf_runtime
26 
27 namespace mlir {
28 namespace tf_jitrt {
29 
30 class JitRuntimeDialect : public mlir::Dialect {
31  public:
32   explicit JitRuntimeDialect(mlir::MLIRContext *context);
getDialectNamespace()33   static mlir::StringRef getDialectNamespace() { return "tf_jitrt"; }
34 };
35 
36 // Returns the maximum size of ranked tensor argument of the `func`. Returns `1`
37 // if all arguments are unranked.
38 //
39 // To get the cost of the `tf_jitrt.execute` operations, when compiled cluster
40 // has unranked inputs, we use the maximum size of the arguments of the parent
41 // function as an estimate (see TFRT_CostFunctionInterface).
42 int64_t GetMaxArgSize(mlir::func::FuncOp func);
43 
44 }  // namespace tf_jitrt
45 }  // namespace mlir
46 
47 #define GET_OP_CLASSES
48 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.h.inc"
49 
50 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_
51