1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
17
18 #include <algorithm>
19 #include <iterator>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/Dialect/Traits.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/Matchers.h" // from @llvm-project
38 #include "mlir/IR/OpDefinition.h" // from @llvm-project
39 #include "mlir/IR/OpImplementation.h" // from @llvm-project
40 #include "mlir/IR/PatternMatch.h" // from @llvm-project
41 #include "mlir/IR/Types.h" // from @llvm-project
42 #include "mlir/IR/Value.h" // from @llvm-project
43 #include "mlir/Support/LogicalResult.h" // from @llvm-project
44 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47
48 namespace mlir {
49 namespace tf_executor {
50
51 //===----------------------------------------------------------------------===//
52 // TF Executor Dialect
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56
57 struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks
62 //===--------------------------------------------------------------------===//
63
64 // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_executor::__anon2956a5d10111::TensorFlowExecutorInlinerInterface65 bool isLegalToInline(Operation *call, Operation *callable,
66 bool wouldBeCloned) const final {
67 return true;
68 }
69 // Override the inlining hook to determine if 'src' can be inlined into
70 // 'dest'.
isLegalToInlinemlir::tf_executor::__anon2956a5d10111::TensorFlowExecutorInlinerInterface71 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72 BlockAndValueMapping &value_mapping) const final {
73 // Allow inlining into tf.island regions if the incoming region has a single
74 // block.
75 return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
76 llvm::hasSingleElement(*src);
77 }
78 };
79
80 struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
81 using DialectFoldInterface::DialectFoldInterface;
82
83 // Registered hook to check if the given region, which is attached to an
84 // operation that is *not* isolated from above (i.e. no internal regions
85 // reference values defined in an enclosing region), should be used when
86 // materializing constants.
87 // In the executor dialect we materialize inside an island.
shouldMaterializeIntomlir::tf_executor::__anon2956a5d10111::TensorFlowExecutorDialectFoldInterface88 bool shouldMaterializeInto(Region *region) const final {
89 return isa<tf_executor::IslandOp>(region->getParentOp());
90 }
91 };
92
93 } // namespace
94
TensorFlowExecutorDialect(MLIRContext * context)95 TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
96 : Dialect(/*name=*/"tf_executor", context,
97 TypeID::get<TensorFlowExecutorDialect>()) {
98 addOperations<
99 #define GET_OP_LIST
100 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
101 >();
102
103 addInterfaces<TensorFlowExecutorInlinerInterface,
104 TensorFlowExecutorDialectFoldInterface>();
105
106 addTypes<ControlType, TokenType>();
107 }
108
parseType(DialectAsmParser & parser) const109 Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const {
110 StringRef data_type;
111 if (parser.parseKeyword(&data_type)) return Type();
112
113 if (data_type == "control") return ControlType::get(getContext());
114 if (data_type == "token") return TokenType::get(getContext());
115 parser.emitError(parser.getNameLoc())
116 << "unknown tf_executor type: " << data_type;
117 return nullptr;
118 }
119
printType(Type type,DialectAsmPrinter & os) const120 void TensorFlowExecutorDialect::printType(Type type,
121 DialectAsmPrinter &os) const {
122 if (type.isa<ControlType>()) {
123 os << "control";
124 return;
125 }
126 if (type.isa<TokenType>()) {
127 os << "token";
128 return;
129 }
130 os << "<unknown tf_executor type>";
131 }
132
133 //===----------------------------------------------------------------------===//
134 // Implementation for all the operations defined in ODS (op definition spec).
135 //===----------------------------------------------------------------------===//
136
137 namespace {
138
139 // Verifies that every control operands are at the end of the list.
140 // Used by the constraint `ControlOperandsAfterAllData` in ODS.
VerifyControlOperandsAfterAllData(Operation * op)141 LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
142 bool found_control = false;
143 for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
144 if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
145 found_control = true;
146 continue;
147 }
148 if (found_control)
149 return op->emitOpError() << "found non-control operand #" << operand_idx
150 << " after control operand";
151 }
152 return success();
153 }
154
155 } // anonymous namespace
156
157 //===----------------------------------------------------------------------===//
158 // tf_executor.graph
159 //===----------------------------------------------------------------------===//
160
GetFetch()161 FetchOp GraphOp::GetFetch() { return llvm::cast<FetchOp>(GetBody().back()); }
162
163 namespace {
164
Verify(GraphOp graph)165 LogicalResult Verify(GraphOp graph) {
166 auto *executorDialect = graph->getDialect();
167
168 if (graph.GetBody().empty())
169 return graph.emitOpError() << "expects a non-empty body";
170
171 // Only tf_executor dialect operations are allowed to be immediately nested
172 // in a tf_executor.graph region.
173 for (Operation &op : graph.GetBody()) {
174 if (op.getDialect() != executorDialect)
175 return op.emitOpError() << "unallowed inside a tf_executor.graph region";
176 if (isa<GraphOp>(op))
177 return op.emitOpError()
178 << "unallowed directly inside another tf_executor.graph";
179 }
180
181 Operation &fetch = graph.GetBody().back();
182 if (!isa<FetchOp>(fetch))
183 return fetch.emitOpError()
184 << "invalid tf_executor.graph terminator, fetch expected";
185
186 // Ensure that the fetch terminator operands matches the graph result type.
187 // All the non-control operands of the fetch operation must match the graph
188 // returned value.
189 if (fetch.getNumOperands() < graph.getNumResults())
190 return fetch.emitOpError() << "does not have enough operands to cover the "
191 "graph returned values";
192 for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
193 Value operand = fetch.getOperand(i);
194 // Break out of the loop at the first control operand encountered.
195 const int64_t num_results = graph.getNumResults();
196 if (operand.getType().isa<ControlType>()) {
197 if (i != num_results)
198 return fetch.emitOpError()
199 << "operand #" << i
200 << " is a control type, can't be bound to a graph result";
201 break;
202 }
203 if (i >= num_results)
204 return fetch.emitOpError()
205 << "operand #" << i << " does not have a graph results to bind";
206 if (graph.getResult(i).getType() != operand.getType())
207 return fetch.emitOpError()
208 << "operand #" << i << " type mismatch graph results";
209 }
210 return success();
211 }
212
Print(GraphOp graph,OpAsmPrinter & p)213 void Print(GraphOp graph, OpAsmPrinter &p) {
214 p << graph.getOperationName();
215 p.printRegion(graph.getOperation()->getRegion(0));
216 p.printOptionalAttrDict(graph.getAttrs());
217 }
218
ParseGraphOp(OpAsmParser & parser,OperationState & result)219 ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
220 llvm::SMLoc loc = parser.getCurrentLocation();
221
222 // Parse the body region.
223 Region &body = *result.addRegion();
224 if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
225
226 // Ensure that the region is well formed: it contains at least a block with
227 // a FetchOp terminator.
228 GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
229
230 if (!llvm::hasSingleElement(body))
231 return parser.emitError(loc) << "expects a single block region";
232
233 // Get the results type from the terminator type inside the graph.
234 Operation &fetch = body.back().back();
235 if (!isa<FetchOp>(fetch))
236 return parser.emitError(loc) << "expects a tf_executor.fetch terminator";
237
238 // The return value of the graph operation are the non-control operands of
239 // the fetch operation.
240 result.types.reserve(fetch.getNumOperands());
241 for (Type type : fetch.getOperandTypes()) {
242 if (type.isa<ControlType>()) break;
243 result.types.push_back(type);
244 }
245
246 // Parse the optional attribute list.
247 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
248
249 return success();
250 }
251
252 } // anonymous namespace
253
254 //===----------------------------------------------------------------------===//
255 // tf_executor.fetch
256 //===----------------------------------------------------------------------===//
257
258 //===----------------------------------------------------------------------===//
259 // tf_executor.island
260 //===----------------------------------------------------------------------===//
261
GetYield()262 YieldOp IslandOp::GetYield() { return llvm::cast<YieldOp>(GetBody().back()); }
263
264 // Checks if a tf_executor.island wraps a single operation and the single
265 // operation results are perfectly forwarded to the islands yield.
WrapsSingleOp()266 bool IslandOp::WrapsSingleOp() {
267 auto body = GetBody().without_terminator();
268 if (!hasSingleElement(body)) return false;
269
270 Operation &wrapped_op = *body.begin();
271 YieldOp yield = GetYield();
272 return wrapped_op.getNumResults() == yield.getNumOperands() &&
273 std::equal(wrapped_op.getResults().begin(),
274 wrapped_op.getResults().end(), yield.getOperands().begin());
275 }
276
277 namespace {
278
Verify(IslandOp island)279 LogicalResult Verify(IslandOp island) {
280 if (!island.GetBody().args_empty())
281 return island.emitOpError() << "expects body without any arguments";
282
283 Operation &yield = island.GetBody().back();
284 if (!isa<YieldOp>(yield))
285 return yield.emitOpError()
286 << "invalid tf_executor.island terminator, yield expected";
287
288 // Ensure that the yield terminator operands matches the island results type.
289 int result_count = island.getNumResults() - 1; // -1 for the control token
290 const int num_operands = yield.getNumOperands();
291 if (num_operands != result_count)
292 return yield.emitOpError()
293 << "has " << yield.getNumOperands()
294 << " operand, but island returns " << result_count;
295 for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
296 if (island.getResult(operand_idx).getType() !=
297 yield.getOperand(operand_idx).getType())
298 return yield.emitOpError()
299 << "operand #" << operand_idx << " type mismatch island results";
300 }
301
302 // Check that there aren't any control results other than the last one.
303 Type control_type = ControlType::get(island.getContext());
304 for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
305 if (island.getResult(operand_idx).getType() == control_type)
306 return yield.emitOpError()
307 << "unexpected control type for operand #" << operand_idx;
308 }
309 return success();
310 }
311
Print(IslandOp op,OpAsmPrinter & p)312 void Print(IslandOp op, OpAsmPrinter &p) {
313 p << op.getOperationName();
314 if (op.getNumOperands()) {
315 // These are always control operand, no explicit type needed.
316 p << '(';
317 p.printOperands(op.getOperands());
318 p << ')';
319 }
320
321 // Check if we can print the short "wraps" form: that is if the island
322 // contains a single operation and the result of this operation are perfectly
323 // forwarded to the yield.
324 if (op.getAttrs().empty() && op.WrapsSingleOp()) {
325 Operation &wrapped_op = op.GetBody().front();
326 YieldOp yield_op = op.GetYield();
327 // The "wraps" syntax only encodes a single location.
328 // In order to correctly round-trip, we can only use this syntax when all
329 // the locations are identical.
330 if (wrapped_op.getLoc() == op.getLoc() &&
331 yield_op.getLoc() == op.getLoc()) {
332 p << " wraps ";
333 p.printGenericOp(&wrapped_op);
334 return;
335 }
336 }
337 p.printRegion(op.getOperation()->getRegion(0));
338 p.printOptionalAttrDict(op.getAttrs());
339 }
340
ParseIslandOp(OpAsmParser & parser,OperationState & result)341 ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
342 llvm::SMLoc loc = parser.getCurrentLocation();
343 Type control_type = ControlType::get(parser.getBuilder().getContext());
344
345 // Parse optional argument list (control dependencies only).
346 SmallVector<OpAsmParser::OperandType, 4> op_infos;
347 if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
348 return failure();
349 if (!op_infos.empty()) {
350 SmallVector<Type, 2> types(op_infos.size(), control_type);
351 parser.resolveOperands(op_infos, types, loc, result.operands);
352 }
353
354 // Parse the body region.
355 Region &body = *result.addRegion();
356
357 if (succeeded(parser.parseOptionalKeyword("wraps"))) {
358 // If we parse the short version of the island, we have an operation in the
359 // generic form that follows the "wraps" keyword. Parse it inside the region
360 // and forward all of its results as-is to the yield operation.
361 body.push_back(new Block);
362 Block &block = body.back();
363 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
364 if (!wrapped_op) return failure();
365 OpBuilder builder(parser.getBuilder().getContext());
366 builder.setInsertionPointToEnd(&block);
367 builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
368 result.location = wrapped_op->getLoc();
369 } else if (parser.parseRegion(body, llvm::None, llvm::None)) {
370 return failure();
371 }
372
373 IslandOp::ensureTerminator(body, parser.getBuilder(), result.location);
374
375 // Get the results type for the island from the terminator operands.
376 Operation &yield = body.back().back();
377 result.types.reserve(yield.getNumOperands() + 1);
378 result.types.append(yield.operand_type_begin(), yield.operand_type_end());
379 result.types.push_back(control_type);
380
381 // Parse the optional attribute list.
382 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
383 return success();
384 }
385
386 } // anonymous namespace
387
388 //===----------------------------------------------------------------------===//
389 // tf_executor.yield
390 //===----------------------------------------------------------------------===//
391
392 //===----------------------------------------------------------------------===//
393 // tf_executor.Switch
394 //===----------------------------------------------------------------------===//
395
396 namespace {
397
ParseSwitchOp(OpAsmParser & parser,OperationState & result)398 ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
399 SmallVector<OpAsmParser::OperandType, 2> op_infos;
400 SmallVector<Type, 1> types;
401 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
402 return failure();
403 if (types.size() != 1)
404 return parser.emitError(parser.getNameLoc())
405 << " expects only a single data type";
406
407 // Support parsing either a functional type (in which case all the types are
408 // fully qualified) or a short form with a single type (in which case the data
409 // input and the outputs are all using this type and predicate is tensor<i1>
410 // type).
411 if (types.front().isa<FunctionType>()) {
412 FunctionType type = types.front().cast<FunctionType>();
413 if (type.getNumInputs() < 2)
414 return parser.emitError(parser.getNameLoc())
415 << " expects a single data type and a predicate";
416 result.types.assign(type.getResults().begin(), type.getResults().end());
417 types.assign(type.getInputs().begin(), type.getInputs().end());
418 } else {
419 if (op_infos.size() < 2)
420 return parser.emitError(parser.getNameLoc())
421 << " expects a single data type and a predicate";
422 Type control_type = ControlType::get(parser.getBuilder().getContext());
423 result.types.append(2, types[0]);
424 result.types.push_back(control_type);
425 Type i1_type = parser.getBuilder().getI1Type();
426 RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
427 types.push_back(predicate_type);
428 types.append(op_infos.size() - 2, control_type);
429 }
430
431 llvm::SMLoc loc = parser.getCurrentLocation();
432 if (parser.resolveOperands(op_infos, types, loc, result.operands))
433 return failure();
434
435 return parser.parseOptionalAttrDict(result.attributes);
436 }
437
Print(SwitchOp switch_op,OpAsmPrinter & p)438 void Print(SwitchOp switch_op, OpAsmPrinter &p) {
439 p << switch_op.getOperationName() << ' ';
440 p.printOperands(switch_op.getOperands());
441 Type data_operand_ty = switch_op.data().getType();
442 // If the types aren't perfectly matching, print the functional type syntax
443 // else print the shorter single type.
444 p << " : ";
445 if (switch_op.trueOutput().getType() != data_operand_ty ||
446 switch_op.falseOutput().getType() != data_operand_ty ||
447 switch_op.predicate().getType().isa<UnrankedTensorType>()) {
448 p.printFunctionalType(switch_op.getOperation());
449 } else {
450 p << switch_op.getType(0);
451 }
452 p.printOptionalAttrDict(switch_op.getAttrs());
453 }
454
455 } // anonymous namespace
456
457 //===----------------------------------------------------------------------===//
458 // tf_executor.SwitchN
459 //===----------------------------------------------------------------------===//
460
461 namespace {
462
Verify(SwitchNOp switchn)463 LogicalResult Verify(SwitchNOp switchn) {
464 IntegerAttr num_outs = switchn->getAttrOfType<IntegerAttr>("num_outs");
465 if (!num_outs)
466 return switchn.emitOpError() << "expects a `num_outs` integer attribute";
467
468 // Expect num_outs results + 1 control output.
469 if (switchn.getNumResults() != num_outs.getInt() + 1)
470 return switchn.emitOpError()
471 << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
472 << (switchn.getNumResults() - 1);
473
474 // Check that operand can be broadcasted to each output type.
475 auto operand0_type = switchn.getOperand(0).getType();
476 TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
477 if (!operand0_tensor_type) {
478 return switchn.emitOpError()
479 << "expects data operand to have tensor type but got "
480 << operand0_type;
481 }
482 for (Type output_type : switchn.getResultTypes()) {
483 if (output_type.isa<ControlType>()) break;
484
485 TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
486 if (!output_tensor_type) {
487 return switchn.emitOpError()
488 << "expects outputs to have tensor type but got " << output_type;
489 }
490
491 // If the output type is a ref type, then the operand type should also be of
492 // the same ref type. However, if the output type is a non-ref type T, then
493 // the operand can be tensor of type T or T_REF.
494 bool is_output_ref =
495 output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
496 if (is_output_ref &&
497 !operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
498 return switchn.emitOpError()
499 << "expects same operand and output element type but got "
500 << operand0_tensor_type << " vs " << output_tensor_type;
501 }
502 Type broadcasted_type = OpTrait::util::getBroadcastedType(
503 TF::DropRefAndSubTypes(operand0_tensor_type),
504 TF::DropRefAndSubTypes(output_tensor_type));
505 if (!broadcasted_type) {
506 return switchn.emitOpError()
507 << "expects data operand to be broadcastable with all output types"
508 << " but got " << operand0_tensor_type << " vs "
509 << output_tensor_type;
510 }
511 }
512 return success();
513 }
514
Print(SwitchNOp switchn,OpAsmPrinter & p)515 void Print(SwitchNOp switchn, OpAsmPrinter &p) {
516 p << switchn.getOperationName() << ' ';
517 auto operands = switchn.getOperands();
518 // Print the 2 data operands.
519 p.printOperands(operands.begin(), std::next(operands.begin(), 2));
520 p << " of " << (switchn.getNumResults() - 1);
521 // print control dependencies if any
522 if (!llvm::empty(switchn.controlInputs())) {
523 p << " (";
524 p.printOperands(switchn.controlInputs());
525 p << ")";
526 }
527 p << " : " << switchn.getType(0);
528 p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
529 }
530
ParseSwitchNOp(OpAsmParser & parser,OperationState & result)531 ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
532 // Parsing:
533 // %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
534 // Where the first operand is the data to replicate, the second is an i32
535 // indicating which output to populate, followed by the keyword `of` and the
536 // number of outputs (+1 for the control token).
537 SmallVector<OpAsmParser::OperandType, 2> op_infos;
538 SmallVector<Type, 1> types;
539 llvm::SMLoc loc = parser.getCurrentLocation();
540 IntegerAttr num_outs;
541 Type i64_type = parser.getBuilder().getIntegerType(64);
542 if (parser.parseOperandList(op_infos, 2) || parser.parseKeyword("of") ||
543 parser.parseAttribute(num_outs, i64_type, "num_outs",
544 result.attributes) ||
545 parser.parseOperandList(op_infos,
546 OpAsmParser::Delimiter::OptionalParen) ||
547 parser.parseColonTypeList(types))
548 return failure();
549 if (types.size() != 1)
550 return parser.emitError(parser.getNameLoc())
551 << " expects only a single data type";
552
553 if (num_outs.getInt() <= 0)
554 return parser.emitError(parser.getNameLoc())
555 << " expects a positive number of outputs";
556
557 // `types` already contains the type for the data, add an i32 for the
558 // output_index, and then the optional control inputs.
559 auto builder = parser.getBuilder();
560 types.push_back(RankedTensorType::get({}, builder.getIntegerType(32)));
561 Type control_type = ControlType::get(builder.getContext());
562 types.append(op_infos.size() - 2, control_type);
563
564 if (parser.resolveOperands(op_infos, types, loc, result.operands))
565 return failure();
566
567 // Output result types is a replication `num_outs` times the data input type.
568 result.types.append(num_outs.getInt(), types[0]);
569 result.types.push_back(control_type);
570
571 return parser.parseOptionalAttrDict(result.attributes);
572 }
573
574 } // anonymous namespace
575
576 //===----------------------------------------------------------------------===//
577 // tf_executor.Merge
578 //===----------------------------------------------------------------------===//
579
580 namespace {
581
Verify(MergeOp merge)582 LogicalResult Verify(MergeOp merge) {
583 if (!merge.getNumOperands())
584 return merge.emitOpError() << "expects at least one operand";
585
586 Type data_type = merge.getOperand(0).getType();
587 if (data_type.isa<ControlType>())
588 return merge.emitOpError() << "expects a non-control input";
589
590 // Check that each operand can be individually broadcasted to the output type.
591 Type output_type = merge.output().getType();
592 TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
593 if (!output_tensor_ty) {
594 return merge.emitOpError()
595 << "expects output to have tensor type but got " << output_type;
596 }
597 bool is_output_ref =
598 output_tensor_ty.getElementType().isa<TF::TensorFlowRefType>();
599 for (Type operand_type : merge.getOperandTypes()) {
600 if (operand_type.isa<ControlType>()) break;
601
602 // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
603 // constraint.
604 TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
605 if (!operand_tensor_ty)
606 return merge.emitOpError()
607 << "expects data operands to have tensor type but got "
608 << operand_type;
609
610 // If output type is a ref type then all operand types should also be of the
611 // same ref type. However, if the output type is a non-ref type T, operands
612 // can be tensor of type T or T_REF.
613 if (is_output_ref &&
614 !operand_tensor_ty.getElementType().isa<TF::TensorFlowRefType>()) {
615 return merge.emitOpError()
616 << "expects same operand and output element type but got "
617 << operand_tensor_ty << " vs " << output_tensor_ty;
618 }
619 Type broadcasted_type = OpTrait::util::getBroadcastedType(
620 TF::DropRefAndSubTypes(output_tensor_ty),
621 TF::DropRefAndSubTypes(operand_tensor_ty));
622 if (!broadcasted_type)
623 return merge.emitOpError()
624 << "expects all operands to be broadcastable with output type"
625 << " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
626 }
627 return success();
628 }
629
Print(MergeOp merge,OpAsmPrinter & p)630 void Print(MergeOp merge, OpAsmPrinter &p) {
631 // Use short form only when there are exactly two data operands and their
632 // type matches the output type. Otherwise, use the generic printer.
633 bool use_short_form = true;
634 int num_data_operands = 0;
635
636 Type output_type = merge.output().getType();
637 for (Type operand_type : merge.getOperandTypes()) {
638 if (operand_type.isa<ControlType>()) break;
639 num_data_operands++;
640
641 if (operand_type != output_type) {
642 use_short_form = false;
643 break;
644 }
645 }
646
647 p << merge.getOperationName() << ' ';
648 p.printOperands(merge.getOperands());
649
650 // Print the type signature of the operation.
651 p << " : ";
652 if (!use_short_form || num_data_operands != 2) {
653 p.printFunctionalType(merge.getOperation());
654 } else {
655 p << output_type;
656 }
657
658 p.printOptionalAttrDict(merge.getAttrs());
659 }
660
ParseMergeOp(OpAsmParser & parser,OperationState & result)661 ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) {
662 SmallVector<OpAsmParser::OperandType, 2> op_infos;
663 SmallVector<Type, 1> types;
664 llvm::SMLoc loc = parser.getCurrentLocation();
665 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
666 return failure();
667 if (types.size() != 1)
668 return parser.emitError(parser.getNameLoc())
669 << " expects only a single data type";
670
671 // Support parsing either a functional type (in which case all the types are
672 // fully qualified) or a short form with a single type (in which case the data
673 // inputs and the output are all using this type).
674 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
675 result.types.assign(type.getResults().begin(), type.getResults().end());
676 types.assign(type.getInputs().begin(), type.getInputs().end());
677 } else {
678 // In case of the short form, use the parsed type for both the operands and
679 // the remaining operands are expected to be control inputs.
680 types.push_back(Type(types.front()));
681 Type control_type = ControlType::get(parser.getBuilder().getContext());
682 types.append(op_infos.size() - 2, control_type);
683
684 RankedTensorType i32_tensor =
685 RankedTensorType::get({}, parser.getBuilder().getIntegerType(32));
686 result.types = {types.front(), i32_tensor, control_type};
687 }
688
689 if (parser.resolveOperands(op_infos, types, loc, result.operands))
690 return failure();
691
692 return parser.parseOptionalAttrDict(result.attributes);
693 }
694
695 } // anonymous namespace
696
697 //===----------------------------------------------------------------------===//
698 // tf_executor.Enter
699 //===----------------------------------------------------------------------===//
700
701 namespace {
702
703 // Default number for the parallel_iterations attributes on Enter nodes.
704 constexpr int kDefaultParallelIterations = 10;
705
Print(EnterOp enter,OpAsmPrinter & p)706 void Print(EnterOp enter, OpAsmPrinter &p) {
707 p << enter.getOperationName() << ' ';
708 p.printOperands(enter.getOperands());
709
710 p << " frame \"";
711 printEscapedString(enter.frame_name(), p.getStream());
712 p << "\"";
713 if (enter.parallel_iterations() != kDefaultParallelIterations)
714 p << " parallel_iterations " << enter.parallel_iterations();
715 if (enter.is_constant()) p << " constant ";
716
717 // If the types aren't perfectly matching, print the functional type syntax
718 // else print the shorter single type.
719 p << " : ";
720 if (enter.data().getType() != enter.output().getType()) {
721 p.printFunctionalType(enter.getOperation());
722 } else {
723 p << enter.getType(0);
724 }
725
726 p.printOptionalAttrDict(enter.getAttrs(),
727 {"frame_name", "parallel_iterations", "is_constant"});
728 }
729
ParseEnterOp(OpAsmParser & parser,OperationState & result)730 ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
731 SmallVector<OpAsmParser::OperandType, 2> op_infos;
732 llvm::SMLoc loc = parser.getCurrentLocation();
733 MLIRContext *context = parser.getBuilder().getContext();
734 if (parser.parseOperandList(op_infos)) return failure();
735 if (op_infos.empty())
736 return parser.emitError(loc) << " expects at least one data operand";
737
738 Attribute frame;
739 if (parser.parseKeyword("frame") ||
740 parser.parseAttribute(frame, NoneType::get(context), "frame_name",
741 result.attributes))
742 return failure();
743
744 Type i64 = parser.getBuilder().getIntegerType(64);
745 if (parser.parseOptionalKeyword("parallel_iterations")) {
746 result.addAttribute("parallel_iterations",
747 IntegerAttr::get(i64, kDefaultParallelIterations));
748 } else {
749 IntegerAttr parallel_iterations;
750 if (parser.parseAttribute(parallel_iterations, i64, "parallel_iterations",
751 result.attributes))
752 return failure();
753 }
754 bool has_constant = succeeded(parser.parseOptionalKeyword("constant"));
755 result.addAttribute("is_constant", BoolAttr::get(context, has_constant));
756
757 SmallVector<Type, 1> types;
758 if (parser.parseColonTypeList(types)) return failure();
759 if (types.size() != 1)
760 return parser.emitError(loc) << " expects only a single data type";
761
762 // Support parsing either a functional type (in which case all the types are
763 // fully qualified) or a short form with a single type (in which case the data
764 // input and the outputs are all using this type).
765 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
766 // One data input, and any number of control inputs.
767 if (type.getNumInputs() >= 1) {
768 result.types.assign(type.getResults().begin(), type.getResults().end());
769 types.assign(type.getInputs().begin(), type.getInputs().end());
770 } else {
771 return parser.emitError(parser.getNameLoc()) << " expects a data input";
772 }
773 } else {
774 Type control_type = ControlType::get(context);
775 types.append(op_infos.size() - 1, control_type);
776 result.addTypes({types.front(), control_type});
777 }
778
779 // Extra operands are expected to be control inputs.
780
781 if (parser.resolveOperands(op_infos, types, loc, result.operands))
782 return failure();
783
784 return parser.parseOptionalAttrDict(result.attributes);
785 }
786
787 } // anonymous namespace
788
789 //===----------------------------------------------------------------------===//
790 // tf_executor.NextIteration.Source
791 //===----------------------------------------------------------------------===//
792
793 namespace {
794
Verify(NextIterationSourceOp source)795 LogicalResult Verify(NextIterationSourceOp source) {
796 Value token = source.token();
797 if (!token.hasOneUse())
798 return source.emitOpError() << "expects a single user for produced token";
799 if (!isa<NextIterationSinkOp>(*token.user_begin()))
800 return source.emitOpError() << "token should be consumed by a sink op";
801 return success();
802 }
803
804 } // anonymous namespace
805
806 //===----------------------------------------------------------------------===//
807 // tf_executor.NextIteration.Sink
808 //===----------------------------------------------------------------------===//
809
810 namespace {
811
Verify(NextIterationSinkOp sink)812 LogicalResult Verify(NextIterationSinkOp sink) {
813 Value token = sink.token();
814 Operation *definingOp = token.getDefiningOp();
815 if (!definingOp)
816 return sink.emitOpError() << "expects a token directly produced by a "
817 "tf_executor.NextIteration.Source op: ";
818 auto source = dyn_cast<NextIterationSourceOp>(definingOp);
819 if (!source)
820 return sink.emitOpError() << "expects a token produced by a "
821 "tf_executor.NextIteration.Source op: ";
822 if (source.output().getType() != sink.input().getType())
823 return sink.emitOpError()
824 << "input type " << sink.input().getType()
825 << " mismatch the tf_executor.NextIteration.Source output type: "
826 << source.output().getType();
827 return success();
828 }
829
830 } // anonymous namespace
831
GetSource()832 NextIterationSourceOp NextIterationSinkOp::GetSource() {
833 return cast<NextIterationSourceOp>(token().getDefiningOp());
834 }
835
836 //===----------------------------------------------------------------------===//
837 // tf_executor.Exit
838 //===----------------------------------------------------------------------===//
839
840 namespace {
841
Print(ExitOp exit,OpAsmPrinter & p)842 void Print(ExitOp exit, OpAsmPrinter &p) {
843 p << exit.getOperationName() << ' ';
844 p.printOperands(exit.getOperands());
845 p << " : " << exit.getType(0);
846 p.printOptionalAttrDict(exit.getAttrs());
847 }
848
ParseExitOp(OpAsmParser & parser,OperationState & result)849 ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
850 SmallVector<OpAsmParser::OperandType, 2> op_infos;
851 SmallVector<Type, 1> types;
852
853 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
854 return failure();
855
856 llvm::SMLoc loc = parser.getCurrentLocation();
857 Type control_type = ControlType::get(parser.getBuilder().getContext());
858 types.append(op_infos.size() - 1, control_type);
859 if (parser.resolveOperands(op_infos, types, loc, result.operands))
860 return failure();
861
862 result.addTypes({types.front(), control_type});
863 return parser.parseOptionalAttrDict(result.attributes);
864 }
865
866 } // anonymous namespace
867
868 //===----------------------------------------------------------------------===//
869 // tf_executor.ControlTrigger
870 //===----------------------------------------------------------------------===//
871
872 //===----------------------------------------------------------------------===//
873 // tf_executor.LoopCond
874 //===----------------------------------------------------------------------===//
875
876 namespace {
877
Print(LoopCondOp loop_cond,OpAsmPrinter & p)878 void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
879 p << loop_cond.getOperationName() << ' ';
880 p.printOperands(loop_cond.getOperands());
881
882 // If the types aren't matching (broadcast), print the functional type syntax.
883 if (loop_cond.input().getType() != loop_cond.output().getType()) {
884 p << " : ";
885 p.printFunctionalType(loop_cond.getOperation());
886 } else {
887 p << " : " << loop_cond.input().getType();
888 }
889
890 p.printOptionalAttrDict(loop_cond.getAttrs());
891 }
892
ParseLoopCondOp(OpAsmParser & parser,OperationState & result)893 ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) {
894 SmallVector<OpAsmParser::OperandType, 2> op_infos;
895
896 if (parser.parseOperandList(op_infos)) return failure();
897 if (op_infos.empty())
898 return parser.emitError(parser.getNameLoc())
899 << "expects at least one operand";
900
901 SmallVector<Type, 1> types;
902 if (parser.parseColonTypeList(types)) return failure();
903
904 // Support parsing either a functional type (in which case all the types are
905 // fully qualified) or a short form with a single type (in which case the data
906 // input and the outputs are all using this type).
907 Type control_type = ControlType::get(parser.getBuilder().getContext());
908 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
909 if (llvm::count_if(type.getInputs(),
910 [=](Type type) { return type != control_type; }) != 1)
911 return parser.emitError(parser.getNameLoc())
912 << " expects a single data type";
913 result.types.assign(type.getResults().begin(), type.getResults().end());
914 types.assign(type.getInputs().begin(), type.getInputs().end());
915 } else {
916 if (types.size() != 1)
917 return parser.emitError(parser.getNameLoc())
918 << " expects a single data type";
919 types.append(op_infos.size() - 1, control_type);
920 result.addTypes({types.front(), control_type});
921 }
922
923 llvm::SMLoc loc = parser.getCurrentLocation();
924 if (parser.resolveOperands(op_infos, types, loc, result.operands))
925 return failure();
926
927 return parser.parseOptionalAttrDict(result.attributes);
928 }
929
930 } // namespace
931
932 //===----------------------------------------------------------------------===//
933 // Canonicalization patterns
934 //===----------------------------------------------------------------------===//
935
936 // TODO(lyandy): Add canonicalization for dedupping control inputs.
937
938 //===----------------------------------------------------------------------===//
939 // tf_executor.graph
940 //===----------------------------------------------------------------------===//
941
942 namespace {
943 // Finds in a block if the op of type `InnerOpT` is the first operation and
944 // optionally followed by a terminator.
945 template <typename InnerOpT>
HasSingleOpInBlock(Block * block)946 bool HasSingleOpInBlock(Block *block) {
947 if (block->empty()) return false;
948 if (!llvm::isa<InnerOpT>(block->front())) return false;
949 // Either InnerOpT is the only instruction in the block, or there is a
950 // possible terminator.
951 return std::next(block->begin()) == block->end() ||
952 std::next(block->begin(), 2) == block->end();
953 }
954
955 // This pattern matches GraphOps with only one FetchOp (empty) and remaps the
956 // results of the GraphOp to the operands of the FetchOp.
957 struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
958 using OpRewritePattern<GraphOp>::OpRewritePattern;
959
matchAndRewritemlir::tf_executor::__anon2956a5d10e11::DropEmptyGraph960 LogicalResult matchAndRewrite(GraphOp op,
961 PatternRewriter &rewriter) const override {
962 Block &block = op.GetBody();
963 // Check if graph only has one fetch.
964 if (&block.front() != &block.back()) return failure();
965
966 // Map graph results to fetch operands.
967 rewriter.replaceOp(op, op.GetFetch().fetches());
968
969 return success();
970 }
971 };
972
973 // This pattern matches GraphOps with only one island, pulls out all inner ops
974 // of the island to the block containing the GraphOp, and then removes the
975 // GraphOp.
976 struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
977 using OpRewritePattern<GraphOp>::OpRewritePattern;
978
matchAndRewritemlir::tf_executor::__anon2956a5d10e11::HoistInnerOpsSingleIslandGraph979 LogicalResult matchAndRewrite(GraphOp op,
980 PatternRewriter &rewriter) const override {
981 Block &block = op.GetBody();
982 // Check if graph only has one island.
983 if (!HasSingleOpInBlock<IslandOp>(&block)) return failure();
984
985 FetchOp fetch_op = op.GetFetch();
986 auto island_op = llvm::cast<IslandOp>(block.front());
987 YieldOp yield_op = island_op.GetYield();
988
989 // Map graph results to inner ops results of single island.
990 llvm::SmallVector<Value, 8> new_rets;
991 for (Value operand : fetch_op.fetches()) {
992 // Control results should not be propagated out.
993 if (operand.getType().isa<ControlType>()) break;
994
995 if (operand.getDefiningOp() != island_op) {
996 // Operand is not from island, simply propagate it out.
997 new_rets.push_back(operand);
998 } else {
999 // Lookup yield operand in island for inner op result.
1000 auto result = operand.cast<OpResult>();
1001 new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
1002 }
1003 }
1004
1005 // Move inner ops from island to block containing graph.
1006 auto &island_body = island_op.GetBody().getOperations();
1007 Operation *operation = op.getOperation();
1008 operation->getBlock()->getOperations().splice(
1009 operation->getIterator(), island_body, island_body.begin(),
1010 std::prev(island_body.end()));
1011 rewriter.replaceOp(op, new_rets);
1012
1013 return success();
1014 }
1015 };
1016 } // anonymous namespace
1017
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1018 void GraphOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1019 MLIRContext *context) {
1020 results.insert<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context);
1021 }
1022
1023 //===----------------------------------------------------------------------===//
1024 // tf_executor.island
1025 //===----------------------------------------------------------------------===//
1026
1027 namespace {
1028 // This pattern matches and removes IslandOps with no inner ops, no control
1029 // operands and no data results. Control result users will have their relevant
1030 // operands removed.
1031 struct DropEmptyIslandNoOperandNoDataResult
1032 : public OpRewritePattern<IslandOp> {
1033 using OpRewritePattern<IslandOp>::OpRewritePattern;
1034
matchAndRewritemlir::tf_executor::__anon2956a5d10f11::DropEmptyIslandNoOperandNoDataResult1035 LogicalResult matchAndRewrite(IslandOp op,
1036 PatternRewriter &rewriter) const override {
1037 if (op.getNumOperands() != 0 || op.getNumResults() != 1 ||
1038 !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1039 return failure();
1040
1041 for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1042 use.getOwner()->eraseOperand(use.getOperandNumber());
1043
1044 rewriter.eraseOp(op);
1045
1046 return success();
1047 }
1048 };
1049
1050 // This pattern matches and removes IslandOps with no inner ops, no control
1051 // operands, one data result and no control result user. The single data result
1052 // (from YieldOps first operand) is forwarded to the IslandOp single data result
1053 // users.
1054 struct DropEmptyIslandNoOperandOneDataResult
1055 : public OpRewritePattern<IslandOp> {
1056 using OpRewritePattern<IslandOp>::OpRewritePattern;
1057
matchAndRewritemlir::tf_executor::__anon2956a5d10f11::DropEmptyIslandNoOperandOneDataResult1058 LogicalResult matchAndRewrite(IslandOp op,
1059 PatternRewriter &rewriter) const override {
1060 if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
1061 !op.control().use_empty() ||
1062 !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1063 return failure();
1064
1065 rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr});
1066
1067 return success();
1068 }
1069 };
1070
1071 // TODO(lyandy): Add canonicalization for empty IslandOps with more than one
1072 // control operand and no data results.
1073
1074 } // anonymous namespace
1075
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1076 void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1077 MLIRContext *context) {
1078 results.insert<DropEmptyIslandNoOperandNoDataResult,
1079 DropEmptyIslandNoOperandOneDataResult>(context);
1080 }
1081
1082 //===----------------------------------------------------------------------===//
1083 // tf_executor.ControlTrigger
1084 //===----------------------------------------------------------------------===//
1085
1086 namespace {
1087 // This pattern matches and removes ControlTriggerOps with no control operands.
1088 // Control result users will have their relevant operands removed.
1089 struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
1090 using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
1091
matchAndRewritemlir::tf_executor::__anon2956a5d11011::DropEmptyControlTrigger1092 LogicalResult matchAndRewrite(ControlTriggerOp op,
1093 PatternRewriter &rewriter) const override {
1094 if (op.getNumOperands() != 0) return failure();
1095
1096 for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1097 use.getOwner()->eraseOperand(use.getOperandNumber());
1098
1099 rewriter.eraseOp(op);
1100
1101 return success();
1102 }
1103 };
1104 } // anonymous namespace
1105
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1106 void ControlTriggerOp::getCanonicalizationPatterns(
1107 OwningRewritePatternList &results, MLIRContext *context) {
1108 results.insert<DropEmptyControlTrigger>(context);
1109 }
1110
1111 //===----------------------------------------------------------------------===//
1112 // Folders
1113 //===----------------------------------------------------------------------===//
1114
1115 //===----------------------------------------------------------------------===//
1116 // tf_executor.island
1117 //===----------------------------------------------------------------------===//
1118
fold(llvm::ArrayRef<Attribute> operands,llvm::SmallVectorImpl<OpFoldResult> & results)1119 LogicalResult IslandOp::fold(llvm::ArrayRef<Attribute> operands,
1120 llvm::SmallVectorImpl<OpFoldResult> &results) {
1121 // This folds IslandOps with no inner ops, one control operand and no data
1122 // results. The single control operand is forwarded to the IslandOp control
1123 // result users.
1124 if (getNumOperands() != 1 || getNumResults() != 1 ||
1125 !HasSingleOpInBlock<YieldOp>(&GetBody()))
1126 return failure();
1127
1128 results.emplace_back(getOperand(0));
1129
1130 return success();
1131 }
1132
1133 } // namespace tf_executor
1134 } // namespace mlir
1135
1136 //===----------------------------------------------------------------------===//
1137 // TableGen'd op method definitions
1138 //===----------------------------------------------------------------------===//
1139
1140 #define GET_OP_CLASSES
1141 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
1142