1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/types/dialect.h"
17
18 #include <cstdint>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/SMLoc.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Traits.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Dialect.h" // from @llvm-project
32 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
33 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
34 #include "mlir/IR/FunctionSupport.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/OperationSupport.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38
39 #define GET_ATTRDEF_CLASSES
40 #include "tensorflow/core/ir/types/attributes.cc.inc"
41
42 #define GET_TYPEDEF_CLASSES
43 #include "tensorflow/core/ir/types/types.cc.inc"
44
45 // Generated definitions.
46 #include "tensorflow/core/ir/types/dialect.cpp.inc"
47
48 namespace mlir {
49 namespace tf_type {
50
51 //===----------------------------------------------------------------------===//
52 // TFType dialect.
53 //===----------------------------------------------------------------------===//
54
55 // Dialect construction: there is one instance per context and it registers its
56 // operations, types, and interfaces here.
initialize()57 void TFTypeDialect::initialize() {
58 addAttributes<
59 #define GET_ATTRDEF_LIST
60 #include "tensorflow/core/ir/types/attributes.cc.inc"
61 >();
62 addTypes<ControlType, OpaqueTensorType,
63 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
64 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
65 #include "tensorflow/core/ir/types/types.def"
66 >();
67 }
68
69 // Entry point for Attribute parsing, TableGen generated code will handle the
70 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const71 Attribute TFTypeDialect::parseAttribute(DialectAsmParser &parser,
72 Type type) const {
73 StringRef attr_tag;
74 if (failed(parser.parseKeyword(&attr_tag))) return Attribute();
75 {
76 Attribute attr;
77 auto parse_result =
78 generatedAttributeParser(getContext(), parser, attr_tag, type, attr);
79 if (parse_result.hasValue()) return attr;
80 }
81 parser.emitError(parser.getNameLoc(), "unknown tf_type attribute");
82 return Attribute();
83 }
84
85 // Entry point for Attribute printing, TableGen generated code will handle the
86 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const87 void TFTypeDialect::printAttribute(Attribute attr,
88 DialectAsmPrinter &os) const {
89 (void)generatedAttributePrinter(attr, os);
90 }
91
92 namespace {
93 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser)94 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser) {
95 // Default type without inferred subtypes.
96 if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
97
98 // Most types with subtypes have only one subtype.
99 SmallVector<TensorType, 1> subtypes;
100 do {
101 TensorType tensor_ty;
102 if (parser.parseType(tensor_ty)) return Type();
103
104 // Each of the subtypes should be a valid TensorFlow type.
105 // TODO(jpienaar): Remove duplication.
106 if (!IsValidTFTensorType(tensor_ty)) {
107 parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
108 return Type();
109 }
110 subtypes.push_back(tensor_ty);
111 } while (succeeded(parser.parseOptionalComma()));
112
113 if (parser.parseGreater()) return Type();
114
115 return TypeWithSubtype::get(subtypes, context);
116 }
117
118 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)119 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
120 DialectAsmPrinter &os) {
121 os << type;
122 ArrayRef<TensorType> subtypes = ty.getSubtypes();
123 if (subtypes.empty()) return;
124
125 os << "<";
126 interleaveComma(subtypes, os);
127 os << ">";
128 }
ParseResourceType(MLIRContext * context,DialectAsmParser & parser)129 Type ParseResourceType(MLIRContext *context, DialectAsmParser &parser) {
130 return ParseTypeWithSubtype<ResourceType>(context, parser);
131 }
132
PrintResourceType(ResourceType ty,DialectAsmPrinter & os)133 void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) {
134 return PrintTypeWithSubtype("resource", ty, os);
135 }
136
ParseVariantType(MLIRContext * context,DialectAsmParser & parser)137 Type ParseVariantType(MLIRContext *context, DialectAsmParser &parser) {
138 return ParseTypeWithSubtype<VariantType>(context, parser);
139 }
140
PrintVariantType(VariantType ty,DialectAsmPrinter & os)141 void PrintVariantType(VariantType ty, DialectAsmPrinter &os) {
142 return PrintTypeWithSubtype("variant", ty, os);
143 }
144
145 } // namespace
146
147 // Entry point for Type parsing, TableGen generated code will handle the
148 // dispatch to the individual classes.
parseType(DialectAsmParser & parser) const149 Type TFTypeDialect::parseType(DialectAsmParser &parser) const {
150 StringRef type_tag;
151 llvm::SMLoc loc = parser.getNameLoc();
152 if (failed(parser.parseKeyword(&type_tag))) return Type();
153
154 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
155 if (type_tag == name) return tftype##Type::get(getContext());
156 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
157 // NOLINTNEXTLINE: intended redundant include.
158 #include "tensorflow/core/ir/types/types.def"
159
160 if (type_tag.startswith("resource")) {
161 Type ret = ParseResourceType(getContext(), parser);
162 if (!ret) parser.emitError(loc, "invalid resource type");
163 return ret;
164 }
165 if (type_tag.startswith("variant")) {
166 Type ret = ParseVariantType(getContext(), parser);
167 if (!ret) parser.emitError(loc, "invalid variant type");
168 return ret;
169 }
170
171 Type genType;
172 auto parse_result =
173 generatedTypeParser(getContext(), parser, type_tag, genType);
174 if (parse_result.hasValue()) return genType;
175 parser.emitError(parser.getNameLoc(),
176 "unknown type in TF graph dialect: " + type_tag);
177 return {};
178 }
179
180 // Entry point for Type parsing, TableGen generated code will handle the
181 // dispatch to the individual classes.
printType(Type type,DialectAsmPrinter & printer) const182 void TFTypeDialect::printType(Type type, DialectAsmPrinter &printer) const {
183 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
184 if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
185 printer << name; \
186 return; \
187 }
188 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
189 if (auto derived_ty = type.dyn_cast<tftype##Type>()) { \
190 Print##tftype##Type(derived_ty, printer); \
191 return; \
192 }
193 // NOLINTNEXTLINE: intended redundant include.
194 #include "tensorflow/core/ir/types/types.def"
195
196 if (failed(generatedTypePrinter(type, printer)))
197 llvm::report_fatal_error("unexpected tensorflow graph type kind");
198 }
199
200 //===----------------------------------------------------------------------===//
201 // Attributes
202 //===----------------------------------------------------------------------===//
203
parse(MLIRContext * context,DialectAsmParser & parser,Type)204 Attribute VersionAttr::parse(MLIRContext *context, DialectAsmParser &parser,
205 Type) {
206 if (failed(parser.parseLess())) return {};
207
208 int32_t producer, min_consumer;
209 if (parser.parseKeyword("producer", " in tf_type version") ||
210 parser.parseEqual() || parser.parseInteger(producer) ||
211 parser.parseComma() ||
212 parser.parseKeyword("min_consumer", " in tf_type version") ||
213 parser.parseEqual() || parser.parseInteger(min_consumer))
214 return {};
215
216 SmallVector<int32_t, 4> bad_consumers;
217 if (!parser.parseOptionalComma()) {
218 if (parser.parseKeyword("bad_consumers", " in tf_type version") ||
219 parser.parseEqual() || parser.parseLSquare())
220 return {};
221 do {
222 int32_t bad_consumer;
223 if (parser.parseInteger(bad_consumer)) return {};
224 bad_consumers.push_back(bad_consumer);
225 } while (!parser.parseOptionalComma());
226 if (parser.parseRSquare()) return {};
227 }
228 if (failed(parser.parseGreater())) return {};
229
230 return VersionAttr::get(context, producer, min_consumer, bad_consumers);
231 }
232
print(DialectAsmPrinter & printer) const233 void VersionAttr::print(DialectAsmPrinter &printer) const {
234 llvm::raw_ostream &os = printer.getStream();
235 os << getMnemonic();
236 os << "<producer = " << getProducer()
237 << ", min_consumer = " << getMinConsumer();
238 ArrayRef<int32_t> badConsumers = getBadConsumers();
239 if (!badConsumers.empty()) {
240 os << ", bad_consumers = [";
241 llvm::interleaveComma(badConsumers, os);
242 os << "]";
243 }
244 os << ">";
245 }
246
247 // Print a #tf.func attribute of the following format:
248 //
249 // #tf.func<@symbol, {attr = "value"}>
250 // or
251 // #tf.func<"", {attr = "value"}>
252 // in case of null symbol ref.
print(DialectAsmPrinter & os) const253 void FuncAttr::print(DialectAsmPrinter &os) const {
254 if (getName().getRootReference().empty())
255 os << "func<\"\", " << getAttrs() << ">";
256 else
257 os << "func<" << getName() << ", " << getAttrs() << ">";
258 }
259
260 // Parses a #tf.func attribute of the following format:
261 //
262 // #tf.func<@symbol, {attr = "value"}>
263 //
264 // where the first element is a SymbolRefAttr and the second element is a
265 // DictionaryAttr.
parse(MLIRContext * context,DialectAsmParser & parser,Type type)266 Attribute FuncAttr::parse(MLIRContext *context, DialectAsmParser &parser,
267 Type type) {
268 if (failed(parser.parseLess())) return {};
269 llvm::SMLoc loc = parser.getCurrentLocation();
270 Attribute name, dict;
271 if (failed(parser.parseAttribute(name))) {
272 parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
273 return {};
274 }
275 if (auto func_name_str = name.dyn_cast<StringAttr>()) {
276 if (!func_name_str.getValue().empty()) {
277 parser.emitError(loc)
278 << "expected empty string or symbol while parsing tf.func "
279 "attribute";
280 return {};
281 }
282 name = SymbolRefAttr::get(context, "");
283 }
284 if (!name.isa<SymbolRefAttr>()) {
285 parser.emitError(loc) << "expected symbol while parsing tf.func attribute";
286 return {};
287 }
288 if (failed(parser.parseComma())) return {};
289 loc = parser.getCurrentLocation();
290 if (failed(parser.parseAttribute(dict)) || !dict.isa<DictionaryAttr>()) {
291 parser.emitError(loc)
292 << "expected Dictionary attribute while parsing tf.func attribute";
293 return {};
294 }
295 if (failed(parser.parseGreater())) return {};
296 return FuncAttr::get(context, name.cast<SymbolRefAttr>(),
297 dict.cast<DictionaryAttr>());
298 }
299
print(DialectAsmPrinter & os) const300 void PlaceholderAttr::print(DialectAsmPrinter &os) const {
301 os << "placeholder<" << StringAttr::get(getContext(), getValue()) << ">";
302 }
303
parse(MLIRContext * context,DialectAsmParser & parser,Type type)304 Attribute PlaceholderAttr::parse(MLIRContext *context, DialectAsmParser &parser,
305 Type type) {
306 if (failed(parser.parseLess())) return {};
307 StringRef content;
308 if (failed(parser.parseOptionalString(&content))) {
309 parser.emitError(parser.getCurrentLocation())
310 << "expected string while parsing tf.placeholder attribute";
311 return {};
312 }
313 if (failed(parser.parseGreater())) return {};
314 return PlaceholderAttr::get(context, content);
315 }
316
print(DialectAsmPrinter & os) const317 void ShapeAttr::print(DialectAsmPrinter &os) const {
318 os << "shape<";
319 if (hasRank()) {
320 auto print_dim = [&](int64_t dim) {
321 if (dim > -1)
322 os << dim;
323 else
324 os << "?";
325 };
326 llvm::interleave(getShape(), os, print_dim, "x");
327 } else {
328 os << "*";
329 }
330 os << ">";
331 }
332
parse(MLIRContext * context,DialectAsmParser & parser,Type type)333 Attribute ShapeAttr::parse(MLIRContext *context, DialectAsmParser &parser,
334 Type type) {
335 if (failed(parser.parseLess())) return {};
336
337 if (succeeded(parser.parseOptionalStar())) {
338 if (failed(parser.parseGreater())) {
339 parser.emitError(parser.getCurrentLocation())
340 << "expected `>` after `*` when parsing a tf.shape "
341 "attribute";
342 return {};
343 }
344 return ShapeAttr::get(context, llvm::None);
345 }
346
347 SmallVector<int64_t> shape;
348 if (failed(parser.parseOptionalGreater())) {
349 auto parse_element = [&]() {
350 shape.emplace_back();
351 llvm::SMLoc loc = parser.getCurrentLocation();
352 if (succeeded(parser.parseOptionalQuestion())) {
353 shape.back() = ShapedType::kDynamicSize;
354 } else if (failed(parser.parseInteger(shape.back())) ||
355 shape.back() < 0) {
356 parser.emitError(loc) << "expected a positive integer or `?` when "
357 "parsing a tf.shape attribute";
358 return failure();
359 }
360 return success();
361 };
362 if (failed(parse_element())) return {};
363 while (failed(parser.parseOptionalGreater())) {
364 if (failed(parser.parseXInDimensionList()) || failed(parse_element()))
365 return {};
366 }
367 }
368 return ShapeAttr::get(context, llvm::makeArrayRef(shape));
369 }
370
371 // Get or create a shape attribute.
get(MLIRContext * context,llvm::Optional<ArrayRef<int64_t>> shape)372 ShapeAttr ShapeAttr::get(MLIRContext *context,
373 llvm::Optional<ArrayRef<int64_t>> shape) {
374 if (shape) return Base::get(context, *shape, /*unranked=*/false);
375
376 return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
377 }
378
379 // Get or create a shape attribute.
get(MLIRContext * context,ShapedType shaped_type)380 ShapeAttr ShapeAttr::get(MLIRContext *context, ShapedType shaped_type) {
381 if (shaped_type.hasRank())
382 return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
383
384 return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
385 }
386
getValue() const387 llvm::Optional<ArrayRef<int64_t>> ShapeAttr::getValue() const {
388 if (hasRank()) return getShape();
389 return llvm::None;
390 }
391
hasRank() const392 bool ShapeAttr::hasRank() const { return !getImpl()->unranked; }
393
getRank() const394 int64_t ShapeAttr::getRank() const {
395 assert(hasRank());
396 return getImpl()->shape.size();
397 }
398
hasStaticShape() const399 bool ShapeAttr::hasStaticShape() const {
400 if (!hasRank()) return false;
401
402 for (auto dim : getShape()) {
403 if (dim < 0) return false;
404 }
405
406 return true;
407 }
408
409 namespace {
410 // Returns the shape of the given value if it's ranked; returns llvm::None
411 // otherwise.
GetShape(Value value)412 llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(Value value) {
413 auto shaped_type = value.getType().cast<ShapedType>();
414 if (shaped_type.hasRank()) return shaped_type.getShape();
415 return llvm::None;
416 }
417
418 // Merges cast compatible shapes and returns a more refined shape. The two
419 // shapes are cast compatible if they have the same rank and at each dimension,
420 // either both have same size or one of them is dynamic. Returns false if the
421 // given shapes are not cast compatible. The refined shape is same or more
422 // precise than the two input shapes.
GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,llvm::ArrayRef<int64_t> b_shape,llvm::SmallVectorImpl<int64_t> * refined_shape)423 bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
424 llvm::ArrayRef<int64_t> b_shape,
425 llvm::SmallVectorImpl<int64_t> *refined_shape) {
426 if (a_shape.size() != b_shape.size()) return false;
427 int64_t rank = a_shape.size();
428 refined_shape->reserve(rank);
429 for (auto dims : llvm::zip(a_shape, b_shape)) {
430 int64_t dim1 = std::get<0>(dims);
431 int64_t dim2 = std::get<1>(dims);
432
433 if (ShapedType::isDynamic(dim1)) {
434 refined_shape->push_back(dim2);
435 continue;
436 }
437 if (ShapedType::isDynamic(dim2)) {
438 refined_shape->push_back(dim1);
439 continue;
440 }
441 if (dim1 == dim2) {
442 refined_shape->push_back(dim1);
443 continue;
444 }
445 return false;
446 }
447 return true;
448 }
449
450 } // namespace
451
452 //===----------------------------------------------------------------------===//
453 // Utility iterators
454 //===----------------------------------------------------------------------===//
455
OperandShapeIterator(Operation::operand_iterator it)456 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
457 : llvm::mapped_iterator<Operation::operand_iterator,
458 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
459 it, &GetShape) {}
460
ResultShapeIterator(Operation::result_iterator it)461 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
462 : llvm::mapped_iterator<Operation::result_iterator,
463 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
464 it, &GetShape) {}
465
466 //===----------------------------------------------------------------------===//
467 // TF types helper functions
468 //===----------------------------------------------------------------------===//
469
classof(Type type)470 bool TensorFlowType::classof(Type type) {
471 return llvm::isa<TFTypeDialect>(type.getDialect());
472 }
classof(Type type)473 bool TensorFlowRefType::classof(Type type) {
474 return type.isa<
475 #define HANDLE_TF_TYPE(tftype, enumerant, name)
476 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
477 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
478 // NOLINTNEXTLINE
479 #include "tensorflow/core/ir/types/types.def"
480 >();
481 }
482
get(Type type)483 TensorFlowType TensorFlowRefType::get(Type type) {
484 MLIRContext *ctx = type.getContext();
485 type = getElementTypeOrSelf(type);
486 if (type.isF16()) {
487 return HalfRefType::get(ctx);
488 } else if (type.isF32()) {
489 return FloatRefType::get(ctx);
490 } else if (type.isF64()) {
491 return DoubleRefType::get(ctx);
492 } else if (type.isBF16()) {
493 return Bfloat16RefType::get(ctx);
494 } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
495 Type etype = complex_type.getElementType();
496 if (etype.isF32()) {
497 return Complex64RefType::get(ctx);
498 } else if (etype.isF64()) {
499 return Complex128RefType::get(ctx);
500 }
501 llvm_unreachable("unexpected complex type");
502 } else if (auto itype = type.dyn_cast<IntegerType>()) {
503 switch (itype.getWidth()) {
504 case 1:
505 return BoolRefType::get(ctx);
506 case 8:
507 return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
508 : Int8RefType::get(ctx);
509 case 16:
510 return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
511 : Int16RefType::get(ctx);
512 case 32:
513 return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
514 : Int32RefType::get(ctx);
515 case 64:
516 return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
517 : Int64RefType::get(ctx);
518 default:
519 llvm_unreachable("unexpected integer type");
520 }
521 }
522 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
523 if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
524 return tftype##RefType::get(ctx);
525
526 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
527 // NOLINTNEXTLINE
528 #include "tensorflow/core/ir/types/types.def"
529 llvm_unreachable("unexpected type kind");
530 }
531
RemoveRef()532 Type TensorFlowRefType::RemoveRef() {
533 MLIRContext *ctx = getContext();
534 if (isa<HalfRefType>()) return FloatType::getF16(ctx);
535 if (isa<FloatRefType>()) return FloatType::getF32(ctx);
536 if (isa<DoubleRefType>()) return FloatType::getF64(ctx);
537 if (isa<Bfloat16RefType>()) return FloatType::getBF16(ctx);
538 if (isa<BoolRefType>()) return IntegerType::get(ctx, 1);
539 if (isa<Int8RefType>()) return IntegerType::get(ctx, 8);
540 if (isa<Int16RefType>()) return IntegerType::get(ctx, 16);
541 if (isa<Int32RefType>()) return IntegerType::get(ctx, 32);
542 if (isa<Int64RefType>()) return IntegerType::get(ctx, 64);
543 if (isa<Uint8RefType>())
544 return IntegerType::get(ctx, 8, IntegerType::Unsigned);
545 if (isa<Uint16RefType>())
546 return IntegerType::get(ctx, 16, IntegerType::Unsigned);
547 if (isa<Uint32RefType>())
548 return IntegerType::get(ctx, 32, IntegerType::Unsigned);
549 if (isa<Uint64RefType>())
550 return IntegerType::get(ctx, 64, IntegerType::Unsigned);
551 if (isa<Complex64RefType>()) return ComplexType::get(FloatType::getF32(ctx));
552 if (isa<Complex128RefType>()) return ComplexType::get(FloatType::getF64(ctx));
553 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
554 if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
555
556 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
557 // NOLINTNEXTLINE
558 #include "tensorflow/core/ir/types/types.def"
559 llvm_unreachable("unexpected tensorflow ref type kind");
560 }
561
classof(Type type)562 bool TensorFlowTypeWithSubtype::classof(Type type) {
563 return type.isa<ResourceType, VariantType>();
564 }
565
RemoveSubtypes()566 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
567 MLIRContext *ctx = getContext();
568 if (isa<VariantType>()) return VariantType::get(ctx);
569 if (isa<ResourceType>()) return ResourceType::get(ctx);
570 llvm_unreachable("unexpected tensorflow type with subtypes kind");
571 }
572
clone(ArrayRef<TensorType> new_subtypes)573 TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
574 ArrayRef<TensorType> new_subtypes) {
575 MLIRContext *ctx = getContext();
576 if (isa<VariantType>())
577 return VariantType::get(new_subtypes, ctx)
578 .cast<TensorFlowTypeWithSubtype>();
579 if (isa<ResourceType>())
580 return ResourceType::get(new_subtypes, ctx)
581 .cast<TensorFlowTypeWithSubtype>();
582 llvm_unreachable("unexpected tensorflow type with subtypes kind");
583 }
584
GetSubtypes()585 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
586 if (auto variant_type = dyn_cast<VariantType>())
587 return variant_type.getSubtypes();
588 if (auto resource_type = dyn_cast<ResourceType>())
589 return resource_type.getSubtypes();
590 llvm_unreachable("unexpected tensorflow type with subtypes kind");
591 }
592
593 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
594 // similar structure that could be extracted into helper method.
BroadcastCompatible(TypeRange lhs,TypeRange rhs)595 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) {
596 if (lhs.size() != rhs.size()) return false;
597 for (auto types : llvm::zip(lhs, rhs)) {
598 // Drop ref types because they don't affect broadcast compatibility. E.g.,
599 // `tensor<!tf_type.f32ref>` and `tensor<f32>` should be considered
600 // broadcast compatible.
601 auto lhs_type = DropRefType(std::get<0>(types));
602 auto rhs_type = DropRefType(std::get<1>(types));
603
604 // This should be true for all TF ops:
605 auto lhs_tt = lhs_type.dyn_cast<TensorType>();
606 auto rhs_tt = rhs_type.dyn_cast<TensorType>();
607 if (!lhs_tt || !rhs_tt) {
608 if (lhs_type != rhs_type) return false;
609 continue;
610 }
611
612 // Verify matching element types. These should be identical, except for
613 // variant type where unknown subtype is considered compatible with all
614 // subtypes.
615 auto lhs_et = lhs_tt.getElementType();
616 auto rhs_et = rhs_tt.getElementType();
617 if (lhs_et != rhs_et) {
618 // If either does not have subtypes, then the element types don't match.
619 auto lhs_wst = lhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
620 auto rhs_wst = rhs_et.dyn_cast<TensorFlowTypeWithSubtype>();
621 if (!lhs_wst || !rhs_wst) return false;
622
623 // Consider the subtype of variant types.
624 auto lhs_wst_st = lhs_wst.GetSubtypes();
625 auto rhs_wst_st = rhs_wst.GetSubtypes();
626 if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
627 for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
628 if (!BroadcastCompatible(std::get<0>(subtypes),
629 std::get<1>(subtypes)))
630 return false;
631 }
632 }
633 }
634
635 auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
636 auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
637 if (!lhs_rt || !rhs_rt) return true;
638 SmallVector<int64_t, 4> shape;
639 return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
640 rhs_rt.getShape(), shape);
641 }
642 return true;
643 }
644
645 // Given two types `a` and `b`, returns a refined type which is cast compatible
646 // with both `a` and `b` and is equal to or more precise than both of them. It
647 // returns empty Type if the input types are not cast compatible.
648 //
649 // The two types are considered cast compatible if they have dynamically equal
650 // shapes and element type. For element types that do not have subtypes, they
651 // must be equal. However for TensorFlow types such as Resource and Variant,
652 // that also have subtypes, we recursively check for subtype compatibility for
653 // Resource types and assume all variant types are cast compatible. If either
654 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
655 //
656 // The returned type is same or more precise than the input types. For example,
657 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
658 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
659 //
660 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
661 // might allow operands to either be same as result type or be a ref type
662 // corresponding to it.
GetCastCompatibleType(Type a,Type b,bool may_ignore_ref_type_a)663 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) {
664 // Fast path if everything is equal.
665 if (a == b) return b;
666
667 auto a_tt = a.dyn_cast<TensorType>();
668 auto b_tt = b.dyn_cast<TensorType>();
669
670 // If only one of a or b is a tensor type, they are incompatible.
671 if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
672
673 // For non-tensor types, we do not need to worry about shape and can return
674 // early.
675 if (!a_tt && !b_tt) {
676 // Remove ref types.
677 if (may_ignore_ref_type_a) {
678 if (auto ref_type = a.dyn_cast<TensorFlowRefType>()) {
679 a = ref_type.RemoveRef();
680 if (a == b) return a;
681 }
682 }
683 if (a.getTypeID() != b.getTypeID()) return nullptr;
684
685 // If either is not a type that contain subtypes then the types are not cast
686 // compatible.
687 auto a_wst = a.dyn_cast<TensorFlowTypeWithSubtype>();
688 auto b_wst = b.dyn_cast<TensorFlowTypeWithSubtype>();
689 if (!a_wst || !b_wst) return nullptr;
690
691 // For Variant types we are more permissive right now and accept all pairs
692 // of Variant types. If we are more constrainted and check compatibility of
693 // subtypes, we might reject valid graphs.
694 // TODO(prakalps): Variant doesn't have a subtype, we assign it
695 // one, so we should only assign it one when we know the subtype. Then we
696 // can be more constrained and check subtypes for cast compatibility as
697 // well.
698 if (a.isa<VariantType>()) return a;
699
700 // For Resource types, we recursively check the subtypes for cast
701 // compatibility, if possible. Otherwise treat them as compatible.
702 auto a_wst_st = a_wst.GetSubtypes();
703 auto b_wst_st = b_wst.GetSubtypes();
704 if (a_wst_st.empty() || b_wst_st.empty()) return a;
705 if (a_wst_st.size() != b_wst_st.size()) return nullptr;
706 llvm::SmallVector<TensorType, 4> refined_subtypes;
707 for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
708 Type refined_st =
709 GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
710 /*may_ignore_ref_type_a=*/false);
711 if (!refined_st) return nullptr;
712 refined_subtypes.push_back(refined_st.cast<TensorType>());
713 }
714
715 return ResourceType::get(refined_subtypes, a.getContext());
716 }
717
718 // For tensor types, check compatibility of both element type and shape.
719 Type refined_element_ty = GetCastCompatibleType(
720 a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
721 if (!refined_element_ty) return nullptr;
722
723 if (!a_tt.hasRank() && !b_tt.hasRank()) {
724 return UnrankedTensorType::get(refined_element_ty);
725 }
726 if (!a_tt.hasRank()) {
727 return RankedTensorType::get(b_tt.getShape(), refined_element_ty);
728 }
729 if (!b_tt.hasRank()) {
730 return RankedTensorType::get(a_tt.getShape(), refined_element_ty);
731 }
732
733 llvm::SmallVector<int64_t, 8> refined_shape;
734 if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
735 return nullptr;
736
737 return RankedTensorType::get(refined_shape, refined_element_ty);
738 }
739
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)740 bool HasCompatibleElementTypes(Type lhs, Type rhs,
741 bool may_ignore_ref_type_lhs) {
742 return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
743 }
744
AreCastCompatible(TypeRange types)745 bool AreCastCompatible(TypeRange types) {
746 Type common = types.front();
747 for (auto type : types.drop_front()) {
748 Type refined_type =
749 GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
750 if (!refined_type) return false;
751 common = refined_type;
752 }
753 return true;
754 }
755
ArraysAreCastCompatible(TypeRange lhs,TypeRange rhs)756 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs) {
757 if (lhs.size() != rhs.size()) return false;
758 for (auto pair : llvm::zip(lhs, rhs)) {
759 auto lhs_i = std::get<0>(pair);
760 auto rhs_i = std::get<1>(pair);
761 if (!AreCastCompatible({lhs_i, rhs_i})) return false;
762 }
763 return true;
764 }
765
766 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
767 // type.
GetDefaultTypeOf(TensorFlowRefType type)768 static Type GetDefaultTypeOf(TensorFlowRefType type) {
769 return type.RemoveRef();
770 }
771
772 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
773 // type for a composed type (such as a ref type or a type with subtypes).
774 template <typename ComposedType>
DropTypeHelper(Type ty)775 Type DropTypeHelper(Type ty) {
776 Type element_ty = getElementTypeOrSelf(ty);
777 auto composed_type = element_ty.dyn_cast<ComposedType>();
778 if (!composed_type) return ty;
779
780 Type default_ty = GetDefaultTypeOf(composed_type);
781 if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
782 return RankedTensorType::get(ranked_ty.getShape(), default_ty);
783 } else if (ty.dyn_cast<UnrankedTensorType>()) {
784 return UnrankedTensorType::get(default_ty);
785 } else {
786 return default_ty;
787 }
788 }
789
DropSubTypes(Type ty)790 Type DropSubTypes(Type ty) {
791 return DropTypeHelper<TensorFlowTypeWithSubtype>(ty);
792 }
793
DropRefType(Type ty)794 Type DropRefType(Type ty) { return DropTypeHelper<TensorFlowRefType>(ty); }
795
DropRefAndSubTypes(Type ty)796 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
797
798 } // namespace tf_type
799 } // namespace mlir
800