• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
40 #include "mlir/IR/Types.h"  // from @llvm-project
41 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
42 #include "mlir/IR/Value.h"  // from @llvm-project
43 #include "mlir/Support/LLVM.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48 
49 namespace mlir {
50 namespace tf_device {
51 
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58   using DialectInlinerInterface::DialectInlinerInterface;
59 
60   //===--------------------------------------------------------------------===//
61   // Analysis Hooks
62   //===--------------------------------------------------------------------===//
63 
64   // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface65   bool isLegalToInline(Operation* call, Operation* callable,
66                        bool wouldBeCloned) const final {
67     return true;
68   }
69 
70   // Returns if its legal to inline 'src' region into the 'dest' region
71   // attached to a TF Device operation.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface72   bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
73                        BlockAndValueMapping& valueMapping) const final {
74     return true;
75   }
76 
77   // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon5124f7b00111::TFInlinerInterface78   bool isLegalToInline(Operation*, Region*, bool,
79                        BlockAndValueMapping&) const final {
80     // For now, enable inlining all operations.
81     return true;
82   }
83 
84   //===--------------------------------------------------------------------===//
85   // Transformation Hooks
86   //===--------------------------------------------------------------------===//
87 
88   // Attempts to materialize a conversion for a type mismatch between a call
89   // from this dialect, and a callable region. This method should generate an
90   // operation that takes 'input' as the only operand, and produces a single
91   // result of 'resultType'. If a conversion can not be generated, nullptr
92   // should be returned.
93   // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon5124f7b00111::TFInlinerInterface94   Operation* materializeCallConversion(OpBuilder& builder, Value input,
95                                        Type result_type,
96                                        Location conversion_loc) const final {
97     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
98       return nullptr;
99     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
100                                       /*truncate=*/builder.getBoolAttr(false));
101   }
102 };
103 
104 // Checks if a block wraps a single operation and the single operation results
105 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)106 bool BlockWrapsSingleOp(Block* block) {
107   auto body = block->without_terminator();
108   if (!hasSingleElement(body)) return false;
109 
110   Operation& wrapped_op = *body.begin();
111   Operation* terminator = block->getTerminator();
112   return wrapped_op.getNumResults() == terminator->getNumOperands() &&
113          std::equal(wrapped_op.getResults().begin(),
114                     wrapped_op.getResults().end(),
115                     terminator->getOperands().begin());
116 }
117 }  // end anonymous namespace
118 
TensorFlowDeviceDialect(MLIRContext * context)119 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
120     : Dialect(/*name=*/"tf_device", context,
121               TypeID::get<TensorFlowDeviceDialect>()) {
122   addOperations<
123 #define GET_OP_LIST
124 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
125       >();
126 
127   addInterfaces<TFInlinerInterface>();
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // tf_device.launch
132 //===----------------------------------------------------------------------===//
133 
134 // Checks if a tf_device.launch wraps a single operation and the single
135 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()136 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
137 
138 //===----------------------------------------------------------------------===//
139 // tf_device.parallel_execute
140 //===----------------------------------------------------------------------===//
141 
verify()142 LogicalResult ParallelExecuteOp::verify() {
143   ParallelExecuteOp op = *this;
144   const auto& regions = op.getOperation()->getRegions();
145   if (regions.empty()) {
146     return op.emitOpError() << "must have at least one region.";
147   }
148 
149   int output_index = 0;
150   for (auto& region_and_index : llvm::enumerate(regions)) {
151     auto& region = region_and_index.value();
152     auto* region_terminator = region.front().getTerminator();
153 
154     // Check that output types of regions match return operand types.
155     for (auto result_type : region_terminator->getOperandTypes()) {
156       if (result_type !=
157           op.getOperation()->getResult(output_index++).getType()) {
158         return op.emitOpError() << "output types must be a concatenated "
159                                 << "list of output types for each regions.";
160       }
161     }
162   }
163 
164   // Check that total number of outputs from regions match the output types of
165   // the parallel_execute op.
166   const int num_output_types = op.getOperation()->getNumResults();
167   if (num_output_types != output_index) {
168     return op.emitOpError()
169            << "number of output types (" << num_output_types << ") "
170            << "must match the total number of outputs from all "
171            << "regions (" << output_index << ").";
172   }
173 
174   return success();
175 }
176 
177 // static
build(OpBuilder & builder,OperationState & state,int num_regions,TypeRange output_types)178 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
179                               int num_regions, TypeRange output_types) {
180   DCHECK_GE(num_regions, 1);
181   for (int i = 0; i < num_regions; ++i) {
182     Region* region = state.addRegion();
183     region->push_back(new Block);
184   }
185   state.addTypes(output_types);
186 }
187 
GetRegionBlockWithIndex(unsigned index)188 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
189   return getOperation()->getRegion(index).front();
190 }
191 
GetRegionOutputs(unsigned region_index)192 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
193     unsigned region_index) {
194   int num_region_results =
195       GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
196 
197   int return_value_offset = 0;
198   for (int region_id = 0; region_id < region_index; ++region_id)
199     return_value_offset +=
200         GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
201 
202   return getResults().slice(return_value_offset, num_region_results);
203 }
204 
RegionWrapsSingleOp(unsigned index)205 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
206   return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // tf_device.replicate
211 //===----------------------------------------------------------------------===//
212 
213 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)214 ParseResult ParseReplicateOpOperands(
215     OpAsmParser* parser, OperationState* state,
216     llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>>*
217         replicated_inputs,
218     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand>* packed_inputs,
219     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand>* region_args,
220     llvm::SmallVectorImpl<Type>* region_arg_types) {
221   // No operands or empty operand list.
222   bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
223   if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
224     return success();
225 
226   // Parse comma separated operands of the following format:
227   //   replicated_input
228   //     [%a, ...] as %block_arg0: type
229   //   packed_input
230   //     %b as %block_arg1: type
231   //
232   // Replicated inputs are placed before packed inputs when forming the op.
233   llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> replicated_region_args;
234   llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> packed_region_args;
235   llvm::SmallVector<Type, 8> replicated_region_arg_types;
236   llvm::SmallVector<Type, 8> packed_region_arg_types;
237   do {
238     OpAsmParser::UnresolvedOperand operand_type;
239     if (parser->parseOptionalOperand(operand_type).hasValue()) {
240       packed_inputs->emplace_back(operand_type);
241       if (parser->parseKeyword("as",
242                                " between packed input and block argument") ||
243           parser->parseOperand(packed_region_args.emplace_back(),
244                                /*allowResultNumber=*/false) ||
245           parser->parseColonType(packed_region_arg_types.emplace_back()))
246         return failure();
247     } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
248                                         OpAsmParser::Delimiter::Square) ||
249                parser->parseKeyword(
250                    "as", " between replicated inputs and block argument") ||
251                parser->parseOperand(replicated_region_args.emplace_back(),
252                                     /*allowResultNumber=*/false) ||
253                parser->parseColonType(
254                    replicated_region_arg_types.emplace_back())) {
255       return failure();
256     }
257   } while (succeeded(parser->parseOptionalComma()));
258 
259   region_args->reserve(replicated_region_args.size() +
260                        packed_region_args.size());
261   region_args->append(replicated_region_args.begin(),
262                       replicated_region_args.end());
263   region_args->append(packed_region_args.begin(), packed_region_args.end());
264 
265   region_arg_types->reserve(replicated_region_arg_types.size() +
266                             packed_region_arg_types.size());
267   region_arg_types->append(replicated_region_arg_types.begin(),
268                            replicated_region_arg_types.end());
269   region_arg_types->append(packed_region_arg_types.begin(),
270                            packed_region_arg_types.end());
271 
272   // Parse remaining `)` surrounding operands.
273   return parser->parseRParen();
274 }
275 
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::UnresolvedOperand,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::UnresolvedOperand> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)276 ParseResult SetReplicateOpOperands(
277     llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
278     llvm::ArrayRef<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>>
279         replicated_inputs,
280     llvm::ArrayRef<OpAsmParser::UnresolvedOperand> packed_inputs,
281     llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
282   for (const auto& attr : state->attributes)
283     if (attr.getName().strref() == "n")
284       if (auto n_attr = attr.getValue().dyn_cast<IntegerAttr>())
285         *n = n_attr.getInt();
286 
287   if (*n < 2)
288     return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
289 
290   if (replicated_inputs.empty() && packed_inputs.empty()) return success();
291 
292   for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
293     const int32_t idx = replicated_input_and_idx.index();
294     const auto& replicated_input = replicated_input_and_idx.value();
295     // Check if replicated input matches `n`.
296     if (replicated_input.size() != *n)
297       return parser->emitError(loc)
298              << "expects number of operands for replicated input " << idx
299              << " to be 'n' (" << *n << "), got " << replicated_input.size();
300 
301     // Resolve replicated input and block argument type.
302     if (parser->resolveOperands(replicated_input, region_arg_types[idx],
303                                 state->operands))
304       return failure();
305   }
306 
307   const int32_t num_replicated_block_args = replicated_inputs.size();
308   for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
309     const int32_t idx = packed_input_and_idx.index();
310     const auto& packed_input = packed_input_and_idx.value();
311 
312     // Resolve packed input and block argument type.
313     if (parser->resolveOperand(
314             packed_input, region_arg_types[idx + num_replicated_block_args],
315             state->operands))
316       return failure();
317   }
318 
319   return success();
320 }
321 
322 }  // namespace
323 
324 static constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
325 
parse(OpAsmParser & parser,OperationState & result)326 ParseResult ReplicateOp::parse(OpAsmParser& parser, OperationState& result) {
327   llvm::SMLoc loc = parser.getCurrentLocation();
328 
329   // Parse operands, attributes, and region of op.
330   llvm::SmallVector<llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8>, 8>
331       replicated_inputs;
332   llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> packed_inputs;
333   llvm::SmallVector<OpAsmParser::UnresolvedOperand, 8> region_args;
334   llvm::SmallVector<Type, 8> region_arg_types;
335   int32_t n = 0;
336   Region& body = *result.addRegion();
337   if (ParseReplicateOpOperands(&parser, &result, &replicated_inputs,
338                                &packed_inputs, &region_args,
339                                &region_arg_types) ||
340       parser.parseOptionalAttrDict(result.attributes) ||
341       SetReplicateOpOperands(loc, &parser, &result, replicated_inputs,
342                              packed_inputs, region_arg_types, &n))
343     return failure();
344 
345   SmallVector<OpAsmParser::Argument> packed_args;
346   for (auto argAndType : llvm::zip(region_args, region_arg_types)) {
347     auto& arg = packed_args.emplace_back();
348     arg.ssaName = std::get<0>(argAndType);
349     arg.type = std::get<1>(argAndType);
350   }
351   if (parser.parseRegion(body, packed_args)) return failure();
352 
353   // Add derived `operand_segment_sizes` attribute based on parsed operands.
354   if (!result.attributes.get(kOperandSegmentSizesAttr)) {
355     int32_t num_replicated_inputs = replicated_inputs.size() * n;
356     int32_t num_packed_inputs = packed_inputs.size();
357     auto attr = parser.getBuilder().getDenseI32ArrayAttr({num_replicated_inputs, num_packed_inputs});
358     result.addAttribute(kOperandSegmentSizesAttr, attr);
359   }
360 
361   // Ensure that the region is well formed: it contains at least a block with
362   // a ReturnOp terminator.
363   ReplicateOp::ensureTerminator(body, parser.getBuilder(), result.location);
364 
365   if (!llvm::hasSingleElement(body))
366     return parser.emitError(loc) << "expects a single block region";
367 
368   Operation& terminator = body.front().back();
369   if (!isa<ReturnOp>(terminator))
370     return parser.emitError(loc) << "expects a tf_device.return terminator";
371 
372   // Get the results type from the terminator type inside the replicate,
373   // replicated each by `n`.
374   result.types.reserve(terminator.getNumOperands() * n);
375   for (const auto& type : terminator.getOperandTypes())
376     result.types.append(n, type);
377 
378   return success();
379 }
380 
print(OpAsmPrinter & p)381 void ReplicateOp::print(OpAsmPrinter& p) {
382   // Print comma separated operands of the following format:
383   //   replicated_input
384   //     [%a, ...] as %block_arg0: type
385   //   packed_input
386   //     %b as %block_arg1: type
387   const int32_t n = this->n();
388   const int32_t num_replicated_inputs =
389       operand_segment_sizes()[0];
390   const int32_t num_replicated_block_args = num_replicated_inputs / n;
391 
392   if (getNumOperands()) {
393     p << '(';
394     Block& block = body().front();
395     interleaveComma(block.getArguments(), p, [&](BlockArgument arg) {
396       const int block_arg_num = arg.getArgNumber();
397       if (block_arg_num < num_replicated_block_args) {
398         p << '[';
399         p.printOperands(
400             std::next(replicated_inputs().begin(), block_arg_num * n),
401             std::next(replicated_inputs().begin(), (block_arg_num + 1) * n));
402         p << "]";
403       } else {
404         p.printOperand(*std::next(packed_inputs().begin(),
405                                   block_arg_num - num_replicated_block_args));
406       }
407       p << " as " << arg << ": " << arg.getType();
408     });
409     p << ')';
410   }
411 
412   // Skip derived `operand_segment_sizes` attribute as custom print format of
413   // operands holds enough information to calculate these variadic operand list
414   // lengths.
415   p.printOptionalAttrDict(
416       getOperation()->getAttrs(),
417       /*elidedAttrs=*/ArrayRef<StringRef>{kOperandSegmentSizesAttr});
418   p << ' ';
419   p.printRegion(body(), /*printEntryBlockArgs=*/false);
420 }
421 
422 namespace {
423 
424 // Checks if two types are compatible (compatible shapes and same elemental
425 // type).
VerifyCompatibleTypes(Type a,Type b)426 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
427   if (failed(verifyCompatibleShape(a, b)) ||
428       getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
429     return failure();
430 
431   return success();
432 }
433 
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)434 void BuildReplicateOp(
435     Builder* builder, OperationState* state, int n,
436     llvm::Optional<DictionaryAttr> devices,
437     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
438     ValueRange packed_inputs, TypeRange replica_output_types) {
439   DCHECK_GE(n, 2);
440   state->addAttribute("n", builder->getI32IntegerAttr(n));
441 
442   if (devices.has_value()) state->addAttribute("devices", devices.getValue());
443 
444   Region* region = state->addRegion();
445   region->push_back(new Block);
446   Block& block = region->front();
447 
448   for (auto& replicated_input : replicated_inputs) {
449     DCHECK_EQ(llvm::size(replicated_input.first), n);
450     for (auto input : replicated_input.first) {
451       DCHECK(succeeded(
452           VerifyCompatibleTypes(input.getType(), replicated_input.second)));
453       state->addOperands(input);
454     }
455     block.addArgument(replicated_input.second, state->location);
456   }
457 
458   for (auto packed_input : packed_inputs) {
459     state->addOperands(packed_input);
460     block.addArgument(packed_input.getType(), state->location);
461   }
462 
463   // Add derived `operand_segment_sizes` attribute.
464   int32_t num_replicated_inputs = replicated_inputs.size() * n;
465   int32_t num_packed_inputs = packed_inputs.size();
466   auto operand_segment_sizes =
467       builder->getDenseI32ArrayAttr({num_replicated_inputs, num_packed_inputs});
468   state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
469 
470   for (const auto& output_type : replica_output_types)
471     state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
472 }
473 
474 }  // anonymous namespace
475 
verify()476 LogicalResult ReplicateOp::verify() {
477   ReplicateOp op = *this;
478   int32_t n = op.n();
479 
480   // Check number of devices, if set, matches `n`.
481   if (op.devices().has_value()) {
482     for (auto device_attr : op.devices().getValue().getValue()) {
483       auto device_list = device_attr.getValue().dyn_cast_or_null<ArrayAttr>();
484       if (!device_list)
485         return op.emitError()
486                << "expects 'devices' to be a map alias and device name list.";
487 
488       bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
489         return attr.dyn_cast_or_null<StringAttr>();
490       });
491       if (!is_device_string)
492         return op.emitOpError() << "expects 'devices' to be a consists of "
493                                    "string list as values.";
494 
495       if (device_list.size() != n)
496         return op.emitOpError()
497                << "expects number of devices (" << device_list.size()
498                << ") to be equal to 'n' (" << n << ")";
499     }
500   }
501 
502   Block& block = op.body().front();
503 
504   auto operand_segment_sizes = op.operand_segment_sizes();
505   const int32_t num_replicated_inputs =
506       operand_segment_sizes[0];
507   const int32_t num_packed_inputs =
508       operand_segment_sizes[1];
509 
510   if (num_replicated_inputs % n != 0)
511     return op.emitOpError()
512            << "expects number of replicated inputs (" << num_replicated_inputs
513            << ") to be evenly divisible by 'n' (" << n << ")";
514 
515   const int32_t num_replicated_block_args = num_replicated_inputs / n;
516   if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
517     return op.emitOpError()
518            << "expects number of block arguments (" << block.getNumArguments()
519            << ") to be equal to number of replicated inputs ("
520            << num_replicated_inputs << ") / 'n' (" << n
521            << ") + number of packed inputs (" << num_packed_inputs << ")";
522 
523   // Check input types match block argument types.
524   auto verify_operand_types = [&](BlockArgument block_arg,
525                                   int32_t op_operand_idx) -> LogicalResult {
526     Type op_operand_type = op.getOperand(op_operand_idx).getType();
527     if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
528       return op.emitOpError()
529              << "expects operand " << op_operand_idx << " (" << op_operand_type
530              << ") and block argument " << block_arg.getArgNumber() << " ("
531              << block_arg.getType() << ") to have compatible types";
532 
533     return success();
534   };
535   for (auto block_arg : block.getArguments()) {
536     if (block_arg.getArgNumber() < num_replicated_block_args) {
537       for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
538         if (failed(verify_operand_types(block_arg, i))) return failure();
539     } else {
540       const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
541                           num_replicated_inputs;
542       if (failed(verify_operand_types(block_arg, idx))) return failure();
543     }
544   }
545 
546   Operation& terminator = block.back();
547 
548   // Check number of results matches `n` * number of return operands.
549   if (op.getNumResults() != n * terminator.getNumOperands())
550     return op.emitOpError()
551            << "expects number of results (" << op.getNumResults()
552            << ") to be equal to 'n' * number of terminator operands (" << n
553            << " * " << terminator.getNumOperands() << ")";
554 
555   // Check replicated output types match return operand types.
556   for (auto operand_type_and_idx :
557        llvm::enumerate(terminator.getOperandTypes())) {
558     Type operand_type = operand_type_and_idx.value();
559     int32_t operand_idx = operand_type_and_idx.index();
560     for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
561       if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
562         return op.emitOpError() << "incompatible types for result " << i
563                                 << " and terminator operand " << operand_idx;
564   }
565 
566   return success();
567 }
568 
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)569 void ReplicateOp::build(
570     OpBuilder& builder, OperationState& state, int n,
571     const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
572         devices,
573     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
574     ValueRange packed_inputs, TypeRange replica_output_types) {
575   llvm::Optional<DictionaryAttr> devices_attr;
576   if (!devices.empty()) {
577     llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
578     device_list.reserve(devices.size());
579     for (auto alias_and_devices : devices) {
580       NamedAttribute device_name_attr = builder.getNamedAttr(
581           alias_and_devices.getFirst(),
582           builder.getStrArrayAttr(alias_and_devices.getSecond()));
583       device_list.emplace_back(device_name_attr);
584     }
585     devices_attr.emplace(builder.getDictionaryAttr(device_list));
586   }
587 
588   BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
589                    packed_inputs, replica_output_types);
590 }
591 
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)592 void ReplicateOp::build(
593     OpBuilder& builder, OperationState& state, int n,
594     llvm::Optional<DictionaryAttr> devices,
595     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
596     ValueRange packed_inputs, TypeRange replica_output_types) {
597   BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
598                    packed_inputs, replica_output_types);
599 }
600 
601 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()602 unsigned ReplicateOp::GetNumPackedBlockArguments() {
603   return packed_inputs().size();
604 }
605 
606 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()607 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
608   return GetBody().getNumArguments() - GetNumPackedBlockArguments();
609 }
610 
611 // Returns the replicated block arguments. A copy should be made if the
612 // replicate op is being modified.
GetReplicatedBlockArguments()613 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
614   return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
615 }
616 
617 // Returns the packed block arguments. A copy should be made if the replicate op
618 // is being modified.
GetPackedBlockArguments()619 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
620   return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
621 }
622 
623 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)624 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
625   assert(block_arg.getOwner() == &GetBody());
626   return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
627 }
628 
629 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)630 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
631   return !IsReplicatedBlockArgument(block_arg);
632 }
633 
634 // Returns the operand index of the operand being forwarded as a
635 // replicated/packed block argument for a given replica. This assumes a valid
636 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)637 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
638     BlockArgument block_arg, unsigned replica) {
639   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
640   if (operands.size() == 1) return operands.front().getOperandNumber();
641 
642   return operands[replica].getOperandNumber();
643 }
644 
645 // Returns the operand being forwarded as a replicated/packed block argument for
646 // a given replica. This assumes a valid block argument (of the replicate op)
647 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)648 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
649                                                      unsigned replica) {
650   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
651   if (operands.size() == 1) return operands.front().get();
652 
653   return operands[replica].get();
654 }
655 
656 // Returns the list of replica op operands that maps to the given block
657 // argument. Returns list with num_replicas elements for replicated operands
658 // and list with a single element for packed operands.
659 //
660 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)661 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
662     BlockArgument block_arg) {
663   assert(block_arg.getOwner() == &GetBody());
664 
665   unsigned arg_number = block_arg.getArgNumber();
666   unsigned num_replicated_args = GetNumReplicatedBlockArguments();
667   int32_t num_replicas = nAttr().getInt();
668   MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
669 
670   // All replicated arguments are before packed arguments so return replicated
671   // operands if the given argument is one of the replicated arguments.
672   if (arg_number < num_replicated_args)
673     return operands.slice(arg_number * num_replicas, num_replicas);
674 
675   operands = operands.drop_front(num_replicated_args * num_replicas);
676   arg_number -= num_replicated_args;
677   return operands.slice(arg_number, 1);
678 }
679 
680 // Checks if a tf_device.replicate wraps a single operation and the single
681 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()682 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
683 
684 //===----------------------------------------------------------------------===//
685 // Canonicalization patterns
686 //===----------------------------------------------------------------------===//
687 
688 //===----------------------------------------------------------------------===//
689 // tf_device.cluster
690 //===----------------------------------------------------------------------===//
691 
692 namespace {
693 
694 // Eliminates cluster op results that are not defined within the cluster and are
695 // defined outside. cluster op can be rewritten to remove those results.
EliminatePassThroughResults(ClusterOp op,PatternRewriter & rewriter)696 static LogicalResult EliminatePassThroughResults(ClusterOp op,
697                                                  PatternRewriter& rewriter) {
698   mlir::Block& body = op.GetBody();
699   Operation* return_op = body.getTerminator();
700   int num_results = return_op->getNumOperands();
701 
702   // Values defined within the cluster.
703   llvm::SmallVector<Value, 4> cluster_vals;
704   cluster_vals.reserve(num_results);
705 
706   // New results stores values to use while replacing the old cluster op.
707   llvm::SmallVector<Value, 4> new_results;
708   new_results.reserve(num_results);
709   for (OpOperand& operand : return_op->getOpOperands()) {
710     // If the corresponding result of the cluster op is used in some resource
711     // update op, do not eliminate the result. Such assignment ops could be for
712     // device resources and are required during fusing of the execute op and
713     // the resource update ops.
714     bool is_used_for_resource_write = llvm::any_of(
715         op.getResult(operand.getOperandNumber()).getUsers(),
716         [](Operation* user) { return isa<TF::AssignVariableOp>(user); });
717 
718     // TODO(b/186717563): Eliminate all pass through results once XLA correctly
719     // handles empty computations. Another approach could be to drop empty
720     // clusters within MLIR but that seems to trigger other failures but can be
721     // considered again.
722     // Old bridge only removes unsupported TPU types (only string for now)
723     // during outside compilation extraction so this should be enough for
724     // the parity.
725     bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType())
726                                    .isa<mlir::TF::StringType>();
727     Value result = operand.get();
728     if (is_unsupported_type && result.getParentBlock() != &body &&
729         !is_used_for_resource_write) {
730       // Pass through result.
731       new_results.push_back(result);
732     } else {
733       // This result will be populated with the new result after rewriting the
734       // cluster op.
735       new_results.push_back(nullptr);
736       cluster_vals.push_back(result);
737     }
738   }
739 
740   // Return failure if there are no pass through results and op is already
741   // canonical.
742   if (cluster_vals.size() == num_results) return failure();
743 
744   // Rewrite return op in the cluster.
745   rewriter.setInsertionPoint(return_op);
746   auto new_return =
747       rewriter.replaceOpWithNewOp<tf_device::ReturnOp>(return_op, cluster_vals);
748 
749   // Rewrite the cluster op.
750   rewriter.setInsertionPoint(op);
751   auto new_op = rewriter.create<tf_device::ClusterOp>(
752       op->getLoc(), new_return.getOperandTypes(), op->getOperands(),
753       op->getAttrs());
754   rewriter.inlineRegionBefore(op.getBodyRegion(), new_op.getBodyRegion(),
755                               new_op.getBodyRegion().end());
756 
757   int idx = 0;
758   for (Value& result : new_results) {
759     if (result == nullptr) result = new_op.getResult(idx++);
760   }
761   rewriter.replaceOp(op, new_results);
762   return success();
763 }
764 }  // anonymous namespace
765 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)766 void ClusterOp::getCanonicalizationPatterns(RewritePatternSet& results,
767                                             MLIRContext* context) {
768   results.add(EliminatePassThroughResults);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // tf_device.launch
773 //===----------------------------------------------------------------------===//
774 
775 namespace {
776 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
777 // results of the LaunchOp to the operands of the ReturnOp.
778 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
779   using OpRewritePattern<LaunchOp>::OpRewritePattern;
780 
matchAndRewritemlir::tf_device::__anon5124f7b00911::DropEmptyLaunch781   LogicalResult matchAndRewrite(LaunchOp op,
782                                 PatternRewriter& rewriter) const override {
783     Block& block = op.GetBody();
784     // Check if launch only has a return.
785     if (&block.front() != &block.back()) return failure();
786 
787     // Map launch results to return operands.
788     rewriter.replaceOp(op, block.front().getOperands());
789 
790     return success();
791   }
792 };
793 }  // anonymous namespace
794 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)795 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet& results,
796                                            MLIRContext* context) {
797   results.add<DropEmptyLaunch>(context);
798 }
799 
800 }  // namespace tf_device
801 }  // namespace mlir
802 
803 //===----------------------------------------------------------------------===//
804 // TableGen'd op method definitions
805 //===----------------------------------------------------------------------===//
806 
807 #define GET_OP_CLASSES
808 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
809