• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/iterator_range.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
41 #include "mlir/Dialect/Traits.h"  // from @llvm-project
42 #include "mlir/IR/Attributes.h"  // from @llvm-project
43 #include "mlir/IR/Builders.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
45 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
47 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
48 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
49 #include "mlir/IR/Identifier.h"  // from @llvm-project
50 #include "mlir/IR/Location.h"  // from @llvm-project
51 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
52 #include "mlir/IR/Matchers.h"  // from @llvm-project
53 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
54 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
55 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
56 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
57 #include "mlir/IR/Types.h"  // from @llvm-project
58 #include "mlir/IR/Value.h"  // from @llvm-project
59 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"  // from @llvm-project
60 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
62 #include "mlir/Parser.h"  // from @llvm-project
63 #include "mlir/Support/LLVM.h"  // from @llvm-project
64 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
71 #include "tensorflow/core/framework/op.h"
72 #include "tensorflow/core/framework/op_def_builder.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/util/device_name_utils.h"
75 #include "tensorflow/core/util/tensor_format.h"
76 
77 // These are currently aliases and the alias will be removed, verified
78 // equivalent until then.
79 // TODO(b/178519687): Remove once addressed.
80 static_assert(std::is_same<tensorflow::int64, std::int64_t>::value,
81               "tensorflow::int64 is expected to match std::int64_t");
82 
83 namespace mlir {
84 namespace TF {
85 
86 //===----------------------------------------------------------------------===//
87 // TF Dialect Interfaces
88 //===----------------------------------------------------------------------===//
89 
90 namespace {
91 
92 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anonaa709cd90111::TFConstantFoldInterface93   TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anonaa709cd90111::TFConstantFoldInterface94   LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
95                      SmallVectorImpl<OpFoldResult> &results) const final {
96     return TensorFlowDialect::constantFold(op, operands, results);
97   }
98 };
99 
100 struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
TFDecodeAttributesInterfacemlir::TF::__anonaa709cd90111::TFDecodeAttributesInterface101   TFDecodeAttributesInterface(Dialect *dialect)
102       : DialectDecodeAttributesInterface(dialect) {}
decodemlir::TF::__anonaa709cd90111::TFDecodeAttributesInterface103   LogicalResult decode(OpaqueElementsAttr input,
104                        ElementsAttr &output) const override {
105     return TensorFlowDialect::decode(input, output);
106   }
107 };
108 
109 // Helper function that implements the multi-device inlining policy behavior
110 // for the inliner hook. In particular, for all function body nodes set unset
111 // placement attributes to match the function call node.
MultiDeviceProcessInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks)112 void MultiDeviceProcessInlinedCallBlocks(
113     Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
114   using DeviceNameUtils = tensorflow::DeviceNameUtils;
115 
116   // Duplicate of the logic in MultiDeviceFunctionBodyPlacer::BodyNodeDevice
117   // LINT.IfChange
118   auto device_id = Identifier::get("device", call->getContext());
119   auto caller_device = call->getAttrOfType<StringAttr>(device_id);
120   if (!caller_device) return;
121 
122   DeviceNameUtils::ParsedName caller_parsed_device;
123   if (!DeviceNameUtils::ParseFullName(caller_device.getValue().str(),
124                                       &caller_parsed_device))
125     return;
126 
127   MLIRContext *context = call->getContext();
128   auto node_device = [&](Operation *n) -> StringAttr {
129     auto device = n->getAttrOfType<StringAttr>(device_id);
130     if (!device || device.getValue().empty()) return caller_device;
131 
132     DeviceNameUtils::ParsedName ndef_parsed_device;
133     if (!DeviceNameUtils::ParseFullName(device.getValue().str(),
134                                         &ndef_parsed_device))
135       return device;
136     DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
137                                         caller_parsed_device);
138     return StringAttr::get(
139         context, DeviceNameUtils::ParsedNameToString(ndef_parsed_device));
140   };
141   // LINT.ThenChange(../../../../core/common_runtime/inline_function_utils.cc)
142 
143   for (Block &block : inlinedBlocks) {
144     block.walk([&](Operation *op) {
145       if (op->getDialect() == call->getDialect())
146         op->setAttr(device_id, node_device(op));
147     });
148   }
149 }
150 
151 struct TFInlinerInterface : public DialectInlinerInterface {
152   using DialectInlinerInterface::DialectInlinerInterface;
153 
154   //===--------------------------------------------------------------------===//
155   // Analysis Hooks
156   //===--------------------------------------------------------------------===//
157 
158   // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
159   // a TF operation.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface160   bool isLegalToInline(Operation *call, Operation *callable,
161                        bool wouldBeCloned) const final {
162     // Check that the TF call operation is one that is legal to inline.
163     return !isa<TPUPartitionedCallOp>(call);
164   }
165 
166   // Returns if its legal to inline 'src' region into the 'dest' region
167   // attached to a TF operation.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface168   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
169                        BlockAndValueMapping &valueMapping) const final {
170     // Allow inlining in regions attached to region based control flow
171     // operations only if the src region is a single block region
172     return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
173            llvm::hasSingleElement(*src);
174   }
175 
176   // Returns true if its legal to inline a TF operation `op` into the `dest`
177   // region.
isLegalToInlinemlir::TF::__anonaa709cd90111::TFInlinerInterface178   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
179                        BlockAndValueMapping &) const final {
180     // An op is legal to inline if either of the following conditions is true:
181     // (a) Its legal to duplicate the Op.
182     // (b) The Op is inside a single use function. If that function is inlined,
183     //     post inlining, the function will be dead and eliminated from the IR.
184     //     So there won't be any code duplication.
185     // plus the function caller op can be replaced by inlined ops.
186     return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
187   }
188 
189   //===--------------------------------------------------------------------===//
190   // Transformation Hooks
191   //===--------------------------------------------------------------------===//
192 
193   // Attempts to materialize a conversion for a type mismatch between a call
194   // from this dialect, and a callable region. This method should generate an
195   // operation that takes 'input' as the only operand, and produces a single
196   // result of 'resultType'. If a conversion can not be generated, nullptr
197   // should be returned.
materializeCallConversionmlir::TF::__anonaa709cd90111::TFInlinerInterface198   Operation *materializeCallConversion(OpBuilder &builder, Value input,
199                                        Type result_type,
200                                        Location conversion_loc) const final {
201     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
202       return nullptr;
203     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
204                                       /*truncate=*/builder.getBoolAttr(false));
205   }
206 
processInlinedCallBlocksmlir::TF::__anonaa709cd90111::TFInlinerInterface207   void processInlinedCallBlocks(
208       Operation *call,
209       iterator_range<Region::iterator> inlinedBlocks) const final {
210     bool has_lower_as_multi_device_function_attr = false;
211     if (auto lower = call->getAttrOfType<BoolAttr>(
212             tensorflow::LowerFunctionalOpsConstants::
213                 kLowerAsMultiDeviceFunctionAttr))
214       has_lower_as_multi_device_function_attr = lower.getValue();
215     tensorflow::FunctionCallInlinePolicy policy =
216         tensorflow::GetFunctionCallInlinePolicy(
217             isa<PartitionedCallOp, StatefulPartitionedCallOp>(call),
218             has_lower_as_multi_device_function_attr);
219 
220     if (policy == tensorflow::FunctionCallInlinePolicy::kMultiDevicePlacer)
221       return MultiDeviceProcessInlinedCallBlocks(call, inlinedBlocks);
222   }
223 };
224 }  // end anonymous namespace
225 
226 //===----------------------------------------------------------------------===//
227 // TF Dialect
228 //===----------------------------------------------------------------------===//
229 
230 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)231 bool TensorFlowDialect::CanDuplicate(Operation *op) {
232   // If the op is marked with the cannot duplicate trait, it cannot be
233   // duplicated.
234   if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
235 
236   // If the op has no memory side effects, it can be duplicated.
237   if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
238 
239   // If the op is marked stateless using the `is_stateless` attribute, that
240   // attribute determines if the op can be duplicated.
241   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
242     return is_stateless.getValue();
243 
244   // Assume ops can be duplicated if modelled.
245   return op->isRegistered();
246 }
247 
248 // TF dialect fallback for MemoryEffectOpInterface. The filtering for returning
249 // the interface is done in the return below and here it is empty as it is only
250 // returned for known not-stateful and unmodelled ops.
251 struct TensorFlowRegistryEffectInterfaceFallback
252     : public MemoryEffectOpInterface::FallbackModel<
253           TensorFlowRegistryEffectInterfaceFallback> {
classofmlir::TF::TensorFlowRegistryEffectInterfaceFallback254   static bool classof(Operation *op) { return true; }
getEffectsmlir::TF::TensorFlowRegistryEffectInterfaceFallback255   void getEffects(
256       Operation *op,
257       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
258           &effects) const {}
259 };
260 
getRegisteredInterfaceForOp(mlir::TypeID interface,mlir::OperationName opName)261 void *TensorFlowDialect::getRegisteredInterfaceForOp(
262     mlir::TypeID interface, mlir::OperationName opName) {
263   if (interface == TypeID::get<mlir::MemoryEffectOpInterface>()) {
264     // Don't use fallback for modelled ops.
265     if (opName.getAbstractOperation()) return nullptr;
266 
267     // Only use fallback interface for known not-stateful ops.
268     const tensorflow::OpRegistrationData *op_reg_data = nullptr;
269     tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
270         opName.stripDialect().str(), &op_reg_data);
271     return (s.ok() && !op_reg_data->op_def.is_stateful())
272                ? fallback_effect_op_interface_
273                : nullptr;
274   }
275 
276   return nullptr;
277 }
278 
279 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)280 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
281   // If the op has no memory side effects, it has no side effects
282   if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
283 
284   // If the op is marked stateless using the `is_stateless` attribute, then
285   // it has no side effects.
286   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
287     return !is_stateless.getValue();
288 
289   // Terminators defined in the TF dialect do not have side effects.
290   if (op->hasTrait<OpTrait::IsTerminator>()) return false;
291 
292   // Otherwise assume that the op can have side effects.
293   return true;
294 }
295 
296 std::vector<TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()297     *TensorFlowDialect::GetAdditionalOperationHooks() {
298   static auto *const additional_operation_hooks =
299       new std::vector<TensorFlowDialect::AdditionalOpFunction>();
300   return additional_operation_hooks;
301 }
302 
303 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
304 TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
305 
TensorFlowDialect(MLIRContext * context)306 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
307     : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
308   context->getOrLoadDialect<::mlir::tf_type::TFTypeDialect>();
309   addOperations<
310 #define GET_OP_LIST
311 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
312       >();
313   addOperations<
314 #define GET_OP_LIST
315 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
316       >();
317   addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
318                 TFConstantFoldInterface>();
319   fallback_effect_op_interface_ =
320       new TensorFlowRegistryEffectInterfaceFallback();
321 
322   // Support unknown operations because not all TensorFlow operations are
323   // registered.
324   allowUnknownOperations();
325 
326   for (const auto &hook : *GetAdditionalOperationHooks()) {
327     hook(*this);
328   }
329 }
330 
~TensorFlowDialect()331 TensorFlowDialect::~TensorFlowDialect() {
332   delete fallback_effect_op_interface_;
333 }
334 
parseType(DialectAsmParser & parser) const335 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
336   StringRef spec = parser.getFullSymbolSpec();
337   llvm::SMLoc loc = parser.getCurrentLocation();
338   parser.emitError(
339       loc, "tf dialect has no types, potentially meant !tf_type." + spec);
340   return nullptr;
341 }
342 
parseAttribute(DialectAsmParser & parser,Type type) const343 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
344                                             Type type) const {
345   StringRef spec = parser.getFullSymbolSpec();
346   llvm::SMLoc loc = parser.getCurrentLocation();
347   parser.emitError(
348       loc, "tf dialect has no attributes, potentially meant #tf_type." + spec);
349   return nullptr;
350 }
351 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)352 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
353                                                   Attribute value, Type type,
354                                                   Location loc) {
355   return builder.create<ConstOp>(loc, type, value);
356 }
357 
358 }  // namespace TF
359 }  // namespace mlir
360