1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the standard MLIR TensorFlow dialect after control 17 // dependences are raise to the standard form. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 20 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 21 22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 23 #include "mlir/IR/Dialect.h" // from @llvm-project 24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" 25 26 namespace mlir { 27 namespace TF { 28 29 class TensorFlowRegistryEffectInterfaceFallback; 30 31 class TensorFlowDialect final : public Dialect { 32 public: 33 explicit TensorFlowDialect(MLIRContext *context); 34 ~TensorFlowDialect() override; 35 getDialectNamespace()36 static StringRef getDialectNamespace() { return "tf"; } 37 38 // Overrides to redirect to tf_type dialect. 39 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; 40 Type parseType(DialectAsmParser &parser) const override; 41 42 // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a 43 // function references to its gradient function. This attribute in TensorFlow 44 // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the 45 // string description of gradient attribute. GetGradientAttrName()46 static StringRef GetGradientAttrName() { return "tf.gradient"; } 47 48 // This attribute marks if a function is stateful. 49 // Returns the string description of stateful attribute. GetStatefulAttrName()50 static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } 51 52 // Returns true if the op can be duplicated during transformations. 53 static bool CanDuplicate(Operation *op); 54 55 // Returns true if the op can have side effects. 56 static bool CanHaveSideEffects(Operation *op); 57 58 // Registered hook to materialize a constant operation from a given attribute 59 // value with the desired resultant type. 60 Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, 61 Location loc) override; 62 63 typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction; 64 65 // Register an op registration hook which is invoked during construction. 66 // 67 // A hook may use the public addOperations() method to add additional 68 // operations to the dialect. Hooks will only apply to subsequent 69 // instantations of the Dialect/MLIRContext. RegisterAdditionalOperationHook(AdditionalOpFunction fn)70 static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) { 71 GetAdditionalOperationHooks()->push_back(std::move(fn)); 72 } 73 74 // Re-define publicly the protected addOperations() method from the Dialect 75 // class, usually used in a Dialect constructor. This allows hook 76 // functions to register operations on the TensorFlow dialect using the 77 // same interface. 78 template <typename... Args> addOperations()79 void addOperations() { 80 Dialect::addOperations<Args...>(); 81 } 82 83 using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>, 84 SmallVectorImpl<OpFoldResult> &); RegisterConstantFoldHook(ConstantFoldHook fn)85 static void RegisterConstantFoldHook(ConstantFoldHook fn) { 86 constant_fold_hook_ = std::move(fn); 87 } 88 constantFold(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)89 static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands, 90 SmallVectorImpl<OpFoldResult> &results) { 91 if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); 92 return failure(); 93 } 94 95 using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input, 96 ElementsAttr &output); RegisterDecodeConstantHook(DecodeConstantHook fn)97 static void RegisterDecodeConstantHook(DecodeConstantHook fn) { 98 decode_constant_hook_ = std::move(fn); 99 } decode(OpaqueElementsAttr input,ElementsAttr & output)100 static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) { 101 if (decode_constant_hook_) return decode_constant_hook_(input, output); 102 return failure(); 103 } 104 105 // Provides a hook for op interface. 106 void *getRegisteredInterfaceForOp(mlir::TypeID interface, 107 mlir::OperationName opName) override; 108 109 private: 110 // Hook functions which may add additional operations to the dialect. 111 // These are invoked at construction time. 112 static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks(); 113 114 static ConstantFoldHook constant_fold_hook_; 115 static DecodeConstantHook decode_constant_hook_; 116 117 // Storage for a custom fallback interface. 118 TensorFlowRegistryEffectInterfaceFallback *fallback_effect_op_interface_; 119 }; 120 121 } // namespace TF 122 } // namespace mlir 123 124 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 125