• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the standard MLIR TensorFlow dialect
17 // after control dependences are raise to the standard form.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
20 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
21 
22 #include "mlir/Dialect/Traits.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Dialect.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
31 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
32 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
33 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
34 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
35 #include "mlir/Interfaces/LoopLikeInterface.h"  // from @llvm-project
36 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
47 
48 namespace mlir {
49 namespace TF {
50 
51 class TensorFlowDialect : public Dialect {
52  public:
53   TensorFlowDialect(MLIRContext *context);
54 
getDialectNamespace()55   static StringRef getDialectNamespace() { return "tf"; }
56 
57   // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
58   // function references to its gradient function. This attribute in TensorFlow
59   // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
60   // string description of gradient attribute.
GetGradientAttrName()61   static StringRef GetGradientAttrName() { return "tf.gradient"; }
62 
63   // This attribute marks if a function is stateful.
64   // Returns the string description of stateful attribute.
GetStatefulAttrName()65   static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
66 
67   // Returns true if the op can be duplicated during transformations.
68   static bool CanDuplicate(Operation *op);
69 
70   // Returns true if the op can have side effects.
71   static bool CanHaveSideEffects(Operation *op);
72 
73   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
74 
75   void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
76 
77   // Parse a type registered to this dialect.
78   Type parseType(DialectAsmParser &parser) const override;
79 
80   // Prints a type registered to this dialect.
81   void printType(Type ty, DialectAsmPrinter &os) const override;
82 
83   // Parses resource type with potential subtypes.
84   Type ParseResourceType(DialectAsmParser &parser, Location loc) const;
85 
86   // Prints resource type with potential subtypes.
87   void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const;
88 
89   // Parse and print variant type. It may have subtypes inferred using shape
90   // inference.
91   Type ParseVariantType(DialectAsmParser &parser, Location loc) const;
92   void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const;
93 
94   // Registered hook to materialize a constant operation from a given attribute
95   // value with the desired resultant type.
96   Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
97                                  Location loc) override;
98 
99   typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
100 
101   // Register an op registration hook which is invoked during construction.
102   //
103   // A hook may use the public addOperations() method to add additional
104   // operations to the dialect. Hooks will only apply to subsequent
105   // instantations of the Dialect/MLIRContext.
RegisterAdditionalOperationHook(AdditionalOpFunction fn)106   static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
107     GetAdditionalOperationHooks()->push_back(std::move(fn));
108   }
109 
110   // Re-define publicly the protected addOperations() method from the Dialect
111   // class, usually used in a Dialect constructor. This allows hook
112   // functions to register operations on the TensorFlow dialect using the
113   // same interface.
114   template <typename... Args>
addOperations()115   void addOperations() {
116     Dialect::addOperations<Args...>();
117   }
118 
119   using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
120                                              SmallVectorImpl<OpFoldResult> &);
RegisterConstantFoldHook(ConstantFoldHook fn)121   static void RegisterConstantFoldHook(ConstantFoldHook fn) {
122     constant_fold_hook_ = std::move(fn);
123   }
124 
constantFold(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)125   static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
126                                     SmallVectorImpl<OpFoldResult> &results) {
127     if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
128     return failure();
129   }
130 
131   using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
132                                                ElementsAttr &output);
RegisterDecodeConstantHook(DecodeConstantHook fn)133   static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
134     decode_constant_hook_ = std::move(fn);
135   }
decode(OpaqueElementsAttr input,ElementsAttr & output)136   static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
137     if (decode_constant_hook_) return decode_constant_hook_(input, output);
138     return failure();
139   }
140 
141  private:
142   // Hook functions which may add additional operations to the dialect.
143   // These are invoked at construction time.
144   static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks();
145 
146   static ConstantFoldHook constant_fold_hook_;
147   static DecodeConstantHook decode_constant_hook_;
148 };
149 
150 }  // namespace TF
151 }  // namespace mlir
152 
153 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
154