1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
41 #include "mlir/Dialect/Traits.h" // from @llvm-project
42 #include "mlir/IR/Attributes.h" // from @llvm-project
43 #include "mlir/IR/Builders.h" // from @llvm-project
44 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
45 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
46 #include "mlir/IR/Diagnostics.h" // from @llvm-project
47 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
48 #include "mlir/IR/Identifier.h" // from @llvm-project
49 #include "mlir/IR/Location.h" // from @llvm-project
50 #include "mlir/IR/MLIRContext.h" // from @llvm-project
51 #include "mlir/IR/Matchers.h" // from @llvm-project
52 #include "mlir/IR/OpDefinition.h" // from @llvm-project
53 #include "mlir/IR/OpImplementation.h" // from @llvm-project
54 #include "mlir/IR/PatternMatch.h" // from @llvm-project
55 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
56 #include "mlir/IR/Types.h" // from @llvm-project
57 #include "mlir/IR/Value.h" // from @llvm-project
58 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project
59 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
60 #include "mlir/Parser.h" // from @llvm-project
61 #include "mlir/Support/LLVM.h" // from @llvm-project
62 #include "mlir/Support/LogicalResult.h" // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
68 #include "tensorflow/core/platform/logging.h"
69 #include "tensorflow/core/util/tensor_format.h"
70
71 // These are currently aliases and the alias will be removed, verified
72 // equivalent until then.
73 // TODO(b/178519687): Remvoe once addressed.
74 static_assert(std::is_same<tensorflow::int64, std::int64_t>::value,
75 "tensorflow::int64 is expected to match std::int64_t");
76
77 namespace mlir {
78 namespace TF {
79
80 //===----------------------------------------------------------------------===//
81 // TF Dialect Interfaces
82 //===----------------------------------------------------------------------===//
83
84 namespace {
85
86 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anon8c7eb4140111::TFConstantFoldInterface87 TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anon8c7eb4140111::TFConstantFoldInterface88 LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
89 SmallVectorImpl<OpFoldResult> &results) const final {
90 return TensorFlowDialect::constantFold(op, operands, results);
91 }
92 };
93
94 struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
TFDecodeAttributesInterfacemlir::TF::__anon8c7eb4140111::TFDecodeAttributesInterface95 TFDecodeAttributesInterface(Dialect *dialect)
96 : DialectDecodeAttributesInterface(dialect) {}
decodemlir::TF::__anon8c7eb4140111::TFDecodeAttributesInterface97 LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const {
98 return TensorFlowDialect::decode(input, output);
99 }
100 };
101
102 struct TFInlinerInterface : public DialectInlinerInterface {
103 using DialectInlinerInterface::DialectInlinerInterface;
104
105 //===--------------------------------------------------------------------===//
106 // Analysis Hooks
107 //===--------------------------------------------------------------------===//
108
109 // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
110 // a TF operation.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface111 bool isLegalToInline(Operation *call, Operation *callable,
112 bool wouldBeCloned) const final {
113 // Check that the TF call operation is one that is legal to inline.
114 return !isa<TPUPartitionedCallOp>(call);
115 }
116
117 // Returns if its legal to inline 'src' region into the 'dest' region
118 // attached to a TF operation.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface119 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
120 BlockAndValueMapping &valueMapping) const final {
121 // Allow inlining in regions attached to region based control flow
122 // operations only if the src region is a single block region
123 return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
124 llvm::hasSingleElement(*src);
125 }
126
127 // Returns true if its legal to inline a TF operation `op` into the `dest`
128 // region.
isLegalToInlinemlir::TF::__anon8c7eb4140111::TFInlinerInterface129 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
130 BlockAndValueMapping &) const final {
131 // An op is legal to inline if either of the following conditions is true:
132 // (a) Its legal to duplicate the Op.
133 // (b) The Op is inside a single use function. If that function is inlined,
134 // post inlining, the function will be dead and eliminated from the IR.
135 // So there won't be any code duplication.
136 // plus the function caller op can be replaced by inlined ops.
137 return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
138 }
139
140 //===--------------------------------------------------------------------===//
141 // Transformation Hooks
142 //===--------------------------------------------------------------------===//
143
144 // Attempts to materialize a conversion for a type mismatch between a call
145 // from this dialect, and a callable region. This method should generate an
146 // operation that takes 'input' as the only operand, and produces a single
147 // result of 'resultType'. If a conversion can not be generated, nullptr
148 // should be returned.
materializeCallConversionmlir::TF::__anon8c7eb4140111::TFInlinerInterface149 Operation *materializeCallConversion(OpBuilder &builder, Value input,
150 Type result_type,
151 Location conversion_loc) const final {
152 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
153 return nullptr;
154 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
155 /*truncate=*/builder.getBoolAttr(false));
156 }
157 };
158 } // end anonymous namespace
159
160 //===----------------------------------------------------------------------===//
161 // TF Dialect
162 //===----------------------------------------------------------------------===//
163
164 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)165 bool TensorFlowDialect::CanDuplicate(Operation *op) {
166 // If the op is marked with the cannot duplicate trait, it cannot be
167 // duplicated.
168 if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
169
170 // If the op has no memory side effects, it can be duplicated.
171 if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
172
173 // If the op is marked stateless using the `is_stateless` attribute, that
174 // attribute determines if the op can be duplicated.
175 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
176 return is_stateless.getValue();
177
178 // Otherwise, assume ops can be duplicated by default if its registered, else
179 // it cannot be for unknown ops.
180 return op->isRegistered();
181 }
182
183 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)184 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
185 // If the op has no memory side effects, it has no side effects
186 if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
187
188 // If the op is marked stateless using the `is_stateless` attribute, then
189 // it has no side effects.
190 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
191 return !is_stateless.getValue();
192
193 // Terminators defined in the TF dialect do not have side effects.
194 if (op->hasTrait<OpTrait::IsTerminator>()) return false;
195
196 // Otherwise assume that the op can have side effects.
197 return true;
198 }
199
200 std::vector<TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()201 *TensorFlowDialect::GetAdditionalOperationHooks() {
202 static auto *const additional_operation_hooks =
203 new std::vector<TensorFlowDialect::AdditionalOpFunction>();
204 return additional_operation_hooks;
205 }
206
207 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
208 TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
209
TensorFlowDialect(MLIRContext * context)210 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
211 : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
212 addOperations<
213 #define GET_OP_LIST
214 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
215 >();
216 addOperations<
217 #define GET_OP_LIST
218 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
219 >();
220 addTypes<
221 #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
222 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
223 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
224 >();
225 addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
226 TFConstantFoldInterface>();
227 addAttributes<ShapeAttr, FuncAttr>();
228
229 // Support unknown operations because not all TensorFlow operations are
230 // registered.
231 allowUnknownOperations();
232
233 for (const auto &hook : *GetAdditionalOperationHooks()) {
234 hook(*this);
235 }
236 }
237
238 namespace {
239
ParseShapeAttr(MLIRContext * context,StringRef spec,Location loc)240 ShapeAttr ParseShapeAttr(MLIRContext *context, StringRef spec, Location loc) {
241 auto emit_error = [&, spec]() {
242 emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
243 return nullptr;
244 };
245
246 if (!spec.consume_front("shape<")) return emit_error();
247
248 if (spec.consume_front("*>"))
249 return mlir::TF::ShapeAttr::get(context, llvm::None);
250
251 SmallVector<int64_t, 4> shape;
252 while (!spec.consume_front(">")) {
253 int64_t dim;
254
255 if (spec.consume_front("?"))
256 dim = -1;
257 else if (spec.consumeInteger(10, dim) || dim < 0)
258 return emit_error();
259
260 spec.consume_front("x");
261
262 shape.push_back(dim);
263 }
264
265 return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
266 }
267
PrintShapeAttr(ShapeAttr attr,DialectAsmPrinter & os)268 void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT
269 os << "shape";
270
271 os << "<";
272 if (attr.hasRank()) {
273 auto print_dim = [&](int64_t dim) {
274 if (dim > -1)
275 os << dim;
276 else
277 os << "?";
278 };
279 llvm::interleave(attr.getShape(), os, print_dim, "x");
280 } else {
281 os << "*";
282 }
283 os << ">";
284 }
285
286 // Parses a #tf.func attribute of the following format:
287 //
288 // #tf.func<@symbol, {attr = "value"}>
289 //
290 // where the first element is a SymbolRefAttr and the second element is a
291 // DictionaryAttr.
ParseFuncAttr(MLIRContext * context,StringRef spec,Location loc)292 FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) {
293 auto emit_error = [&, spec]() {
294 emitError(loc, "invalid TensorFlow func attribute: ") << spec;
295 return nullptr;
296 };
297
298 if (!spec.consume_front("func<")) return emit_error();
299
300 size_t func_name_num_read = 0;
301 Attribute func_name_attr =
302 mlir::parseAttribute(spec, context, func_name_num_read);
303 if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
304 return emit_error();
305 spec = spec.drop_front(func_name_num_read);
306
307 if (!spec.consume_front(", ")) return emit_error();
308
309 size_t func_attrs_num_read = 0;
310 Attribute func_attrs_attr =
311 mlir::parseAttribute(spec, context, func_attrs_num_read);
312 if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
313 return emit_error();
314 spec = spec.drop_front(func_attrs_num_read);
315
316 if (!spec.consume_front(">")) return emit_error();
317
318 return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
319 func_attrs_attr.cast<DictionaryAttr>());
320 }
321
322 // Prints a #tf.func attribute of the following format:
323 //
324 // #tf.func<@symbol, {attr = "value"}>
PrintFuncAttr(FuncAttr attr,DialectAsmPrinter & os)325 void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) {
326 os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">";
327 }
328
329 } // namespace
330
parseAttribute(DialectAsmParser & parser,Type type) const331 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
332 Type type) const {
333 auto spec = parser.getFullSymbolSpec();
334 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
335
336 if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
337
338 if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
339
340 return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
341 }
342
printAttribute(Attribute attr,DialectAsmPrinter & os) const343 void TensorFlowDialect::printAttribute(Attribute attr,
344 DialectAsmPrinter &os) const {
345 if (auto shape_attr = attr.dyn_cast<ShapeAttr>())
346 PrintShapeAttr(shape_attr, os);
347 else if (auto func_attr = attr.dyn_cast<FuncAttr>())
348 PrintFuncAttr(func_attr, os);
349 else
350 llvm_unreachable("unexpected tensorflow attribute type");
351 }
352
353 // Parses a type registered to this dialect.
parseType(DialectAsmParser & parser) const354 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
355 StringRef data;
356 if (parser.parseKeyword(&data)) return Type();
357
358 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
359
360 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
361 if (data == name) return tftype##Type::get(getContext());
362 // Custom TensorFlow types are handled separately at the end as they do partial
363 // match.
364 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
365 // NOLINTNEXTLINE
366 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
367
368 if (data.startswith("resource")) return ParseResourceType(parser, loc);
369 if (data.startswith("variant")) return ParseVariantType(parser, loc);
370 return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
371 }
372
373 // Prints a type registered to this dialect.
printType(Type ty,DialectAsmPrinter & os) const374 void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const {
375 assert(ty.isa<TensorFlowType>());
376 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
377 if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
378 os << name; \
379 return; \
380 }
381 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
382 if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
383 Print##tftype##Type(derived_ty, os); \
384 return; \
385 }
386 // NOLINTNEXTLINE
387 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
388
389 llvm_unreachable("unexpected tensorflow type kind");
390 }
391
392 namespace {
393 template <typename TypeWithSubtype>
ParseTypeWithSubtype(MLIRContext * context,DialectAsmParser & parser,Location loc)394 Type ParseTypeWithSubtype(MLIRContext *context, DialectAsmParser &parser,
395 Location loc) {
396 // Default type without inferred subtypes.
397 if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
398
399 // Most types with subtypes have only one subtype.
400 SmallVector<TensorType, 1> subtypes;
401 do {
402 TensorType tensor_ty;
403 if (parser.parseType(tensor_ty)) return Type();
404 subtypes.push_back(tensor_ty);
405 } while (succeeded(parser.parseOptionalComma()));
406
407 if (parser.parseGreater()) return Type();
408 return TypeWithSubtype::getChecked(subtypes, context, loc);
409 }
410
411 template <typename TypeWithSubtype>
PrintTypeWithSubtype(StringRef type,TypeWithSubtype ty,DialectAsmPrinter & os)412 void PrintTypeWithSubtype(StringRef type, TypeWithSubtype ty,
413 DialectAsmPrinter &os) {
414 os << type;
415 ArrayRef<TensorType> subtypes = ty.getSubtypes();
416 if (subtypes.empty()) return;
417
418 os << "<";
419 interleaveComma(subtypes, os);
420 os << ">";
421 }
422 } // anonymous namespace
423
ParseResourceType(DialectAsmParser & parser,Location loc) const424 Type TensorFlowDialect::ParseResourceType(DialectAsmParser &parser,
425 Location loc) const {
426 return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
427 }
428
PrintResourceType(ResourceType ty,DialectAsmPrinter & os) const429 void TensorFlowDialect::PrintResourceType(ResourceType ty,
430 DialectAsmPrinter &os) const {
431 return PrintTypeWithSubtype("resource", ty, os);
432 }
433
ParseVariantType(DialectAsmParser & parser,Location loc) const434 Type TensorFlowDialect::ParseVariantType(DialectAsmParser &parser,
435 Location loc) const {
436 return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
437 }
438
PrintVariantType(VariantType ty,DialectAsmPrinter & os) const439 void TensorFlowDialect::PrintVariantType(VariantType ty,
440 DialectAsmPrinter &os) const {
441 return PrintTypeWithSubtype("variant", ty, os);
442 }
443
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)444 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
445 Attribute value, Type type,
446 Location loc) {
447 return builder.create<ConstOp>(loc, type, value);
448 }
449
450 } // namespace TF
451 } // namespace mlir
452