• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/types/dialect.h"
17 
18 #include <cstdint>
19 #include <string>
20 
21 #include "absl/strings/escaping.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Traits.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Dialect.h"  // from @llvm-project
36 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
37 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
38 #include "mlir/IR/FunctionInterfaces.h"  // from @llvm-project
39 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
40 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
41 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 
44 #define GET_ATTRDEF_CLASSES
45 #include "tensorflow/core/ir/types/attributes.cc.inc"
46 #include "tensorflow/core/ir/types/attributes_enum.cc.inc"
47 
48 #define GET_TYPEDEF_CLASSES
49 #include "tensorflow/core/ir/types/types.cc.inc"
50 
51 // Generated definitions.
52 #include "tensorflow/core/ir/types/dialect.cpp.inc"
53 
54 namespace mlir {
55 namespace tf_type {
56 
57 //===----------------------------------------------------------------------===//
58 // TFType dialect.
59 //===----------------------------------------------------------------------===//
60 
61 // Dialect construction: there is one instance per context and it registers its
62 // operations, types, and interfaces here.
initialize()63 void TFTypeDialect::initialize() {
64   addAttributes<
65 #define GET_ATTRDEF_LIST
66 #include "tensorflow/core/ir/types/attributes.cc.inc"
67       >();
68   addTypes<ControlType, OpaqueTensorType,
69 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
70 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
71 #include "tensorflow/core/ir/types/types.def"
72            >();
73 }
74 
75 namespace {
76 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser)77 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser) {
78   // Default type without inferred subtypes.
79   if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
80 
81   // Most types with subtypes have only one subtype.
82   SmallVector<TensorType, 1> subtypes;
83   do {
84     TensorType tensor_ty;
85     if (parser.parseType(tensor_ty)) return Type();
86 
87     // Each of the subtypes should be a valid TensorFlow type.
88     // TODO(jpienaar): Remove duplication.
89     if (!IsValidTFTensorType(tensor_ty)) {
90       parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
91       return Type();
92     }
93     subtypes.push_back(tensor_ty);
94   } while (succeeded(parser.parseOptionalComma()));
95 
96   if (parser.parseGreater()) return Type();
97 
98   return TypeWithSubtype::get(subtypes, context);
99 }
100 
101 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)102 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
103                           DialectAsmPrinter &os) {
104   os << type;
105   ArrayRef<TensorType> subtypes = ty.getSubtypes();
106   if (subtypes.empty()) return;
107 
108   os << "<";
109   interleaveComma(subtypes, os);
110   os << ">";
111 }
ParseResourceType(MLIRContext * context,DialectAsmParser & parser)112 Type ParseResourceType(MLIRContext *context, DialectAsmParser &parser) {
113   return ParseTypeWithSubtype<ResourceType>(context, parser);
114 }
115 
PrintResourceType(ResourceType ty,DialectAsmPrinter & os)116 void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) {
117   return PrintTypeWithSubtype("resource", ty, os);
118 }
119 
ParseVariantType(MLIRContext * context,DialectAsmParser & parser)120 Type ParseVariantType(MLIRContext *context, DialectAsmParser &parser) {
121   return ParseTypeWithSubtype<VariantType>(context, parser);
122 }
123 
PrintVariantType(VariantType ty,DialectAsmPrinter & os)124 void PrintVariantType(VariantType ty, DialectAsmPrinter &os) {
125   return PrintTypeWithSubtype("variant", ty, os);
126 }
127 
128 }  // namespace
129 
130 // Entry point for Type parsing, TableGen generated code will handle the
131 // dispatch to the individual classes.
parseType(DialectAsmParser & parser) const132 Type TFTypeDialect::parseType(DialectAsmParser &parser) const {
133   StringRef type_tag;
134   llvm::SMLoc loc = parser.getNameLoc();
135 
136   Type genType;
137   auto parse_result = generatedTypeParser(parser, &type_tag, genType);
138   if (parse_result.hasValue()) return genType;
139 
140 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
141   if (type_tag == name) return tftype##Type::get(getContext());
142 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
143 // NOLINTNEXTLINE: intended redundant include.
144 #include "tensorflow/core/ir/types/types.def"
145 
146   if (type_tag.startswith("resource")) {
147     Type ret = ParseResourceType(getContext(), parser);
148     if (!ret) parser.emitError(loc, "invalid resource type");
149     return ret;
150   }
151   if (type_tag.startswith("variant")) {
152     Type ret = ParseVariantType(getContext(), parser);
153     if (!ret) parser.emitError(loc, "invalid variant type");
154     return ret;
155   }
156 
157   parser.emitError(parser.getNameLoc(),
158                    "unknown type in TF graph dialect: " + type_tag);
159   return {};
160 }
161 
162 // Entry point for Type parsing, TableGen generated code will handle the
163 // dispatch to the individual classes.
printType(Type type,DialectAsmPrinter & printer) const164 void TFTypeDialect::printType(Type type, DialectAsmPrinter &printer) const {
165 #define HANDLE_TF_TYPE(tftype, enumerant, name)          \
166   if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
167     printer << name;                                     \
168     return;                                              \
169   }
170 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)   \
171   if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
172     Print##tftype##Type(derived_ty, printer);            \
173     return;                                              \
174   }
175 // NOLINTNEXTLINE: intended redundant include.
176 #include "tensorflow/core/ir/types/types.def"
177 
178   if (failed(generatedTypePrinter(type, printer)))
179     llvm::report_fatal_error("unexpected tensorflow graph type kind");
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Attributes
184 //===----------------------------------------------------------------------===//
185 
parse(AsmParser & parser,Type)186 Attribute VersionAttr::parse(AsmParser &parser, Type) {
187   if (failed(parser.parseLess())) return {};
188 
189   int32_t producer, min_consumer;
190   if (parser.parseKeyword("producer", " in tf_type version") ||
191       parser.parseEqual() || parser.parseInteger(producer) ||
192       parser.parseComma() ||
193       parser.parseKeyword("min_consumer", " in tf_type version") ||
194       parser.parseEqual() || parser.parseInteger(min_consumer))
195     return {};
196 
197   SmallVector<int32_t, 4> bad_consumers;
198   if (!parser.parseOptionalComma()) {
199     if (parser.parseKeyword("bad_consumers", " in tf_type version") ||
200         parser.parseEqual() || parser.parseLSquare())
201       return {};
202     do {
203       int32_t bad_consumer;
204       if (parser.parseInteger(bad_consumer)) return {};
205       bad_consumers.push_back(bad_consumer);
206     } while (!parser.parseOptionalComma());
207     if (parser.parseRSquare()) return {};
208   }
209   if (failed(parser.parseGreater())) return {};
210 
211   return VersionAttr::get(parser.getContext(), producer, min_consumer,
212                           bad_consumers);
213 }
214 
print(AsmPrinter & printer) const215 void VersionAttr::print(AsmPrinter &printer) const {
216   llvm::raw_ostream &os = printer.getStream();
217   os << "<producer = " << getProducer()
218      << ", min_consumer = " << getMinConsumer();
219   ArrayRef<int32_t> badConsumers = getBadConsumers();
220   if (!badConsumers.empty()) {
221     os << ", bad_consumers = [";
222     llvm::interleaveComma(badConsumers, os);
223     os << "]";
224   }
225   os << ">";
226 }
227 
RawFullTypeAttrParser(AsmParser & parser)228 FailureOr<FullTypeAttr> RawFullTypeAttrParser(AsmParser &parser) {
229   SmallVector<FullTypeAttr> args;
230 
231   // Parse variable 'type_id'
232   llvm::StringRef type_id_str;
233   if (failed(parser.parseKeyword(&type_id_str))) {
234     parser.emitError(
235         parser.getCurrentLocation(),
236         "failed to parse TFType_FullTypeAttr parameter keyword for "
237         "'type_id'");
238     return failure();
239   }
240   Optional<FullTypeId> type_id = symbolizeFullTypeId(type_id_str);
241   if (!type_id) {
242     parser.emitError(parser.getCurrentLocation(),
243                      "failed to parse TFType_FullTypeAttr parameter "
244                      "'type_id'");
245     return failure();
246   }
247 
248   // Parse variable 'args'
249   if (parser.parseCommaSeparatedList(
250       AsmParser::Delimiter::OptionalLessGreater, [&]() {
251         FailureOr<tf_type::FullTypeAttr> arg = RawFullTypeAttrParser(parser);
252         if (failed(arg)) return failure();
253         args.push_back(*arg);
254         return success();
255       }))
256     return failure();
257 
258   // Parse variable 'attr'
259   Attribute attr;
260   parser.parseOptionalAttribute(attr);
261   return FullTypeAttr::get(parser.getContext(), static_cast<int32_t>(*type_id),
262                            args, attr);
263 }
264 
parse(AsmParser & parser,Type odsType)265 Attribute FullTypeAttr::parse(AsmParser &parser, Type odsType) {
266   if (failed(parser.parseLess())) return {};
267   FailureOr<tf_type::FullTypeAttr> ret = RawFullTypeAttrParser(parser);
268   if (succeeded(ret) && failed(parser.parseGreater())) return {};
269   return ret.getValueOr(FullTypeAttr());
270 }
271 
RawFullTypeAttrPrint(FullTypeAttr tfattr,AsmPrinter & printer)272 static void RawFullTypeAttrPrint(FullTypeAttr tfattr, AsmPrinter &printer) {
273   printer << stringifyFullTypeId(tf_type::FullTypeId(tfattr.getTypeId()));
274   if (!tfattr.getArgs().empty()) {
275     printer << "<";
276     llvm::interleaveComma(tfattr.getArgs(), printer, [&](Attribute arg) {
277       if (auto t = arg.dyn_cast<FullTypeAttr>())
278         RawFullTypeAttrPrint(t, printer);
279       else
280         printer << "<<INVALID ARG>>";
281     });
282     printer << ">";
283   }
284   if (tfattr.getAttr()) {
285     printer << ' ';
286     printer.printStrippedAttrOrType(tfattr.getAttr());
287   }
288 }
289 
print(AsmPrinter & printer) const290 void FullTypeAttr::print(AsmPrinter &printer) const {
291   printer << "<";
292   RawFullTypeAttrPrint(*this, printer);
293   printer << ">";
294 }
295 
296 // Print a #tf.func attribute of the following format:
297 //
298 //   #tf.func<@symbol, {attr = "value"}>
299 // or
300 //   #tf.func<"", {attr = "value"}>
301 // in case of null symbol ref.
print(AsmPrinter & os) const302 void FuncAttr::print(AsmPrinter &os) const {
303   if (getName().getRootReference().getValue().empty())
304     os << "<\"\", " << getAttrs() << ">";
305   else
306     os << "<" << getName() << ", " << getAttrs() << ">";
307 }
308 
309 // Parses a #tf.func attribute of the following format:
310 //
311 //   #tf.func<@symbol, {attr = "value"}>
312 //
313 // where the first element is a SymbolRefAttr and the second element is a
314 // DictionaryAttr.
parse(AsmParser & parser,Type type)315 Attribute FuncAttr::parse(AsmParser &parser, Type type) {
316   if (failed(parser.parseLess())) return {};
317   llvm::SMLoc loc = parser.getCurrentLocation();
318   Attribute name, dict;
319   if (failed(parser.parseAttribute(name))) {
320     parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
321     return {};
322   }
323   if (auto func_name_str = name.dyn_cast<StringAttr>()) {
324     if (!func_name_str.getValue().empty()) {
325       parser.emitError(loc)
326           << "expected empty string or symbol while parsing tf.func "
327              "attribute";
328       return {};
329     }
330     name = SymbolRefAttr::get(parser.getContext(), "");
331   }
332   if (!name.isa<SymbolRefAttr>()) {
333     parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
334     return {};
335   }
336   if (failed(parser.parseComma())) return {};
337   loc = parser.getCurrentLocation();
338   if (failed(parser.parseAttribute(dict)) || !dict.isa<DictionaryAttr>()) {
339     parser.emitError(loc)
340         << "expected Dictionary attribute while parsing tf.func attribute";
341     return {};
342   }
343   if (failed(parser.parseGreater())) return {};
344   return FuncAttr::get(parser.getContext(), name.cast<SymbolRefAttr>(),
345                        dict.cast<DictionaryAttr>());
346 }
347 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const348 void FuncAttr::walkImmediateSubElements(
349     function_ref<void(Attribute)> walkAttrsFn,
350     function_ref<void(Type)> walkTypesFn) const {
351   // Walk the dictionary attribute first, so that its index is always 0.
352   walkAttrsFn(getAttrs());
353   // Walk the symbol ref attribute if it isn't empty.
354   if (!getName().getRootReference().getValue().empty()) walkAttrsFn(getName());
355 }
356 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const357 Attribute FuncAttr::replaceImmediateSubElements(
358     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
359   assert(replAttrs.size() == 2 && "invalid number of replacement attributes");
360   return get(getContext(), replAttrs[1].cast<SymbolRefAttr>(),
361              replAttrs[0].cast<DictionaryAttr>());
362 }
363 
print(AsmPrinter & os) const364 void PlaceholderAttr::print(AsmPrinter &os) const {
365   os << "<" << StringAttr::get(getContext(), getValue()) << ">";
366 }
367 
parse(AsmParser & parser,Type type)368 Attribute PlaceholderAttr::parse(AsmParser &parser, Type type) {
369   if (failed(parser.parseLess())) return {};
370   std::string content;
371   if (failed(parser.parseOptionalString(&content))) {
372     parser.emitError(parser.getCurrentLocation())
373         << "expected string while parsing tf.placeholder attribute";
374     return {};
375   }
376   if (failed(parser.parseGreater())) return {};
377   return PlaceholderAttr::get(parser.getContext(), content);
378 }
379 
print(AsmPrinter & os) const380 void ShapeAttr::print(AsmPrinter &os) const {
381   os << "<";
382   if (hasRank()) {
383     auto print_dim = [&](int64_t dim) {
384       if (dim != -1)
385         os << dim;
386       else
387         os << "?";
388     };
389     llvm::interleave(getShape(), os, print_dim, "x");
390   } else {
391     os << "*";
392   }
393   os << ">";
394 }
395 
parse(AsmParser & parser,Type type)396 Attribute ShapeAttr::parse(AsmParser &parser, Type type) {
397   if (failed(parser.parseLess())) return {};
398 
399   if (succeeded(parser.parseOptionalStar())) {
400     if (failed(parser.parseGreater())) {
401       parser.emitError(parser.getCurrentLocation())
402           << "expected `>` after `*` when parsing a tf.shape "
403              "attribute";
404       return {};
405     }
406     return ShapeAttr::get(parser.getContext(), llvm::None);
407   }
408 
409   SmallVector<int64_t> shape;
410   if (failed(parser.parseOptionalGreater())) {
411     auto parse_element = [&]() {
412       shape.emplace_back();
413       llvm::SMLoc loc = parser.getCurrentLocation();
414       if (succeeded(parser.parseOptionalQuestion())) {
415         shape.back() = ShapedType::kDynamicSize;
416       } else if (failed(parser.parseInteger(shape.back()))) {
417         parser.emitError(loc)
418             << "expected an integer or `?` when parsing a tf.shape attribute";
419         return failure();
420       }
421       return success();
422     };
423     if (failed(parse_element())) return {};
424     while (failed(parser.parseOptionalGreater())) {
425       if (failed(parser.parseXInDimensionList()) || failed(parse_element()))
426         return {};
427     }
428   }
429   return ShapeAttr::get(parser.getContext(), llvm::makeArrayRef(shape));
430 }
431 
432 // Get or create a shape attribute.
get(MLIRContext * context,llvm::Optional<ArrayRef<int64_t>> shape)433 ShapeAttr ShapeAttr::get(MLIRContext *context,
434                          llvm::Optional<ArrayRef<int64_t>> shape) {
435   if (shape) return Base::get(context, *shape, /*unranked=*/false);
436 
437   return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
438 }
439 
440 // Get or create a shape attribute.
get(MLIRContext * context,ShapedType shaped_type)441 ShapeAttr ShapeAttr::get(MLIRContext *context, ShapedType shaped_type) {
442   if (shaped_type.hasRank())
443     return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
444 
445   return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
446 }
447 
getValue() const448 llvm::Optional<ArrayRef<int64_t>> ShapeAttr::getValue() const {
449   if (hasRank()) return getShape();
450   return llvm::None;
451 }
452 
hasRank() const453 bool ShapeAttr::hasRank() const { return !getImpl()->unranked; }
454 
getRank() const455 int64_t ShapeAttr::getRank() const {
456   assert(hasRank());
457   return getImpl()->shape.size();
458 }
459 
hasStaticShape() const460 bool ShapeAttr::hasStaticShape() const {
461   if (!hasRank()) return false;
462 
463   for (auto dim : getShape()) {
464     if (dim < 0) return false;
465   }
466 
467   return true;
468 }
469 
470 namespace {
471 // Returns the shape of the given value if it's ranked; returns llvm::None
472 // otherwise.
GetShape(Value value)473 Optional<ArrayRef<int64_t>> GetShape(Value value) {
474   auto shaped_type = value.getType().cast<ShapedType>();
475   if (shaped_type.hasRank()) return shaped_type.getShape();
476   return llvm::None;
477 }
478 
479 // Merges cast compatible shapes and returns a more refined shape. The two
480 // shapes are cast compatible if they have the same rank and at each dimension,
481 // either both have same size or one of them is dynamic. Returns false if the
482 // given shapes are not cast compatible. The refined shape is same or more
483 // precise than the two input shapes.
GetCastCompatibleShape(ArrayRef<int64_t> a_shape,ArrayRef<int64_t> b_shape,SmallVectorImpl<int64_t> * refined_shape)484 bool GetCastCompatibleShape(ArrayRef<int64_t> a_shape,
485                             ArrayRef<int64_t> b_shape,
486                             SmallVectorImpl<int64_t> *refined_shape) {
487   if (a_shape.size() != b_shape.size()) return false;
488   int64_t rank = a_shape.size();
489   refined_shape->reserve(rank);
490   for (auto dims : llvm::zip(a_shape, b_shape)) {
491     int64_t dim1 = std::get<0>(dims);
492     int64_t dim2 = std::get<1>(dims);
493 
494     if (ShapedType::isDynamic(dim1)) {
495       refined_shape->push_back(dim2);
496       continue;
497     }
498     if (ShapedType::isDynamic(dim2)) {
499       refined_shape->push_back(dim1);
500       continue;
501     }
502     if (dim1 == dim2) {
503       refined_shape->push_back(dim1);
504       continue;
505     }
506     return false;
507   }
508   return true;
509 }
510 
511 }  // namespace
512 
513 //===----------------------------------------------------------------------===//
514 // Utility iterators
515 //===----------------------------------------------------------------------===//
516 
OperandShapeIterator(Operation::operand_iterator it)517 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
518     : llvm::mapped_iterator<Operation::operand_iterator,
519                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
520           it, &GetShape) {}
521 
ResultShapeIterator(Operation::result_iterator it)522 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
523     : llvm::mapped_iterator<Operation::result_iterator,
524                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
525           it, &GetShape) {}
526 
527 //===----------------------------------------------------------------------===//
528 // TF types helper functions
529 //===----------------------------------------------------------------------===//
530 
classof(Type type)531 bool TensorFlowType::classof(Type type) {
532   return llvm::isa<TFTypeDialect>(type.getDialect());
533 }
classof(Type type)534 bool TensorFlowRefType::classof(Type type) {
535   return type.isa<
536 #define HANDLE_TF_TYPE(tftype, enumerant, name)
537 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
538 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
539 // NOLINTNEXTLINE
540 #include "tensorflow/core/ir/types/types.def"
541       >();
542 }
543 
get(Type type)544 TensorFlowType TensorFlowRefType::get(Type type) {
545   MLIRContext *ctx = type.getContext();
546   type = getElementTypeOrSelf(type);
547   if (type.isF16()) {
548     return HalfRefType::get(ctx);
549   } else if (type.isF32()) {
550     return FloatRefType::get(ctx);
551   } else if (type.isF64()) {
552     return DoubleRefType::get(ctx);
553   } else if (type.isBF16()) {
554     return Bfloat16RefType::get(ctx);
555   } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
556     Type etype = complex_type.getElementType();
557     if (etype.isF32()) {
558       return Complex64RefType::get(ctx);
559     } else if (etype.isF64()) {
560       return Complex128RefType::get(ctx);
561     }
562     llvm_unreachable("unexpected complex type");
563   } else if (auto itype = type.dyn_cast<IntegerType>()) {
564     switch (itype.getWidth()) {
565       case 1:
566         return BoolRefType::get(ctx);
567       case 8:
568         return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
569                                   : Int8RefType::get(ctx);
570       case 16:
571         return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
572                                   : Int16RefType::get(ctx);
573       case 32:
574         return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
575                                   : Int32RefType::get(ctx);
576       case 64:
577         return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
578                                   : Int64RefType::get(ctx);
579       default:
580         llvm_unreachable("unexpected integer type");
581     }
582   }
583 #define HANDLE_TF_TYPE(tftype, enumerant, name)        \
584   if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
585     return tftype##RefType::get(ctx);
586 
587 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
588 // NOLINTNEXTLINE
589 #include "tensorflow/core/ir/types/types.def"
590   llvm_unreachable("unexpected type kind");
591 }
592 
RemoveRef()593 Type TensorFlowRefType::RemoveRef() {
594   MLIRContext *ctx = getContext();
595   if (isa<HalfRefType>()) return FloatType::getF16(ctx);
596   if (isa<FloatRefType>()) return FloatType::getF32(ctx);
597   if (isa<DoubleRefType>()) return FloatType::getF64(ctx);
598   if (isa<Bfloat16RefType>()) return FloatType::getBF16(ctx);
599   if (isa<BoolRefType>()) return IntegerType::get(ctx, 1);
600   if (isa<Int8RefType>()) return IntegerType::get(ctx, 8);
601   if (isa<Int16RefType>()) return IntegerType::get(ctx, 16);
602   if (isa<Int32RefType>()) return IntegerType::get(ctx, 32);
603   if (isa<Int64RefType>()) return IntegerType::get(ctx, 64);
604   if (isa<Uint8RefType>())
605     return IntegerType::get(ctx, 8, IntegerType::Unsigned);
606   if (isa<Uint16RefType>())
607     return IntegerType::get(ctx, 16, IntegerType::Unsigned);
608   if (isa<Uint32RefType>())
609     return IntegerType::get(ctx, 32, IntegerType::Unsigned);
610   if (isa<Uint64RefType>())
611     return IntegerType::get(ctx, 64, IntegerType::Unsigned);
612   if (isa<Complex64RefType>()) return ComplexType::get(FloatType::getF32(ctx));
613   if (isa<Complex128RefType>()) return ComplexType::get(FloatType::getF64(ctx));
614 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
615   if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
616 
617 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
618 // NOLINTNEXTLINE
619 #include "tensorflow/core/ir/types/types.def"
620   llvm_unreachable("unexpected tensorflow ref type kind");
621 }
622 
classof(Type type)623 bool TensorFlowTypeWithSubtype::classof(Type type) {
624   return type.isa<ResourceType, VariantType>();
625 }
626 
RemoveSubtypes()627 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
628   MLIRContext *ctx = getContext();
629   if (isa<VariantType>()) return VariantType::get(ctx);
630   if (isa<ResourceType>()) return ResourceType::get(ctx);
631   llvm_unreachable("unexpected tensorflow type with subtypes kind");
632 }
633 
clone(ArrayRef<TensorType> new_subtypes)634 TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
635     ArrayRef<TensorType> new_subtypes) {
636   MLIRContext *ctx = getContext();
637   if (isa<VariantType>())
638     return VariantType::get(new_subtypes, ctx)
639         .cast<TensorFlowTypeWithSubtype>();
640   if (isa<ResourceType>())
641     return ResourceType::get(new_subtypes, ctx)
642         .cast<TensorFlowTypeWithSubtype>();
643   llvm_unreachable("unexpected tensorflow type with subtypes kind");
644 }
645 
GetSubtypes()646 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
647   if (auto variant_type = dyn_cast<VariantType>())
648     return variant_type.getSubtypes();
649   if (auto resource_type = dyn_cast<ResourceType>())
650     return resource_type.getSubtypes();
651   llvm_unreachable("unexpected tensorflow type with subtypes kind");
652 }
653 
654 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
655 // similar structure that could be extracted into helper method.
BroadcastCompatible(TypeRange lhs,TypeRange rhs)656 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) {
657   if (lhs.size() != rhs.size()) return false;
658   for (auto types : llvm::zip(lhs, rhs)) {
659     // Drop ref types because they don't affect broadcast compatibility. E.g.,
660     // `tensor<!tf_type.f32ref>` and `tensor<f32>` should be considered
661     // broadcast compatible.
662     auto lhs_type = DropRefType(std::get<0>(types));
663     auto rhs_type = DropRefType(std::get<1>(types));
664 
665     // This should be true for all TF ops:
666     auto lhs_tt = lhs_type.dyn_cast<TensorType>();
667     auto rhs_tt = rhs_type.dyn_cast<TensorType>();
668     if (!lhs_tt || !rhs_tt) {
669       if (lhs_type != rhs_type) return false;
670       continue;
671     }
672 
673     // Verify matching element types. These should be identical, except for
674     // variant type where unknown subtype is considered compatible with all
675     // subtypes.
676     auto lhs_et = lhs_tt.getElementType();
677     auto rhs_et = rhs_tt.getElementType();
678     if (lhs_et != rhs_et) {
679       // If either does not have subtypes, then the element types don't match.
680       auto lhs_wst = lhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
681       auto rhs_wst = rhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
682       if (!lhs_wst || !rhs_wst) return false;
683 
684       // Consider the subtype of variant types.
685       auto lhs_wst_st = lhs_wst.GetSubtypes();
686       auto rhs_wst_st = rhs_wst.GetSubtypes();
687       if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
688         for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
689           if (!BroadcastCompatible(std::get<0>(subtypes),
690                                    std::get<1>(subtypes)))
691             return false;
692         }
693       }
694     }
695 
696     auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
697     auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
698     if (!lhs_rt || !rhs_rt) return true;
699     SmallVector<int64_t, 4> shape;
700     return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
701                                               rhs_rt.getShape(), shape);
702   }
703   return true;
704 }
705 
706 // Given two types `a` and `b`, returns a refined type which is cast compatible
707 // with both `a` and `b` and is equal to or more precise than both of them. It
708 // returns empty Type if the input types are not cast compatible.
709 //
710 // The two types are considered cast compatible if they have dynamically equal
711 // shapes and element type. For element types that do not have subtypes, they
712 // must be equal. However for TensorFlow types such as Resource and Variant,
713 // that also have subtypes, we recursively check for subtype compatibility for
714 // Resource types and assume all variant types are cast compatible. If either
715 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
716 //
717 // The returned type is same or more precise than the input types. For example,
718 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
719 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
720 //
721 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
722 // might allow operands to either be same as result type or be a ref type
723 // corresponding to it.
GetCastCompatibleType(Type a,Type b,bool may_ignore_ref_type_a)724 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) {
725   // Fast path if everything is equal.
726   if (a == b) return b;
727 
728   auto a_tt = a.dyn_cast<TensorType>();
729   auto b_tt = b.dyn_cast<TensorType>();
730 
731   // If only one of a or b is a tensor type, they are incompatible.
732   if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
733 
734   // For non-tensor types, we do not need to worry about shape and can return
735   // early.
736   if (!a_tt && !b_tt) {
737     // Remove ref types.
738     if (may_ignore_ref_type_a) {
739       if (auto ref_type = a.dyn_cast<TensorFlowRefType>()) {
740         a = ref_type.RemoveRef();
741         if (a == b) return a;
742       }
743     }
744     if (a.getTypeID() != b.getTypeID()) return nullptr;
745 
746     // If either is not a type that contain subtypes then the types are not cast
747     // compatible.
748     auto a_wst = a.dyn_cast<TensorFlowTypeWithSubtype>();
749     auto b_wst = b.dyn_cast<TensorFlowTypeWithSubtype>();
750     if (!a_wst || !b_wst) return nullptr;
751 
752     // For Variant types we are more permissive right now and accept all pairs
753     // of Variant types. If we are more constrainted and check compatibility of
754     // subtypes, we might reject valid graphs.
755     // TODO(prakalps): Variant doesn't have a subtype, we assign it
756     // one, so we should only assign it one when we know the subtype. Then we
757     // can be more constrained and check subtypes for cast compatibility as
758     // well.
759     if (a.isa<VariantType>()) return a;
760     if (b.isa<VariantType>()) return b;
761 
762     // For Resource types, we recursively check the subtypes for cast
763     // compatibility, if possible. Otherwise treat them as compatible.
764     auto a_wst_st = a_wst.GetSubtypes();
765     auto b_wst_st = b_wst.GetSubtypes();
766     if (a_wst_st.empty()) return b;
767     if (b_wst_st.empty()) return a;
768     if (a_wst_st.size() != b_wst_st.size()) return nullptr;
769     SmallVector<TensorType, 4> refined_subtypes;
770     for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
771       Type refined_st =
772           GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
773                                 /*may_ignore_ref_type_a=*/false);
774       if (!refined_st) return nullptr;
775       refined_subtypes.push_back(refined_st.cast<TensorType>());
776     }
777 
778     return ResourceType::get(refined_subtypes, a.getContext());
779   }
780 
781   // For tensor types, check compatibility of both element type and shape.
782   Type refined_element_ty = GetCastCompatibleType(
783       a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
784   if (!refined_element_ty) return nullptr;
785 
786   if (!a_tt.hasRank() && !b_tt.hasRank()) {
787     return UnrankedTensorType::get(refined_element_ty);
788   }
789   if (!a_tt.hasRank()) {
790     return RankedTensorType::get(b_tt.getShape(), refined_element_ty);
791   }
792   if (!b_tt.hasRank()) {
793     return RankedTensorType::get(a_tt.getShape(), refined_element_ty);
794   }
795 
796   SmallVector<int64_t, 4> refined_shape;
797   if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
798     return nullptr;
799 
800   return RankedTensorType::get(refined_shape, refined_element_ty);
801 }
802 
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)803 bool HasCompatibleElementTypes(Type lhs, Type rhs,
804                                bool may_ignore_ref_type_lhs) {
805   return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
806 }
807 
AreCastCompatible(TypeRange types)808 bool AreCastCompatible(TypeRange types) {
809   Type common = types.front();
810   for (auto type : types.drop_front()) {
811     Type refined_type =
812         GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
813     if (!refined_type) return false;
814     common = refined_type;
815   }
816   return true;
817 }
818 
ArraysAreCastCompatible(TypeRange lhs,TypeRange rhs)819 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) {
820   if (lhs.size() != rhs.size()) return false;
821   for (auto pair : llvm::zip(lhs, rhs)) {
822     auto lhs_i = std::get<0>(pair);
823     auto rhs_i = std::get<1>(pair);
824     if (!AreCastCompatible({lhs_i, rhs_i})) return false;
825   }
826   return true;
827 }
828 
829 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
830 // type.
GetDefaultTypeOf(TensorFlowRefType type)831 static Type GetDefaultTypeOf(TensorFlowRefType type) {
832   return type.RemoveRef();
833 }
834 
835 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
836 // type for a composed type (such as a ref type or a type with subtypes).
837 template <typename ComposedType>
DropTypeHelper(Type ty)838 Type DropTypeHelper(Type ty) {
839   Type element_ty = getElementTypeOrSelf(ty);
840   auto composed_type = element_ty.dyn_cast<ComposedType>();
841   if (!composed_type) return ty;
842 
843   Type default_ty = GetDefaultTypeOf(composed_type);
844   if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
845     return RankedTensorType::get(ranked_ty.getShape(), default_ty);
846   } else if (ty.dyn_cast<UnrankedTensorType>()) {
847     return UnrankedTensorType::get(default_ty);
848   } else {
849     return default_ty;
850   }
851 }
852 
DropSubTypes(Type ty)853 Type DropSubTypes(Type ty) {
854   return DropTypeHelper<TensorFlowTypeWithSubtype>(ty);
855 }
856 
DropRefType(Type ty)857 Type DropRefType(Type ty) { return DropTypeHelper<TensorFlowRefType>(ty); }
858 
DropRefAndSubTypes(Type ty)859 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
860 
parse(AsmParser & parser,Type type)861 Attribute TensorProtoAttr::parse(AsmParser &parser, Type type) {
862   if (parser.parseColon()) {
863     return nullptr;
864   }
865 
866   std::string data;
867   if (parser.parseString(&data)) {
868     return nullptr;
869   }
870   if (data.size() < 2 || data.substr(0, 2) != "0x") {
871     parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`");
872     return nullptr;
873   }
874 
875   std::string bytes_data = absl::HexStringToBytes(data.substr(2));
876   return TensorProtoAttr::get(type, bytes_data);
877 }
878 
print(mlir::AsmPrinter & printer) const879 void TensorProtoAttr::print(mlir::AsmPrinter &printer) const {
880   StringRef bytes_str = getValue();
881   printer << " : \"0x" << llvm::toHex(bytes_str) << "\"";
882 }
883 
884 }  // namespace tf_type
885 }  // namespace mlir
886