1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Location.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38
39 namespace mlir {
40 namespace TF {
41 namespace collection_ops_util {
42
CreateScalarConst(int32_t value,OpBuilder builder,Location loc)43 Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc) {
44 auto attr = DenseIntElementsAttr::get(
45 RankedTensorType::get({}, builder.getI32Type()), value);
46 return builder.create<TF::ConstOp>(loc, attr);
47 }
48
GetR1Const(ArrayRef<int64_t> r1,OpBuilder builder,Location loc,int bitwidth)49 Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
50 int bitwidth) {
51 llvm::SmallVector<APInt, 4> values;
52 int64_t rank = r1.size();
53 values.reserve(rank);
54 for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
55 auto result_type = RankedTensorType::get(
56 {rank}, IntegerType::get(builder.getContext(), bitwidth));
57 return builder.create<TF::ConstOp>(
58 loc, DenseElementsAttr::get(result_type, values));
59 }
60
GetIndicesForElement(Value index,Value buffer,OpBuilder builder,Location loc)61 Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
62 Location loc) {
63 auto buffer_type = buffer.getType().cast<RankedTensorType>();
64 if (buffer_type.getShape().size() == 1) return index;
65 // Create a concat of index and trailing zeros.
66 llvm::SmallVector<int64_t, 8> zeros(buffer_type.getShape().size() - 1, 0);
67 auto zeros_tensor = GetR1Const(zeros, builder, loc);
68 return builder.create<TF::ConcatV2Op>(
69 loc,
70 ArrayRef<Type>{RankedTensorType::get(
71 {static_cast<int64_t>(buffer_type.getShape().size())},
72 getElementTypeOrSelf(index.getType()))},
73 ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)});
74 }
75
GetElement(Value index,Value buffer,OpBuilder builder,Location loc,bool keep_slice_shape)76 Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
77 bool keep_slice_shape) {
78 auto buffer_type = buffer.getType().cast<RankedTensorType>();
79 // Create a slice then reshape to remove the leading trivial dimension of
80 // size 1.
81 llvm::SmallVector<int64_t, 8> slice_size =
82 llvm::to_vector<8>(buffer_type.getShape());
83 slice_size[0] = 1;
84 auto size_const = GetR1Const(slice_size, builder, loc);
85 auto slice_type =
86 RankedTensorType::get(slice_size, buffer_type.getElementType());
87 auto slice = builder.create<TF::SliceOp>(
88 loc, ArrayRef<Type>{slice_type},
89 ArrayRef<Value>{buffer, GetIndicesForElement(index, buffer, builder, loc),
90 size_const});
91 if (keep_slice_shape) return slice;
92 auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(),
93 buffer_type.getElementType());
94 auto reshape = builder.create<TF::ReshapeOp>(
95 loc, ArrayRef<Type>{element_type},
96 ArrayRef<Value>{slice,
97 GetR1Const(element_type.getShape(), builder, loc)});
98 return reshape.output();
99 }
100
SetElement(Value index,Value buffer,Value element,OpBuilder builder,Location loc)101 Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
102 Location loc) {
103 auto buffer_type = buffer.getType().cast<RankedTensorType>();
104 // Reshape the element to add a leading dimension of size 1 if th element does
105 // not have that dimension, then perform a dynamic update slice.
106 auto slice_shape = llvm::to_vector<8>(buffer_type.getShape());
107 slice_shape[0] = 1;
108 auto slice_type =
109 RankedTensorType::get(slice_shape, buffer_type.getElementType());
110 auto update_slice = element;
111 if (element.getType() != slice_type) {
112 update_slice = builder.create<TF::ReshapeOp>(
113 loc, ArrayRef<Type>{slice_type},
114 ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)});
115 }
116 return builder
117 .create<TF::XlaDynamicUpdateSliceOp>(
118 loc, ArrayRef<Type>{buffer.getType()},
119 ArrayRef<Value>{buffer, update_slice,
120 GetIndicesForElement(index, buffer, builder, loc)})
121 .output();
122 }
123
GetSizeType(OpBuilder builder)124 TensorType GetSizeType(OpBuilder builder) {
125 return RankedTensorType::get({1}, builder.getIntegerType(32));
126 }
127
ReshapeScalarToSizeType(OpBuilder builder,Value scalar,Location loc)128 Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) {
129 auto size_type = GetSizeType(builder);
130 return builder.create<TF::ReshapeOp>(
131 loc, ArrayRef<Type>{size_type},
132 ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)});
133 }
134
CreateInitBufferValue(ArrayRef<int64_t> element_shape,Value max_size,Operation * op,Type element_dtype,OpBuilder builder,Value * buffer)135 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
136 Value max_size, Operation* op,
137 Type element_dtype, OpBuilder builder,
138 Value* buffer) {
139 auto max_count_op = max_size.getDefiningOp();
140 if (!max_count_op) return op->emitOpError("unknown max element count");
141 auto max_count_const_op = llvm::dyn_cast<TF::ConstOp>(max_count_op);
142 if (!max_count_const_op) return op->emitOpError("unknown max element count");
143 int64_t max_size_const =
144 (*max_count_const_op.value().getValues<APInt>().begin()).getSExtValue();
145 return CreateInitBufferValue(element_shape, max_size_const, op, element_dtype,
146 builder, buffer);
147 }
148
CreateInitBufferValue(ArrayRef<int64_t> element_shape,int64_t max_size,Operation * op,Type element_dtype,OpBuilder builder,Value * buffer)149 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
150 int64_t max_size, Operation* op,
151 Type element_dtype, OpBuilder builder,
152 Value* buffer) {
153 llvm::SmallVector<int64_t, 8> buffer_shape;
154 buffer_shape.push_back(max_size);
155 for (int64_t dim : element_shape) {
156 buffer_shape.push_back(dim);
157 }
158 auto zero = CreateScalarConst(0, builder, op->getLoc());
159 if (getElementTypeOrSelf(zero.getType()) != element_dtype) {
160 zero = builder.create<TF::CastOp>(
161 op->getLoc(), ArrayRef<Type>{RankedTensorType::get({}, element_dtype)},
162 ArrayRef<Value>{zero});
163 }
164 auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype);
165 auto broadcast = builder.create<TF::BroadcastToOp>(
166 op->getLoc(), ArrayRef<Type>{buffer_type},
167 ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())});
168 *buffer = broadcast.output();
169 return success();
170 }
171
GetElementTypeFromAccess(Value collection,ModuleOp module,llvm::function_ref<llvm::Optional<Type> (Operation *)> infer_from_op)172 llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
173 Value collection, ModuleOp module,
174 llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op) {
175 for (auto& use : collection.getUses()) {
176 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
177 auto body = while_op.body_function();
178 assert(body);
179 auto type_from_body = GetElementTypeFromAccess(
180 body.getArgument(use.getOperandNumber()), module, infer_from_op);
181 if (type_from_body.hasValue()) return type_from_body;
182 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
183 auto then_branch = if_op.then_function();
184 auto else_branch = if_op.else_function();
185 assert(then_branch && else_branch);
186 auto type_from_then = GetElementTypeFromAccess(
187 then_branch.getArgument(use.getOperandNumber() - 1), module,
188 infer_from_op);
189 if (type_from_then.hasValue()) return type_from_then;
190 auto type_from_else = GetElementTypeFromAccess(
191 else_branch.getArgument(use.getOperandNumber() - 1), module,
192 infer_from_op);
193 if (type_from_else.hasValue()) return type_from_else;
194 } else if (auto call = llvm::dyn_cast<CallOpInterface>(use.getOwner())) {
195 auto callee = dyn_cast<FuncOp>(call.resolveCallable());
196 auto type_from_callee = GetElementTypeFromAccess(
197 callee.getArgument(use.getOperandNumber()), module, infer_from_op);
198 if (type_from_callee.hasValue()) return type_from_callee;
199 } else if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(use.getOwner())) {
200 auto type_from_alias = GetElementTypeFromAccess(
201 use.getOwner()->getResult(use.getOperandNumber()), module,
202 infer_from_op);
203 if (type_from_alias.hasValue()) return type_from_alias;
204 } else if (auto type = infer_from_op(use.getOwner())) {
205 if (!type) continue;
206 auto elem_type = type->dyn_cast<RankedTensorType>();
207 if (elem_type && elem_type.hasStaticShape()) return elem_type;
208 }
209 }
210 return llvm::None;
211 }
212
213 // Creates a ReadVariableOp on a local variable.
ReadLocalVariable(Value local_var,OpBuilder builder,Location loc)214 Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
215 return builder
216 .create<TF::ReadVariableOp>(
217 loc,
218 ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
219 .cast<TF::ResourceType>()
220 .getSubtypes()[0]},
221 ArrayRef<Value>{local_var})
222 .value();
223 }
224
225 // Creates an AssignVariableOp on a local variable.
WriteLocalVariable(Value local_var,Value value,OpBuilder builder,Location loc)226 TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
227 OpBuilder builder, Location loc) {
228 return builder.create<TF::AssignVariableOp>(
229 loc, ArrayRef<Type>{}, ArrayRef<Value>{local_var, value});
230 }
231
AccumulateBuffers(Value a,Value b,OpBuilder builder,Location loc)232 Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) {
233 if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) {
234 return builder.create<TF::LogicalOrOp>(loc, ArrayRef<Type>{a.getType()},
235 ArrayRef<Value>{a, b});
236 }
237 return builder.create<TF::AddV2Op>(loc, ArrayRef<Type>{a.getType()},
238 ArrayRef<Value>{a, b});
239 }
240
241 namespace {
242
GetFirstIfIndicesAreContiguous(Value indices)243 int64_t GetFirstIfIndicesAreContiguous(Value indices) {
244 auto type = indices.getType().dyn_cast<RankedTensorType>();
245 if (!type) return -1;
246 auto indices_op = indices.getDefiningOp();
247 if (!indices_op) return -1;
248 auto const_op = llvm::dyn_cast<TF::ConstOp>(indices_op);
249 if (!const_op) return -1;
250 int64_t last_index = -1;
251 int64_t first_index = -1;
252 for (const auto& ind : const_op.value().getValues<APInt>()) {
253 if (last_index == -1) {
254 last_index = ind.getSExtValue();
255 first_index = last_index;
256 continue;
257 }
258 if (last_index + 1 != ind.getSExtValue()) return -1;
259 last_index++;
260 }
261 return first_index;
262 }
263
264 } // namespace
265
GatherElements(Value indices,Value buffer,OpBuilder builder,Location loc)266 Value GatherElements(Value indices, Value buffer, OpBuilder builder,
267 Location loc) {
268 auto buffer_type = buffer.getType().cast<RankedTensorType>();
269 auto result_shape = llvm::to_vector<8>(buffer_type.getShape());
270 result_shape[0] = indices.getType().cast<RankedTensorType>().getDimSize(0);
271 int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
272 if (maybe_contiguous_start >= 0) {
273 llvm::SmallVector<int64_t, 8> slice_starts(result_shape.size(), 0);
274 slice_starts[0] = maybe_contiguous_start;
275 auto slice_type =
276 RankedTensorType::get(result_shape, buffer_type.getElementType());
277 return builder.create<TF::SliceOp>(
278 loc, ArrayRef<Type>{slice_type},
279 ArrayRef<Value>{buffer, GetR1Const(slice_starts, builder, loc),
280 GetR1Const(result_shape, builder, loc)});
281 }
282 auto result_type =
283 RankedTensorType::get(result_shape, buffer_type.getElementType());
284 return builder.create<TF::GatherV2Op>(
285 loc, ArrayRef<Type>{result_type},
286 ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)});
287 }
288
ScatterAccumulateElements(Value indices,Value updates,Value buffer,OpBuilder builder,Location loc)289 Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
290 OpBuilder builder, Location loc) {
291 auto buffer_type = buffer.getType().cast<RankedTensorType>();
292 auto updates_type = updates.getType().cast<RankedTensorType>();
293 int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
294 if (maybe_contiguous_start == 0 && buffer_type == updates_type) {
295 return AccumulateBuffers(buffer, updates, builder, loc);
296 }
297 // We cannot simply use a TensorScatterUpdate, as it does not accumulate with
298 // the old data; it is tricky to manually add the old data either, since there
299 // could be duplicates in the index. We follow the old bridge's approach by
300 // iterating through the indices.
301 auto per_slice_shape = llvm::to_vector<8>(buffer_type.getShape());
302 per_slice_shape[0] = 1;
303 auto slice_sizes = GetR1Const(per_slice_shape, builder, loc);
304 llvm::SmallVector<int64_t, 8> starts_in_update(buffer_type.getRank(), 0);
305 for (int64_t i = 0; i < updates_type.getDimSize(0); ++i) {
306 auto index = builder.create<TF::SliceOp>(
307 loc, ArrayRef<Type>{GetSizeType(builder)},
308 ArrayRef<Value>{indices, GetR1Const({i}, builder, loc),
309 GetR1Const({1}, builder, loc)});
310 auto old_slice =
311 GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true);
312 starts_in_update[0] = i;
313 auto update_slice_starts = GetR1Const(starts_in_update, builder, loc);
314 auto slice =
315 builder
316 .create<TF::SliceOp>(
317 loc, ArrayRef<Type>{old_slice.getType()},
318 ArrayRef<Value>{updates, update_slice_starts, slice_sizes})
319 .output();
320 slice = AccumulateBuffers(old_slice, slice, builder, loc);
321 buffer = SetElement(index, buffer, slice, builder, loc);
322 }
323 return buffer;
324 }
325
326 } // namespace collection_ops_util
327 } // namespace TF
328 } // namespace mlir
329