• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Location.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
43 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45 
46 namespace mlir {
47 namespace TFTPU {
48 
49 namespace {
50 
51 constexpr char kDeviceAttr[] = "device";
52 constexpr char kDeviceCPU[] = "CPU";
53 constexpr char kFuncDeviceAttr[] = "tf.device";
54 
55 // A pass that allows TPU input layout to be determined after JIT compilation.
56 // This is done by adding run-time ops that interpret compilation result and
57 // copy the input to device with that layout.
58 //
59 // Example: original program:
60 //
61 //   %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
62 //   %compile:2 = "tf._TPUCompileMlir"(...)
63 //   %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
64 //
65 // Without this pass, later TF graph partitioning passes will insert send/recv
66 // between %input and %execute and data will be copied to device in a fixed
67 // layout. With this pass, the program will be transformed into:
68 //
69 //   %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
70 //   %compile:2 = "tf._TPUCompileMlir"(...)
71 //   %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
72 //   %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
73 //       {device = "/TPU:0"}
74 //   %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
75 //       {device = "/TPU:0"}
76 //
77 // This way, %compile will determine the layout, which will be respected by
78 // %copy_to_device. There will not be send/recv ops added by later passes,
79 // because tf.TPUCopyWithLayout accepts a host input and produces a device
80 // output.
81 struct TPUDynamicLayoutPass
82     : public TF::PerFunctionAggregateAnalysisConsumerPass<
83           TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
84   void runOnFunction(
85       FuncOp func,
86       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
87 };
88 
89 // Checks if the input producer op is supported in this transform. Right now, we
90 // only check if it is a tf.IteratorGetNext where resource input is coming from
91 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)92 bool IsSupportedInputOp(
93     Operation* op,
94     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
95   TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
96   if (!iterator_op) return false;
97 
98   Value resource_iterator = iterator_op.iterator();
99 
100   if (resource_alias_analysis.IsUnknownResource(resource_iterator))
101     return false;
102   llvm::SmallSetVector<Value, 8> aliases =
103       resource_alias_analysis.GetResourceAliases(resource_iterator);
104 
105   auto is_generator = [](Value val) {
106     if (val.isa<BlockArgument>()) return true;
107     Operation* definition = val.getDefiningOp();
108     return definition->getNumOperands() == 0 &&
109            definition->getNumResults() == 1;
110   };
111 
112   // Check all generator aliases (ops or function argument) are on CPU.
113   FuncOp func = iterator_op->getParentOfType<FuncOp>();
114   return llvm::all_of(aliases, [&](Value alias) {
115     // Ignore non-generator aliases.
116     if (!is_generator(alias)) return true;
117 
118     StringAttr device;
119     if (auto arg = alias.dyn_cast<BlockArgument>()) {
120       device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
121                                                        kFuncDeviceAttr);
122     } else {
123       device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
124     }
125 
126     if (!device) return false;
127     tensorflow::DeviceNameUtils::ParsedName parsed_device;
128     if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
129                                                     &parsed_device)) {
130       return false;
131     }
132     return parsed_device.has_type && parsed_device.type == kDeviceCPU;
133   });
134 }
135 
CreateBuilderAfterOp(Operation * op)136 OpBuilder CreateBuilderAfterOp(Operation* op) {
137   return OpBuilder(op->getBlock(), ++Block::iterator(op));
138 }
139 
140 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)141 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
142                                   Value compilation_key,
143                                   tf_device::LaunchOp compile_launch,
144                                   OpBuilder* builder) {
145   return builder->create<TF::TPUGetLayoutOp>(
146       compile_launch.getLoc(),
147       llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
148                                                  builder->getIntegerType(64))},
149       llvm::ArrayRef<Value>{compilation_key},
150       llvm::ArrayRef<NamedAttribute>{
151           builder->getNamedAttr("index",
152                                 builder->getI64IntegerAttr(execute_arg_index)),
153           builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
154 }
155 
156 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)157 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
158                                             tf_device::LaunchOp compile_launch,
159                                             TF::TPUGetLayoutOp get_layout,
160                                             Value input, OpBuilder* builder) {
161   return builder->create<TF::TPUCopyWithLayoutOp>(
162       execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
163       llvm::ArrayRef<Value>{input, get_layout.layout()});
164 }
165 
166 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)167 void HandleInput(Value input, const int64_t execute_arg_index,
168                  TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
169                  tf_device::LaunchOp compile_launch) {
170   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
171   auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
172                                    compile_launch, &builder);
173   builder.setInsertionPoint(execute_launch);
174   auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
175                                               get_layout, input, &builder);
176   copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
177   execute.setOperand(execute_arg_index, copy_with_layout);
178 }
179 
180 // Performs transformation for replicated inputs. Returns true if this is a
181 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)182 bool HandleReplicatedInputs(
183     const int64_t execute_arg_index, Value compilation_key,
184     tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
185     mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
186     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
187   // We need to know the devices to copy to.
188   if (!replicate.devices()) return false;
189 
190   MutableArrayRef<OpOperand> inputs =
191       replicate.GetOperandsForBlockArgument(replicate_arg);
192   for (auto entry : llvm::enumerate(inputs)) {
193     auto input_op = entry.value().get().getDefiningOp();
194     if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
195       return false;
196   }
197   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
198   auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
199                                    compile_launch, &builder);
200   builder.setInsertionPoint(replicate);
201   for (auto entry : llvm::enumerate(inputs)) {
202     auto copy_with_layout =
203         BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
204                             entry.value().get(), &builder);
205 
206     auto device_list = replicate.devices()
207                            .getValue()
208                            .get(execute_launch.getDevice())
209                            .cast<ArrayAttr>();
210     copy_with_layout->setAttr(kDeviceAttr,
211                               device_list.getValue()[entry.index()]);
212 
213     entry.value().set(copy_with_layout);
214   }
215   return true;
216 }
217 
218 // Performs transformation on a compile and associated execute(s) ops. The
219 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)220 void HandleCompileAndExecutes(
221     tf_device::LaunchOp compile_launch,
222     llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
223     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
224   auto compile =
225       llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
226   tensorflow::tpu::TPUCompileMetadataProto metadata;
227   metadata.ParseFromString(compile.metadata().str());
228   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
229       tensorflow::GetMetadataArgumentMapping(metadata);
230 
231   bool metadata_updated = false;
232   auto maybe_replicate =
233       execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
234 
235   for (auto execute_and_input_mapping :
236        llvm::zip(execute_launches, input_mappings)) {
237     auto& execute_launch = std::get<0>(execute_and_input_mapping);
238     auto execute =
239         llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
240     const auto& input_mapping = std::get<1>(execute_and_input_mapping);
241 
242     for (auto& input_and_idx : llvm::enumerate(execute.args())) {
243       Value input = input_and_idx.value();
244       const int64_t execute_arg_index = input_and_idx.index();
245       if (auto block_arg = input.dyn_cast<BlockArgument>()) {
246         // For a block argument, consider transforms only when it is a
247         // replicated input (defining ops will be outside the replicate node).
248         if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
249             !HandleReplicatedInputs(execute_arg_index, execute.key(),
250                                     execute_launch, compile_launch, block_arg,
251                                     maybe_replicate, resource_alias_analysis)) {
252           continue;
253         }
254       } else {
255         // For an op output, consider transforms only when 1) there is no
256         // replication or 2) it is outside the replicate node that encloses the
257         // execute node. (Because if the op is inside replicate, it is probably
258         // not on the host.)
259         auto* input_op = input.getDefiningOp();
260         if (maybe_replicate &&
261             maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
262           continue;
263         }
264         if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
265         HandleInput(input, execute_arg_index, execute, execute_launch,
266                     compile_launch);
267       }
268 
269       metadata.mutable_args(input_mapping[execute_arg_index])
270           ->set_unrestricted_layout(true);
271       metadata_updated = true;
272     }
273   }
274 
275   if (metadata_updated)
276     compile->setAttr("metadata", StringAttr::get(compile.getContext(),
277                                                  metadata.SerializeAsString()));
278 }
279 
runOnFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)280 void TPUDynamicLayoutPass::runOnFunction(
281     FuncOp func,
282     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
283   func.walk([&](TF::_TPUCompileMlirOp compile) {
284     // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
285     auto compile_launch =
286         llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
287     if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
288 
289     llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
290     execute_launches.reserve(compile_launch.getNumResults() - 1);
291     for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
292       if (!program_result.hasOneUse()) return;
293       Operation* user = *program_result.user_begin();
294       auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
295       if (!execute) return;
296       auto execute_launch =
297           llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
298       if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
299       execute_launches.push_back(execute_launch);
300     }
301 
302     HandleCompileAndExecutes(compile_launch, execute_launches,
303                              resource_alias_analysis);
304   });
305 }
306 
307 }  // namespace
308 
CreateTPUDynamicLayoutPass()309 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
310   return std::make_unique<TPUDynamicLayoutPass>();
311 }
312 
313 static PassRegistration<TPUDynamicLayoutPass> pass(
314     "tf-tpu-dynamic-layout-pass",
315     "Adds ops that allow TPU program inputs to have layouts determined at JIT "
316     "compile time.");
317 
318 }  // namespace TFTPU
319 }  // namespace mlir
320