• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Location.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
43 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45 
46 namespace mlir {
47 namespace TFTPU {
48 
49 namespace {
50 
51 constexpr char kDeviceAttr[] = "device";
52 constexpr char kDeviceCPU[] = "CPU";
53 constexpr char kFuncDeviceAttr[] = "tf.device";
54 
55 // A pass that allows TPU input layout to be determined after JIT compilation.
56 // This is done by adding run-time ops that interpret compilation result and
57 // copy the input to device with that layout.
58 //
59 // Example: original program:
60 //
61 //   %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
62 //   %compile:2 = "tf._TPUCompileMlir"(...)
63 //   %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
64 //
65 // Without this pass, later TF graph partitioning passes will insert send/recv
66 // between %input and %execute and data will be copied to device in a fixed
67 // layout. With this pass, the program will be transformed into:
68 //
69 //   %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
70 //   %compile:2 = "tf._TPUCompileMlir"(...)
71 //   %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
72 //   %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
73 //       {device = "/TPU:0"}
74 //   %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
75 //       {device = "/TPU:0"}
76 //
77 // This way, %compile will determine the layout, which will be respected by
78 // %copy_to_device. There will not be send/recv ops added by later passes,
79 // because tf.TPUCopyWithLayout accepts a host input and produces a device
80 // output.
81 struct TPUDynamicLayoutPass
82     : public TF::PerFunctionAggregateAnalysisConsumerPass<
83           TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
84   void runOnFunction(
85       FuncOp func,
86       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
87 
getArgumentmlir::TFTPU::__anoncde9fad20111::TPUDynamicLayoutPass88   StringRef getArgument() const final { return "tf-tpu-dynamic-layout-pass"; }
89 
getDescriptionmlir::TFTPU::__anoncde9fad20111::TPUDynamicLayoutPass90   StringRef getDescription() const final {
91     return "Adds ops that allow TPU program inputs to have layouts determined "
92            "at JIT compile time.";
93   }
94 };
95 
96 // Checks if the input producer op is supported in this transform. Right now, we
97 // only check if it is a tf.IteratorGetNext where resource input is coming from
98 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)99 bool IsSupportedInputOp(
100     Operation* op,
101     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
102   TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
103   if (!iterator_op) return false;
104 
105   Value resource_iterator = iterator_op.iterator();
106 
107   if (resource_alias_analysis.IsUnknownResource(resource_iterator))
108     return false;
109   llvm::SmallSetVector<Value, 8> aliases =
110       resource_alias_analysis.GetResourceAliases(resource_iterator);
111 
112   auto is_generator = [](Value val) {
113     if (val.isa<BlockArgument>()) return true;
114     Operation* definition = val.getDefiningOp();
115     return definition->getNumOperands() == 0 &&
116            definition->getNumResults() == 1;
117   };
118 
119   // Check all generator aliases (ops or function argument) are on CPU.
120   FuncOp func = iterator_op->getParentOfType<FuncOp>();
121   return llvm::all_of(aliases, [&](Value alias) {
122     // Ignore non-generator aliases.
123     if (!is_generator(alias)) return true;
124 
125     StringAttr device;
126     if (auto arg = alias.dyn_cast<BlockArgument>()) {
127       device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
128                                                        kFuncDeviceAttr);
129     } else {
130       device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
131     }
132 
133     if (!device) return false;
134     tensorflow::DeviceNameUtils::ParsedName parsed_device;
135     if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
136                                                     &parsed_device)) {
137       return false;
138     }
139     return parsed_device.has_type && parsed_device.type == kDeviceCPU;
140   });
141 }
142 
CreateBuilderAfterOp(Operation * op)143 OpBuilder CreateBuilderAfterOp(Operation* op) {
144   return OpBuilder(op->getBlock(), ++Block::iterator(op));
145 }
146 
147 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)148 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
149                                   Value compilation_key,
150                                   tf_device::LaunchOp compile_launch,
151                                   OpBuilder* builder) {
152   return builder->create<TF::TPUGetLayoutOp>(
153       compile_launch.getLoc(),
154       llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
155                                                  builder->getIntegerType(64))},
156       llvm::ArrayRef<Value>{compilation_key},
157       llvm::ArrayRef<NamedAttribute>{
158           builder->getNamedAttr("index",
159                                 builder->getI64IntegerAttr(execute_arg_index)),
160           builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
161 }
162 
163 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)164 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
165                                             tf_device::LaunchOp compile_launch,
166                                             TF::TPUGetLayoutOp get_layout,
167                                             Value input, OpBuilder* builder) {
168   return builder->create<TF::TPUCopyWithLayoutOp>(
169       execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
170       llvm::ArrayRef<Value>{input, get_layout.layout()});
171 }
172 
173 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)174 void HandleInput(Value input, const int64_t execute_arg_index,
175                  TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
176                  tf_device::LaunchOp compile_launch) {
177   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
178   auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
179                                    compile_launch, &builder);
180   builder.setInsertionPoint(execute_launch);
181   auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
182                                               get_layout, input, &builder);
183   copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
184   execute.setOperand(execute_arg_index, copy_with_layout);
185 }
186 
187 // Performs transformation for replicated inputs. Returns true if this is a
188 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)189 bool HandleReplicatedInputs(
190     const int64_t execute_arg_index, Value compilation_key,
191     tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
192     mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
193     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
194   // We need to know the devices to copy to.
195   if (!replicate.devices()) return false;
196 
197   MutableArrayRef<OpOperand> inputs =
198       replicate.GetOperandsForBlockArgument(replicate_arg);
199   for (auto entry : llvm::enumerate(inputs)) {
200     auto input_op = entry.value().get().getDefiningOp();
201     if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
202       return false;
203   }
204   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
205   auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
206                                    compile_launch, &builder);
207   builder.setInsertionPoint(replicate);
208   for (auto entry : llvm::enumerate(inputs)) {
209     auto copy_with_layout =
210         BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
211                             entry.value().get(), &builder);
212 
213     auto device_list = replicate.devices()
214                            .getValue()
215                            .get(execute_launch.getDevice())
216                            .cast<ArrayAttr>();
217     copy_with_layout->setAttr(kDeviceAttr,
218                               device_list.getValue()[entry.index()]);
219 
220     entry.value().set(copy_with_layout);
221   }
222   return true;
223 }
224 
225 // Performs transformation on a compile and associated execute(s) ops. The
226 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)227 void HandleCompileAndExecutes(
228     tf_device::LaunchOp compile_launch,
229     llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
230     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
231   auto compile =
232       llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
233   tensorflow::tpu::TPUCompileMetadataProto metadata;
234   metadata.ParseFromString(compile.metadata().str());
235   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
236       tensorflow::GetMetadataArgumentMapping(metadata);
237 
238   bool metadata_updated = false;
239   auto maybe_replicate =
240       execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
241 
242   for (auto execute_and_input_mapping :
243        llvm::zip(execute_launches, input_mappings)) {
244     auto& execute_launch = std::get<0>(execute_and_input_mapping);
245     auto execute =
246         llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
247     const auto& input_mapping = std::get<1>(execute_and_input_mapping);
248 
249     for (auto& input_and_idx : llvm::enumerate(execute.args())) {
250       Value input = input_and_idx.value();
251       const int64_t execute_arg_index = input_and_idx.index();
252       if (auto block_arg = input.dyn_cast<BlockArgument>()) {
253         // For a block argument, consider transforms only when it is a
254         // replicated input (defining ops will be outside the replicate node).
255         if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
256             !HandleReplicatedInputs(execute_arg_index, execute.key(),
257                                     execute_launch, compile_launch, block_arg,
258                                     maybe_replicate, resource_alias_analysis)) {
259           continue;
260         }
261       } else {
262         // For an op output, consider transforms only when 1) there is no
263         // replication or 2) it is outside the replicate node that encloses the
264         // execute node. (Because if the op is inside replicate, it is probably
265         // not on the host.)
266         auto* input_op = input.getDefiningOp();
267         if (maybe_replicate &&
268             maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
269           continue;
270         }
271         if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
272         HandleInput(input, execute_arg_index, execute, execute_launch,
273                     compile_launch);
274       }
275 
276       metadata.mutable_args(input_mapping[execute_arg_index])
277           ->set_unrestricted_layout(true);
278       metadata_updated = true;
279     }
280   }
281 
282   if (metadata_updated)
283     compile->setAttr("metadata", StringAttr::get(compile.getContext(),
284                                                  metadata.SerializeAsString()));
285 }
286 
runOnFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)287 void TPUDynamicLayoutPass::runOnFunction(
288     FuncOp func,
289     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
290   func.walk([&](TF::_TPUCompileMlirOp compile) {
291     // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
292     auto compile_launch =
293         llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
294     if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
295 
296     llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
297     execute_launches.reserve(compile_launch.getNumResults() - 1);
298     for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
299       if (!program_result.hasOneUse()) return;
300       Operation* user = *program_result.user_begin();
301       auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
302       if (!execute) return;
303       auto execute_launch =
304           llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
305       if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
306       execute_launches.push_back(execute_launch);
307     }
308 
309     HandleCompileAndExecutes(compile_launch, execute_launches,
310                              resource_alias_analysis);
311   });
312 }
313 
314 }  // namespace
315 
CreateTPUDynamicLayoutPass()316 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
317   return std::make_unique<TPUDynamicLayoutPass>();
318 }
319 
320 static PassRegistration<TPUDynamicLayoutPass> pass;
321 
322 }  // namespace TFTPU
323 }  // namespace mlir
324