1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/types/dialect.h"
17
18 #include <cstdint>
19 #include <string>
20
21 #include "absl/strings/escaping.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Traits.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Dialect.h" // from @llvm-project
36 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
37 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
38 #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
39 #include "mlir/IR/MLIRContext.h" // from @llvm-project
40 #include "mlir/IR/OpImplementation.h" // from @llvm-project
41 #include "mlir/IR/OperationSupport.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43
44 #define GET_ATTRDEF_CLASSES
45 #include "tensorflow/core/ir/types/attributes.cc.inc"
46 #include "tensorflow/core/ir/types/attributes_enum.cc.inc"
47
48 #define GET_TYPEDEF_CLASSES
49 #include "tensorflow/core/ir/types/types.cc.inc"
50
51 // Generated definitions.
52 #include "tensorflow/core/ir/types/dialect.cpp.inc"
53
54 namespace mlir {
55 namespace tf_type {
56
57 //===----------------------------------------------------------------------===//
58 // TFType dialect.
59 //===----------------------------------------------------------------------===//
60
61 // Dialect construction: there is one instance per context and it registers its
62 // operations, types, and interfaces here.
initialize()63 void TFTypeDialect::initialize() {
64 addAttributes<
65 #define GET_ATTRDEF_LIST
66 #include "tensorflow/core/ir/types/attributes.cc.inc"
67 >();
68 addTypes<ControlType, OpaqueTensorType,
69 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
70 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
71 #include "tensorflow/core/ir/types/types.def"
72 >();
73 }
74
75 namespace {
76 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser)77 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser) {
78 // Default type without inferred subtypes.
79 if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
80
81 // Most types with subtypes have only one subtype.
82 SmallVector<TensorType, 1> subtypes;
83 do {
84 TensorType tensor_ty;
85 if (parser.parseType(tensor_ty)) return Type();
86
87 // Each of the subtypes should be a valid TensorFlow type.
88 // TODO(jpienaar): Remove duplication.
89 if (!IsValidTFTensorType(tensor_ty)) {
90 parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
91 return Type();
92 }
93 subtypes.push_back(tensor_ty);
94 } while (succeeded(parser.parseOptionalComma()));
95
96 if (parser.parseGreater()) return Type();
97
98 return TypeWithSubtype::get(subtypes, context);
99 }
100
101 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)102 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
103 DialectAsmPrinter &os) {
104 os << type;
105 ArrayRef<TensorType> subtypes = ty.getSubtypes();
106 if (subtypes.empty()) return;
107
108 os << "<";
109 interleaveComma(subtypes, os);
110 os << ">";
111 }
ParseResourceType(MLIRContext * context,DialectAsmParser & parser)112 Type ParseResourceType(MLIRContext *context, DialectAsmParser &parser) {
113 return ParseTypeWithSubtype<ResourceType>(context, parser);
114 }
115
PrintResourceType(ResourceType ty,DialectAsmPrinter & os)116 void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) {
117 return PrintTypeWithSubtype("resource", ty, os);
118 }
119
ParseVariantType(MLIRContext * context,DialectAsmParser & parser)120 Type ParseVariantType(MLIRContext *context, DialectAsmParser &parser) {
121 return ParseTypeWithSubtype<VariantType>(context, parser);
122 }
123
PrintVariantType(VariantType ty,DialectAsmPrinter & os)124 void PrintVariantType(VariantType ty, DialectAsmPrinter &os) {
125 return PrintTypeWithSubtype("variant", ty, os);
126 }
127
128 } // namespace
129
130 // Entry point for Type parsing, TableGen generated code will handle the
131 // dispatch to the individual classes.
parseType(DialectAsmParser & parser) const132 Type TFTypeDialect::parseType(DialectAsmParser &parser) const {
133 StringRef type_tag;
134 llvm::SMLoc loc = parser.getNameLoc();
135
136 Type genType;
137 auto parse_result = generatedTypeParser(parser, &type_tag, genType);
138 if (parse_result.hasValue()) return genType;
139
140 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
141 if (type_tag == name) return tftype##Type::get(getContext());
142 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
143 // NOLINTNEXTLINE: intended redundant include.
144 #include "tensorflow/core/ir/types/types.def"
145
146 if (type_tag.startswith("resource")) {
147 Type ret = ParseResourceType(getContext(), parser);
148 if (!ret) parser.emitError(loc, "invalid resource type");
149 return ret;
150 }
151 if (type_tag.startswith("variant")) {
152 Type ret = ParseVariantType(getContext(), parser);
153 if (!ret) parser.emitError(loc, "invalid variant type");
154 return ret;
155 }
156
157 parser.emitError(parser.getNameLoc(),
158 "unknown type in TF graph dialect: " + type_tag);
159 return {};
160 }
161
162 // Entry point for Type parsing, TableGen generated code will handle the
163 // dispatch to the individual classes.
printType(Type type,DialectAsmPrinter & printer) const164 void TFTypeDialect::printType(Type type, DialectAsmPrinter &printer) const {
165 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
166 if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
167 printer << name; \
168 return; \
169 }
170 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
171 if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
172 Print##tftype##Type(derived_ty, printer); \
173 return; \
174 }
175 // NOLINTNEXTLINE: intended redundant include.
176 #include "tensorflow/core/ir/types/types.def"
177
178 if (failed(generatedTypePrinter(type, printer)))
179 llvm::report_fatal_error("unexpected tensorflow graph type kind");
180 }
181
182 //===----------------------------------------------------------------------===//
183 // Attributes
184 //===----------------------------------------------------------------------===//
185
parse(AsmParser & parser,Type)186 Attribute VersionAttr::parse(AsmParser &parser, Type) {
187 if (failed(parser.parseLess())) return {};
188
189 int32_t producer, min_consumer;
190 if (parser.parseKeyword("producer", " in tf_type version") ||
191 parser.parseEqual() || parser.parseInteger(producer) ||
192 parser.parseComma() ||
193 parser.parseKeyword("min_consumer", " in tf_type version") ||
194 parser.parseEqual() || parser.parseInteger(min_consumer))
195 return {};
196
197 SmallVector<int32_t, 4> bad_consumers;
198 if (!parser.parseOptionalComma()) {
199 if (parser.parseKeyword("bad_consumers", " in tf_type version") ||
200 parser.parseEqual() || parser.parseLSquare())
201 return {};
202 do {
203 int32_t bad_consumer;
204 if (parser.parseInteger(bad_consumer)) return {};
205 bad_consumers.push_back(bad_consumer);
206 } while (!parser.parseOptionalComma());
207 if (parser.parseRSquare()) return {};
208 }
209 if (failed(parser.parseGreater())) return {};
210
211 return VersionAttr::get(parser.getContext(), producer, min_consumer,
212 bad_consumers);
213 }
214
print(AsmPrinter & printer) const215 void VersionAttr::print(AsmPrinter &printer) const {
216 llvm::raw_ostream &os = printer.getStream();
217 os << "<producer = " << getProducer()
218 << ", min_consumer = " << getMinConsumer();
219 ArrayRef<int32_t> badConsumers = getBadConsumers();
220 if (!badConsumers.empty()) {
221 os << ", bad_consumers = [";
222 llvm::interleaveComma(badConsumers, os);
223 os << "]";
224 }
225 os << ">";
226 }
227
RawFullTypeAttrParser(AsmParser & parser)228 FailureOr<FullTypeAttr> RawFullTypeAttrParser(AsmParser &parser) {
229 SmallVector<FullTypeAttr> args;
230
231 // Parse variable 'type_id'
232 llvm::StringRef type_id_str;
233 if (failed(parser.parseKeyword(&type_id_str))) {
234 parser.emitError(
235 parser.getCurrentLocation(),
236 "failed to parse TFType_FullTypeAttr parameter keyword for "
237 "'type_id'");
238 return failure();
239 }
240 Optional<FullTypeId> type_id = symbolizeFullTypeId(type_id_str);
241 if (!type_id) {
242 parser.emitError(parser.getCurrentLocation(),
243 "failed to parse TFType_FullTypeAttr parameter "
244 "'type_id'");
245 return failure();
246 }
247
248 // Parse variable 'args'
249 if (parser.parseCommaSeparatedList(
250 AsmParser::Delimiter::OptionalLessGreater, [&]() {
251 FailureOr<tf_type::FullTypeAttr> arg = RawFullTypeAttrParser(parser);
252 if (failed(arg)) return failure();
253 args.push_back(*arg);
254 return success();
255 }))
256 return failure();
257
258 // Parse variable 'attr'
259 Attribute attr;
260 parser.parseOptionalAttribute(attr);
261 return FullTypeAttr::get(parser.getContext(), static_cast<int32_t>(*type_id),
262 args, attr);
263 }
264
parse(AsmParser & parser,Type odsType)265 Attribute FullTypeAttr::parse(AsmParser &parser, Type odsType) {
266 if (failed(parser.parseLess())) return {};
267 FailureOr<tf_type::FullTypeAttr> ret = RawFullTypeAttrParser(parser);
268 if (succeeded(ret) && failed(parser.parseGreater())) return {};
269 return ret.getValueOr(FullTypeAttr());
270 }
271
RawFullTypeAttrPrint(FullTypeAttr tfattr,AsmPrinter & printer)272 static void RawFullTypeAttrPrint(FullTypeAttr tfattr, AsmPrinter &printer) {
273 printer << stringifyFullTypeId(tf_type::FullTypeId(tfattr.getTypeId()));
274 if (!tfattr.getArgs().empty()) {
275 printer << "<";
276 llvm::interleaveComma(tfattr.getArgs(), printer, [&](Attribute arg) {
277 if (auto t = arg.dyn_cast<FullTypeAttr>())
278 RawFullTypeAttrPrint(t, printer);
279 else
280 printer << "<<INVALID ARG>>";
281 });
282 printer << ">";
283 }
284 if (tfattr.getAttr()) {
285 printer << ' ';
286 printer.printStrippedAttrOrType(tfattr.getAttr());
287 }
288 }
289
print(AsmPrinter & printer) const290 void FullTypeAttr::print(AsmPrinter &printer) const {
291 printer << "<";
292 RawFullTypeAttrPrint(*this, printer);
293 printer << ">";
294 }
295
296 // Print a #tf.func attribute of the following format:
297 //
298 // #tf.func<@symbol, {attr = "value"}>
299 // or
300 // #tf.func<"", {attr = "value"}>
301 // in case of null symbol ref.
print(AsmPrinter & os) const302 void FuncAttr::print(AsmPrinter &os) const {
303 if (getName().getRootReference().getValue().empty())
304 os << "<\"\", " << getAttrs() << ">";
305 else
306 os << "<" << getName() << ", " << getAttrs() << ">";
307 }
308
309 // Parses a #tf.func attribute of the following format:
310 //
311 // #tf.func<@symbol, {attr = "value"}>
312 //
313 // where the first element is a SymbolRefAttr and the second element is a
314 // DictionaryAttr.
parse(AsmParser & parser,Type type)315 Attribute FuncAttr::parse(AsmParser &parser, Type type) {
316 if (failed(parser.parseLess())) return {};
317 llvm::SMLoc loc = parser.getCurrentLocation();
318 Attribute name, dict;
319 if (failed(parser.parseAttribute(name))) {
320 parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
321 return {};
322 }
323 if (auto func_name_str = name.dyn_cast<StringAttr>()) {
324 if (!func_name_str.getValue().empty()) {
325 parser.emitError(loc)
326 << "expected empty string or symbol while parsing tf.func "
327 "attribute";
328 return {};
329 }
330 name = SymbolRefAttr::get(parser.getContext(), "");
331 }
332 if (!name.isa<SymbolRefAttr>()) {
333 parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
334 return {};
335 }
336 if (failed(parser.parseComma())) return {};
337 loc = parser.getCurrentLocation();
338 if (failed(parser.parseAttribute(dict)) || !dict.isa<DictionaryAttr>()) {
339 parser.emitError(loc)
340 << "expected Dictionary attribute while parsing tf.func attribute";
341 return {};
342 }
343 if (failed(parser.parseGreater())) return {};
344 return FuncAttr::get(parser.getContext(), name.cast<SymbolRefAttr>(),
345 dict.cast<DictionaryAttr>());
346 }
347
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const348 void FuncAttr::walkImmediateSubElements(
349 function_ref<void(Attribute)> walkAttrsFn,
350 function_ref<void(Type)> walkTypesFn) const {
351 // Walk the dictionary attribute first, so that its index is always 0.
352 walkAttrsFn(getAttrs());
353 // Walk the symbol ref attribute if it isn't empty.
354 if (!getName().getRootReference().getValue().empty()) walkAttrsFn(getName());
355 }
356
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const357 Attribute FuncAttr::replaceImmediateSubElements(
358 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
359 assert(replAttrs.size() == 2 && "invalid number of replacement attributes");
360 return get(getContext(), replAttrs[1].cast<SymbolRefAttr>(),
361 replAttrs[0].cast<DictionaryAttr>());
362 }
363
print(AsmPrinter & os) const364 void PlaceholderAttr::print(AsmPrinter &os) const {
365 os << "<" << StringAttr::get(getContext(), getValue()) << ">";
366 }
367
parse(AsmParser & parser,Type type)368 Attribute PlaceholderAttr::parse(AsmParser &parser, Type type) {
369 if (failed(parser.parseLess())) return {};
370 std::string content;
371 if (failed(parser.parseOptionalString(&content))) {
372 parser.emitError(parser.getCurrentLocation())
373 << "expected string while parsing tf.placeholder attribute";
374 return {};
375 }
376 if (failed(parser.parseGreater())) return {};
377 return PlaceholderAttr::get(parser.getContext(), content);
378 }
379
print(AsmPrinter & os) const380 void ShapeAttr::print(AsmPrinter &os) const {
381 os << "<";
382 if (hasRank()) {
383 auto print_dim = [&](int64_t dim) {
384 if (dim != -1)
385 os << dim;
386 else
387 os << "?";
388 };
389 llvm::interleave(getShape(), os, print_dim, "x");
390 } else {
391 os << "*";
392 }
393 os << ">";
394 }
395
parse(AsmParser & parser,Type type)396 Attribute ShapeAttr::parse(AsmParser &parser, Type type) {
397 if (failed(parser.parseLess())) return {};
398
399 if (succeeded(parser.parseOptionalStar())) {
400 if (failed(parser.parseGreater())) {
401 parser.emitError(parser.getCurrentLocation())
402 << "expected `>` after `*` when parsing a tf.shape "
403 "attribute";
404 return {};
405 }
406 return ShapeAttr::get(parser.getContext(), llvm::None);
407 }
408
409 SmallVector<int64_t> shape;
410 if (failed(parser.parseOptionalGreater())) {
411 auto parse_element = [&]() {
412 shape.emplace_back();
413 llvm::SMLoc loc = parser.getCurrentLocation();
414 if (succeeded(parser.parseOptionalQuestion())) {
415 shape.back() = ShapedType::kDynamicSize;
416 } else if (failed(parser.parseInteger(shape.back()))) {
417 parser.emitError(loc)
418 << "expected an integer or `?` when parsing a tf.shape attribute";
419 return failure();
420 }
421 return success();
422 };
423 if (failed(parse_element())) return {};
424 while (failed(parser.parseOptionalGreater())) {
425 if (failed(parser.parseXInDimensionList()) || failed(parse_element()))
426 return {};
427 }
428 }
429 return ShapeAttr::get(parser.getContext(), llvm::makeArrayRef(shape));
430 }
431
432 // Get or create a shape attribute.
get(MLIRContext * context,llvm::Optional<ArrayRef<int64_t>> shape)433 ShapeAttr ShapeAttr::get(MLIRContext *context,
434 llvm::Optional<ArrayRef<int64_t>> shape) {
435 if (shape) return Base::get(context, *shape, /*unranked=*/false);
436
437 return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
438 }
439
440 // Get or create a shape attribute.
get(MLIRContext * context,ShapedType shaped_type)441 ShapeAttr ShapeAttr::get(MLIRContext *context, ShapedType shaped_type) {
442 if (shaped_type.hasRank())
443 return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
444
445 return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
446 }
447
getValue() const448 llvm::Optional<ArrayRef<int64_t>> ShapeAttr::getValue() const {
449 if (hasRank()) return getShape();
450 return llvm::None;
451 }
452
hasRank() const453 bool ShapeAttr::hasRank() const { return !getImpl()->unranked; }
454
getRank() const455 int64_t ShapeAttr::getRank() const {
456 assert(hasRank());
457 return getImpl()->shape.size();
458 }
459
hasStaticShape() const460 bool ShapeAttr::hasStaticShape() const {
461 if (!hasRank()) return false;
462
463 for (auto dim : getShape()) {
464 if (dim < 0) return false;
465 }
466
467 return true;
468 }
469
470 namespace {
471 // Returns the shape of the given value if it's ranked; returns llvm::None
472 // otherwise.
GetShape(Value value)473 Optional<ArrayRef<int64_t>> GetShape(Value value) {
474 auto shaped_type = value.getType().cast<ShapedType>();
475 if (shaped_type.hasRank()) return shaped_type.getShape();
476 return llvm::None;
477 }
478
479 // Merges cast compatible shapes and returns a more refined shape. The two
480 // shapes are cast compatible if they have the same rank and at each dimension,
481 // either both have same size or one of them is dynamic. Returns false if the
482 // given shapes are not cast compatible. The refined shape is same or more
483 // precise than the two input shapes.
GetCastCompatibleShape(ArrayRef<int64_t> a_shape,ArrayRef<int64_t> b_shape,SmallVectorImpl<int64_t> * refined_shape)484 bool GetCastCompatibleShape(ArrayRef<int64_t> a_shape,
485 ArrayRef<int64_t> b_shape,
486 SmallVectorImpl<int64_t> *refined_shape) {
487 if (a_shape.size() != b_shape.size()) return false;
488 int64_t rank = a_shape.size();
489 refined_shape->reserve(rank);
490 for (auto dims : llvm::zip(a_shape, b_shape)) {
491 int64_t dim1 = std::get<0>(dims);
492 int64_t dim2 = std::get<1>(dims);
493
494 if (ShapedType::isDynamic(dim1)) {
495 refined_shape->push_back(dim2);
496 continue;
497 }
498 if (ShapedType::isDynamic(dim2)) {
499 refined_shape->push_back(dim1);
500 continue;
501 }
502 if (dim1 == dim2) {
503 refined_shape->push_back(dim1);
504 continue;
505 }
506 return false;
507 }
508 return true;
509 }
510
511 } // namespace
512
513 //===----------------------------------------------------------------------===//
514 // Utility iterators
515 //===----------------------------------------------------------------------===//
516
OperandShapeIterator(Operation::operand_iterator it)517 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
518 : llvm::mapped_iterator<Operation::operand_iterator,
519 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
520 it, &GetShape) {}
521
ResultShapeIterator(Operation::result_iterator it)522 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
523 : llvm::mapped_iterator<Operation::result_iterator,
524 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
525 it, &GetShape) {}
526
527 //===----------------------------------------------------------------------===//
528 // TF types helper functions
529 //===----------------------------------------------------------------------===//
530
classof(Type type)531 bool TensorFlowType::classof(Type type) {
532 return llvm::isa<TFTypeDialect>(type.getDialect());
533 }
classof(Type type)534 bool TensorFlowRefType::classof(Type type) {
535 return type.isa<
536 #define HANDLE_TF_TYPE(tftype, enumerant, name)
537 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
538 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
539 // NOLINTNEXTLINE
540 #include "tensorflow/core/ir/types/types.def"
541 >();
542 }
543
get(Type type)544 TensorFlowType TensorFlowRefType::get(Type type) {
545 MLIRContext *ctx = type.getContext();
546 type = getElementTypeOrSelf(type);
547 if (type.isF16()) {
548 return HalfRefType::get(ctx);
549 } else if (type.isF32()) {
550 return FloatRefType::get(ctx);
551 } else if (type.isF64()) {
552 return DoubleRefType::get(ctx);
553 } else if (type.isBF16()) {
554 return Bfloat16RefType::get(ctx);
555 } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
556 Type etype = complex_type.getElementType();
557 if (etype.isF32()) {
558 return Complex64RefType::get(ctx);
559 } else if (etype.isF64()) {
560 return Complex128RefType::get(ctx);
561 }
562 llvm_unreachable("unexpected complex type");
563 } else if (auto itype = type.dyn_cast<IntegerType>()) {
564 switch (itype.getWidth()) {
565 case 1:
566 return BoolRefType::get(ctx);
567 case 8:
568 return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
569 : Int8RefType::get(ctx);
570 case 16:
571 return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
572 : Int16RefType::get(ctx);
573 case 32:
574 return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
575 : Int32RefType::get(ctx);
576 case 64:
577 return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
578 : Int64RefType::get(ctx);
579 default:
580 llvm_unreachable("unexpected integer type");
581 }
582 }
583 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
584 if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
585 return tftype##RefType::get(ctx);
586
587 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
588 // NOLINTNEXTLINE
589 #include "tensorflow/core/ir/types/types.def"
590 llvm_unreachable("unexpected type kind");
591 }
592
RemoveRef()593 Type TensorFlowRefType::RemoveRef() {
594 MLIRContext *ctx = getContext();
595 if (isa<HalfRefType>()) return FloatType::getF16(ctx);
596 if (isa<FloatRefType>()) return FloatType::getF32(ctx);
597 if (isa<DoubleRefType>()) return FloatType::getF64(ctx);
598 if (isa<Bfloat16RefType>()) return FloatType::getBF16(ctx);
599 if (isa<BoolRefType>()) return IntegerType::get(ctx, 1);
600 if (isa<Int8RefType>()) return IntegerType::get(ctx, 8);
601 if (isa<Int16RefType>()) return IntegerType::get(ctx, 16);
602 if (isa<Int32RefType>()) return IntegerType::get(ctx, 32);
603 if (isa<Int64RefType>()) return IntegerType::get(ctx, 64);
604 if (isa<Uint8RefType>())
605 return IntegerType::get(ctx, 8, IntegerType::Unsigned);
606 if (isa<Uint16RefType>())
607 return IntegerType::get(ctx, 16, IntegerType::Unsigned);
608 if (isa<Uint32RefType>())
609 return IntegerType::get(ctx, 32, IntegerType::Unsigned);
610 if (isa<Uint64RefType>())
611 return IntegerType::get(ctx, 64, IntegerType::Unsigned);
612 if (isa<Complex64RefType>()) return ComplexType::get(FloatType::getF32(ctx));
613 if (isa<Complex128RefType>()) return ComplexType::get(FloatType::getF64(ctx));
614 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
615 if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
616
617 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
618 // NOLINTNEXTLINE
619 #include "tensorflow/core/ir/types/types.def"
620 llvm_unreachable("unexpected tensorflow ref type kind");
621 }
622
classof(Type type)623 bool TensorFlowTypeWithSubtype::classof(Type type) {
624 return type.isa<ResourceType, VariantType>();
625 }
626
RemoveSubtypes()627 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
628 MLIRContext *ctx = getContext();
629 if (isa<VariantType>()) return VariantType::get(ctx);
630 if (isa<ResourceType>()) return ResourceType::get(ctx);
631 llvm_unreachable("unexpected tensorflow type with subtypes kind");
632 }
633
clone(ArrayRef<TensorType> new_subtypes)634 TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
635 ArrayRef<TensorType> new_subtypes) {
636 MLIRContext *ctx = getContext();
637 if (isa<VariantType>())
638 return VariantType::get(new_subtypes, ctx)
639 .cast<TensorFlowTypeWithSubtype>();
640 if (isa<ResourceType>())
641 return ResourceType::get(new_subtypes, ctx)
642 .cast<TensorFlowTypeWithSubtype>();
643 llvm_unreachable("unexpected tensorflow type with subtypes kind");
644 }
645
GetSubtypes()646 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
647 if (auto variant_type = dyn_cast<VariantType>())
648 return variant_type.getSubtypes();
649 if (auto resource_type = dyn_cast<ResourceType>())
650 return resource_type.getSubtypes();
651 llvm_unreachable("unexpected tensorflow type with subtypes kind");
652 }
653
654 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
655 // similar structure that could be extracted into helper method.
BroadcastCompatible(TypeRange lhs,TypeRange rhs)656 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) {
657 if (lhs.size() != rhs.size()) return false;
658 for (auto types : llvm::zip(lhs, rhs)) {
659 // Drop ref types because they don't affect broadcast compatibility. E.g.,
660 // `tensor<!tf_type.f32ref>` and `tensor<f32>` should be considered
661 // broadcast compatible.
662 auto lhs_type = DropRefType(std::get<0>(types));
663 auto rhs_type = DropRefType(std::get<1>(types));
664
665 // This should be true for all TF ops:
666 auto lhs_tt = lhs_type.dyn_cast<TensorType>();
667 auto rhs_tt = rhs_type.dyn_cast<TensorType>();
668 if (!lhs_tt || !rhs_tt) {
669 if (lhs_type != rhs_type) return false;
670 continue;
671 }
672
673 // Verify matching element types. These should be identical, except for
674 // variant type where unknown subtype is considered compatible with all
675 // subtypes.
676 auto lhs_et = lhs_tt.getElementType();
677 auto rhs_et = rhs_tt.getElementType();
678 if (lhs_et != rhs_et) {
679 // If either does not have subtypes, then the element types don't match.
680 auto lhs_wst = lhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
681 auto rhs_wst = rhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
682 if (!lhs_wst || !rhs_wst) return false;
683
684 // Consider the subtype of variant types.
685 auto lhs_wst_st = lhs_wst.GetSubtypes();
686 auto rhs_wst_st = rhs_wst.GetSubtypes();
687 if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
688 for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
689 if (!BroadcastCompatible(std::get<0>(subtypes),
690 std::get<1>(subtypes)))
691 return false;
692 }
693 }
694 }
695
696 auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
697 auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
698 if (!lhs_rt || !rhs_rt) return true;
699 SmallVector<int64_t, 4> shape;
700 return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
701 rhs_rt.getShape(), shape);
702 }
703 return true;
704 }
705
706 // Given two types `a` and `b`, returns a refined type which is cast compatible
707 // with both `a` and `b` and is equal to or more precise than both of them. It
708 // returns empty Type if the input types are not cast compatible.
709 //
710 // The two types are considered cast compatible if they have dynamically equal
711 // shapes and element type. For element types that do not have subtypes, they
712 // must be equal. However for TensorFlow types such as Resource and Variant,
713 // that also have subtypes, we recursively check for subtype compatibility for
714 // Resource types and assume all variant types are cast compatible. If either
715 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
716 //
717 // The returned type is same or more precise than the input types. For example,
718 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
719 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
720 //
721 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
722 // might allow operands to either be same as result type or be a ref type
723 // corresponding to it.
GetCastCompatibleType(Type a,Type b,bool may_ignore_ref_type_a)724 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) {
725 // Fast path if everything is equal.
726 if (a == b) return b;
727
728 auto a_tt = a.dyn_cast<TensorType>();
729 auto b_tt = b.dyn_cast<TensorType>();
730
731 // If only one of a or b is a tensor type, they are incompatible.
732 if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
733
734 // For non-tensor types, we do not need to worry about shape and can return
735 // early.
736 if (!a_tt && !b_tt) {
737 // Remove ref types.
738 if (may_ignore_ref_type_a) {
739 if (auto ref_type = a.dyn_cast<TensorFlowRefType>()) {
740 a = ref_type.RemoveRef();
741 if (a == b) return a;
742 }
743 }
744 if (a.getTypeID() != b.getTypeID()) return nullptr;
745
746 // If either is not a type that contain subtypes then the types are not cast
747 // compatible.
748 auto a_wst = a.dyn_cast<TensorFlowTypeWithSubtype>();
749 auto b_wst = b.dyn_cast<TensorFlowTypeWithSubtype>();
750 if (!a_wst || !b_wst) return nullptr;
751
752 // For Variant types we are more permissive right now and accept all pairs
753 // of Variant types. If we are more constrainted and check compatibility of
754 // subtypes, we might reject valid graphs.
755 // TODO(prakalps): Variant doesn't have a subtype, we assign it
756 // one, so we should only assign it one when we know the subtype. Then we
757 // can be more constrained and check subtypes for cast compatibility as
758 // well.
759 if (a.isa<VariantType>()) return a;
760 if (b.isa<VariantType>()) return b;
761
762 // For Resource types, we recursively check the subtypes for cast
763 // compatibility, if possible. Otherwise treat them as compatible.
764 auto a_wst_st = a_wst.GetSubtypes();
765 auto b_wst_st = b_wst.GetSubtypes();
766 if (a_wst_st.empty()) return b;
767 if (b_wst_st.empty()) return a;
768 if (a_wst_st.size() != b_wst_st.size()) return nullptr;
769 SmallVector<TensorType, 4> refined_subtypes;
770 for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
771 Type refined_st =
772 GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
773 /*may_ignore_ref_type_a=*/false);
774 if (!refined_st) return nullptr;
775 refined_subtypes.push_back(refined_st.cast<TensorType>());
776 }
777
778 return ResourceType::get(refined_subtypes, a.getContext());
779 }
780
781 // For tensor types, check compatibility of both element type and shape.
782 Type refined_element_ty = GetCastCompatibleType(
783 a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
784 if (!refined_element_ty) return nullptr;
785
786 if (!a_tt.hasRank() && !b_tt.hasRank()) {
787 return UnrankedTensorType::get(refined_element_ty);
788 }
789 if (!a_tt.hasRank()) {
790 return RankedTensorType::get(b_tt.getShape(), refined_element_ty);
791 }
792 if (!b_tt.hasRank()) {
793 return RankedTensorType::get(a_tt.getShape(), refined_element_ty);
794 }
795
796 SmallVector<int64_t, 4> refined_shape;
797 if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
798 return nullptr;
799
800 return RankedTensorType::get(refined_shape, refined_element_ty);
801 }
802
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)803 bool HasCompatibleElementTypes(Type lhs, Type rhs,
804 bool may_ignore_ref_type_lhs) {
805 return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
806 }
807
AreCastCompatible(TypeRange types)808 bool AreCastCompatible(TypeRange types) {
809 Type common = types.front();
810 for (auto type : types.drop_front()) {
811 Type refined_type =
812 GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
813 if (!refined_type) return false;
814 common = refined_type;
815 }
816 return true;
817 }
818
ArraysAreCastCompatible(TypeRange lhs,TypeRange rhs)819 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) {
820 if (lhs.size() != rhs.size()) return false;
821 for (auto pair : llvm::zip(lhs, rhs)) {
822 auto lhs_i = std::get<0>(pair);
823 auto rhs_i = std::get<1>(pair);
824 if (!AreCastCompatible({lhs_i, rhs_i})) return false;
825 }
826 return true;
827 }
828
829 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
830 // type.
GetDefaultTypeOf(TensorFlowRefType type)831 static Type GetDefaultTypeOf(TensorFlowRefType type) {
832 return type.RemoveRef();
833 }
834
835 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
836 // type for a composed type (such as a ref type or a type with subtypes).
837 template <typename ComposedType>
DropTypeHelper(Type ty)838 Type DropTypeHelper(Type ty) {
839 Type element_ty = getElementTypeOrSelf(ty);
840 auto composed_type = element_ty.dyn_cast<ComposedType>();
841 if (!composed_type) return ty;
842
843 Type default_ty = GetDefaultTypeOf(composed_type);
844 if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
845 return RankedTensorType::get(ranked_ty.getShape(), default_ty);
846 } else if (ty.dyn_cast<UnrankedTensorType>()) {
847 return UnrankedTensorType::get(default_ty);
848 } else {
849 return default_ty;
850 }
851 }
852
DropSubTypes(Type ty)853 Type DropSubTypes(Type ty) {
854 return DropTypeHelper<TensorFlowTypeWithSubtype>(ty);
855 }
856
DropRefType(Type ty)857 Type DropRefType(Type ty) { return DropTypeHelper<TensorFlowRefType>(ty); }
858
DropRefAndSubTypes(Type ty)859 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
860
parse(AsmParser & parser,Type type)861 Attribute TensorProtoAttr::parse(AsmParser &parser, Type type) {
862 if (parser.parseColon()) {
863 return nullptr;
864 }
865
866 std::string data;
867 if (parser.parseString(&data)) {
868 return nullptr;
869 }
870 if (data.size() < 2 || data.substr(0, 2) != "0x") {
871 parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`");
872 return nullptr;
873 }
874
875 std::string bytes_data = absl::HexStringToBytes(data.substr(2));
876 return TensorProtoAttr::get(type, bytes_data);
877 }
878
print(mlir::AsmPrinter & printer) const879 void TensorProtoAttr::print(mlir::AsmPrinter &printer) const {
880 StringRef bytes_str = getValue();
881 printer << " : \"0x" << llvm::toHex(bytes_str) << "\"";
882 }
883
884 } // namespace tf_type
885 } // namespace mlir
886