• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
18 
19 #include <string>
20 
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 
31 namespace mlir {
32 namespace TF {
33 
34 //===----------------------------------------------------------------------===//
35 // TensorFlow Contraction Fusion.
36 //===----------------------------------------------------------------------===//
37 
38 struct ContractionFusion {
39   explicit ContractionFusion(
40       StringRef output_kernel, ArrayRef<int> additional_arguments = {},
41       ArrayRef<NamedAttribute> additional_attributes = {})
42       : output_kernel(output_kernel.str()),
43         additional_arguments(additional_arguments.begin(),
44                              additional_arguments.end()),
45         additional_attributes(additional_attributes.begin(),
46                               additional_attributes.end()) {}
47 
48   // Name of the output kernel implementing the contraction fusion.
49   std::string output_kernel;
50 
51   // Indices of additional arguments that will be forwarded to the fused
52   // operation (e.g. forward bias vector if fusing BiasAdd operation).
53   SmallVector<int, 4> additional_arguments;
54 
55   // Add additional attributes to the fused node.
56   SmallVector<NamedAttribute, 4> additional_attributes;
57 };
58 
59 //===----------------------------------------------------------------------===//
60 // TensorFlow Resource Handles.
61 //===----------------------------------------------------------------------===//
62 
IsResourceHandleAnonymous(StringRef name)63 inline bool IsResourceHandleAnonymous(StringRef name) {
64   return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME;
65 }
66 
67 // Helper struct representing an identifier for a resource handle. For resource
68 // handles created explicitly and shared across resource allocator ops,
69 // `container`, `name`, and `device` can be set. If an resource handle is tied
70 // to an instance of an operation (e.g. TensorFlow runtime operation caching),
71 // `op` can be set instead.
72 struct ResourceHandle {
ResourceHandleResourceHandle73   ResourceHandle(StringRef container, StringRef name, StringRef device,
74                  Operation* op)
75       : container(container), name(name), device(device), op(op) {}
76 
77   bool operator==(const ResourceHandle& rhs) const {
78     return container == rhs.container && name == rhs.name &&
79            device == rhs.device && op == rhs.op;
80   }
81 
82   // Make ResourceHandle hashable.
83   friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
84 
85   StringRef container;
86   StringRef name;
87   StringRef device;
88   Operation* op = nullptr;
89 };
90 
91 // Make ResourceHandle hashable.
hash_value(const ResourceHandle & resource_handle)92 inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
93   return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
94                               resource_handle.device, resource_handle.op);
95 }
96 
97 // Helper struct holding a resource handle value and unique id associated to the
98 // resource handle.
99 struct ResourceHandleValueAndId {
ResourceHandleValueAndIdResourceHandleValueAndId100   ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
101 
102   Value value;
103   int64_t id = -1;
104 };
105 
106 //===----------------------------------------------------------------------===//
107 // TF op helper functions for handling resource handles and ids.
108 //===----------------------------------------------------------------------===//
109 
110 // Returns device of op if present. If op has no device set, an empty string ref
111 // is returned instead.
112 llvm::StringRef GetDeviceOrEmpty(Operation* op);
113 
114 // Returns resource handle value and id for resource op based on attributes. If
115 // a resource handle is anonymous, a new id is always returned.
116 ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
117     llvm::StringRef container, llvm::StringRef shared_name,
118     llvm::StringRef device, Value resource,
119     llvm::SmallDenseMap<ResourceHandle, int64_t>& resource_handle_id_map,
120     int64_t& next_id);
121 
122 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
123 }  // namespace TF
124 }  // namespace mlir
125 
126 namespace llvm {
127 template <>
128 struct DenseMapInfo<mlir::TF::ResourceHandle> {
129   static mlir::TF::ResourceHandle getEmptyKey() {
130     return {/*container=*/"", /*name=*/"", /*device=*/"",
131             /*op=*/DenseMapInfo<mlir::Operation*>::getEmptyKey()};
132   }
133 
134   static mlir::TF::ResourceHandle getTombstoneKey() {
135     return {/*container=*/"", /*name=*/"", /*device=*/"",
136             /*op=*/DenseMapInfo<mlir::Operation*>::getTombstoneKey()};
137   }
138 
139   static unsigned getHashValue(
140       const mlir::TF::ResourceHandle& resource_handle) {
141     return mlir::TF::hash_value(resource_handle);
142   }
143 
144   static bool isEqual(const mlir::TF::ResourceHandle& lhs,
145                       const mlir::TF::ResourceHandle& rhs) {
146     return lhs == rhs;
147   }
148 };
149 }  // namespace llvm
150 
151 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
152