• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the tf_framework dialect.
17 
18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
19 
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc"
23 
24 // Generated dialect definitions.
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc"
26 
27 namespace mlir {
28 namespace kernel_gen {
29 namespace tf_framework {
30 
initialize()31 void TFFrameworkDialect::initialize() {
32   addOperations<
33 #define GET_OP_LIST
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
35       >();
36   addTypes<JITCallableType, OpKernelContextType>();
37 }
38 
39 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const40 Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
41   StringRef keyword;
42   if (parser.parseKeyword(&keyword)) return Type();
43 
44   if (keyword == "op_kernel_context") {
45     return OpKernelContextType::get(getContext());
46   }
47   if (keyword == "jit_callable") {
48     return JITCallableType::get(getContext());
49   }
50 
51   parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ")
52       << keyword;
53   return Type();
54 }
55 
56 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const57 void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
58   if (type.isa<OpKernelContextType>()) {
59     os << "op_kernel_context";
60     return;
61   }
62   if (type.isa<JITCallableType>()) {
63     os << "jit_callable";
64     return;
65   }
66   llvm_unreachable("unexpected TF Framework type kind");
67 }
68 
69 template <typename OpTy>
Verify(OpTy op)70 LogicalResult Verify(OpTy op) {
71   return success();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TFAllocOp
76 //===----------------------------------------------------------------------===//
77 template <>
Verify(TFAllocOp op)78 LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
79   // Check that the total number of operands matches the number of dynamic
80   // dimensions specified in the memref type.
81   unsigned result_dyn_dims = op.getType().getNumDynamicDims();
82   unsigned dyn_sizes_count = op.dyn_sizes().size();
83   if (dyn_sizes_count != result_dyn_dims)
84     return op.emitOpError()
85            << "`dyn_sizes` count " << dyn_sizes_count
86            << " does not match dynamic dimensions count in the result type"
87            << op.getType();
88   return success();
89 }
90 
ConvertAttrToEnumValue(ErrorCode error_code)91 ::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) {
92   using ::tensorflow::error::Code;
93   switch (error_code) {
94     case ErrorCode::OK:
95       return Code::OK;
96     case ErrorCode::CANCELLED:
97       return Code::CANCELLED;
98     case ErrorCode::UNKNOWN:
99       return Code::UNKNOWN;
100     case ErrorCode::INVALID_ARGUMENT:
101       return Code::INVALID_ARGUMENT;
102     case ErrorCode::DEADLINE_EXCEEDED:
103       return Code::DEADLINE_EXCEEDED;
104     case ErrorCode::NOT_FOUND:
105       return Code::NOT_FOUND;
106     case ErrorCode::ALREADY_EXISTS:
107       return Code::ALREADY_EXISTS;
108     case ErrorCode::PERMISSION_DENIED:
109       return Code::PERMISSION_DENIED;
110     case ErrorCode::UNAUTHENTICATED:
111       return Code::UNAUTHENTICATED;
112     case ErrorCode::RESOURCE_EXHAUSTED:
113       return Code::RESOURCE_EXHAUSTED;
114     case ErrorCode::FAILED_PRECONDITION:
115       return Code::FAILED_PRECONDITION;
116     case ErrorCode::ABORTED:
117       return Code::ABORTED;
118     case ErrorCode::OUT_OF_RANGE:
119       return Code::OUT_OF_RANGE;
120     case ErrorCode::UNIMPLEMENTED:
121       return Code::UNIMPLEMENTED;
122     case ErrorCode::INTERNAL:
123       return Code::INTERNAL;
124     case ErrorCode::UNAVAILABLE:
125       return Code::UNAVAILABLE;
126     case ErrorCode::DATA_LOSS:
127       return Code::DATA_LOSS;
128   }
129 }
130 
131 }  // namespace tf_framework
132 }  // namespace kernel_gen
133 }  // namespace mlir
134 
135 #define GET_OP_CLASSES
136 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
137