1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_ 18 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/OpDefinition.h" 23 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" 24 #include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" // from @tf_runtime 25 #include "tfrt/compiler/opdefs/tfrt_traits.h" // from @tf_runtime 26 27 namespace mlir { 28 namespace tf_jitrt { 29 30 class JitRuntimeDialect : public mlir::Dialect { 31 public: 32 explicit JitRuntimeDialect(mlir::MLIRContext *context); getDialectNamespace()33 static mlir::StringRef getDialectNamespace() { return "tf_jitrt"; } 34 }; 35 36 // Returns the maximum size of ranked tensor argument of the `func`. Returns `1` 37 // if all arguments are unranked. 38 // 39 // To get the cost of the `tf_jitrt.execute` operations, when compiled cluster 40 // has unranked inputs, we use the maximum size of the arguments of the parent 41 // function as an estimate (see TFRT_CostFunctionInterface). 42 int64_t GetMaxArgSize(mlir::func::FuncOp func); 43 44 } // namespace tf_jitrt 45 } // namespace mlir 46 47 #define GET_OP_CLASSES 48 #include "tensorflow/compiler/mlir/tfrt/tf_jitrt_ops.h.inc" 49 50 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_OPDEFS_TF_JITRT_OPS_H_ 51