• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
40 #include "mlir/IR/Types.h"  // from @llvm-project
41 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
42 #include "mlir/IR/Value.h"  // from @llvm-project
43 #include "mlir/Support/LLVM.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/core/platform/logging.h"
48 
49 namespace mlir {
50 namespace tf_device {
51 
52 //===----------------------------------------------------------------------===//
53 // TF Device Dialect Interfaces
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 struct TFInlinerInterface : public DialectInlinerInterface {
58   using DialectInlinerInterface::DialectInlinerInterface;
59 
60   //===--------------------------------------------------------------------===//
61   // Analysis Hooks
62   //===--------------------------------------------------------------------===//
63 
64   // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface65   bool isLegalToInline(Operation* call, Operation* callable,
66                        bool wouldBeCloned) const final {
67     return true;
68   }
69 
70   // Returns if its legal to inline 'src' region into the 'dest' region
71   // attached to a TF Device operation.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface72   bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
73                        BlockAndValueMapping& valueMapping) const final {
74     return true;
75   }
76 
77   // Defines the legality of inlining TF Device operations.
isLegalToInlinemlir::tf_device::__anon72d616770111::TFInlinerInterface78   bool isLegalToInline(Operation*, Region*, bool,
79                        BlockAndValueMapping&) const final {
80     // For now, enable inlining all operations.
81     return true;
82   }
83 
84   //===--------------------------------------------------------------------===//
85   // Transformation Hooks
86   //===--------------------------------------------------------------------===//
87 
88   // Attempts to materialize a conversion for a type mismatch between a call
89   // from this dialect, and a callable region. This method should generate an
90   // operation that takes 'input' as the only operand, and produces a single
91   // result of 'resultType'. If a conversion can not be generated, nullptr
92   // should be returned.
93   // This is just re-using the same logic as the TensorFlow dialect right now.
materializeCallConversionmlir::tf_device::__anon72d616770111::TFInlinerInterface94   Operation* materializeCallConversion(OpBuilder& builder, Value input,
95                                        Type result_type,
96                                        Location conversion_loc) const final {
97     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
98       return nullptr;
99     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
100                                       /*truncate=*/builder.getBoolAttr(false));
101   }
102 };
103 
104 // Checks if a block wraps a single operation and the single operation results
105 // are perfectly forwarded to the block's terminator.
BlockWrapsSingleOp(Block * block)106 bool BlockWrapsSingleOp(Block* block) {
107   auto body = block->without_terminator();
108   if (!hasSingleElement(body)) return false;
109 
110   Operation& wrapped_op = *body.begin();
111   Operation* terminator = block->getTerminator();
112   return wrapped_op.getNumResults() == terminator->getNumOperands() &&
113          std::equal(wrapped_op.getResults().begin(),
114                     wrapped_op.getResults().end(),
115                     terminator->getOperands().begin());
116 }
117 }  // end anonymous namespace
118 
TensorFlowDeviceDialect(MLIRContext * context)119 TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
120     : Dialect(/*name=*/"tf_device", context,
121               TypeID::get<TensorFlowDeviceDialect>()) {
122   addOperations<
123 #define GET_OP_LIST
124 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
125       >();
126 
127   addInterfaces<TFInlinerInterface>();
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // tf_device.launch
132 //===----------------------------------------------------------------------===//
133 
134 // Checks if a tf_device.launch wraps a single operation and the single
135 // operation results are perfectly forwarded to the launch return.
WrapsSingleOp()136 bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
137 
138 //===----------------------------------------------------------------------===//
139 // tf_device.parallel_execute
140 //===----------------------------------------------------------------------===//
141 
142 namespace {
143 
Verify(ParallelExecuteOp op)144 LogicalResult Verify(ParallelExecuteOp op) {
145   const auto& regions = op.getOperation()->getRegions();
146   if (regions.size() < 2) {
147     return op.emitOpError() << "must have at least two regions.";
148   }
149 
150   int output_index = 0;
151   for (auto& region_and_index : llvm::enumerate(regions)) {
152     auto& region = region_and_index.value();
153     auto* region_terminator = region.front().getTerminator();
154 
155     // Check that output types of regions match return operand types.
156     for (auto result_type : region_terminator->getOperandTypes()) {
157       if (result_type !=
158           op.getOperation()->getResult(output_index++).getType()) {
159         return op.emitOpError() << "output types must be a concatenated "
160                                 << "list of output types for each regions.";
161       }
162     }
163   }
164 
165   // Check that total number of outputs from regions match the output types of
166   // the parallel_execute op.
167   const int num_output_types = op.getOperation()->getNumResults();
168   if (num_output_types != output_index) {
169     return op.emitOpError()
170            << "number of output types (" << num_output_types << ") "
171            << "must match the total number of outputs from all "
172            << "regions (" << output_index << ").";
173   }
174 
175   return success();
176 }
177 
178 }  // namespace
179 
180 // static
build(OpBuilder & builder,OperationState & state,int num_regions,TypeRange output_types)181 void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state,
182                               int num_regions, TypeRange output_types) {
183   DCHECK_GE(num_regions, 2);
184   for (int i = 0; i < num_regions; ++i) {
185     Region* region = state.addRegion();
186     region->push_back(new Block);
187   }
188   state.addTypes(output_types);
189 }
190 
GetRegionBlockWithIndex(unsigned index)191 Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
192   return getOperation()->getRegion(index).front();
193 }
194 
GetRegionOutputs(unsigned region_index)195 Operation::result_range ParallelExecuteOp::GetRegionOutputs(
196     unsigned region_index) {
197   int num_region_results =
198       GetRegionBlockWithIndex(region_index).getTerminator()->getNumOperands();
199 
200   int return_value_offset = 0;
201   for (int region_id = 0; region_id < region_index; ++region_id)
202     return_value_offset +=
203         GetRegionBlockWithIndex(region_id).getTerminator()->getNumOperands();
204 
205   return getResults().slice(return_value_offset, num_region_results);
206 }
207 
RegionWrapsSingleOp(unsigned index)208 bool ParallelExecuteOp::RegionWrapsSingleOp(unsigned index) {
209   return BlockWrapsSingleOp(&GetRegionBlockWithIndex(index));
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // tf_device.replicate
214 //===----------------------------------------------------------------------===//
215 
216 namespace {
ParseReplicateOpOperands(OpAsmParser * parser,OperationState * state,llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType,8>> * replicated_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * packed_inputs,llvm::SmallVectorImpl<OpAsmParser::OperandType> * region_args,llvm::SmallVectorImpl<Type> * region_arg_types)217 ParseResult ParseReplicateOpOperands(
218     OpAsmParser* parser, OperationState* state,
219     llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
220         replicated_inputs,
221     llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
222     llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
223     llvm::SmallVectorImpl<Type>* region_arg_types) {
224   // No operands or empty operand list.
225   bool parsed_l_paren = succeeded(parser->parseOptionalLParen());
226   if (!parsed_l_paren || succeeded(parser->parseOptionalRParen()))
227     return success();
228 
229   // Parse comma separated operands of the following format:
230   //   replicated_input
231   //     [%a, ...] as %block_arg0: type
232   //   packed_input
233   //     %b as %block_arg1: type
234   //
235   // Replicated inputs are placed before packed inputs when forming the op.
236   llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
237   llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
238   llvm::SmallVector<Type, 8> replicated_region_arg_types;
239   llvm::SmallVector<Type, 8> packed_region_arg_types;
240   do {
241     OpAsmParser::OperandType operand_type;
242     if (parser->parseOptionalOperand(operand_type).hasValue()) {
243       packed_inputs->emplace_back(operand_type);
244       if (parser->parseKeyword("as",
245                                " between packed input and block argument") ||
246           parser->parseRegionArgument(packed_region_args.emplace_back()) ||
247           parser->parseColonType(packed_region_arg_types.emplace_back()))
248         return failure();
249     } else if (parser->parseOperandList(replicated_inputs->emplace_back(),
250                                         OpAsmParser::Delimiter::Square) ||
251                parser->parseKeyword(
252                    "as", " between replicated inputs and block argument") ||
253                parser->parseRegionArgument(
254                    replicated_region_args.emplace_back()) ||
255                parser->parseColonType(
256                    replicated_region_arg_types.emplace_back())) {
257       return failure();
258     }
259   } while (succeeded(parser->parseOptionalComma()));
260 
261   region_args->reserve(replicated_region_args.size() +
262                        packed_region_args.size());
263   region_args->append(replicated_region_args.begin(),
264                       replicated_region_args.end());
265   region_args->append(packed_region_args.begin(), packed_region_args.end());
266 
267   region_arg_types->reserve(replicated_region_arg_types.size() +
268                             packed_region_arg_types.size());
269   region_arg_types->append(replicated_region_arg_types.begin(),
270                            replicated_region_arg_types.end());
271   region_arg_types->append(packed_region_arg_types.begin(),
272                            packed_region_arg_types.end());
273 
274   // Parse remaining `)` surrounding operands.
275   return parser->parseRParen();
276 }
277 
SetReplicateOpOperands(llvm::SMLoc loc,OpAsmParser * parser,OperationState * state,llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType,8>> replicated_inputs,llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,llvm::ArrayRef<Type> region_arg_types,int32_t * n)278 ParseResult SetReplicateOpOperands(
279     llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
280     llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
281         replicated_inputs,
282     llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
283     llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
284   for (const auto& attr : state->attributes)
285     if (attr.first.strref() == "n")
286       if (auto n_attr = attr.second.dyn_cast<IntegerAttr>())
287         *n = n_attr.getInt();
288 
289   if (*n < 2)
290     return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
291 
292   if (replicated_inputs.empty() && packed_inputs.empty()) return success();
293 
294   for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
295     const int32_t idx = replicated_input_and_idx.index();
296     const auto& replicated_input = replicated_input_and_idx.value();
297     // Check if replicated input matches `n`.
298     if (replicated_input.size() != *n)
299       return parser->emitError(loc)
300              << "expects number of operands for replicated input " << idx
301              << " to be 'n' (" << *n << "), got " << replicated_input.size();
302 
303     // Resolve replicated input and block argument type.
304     if (parser->resolveOperands(replicated_input, region_arg_types[idx],
305                                 state->operands))
306       return failure();
307   }
308 
309   const int32_t num_replicated_block_args = replicated_inputs.size();
310   for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
311     const int32_t idx = packed_input_and_idx.index();
312     const auto& packed_input = packed_input_and_idx.value();
313 
314     // Resolve packed input and block argument type.
315     if (parser->resolveOperand(
316             packed_input, region_arg_types[idx + num_replicated_block_args],
317             state->operands))
318       return failure();
319   }
320 
321   return success();
322 }
323 
324 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
325 
ParseReplicateOp(OpAsmParser * parser,OperationState * state)326 ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
327   llvm::SMLoc loc = parser->getCurrentLocation();
328 
329   // Parse operands, attributes, and region of op.
330   llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
331       replicated_inputs;
332   llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
333   llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
334   llvm::SmallVector<Type, 8> region_arg_types;
335   int32_t n = 0;
336   Region& body = *state->addRegion();
337   if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
338                                &packed_inputs, &region_args,
339                                &region_arg_types) ||
340       parser->parseOptionalAttrDict(state->attributes) ||
341       SetReplicateOpOperands(loc, parser, state, replicated_inputs,
342                              packed_inputs, region_arg_types, &n) ||
343       parser->parseRegion(body, region_args, region_arg_types))
344     return failure();
345 
346   // Add derived `operand_segment_sizes` attribute based on parsed operands.
347   if (!state->attributes.get(kOperandSegmentSizesAttr)) {
348     int32_t num_replicated_inputs = replicated_inputs.size() * n;
349     int32_t num_packed_inputs = packed_inputs.size();
350     auto attr = DenseIntElementsAttr::get(
351         VectorType::get({2}, parser->getBuilder().getI32Type()),
352         {num_replicated_inputs, num_packed_inputs});
353     state->addAttribute(kOperandSegmentSizesAttr, attr);
354   }
355 
356   // Ensure that the region is well formed: it contains at least a block with
357   // a ReturnOp terminator.
358   ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
359 
360   if (!llvm::hasSingleElement(body))
361     return parser->emitError(loc) << "expects a single block region";
362 
363   Operation& terminator = body.front().back();
364   if (!isa<ReturnOp>(terminator))
365     return parser->emitError(loc) << "expects a tf_device.return terminator";
366 
367   // Get the results type from the terminator type inside the replicate,
368   // replicated each by `n`.
369   state->types.reserve(terminator.getNumOperands() * n);
370   for (const auto& type : terminator.getOperandTypes())
371     state->types.append(n, type);
372 
373   return success();
374 }
375 
Print(ReplicateOp op,OpAsmPrinter * p)376 void Print(ReplicateOp op, OpAsmPrinter* p) {
377   *p << op.getOperationName();
378 
379   // Print comma separated operands of the following format:
380   //   replicated_input
381   //     [%a, ...] as %block_arg0: type
382   //   packed_input
383   //     %b as %block_arg1: type
384   const int32_t n = op.n();
385   const int32_t num_replicated_inputs =
386       (*op.operand_segment_sizes().int_value_begin()).getSExtValue();
387   const int32_t num_replicated_block_args = num_replicated_inputs / n;
388 
389   if (op.getNumOperands()) {
390     *p << '(';
391     Block& block = op.body().front();
392     interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
393       const int block_arg_num = arg.getArgNumber();
394       if (block_arg_num < num_replicated_block_args) {
395         *p << '[';
396         p->printOperands(
397             std::next(op.replicated_inputs().begin(), block_arg_num * n),
398             std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
399         *p << "]";
400       } else {
401         p->printOperand(*std::next(op.packed_inputs().begin(),
402                                    block_arg_num - num_replicated_block_args));
403       }
404       *p << " as " << arg << ": " << arg.getType();
405     });
406     *p << ')';
407   }
408 
409   // Skip derived `operand_segment_sizes` attribute as custom print format of
410   // operands holds enough information to calculate these variadic operand list
411   // lengths.
412   p->printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
413                                kOperandSegmentSizesAttr});
414   p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
415 }
416 
417 // Checks if two types are compatible (compatible shapes and same elemental
418 // type).
VerifyCompatibleTypes(Type a,Type b)419 LogicalResult VerifyCompatibleTypes(Type a, Type b) {
420   if (failed(verifyCompatibleShape(a, b)) ||
421       getElementTypeOrSelf(a) != getElementTypeOrSelf(b))
422     return failure();
423 
424   return success();
425 }
426 
Verify(ReplicateOp op)427 LogicalResult Verify(ReplicateOp op) {
428   int32_t n = op.n();
429 
430   // Check number of devices, if set, matches `n`.
431   if (op.devices().hasValue()) {
432     for (auto device_attr : op.devices().getValue().getValue()) {
433       auto device_list = device_attr.second.dyn_cast_or_null<ArrayAttr>();
434       if (!device_list)
435         return op.emitError()
436                << "expects 'devices' to be a map alias and device name list.";
437 
438       bool is_device_string = llvm::all_of(device_list, [](Attribute attr) {
439         return attr.dyn_cast_or_null<StringAttr>();
440       });
441       if (!is_device_string)
442         return op.emitOpError() << "expects 'devices' to be a consists of "
443                                    "string list as values.";
444 
445       if (device_list.size() != n)
446         return op.emitOpError()
447                << "expects number of devices (" << device_list.size()
448                << ") to be equal to 'n' (" << n << ")";
449     }
450   }
451 
452   Block& block = op.body().front();
453 
454   auto operand_segment_sizes = op.operand_segment_sizes();
455   const int32_t num_replicated_inputs =
456       operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
457   const int32_t num_packed_inputs =
458       operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
459 
460   if (num_replicated_inputs % n != 0)
461     return op.emitOpError()
462            << "expects number of replicated inputs (" << num_replicated_inputs
463            << ") to be evenly divisible by 'n' (" << n << ")";
464 
465   const int32_t num_replicated_block_args = num_replicated_inputs / n;
466   if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
467     return op.emitOpError()
468            << "expects number of block arguments (" << block.getNumArguments()
469            << ") to be equal to number of replicated inputs ("
470            << num_replicated_inputs << ") / 'n' (" << n
471            << ") + number of packed inputs (" << num_packed_inputs << ")";
472 
473   // Check input types match block argument types.
474   auto verify_operand_types = [&](BlockArgument block_arg,
475                                   int32_t op_operand_idx) -> LogicalResult {
476     Type op_operand_type = op.getOperand(op_operand_idx).getType();
477     if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
478       return op.emitOpError()
479              << "expects operand " << op_operand_idx << " (" << op_operand_type
480              << ") and block argument " << block_arg.getArgNumber() << " ("
481              << block_arg.getType() << ") to have compatible types";
482 
483     return success();
484   };
485   for (auto block_arg : block.getArguments()) {
486     if (block_arg.getArgNumber() < num_replicated_block_args) {
487       for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
488         if (failed(verify_operand_types(block_arg, i))) return failure();
489     } else {
490       const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
491                           num_replicated_inputs;
492       if (failed(verify_operand_types(block_arg, idx))) return failure();
493     }
494   }
495 
496   Operation& terminator = block.back();
497 
498   // Check number of results matches `n` * number of return operands.
499   if (op.getNumResults() != n * terminator.getNumOperands())
500     return op.emitOpError()
501            << "expects number of results (" << op.getNumResults()
502            << ") to be equal to 'n' * number of terminator operands (" << n
503            << " * " << terminator.getNumOperands() << ")";
504 
505   // Check replicated output types match return operand types.
506   for (auto operand_type_and_idx :
507        llvm::enumerate(terminator.getOperandTypes())) {
508     Type operand_type = operand_type_and_idx.value();
509     int32_t operand_idx = operand_type_and_idx.index();
510     for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
511       if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
512         return op.emitOpError() << "incompatible types for result " << i
513                                 << " and terminator operand " << operand_idx;
514   }
515 
516   return success();
517 }
518 
BuildReplicateOp(Builder * builder,OperationState * state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)519 void BuildReplicateOp(
520     Builder* builder, OperationState* state, int n,
521     llvm::Optional<DictionaryAttr> devices,
522     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
523     ValueRange packed_inputs, TypeRange replica_output_types) {
524   DCHECK_GE(n, 2);
525   state->addAttribute("n", builder->getI32IntegerAttr(n));
526 
527   if (devices.hasValue()) state->addAttribute("devices", devices.getValue());
528 
529   Region* region = state->addRegion();
530   region->push_back(new Block);
531   Block& block = region->front();
532 
533   for (auto& replicated_input : replicated_inputs) {
534     DCHECK_EQ(llvm::size(replicated_input.first), n);
535     for (auto input : replicated_input.first) {
536       DCHECK(succeeded(
537           VerifyCompatibleTypes(input.getType(), replicated_input.second)));
538       state->addOperands(input);
539     }
540     block.addArgument(replicated_input.second);
541   }
542 
543   for (auto packed_input : packed_inputs) {
544     state->addOperands(packed_input);
545     block.addArgument(packed_input.getType());
546   }
547 
548   // Add derived `operand_segment_sizes` attribute.
549   int32_t num_replicated_inputs = replicated_inputs.size() * n;
550   int32_t num_packed_inputs = packed_inputs.size();
551   auto operand_segment_sizes =
552       DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
553                                 {num_replicated_inputs, num_packed_inputs});
554   state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
555 
556   for (const auto& output_type : replica_output_types)
557     state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
558 }
559 }  // anonymous namespace
560 
build(OpBuilder & builder,OperationState & state,int n,const llvm::SmallDenseMap<StringRef,llvm::SmallVector<StringRef,4>> & devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)561 void ReplicateOp::build(
562     OpBuilder& builder, OperationState& state, int n,
563     const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
564         devices,
565     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
566     ValueRange packed_inputs, TypeRange replica_output_types) {
567   llvm::Optional<DictionaryAttr> devices_attr;
568   if (!devices.empty()) {
569     llvm::SmallVector<mlir::NamedAttribute, 1> device_list;
570     device_list.reserve(devices.size());
571     for (auto alias_and_devices : devices) {
572       NamedAttribute device_name_attr = builder.getNamedAttr(
573           alias_and_devices.getFirst(),
574           builder.getStrArrayAttr(alias_and_devices.getSecond()));
575       device_list.emplace_back(device_name_attr);
576     }
577     devices_attr.emplace(builder.getDictionaryAttr(device_list));
578   }
579 
580   BuildReplicateOp(&builder, &state, n, devices_attr, replicated_inputs,
581                    packed_inputs, replica_output_types);
582 }
583 
build(OpBuilder & builder,OperationState & state,int n,llvm::Optional<DictionaryAttr> devices,llvm::ArrayRef<std::pair<ValueRange,Type>> replicated_inputs,ValueRange packed_inputs,TypeRange replica_output_types)584 void ReplicateOp::build(
585     OpBuilder& builder, OperationState& state, int n,
586     llvm::Optional<DictionaryAttr> devices,
587     llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
588     ValueRange packed_inputs, TypeRange replica_output_types) {
589   BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
590                    packed_inputs, replica_output_types);
591 }
592 
593 // Returns the number of packed block arguments.
GetNumPackedBlockArguments()594 unsigned ReplicateOp::GetNumPackedBlockArguments() {
595   return packed_inputs().size();
596 }
597 
598 // Returns the number of replicated block arguments.
GetNumReplicatedBlockArguments()599 unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
600   return GetBody().getNumArguments() - GetNumPackedBlockArguments();
601 }
602 
603 // Returns the replicated block arguments. A copy should be made if the
604 // replicate op is being modified.
GetReplicatedBlockArguments()605 llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
606   return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
607 }
608 
609 // Returns the packed block arguments. A copy should be made if the replicate op
610 // is being modified.
GetPackedBlockArguments()611 llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
612   return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
613 }
614 
615 // Checks if a block argument is replicated (forwarding replicated inputs).
IsReplicatedBlockArgument(BlockArgument block_arg)616 bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
617   assert(block_arg.getOwner() == &GetBody());
618   return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
619 }
620 
621 // Checks if a block argument is packed (forwarding a packed input).
IsPackedBlockArgument(BlockArgument block_arg)622 bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
623   return !IsReplicatedBlockArgument(block_arg);
624 }
625 
626 // Returns the operand index of the operand being forwarded as a
627 // replicated/packed block argument for a given replica. This assumes a valid
628 // block argument (of the replicate op) and a valid replica is provided.
GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg,unsigned replica)629 unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
630     BlockArgument block_arg, unsigned replica) {
631   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
632   if (operands.size() == 1) return operands.front().getOperandNumber();
633 
634   return operands[replica].getOperandNumber();
635 }
636 
637 // Returns the operand being forwarded as a replicated/packed block argument for
638 // a given replica. This assumes a valid block argument (of the replicate op)
639 // and a valid replica is provided.
GetReplicaOperandForBlockArgument(BlockArgument block_arg,unsigned replica)640 Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
641                                                      unsigned replica) {
642   MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
643   if (operands.size() == 1) return operands.front().get();
644 
645   return operands[replica].get();
646 }
647 
648 // Returns the list of replica op operands that maps to the given block
649 // argument. Returns list with num_replicas elements for replicated operands
650 // and list with a single element for packed operands.
651 //
652 // Requires that block argument is of this replicate op.
GetOperandsForBlockArgument(BlockArgument block_arg)653 MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
654     BlockArgument block_arg) {
655   assert(block_arg.getOwner() == &GetBody());
656 
657   unsigned arg_number = block_arg.getArgNumber();
658   unsigned num_replicated_args = GetNumReplicatedBlockArguments();
659   int32_t num_replicas = nAttr().getInt();
660   MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
661 
662   // All replicated arguments are before packed arguments so return replicated
663   // operands if the given argument is one of the replicated arguments.
664   if (arg_number < num_replicated_args)
665     return operands.slice(arg_number * num_replicas, num_replicas);
666 
667   operands = operands.drop_front(num_replicated_args * num_replicas);
668   arg_number -= num_replicated_args;
669   return operands.slice(arg_number, 1);
670 }
671 
672 // Checks if a tf_device.replicate wraps a single operation and the single
673 // operation results are perfectly forwarded to the replicate return.
WrapsSingleOp()674 bool ReplicateOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); }
675 
676 //===----------------------------------------------------------------------===//
677 // Canonicalization patterns
678 //===----------------------------------------------------------------------===//
679 
680 //===----------------------------------------------------------------------===//
681 // tf_device.cluster
682 //===----------------------------------------------------------------------===//
683 
684 namespace {
685 
686 // Eliminates cluster op results that are not defined within the cluster and are
687 // defined outside. cluster op can be rewritten to remove those results.
EliminatePassThroughResults(ClusterOp op,PatternRewriter & rewriter)688 static LogicalResult EliminatePassThroughResults(ClusterOp op,
689                                                  PatternRewriter& rewriter) {
690   mlir::Block& body = op.GetBody();
691   Operation* return_op = body.getTerminator();
692   int num_results = return_op->getNumOperands();
693 
694   // Values defined within the cluster.
695   llvm::SmallVector<Value, 4> cluster_vals;
696   cluster_vals.reserve(num_results);
697 
698   // New results stores values to use while replacing the old cluster op.
699   llvm::SmallVector<Value, 4> new_results;
700   new_results.reserve(num_results);
701   for (OpOperand& operand : return_op->getOpOperands()) {
702     // If the corresponding result of the cluster op is used in some resource
703     // update op, do not eliminate the result. Such assignment ops could be for
704     // device resources and are required during fusing of the execute op and
705     // the resource update ops.
706     bool is_used_for_resource_write = llvm::any_of(
707         op.getResult(operand.getOperandNumber()).getUsers(),
708         [](Operation* user) { return isa<TF::AssignVariableOp>(user); });
709 
710     // TODO(b/186717563): Eliminate all pass through results once XLA correctly
711     // handles empty computations. Another approach could be to drop empty
712     // clusters within MLIR but that seems to trigger other failures but can be
713     // considered again.
714     // Old bridge only removes unsupported TPU types (only string for now)
715     // during outside compilation extraction so this should be enough for
716     // the parity.
717     bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType())
718                                    .isa<mlir::TF::StringType>();
719     Value result = operand.get();
720     if (is_unsupported_type && result.getParentBlock() != &body &&
721         !is_used_for_resource_write) {
722       // Pass through result.
723       new_results.push_back(result);
724     } else {
725       // This result will be populated with the new result after rewriting the
726       // cluster op.
727       new_results.push_back(nullptr);
728       cluster_vals.push_back(result);
729     }
730   }
731 
732   // Return failure if there are no pass through results and op is already
733   // canonical.
734   if (cluster_vals.size() == num_results) return failure();
735 
736   // Rewrite return op in the cluster.
737   rewriter.setInsertionPoint(return_op);
738   auto new_return =
739       rewriter.replaceOpWithNewOp<tf_device::ReturnOp>(return_op, cluster_vals);
740 
741   // Rewrite the cluster op.
742   rewriter.setInsertionPoint(op);
743   auto new_op = rewriter.create<tf_device::ClusterOp>(
744       op->getLoc(), new_return.getOperandTypes(), op->getOperands(),
745       op->getAttrs());
746   rewriter.inlineRegionBefore(op.getBodyRegion(), new_op.getBodyRegion(),
747                               new_op.getBodyRegion().end());
748 
749   int idx = 0;
750   for (Value& result : new_results) {
751     if (result == nullptr) result = new_op.getResult(idx++);
752   }
753   rewriter.replaceOp(op, new_results);
754   return success();
755 }
756 }  // anonymous namespace
757 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)758 void ClusterOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
759                                             MLIRContext* context) {
760   results.insert(EliminatePassThroughResults);
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // tf_device.launch
765 //===----------------------------------------------------------------------===//
766 
767 namespace {
768 // This pattern matches LaunchOps with only one ReturnOp (empty) and remaps the
769 // results of the LaunchOp to the operands of the ReturnOp.
770 struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
771   using OpRewritePattern<LaunchOp>::OpRewritePattern;
772 
matchAndRewritemlir::tf_device::__anon72d616770911::DropEmptyLaunch773   LogicalResult matchAndRewrite(LaunchOp op,
774                                 PatternRewriter& rewriter) const override {
775     Block& block = op.GetBody();
776     // Check if launch only has a return.
777     if (&block.front() != &block.back()) return failure();
778 
779     // Map launch results to return operands.
780     rewriter.replaceOp(op, block.front().getOperands());
781 
782     return success();
783   }
784 };
785 }  // anonymous namespace
786 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)787 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
788                                            MLIRContext* context) {
789   results.insert<DropEmptyLaunch>(context);
790 }
791 
792 }  // namespace tf_device
793 }  // namespace mlir
794 
795 //===----------------------------------------------------------------------===//
796 // TableGen'd op method definitions
797 //===----------------------------------------------------------------------===//
798 
799 #define GET_OP_CLASSES
800 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
801