1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/OpImplementation.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
40 #include "mlir/IR/Types.h" // from @llvm-project
41 #include "mlir/IR/UseDefLists.h" // from @llvm-project
42 #include "mlir/IR/Value.h" // from @llvm-project
43 #include "mlir/Support/LLVM.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48
49 namespace mlir {
50 namespace tf_device {
51
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks
62 //===--------------------------------------------------------------------===//
63
64 // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface65 bool isLegalToInline(Operation* call, Operation* callable,
66 bool wouldBeCloned) const final {
67 return true;
68 }
69
70 // Returns if its legal to inline 'src' region into the 'dest' region
71 // attached to a TF Device operation.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface72 bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
73 BlockAndValueMapping& valueMapping) const final {
74 return true;
75 }
76
77 // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface78 bool isLegalToInline(Operation*, Region*, bool,
79 BlockAndValueMapping&) const final {
80 // For now, enable inlining all operations.
81 return true;
82 }
83
84 //===--------------------------------------------------------------------===//
85 // Transformation Hooks
86 //===--------------------------------------------------------------------===//
87
88 // Attempts to materialize a conversion for a type mismatch between a call
89 // from this dialect, and a callable region. This method should generate an
90 // operation that takes 'input' as the only operand, and produces a single
91 // result of 'resultType'. If a conversion can not be generated, nullptr
92 // should be returned.
93 // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon72d616770111::TFInlinerInterface94 Operation* materializeCallConversion(OpBuilder& builder, Value input,
95 Type result_type,
96 Location conversion_loc) const final {
97 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
98 return nullptr;
99 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
100 /*truncate=*/builder.getBoolAttr(false));
101 }
102 };
103
104 // Checks if a block wraps a single operation and the single operation results
105 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)106 bool BlockWrapsSingleOp(Block* block) {
107 auto body = block->without_terminator();
108 if (!hasSingleElement(body)) return false;
109
110 Operation& wrapped_op = *body.begin();
111 Operation* terminator = block->getTerminator();
112 return wrapped_op.getNumResults() == terminator->getNumOperands() &&
113 std::equal(wrapped_op.getResults().begin(),
114 wrapped_op.getResults().end(),
115 terminator->getOperands().begin());
116 }
117 } // end anonymous namespace
118
TensorFlowDeviceDialect(MLIRContext * context)119 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
120 : Dialect(/*name=*/"tf_device", context,
121 TypeID::get<TensorFlowDeviceDialect>()) {
122 addOperations<
123 #define GET_OP_LIST
124 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
125 >();
126
127 addInterfaces<TFInlinerInterface>();
128 }
129
130 //===----------------------------------------------------------------------===//
131 // tf_device.launch
132 //===----------------------------------------------------------------------===//
133
134 // Checks if a tf_device.launch wraps a single operation and the single
135 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()136 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
137
138 //===----------------------------------------------------------------------===//
139 // tf_device.parallel_execute
140 //===----------------------------------------------------------------------===//
141
142 namespace {
143
Verify(ParallelExecuteOp op)144 LogicalResult Verify(ParallelExecuteOp op) {
145 const auto& regions = op.getOperation()->getRegions();
146 if (regions.size() < 2) {
147 return op.emitOpError() << "must have at least two regions.";
148 }
149
150 int output_index = 0;
151 for (auto& region_and_index : llvm::enumerate(regions)) {
152 auto& region = region_and_index.value();
153 auto* region_terminator = region.front().getTerminator();
154
155 // Check that output types of regions match return operand types.
156 for (auto result_type : region_terminator->getOperandTypes()) {
157 if (result_type !=
158 op.getOperation()->getResult(output_index++).getType()) {
159 return op.emitOpError() << "output types must be a concatenated "
160 << "list of output types for each regions.";
161 }
162 }
163 }
164
165 // Check that total number of outputs from regions match the output types of
166 // the parallel_execute op.
167 const int num_output_types = op.getOperation()->getNumResults();
168 if (num_output_types != output_index) {
169 return op.emitOpError()
170 << "number of output types (" << num_output_types << ") "
171 << "must match the total number of outputs from all "
172 << "regions (" << output_index << ").";
173 }
174
175 return success();
176 }
177
178 } // namespace
179
180 // static
build(OpBuilder & builder,OperationState & state,int num_regions,TypeRange output_types)181 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
182 int num_regions, TypeRange output_types) {
183 DCHECK_GE(num_regions, 2);
184 for (int i = 0; i < num_regions; ++i) {
185 Region* region = state.addRegion();
186 region->push_back(new Block);
187 }
188 state.addTypes(output_types);
189 }
190
GetRegionBlockWithIndex(unsigned index)191 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
192 return getOperation()->getRegion(index).front();
193 }
194
GetRegionOutputs(unsigned region_index)195 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
196 unsigned region_index) {
197 int num_region_results =
198 GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
199
200 int return_value_offset = 0;
201 for (int region_id = 0; region_id < region_index; ++region_id)
202 return_value_offset +=
203 GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
204
205 return getResults().slice(return_value_offset, num_region_results);
206 }
207
RegionWrapsSingleOp(unsigned index)208 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
209 return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
210 }
211
212 //===----------------------------------------------------------------------===//
213 // tf_device.replicate
214 //===----------------------------------------------------------------------===//
215
216 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)217 ParseResult ParseReplicateOpOperands(
218 OpAsmParser* parser, OperationState* state,
219 llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
220 replicated_inputs,
221 llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
222 llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
223 llvm::SmallVectorImpl<Type>* region_arg_types) {
224 // No operands or empty operand list.
225 bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
226 if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
227 return success();
228
229 // Parse comma separated operands of the following format:
230 // replicated_input
231 // [%a, ...] as %block_arg0: type
232 // packed_input
233 // %b as %block_arg1: type
234 //
235 // Replicated inputs are placed before packed inputs when forming the op.
236 llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
237 llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
238 llvm::SmallVector<Type, 8> replicated_region_arg_types;
239 llvm::SmallVector<Type, 8> packed_region_arg_types;
240 do {
241 OpAsmParser::OperandType operand_type;
242 if (parser->parseOptionalOperand(operand_type).hasValue()) {
243 packed_inputs->emplace_back(operand_type);
244 if (parser->parseKeyword("as",
245 " between packed input and block argument") ||
246 parser->parseRegionArgument(packed_region_args.emplace_back()) ||
247 parser->parseColonType(packed_region_arg_types.emplace_back()))
248 return failure();
249 } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
250 OpAsmParser::Delimiter::Square) ||
251 parser->parseKeyword(
252 "as", " between replicated inputs and block argument") ||
253 parser->parseRegionArgument(
254 replicated_region_args.emplace_back()) ||
255 parser->parseColonType(
256 replicated_region_arg_types.emplace_back())) {
257 return failure();
258 }
259 } while (succeeded(parser->parseOptionalComma()));
260
261 region_args->reserve(replicated_region_args.size() +
262 packed_region_args.size());
263 region_args->append(replicated_region_args.begin(),
264 replicated_region_args.end());
265 region_args->append(packed_region_args.begin(), packed_region_args.end());
266
267 region_arg_types->reserve(replicated_region_arg_types.size() +
268 packed_region_arg_types.size());
269 region_arg_types->append(replicated_region_arg_types.begin(),
270 replicated_region_arg_types.end());
271 region_arg_types->append(packed_region_arg_types.begin(),
272 packed_region_arg_types.end());
273
274 // Parse remaining `)` surrounding operands.
275 return parser->parseRParen();
276 }
277
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)278 ParseResult SetReplicateOpOperands(
279 llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
280 llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
281 replicated_inputs,
282 llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
283 llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
284 for (const auto& attr : state->attributes)
285 if (attr.first.strref() == "n")
286 if (auto n_attr = attr.second.dyn_cast<IntegerAttr>())
287 *n = n_attr.getInt();
288
289 if (*n < 2)
290 return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
291
292 if (replicated_inputs.empty() && packed_inputs.empty()) return success();
293
294 for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
295 const int32_t idx = replicated_input_and_idx.index();
296 const auto& replicated_input = replicated_input_and_idx.value();
297 // Check if replicated input matches `n`.
298 if (replicated_input.size() != *n)
299 return parser->emitError(loc)
300 << "expects number of operands for replicated input " << idx
301 << " to be 'n' (" << *n << "), got " << replicated_input.size();
302
303 // Resolve replicated input and block argument type.
304 if (parser->resolveOperands(replicated_input, region_arg_types[idx],
305 state->operands))
306 return failure();
307 }
308
309 const int32_t num_replicated_block_args = replicated_inputs.size();
310 for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
311 const int32_t idx = packed_input_and_idx.index();
312 const auto& packed_input = packed_input_and_idx.value();
313
314 // Resolve packed input and block argument type.
315 if (parser->resolveOperand(
316 packed_input, region_arg_types[idx + num_replicated_block_args],
317 state->operands))
318 return failure();
319 }
320
321 return success();
322 }
323
324 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
325
ParseReplicateOp(OpAsmParser * parser,OperationState * state)326 ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
327 llvm::SMLoc loc = parser->getCurrentLocation();
328
329 // Parse operands, attributes, and region of op.
330 llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
331 replicated_inputs;
332 llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
333 llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
334 llvm::SmallVector<Type, 8> region_arg_types;
335 int32_t n = 0;
336 Region& body = *state->addRegion();
337 if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
338 &packed_inputs, ®ion_args,
339 ®ion_arg_types) ||
340 parser->parseOptionalAttrDict(state->attributes) ||
341 SetReplicateOpOperands(loc, parser, state, replicated_inputs,
342 packed_inputs, region_arg_types, &n) ||
343 parser->parseRegion(body, region_args, region_arg_types))
344 return failure();
345
346 // Add derived `operand_segment_sizes` attribute based on parsed operands.
347 if (!state->attributes.get(kOperandSegmentSizesAttr)) {
348 int32_t num_replicated_inputs = replicated_inputs.size() * n;
349 int32_t num_packed_inputs = packed_inputs.size();
350 auto attr = DenseIntElementsAttr::get(
351 VectorType::get({2}, parser->getBuilder().getI32Type()),
352 {num_replicated_inputs, num_packed_inputs});
353 state->addAttribute(kOperandSegmentSizesAttr, attr);
354 }
355
356 // Ensure that the region is well formed: it contains at least a block with
357 // a ReturnOp terminator.
358 ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
359
360 if (!llvm::hasSingleElement(body))
361 return parser->emitError(loc) << "expects a single block region";
362
363 Operation& terminator = body.front().back();
364 if (!isa<ReturnOp>(terminator))
365 return parser->emitError(loc) << "expects a tf_device.return terminator";
366
367 // Get the results type from the terminator type inside the replicate,
368 // replicated each by `n`.
369 state->types.reserve(terminator.getNumOperands() * n);
370 for (const auto& type : terminator.getOperandTypes())
371 state->types.append(n, type);
372
373 return success();
374 }
375
Print(ReplicateOp op,OpAsmPrinter * p)376 void Print(ReplicateOp op, OpAsmPrinter* p) {
377 *p << op.getOperationName();
378
379 // Print comma separated operands of the following format:
380 // replicated_input
381 // [%a, ...] as %block_arg0: type
382 // packed_input
383 // %b as %block_arg1: type
384 const int32_t n = op.n();
385 const int32_t num_replicated_inputs =
386 (*op.operand_segment_sizes().int_value_begin()).getSExtValue();
387 const int32_t num_replicated_block_args = num_replicated_inputs / n;
388
389 if (op.getNumOperands()) {
390 *p << '(';
391 Block& block = op.body().front();
392 interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
393 const int block_arg_num = arg.getArgNumber();
394 if (block_arg_num < num_replicated_block_args) {
395 *p << '[';
396 p->printOperands(
397 std::next(op.replicated_inputs().begin(), block_arg_num * n),
398 std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
399 *p << "]";
400 } else {
401 p->printOperand(*std::next(op.packed_inputs().begin(),
402 block_arg_num - num_replicated_block_args));
403 }
404 *p << " as " << arg << ": " << arg.getType();
405 });
406 *p << ')';
407 }
408
409 // Skip derived `operand_segment_sizes` attribute as custom print format of
410 // operands holds enough information to calculate these variadic operand list
411 // lengths.
412 p->printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
413 kOperandSegmentSizesAttr});
414 p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
415 }
416
417 // Checks if two types are compatible (compatible shapes and same elemental
418 // type).
VerifyCompatibleTypes(Type a,Type b)419 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
420 if (failed(verifyCompatibleShape(a, b)) ||
421 getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
422 return failure();
423
424 return success();
425 }
426
Verify(ReplicateOp op)427 LogicalResult Verify(ReplicateOp op) {
428 int32_t n = op.n();
429
430 // Check number of devices, if set, matches `n`.
431 if (op.devices().hasValue()) {
432 for (auto device_attr : op.devices().getValue().getValue()) {
433 auto device_list = device_attr.second.dyn_cast_or_null<ArrayAttr>();
434 if (!device_list)
435 return op.emitError()
436 << "expects 'devices' to be a map alias and device name list.";
437
438 bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
439 return attr.dyn_cast_or_null<StringAttr>();
440 });
441 if (!is_device_string)
442 return op.emitOpError() << "expects 'devices' to be a consists of "
443 "string list as values.";
444
445 if (device_list.size() != n)
446 return op.emitOpError()
447 << "expects number of devices (" << device_list.size()
448 << ") to be equal to 'n' (" << n << ")";
449 }
450 }
451
452 Block& block = op.body().front();
453
454 auto operand_segment_sizes = op.operand_segment_sizes();
455 const int32_t num_replicated_inputs =
456 operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
457 const int32_t num_packed_inputs =
458 operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
459
460 if (num_replicated_inputs % n != 0)
461 return op.emitOpError()
462 << "expects number of replicated inputs (" << num_replicated_inputs
463 << ") to be evenly divisible by 'n' (" << n << ")";
464
465 const int32_t num_replicated_block_args = num_replicated_inputs / n;
466 if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
467 return op.emitOpError()
468 << "expects number of block arguments (" << block.getNumArguments()
469 << ") to be equal to number of replicated inputs ("
470 << num_replicated_inputs << ") / 'n' (" << n
471 << ") + number of packed inputs (" << num_packed_inputs << ")";
472
473 // Check input types match block argument types.
474 auto verify_operand_types = [&](BlockArgument block_arg,
475 int32_t op_operand_idx) -> LogicalResult {
476 Type op_operand_type = op.getOperand(op_operand_idx).getType();
477 if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
478 return op.emitOpError()
479 << "expects operand " << op_operand_idx << " (" << op_operand_type
480 << ") and block argument " << block_arg.getArgNumber() << " ("
481 << block_arg.getType() << ") to have compatible types";
482
483 return success();
484 };
485 for (auto block_arg : block.getArguments()) {
486 if (block_arg.getArgNumber() < num_replicated_block_args) {
487 for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
488 if (failed(verify_operand_types(block_arg, i))) return failure();
489 } else {
490 const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
491 num_replicated_inputs;
492 if (failed(verify_operand_types(block_arg, idx))) return failure();
493 }
494 }
495
496 Operation& terminator = block.back();
497
498 // Check number of results matches `n` * number of return operands.
499 if (op.getNumResults() != n * terminator.getNumOperands())
500 return op.emitOpError()
501 << "expects number of results (" << op.getNumResults()
502 << ") to be equal to 'n' * number of terminator operands (" << n
503 << " * " << terminator.getNumOperands() << ")";
504
505 // Check replicated output types match return operand types.
506 for (auto operand_type_and_idx :
507 llvm::enumerate(terminator.getOperandTypes())) {
508 Type operand_type = operand_type_and_idx.value();
509 int32_t operand_idx = operand_type_and_idx.index();
510 for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
511 if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
512 return op.emitOpError() << "incompatible types for result " << i
513 << " and terminator operand " << operand_idx;
514 }
515
516 return success();
517 }
518
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)519 void BuildReplicateOp(
520 Builder* builder, OperationState* state, int n,
521 llvm::Optional<DictionaryAttr> devices,
522 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
523 ValueRange packed_inputs, TypeRange replica_output_types) {
524 DCHECK_GE(n, 2);
525 state->addAttribute("n", builder->getI32IntegerAttr(n));
526
527 if (devices.hasValue()) state->addAttribute("devices", devices.getValue());
528
529 Region* region = state->addRegion();
530 region->push_back(new Block);
531 Block& block = region->front();
532
533 for (auto& replicated_input : replicated_inputs) {
534 DCHECK_EQ(llvm::size(replicated_input.first), n);
535 for (auto input : replicated_input.first) {
536 DCHECK(succeeded(
537 VerifyCompatibleTypes(input.getType(), replicated_input.second)));
538 state->addOperands(input);
539 }
540 block.addArgument(replicated_input.second);
541 }
542
543 for (auto packed_input : packed_inputs) {
544 state->addOperands(packed_input);
545 block.addArgument(packed_input.getType());
546 }
547
548 // Add derived `operand_segment_sizes` attribute.
549 int32_t num_replicated_inputs = replicated_inputs.size() * n;
550 int32_t num_packed_inputs = packed_inputs.size();
551 auto operand_segment_sizes =
552 DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
553 {num_replicated_inputs, num_packed_inputs});
554 state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
555
556 for (const auto& output_type : replica_output_types)
557 state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
558 }
559 } // anonymous namespace
560
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)561 void ReplicateOp::build(
562 OpBuilder& builder, OperationState& state, int n,
563 const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
564 devices,
565 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
566 ValueRange packed_inputs, TypeRange replica_output_types) {
567 llvm::Optional<DictionaryAttr> devices_attr;
568 if (!devices.empty()) {
569 llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
570 device_list.reserve(devices.size());
571 for (auto alias_and_devices : devices) {
572 NamedAttribute device_name_attr = builder.getNamedAttr(
573 alias_and_devices.getFirst(),
574 builder.getStrArrayAttr(alias_and_devices.getSecond()));
575 device_list.emplace_back(device_name_attr);
576 }
577 devices_attr.emplace(builder.getDictionaryAttr(device_list));
578 }
579
580 BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
581 packed_inputs, replica_output_types);
582 }
583
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)584 void ReplicateOp::build(
585 OpBuilder& builder, OperationState& state, int n,
586 llvm::Optional<DictionaryAttr> devices,
587 llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
588 ValueRange packed_inputs, TypeRange replica_output_types) {
589 BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
590 packed_inputs, replica_output_types);
591 }
592
593 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()594 unsigned ReplicateOp::GetNumPackedBlockArguments() {
595 return packed_inputs().size();
596 }
597
598 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()599 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
600 return GetBody().getNumArguments() - GetNumPackedBlockArguments();
601 }
602
603 // Returns the replicated block arguments. A copy should be made if the
604 // replicate op is being modified.
GetReplicatedBlockArguments()605 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
606 return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
607 }
608
609 // Returns the packed block arguments. A copy should be made if the replicate op
610 // is being modified.
GetPackedBlockArguments()611 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
612 return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
613 }
614
615 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)616 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
617 assert(block_arg.getOwner() == &GetBody());
618 return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
619 }
620
621 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)622 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
623 return !IsReplicatedBlockArgument(block_arg);
624 }
625
626 // Returns the operand index of the operand being forwarded as a
627 // replicated/packed block argument for a given replica. This assumes a valid
628 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)629 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
630 BlockArgument block_arg, unsigned replica) {
631 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
632 if (operands.size() == 1) return operands.front().getOperandNumber();
633
634 return operands[replica].getOperandNumber();
635 }
636
637 // Returns the operand being forwarded as a replicated/packed block argument for
638 // a given replica. This assumes a valid block argument (of the replicate op)
639 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)640 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
641 unsigned replica) {
642 MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
643 if (operands.size() == 1) return operands.front().get();
644
645 return operands[replica].get();
646 }
647
648 // Returns the list of replica op operands that maps to the given block
649 // argument. Returns list with num_replicas elements for replicated operands
650 // and list with a single element for packed operands.
651 //
652 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)653 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
654 BlockArgument block_arg) {
655 assert(block_arg.getOwner() == &GetBody());
656
657 unsigned arg_number = block_arg.getArgNumber();
658 unsigned num_replicated_args = GetNumReplicatedBlockArguments();
659 int32_t num_replicas = nAttr().getInt();
660 MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
661
662 // All replicated arguments are before packed arguments so return replicated
663 // operands if the given argument is one of the replicated arguments.
664 if (arg_number < num_replicated_args)
665 return operands.slice(arg_number * num_replicas, num_replicas);
666
667 operands = operands.drop_front(num_replicated_args * num_replicas);
668 arg_number -= num_replicated_args;
669 return operands.slice(arg_number, 1);
670 }
671
672 // Checks if a tf_device.replicate wraps a single operation and the single
673 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()674 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
675
676 //===----------------------------------------------------------------------===//
677 // Canonicalization patterns
678 //===----------------------------------------------------------------------===//
679
680 //===----------------------------------------------------------------------===//
681 // tf_device.cluster
682 //===----------------------------------------------------------------------===//
683
684 namespace {
685
686 // Eliminates cluster op results that are not defined within the cluster and are
687 // defined outside. cluster op can be rewritten to remove those results.
EliminatePassThroughResults(ClusterOp op,PatternRewriter & rewriter)688 static LogicalResult EliminatePassThroughResults(ClusterOp op,
689 PatternRewriter& rewriter) {
690 mlir::Block& body = op.GetBody();
691 Operation* return_op = body.getTerminator();
692 int num_results = return_op->getNumOperands();
693
694 // Values defined within the cluster.
695 llvm::SmallVector<Value, 4> cluster_vals;
696 cluster_vals.reserve(num_results);
697
698 // New results stores values to use while replacing the old cluster op.
699 llvm::SmallVector<Value, 4> new_results;
700 new_results.reserve(num_results);
701 for (OpOperand& operand : return_op->getOpOperands()) {
702 // If the corresponding result of the cluster op is used in some resource
703 // update op, do not eliminate the result. Such assignment ops could be for
704 // device resources and are required during fusing of the execute op and
705 // the resource update ops.
706 bool is_used_for_resource_write = llvm::any_of(
707 op.getResult(operand.getOperandNumber()).getUsers(),
708 [](Operation* user) { return isa<TF::AssignVariableOp>(user); });
709
710 // TODO(b/186717563): Eliminate all pass through results once XLA correctly
711 // handles empty computations. Another approach could be to drop empty
712 // clusters within MLIR but that seems to trigger other failures but can be
713 // considered again.
714 // Old bridge only removes unsupported TPU types (only string for now)
715 // during outside compilation extraction so this should be enough for
716 // the parity.
717 bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType())
718 .isa<mlir::TF::StringType>();
719 Value result = operand.get();
720 if (is_unsupported_type && result.getParentBlock() != &body &&
721 !is_used_for_resource_write) {
722 // Pass through result.
723 new_results.push_back(result);
724 } else {
725 // This result will be populated with the new result after rewriting the
726 // cluster op.
727 new_results.push_back(nullptr);
728 cluster_vals.push_back(result);
729 }
730 }
731
732 // Return failure if there are no pass through results and op is already
733 // canonical.
734 if (cluster_vals.size() == num_results) return failure();
735
736 // Rewrite return op in the cluster.
737 rewriter.setInsertionPoint(return_op);
738 auto new_return =
739 rewriter.replaceOpWithNewOp<tf_device::ReturnOp>(return_op, cluster_vals);
740
741 // Rewrite the cluster op.
742 rewriter.setInsertionPoint(op);
743 auto new_op = rewriter.create<tf_device::ClusterOp>(
744 op->getLoc(), new_return.getOperandTypes(), op->getOperands(),
745 op->getAttrs());
746 rewriter.inlineRegionBefore(op.getBodyRegion(), new_op.getBodyRegion(),
747 new_op.getBodyRegion().end());
748
749 int idx = 0;
750 for (Value& result : new_results) {
751 if (result == nullptr) result = new_op.getResult(idx++);
752 }
753 rewriter.replaceOp(op, new_results);
754 return success();
755 }
756 } // anonymous namespace
757
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)758 void ClusterOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
759 MLIRContext* context) {
760 results.insert(EliminatePassThroughResults);
761 }
762
763 //===----------------------------------------------------------------------===//
764 // tf_device.launch
765 //===----------------------------------------------------------------------===//
766
767 namespace {
768 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
769 // results of the LaunchOp to the operands of the ReturnOp.
770 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
771 using OpRewritePattern<LaunchOp>::OpRewritePattern;
772
matchAndRewritemlir::tf_device::__anon72d616770911::DropEmptyLaunch773 LogicalResult matchAndRewrite(LaunchOp op,
774 PatternRewriter& rewriter) const override {
775 Block& block = op.GetBody();
776 // Check if launch only has a return.
777 if (&block.front() != &block.back()) return failure();
778
779 // Map launch results to return operands.
780 rewriter.replaceOp(op, block.front().getOperands());
781
782 return success();
783 }
784 };
785 } // anonymous namespace
786
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)787 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
788 MLIRContext* context) {
789 results.insert<DropEmptyLaunch>(context);
790 }
791
792 } // namespace tf_device
793 } // namespace mlir
794
795 //===----------------------------------------------------------------------===//
796 // TableGen'd op method definitions
797 //===----------------------------------------------------------------------===//
798
799 #define GET_OP_CLASSES
800 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
801