1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the operations used in the standard MLIR TensorFlow dialect 17 // after control dependences are raise to the standard form. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ 20 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ 21 22 #include "mlir/Dialect/Traits.h" // from @llvm-project 23 #include "mlir/IR/Attributes.h" // from @llvm-project 24 #include "mlir/IR/Builders.h" // from @llvm-project 25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 27 #include "mlir/IR/Dialect.h" // from @llvm-project 28 #include "mlir/IR/Matchers.h" // from @llvm-project 29 #include "mlir/IR/OpImplementation.h" // from @llvm-project 30 #include "mlir/IR/TypeUtilities.h" // from @llvm-project 31 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project 32 #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project 33 #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project 34 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project 35 #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project 36 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project 37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" 38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" 39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" 40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" 41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" 42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" 43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" 44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" 45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" 46 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" 47 48 namespace mlir { 49 namespace TF { 50 51 class TensorFlowDialect : public Dialect { 52 public: 53 TensorFlowDialect(MLIRContext *context); 54 getDialectNamespace()55 static StringRef getDialectNamespace() { return "tf"; } 56 57 // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a 58 // function references to its gradient function. This attribute in TensorFlow 59 // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the 60 // string description of gradient attribute. GetGradientAttrName()61 static StringRef GetGradientAttrName() { return "tf.gradient"; } 62 63 // This attribute marks if a function is stateful. 64 // Returns the string description of stateful attribute. GetStatefulAttrName()65 static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } 66 67 // Returns true if the op can be duplicated during transformations. 68 static bool CanDuplicate(Operation *op); 69 70 // Returns true if the op can have side effects. 71 static bool CanHaveSideEffects(Operation *op); 72 73 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; 74 75 void printAttribute(Attribute attr, DialectAsmPrinter &os) const override; 76 77 // Parse a type registered to this dialect. 78 Type parseType(DialectAsmParser &parser) const override; 79 80 // Prints a type registered to this dialect. 81 void printType(Type ty, DialectAsmPrinter &os) const override; 82 83 // Parses resource type with potential subtypes. 84 Type ParseResourceType(DialectAsmParser &parser, Location loc) const; 85 86 // Prints resource type with potential subtypes. 87 void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const; 88 89 // Parse and print variant type. It may have subtypes inferred using shape 90 // inference. 91 Type ParseVariantType(DialectAsmParser &parser, Location loc) const; 92 void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const; 93 94 // Registered hook to materialize a constant operation from a given attribute 95 // value with the desired resultant type. 96 Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, 97 Location loc) override; 98 99 typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction; 100 101 // Register an op registration hook which is invoked during construction. 102 // 103 // A hook may use the public addOperations() method to add additional 104 // operations to the dialect. Hooks will only apply to subsequent 105 // instantations of the Dialect/MLIRContext. RegisterAdditionalOperationHook(AdditionalOpFunction fn)106 static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) { 107 GetAdditionalOperationHooks()->push_back(std::move(fn)); 108 } 109 110 // Re-define publicly the protected addOperations() method from the Dialect 111 // class, usually used in a Dialect constructor. This allows hook 112 // functions to register operations on the TensorFlow dialect using the 113 // same interface. 114 template <typename... Args> addOperations()115 void addOperations() { 116 Dialect::addOperations<Args...>(); 117 } 118 119 using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>, 120 SmallVectorImpl<OpFoldResult> &); RegisterConstantFoldHook(ConstantFoldHook fn)121 static void RegisterConstantFoldHook(ConstantFoldHook fn) { 122 constant_fold_hook_ = std::move(fn); 123 } 124 constantFold(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)125 static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands, 126 SmallVectorImpl<OpFoldResult> &results) { 127 if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); 128 return failure(); 129 } 130 131 using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input, 132 ElementsAttr &output); RegisterDecodeConstantHook(DecodeConstantHook fn)133 static void RegisterDecodeConstantHook(DecodeConstantHook fn) { 134 decode_constant_hook_ = std::move(fn); 135 } decode(OpaqueElementsAttr input,ElementsAttr & output)136 static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) { 137 if (decode_constant_hook_) return decode_constant_hook_(input, output); 138 return failure(); 139 } 140 141 private: 142 // Hook functions which may add additional operations to the dialect. 143 // These are invoked at construction time. 144 static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks(); 145 146 static ConstantFoldHook constant_fold_hook_; 147 static DecodeConstantHook decode_constant_hook_; 148 }; 149 150 } // namespace TF 151 } // namespace mlir 152 153 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ 154