1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/OpImplementation.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
40 #include "mlir/IR/Types.h" // from @llvm-project
41 #include "mlir/IR/UseDefLists.h" // from @llvm-project
42 #include "mlir/IR/Value.h" // from @llvm-project
43 #include "mlir/Support/LLVM.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48
49 namespace mlir {
50 namespace tf_device {
51
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks
62 //===--------------------------------------------------------------------===//
63
64 // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface65 bool isLegalToInline(Operation* call, Operation* callable,
66 bool wouldBeCloned) const final {
67 return true;
68 }
69
70 // Returns if its legal to inline 'src' region into the 'dest' region
71 // attached to a TF Device operation.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface72 bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
73 BlockAndValueMapping& valueMapping) const final {
74 return true;
75 }
76
77 // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface78 bool isLegalToInline(Operation*, Region*, bool,
79 BlockAndValueMapping&) const final {
80 // For now, enable inlining all operations.
81 return true;
82 }
83
84 //===--------------------------------------------------------------------===//
85 // Transformation Hooks
86 //===--------------------------------------------------------------------===//
87
88 // Attempts to materialize a conversion for a type mismatch between a call
89 // from this dialect, and a callable region. This method should generate an
90 // operation that takes 'input' as the only operand, and produces a single
91 // result of 'resultType'. If a conversion can not be generated, nullptr
92 // should be returned.
93 // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon5124f7b00111::TFInlinerInterface94 Operation* materializeCallConversion(OpBuilder& builder, Value input,
95 Type result_type,
96 Location conversion_loc) const final {
97 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
98 return nullptr;
99 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
100 /*truncate=*/builder.getBoolAttr(false));
101 }
102 };
103
104 // Checks if a block wraps a single operation and the single operation results
105 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)106 bool BlockWrapsSingleOp(Block* block) {
107 auto body = block->without_terminator();
108 if (!hasSingleElement(body)) return false;
109
110 Operation& wrapped_op = *body.begin();
111 Operation* terminator = block->getTerminator();
112 return wrapped_op.getNumResults() == terminator->getNumOperands() &&
113 std::equal(wrapped_op.getResults().begin(),
114 wrapped_op.getResults().end(),
115 terminator->getOperands().begin());
116 }
117 } // end anonymous namespace
118
TensorFlowDeviceDialect(MLIRContext * context)119 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
120 : Dialect(/*name=*/"tf_device", context,
121 TypeID::get<TensorFlowDeviceDialect>()) {
122 addOperations<
123 #define GET_OP_LIST
124 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
125 >();
126
127 addInterfaces<TFInlinerInterface>();
128 }
129
130 //===----------------------------------------------------------------------===//
131 // tf_device.launch
132 //===----------------------------------------------------------------------===//
133
134 // Checks if a tf_device.launch wraps a single operation and the single
135 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()136 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
137
138 //===----------------------------------------------------------------------===//
139 // tf_device.parallel_execute
140 //===----------------------------------------------------------------------===//
141
verify()142 LogicalResult ParallelExecuteOp::verify() {
143 ParallelExecuteOp op = *this;
144 const auto& regions = op.getOperation()->getRegions();
145 if (regions.empty()) {
146 return op.emitOpError() << "must have at least one region.";
147 }
148
149 int output_index = 0;
150 for (auto& region_and_index : llvm::enumerate(regions)) {
151 auto& region = region_and_index.value();
152 auto* region_terminator = region.front().getTerminator();
153
154 // Check that output types of regions match return operand types.
155 for (auto result_type : region_terminator->getOperandTypes()) {
156 if (result_type !=
157 op.getOperation()->getResult(output_index++).getType()) {
158 return op.emitOpError() << "output types must be a concatenated "
159 << "list of output types for each regions.";
160 }
161 }
162 }
163
164 // Check that total number of outputs from regions match the output types of
165 // the parallel_execute op.
166 const int num_output_types = op.getOperation()->getNumResults();
167 if (num_output_types != output_index) {
168 return op.emitOpError()
169 << "number of output types (" << num_output_types << ") "
170 << "must match the total number of outputs from all "
171 << "regions (" << output_index << ").";
172 }
173
174 return success();
175 }
176
177 // static
build(OpBuilder & builder,OperationState & state,int num_regions,TypeRange output_types)178 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
179 int num_regions, TypeRange output_types) {
180 DCHECK_GE(num_regions, 1);
181 for (int i = 0; i < num_regions; ++i) {
182 Region* region = state.addRegion();
183 region->push_back(new Block);
184 }
185 state.addTypes(output_types);
186 }
187
GetRegionBlockWithIndex(unsigned index)188 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
189 return getOperation()->getRegion(index).front();
190 }
191
GetRegionOutputs(unsigned region_index)192 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
193 unsigned region_index) {
194 int num_region_results =
195 GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
196
197 int return_value_offset = 0;
198 for (int region_id = 0; region_id < region_index; ++region_id)
199 return_value_offset +=
200 GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
201
202 return getResults().slice(return_value_offset, num_region_results);
203 }
204
RegionWrapsSingleOp(unsigned index)205 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
206 return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
207 }
208
209 //===----------------------------------------------------------------------===//
210 // tf_device.replicate
211 //===----------------------------------------------------------------------===//
212
213 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)214 ParseResult ParseReplicateOpOperands(
215 OpAsmParser* parser, OperationState* state,
216 llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>>*
217 replicated_inputs,
218 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand>* packed_inputs,
219 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand>* region_args,
220 llvm::SmallVectorImpl<Type>* region_arg_types) {
221 // No operands or empty operand list.
222 bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
223 if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
224 return success();
225
226 // Parse comma separated operands of the following format:
227 // replicated_input
228 // [%a, ...] as %block_arg0: type
229 // packed_input
230 // %b as %block_arg1: type
231 //
232 // Replicated inputs are placed before packed inputs when forming the op.
233 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> replicated_region_args;
234 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> packed_region_args;
235 llvm::SmallVector<Type, 8> replicated_region_arg_types;
236 llvm::SmallVector<Type, 8> packed_region_arg_types;
237 do {
238 OpAsmParser::UnresolvedOperand operand_type;
239 if (parser->parseOptionalOperand(operand_type).hasValue()) {
240 packed_inputs->emplace_back(operand_type);
241 if (parser->parseKeyword("as",
242 " between packed input and block argument") ||
243 parser->parseOperand(packed_region_args.emplace_back(),
244 /*allowResultNumber=*/false) ||
245 parser->parseColonType(packed_region_arg_types.emplace_back()))
246 return failure();
247 } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
248 OpAsmParser::Delimiter::Square) ||
249 parser->parseKeyword(
250 "as", " between replicated inputs and block argument") ||
251 parser->parseOperand(replicated_region_args.emplace_back(),
252 /*allowResultNumber=*/false) ||
253 parser->parseColonType(
254 replicated_region_arg_types.emplace_back())) {
255 return failure();
256 }
257 } while (succeeded(parser->parseOptionalComma()));
258
259 region_args->reserve(replicated_region_args.size() +
260 packed_region_args.size());
261 region_args->append(replicated_region_args.begin(),
262 replicated_region_args.end());
263 region_args->append(packed_region_args.begin(), packed_region_args.end());
264
265 region_arg_types->reserve(replicated_region_arg_types.size() +
266 packed_region_arg_types.size());
267 region_arg_types->append(replicated_region_arg_types.begin(),
268 replicated_region_arg_types.end());
269 region_arg_types->append(packed_region_arg_types.begin(),
270 packed_region_arg_types.end());
271
272 // Parse remaining `)` surrounding operands.
273 return parser->parseRParen();
274 }
275
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::UnresolvedOperand,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::UnresolvedOperand> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)276 ParseResult SetReplicateOpOperands(
277 llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
278 llvm::ArrayRef<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>>
279 replicated_inputs,
280 llvm::ArrayRef<OpAsmParser::UnresolvedOperand> packed_inputs,
281 llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
282 for (const auto& attr : state->attributes)
283 if (attr.getName().strref() == "n")
284 if (auto n_attr = attr.getValue().dyn_cast<IntegerAttr>())
285 *n = n_attr.getInt();
286
287 if (*n < 2)
288 return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
289
290 if (replicated_inputs.empty() && packed_inputs.empty()) return success();
291
292 for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
293 const int32_t idx = replicated_input_and_idx.index();
294 const auto& replicated_input = replicated_input_and_idx.value();
295 // Check if replicated input matches `n`.
296 if (replicated_input.size() != *n)
297 return parser->emitError(loc)
298 << "expects number of operands for replicated input " << idx
299 << " to be 'n' (" << *n << "), got " << replicated_input.size();
300
301 // Resolve replicated input and block argument type.
302 if (parser->resolveOperands(replicated_input, region_arg_types[idx],
303 state->operands))
304 return failure();
305 }
306
307 const int32_t num_replicated_block_args = replicated_inputs.size();
308 for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
309 const int32_t idx = packed_input_and_idx.index();
310 const auto& packed_input = packed_input_and_idx.value();
311
312 // Resolve packed input and block argument type.
313 if (parser->resolveOperand(
314 packed_input, region_arg_types[idx + num_replicated_block_args],
315 state->operands))
316 return failure();
317 }
318
319 return success();
320 }
321
322 } // namespace
323
324 static constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
325
parse(OpAsmParser & parser,OperationState & result)326 ParseResult ReplicateOp::parse(OpAsmParser& parser, OperationState& result) {
327 llvm::SMLoc loc = parser.getCurrentLocation();
328
329 // Parse operands, attributes, and region of op.
330 llvm::SmallVector<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>, 8>
331 replicated_inputs;
332 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> packed_inputs;
333 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> region_args;
334 llvm::SmallVector<Type, 8> region_arg_types;
335 int32_t n = 0;
336 Region& body = *result.addRegion();
337 if (ParseReplicateOpOperands(&parser, &result, &replicated_inputs,
338 &packed_inputs, ®ion_args,
339 ®ion_arg_types) ||
340 parser.parseOptionalAttrDict(result.attributes) ||
341 SetReplicateOpOperands(loc, &parser, &result, replicated_inputs,
342 packed_inputs, region_arg_types, &n))
343 return failure();
344
345 SmallVector<OpAsmParser::Argument> packed_args;
346 for (auto argAndType : llvm::zip(region_args, region_arg_types)) {
347 auto& arg = packed_args.emplace_back();
348 arg.ssaName = std::get<0>(argAndType);
349 arg.type = std::get<1>(argAndType);
350 }
351 if (parser.parseRegion(body, packed_args)) return failure();
352
353 // Add derived `operand_segment_sizes` attribute based on parsed operands.
354 if (!result.attributes.get(kOperandSegmentSizesAttr)) {
355 int32_t num_replicated_inputs = replicated_inputs.size() * n;
356 int32_t num_packed_inputs = packed_inputs.size();
357 auto attr = parser.getBuilder().getDenseI32ArrayAttr({num_replicated_inputs, num_packed_inputs});
358 result.addAttribute(kOperandSegmentSizesAttr, attr);
359 }
360
361 // Ensure that the region is well formed: it contains at least a block with
362 // a ReturnOp terminator.
363 ReplicateOp::ensureTerminator(body, parser.getBuilder(), result.location);
364
365 if (!llvm::hasSingleElement(body))
366 return parser.emitError(loc) << "expects a single block region";
367
368 Operation& terminator = body.front().back();
369 if (!isa<ReturnOp>(terminator))
370 return parser.emitError(loc) << "expects a tf_device.return terminator";
371
372 // Get the results type from the terminator type inside the replicate,
373 // replicated each by `n`.
374 result.types.reserve(terminator.getNumOperands() * n);
375 for (const auto& type : terminator.getOperandTypes())
376 result.types.append(n, type);
377
378 return success();
379 }
380
print(OpAsmPrinter & p)381 void ReplicateOp::print(OpAsmPrinter& p) {
382 // Print comma separated operands of the following format:
383 // replicated_input
384 // [%a, ...] as %block_arg0: type
385 // packed_input
386 // %b as %block_arg1: type
387 const int32_t n = this->n();
388 const int32_t num_replicated_inputs =
389 operand_segment_sizes()[0];
390 const int32_t num_replicated_block_args = num_replicated_inputs / n;
391
392 if (getNumOperands()) {
393 p << '(';
394 Block& block = body().front();
395 interleaveComma(block.getArguments(), p, [&](BlockArgument arg) {
396 const int block_arg_num = arg.getArgNumber();
397 if (block_arg_num < num_replicated_block_args) {
398 p << '[';
399 p.printOperands(
400 std::next(replicated_inputs().begin(), block_arg_num * n),
401 std::next(replicated_inputs().begin(), (block_arg_num + 1) * n));
402 p << "]";
403 } else {
404 p.printOperand(*std::next(packed_inputs().begin(),
405 block_arg_num - num_replicated_block_args));
406 }
407 p << " as " << arg << ": " << arg.getType();
408 });
409 p << ')';
410 }
411
412 // Skip derived `operand_segment_sizes` attribute as custom print format of
413 // operands holds enough information to calculate these variadic operand list
414 // lengths.
415 p.printOptionalAttrDict(
416 getOperation()->getAttrs(),
417 /*elidedAttrs=*/ArrayRef<StringRef>{kOperandSegmentSizesAttr});
418 p << ' ';
419 p.printRegion(body(), /*printEntryBlockArgs=*/false);
420 }
421
422 namespace {
423
424 // Checks if two types are compatible (compatible shapes and same elemental
425 // type).
VerifyCompatibleTypes(Type a,Type b)426 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
427 if (failed(verifyCompatibleShape(a, b)) ||
428 getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
429 return failure();
430
431 return success();
432 }
433
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)434 void BuildReplicateOp(
435 Builder* builder, OperationState* state, int n,
436 llvm::Optional<DictionaryAttr> devices,
437 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
438 ValueRange packed_inputs, TypeRange replica_output_types) {
439 DCHECK_GE(n, 2);
440 state->addAttribute("n", builder->getI32IntegerAttr(n));
441
442 if (devices.has_value()) state->addAttribute("devices", devices.getValue());
443
444 Region* region = state->addRegion();
445 region->push_back(new Block);
446 Block& block = region->front();
447
448 for (auto& replicated_input : replicated_inputs) {
449 DCHECK_EQ(llvm::size(replicated_input.first), n);
450 for (auto input : replicated_input.first) {
451 DCHECK(succeeded(
452 VerifyCompatibleTypes(input.getType(), replicated_input.second)));
453 state->addOperands(input);
454 }
455 block.addArgument(replicated_input.second, state->location);
456 }
457
458 for (auto packed_input : packed_inputs) {
459 state->addOperands(packed_input);
460 block.addArgument(packed_input.getType(), state->location);
461 }
462
463 // Add derived `operand_segment_sizes` attribute.
464 int32_t num_replicated_inputs = replicated_inputs.size() * n;
465 int32_t num_packed_inputs = packed_inputs.size();
466 auto operand_segment_sizes =
467 builder->getDenseI32ArrayAttr({num_replicated_inputs, num_packed_inputs});
468 state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
469
470 for (const auto& output_type : replica_output_types)
471 state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
472 }
473
474 } // anonymous namespace
475
verify()476 LogicalResult ReplicateOp::verify() {
477 ReplicateOp op = *this;
478 int32_t n = op.n();
479
480 // Check number of devices, if set, matches `n`.
481 if (op.devices().has_value()) {
482 for (auto device_attr : op.devices().getValue().getValue()) {
483 auto device_list = device_attr.getValue().dyn_cast_or_null<ArrayAttr>();
484 if (!device_list)
485 return op.emitError()
486 << "expects 'devices' to be a map alias and device name list.";
487
488 bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
489 return attr.dyn_cast_or_null<StringAttr>();
490 });
491 if (!is_device_string)
492 return op.emitOpError() << "expects 'devices' to be a consists of "
493 "string list as values.";
494
495 if (device_list.size() != n)
496 return op.emitOpError()
497 << "expects number of devices (" << device_list.size()
498 << ") to be equal to 'n' (" << n << ")";
499 }
500 }
501
502 Block& block = op.body().front();
503
504 auto operand_segment_sizes = op.operand_segment_sizes();
505 const int32_t num_replicated_inputs =
506 operand_segment_sizes[0];
507 const int32_t num_packed_inputs =
508 operand_segment_sizes[1];
509
510 if (num_replicated_inputs % n != 0)
511 return op.emitOpError()
512 << "expects number of replicated inputs (" << num_replicated_inputs
513 << ") to be evenly divisible by 'n' (" << n << ")";
514
515 const int32_t num_replicated_block_args = num_replicated_inputs / n;
516 if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
517 return op.emitOpError()
518 << "expects number of block arguments (" << block.getNumArguments()
519 << ") to be equal to number of replicated inputs ("
520 << num_replicated_inputs << ") / 'n' (" << n
521 << ") + number of packed inputs (" << num_packed_inputs << ")";
522
523 // Check input types match block argument types.
524 auto verify_operand_types = [&](BlockArgument block_arg,
525 int32_t op_operand_idx) -> LogicalResult {
526 Type op_operand_type = op.getOperand(op_operand_idx).getType();
527 if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
528 return op.emitOpError()
529 << "expects operand " << op_operand_idx << " (" << op_operand_type
530 << ") and block argument " << block_arg.getArgNumber() << " ("
531 << block_arg.getType() << ") to have compatible types";
532
533 return success();
534 };
535 for (auto block_arg : block.getArguments()) {
536 if (block_arg.getArgNumber() < num_replicated_block_args) {
537 for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
538 if (failed(verify_operand_types(block_arg, i))) return failure();
539 } else {
540 const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
541 num_replicated_inputs;
542 if (failed(verify_operand_types(block_arg, idx))) return failure();
543 }
544 }
545
546 Operation& terminator = block.back();
547
548 // Check number of results matches `n` * number of return operands.
549 if (op.getNumResults() != n * terminator.getNumOperands())
550 return op.emitOpError()
551 << "expects number of results (" << op.getNumResults()
552 << ") to be equal to 'n' * number of terminator operands (" << n
553 << " * " << terminator.getNumOperands() << ")";
554
555 // Check replicated output types match return operand types.
556 for (auto operand_type_and_idx :
557 llvm::enumerate(terminator.getOperandTypes())) {
558 Type operand_type = operand_type_and_idx.value();
559 int32_t operand_idx = operand_type_and_idx.index();
560 for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
561 if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
562 return op.emitOpError() << "incompatible types for result " << i
563 << " and terminator operand " << operand_idx;
564 }
565
566 return success();
567 }
568
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)569 void ReplicateOp::build(
570 OpBuilder& builder, OperationState& state, int n,
571 const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
572 devices,
573 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
574 ValueRange packed_inputs, TypeRange replica_output_types) {
575 llvm::Optional<DictionaryAttr> devices_attr;
576 if (!devices.empty()) {
577 llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
578 device_list.reserve(devices.size());
579 for (auto alias_and_devices : devices) {
580 NamedAttribute device_name_attr = builder.getNamedAttr(
581 alias_and_devices.getFirst(),
582 builder.getStrArrayAttr(alias_and_devices.getSecond()));
583 device_list.emplace_back(device_name_attr);
584 }
585 devices_attr.emplace(builder.getDictionaryAttr(device_list));
586 }
587
588 BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
589 packed_inputs, replica_output_types);
590 }
591
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)592 void ReplicateOp::build(
593 OpBuilder& builder, OperationState& state, int n,
594 llvm::Optional<DictionaryAttr> devices,
595 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
596 ValueRange packed_inputs, TypeRange replica_output_types) {
597 BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
598 packed_inputs, replica_output_types);
599 }
600
601 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()602 unsigned ReplicateOp::GetNumPackedBlockArguments() {
603 return packed_inputs().size();
604 }
605
606 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()607 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
608 return GetBody().getNumArguments() - GetNumPackedBlockArguments();
609 }
610
611 // Returns the replicated block arguments. A copy should be made if the
612 // replicate op is being modified.
GetReplicatedBlockArguments()613 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
614 return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
615 }
616
617 // Returns the packed block arguments. A copy should be made if the replicate op
618 // is being modified.
GetPackedBlockArguments()619 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
620 return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
621 }
622
623 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)624 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
625 assert(block_arg.getOwner() == &GetBody());
626 return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
627 }
628
629 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)630 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
631 return !IsReplicatedBlockArgument(block_arg);
632 }
633
634 // Returns the operand index of the operand being forwarded as a
635 // replicated/packed block argument for a given replica. This assumes a valid
636 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)637 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
638 BlockArgument block_arg, unsigned replica) {
639 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
640 if (operands.size() == 1) return operands.front().getOperandNumber();
641
642 return operands[replica].getOperandNumber();
643 }
644
645 // Returns the operand being forwarded as a replicated/packed block argument for
646 // a given replica. This assumes a valid block argument (of the replicate op)
647 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)648 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
649 unsigned replica) {
650 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
651 if (operands.size() == 1) return operands.front().get();
652
653 return operands[replica].get();
654 }
655
656 // Returns the list of replica op operands that maps to the given block
657 // argument. Returns list with num_replicas elements for replicated operands
658 // and list with a single element for packed operands.
659 //
660 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)661 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
662 BlockArgument block_arg) {
663 assert(block_arg.getOwner() == &GetBody());
664
665 unsigned arg_number = block_arg.getArgNumber();
666 unsigned num_replicated_args = GetNumReplicatedBlockArguments();
667 int32_t num_replicas = nAttr().getInt();
668 MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
669
670 // All replicated arguments are before packed arguments so return replicated
671 // operands if the given argument is one of the replicated arguments.
672 if (arg_number < num_replicated_args)
673 return operands.slice(arg_number * num_replicas, num_replicas);
674
675 operands = operands.drop_front(num_replicated_args * num_replicas);
676 arg_number -= num_replicated_args;
677 return operands.slice(arg_number, 1);
678 }
679
680 // Checks if a tf_device.replicate wraps a single operation and the single
681 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()682 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
683
684 //===----------------------------------------------------------------------===//
685 // Canonicalization patterns
686 //===----------------------------------------------------------------------===//
687
688 //===----------------------------------------------------------------------===//
689 // tf_device.cluster
690 //===----------------------------------------------------------------------===//
691
692 namespace {
693
694 // Eliminates cluster op results that are not defined within the cluster and are
695 // defined outside. cluster op can be rewritten to remove those results.
EliminatePassThroughResults(ClusterOp op,PatternRewriter & rewriter)696 static LogicalResult EliminatePassThroughResults(ClusterOp op,
697 PatternRewriter& rewriter) {
698 mlir::Block& body = op.GetBody();
699 Operation* return_op = body.getTerminator();
700 int num_results = return_op->getNumOperands();
701
702 // Values defined within the cluster.
703 llvm::SmallVector<Value, 4> cluster_vals;
704 cluster_vals.reserve(num_results);
705
706 // New results stores values to use while replacing the old cluster op.
707 llvm::SmallVector<Value, 4> new_results;
708 new_results.reserve(num_results);
709 for (OpOperand& operand : return_op->getOpOperands()) {
710 // If the corresponding result of the cluster op is used in some resource
711 // update op, do not eliminate the result. Such assignment ops could be for
712 // device resources and are required during fusing of the execute op and
713 // the resource update ops.
714 bool is_used_for_resource_write = llvm::any_of(
715 op.getResult(operand.getOperandNumber()).getUsers(),
716 [](Operation* user) { return isa<TF::AssignVariableOp>(user); });
717
718 // TODO(b/186717563): Eliminate all pass through results once XLA correctly
719 // handles empty computations. Another approach could be to drop empty
720 // clusters within MLIR but that seems to trigger other failures but can be
721 // considered again.
722 // Old bridge only removes unsupported TPU types (only string for now)
723 // during outside compilation extraction so this should be enough for
724 // the parity.
725 bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType())
726 .isa<mlir::TF::StringType>();
727 Value result = operand.get();
728 if (is_unsupported_type && result.getParentBlock() != &body &&
729 !is_used_for_resource_write) {
730 // Pass through result.
731 new_results.push_back(result);
732 } else {
733 // This result will be populated with the new result after rewriting the
734 // cluster op.
735 new_results.push_back(nullptr);
736 cluster_vals.push_back(result);
737 }
738 }
739
740 // Return failure if there are no pass through results and op is already
741 // canonical.
742 if (cluster_vals.size() == num_results) return failure();
743
744 // Rewrite return op in the cluster.
745 rewriter.setInsertionPoint(return_op);
746 auto new_return =
747 rewriter.replaceOpWithNewOp<tf_device::ReturnOp>(return_op, cluster_vals);
748
749 // Rewrite the cluster op.
750 rewriter.setInsertionPoint(op);
751 auto new_op = rewriter.create<tf_device::ClusterOp>(
752 op->getLoc(), new_return.getOperandTypes(), op->getOperands(),
753 op->getAttrs());
754 rewriter.inlineRegionBefore(op.getBodyRegion(), new_op.getBodyRegion(),
755 new_op.getBodyRegion().end());
756
757 int idx = 0;
758 for (Value& result : new_results) {
759 if (result == nullptr) result = new_op.getResult(idx++);
760 }
761 rewriter.replaceOp(op, new_results);
762 return success();
763 }
764 } // anonymous namespace
765
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)766 void ClusterOp::getCanonicalizationPatterns(RewritePatternSet& results,
767 MLIRContext* context) {
768 results.add(EliminatePassThroughResults);
769 }
770
771 //===----------------------------------------------------------------------===//
772 // tf_device.launch
773 //===----------------------------------------------------------------------===//
774
775 namespace {
776 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
777 // results of the LaunchOp to the operands of the ReturnOp.
778 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
779 using OpRewritePattern<LaunchOp>::OpRewritePattern;
780
matchAndRewritemlir::tf_device::__anon5124f7b00911::DropEmptyLaunch781 LogicalResult matchAndRewrite(LaunchOp op,
782 PatternRewriter& rewriter) const override {
783 Block& block = op.GetBody();
784 // Check if launch only has a return.
785 if (&block.front() != &block.back()) return failure();
786
787 // Map launch results to return operands.
788 rewriter.replaceOp(op, block.front().getOperands());
789
790 return success();
791 }
792 };
793 } // anonymous namespace
794
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)795 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet& results,
796 MLIRContext* context) {
797 results.add<DropEmptyLaunch>(context);
798 }
799
800 } // namespace tf_device
801 } // namespace mlir
802
803 //===----------------------------------------------------------------------===//
804 // TableGen'd op method definitions
805 //===----------------------------------------------------------------------===//
806
807 #define GET_OP_CLASSES
808 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
809