1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
41 #include "mlir/Dialect/Traits.h" // from @llvm-project
42 #include "mlir/IR/Attributes.h" // from @llvm-project
43 #include "mlir/IR/Builders.h" // from @llvm-project
44 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
45 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
46 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
47 #include "mlir/IR/Diagnostics.h" // from @llvm-project
48 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
49 #include "mlir/IR/Identifier.h" // from @llvm-project
50 #include "mlir/IR/Location.h" // from @llvm-project
51 #include "mlir/IR/MLIRContext.h" // from @llvm-project
52 #include "mlir/IR/Matchers.h" // from @llvm-project
53 #include "mlir/IR/OpDefinition.h" // from @llvm-project
54 #include "mlir/IR/OpImplementation.h" // from @llvm-project
55 #include "mlir/IR/PatternMatch.h" // from @llvm-project
56 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
57 #include "mlir/IR/Types.h" // from @llvm-project
58 #include "mlir/IR/Value.h" // from @llvm-project
59 #include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project
60 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
62 #include "mlir/Parser.h" // from @llvm-project
63 #include "mlir/Support/LLVM.h" // from @llvm-project
64 #include "mlir/Support/LogicalResult.h" // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
71 #include "tensorflow/core/framework/op.h"
72 #include "tensorflow/core/framework/op_def_builder.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/util/device_name_utils.h"
75 #include "tensorflow/core/util/tensor_format.h"
76
77 // These are currently aliases and the alias will be removed, verified
78 // equivalent until then.
79 // TODO(b/178519687): Remove once addressed.
80 static_assert(std::is_same<tensorflow::int64, std::int64_t>::value,
81 "tensorflow::int64 is expected to match std::int64_t");
82
83 namespace mlir {
84 namespace TF {
85
86 //===----------------------------------------------------------------------===//
87 // TF Dialect Interfaces
88 //===----------------------------------------------------------------------===//
89
90 namespace {
91
92 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anonaa709cd90111::TFConstantFoldInterface93 TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anonaa709cd90111::TFConstantFoldInterface94 LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
95 SmallVectorImpl<OpFoldResult> &results) const final {
96 return TensorFlowDialect::constantFold(op, operands, results);
97 }
98 };
99
100 struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
TFDecodeAttributesInterfacemlir::TF::__anonaa709cd90111::TFDecodeAttributesInterface101 TFDecodeAttributesInterface(Dialect *dialect)
102 : DialectDecodeAttributesInterface(dialect) {}
decodemlir::TF::__anonaa709cd90111::TFDecodeAttributesInterface103 LogicalResult decode(OpaqueElementsAttr input,
104 ElementsAttr &output) const override {
105 return TensorFlowDialect::decode(input, output);
106 }
107 };
108
109 // Helper function that implements the multi-device inlining policy behavior
110 // for the inliner hook. In particular, for all function body nodes set unset
111 // placement attributes to match the function call node.
MultiDeviceProcessInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks)112 void MultiDeviceProcessInlinedCallBlocks(
113 Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
114 using DeviceNameUtils = tensorflow::DeviceNameUtils;
115
116 // Duplicate of the logic in MultiDeviceFunctionBodyPlacer::BodyNodeDevice
117 // LINT.IfChange
118 auto device_id = Identifier::get("device", call->getContext());
119 auto caller_device = call->getAttrOfType<StringAttr>(device_id);
120 if (!caller_device) return;
121
122 DeviceNameUtils::ParsedName caller_parsed_device;
123 if (!DeviceNameUtils::ParseFullName(caller_device.getValue().str(),
124 &caller_parsed_device))
125 return;
126
127 MLIRContext *context = call->getContext();
128 auto node_device = [&](Operation *n) -> StringAttr {
129 auto device = n->getAttrOfType<StringAttr>(device_id);
130 if (!device || device.getValue().empty()) return caller_device;
131
132 DeviceNameUtils::ParsedName ndef_parsed_device;
133 if (!DeviceNameUtils::ParseFullName(device.getValue().str(),
134 &ndef_parsed_device))
135 return device;
136 DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
137 caller_parsed_device);
138 return StringAttr::get(
139 context, DeviceNameUtils::ParsedNameToString(ndef_parsed_device));
140 };
141 // LINT.ThenChange(../../../../core/common_runtime/inline_function_utils.cc)
142
143 for (Block &block : inlinedBlocks) {
144 block.walk([&](Operation *op) {
145 if (op->getDialect() == call->getDialect())
146 op->setAttr(device_id, node_device(op));
147 });
148 }
149 }
150
151 struct TFInlinerInterface : public DialectInlinerInterface {
152 using DialectInlinerInterface::DialectInlinerInterface;
153
154 //===--------------------------------------------------------------------===//
155 // Analysis Hooks
156 //===--------------------------------------------------------------------===//
157
158 // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
159 // a TF operation.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface160 bool isLegalToInline(Operation *call, Operation *callable,
161 bool wouldBeCloned) const final {
162 // Check that the TF call operation is one that is legal to inline.
163 return !isa<TPUPartitionedCallOp>(call);
164 }
165
166 // Returns if its legal to inline 'src' region into the 'dest' region
167 // attached to a TF operation.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface168 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
169 BlockAndValueMapping &valueMapping) const final {
170 // Allow inlining in regions attached to region based control flow
171 // operations only if the src region is a single block region
172 return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
173 llvm::hasSingleElement(*src);
174 }
175
176 // Returns true if its legal to inline a TF operation `op` into the `dest`
177 // region.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface178 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
179 BlockAndValueMapping &) const final {
180 // An op is legal to inline if either of the following conditions is true:
181 // (a) Its legal to duplicate the Op.
182 // (b) The Op is inside a single use function. If that function is inlined,
183 // post inlining, the function will be dead and eliminated from the IR.
184 // So there won't be any code duplication.
185 // plus the function caller op can be replaced by inlined ops.
186 return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
187 }
188
189 //===--------------------------------------------------------------------===//
190 // Transformation Hooks
191 //===--------------------------------------------------------------------===//
192
193 // Attempts to materialize a conversion for a type mismatch between a call
194 // from this dialect, and a callable region. This method should generate an
195 // operation that takes 'input' as the only operand, and produces a single
196 // result of 'resultType'. If a conversion can not be generated, nullptr
197 // should be returned.
materializeCallConversionmlir::TF::__anonaa709cd90111::TFInlinerInterface198 Operation *materializeCallConversion(OpBuilder &builder, Value input,
199 Type result_type,
200 Location conversion_loc) const final {
201 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
202 return nullptr;
203 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
204 /*truncate=*/builder.getBoolAttr(false));
205 }
206
processInlinedCallBlocksmlir::TF::__anonaa709cd90111::TFInlinerInterface207 void processInlinedCallBlocks(
208 Operation *call,
209 iterator_range<Region::iterator> inlinedBlocks) const final {
210 bool has_lower_as_multi_device_function_attr = false;
211 if (auto lower = call->getAttrOfType<BoolAttr>(
212 tensorflow::LowerFunctionalOpsConstants::
213 kLowerAsMultiDeviceFunctionAttr))
214 has_lower_as_multi_device_function_attr = lower.getValue();
215 tensorflow::FunctionCallInlinePolicy policy =
216 tensorflow::GetFunctionCallInlinePolicy(
217 isa<PartitionedCallOp, StatefulPartitionedCallOp>(call),
218 has_lower_as_multi_device_function_attr);
219
220 if (policy == tensorflow::FunctionCallInlinePolicy::kMultiDevicePlacer)
221 return MultiDeviceProcessInlinedCallBlocks(call, inlinedBlocks);
222 }
223 };
224 } // end anonymous namespace
225
226 //===----------------------------------------------------------------------===//
227 // TF Dialect
228 //===----------------------------------------------------------------------===//
229
230 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)231 bool TensorFlowDialect::CanDuplicate(Operation *op) {
232 // If the op is marked with the cannot duplicate trait, it cannot be
233 // duplicated.
234 if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
235
236 // If the op has no memory side effects, it can be duplicated.
237 if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
238
239 // If the op is marked stateless using the `is_stateless` attribute, that
240 // attribute determines if the op can be duplicated.
241 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
242 return is_stateless.getValue();
243
244 // Assume ops can be duplicated if modelled.
245 return op->isRegistered();
246 }
247
248 // TF dialect fallback for MemoryEffectOpInterface. The filtering for returning
249 // the interface is done in the return below and here it is empty as it is only
250 // returned for known not-stateful and unmodelled ops.
251 struct TensorFlowRegistryEffectInterfaceFallback
252 : public MemoryEffectOpInterface::FallbackModel<
253 TensorFlowRegistryEffectInterfaceFallback> {
classofmlir::TF::TensorFlowRegistryEffectInterfaceFallback254 static bool classof(Operation *op) { return true; }
getEffectsmlir::TF::TensorFlowRegistryEffectInterfaceFallback255 void getEffects(
256 Operation *op,
257 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
258 &effects) const {}
259 };
260
getRegisteredInterfaceForOp(mlir::TypeID interface,mlir::OperationName opName)261 void *TensorFlowDialect::getRegisteredInterfaceForOp(
262 mlir::TypeID interface, mlir::OperationName opName) {
263 if (interface == TypeID::get<mlir::MemoryEffectOpInterface>()) {
264 // Don't use fallback for modelled ops.
265 if (opName.getAbstractOperation()) return nullptr;
266
267 // Only use fallback interface for known not-stateful ops.
268 const tensorflow::OpRegistrationData *op_reg_data = nullptr;
269 tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
270 opName.stripDialect().str(), &op_reg_data);
271 return (s.ok() && !op_reg_data->op_def.is_stateful())
272 ? fallback_effect_op_interface_
273 : nullptr;
274 }
275
276 return nullptr;
277 }
278
279 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)280 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
281 // If the op has no memory side effects, it has no side effects
282 if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
283
284 // If the op is marked stateless using the `is_stateless` attribute, then
285 // it has no side effects.
286 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
287 return !is_stateless.getValue();
288
289 // Terminators defined in the TF dialect do not have side effects.
290 if (op->hasTrait<OpTrait::IsTerminator>()) return false;
291
292 // Otherwise assume that the op can have side effects.
293 return true;
294 }
295
296 std::vector<TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()297 *TensorFlowDialect::GetAdditionalOperationHooks() {
298 static auto *const additional_operation_hooks =
299 new std::vector<TensorFlowDialect::AdditionalOpFunction>();
300 return additional_operation_hooks;
301 }
302
303 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
304 TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
305
TensorFlowDialect(MLIRContext * context)306 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
307 : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
308 context->getOrLoadDialect<::mlir::tf_type::TFTypeDialect>();
309 addOperations<
310 #define GET_OP_LIST
311 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
312 >();
313 addOperations<
314 #define GET_OP_LIST
315 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
316 >();
317 addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
318 TFConstantFoldInterface>();
319 fallback_effect_op_interface_ =
320 new TensorFlowRegistryEffectInterfaceFallback();
321
322 // Support unknown operations because not all TensorFlow operations are
323 // registered.
324 allowUnknownOperations();
325
326 for (const auto &hook : *GetAdditionalOperationHooks()) {
327 hook(*this);
328 }
329 }
330
~TensorFlowDialect()331 TensorFlowDialect::~TensorFlowDialect() {
332 delete fallback_effect_op_interface_;
333 }
334
parseType(DialectAsmParser & parser) const335 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
336 StringRef spec = parser.getFullSymbolSpec();
337 llvm::SMLoc loc = parser.getCurrentLocation();
338 parser.emitError(
339 loc, "tf dialect has no types, potentially meant !tf_type." + spec);
340 return nullptr;
341 }
342
parseAttribute(DialectAsmParser & parser,Type type) const343 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
344 Type type) const {
345 StringRef spec = parser.getFullSymbolSpec();
346 llvm::SMLoc loc = parser.getCurrentLocation();
347 parser.emitError(
348 loc, "tf dialect has no attributes, potentially meant #tf_type." + spec);
349 return nullptr;
350 }
351
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)352 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
353 Attribute value, Type type,
354 Location loc) {
355 return builder.create<ConstOp>(loc, type, value);
356 }
357
358 } // namespace TF
359 } // namespace mlir
360