• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/types/dialect.h"
17 
18 #include <cstdint>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/SMLoc.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Traits.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Dialect.h"  // from @llvm-project
32 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
33 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
34 #include "mlir/IR/FunctionSupport.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 
39 #define GET_ATTRDEF_CLASSES
40 #include "tensorflow/core/ir/types/attributes.cc.inc"
41 
42 #define GET_TYPEDEF_CLASSES
43 #include "tensorflow/core/ir/types/types.cc.inc"
44 
45 // Generated definitions.
46 #include "tensorflow/core/ir/types/dialect.cpp.inc"
47 
48 namespace mlir {
49 namespace tf_type {
50 
51 //===----------------------------------------------------------------------===//
52 // TFType dialect.
53 //===----------------------------------------------------------------------===//
54 
55 // Dialect construction: there is one instance per context and it registers its
56 // operations, types, and interfaces here.
initialize()57 void TFTypeDialect::initialize() {
58   addAttributes<
59 #define GET_ATTRDEF_LIST
60 #include "tensorflow/core/ir/types/attributes.cc.inc"
61       >();
62   addTypes<ControlType, OpaqueTensorType,
63 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
64 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
65 #include "tensorflow/core/ir/types/types.def"
66            >();
67 }
68 
69 // Entry point for Attribute parsing, TableGen generated code will handle the
70 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const71 Attribute TFTypeDialect::parseAttribute(DialectAsmParser &parser,
72                                         Type type) const {
73   StringRef attr_tag;
74   if (failed(parser.parseKeyword(&attr_tag))) return Attribute();
75   {
76     Attribute attr;
77     auto parse_result =
78         generatedAttributeParser(getContext(), parser, attr_tag, type, attr);
79     if (parse_result.hasValue()) return attr;
80   }
81   parser.emitError(parser.getNameLoc(), "unknown tf_type attribute");
82   return Attribute();
83 }
84 
85 // Entry point for Attribute printing, TableGen generated code will handle the
86 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const87 void TFTypeDialect::printAttribute(Attribute attr,
88                                    DialectAsmPrinter &os) const {
89   (void)generatedAttributePrinter(attr, os);
90 }
91 
92 namespace {
93 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser)94 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser) {
95   // Default type without inferred subtypes.
96   if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
97 
98   // Most types with subtypes have only one subtype.
99   SmallVector<TensorType, 1> subtypes;
100   do {
101     TensorType tensor_ty;
102     if (parser.parseType(tensor_ty)) return Type();
103 
104     // Each of the subtypes should be a valid TensorFlow type.
105     // TODO(jpienaar): Remove duplication.
106     if (!IsValidTFTensorType(tensor_ty)) {
107       parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
108       return Type();
109     }
110     subtypes.push_back(tensor_ty);
111   } while (succeeded(parser.parseOptionalComma()));
112 
113   if (parser.parseGreater()) return Type();
114 
115   return TypeWithSubtype::get(subtypes, context);
116 }
117 
118 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)119 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
120                           DialectAsmPrinter &os) {
121   os << type;
122   ArrayRef<TensorType> subtypes = ty.getSubtypes();
123   if (subtypes.empty()) return;
124 
125   os << "<";
126   interleaveComma(subtypes, os);
127   os << ">";
128 }
ParseResourceType(MLIRContext * context,DialectAsmParser & parser)129 Type ParseResourceType(MLIRContext *context, DialectAsmParser &parser) {
130   return ParseTypeWithSubtype<ResourceType>(context, parser);
131 }
132 
PrintResourceType(ResourceType ty,DialectAsmPrinter & os)133 void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) {
134   return PrintTypeWithSubtype("resource", ty, os);
135 }
136 
ParseVariantType(MLIRContext * context,DialectAsmParser & parser)137 Type ParseVariantType(MLIRContext *context, DialectAsmParser &parser) {
138   return ParseTypeWithSubtype<VariantType>(context, parser);
139 }
140 
PrintVariantType(VariantType ty,DialectAsmPrinter & os)141 void PrintVariantType(VariantType ty, DialectAsmPrinter &os) {
142   return PrintTypeWithSubtype("variant", ty, os);
143 }
144 
145 }  // namespace
146 
147 // Entry point for Type parsing, TableGen generated code will handle the
148 // dispatch to the individual classes.
parseType(DialectAsmParser & parser) const149 Type TFTypeDialect::parseType(DialectAsmParser &parser) const {
150   StringRef type_tag;
151   llvm::SMLoc loc = parser.getNameLoc();
152   if (failed(parser.parseKeyword(&type_tag))) return Type();
153 
154 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
155   if (type_tag == name) return tftype##Type::get(getContext());
156 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
157 // NOLINTNEXTLINE: intended redundant include.
158 #include "tensorflow/core/ir/types/types.def"
159 
160   if (type_tag.startswith("resource")) {
161     Type ret = ParseResourceType(getContext(), parser);
162     if (!ret) parser.emitError(loc, "invalid resource type");
163     return ret;
164   }
165   if (type_tag.startswith("variant")) {
166     Type ret = ParseVariantType(getContext(), parser);
167     if (!ret) parser.emitError(loc, "invalid variant type");
168     return ret;
169   }
170 
171   Type genType;
172   auto parse_result =
173       generatedTypeParser(getContext(), parser, type_tag, genType);
174   if (parse_result.hasValue()) return genType;
175   parser.emitError(parser.getNameLoc(),
176                    "unknown type in TF graph dialect: " + type_tag);
177   return {};
178 }
179 
180 // Entry point for Type parsing, TableGen generated code will handle the
181 // dispatch to the individual classes.
printType(Type type,DialectAsmPrinter & printer) const182 void TFTypeDialect::printType(Type type, DialectAsmPrinter &printer) const {
183 #define HANDLE_TF_TYPE(tftype, enumerant, name)          \
184   if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
185     printer << name;                                     \
186     return;                                              \
187   }
188 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)   \
189   if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
190     Print##tftype##Type(derived_ty, printer);            \
191     return;                                              \
192   }
193 // NOLINTNEXTLINE: intended redundant include.
194 #include "tensorflow/core/ir/types/types.def"
195 
196   if (failed(generatedTypePrinter(type, printer)))
197     llvm::report_fatal_error("unexpected tensorflow graph type kind");
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Attributes
202 //===----------------------------------------------------------------------===//
203 
parse(MLIRContext * context,DialectAsmParser & parser,Type)204 Attribute VersionAttr::parse(MLIRContext *context, DialectAsmParser &parser,
205                              Type) {
206   if (failed(parser.parseLess())) return {};
207 
208   int32_t producer, min_consumer;
209   if (parser.parseKeyword("producer", " in tf_type version") ||
210       parser.parseEqual() || parser.parseInteger(producer) ||
211       parser.parseComma() ||
212       parser.parseKeyword("min_consumer", " in tf_type version") ||
213       parser.parseEqual() || parser.parseInteger(min_consumer))
214     return {};
215 
216   SmallVector<int32_t, 4> bad_consumers;
217   if (!parser.parseOptionalComma()) {
218     if (parser.parseKeyword("bad_consumers", " in tf_type version") ||
219         parser.parseEqual() || parser.parseLSquare())
220       return {};
221     do {
222       int32_t bad_consumer;
223       if (parser.parseInteger(bad_consumer)) return {};
224       bad_consumers.push_back(bad_consumer);
225     } while (!parser.parseOptionalComma());
226     if (parser.parseRSquare()) return {};
227   }
228   if (failed(parser.parseGreater())) return {};
229 
230   return VersionAttr::get(context, producer, min_consumer, bad_consumers);
231 }
232 
print(DialectAsmPrinter & printer) const233 void VersionAttr::print(DialectAsmPrinter &printer) const {
234   llvm::raw_ostream &os = printer.getStream();
235   os << getMnemonic();
236   os << "<producer = " << getProducer()
237      << ", min_consumer = " << getMinConsumer();
238   ArrayRef<int32_t> badConsumers = getBadConsumers();
239   if (!badConsumers.empty()) {
240     os << ", bad_consumers = [";
241     llvm::interleaveComma(badConsumers, os);
242     os << "]";
243   }
244   os << ">";
245 }
246 
247 // Print a #tf.func attribute of the following format:
248 //
249 //   #tf.func<@symbol, {attr = "value"}>
250 // or
251 //   #tf.func<"", {attr = "value"}>
252 // in case of null symbol ref.
print(DialectAsmPrinter & os) const253 void FuncAttr::print(DialectAsmPrinter &os) const {
254   if (getName().getRootReference().empty())
255     os << "func<\"\", " << getAttrs() << ">";
256   else
257     os << "func<" << getName() << ", " << getAttrs() << ">";
258 }
259 
260 // Parses a #tf.func attribute of the following format:
261 //
262 //   #tf.func<@symbol, {attr = "value"}>
263 //
264 // where the first element is a SymbolRefAttr and the second element is a
265 // DictionaryAttr.
parse(MLIRContext * context,DialectAsmParser & parser,Type type)266 Attribute FuncAttr::parse(MLIRContext *context, DialectAsmParser &parser,
267                           Type type) {
268   if (failed(parser.parseLess())) return {};
269   llvm::SMLoc loc = parser.getCurrentLocation();
270   Attribute name, dict;
271   if (failed(parser.parseAttribute(name))) {
272     parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
273     return {};
274   }
275   if (auto func_name_str = name.dyn_cast<StringAttr>()) {
276     if (!func_name_str.getValue().empty()) {
277       parser.emitError(loc)
278           << "expected empty string or symbol while parsing tf.func "
279              "attribute";
280       return {};
281     }
282     name = SymbolRefAttr::get(context, "");
283   }
284   if (!name.isa<SymbolRefAttr>()) {
285     parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
286     return {};
287   }
288   if (failed(parser.parseComma())) return {};
289   loc = parser.getCurrentLocation();
290   if (failed(parser.parseAttribute(dict)) || !dict.isa<DictionaryAttr>()) {
291     parser.emitError(loc)
292         << "expected Dictionary attribute while parsing tf.func attribute";
293     return {};
294   }
295   if (failed(parser.parseGreater())) return {};
296   return FuncAttr::get(context, name.cast<SymbolRefAttr>(),
297                        dict.cast<DictionaryAttr>());
298 }
299 
print(DialectAsmPrinter & os) const300 void PlaceholderAttr::print(DialectAsmPrinter &os) const {
301   os << "placeholder<" << StringAttr::get(getContext(), getValue()) << ">";
302 }
303 
parse(MLIRContext * context,DialectAsmParser & parser,Type type)304 Attribute PlaceholderAttr::parse(MLIRContext *context, DialectAsmParser &parser,
305                                  Type type) {
306   if (failed(parser.parseLess())) return {};
307   StringRef content;
308   if (failed(parser.parseOptionalString(&content))) {
309     parser.emitError(parser.getCurrentLocation())
310         << "expected string while parsing tf.placeholder attribute";
311     return {};
312   }
313   if (failed(parser.parseGreater())) return {};
314   return PlaceholderAttr::get(context, content);
315 }
316 
print(DialectAsmPrinter & os) const317 void ShapeAttr::print(DialectAsmPrinter &os) const {
318   os << "shape<";
319   if (hasRank()) {
320     auto print_dim = [&](int64_t dim) {
321       if (dim > -1)
322         os << dim;
323       else
324         os << "?";
325     };
326     llvm::interleave(getShape(), os, print_dim, "x");
327   } else {
328     os << "*";
329   }
330   os << ">";
331 }
332 
parse(MLIRContext * context,DialectAsmParser & parser,Type type)333 Attribute ShapeAttr::parse(MLIRContext *context, DialectAsmParser &parser,
334                            Type type) {
335   if (failed(parser.parseLess())) return {};
336 
337   if (succeeded(parser.parseOptionalStar())) {
338     if (failed(parser.parseGreater())) {
339       parser.emitError(parser.getCurrentLocation())
340           << "expected `>` after `*` when parsing a tf.shape "
341              "attribute";
342       return {};
343     }
344     return ShapeAttr::get(context, llvm::None);
345   }
346 
347   SmallVector<int64_t> shape;
348   if (failed(parser.parseOptionalGreater())) {
349     auto parse_element = [&]() {
350       shape.emplace_back();
351       llvm::SMLoc loc = parser.getCurrentLocation();
352       if (succeeded(parser.parseOptionalQuestion())) {
353         shape.back() = ShapedType::kDynamicSize;
354       } else if (failed(parser.parseInteger(shape.back())) ||
355                  shape.back() < 0) {
356         parser.emitError(loc) << "expected a positive integer or `?` when "
357                                  "parsing a tf.shape attribute";
358         return failure();
359       }
360       return success();
361     };
362     if (failed(parse_element())) return {};
363     while (failed(parser.parseOptionalGreater())) {
364       if (failed(parser.parseXInDimensionList()) || failed(parse_element()))
365         return {};
366     }
367   }
368   return ShapeAttr::get(context, llvm::makeArrayRef(shape));
369 }
370 
371 // Get or create a shape attribute.
get(MLIRContext * context,llvm::Optional<ArrayRef<int64_t>> shape)372 ShapeAttr ShapeAttr::get(MLIRContext *context,
373                          llvm::Optional<ArrayRef<int64_t>> shape) {
374   if (shape) return Base::get(context, *shape, /*unranked=*/false);
375 
376   return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
377 }
378 
379 // Get or create a shape attribute.
get(MLIRContext * context,ShapedType shaped_type)380 ShapeAttr ShapeAttr::get(MLIRContext *context, ShapedType shaped_type) {
381   if (shaped_type.hasRank())
382     return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
383 
384   return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
385 }
386 
getValue() const387 llvm::Optional<ArrayRef<int64_t>> ShapeAttr::getValue() const {
388   if (hasRank()) return getShape();
389   return llvm::None;
390 }
391 
hasRank() const392 bool ShapeAttr::hasRank() const { return !getImpl()->unranked; }
393 
getRank() const394 int64_t ShapeAttr::getRank() const {
395   assert(hasRank());
396   return getImpl()->shape.size();
397 }
398 
hasStaticShape() const399 bool ShapeAttr::hasStaticShape() const {
400   if (!hasRank()) return false;
401 
402   for (auto dim : getShape()) {
403     if (dim < 0) return false;
404   }
405 
406   return true;
407 }
408 
409 namespace {
410 // Returns the shape of the given value if it's ranked; returns llvm::None
411 // otherwise.
GetShape(Value value)412 llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(Value value) {
413   auto shaped_type = value.getType().cast<ShapedType>();
414   if (shaped_type.hasRank()) return shaped_type.getShape();
415   return llvm::None;
416 }
417 
418 // Merges cast compatible shapes and returns a more refined shape. The two
419 // shapes are cast compatible if they have the same rank and at each dimension,
420 // either both have same size or one of them is dynamic. Returns false if the
421 // given shapes are not cast compatible. The refined shape is same or more
422 // precise than the two input shapes.
GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,llvm::ArrayRef<int64_t> b_shape,llvm::SmallVectorImpl<int64_t> * refined_shape)423 bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
424                             llvm::ArrayRef<int64_t> b_shape,
425                             llvm::SmallVectorImpl<int64_t> *refined_shape) {
426   if (a_shape.size() != b_shape.size()) return false;
427   int64_t rank = a_shape.size();
428   refined_shape->reserve(rank);
429   for (auto dims : llvm::zip(a_shape, b_shape)) {
430     int64_t dim1 = std::get<0>(dims);
431     int64_t dim2 = std::get<1>(dims);
432 
433     if (ShapedType::isDynamic(dim1)) {
434       refined_shape->push_back(dim2);
435       continue;
436     }
437     if (ShapedType::isDynamic(dim2)) {
438       refined_shape->push_back(dim1);
439       continue;
440     }
441     if (dim1 == dim2) {
442       refined_shape->push_back(dim1);
443       continue;
444     }
445     return false;
446   }
447   return true;
448 }
449 
450 }  // namespace
451 
452 //===----------------------------------------------------------------------===//
453 // Utility iterators
454 //===----------------------------------------------------------------------===//
455 
OperandShapeIterator(Operation::operand_iterator it)456 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
457     : llvm::mapped_iterator<Operation::operand_iterator,
458                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
459           it, &GetShape) {}
460 
ResultShapeIterator(Operation::result_iterator it)461 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
462     : llvm::mapped_iterator<Operation::result_iterator,
463                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
464           it, &GetShape) {}
465 
466 //===----------------------------------------------------------------------===//
467 // TF types helper functions
468 //===----------------------------------------------------------------------===//
469 
classof(Type type)470 bool TensorFlowType::classof(Type type) {
471   return llvm::isa<TFTypeDialect>(type.getDialect());
472 }
classof(Type type)473 bool TensorFlowRefType::classof(Type type) {
474   return type.isa<
475 #define HANDLE_TF_TYPE(tftype, enumerant, name)
476 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
477 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
478 // NOLINTNEXTLINE
479 #include "tensorflow/core/ir/types/types.def"
480       >();
481 }
482 
get(Type type)483 TensorFlowType TensorFlowRefType::get(Type type) {
484   MLIRContext *ctx = type.getContext();
485   type = getElementTypeOrSelf(type);
486   if (type.isF16()) {
487     return HalfRefType::get(ctx);
488   } else if (type.isF32()) {
489     return FloatRefType::get(ctx);
490   } else if (type.isF64()) {
491     return DoubleRefType::get(ctx);
492   } else if (type.isBF16()) {
493     return Bfloat16RefType::get(ctx);
494   } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
495     Type etype = complex_type.getElementType();
496     if (etype.isF32()) {
497       return Complex64RefType::get(ctx);
498     } else if (etype.isF64()) {
499       return Complex128RefType::get(ctx);
500     }
501     llvm_unreachable("unexpected complex type");
502   } else if (auto itype = type.dyn_cast<IntegerType>()) {
503     switch (itype.getWidth()) {
504       case 1:
505         return BoolRefType::get(ctx);
506       case 8:
507         return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
508                                   : Int8RefType::get(ctx);
509       case 16:
510         return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
511                                   : Int16RefType::get(ctx);
512       case 32:
513         return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
514                                   : Int32RefType::get(ctx);
515       case 64:
516         return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
517                                   : Int64RefType::get(ctx);
518       default:
519         llvm_unreachable("unexpected integer type");
520     }
521   }
522 #define HANDLE_TF_TYPE(tftype, enumerant, name)        \
523   if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
524     return tftype##RefType::get(ctx);
525 
526 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
527 // NOLINTNEXTLINE
528 #include "tensorflow/core/ir/types/types.def"
529   llvm_unreachable("unexpected type kind");
530 }
531 
RemoveRef()532 Type TensorFlowRefType::RemoveRef() {
533   MLIRContext *ctx = getContext();
534   if (isa<HalfRefType>()) return FloatType::getF16(ctx);
535   if (isa<FloatRefType>()) return FloatType::getF32(ctx);
536   if (isa<DoubleRefType>()) return FloatType::getF64(ctx);
537   if (isa<Bfloat16RefType>()) return FloatType::getBF16(ctx);
538   if (isa<BoolRefType>()) return IntegerType::get(ctx, 1);
539   if (isa<Int8RefType>()) return IntegerType::get(ctx, 8);
540   if (isa<Int16RefType>()) return IntegerType::get(ctx, 16);
541   if (isa<Int32RefType>()) return IntegerType::get(ctx, 32);
542   if (isa<Int64RefType>()) return IntegerType::get(ctx, 64);
543   if (isa<Uint8RefType>())
544     return IntegerType::get(ctx, 8, IntegerType::Unsigned);
545   if (isa<Uint16RefType>())
546     return IntegerType::get(ctx, 16, IntegerType::Unsigned);
547   if (isa<Uint32RefType>())
548     return IntegerType::get(ctx, 32, IntegerType::Unsigned);
549   if (isa<Uint64RefType>())
550     return IntegerType::get(ctx, 64, IntegerType::Unsigned);
551   if (isa<Complex64RefType>()) return ComplexType::get(FloatType::getF32(ctx));
552   if (isa<Complex128RefType>()) return ComplexType::get(FloatType::getF64(ctx));
553 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
554   if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
555 
556 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
557 // NOLINTNEXTLINE
558 #include "tensorflow/core/ir/types/types.def"
559   llvm_unreachable("unexpected tensorflow ref type kind");
560 }
561 
classof(Type type)562 bool TensorFlowTypeWithSubtype::classof(Type type) {
563   return type.isa<ResourceType, VariantType>();
564 }
565 
RemoveSubtypes()566 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
567   MLIRContext *ctx = getContext();
568   if (isa<VariantType>()) return VariantType::get(ctx);
569   if (isa<ResourceType>()) return ResourceType::get(ctx);
570   llvm_unreachable("unexpected tensorflow type with subtypes kind");
571 }
572 
clone(ArrayRef<TensorType> new_subtypes)573 TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
574     ArrayRef<TensorType> new_subtypes) {
575   MLIRContext *ctx = getContext();
576   if (isa<VariantType>())
577     return VariantType::get(new_subtypes, ctx)
578         .cast<TensorFlowTypeWithSubtype>();
579   if (isa<ResourceType>())
580     return ResourceType::get(new_subtypes, ctx)
581         .cast<TensorFlowTypeWithSubtype>();
582   llvm_unreachable("unexpected tensorflow type with subtypes kind");
583 }
584 
GetSubtypes()585 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
586   if (auto variant_type = dyn_cast<VariantType>())
587     return variant_type.getSubtypes();
588   if (auto resource_type = dyn_cast<ResourceType>())
589     return resource_type.getSubtypes();
590   llvm_unreachable("unexpected tensorflow type with subtypes kind");
591 }
592 
593 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
594 // similar structure that could be extracted into helper method.
BroadcastCompatible(TypeRange lhs,TypeRange rhs)595 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) {
596   if (lhs.size() != rhs.size()) return false;
597   for (auto types : llvm::zip(lhs, rhs)) {
598     // Drop ref types because they don't affect broadcast compatibility. E.g.,
599     // `tensor<!tf_type.f32ref>` and `tensor<f32>` should be considered
600     // broadcast compatible.
601     auto lhs_type = DropRefType(std::get<0>(types));
602     auto rhs_type = DropRefType(std::get<1>(types));
603 
604     // This should be true for all TF ops:
605     auto lhs_tt = lhs_type.dyn_cast<TensorType>();
606     auto rhs_tt = rhs_type.dyn_cast<TensorType>();
607     if (!lhs_tt || !rhs_tt) {
608       if (lhs_type != rhs_type) return false;
609       continue;
610     }
611 
612     // Verify matching element types. These should be identical, except for
613     // variant type where unknown subtype is considered compatible with all
614     // subtypes.
615     auto lhs_et = lhs_tt.getElementType();
616     auto rhs_et = rhs_tt.getElementType();
617     if (lhs_et != rhs_et) {
618       // If either does not have subtypes, then the element types don't match.
619       auto lhs_wst = lhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
620       auto rhs_wst = rhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
621       if (!lhs_wst || !rhs_wst) return false;
622 
623       // Consider the subtype of variant types.
624       auto lhs_wst_st = lhs_wst.GetSubtypes();
625       auto rhs_wst_st = rhs_wst.GetSubtypes();
626       if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
627         for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
628           if (!BroadcastCompatible(std::get<0>(subtypes),
629                                    std::get<1>(subtypes)))
630             return false;
631         }
632       }
633     }
634 
635     auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
636     auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
637     if (!lhs_rt || !rhs_rt) return true;
638     SmallVector<int64_t, 4> shape;
639     return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
640                                               rhs_rt.getShape(), shape);
641   }
642   return true;
643 }
644 
645 // Given two types `a` and `b`, returns a refined type which is cast compatible
646 // with both `a` and `b` and is equal to or more precise than both of them. It
647 // returns empty Type if the input types are not cast compatible.
648 //
649 // The two types are considered cast compatible if they have dynamically equal
650 // shapes and element type. For element types that do not have subtypes, they
651 // must be equal. However for TensorFlow types such as Resource and Variant,
652 // that also have subtypes, we recursively check for subtype compatibility for
653 // Resource types and assume all variant types are cast compatible. If either
654 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
655 //
656 // The returned type is same or more precise than the input types. For example,
657 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
658 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
659 //
660 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
661 // might allow operands to either be same as result type or be a ref type
662 // corresponding to it.
GetCastCompatibleType(Type a,Type b,bool may_ignore_ref_type_a)663 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) {
664   // Fast path if everything is equal.
665   if (a == b) return b;
666 
667   auto a_tt = a.dyn_cast<TensorType>();
668   auto b_tt = b.dyn_cast<TensorType>();
669 
670   // If only one of a or b is a tensor type, they are incompatible.
671   if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
672 
673   // For non-tensor types, we do not need to worry about shape and can return
674   // early.
675   if (!a_tt && !b_tt) {
676     // Remove ref types.
677     if (may_ignore_ref_type_a) {
678       if (auto ref_type = a.dyn_cast<TensorFlowRefType>()) {
679         a = ref_type.RemoveRef();
680         if (a == b) return a;
681       }
682     }
683     if (a.getTypeID() != b.getTypeID()) return nullptr;
684 
685     // If either is not a type that contain subtypes then the types are not cast
686     // compatible.
687     auto a_wst = a.dyn_cast<TensorFlowTypeWithSubtype>();
688     auto b_wst = b.dyn_cast<TensorFlowTypeWithSubtype>();
689     if (!a_wst || !b_wst) return nullptr;
690 
691     // For Variant types we are more permissive right now and accept all pairs
692     // of Variant types. If we are more constrainted and check compatibility of
693     // subtypes, we might reject valid graphs.
694     // TODO(prakalps): Variant doesn't have a subtype, we assign it
695     // one, so we should only assign it one when we know the subtype. Then we
696     // can be more constrained and check subtypes for cast compatibility as
697     // well.
698     if (a.isa<VariantType>()) return a;
699 
700     // For Resource types, we recursively check the subtypes for cast
701     // compatibility, if possible. Otherwise treat them as compatible.
702     auto a_wst_st = a_wst.GetSubtypes();
703     auto b_wst_st = b_wst.GetSubtypes();
704     if (a_wst_st.empty() || b_wst_st.empty()) return a;
705     if (a_wst_st.size() != b_wst_st.size()) return nullptr;
706     llvm::SmallVector<TensorType, 4> refined_subtypes;
707     for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
708       Type refined_st =
709           GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
710                                 /*may_ignore_ref_type_a=*/false);
711       if (!refined_st) return nullptr;
712       refined_subtypes.push_back(refined_st.cast<TensorType>());
713     }
714 
715     return ResourceType::get(refined_subtypes, a.getContext());
716   }
717 
718   // For tensor types, check compatibility of both element type and shape.
719   Type refined_element_ty = GetCastCompatibleType(
720       a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
721   if (!refined_element_ty) return nullptr;
722 
723   if (!a_tt.hasRank() && !b_tt.hasRank()) {
724     return UnrankedTensorType::get(refined_element_ty);
725   }
726   if (!a_tt.hasRank()) {
727     return RankedTensorType::get(b_tt.getShape(), refined_element_ty);
728   }
729   if (!b_tt.hasRank()) {
730     return RankedTensorType::get(a_tt.getShape(), refined_element_ty);
731   }
732 
733   llvm::SmallVector<int64_t, 8> refined_shape;
734   if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
735     return nullptr;
736 
737   return RankedTensorType::get(refined_shape, refined_element_ty);
738 }
739 
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)740 bool HasCompatibleElementTypes(Type lhs, Type rhs,
741                                bool may_ignore_ref_type_lhs) {
742   return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
743 }
744 
AreCastCompatible(TypeRange types)745 bool AreCastCompatible(TypeRange types) {
746   Type common = types.front();
747   for (auto type : types.drop_front()) {
748     Type refined_type =
749         GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
750     if (!refined_type) return false;
751     common = refined_type;
752   }
753   return true;
754 }
755 
ArraysAreCastCompatible(TypeRange lhs,TypeRange rhs)756 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) {
757   if (lhs.size() != rhs.size()) return false;
758   for (auto pair : llvm::zip(lhs, rhs)) {
759     auto lhs_i = std::get<0>(pair);
760     auto rhs_i = std::get<1>(pair);
761     if (!AreCastCompatible({lhs_i, rhs_i})) return false;
762   }
763   return true;
764 }
765 
766 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
767 // type.
GetDefaultTypeOf(TensorFlowRefType type)768 static Type GetDefaultTypeOf(TensorFlowRefType type) {
769   return type.RemoveRef();
770 }
771 
772 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
773 // type for a composed type (such as a ref type or a type with subtypes).
774 template <typename ComposedType>
DropTypeHelper(Type ty)775 Type DropTypeHelper(Type ty) {
776   Type element_ty = getElementTypeOrSelf(ty);
777   auto composed_type = element_ty.dyn_cast<ComposedType>();
778   if (!composed_type) return ty;
779 
780   Type default_ty = GetDefaultTypeOf(composed_type);
781   if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
782     return RankedTensorType::get(ranked_ty.getShape(), default_ty);
783   } else if (ty.dyn_cast<UnrankedTensorType>()) {
784     return UnrankedTensorType::get(default_ty);
785   } else {
786     return default_ty;
787   }
788 }
789 
DropSubTypes(Type ty)790 Type DropSubTypes(Type ty) {
791   return DropTypeHelper<TensorFlowTypeWithSubtype>(ty);
792 }
793 
DropRefType(Type ty)794 Type DropRefType(Type ty) { return DropTypeHelper<TensorFlowRefType>(ty); }
795 
DropRefAndSubTypes(Type ty)796 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
797 
798 }  // namespace tf_type
799 }  // namespace mlir
800