1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Location.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/OperationSupport.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "mlir/IR/UseDefLists.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
43 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45
46 namespace mlir {
47 namespace TFTPU {
48
49 namespace {
50
51 constexpr char kDeviceAttr[] = "device";
52 constexpr char kDeviceCPU[] = "CPU";
53 constexpr char kFuncDeviceAttr[] = "tf.device";
54
55 // A pass that allows TPU input layout to be determined after JIT compilation.
56 // This is done by adding run-time ops that interpret compilation result and
57 // copy the input to device with that layout.
58 //
59 // Example: original program:
60 //
61 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
62 // %compile:2 = "tf._TPUCompileMlir"(...)
63 // %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
64 //
65 // Without this pass, later TF graph partitioning passes will insert send/recv
66 // between %input and %execute and data will be copied to device in a fixed
67 // layout. With this pass, the program will be transformed into:
68 //
69 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
70 // %compile:2 = "tf._TPUCompileMlir"(...)
71 // %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
72 // %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
73 // {device = "/TPU:0"}
74 // %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
75 // {device = "/TPU:0"}
76 //
77 // This way, %compile will determine the layout, which will be respected by
78 // %copy_to_device. There will not be send/recv ops added by later passes,
79 // because tf.TPUCopyWithLayout accepts a host input and produces a device
80 // output.
81 struct TPUDynamicLayoutPass
82 : public TF::PerFunctionAggregateAnalysisConsumerPass<
83 TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
84 void runOnFunction(
85 FuncOp func,
86 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
87 };
88
89 // Checks if the input producer op is supported in this transform. Right now, we
90 // only check if it is a tf.IteratorGetNext where resource input is coming from
91 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)92 bool IsSupportedInputOp(
93 Operation* op,
94 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
95 TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
96 if (!iterator_op) return false;
97
98 Value resource_iterator = iterator_op.iterator();
99
100 if (resource_alias_analysis.IsUnknownResource(resource_iterator))
101 return false;
102 llvm::SmallSetVector<Value, 8> aliases =
103 resource_alias_analysis.GetResourceAliases(resource_iterator);
104
105 auto is_generator = [](Value val) {
106 if (val.isa<BlockArgument>()) return true;
107 Operation* definition = val.getDefiningOp();
108 return definition->getNumOperands() == 0 &&
109 definition->getNumResults() == 1;
110 };
111
112 // Check all generator aliases (ops or function argument) are on CPU.
113 FuncOp func = iterator_op->getParentOfType<FuncOp>();
114 return llvm::all_of(aliases, [&](Value alias) {
115 // Ignore non-generator aliases.
116 if (!is_generator(alias)) return true;
117
118 StringAttr device;
119 if (auto arg = alias.dyn_cast<BlockArgument>()) {
120 device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
121 kFuncDeviceAttr);
122 } else {
123 device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
124 }
125
126 if (!device) return false;
127 tensorflow::DeviceNameUtils::ParsedName parsed_device;
128 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
129 &parsed_device)) {
130 return false;
131 }
132 return parsed_device.has_type && parsed_device.type == kDeviceCPU;
133 });
134 }
135
CreateBuilderAfterOp(Operation * op)136 OpBuilder CreateBuilderAfterOp(Operation* op) {
137 return OpBuilder(op->getBlock(), ++Block::iterator(op));
138 }
139
140 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)141 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
142 Value compilation_key,
143 tf_device::LaunchOp compile_launch,
144 OpBuilder* builder) {
145 return builder->create<TF::TPUGetLayoutOp>(
146 compile_launch.getLoc(),
147 llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
148 builder->getIntegerType(64))},
149 llvm::ArrayRef<Value>{compilation_key},
150 llvm::ArrayRef<NamedAttribute>{
151 builder->getNamedAttr("index",
152 builder->getI64IntegerAttr(execute_arg_index)),
153 builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
154 }
155
156 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)157 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
158 tf_device::LaunchOp compile_launch,
159 TF::TPUGetLayoutOp get_layout,
160 Value input, OpBuilder* builder) {
161 return builder->create<TF::TPUCopyWithLayoutOp>(
162 execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
163 llvm::ArrayRef<Value>{input, get_layout.layout()});
164 }
165
166 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)167 void HandleInput(Value input, const int64_t execute_arg_index,
168 TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
169 tf_device::LaunchOp compile_launch) {
170 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
171 auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
172 compile_launch, &builder);
173 builder.setInsertionPoint(execute_launch);
174 auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
175 get_layout, input, &builder);
176 copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
177 execute.setOperand(execute_arg_index, copy_with_layout);
178 }
179
180 // Performs transformation for replicated inputs. Returns true if this is a
181 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)182 bool HandleReplicatedInputs(
183 const int64_t execute_arg_index, Value compilation_key,
184 tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
185 mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
186 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
187 // We need to know the devices to copy to.
188 if (!replicate.devices()) return false;
189
190 MutableArrayRef<OpOperand> inputs =
191 replicate.GetOperandsForBlockArgument(replicate_arg);
192 for (auto entry : llvm::enumerate(inputs)) {
193 auto input_op = entry.value().get().getDefiningOp();
194 if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
195 return false;
196 }
197 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
198 auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
199 compile_launch, &builder);
200 builder.setInsertionPoint(replicate);
201 for (auto entry : llvm::enumerate(inputs)) {
202 auto copy_with_layout =
203 BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
204 entry.value().get(), &builder);
205
206 auto device_list = replicate.devices()
207 .getValue()
208 .get(execute_launch.getDevice())
209 .cast<ArrayAttr>();
210 copy_with_layout->setAttr(kDeviceAttr,
211 device_list.getValue()[entry.index()]);
212
213 entry.value().set(copy_with_layout);
214 }
215 return true;
216 }
217
218 // Performs transformation on a compile and associated execute(s) ops. The
219 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)220 void HandleCompileAndExecutes(
221 tf_device::LaunchOp compile_launch,
222 llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
223 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
224 auto compile =
225 llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
226 tensorflow::tpu::TPUCompileMetadataProto metadata;
227 metadata.ParseFromString(compile.metadata().str());
228 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
229 tensorflow::GetMetadataArgumentMapping(metadata);
230
231 bool metadata_updated = false;
232 auto maybe_replicate =
233 execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
234
235 for (auto execute_and_input_mapping :
236 llvm::zip(execute_launches, input_mappings)) {
237 auto& execute_launch = std::get<0>(execute_and_input_mapping);
238 auto execute =
239 llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
240 const auto& input_mapping = std::get<1>(execute_and_input_mapping);
241
242 for (auto& input_and_idx : llvm::enumerate(execute.args())) {
243 Value input = input_and_idx.value();
244 const int64_t execute_arg_index = input_and_idx.index();
245 if (auto block_arg = input.dyn_cast<BlockArgument>()) {
246 // For a block argument, consider transforms only when it is a
247 // replicated input (defining ops will be outside the replicate node).
248 if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
249 !HandleReplicatedInputs(execute_arg_index, execute.key(),
250 execute_launch, compile_launch, block_arg,
251 maybe_replicate, resource_alias_analysis)) {
252 continue;
253 }
254 } else {
255 // For an op output, consider transforms only when 1) there is no
256 // replication or 2) it is outside the replicate node that encloses the
257 // execute node. (Because if the op is inside replicate, it is probably
258 // not on the host.)
259 auto* input_op = input.getDefiningOp();
260 if (maybe_replicate &&
261 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
262 continue;
263 }
264 if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
265 HandleInput(input, execute_arg_index, execute, execute_launch,
266 compile_launch);
267 }
268
269 metadata.mutable_args(input_mapping[execute_arg_index])
270 ->set_unrestricted_layout(true);
271 metadata_updated = true;
272 }
273 }
274
275 if (metadata_updated)
276 compile->setAttr("metadata", StringAttr::get(compile.getContext(),
277 metadata.SerializeAsString()));
278 }
279
runOnFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)280 void TPUDynamicLayoutPass::runOnFunction(
281 FuncOp func,
282 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
283 func.walk([&](TF::_TPUCompileMlirOp compile) {
284 // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
285 auto compile_launch =
286 llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
287 if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
288
289 llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
290 execute_launches.reserve(compile_launch.getNumResults() - 1);
291 for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
292 if (!program_result.hasOneUse()) return;
293 Operation* user = *program_result.user_begin();
294 auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
295 if (!execute) return;
296 auto execute_launch =
297 llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
298 if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
299 execute_launches.push_back(execute_launch);
300 }
301
302 HandleCompileAndExecutes(compile_launch, execute_launches,
303 resource_alias_analysis);
304 });
305 }
306
307 } // namespace
308
CreateTPUDynamicLayoutPass()309 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
310 return std::make_unique<TPUDynamicLayoutPass>();
311 }
312
313 static PassRegistration<TPUDynamicLayoutPass> pass(
314 "tf-tpu-dynamic-layout-pass",
315 "Adds ops that allow TPU program inputs to have layouts determined at JIT "
316 "compile time.");
317
318 } // namespace TFTPU
319 } // namespace mlir
320