1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
17
18 #include "absl/container/inlined_vector.h"
19 #include "absl/strings/string_view.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/Types.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/cleanup.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36
37 namespace tensorflow {
38
39 using gtl::MakeCleanup;
40
41 #define RETURN_FAILURE_IF_ERROR(expr) \
42 if (!IsOk(expr)) { \
43 return mlir::failure(); \
44 }
45
IsOk(const TF_Status * s)46 static bool IsOk(const TF_Status* s) {
47 if (TF_GetCode(s) == TF_OK) return true;
48 VLOG(2) << TF_Message(s);
49 return false;
50 }
51
IsOk(const Status & s)52 static bool IsOk(const Status& s) {
53 if (s.ok()) return true;
54 VLOG(2) << s.error_message();
55 return false;
56 }
57
EvaluateOperation(mlir::Operation * inst,llvm::ArrayRef<mlir::ElementsAttr> operands,TFE_Context * context,llvm::SmallVectorImpl<mlir::Attribute> * results)58 mlir::LogicalResult EvaluateOperation(
59 mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
60 TFE_Context* context, llvm::SmallVectorImpl<mlir::Attribute>* results) {
61 if (!context) {
62 VLOG(1) << "Can't evaluate with null context.";
63 return mlir::failure();
64 }
65 // If any operand is nullptr returns true for a failure.
66 // TODO(b/120678030): remove this constraint if we find operators can be
67 // evaluated with some unknown operands.
68 if (std::any_of(operands.begin(), operands.end(),
69 [](mlir::Attribute operand) { return !operand; })) {
70 VLOG(1) << "Can't evaluate since not all operands are constant.";
71 return mlir::failure();
72 }
73
74 TF_Status* status = TF_NewStatus();
75 auto clean_status = MakeCleanup([status] { TF_DeleteStatus(status); });
76
77 // Builds TF operation and sets all the attributes.
78 std::string node_name = "unnamed";
79 if (auto attr = inst->getAttrOfType<mlir::StringAttr>("name")) {
80 node_name = std::string(attr.getValue());
81 }
82 auto node_def_or = ConvertTFDialectOpToNodeDef(
83 inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true);
84 RETURN_FAILURE_IF_ERROR(node_def_or.status());
85 const auto& node_def = node_def_or.ValueOrDie();
86
87 TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);
88 RETURN_FAILURE_IF_ERROR(status);
89 auto clean_op = MakeCleanup([op] { TFE_DeleteOp(op); });
90
91 // Explicitly set device to Host CPU instead of the device present in device
92 // attribute of the MLIR op. The assigned device might be remote, not
93 // available during compilation or compilation only device for on demand
94 // execution which may create a recursion if used for constant folding.
95 constexpr char kHostCpu[] = "/job:localhost/replica:0/task:0/CPU:0";
96 TFE_OpSetDevice(op, kHostCpu, status);
97 RETURN_FAILURE_IF_ERROR(status);
98 for (const auto& attr : node_def->attr()) {
99 SetOpAttrValueScalar(context, op, attr.second, attr.first.c_str(), status);
100 RETURN_FAILURE_IF_ERROR(status);
101 }
102
103 VLOG(1) << "Start to evaluate node: " << node_def->DebugString();
104
105 // Adds inputs to the TF operation.
106 for (const auto operand : operands) {
107 Tensor tensor;
108 RETURN_FAILURE_IF_ERROR(ConvertToTensor(operand, &tensor));
109 TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status->status);
110 RETURN_FAILURE_IF_ERROR(status);
111 auto clean_tensor =
112 MakeCleanup([tf_tensor] { TF_DeleteTensor(tf_tensor); });
113 TFE_TensorHandle* input_handle = TFE_NewTensorHandle(tf_tensor, status);
114 RETURN_FAILURE_IF_ERROR(status);
115 auto clean_input_handle =
116 MakeCleanup([input_handle] { TFE_DeleteTensorHandle(input_handle); });
117 TFE_OpAddInput(op, input_handle, status);
118 RETURN_FAILURE_IF_ERROR(status);
119 }
120
121 // Executes the TF operation.
122 int num_outputs = inst->getNumResults();
123 absl::InlinedVector<TFE_TensorHandle*, 2> outputs(num_outputs);
124 TFE_Execute(op, outputs.data(), &num_outputs, status);
125 RETURN_FAILURE_IF_ERROR(status);
126 auto clean_outputs = MakeCleanup([&outputs] {
127 for (TFE_TensorHandle* tensor_handle : outputs) {
128 TFE_DeleteTensorHandle(tensor_handle);
129 }
130 });
131
132 // Converts the outputs to MLIR attributes.
133 mlir::Builder builder(inst->getContext());
134 for (TFE_TensorHandle* tensor_handle : outputs) {
135 TF_Tensor* tf_tensor = TFE_TensorHandleResolve(tensor_handle, status);
136 RETURN_FAILURE_IF_ERROR(status);
137 auto clean_tensor =
138 MakeCleanup([tf_tensor] { TF_DeleteTensor(tf_tensor); });
139 Tensor tensor;
140 RETURN_FAILURE_IF_ERROR(TF_TensorToTensor(tf_tensor, &tensor));
141 auto attr_or = ConvertTensor(tensor, &builder);
142 RETURN_FAILURE_IF_ERROR(attr_or.status());
143 results->push_back(attr_or.ValueOrDie());
144 }
145
146 VLOG(1) << "Evaluate node " << node_name << " successfully!";
147
148 return mlir::success();
149 }
150
151 #undef RETURN_FAILURE_IF_ERROR
152 } // namespace tensorflow
153