1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
18
19 #include <string>
20
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/OpImplementation.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30
31 namespace mlir {
32 namespace TF {
33
34 //===----------------------------------------------------------------------===//
35 // TensorFlow Contraction Fusion.
36 //===----------------------------------------------------------------------===//
37
38 struct ContractionFusion {
39 explicit ContractionFusion(
40 StringRef output_kernel, ArrayRef<int> additional_arguments = {},
41 ArrayRef<NamedAttribute> additional_attributes = {})
42 : output_kernel(output_kernel.str()),
43 additional_arguments(additional_arguments.begin(),
44 additional_arguments.end()),
45 additional_attributes(additional_attributes.begin(),
46 additional_attributes.end()) {}
47
48 // Name of the output kernel implementing the contraction fusion.
49 std::string output_kernel;
50
51 // Indices of additional arguments that will be forwarded to the fused
52 // operation (e.g. forward bias vector if fusing BiasAdd operation).
53 SmallVector<int, 4> additional_arguments;
54
55 // Add additional attributes to the fused node.
56 SmallVector<NamedAttribute, 4> additional_attributes;
57 };
58
59 //===----------------------------------------------------------------------===//
60 // TensorFlow Resource Handles.
61 //===----------------------------------------------------------------------===//
62
IsResourceHandleAnonymous(StringRef name)63 inline bool IsResourceHandleAnonymous(StringRef name) {
64 return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME;
65 }
66
67 // Helper struct representing an identifier for a resource handle. For resource
68 // handles created explicitly and shared across resource allocator ops,
69 // `container`, `name`, and `device` can be set. If an resource handle is tied
70 // to an instance of an operation (e.g. TensorFlow runtime operation caching),
71 // `op` can be set instead.
72 struct ResourceHandle {
ResourceHandleResourceHandle73 ResourceHandle(StringRef container, StringRef name, StringRef device,
74 Operation* op)
75 : container(container), name(name), device(device), op(op) {}
76
77 bool operator==(const ResourceHandle& rhs) const {
78 return container == rhs.container && name == rhs.name &&
79 device == rhs.device && op == rhs.op;
80 }
81
82 // Make ResourceHandle hashable.
83 friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
84
85 StringRef container;
86 StringRef name;
87 StringRef device;
88 Operation* op = nullptr;
89 };
90
91 // Make ResourceHandle hashable.
hash_value(const ResourceHandle & resource_handle)92 inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
93 return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
94 resource_handle.device, resource_handle.op);
95 }
96
97 // Helper struct holding a resource handle value and unique id associated to the
98 // resource handle.
99 struct ResourceHandleValueAndId {
ResourceHandleValueAndIdResourceHandleValueAndId100 ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
101
102 Value value;
103 int64_t id = -1;
104 };
105
106 //===----------------------------------------------------------------------===//
107 // TF op helper functions for handling resource handles and ids.
108 //===----------------------------------------------------------------------===//
109
110 // Returns device of op if present. If op has no device set, an empty string ref
111 // is returned instead.
112 llvm::StringRef GetDeviceOrEmpty(Operation* op);
113
114 // Returns resource handle value and id for resource op based on attributes. If
115 // a resource handle is anonymous, a new id is always returned.
116 ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
117 llvm::StringRef container, llvm::StringRef shared_name,
118 llvm::StringRef device, Value resource,
119 llvm::SmallDenseMap<ResourceHandle, int64_t>& resource_handle_id_map,
120 int64_t& next_id);
121
122 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
123 } // namespace TF
124 } // namespace mlir
125
126 namespace llvm {
127 template <>
128 struct DenseMapInfo<mlir::TF::ResourceHandle> {
129 static mlir::TF::ResourceHandle getEmptyKey() {
130 return {/*container=*/"", /*name=*/"", /*device=*/"",
131 /*op=*/DenseMapInfo<mlir::Operation*>::getEmptyKey()};
132 }
133
134 static mlir::TF::ResourceHandle getTombstoneKey() {
135 return {/*container=*/"", /*name=*/"", /*device=*/"",
136 /*op=*/DenseMapInfo<mlir::Operation*>::getTombstoneKey()};
137 }
138
139 static unsigned getHashValue(
140 const mlir::TF::ResourceHandle& resource_handle) {
141 return mlir::TF::hash_value(resource_handle);
142 }
143
144 static bool isEqual(const mlir::TF::ResourceHandle& lhs,
145 const mlir::TF::ResourceHandle& rhs) {
146 return lhs == rhs;
147 }
148 };
149 } // namespace llvm
150
151 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
152