• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
17 
18 #include "absl/container/inlined_vector.h"
19 #include "absl/strings/string_view.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/Types.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/cleanup.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36 
37 namespace tensorflow {
38 
39 using gtl::MakeCleanup;
40 
41 #define RETURN_FAILURE_IF_ERROR(expr) \
42   if (!IsOk(expr)) {                  \
43     return mlir::failure();           \
44   }
45 
IsOk(const TF_Status * s)46 static bool IsOk(const TF_Status* s) {
47   if (TF_GetCode(s) == TF_OK) return true;
48   VLOG(2) << TF_Message(s);
49   return false;
50 }
51 
IsOk(const Status & s)52 static bool IsOk(const Status& s) {
53   if (s.ok()) return true;
54   VLOG(2) << s.error_message();
55   return false;
56 }
57 
EvaluateOperation(mlir::Operation * inst,llvm::ArrayRef<mlir::ElementsAttr> operands,TFE_Context * context,llvm::SmallVectorImpl<mlir::Attribute> * results)58 mlir::LogicalResult EvaluateOperation(
59     mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
60     TFE_Context* context, llvm::SmallVectorImpl<mlir::Attribute>* results) {
61   if (!context) {
62     VLOG(1) << "Can't evaluate with null context.";
63     return mlir::failure();
64   }
65   // If any operand is nullptr returns true for a failure.
66   // TODO(b/120678030): remove this constraint if we find operators can be
67   // evaluated with some unknown operands.
68   if (std::any_of(operands.begin(), operands.end(),
69                   [](mlir::Attribute operand) { return !operand; })) {
70     VLOG(1) << "Can't evaluate since not all operands are constant.";
71     return mlir::failure();
72   }
73 
74   TF_Status* status = TF_NewStatus();
75   auto clean_status = MakeCleanup([status] { TF_DeleteStatus(status); });
76 
77   // Builds TF operation and sets all the attributes.
78   std::string node_name = "unnamed";
79   if (auto attr = inst->getAttrOfType<mlir::StringAttr>("name")) {
80     node_name = std::string(attr.getValue());
81   }
82   auto node_def_or = ConvertTFDialectOpToNodeDef(
83       inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true);
84   RETURN_FAILURE_IF_ERROR(node_def_or.status());
85   const auto& node_def = node_def_or.ValueOrDie();
86 
87   TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);
88   RETURN_FAILURE_IF_ERROR(status);
89   auto clean_op = MakeCleanup([op] { TFE_DeleteOp(op); });
90 
91   // Explicitly set device to Host CPU instead of the device present in device
92   // attribute of the MLIR op. The assigned device might be remote, not
93   // available during compilation or compilation only device for on demand
94   // execution which may create a recursion if used for constant folding.
95   constexpr char kHostCpu[] = "/job:localhost/replica:0/task:0/CPU:0";
96   TFE_OpSetDevice(op, kHostCpu, status);
97   RETURN_FAILURE_IF_ERROR(status);
98   for (const auto& attr : node_def->attr()) {
99     SetOpAttrValueScalar(context, op, attr.second, attr.first.c_str(), status);
100     RETURN_FAILURE_IF_ERROR(status);
101   }
102 
103   VLOG(1) << "Start to evaluate node: " << node_def->DebugString();
104 
105   // Adds inputs to the TF operation.
106   for (const auto operand : operands) {
107     Tensor tensor;
108     RETURN_FAILURE_IF_ERROR(ConvertToTensor(operand, &tensor));
109     TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status->status);
110     RETURN_FAILURE_IF_ERROR(status);
111     auto clean_tensor =
112         MakeCleanup([tf_tensor] { TF_DeleteTensor(tf_tensor); });
113     TFE_TensorHandle* input_handle = TFE_NewTensorHandle(tf_tensor, status);
114     RETURN_FAILURE_IF_ERROR(status);
115     auto clean_input_handle =
116         MakeCleanup([input_handle] { TFE_DeleteTensorHandle(input_handle); });
117     TFE_OpAddInput(op, input_handle, status);
118     RETURN_FAILURE_IF_ERROR(status);
119   }
120 
121   // Executes the TF operation.
122   int num_outputs = inst->getNumResults();
123   absl::InlinedVector<TFE_TensorHandle*, 2> outputs(num_outputs);
124   TFE_Execute(op, outputs.data(), &num_outputs, status);
125   RETURN_FAILURE_IF_ERROR(status);
126   auto clean_outputs = MakeCleanup([&outputs] {
127     for (TFE_TensorHandle* tensor_handle : outputs) {
128       TFE_DeleteTensorHandle(tensor_handle);
129     }
130   });
131 
132   // Converts the outputs to MLIR attributes.
133   mlir::Builder builder(inst->getContext());
134   for (TFE_TensorHandle* tensor_handle : outputs) {
135     TF_Tensor* tf_tensor = TFE_TensorHandleResolve(tensor_handle, status);
136     RETURN_FAILURE_IF_ERROR(status);
137     auto clean_tensor =
138         MakeCleanup([tf_tensor] { TF_DeleteTensor(tf_tensor); });
139     Tensor tensor;
140     RETURN_FAILURE_IF_ERROR(TF_TensorToTensor(tf_tensor, &tensor));
141     auto attr_or = ConvertTensor(tensor, &builder);
142     RETURN_FAILURE_IF_ERROR(attr_or.status());
143     results->push_back(attr_or.ValueOrDie());
144   }
145 
146   VLOG(1) << "Evaluate node " << node_name << " successfully!";
147 
148   return mlir::success();
149 }
150 
151 #undef RETURN_FAILURE_IF_ERROR
152 }  // namespace tensorflow
153