• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
41 #include "mlir/Dialect/Traits.h"  // from @llvm-project
42 #include "mlir/IR/Attributes.h"  // from @llvm-project
43 #include "mlir/IR/Builders.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
45 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
46 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
47 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
48 #include "mlir/IR/Identifier.h"  // from @llvm-project
49 #include "mlir/IR/Location.h"  // from @llvm-project
50 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
51 #include "mlir/IR/Matchers.h"  // from @llvm-project
52 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
53 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
54 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
55 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
56 #include "mlir/IR/Types.h"  // from @llvm-project
57 #include "mlir/IR/Value.h"  // from @llvm-project
58 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"  // from @llvm-project
59 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
60 #include "mlir/Parser.h"  // from @llvm-project
61 #include "mlir/Support/LLVM.h"  // from @llvm-project
62 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
68 #include "tensorflow/core/platform/logging.h"
69 #include "tensorflow/core/util/tensor_format.h"
70 
71 // These are currently aliases and the alias will be removed, verified
72 // equivalent until then.
73 // TODO(b/178519687): Remvoe once addressed.
74 static_assert(std::is_same<tensorflow::int64, std::int64_t>::value,
75               "tensorflow::int64 is expected to match std::int64_t");
76 
77 namespace mlir {
78 namespace TF {
79 
80 //===----------------------------------------------------------------------===//
81 // TF Dialect Interfaces
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 
86 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anon8c7eb4140111::TFConstantFoldInterface87   TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anon8c7eb4140111::TFConstantFoldInterface88   LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
89                      SmallVectorImpl<OpFoldResult> &results) const final {
90     return TensorFlowDialect::constantFold(op, operands, results);
91   }
92 };
93 
94 struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
TFDecodeAttributesInterfacemlir::TF::__anon8c7eb4140111::TFDecodeAttributesInterface95   TFDecodeAttributesInterface(Dialect *dialect)
96       : DialectDecodeAttributesInterface(dialect) {}
decodemlir::TF::__anon8c7eb4140111::TFDecodeAttributesInterface97   LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const {
98     return TensorFlowDialect::decode(input, output);
99   }
100 };
101 
102 struct TFInlinerInterface : public DialectInlinerInterface {
103   using DialectInlinerInterface::DialectInlinerInterface;
104 
105   //===--------------------------------------------------------------------===//
106   // Analysis Hooks
107   //===--------------------------------------------------------------------===//
108 
109   // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
110   // a TF operation.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface111   bool isLegalToInline(Operation *call, Operation *callable,
112                        bool wouldBeCloned) const final {
113     // Check that the TF call operation is one that is legal to inline.
114     return !isa<TPUPartitionedCallOp>(call);
115   }
116 
117   // Returns if its legal to inline 'src' region into the 'dest' region
118   // attached to a TF operation.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface119   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
120                        BlockAndValueMapping &valueMapping) const final {
121     // Allow inlining in regions attached to region based control flow
122     // operations only if the src region is a single block region
123     return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
124            llvm::hasSingleElement(*src);
125   }
126 
127   // Returns true if its legal to inline a TF operation `op` into the `dest`
128   // region.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface129   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
130                        BlockAndValueMapping &) const final {
131     // An op is legal to inline if either of the following conditions is true:
132     // (a) Its legal to duplicate the Op.
133     // (b) The Op is inside a single use function. If that function is inlined,
134     //     post inlining, the function will be dead and eliminated from the IR.
135     //     So there won't be any code duplication.
136     // plus the function caller op can be replaced by inlined ops.
137     return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
138   }
139 
140   //===--------------------------------------------------------------------===//
141   // Transformation Hooks
142   //===--------------------------------------------------------------------===//
143 
144   // Attempts to materialize a conversion for a type mismatch between a call
145   // from this dialect, and a callable region. This method should generate an
146   // operation that takes 'input' as the only operand, and produces a single
147   // result of 'resultType'. If a conversion can not be generated, nullptr
148   // should be returned.
materializeCallConversionmlir::TF::__anon8c7eb4140111::TFInlinerInterface149   Operation *materializeCallConversion(OpBuilder &builder, Value input,
150                                        Type result_type,
151                                        Location conversion_loc) const final {
152     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
153       return nullptr;
154     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
155                                       /*truncate=*/builder.getBoolAttr(false));
156   }
157 };
158 }  // end anonymous namespace
159 
160 //===----------------------------------------------------------------------===//
161 // TF Dialect
162 //===----------------------------------------------------------------------===//
163 
164 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)165 bool TensorFlowDialect::CanDuplicate(Operation *op) {
166   // If the op is marked with the cannot duplicate trait, it cannot be
167   // duplicated.
168   if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
169 
170   // If the op has no memory side effects, it can be duplicated.
171   if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
172 
173   // If the op is marked stateless using the `is_stateless` attribute, that
174   // attribute determines if the op can be duplicated.
175   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
176     return is_stateless.getValue();
177 
178   // Otherwise, assume ops can be duplicated by default if its registered, else
179   // it cannot be for unknown ops.
180   return op->isRegistered();
181 }
182 
183 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)184 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
185   // If the op has no memory side effects, it has no side effects
186   if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
187 
188   // If the op is marked stateless using the `is_stateless` attribute, then
189   // it has no side effects.
190   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
191     return !is_stateless.getValue();
192 
193   // Terminators defined in the TF dialect do not have side effects.
194   if (op->hasTrait<OpTrait::IsTerminator>()) return false;
195 
196   // Otherwise assume that the op can have side effects.
197   return true;
198 }
199 
200 std::vector<TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()201     *TensorFlowDialect::GetAdditionalOperationHooks() {
202   static auto *const additional_operation_hooks =
203       new std::vector<TensorFlowDialect::AdditionalOpFunction>();
204   return additional_operation_hooks;
205 }
206 
207 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
208 TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
209 
TensorFlowDialect(MLIRContext * context)210 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
211     : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
212   addOperations<
213 #define GET_OP_LIST
214 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
215       >();
216   addOperations<
217 #define GET_OP_LIST
218 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
219       >();
220   addTypes<
221 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
222 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
223 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
224       >();
225   addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
226                 TFConstantFoldInterface>();
227   addAttributes<ShapeAttr, FuncAttr>();
228 
229   // Support unknown operations because not all TensorFlow operations are
230   // registered.
231   allowUnknownOperations();
232 
233   for (const auto &hook : *GetAdditionalOperationHooks()) {
234     hook(*this);
235   }
236 }
237 
238 namespace {
239 
ParseShapeAttr(MLIRContext * context,StringRef spec,Location loc)240 ShapeAttr ParseShapeAttr(MLIRContext *context, StringRef spec, Location loc) {
241   auto emit_error = [&, spec]() {
242     emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
243     return nullptr;
244   };
245 
246   if (!spec.consume_front("shape<")) return emit_error();
247 
248   if (spec.consume_front("*>"))
249     return mlir::TF::ShapeAttr::get(context, llvm::None);
250 
251   SmallVector<int64_t, 4> shape;
252   while (!spec.consume_front(">")) {
253     int64_t dim;
254 
255     if (spec.consume_front("?"))
256       dim = -1;
257     else if (spec.consumeInteger(10, dim) || dim < 0)
258       return emit_error();
259 
260     spec.consume_front("x");
261 
262     shape.push_back(dim);
263   }
264 
265   return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
266 }
267 
PrintShapeAttr(ShapeAttr attr,DialectAsmPrinter & os)268 void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) {  // NOLINT
269   os << "shape";
270 
271   os << "<";
272   if (attr.hasRank()) {
273     auto print_dim = [&](int64_t dim) {
274       if (dim > -1)
275         os << dim;
276       else
277         os << "?";
278     };
279     llvm::interleave(attr.getShape(), os, print_dim, "x");
280   } else {
281     os << "*";
282   }
283   os << ">";
284 }
285 
286 // Parses a #tf.func attribute of the following format:
287 //
288 //   #tf.func<@symbol, {attr = "value"}>
289 //
290 // where the first element is a SymbolRefAttr and the second element is a
291 // DictionaryAttr.
ParseFuncAttr(MLIRContext * context,StringRef spec,Location loc)292 FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) {
293   auto emit_error = [&, spec]() {
294     emitError(loc, "invalid TensorFlow func attribute: ") << spec;
295     return nullptr;
296   };
297 
298   if (!spec.consume_front("func<")) return emit_error();
299 
300   size_t func_name_num_read = 0;
301   Attribute func_name_attr =
302       mlir::parseAttribute(spec, context, func_name_num_read);
303   if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
304     return emit_error();
305   spec = spec.drop_front(func_name_num_read);
306 
307   if (!spec.consume_front(", ")) return emit_error();
308 
309   size_t func_attrs_num_read = 0;
310   Attribute func_attrs_attr =
311       mlir::parseAttribute(spec, context, func_attrs_num_read);
312   if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
313     return emit_error();
314   spec = spec.drop_front(func_attrs_num_read);
315 
316   if (!spec.consume_front(">")) return emit_error();
317 
318   return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
319                                  func_attrs_attr.cast<DictionaryAttr>());
320 }
321 
322 // Prints a #tf.func attribute of the following format:
323 //
324 //   #tf.func<@symbol, {attr = "value"}>
PrintFuncAttr(FuncAttr attr,DialectAsmPrinter & os)325 void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) {
326   os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">";
327 }
328 
329 }  // namespace
330 
parseAttribute(DialectAsmParser & parser,Type type) const331 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
332                                             Type type) const {
333   auto spec = parser.getFullSymbolSpec();
334   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
335 
336   if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
337 
338   if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
339 
340   return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
341 }
342 
printAttribute(Attribute attr,DialectAsmPrinter & os) const343 void TensorFlowDialect::printAttribute(Attribute attr,
344                                        DialectAsmPrinter &os) const {
345   if (auto shape_attr = attr.dyn_cast<ShapeAttr>())
346     PrintShapeAttr(shape_attr, os);
347   else if (auto func_attr = attr.dyn_cast<FuncAttr>())
348     PrintFuncAttr(func_attr, os);
349   else
350     llvm_unreachable("unexpected tensorflow attribute type");
351 }
352 
353 // Parses a type registered to this dialect.
parseType(DialectAsmParser & parser) const354 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
355   StringRef data;
356   if (parser.parseKeyword(&data)) return Type();
357 
358   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
359 
360 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
361   if (data == name) return tftype##Type::get(getContext());
362 // Custom TensorFlow types are handled separately at the end as they do partial
363 // match.
364 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
365 // NOLINTNEXTLINE
366 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
367 
368   if (data.startswith("resource")) return ParseResourceType(parser, loc);
369   if (data.startswith("variant")) return ParseVariantType(parser, loc);
370   return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
371 }
372 
373 // Prints a type registered to this dialect.
printType(Type ty,DialectAsmPrinter & os) const374 void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const {
375   assert(ty.isa<TensorFlowType>());
376 #define HANDLE_TF_TYPE(tftype, enumerant, name)        \
377   if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
378     os << name;                                        \
379     return;                                            \
380   }
381 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
382   if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
383     Print##tftype##Type(derived_ty, os);               \
384     return;                                            \
385   }
386 // NOLINTNEXTLINE
387 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
388 
389   llvm_unreachable("unexpected tensorflow type kind");
390 }
391 
392 namespace {
393 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser,Location loc)394 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser,
395                           Location loc) {
396   // Default type without inferred subtypes.
397   if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
398 
399   // Most types with subtypes have only one subtype.
400   SmallVector<TensorType, 1> subtypes;
401   do {
402     TensorType tensor_ty;
403     if (parser.parseType(tensor_ty)) return Type();
404     subtypes.push_back(tensor_ty);
405   } while (succeeded(parser.parseOptionalComma()));
406 
407   if (parser.parseGreater()) return Type();
408   return TypeWithSubtype::getChecked(subtypes, context, loc);
409 }
410 
411 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)412 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
413                           DialectAsmPrinter &os) {
414   os << type;
415   ArrayRef<TensorType> subtypes = ty.getSubtypes();
416   if (subtypes.empty()) return;
417 
418   os << "<";
419   interleaveComma(subtypes, os);
420   os << ">";
421 }
422 }  // anonymous namespace
423 
ParseResourceType(DialectAsmParser & parser,Location loc) const424 Type TensorFlowDialect::ParseResourceType(DialectAsmParser &parser,
425                                           Location loc) const {
426   return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
427 }
428 
PrintResourceType(ResourceType ty,DialectAsmPrinter & os) const429 void TensorFlowDialect::PrintResourceType(ResourceType ty,
430                                           DialectAsmPrinter &os) const {
431   return PrintTypeWithSubtype("resource", ty, os);
432 }
433 
ParseVariantType(DialectAsmParser & parser,Location loc) const434 Type TensorFlowDialect::ParseVariantType(DialectAsmParser &parser,
435                                          Location loc) const {
436   return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
437 }
438 
PrintVariantType(VariantType ty,DialectAsmPrinter & os) const439 void TensorFlowDialect::PrintVariantType(VariantType ty,
440                                          DialectAsmPrinter &os) const {
441   return PrintTypeWithSubtype("variant", ty, os);
442 }
443 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)444 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
445                                                   Attribute value, Type type,
446                                                   Location loc) {
447   return builder.create<ConstOp>(loc, type, value);
448 }
449 
450 }  // namespace TF
451 }  // namespace mlir
452