1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the operations used in the tf_framework dialect. 17 18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" 19 20 #include "mlir/IR/Builders.h" // from @llvm-project 21 #include "mlir/IR/DialectImplementation.h" // from @llvm-project 22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc" 23 24 // Generated dialect definitions. 25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc" 26 27 namespace mlir { 28 namespace kernel_gen { 29 namespace tf_framework { 30 initialize()31void TFFrameworkDialect::initialize() { 32 addOperations< 33 #define GET_OP_LIST 34 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" 35 >(); 36 addTypes<JITCallableType, OpKernelContextType>(); 37 } 38 39 /// Parse a type registered to this dialect. parseType(DialectAsmParser & parser) const40Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { 41 StringRef keyword; 42 if (parser.parseKeyword(&keyword)) return Type(); 43 44 if (keyword == "op_kernel_context") { 45 return OpKernelContextType::get(getContext()); 46 } 47 if (keyword == "jit_callable") { 48 return JITCallableType::get(getContext()); 49 } 50 51 parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ") 52 << keyword; 53 return Type(); 54 } 55 56 /// Print a type registered to this dialect. printType(Type type,DialectAsmPrinter & os) const57void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { 58 if (type.isa<OpKernelContextType>()) { 59 os << "op_kernel_context"; 60 return; 61 } 62 if (type.isa<JITCallableType>()) { 63 os << "jit_callable"; 64 return; 65 } 66 llvm_unreachable("unexpected TF Framework type kind"); 67 } 68 69 template <typename OpTy> Verify(OpTy op)70LogicalResult Verify(OpTy op) { 71 return success(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TFAllocOp 76 //===----------------------------------------------------------------------===// 77 template <> Verify(TFAllocOp op)78LogicalResult Verify<TFAllocOp>(TFAllocOp op) { 79 // Check that the total number of operands matches the number of dynamic 80 // dimensions specified in the memref type. 81 unsigned result_dyn_dims = op.getType().getNumDynamicDims(); 82 unsigned dyn_sizes_count = op.dyn_sizes().size(); 83 if (dyn_sizes_count != result_dyn_dims) 84 return op.emitOpError() 85 << "`dyn_sizes` count " << dyn_sizes_count 86 << " does not match dynamic dimensions count in the result type" 87 << op.getType(); 88 return success(); 89 } 90 ConvertAttrToEnumValue(ErrorCode error_code)91::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) { 92 using ::tensorflow::error::Code; 93 switch (error_code) { 94 case ErrorCode::OK: 95 return Code::OK; 96 case ErrorCode::CANCELLED: 97 return Code::CANCELLED; 98 case ErrorCode::UNKNOWN: 99 return Code::UNKNOWN; 100 case ErrorCode::INVALID_ARGUMENT: 101 return Code::INVALID_ARGUMENT; 102 case ErrorCode::DEADLINE_EXCEEDED: 103 return Code::DEADLINE_EXCEEDED; 104 case ErrorCode::NOT_FOUND: 105 return Code::NOT_FOUND; 106 case ErrorCode::ALREADY_EXISTS: 107 return Code::ALREADY_EXISTS; 108 case ErrorCode::PERMISSION_DENIED: 109 return Code::PERMISSION_DENIED; 110 case ErrorCode::UNAUTHENTICATED: 111 return Code::UNAUTHENTICATED; 112 case ErrorCode::RESOURCE_EXHAUSTED: 113 return Code::RESOURCE_EXHAUSTED; 114 case ErrorCode::FAILED_PRECONDITION: 115 return Code::FAILED_PRECONDITION; 116 case ErrorCode::ABORTED: 117 return Code::ABORTED; 118 case ErrorCode::OUT_OF_RANGE: 119 return Code::OUT_OF_RANGE; 120 case ErrorCode::UNIMPLEMENTED: 121 return Code::UNIMPLEMENTED; 122 case ErrorCode::INTERNAL: 123 return Code::INTERNAL; 124 case ErrorCode::UNAVAILABLE: 125 return Code::UNAVAILABLE; 126 case ErrorCode::DATA_LOSS: 127 return Code::DATA_LOSS; 128 } 129 } 130 131 } // namespace tf_framework 132 } // namespace kernel_gen 133 } // namespace mlir 134 135 #define GET_OP_CLASSES 136 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" 137