1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <memory>
18
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Location.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/OperationSupport.h" // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
32 #include "mlir/Pass/PassManager.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "tensorflow/c/c_api.h"
35 #include "tensorflow/c/eager/abstract_context.h"
36 #include "tensorflow/c/eager/abstract_operation.h"
37 #include "tensorflow/c/eager/abstract_tensor_handle.h"
38 #include "tensorflow/c/eager/c_api.h"
39 #include "tensorflow/c/eager/c_api_internal.h"
40 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
41 #include "tensorflow/c/tensor_interface.h"
42 #include "tensorflow/c/tf_status.h"
43 #include "tensorflow/c/tf_status_helper.h"
44 #include "tensorflow/c/tf_status_internal.h"
45 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/tensor_shape.h"
57 #include "tensorflow/core/framework/types.pb.h"
58 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
59 #include "tensorflow/core/platform/errors.h"
60
61 namespace mlir {
62 namespace TF {
63 using tensorflow::AbstractFunction;
64 using tensorflow::AbstractOperation;
65 using tensorflow::AbstractTensorHandle;
66 using tensorflow::AbstractTensorInterface;
67 using tensorflow::dyn_cast;
68 using tensorflow::OutputList;
69 using tensorflow::string;
70 using tensorflow::errors::FailedPrecondition;
71 using tensorflow::errors::InvalidArgument;
72 using tensorflow::errors::Unimplemented;
73 using tensorflow::tracing::TracingContext;
74 using tensorflow::tracing::TracingOperation;
75 using tensorflow::tracing::TracingTensorHandle;
76
77 namespace {
78
RegisterDialects(mlir::MLIRContext & ctx)79 void RegisterDialects(mlir::MLIRContext& ctx) {
80 mlir::DialectRegistry registry;
81 mlir::RegisterAllTensorFlowDialects(registry);
82 ctx.appendDialectRegistry(registry);
83 ctx.loadAllAvailableDialects();
84 }
85
ConvertDataTypeToTensor(tensorflow::DataType dtype,Builder builder,Type * type)86 Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
87 Type* type) {
88 Status s = tensorflow::ConvertDataType(dtype, builder, type);
89 if (s.ok()) *type = UnrankedTensorType::get(*type);
90 return s;
91 }
92
93 class MlirTensor : public TracingTensorHandle {
94 public:
MlirTensor(Value value)95 explicit MlirTensor(Value value)
96 : TracingTensorHandle(kMlir), value_(value) {}
97
DataType() const98 tensorflow::DataType DataType() const override {
99 tensorflow::DataType type;
100 Status s = ConvertToDataType(value_.getType(), &type);
101 if (!s.ok()) {
102 return tensorflow::DT_INVALID;
103 }
104 return type;
105 }
106
Shape(tensorflow::PartialTensorShape * shape) const107 tensorflow::Status Shape(
108 tensorflow::PartialTensorShape* shape) const override {
109 // TODO(b/173074167): Implement this and enable tests in
110 // unified_api_test.cc.
111 return Unimplemented("MlirTensor::Shape is not implemented yet.");
112 }
113
getValue()114 Value getValue() { return value_; }
getElementType()115 Type getElementType() {
116 return value_.getType().cast<ShapedType>().getElementType();
117 }
118
119 // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)120 static bool classof(const AbstractTensorHandle* ptr) {
121 return ptr->getKind() == kMlir;
122 }
123
124 private:
125 Value value_;
126 };
127
128 class MlirFunctionContext;
129
130 class MlirAbstractOp : public TracingOperation {
131 public:
MlirAbstractOp(MLIRContext * context,MlirFunctionContext * function_context)132 explicit MlirAbstractOp(MLIRContext* context,
133 MlirFunctionContext* function_context)
134 : TracingOperation(kMlir),
135 context_(context),
136 function_context_(function_context) {}
137
Release()138 void Release() override { delete this; }
139
140 Status Reset(const char* op, const char* raw_device_name) override;
141
142 const string& Name() const override;
143
144 const string& DeviceName() const override;
145
146 Status SetDeviceName(const char* name) override;
147
148 Status AddInput(AbstractTensorHandle* input) override;
149 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
150 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
151 int* num_retvals) override;
152
153 Status SetAttrString(const char* attr_name, const char* data,
154 size_t length) override;
155 Status SetAttrInt(const char* attr_name, int64_t value) override;
156 Status SetAttrFloat(const char* attr_name, float value) override;
157 Status SetAttrBool(const char* attr_name, bool value) override;
158 Status SetAttrType(const char* attr_name,
159 tensorflow::DataType dtype) override;
160 Status SetAttrShape(const char* attr_name, const int64_t* dims,
161 const int num_dims) override;
162 Status SetAttrFunction(const char* attr_name,
163 const AbstractOperation* value) override;
164 Status SetAttrFunctionName(const char* attr_name, const char* value,
165 size_t length) override;
166 Status SetAttrTensor(const char* attr_name,
167 AbstractTensorInterface* tensor) override;
168 Status SetAttrStringList(const char* attr_name, const void* const* values,
169 const size_t* lengths, int num_values) override;
170 Status SetAttrFloatList(const char* attr_name, const float* values,
171 int num_values) override;
172 Status SetAttrIntList(const char* attr_name, const int64_t* values,
173 int num_values) override;
174 Status SetAttrTypeList(const char* attr_name,
175 const tensorflow::DataType* values,
176 int num_values) override;
177 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
178 int num_values) override;
179 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
180 const int* num_dims, int num_values) override;
181 Status SetAttrFunctionList(
182 const char* attr_name,
183 absl::Span<const AbstractOperation*> values) override;
184
185 Status SetOpName(const char* const op_name) override;
186
GetContext()187 MLIRContext* GetContext() { return context_; }
188
189 Status AddRef(Type type, Type* output_type);
190
191 Status Create(ArrayRef<Value> operands, OperationState**);
192
193 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)194 static bool classof(const AbstractOperation* ptr) {
195 return ptr->getKind() == kMlir;
196 }
197
198 private:
199 // Return true is there are still unfilled ODS slots for adding more inputs.
200 bool IsNextODSArgAvailable();
201
202 MLIRContext* context_;
203 MlirFunctionContext* function_context_;
204 SmallVector<Value, 8> operands_;
205 llvm::StringMap<Attribute> attrs_;
206 std::unique_ptr<OperationState> state_;
207 // This is the index of the next ODS operand that will be added with AddInput
208 // or AddInput;
209 int current_ods_input_ = 0;
210 const tensorflow::OpDef* op_def_ = nullptr;
211 const char* op_name_ = nullptr;
212 string tf_op_type_;
213 // TODO(srbs): Use this.
214 string device_name_;
215 };
216
217 // MlirFunction is a thin wrapper over a FuncOp.
218 class MlirFunction : public AbstractFunction {
219 public:
MlirFunction(std::unique_ptr<MLIRContext> context,OwningModuleRef module,FuncOp func)220 explicit MlirFunction(std::unique_ptr<MLIRContext> context,
221 OwningModuleRef module, FuncOp func)
222 : AbstractFunction(kMlir),
223 context_(std::move(context)),
224 module_(std::move(module)),
225 func_(func) {}
226
227 Status GetFunctionDef(tensorflow::FunctionDef** f) override;
228
229 // For LLVM style RTTI.
classof(const AbstractFunction * ptr)230 static bool classof(const AbstractFunction* ptr) {
231 return ptr->getKind() == kMlir;
232 }
233
234 private:
235 std::unique_ptr<MLIRContext> context_;
236 OwningModuleRef module_;
237 FuncOp func_;
238 std::unique_ptr<tensorflow::FunctionDef> fdef_;
239 };
240
241 class MlirFunctionContext : public TracingContext {
242 public:
MlirFunctionContext(const char * name)243 explicit MlirFunctionContext(const char* name)
244 : TracingContext(kMlir),
245 context_(std::make_unique<MLIRContext>()),
246 builder_(context_.get()) {
247 RegisterDialects(*context_);
248 // TODO(aminim) figure out the location story here
249 module_ = ModuleOp::create(builder_.getUnknownLoc());
250 func_ = FuncOp::create(builder_.getUnknownLoc(), name,
251 builder_.getFunctionType(llvm::None, llvm::None));
252 module_->push_back(func_);
253 builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
254 }
255
Release()256 void Release() override { delete this; }
257
CreateOperation()258 AbstractOperation* CreateOperation() override {
259 return new MlirAbstractOp(context_.get(), this);
260 }
261 Status AddParameter(tensorflow::DataType dtype,
262 const tensorflow::PartialTensorShape& shape,
263 TracingTensorHandle** handle) override;
264
265 Status Finalize(OutputList* outputs, AbstractFunction** f) override;
266
RegisterFunction(AbstractFunction * func)267 Status RegisterFunction(AbstractFunction* func) override {
268 return Unimplemented(
269 "Registering graph functions has not been implemented yet.");
270 }
271
RemoveFunction(const string & func)272 Status RemoveFunction(const string& func) override {
273 return Unimplemented(
274 "MlirFunctionContext::RemoveFunction has not been implemented yet.");
275 }
276
277 Operation* CreateOperationFromState(const OperationState& state);
278
279 private:
280 std::unique_ptr<MLIRContext> context_;
281 OpBuilder builder_;
282 FuncOp func_;
283 OwningModuleRef module_;
284 };
285
Reset(const char * op,const char * device_name)286 Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
287 if (state_) {
288 return FailedPrecondition("Reset called on already built op.");
289 }
290 TF_RETURN_IF_ERROR(
291 tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
292 assert(op_def_);
293
294 tf_op_type_ = op;
295 std::string name = "tf.";
296 name += op;
297 // TODO(aminim) figure out the location story here
298 state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
299 return Status::OK();
300 }
301
SetAttrType(const char * attr_name,tensorflow::DataType dtype)302 Status MlirAbstractOp::SetAttrType(const char* attr_name,
303 tensorflow::DataType dtype) {
304 if (!state_)
305 return FailedPrecondition(
306 "op_type must be specified before specifying attrs.");
307 Type mlir_type;
308 Builder builder(context_);
309 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
310 attrs_[attr_name] = TypeAttr::get(mlir_type);
311 return Status::OK();
312 }
313
SetOpName(const char * const op_name)314 Status MlirAbstractOp::SetOpName(const char* const op_name) {
315 // TODO(aminim): should we use a location?
316 if (op_name_) {
317 return FailedPrecondition("SetOpName called on already built op.");
318 }
319 op_name_ = op_name;
320 return Status::OK();
321 }
322
AddRef(Type type,Type * output_type)323 Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
324 Type elt_type = getElementTypeOrSelf(type);
325 if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
326 return InvalidArgument("Requested reference to a reference type");
327 }
328 elt_type = TensorFlowRefType::get(elt_type);
329 if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
330 *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type);
331 }
332 *output_type = UnrankedTensorType::get(elt_type);
333 return Status::OK();
334 }
335
Create(ArrayRef<Value> operands,OperationState ** state)336 Status MlirAbstractOp::Create(ArrayRef<Value> operands,
337 OperationState** state) {
338 state_->operands = llvm::to_vector<4>(operands);
339 Builder builder(context_);
340
341 if (current_ods_input_ != op_def_->input_arg_size())
342 return InvalidArgument(absl::StrCat("Mismatch in operands number: got ",
343 current_ods_input_, " expected ",
344 op_def_->input_arg_size(), " ; for op ",
345 state_->name.getStringRef().str()));
346
347 // Process results according to the op_def and infer types for derived
348 // attributes.
349 for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
350 int original_size = state_->types.size();
351 if (!output_arg.number_attr().empty()) {
352 // Same type repeated "repeats" times.
353 Attribute repeats_attr = attrs_[output_arg.number_attr()];
354 if (!repeats_attr)
355 return InvalidArgument("Missing attribute '", output_arg.number_attr(),
356 "' required for output list '",
357 output_arg.name(), "'");
358 if (!repeats_attr.isa<IntegerAttr>())
359 return InvalidArgument("Attribute '", output_arg.number_attr(),
360 "' required for output list '",
361 output_arg.name(), "' isn't an integer");
362 int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
363
364 if (!output_arg.type_attr().empty()) {
365 // Same type repeated "repeats" times.
366 Attribute attr = attrs_[output_arg.type_attr()];
367 if (!attr)
368 return InvalidArgument("Missing attribute '", output_arg.type_attr(),
369 "' required for output '", output_arg.name(),
370 "'");
371 TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
372 if (!type_attr)
373 return InvalidArgument("Attribute '", output_arg.type_attr(),
374 "' required for output '", output_arg.name(),
375 "' isn't a type attribute");
376 for (int i = 0; i < repeats; ++i)
377 state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
378 } else if (output_arg.type() != tensorflow::DT_INVALID) {
379 for (int i = 0; i < repeats; ++i) {
380 Type type;
381 TF_RETURN_IF_ERROR(
382 ConvertDataType(output_arg.type(), builder, &type));
383 state_->types.push_back(type);
384 }
385 } else {
386 return InvalidArgument("Missing type or type_attr field in ",
387 output_arg.ShortDebugString());
388 }
389 } else if (!output_arg.type_attr().empty()) {
390 Attribute attr = attrs_[output_arg.type_attr()];
391 if (!attr)
392 return InvalidArgument("Missing attribute '", output_arg.type_attr(),
393 "' required for output '", output_arg.name(),
394 "'");
395 TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
396 if (!type_attr)
397 return InvalidArgument("Attribute '", output_arg.type_attr(),
398 "' required for output '", output_arg.name(),
399 "' isn't a type attribute");
400 state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
401 } else if (!output_arg.type_list_attr().empty()) {
402 // This is pointing to an attribute which is an array of types.
403 Attribute attr = attrs_[output_arg.type_list_attr()];
404 if (!attr)
405 return InvalidArgument(
406 "Missing attribute '", output_arg.type_list_attr(),
407 "' required for output '", output_arg.name(), "'");
408 ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
409 if (!array_attr)
410 return InvalidArgument("Attribute '", output_arg.type_list_attr(),
411 "' required for output '", output_arg.name(),
412 "' isn't an array attribute");
413 for (Attribute attr : array_attr) {
414 TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
415 if (!type_attr)
416 return InvalidArgument("Array Attribute '",
417 output_arg.type_list_attr(),
418 "' required for output '", output_arg.name(),
419 "' has a non-Type element");
420 state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
421 }
422 } else if (output_arg.type() != tensorflow::DT_INVALID) {
423 Type type;
424 Builder builder(context_);
425 TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
426 state_->types.push_back(type);
427 } else {
428 return InvalidArgument("No type fields in ",
429 output_arg.ShortDebugString());
430 }
431 if (output_arg.is_ref()) {
432 // For all types that were added by this function call, make them refs.
433 for (Type& type : llvm::make_range(&state_->types[original_size],
434 state_->types.end())) {
435 Type output_type;
436 TF_RETURN_IF_ERROR(AddRef(type, &output_type));
437 type = output_type;
438 }
439 }
440 }
441 for (auto& it : attrs_) state_->addAttribute(it.first(), it.second);
442 *state = state_.get();
443 return Status::OK();
444 }
445
Name() const446 const string& MlirAbstractOp::Name() const { return tf_op_type_; }
447
DeviceName() const448 const string& MlirAbstractOp::DeviceName() const { return device_name_; }
449
SetDeviceName(const char * name)450 Status MlirAbstractOp::SetDeviceName(const char* name) {
451 device_name_ = name;
452 return Status::OK();
453 }
454
SetAttrString(const char * attr_name,const char * data,size_t length)455 Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
456 size_t length) {
457 return Unimplemented("SetAttrString has not been implemented yet.");
458 }
SetAttrInt(const char * attr_name,int64_t value)459 Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) {
460 return Unimplemented("SetAttrInt has not been implemented yet.");
461 }
SetAttrFloat(const char * attr_name,float value)462 Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
463 return Unimplemented("SetAttrFloat has not been implemented yet.");
464 }
SetAttrBool(const char * attr_name,bool value)465 Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
466 attrs_[attr_name] = BoolAttr::get(context_, value);
467 return Status::OK();
468 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)469 Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
470 const int num_dims) {
471 return Unimplemented("SetAttrShape has not been implemented yet.");
472 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)473 Status MlirAbstractOp::SetAttrFunction(const char* attr_name,
474 const AbstractOperation* value) {
475 return Unimplemented("SetAttrFunction has not been implemented yet.");
476 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)477 Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name,
478 const char* value, size_t length) {
479 return Unimplemented("SetAttrFunctionName has not been implemented yet.");
480 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)481 Status MlirAbstractOp::SetAttrTensor(const char* attr_name,
482 AbstractTensorInterface* tensor) {
483 return Unimplemented("SetAttrTensor has not been implemented yet.");
484 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)485 Status MlirAbstractOp::SetAttrStringList(const char* attr_name,
486 const void* const* values,
487 const size_t* lengths,
488 int num_values) {
489 return Unimplemented("SetAttrStringList has not been implemented yet.");
490 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)491 Status MlirAbstractOp::SetAttrFloatList(const char* attr_name,
492 const float* values, int num_values) {
493 return Unimplemented("SetAttrFloatList has not been implemented yet.");
494 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)495 Status MlirAbstractOp::SetAttrIntList(const char* attr_name,
496 const int64_t* values, int num_values) {
497 return Unimplemented("SetAttrIntList has not been implemented yet.");
498 }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)499 Status MlirAbstractOp::SetAttrTypeList(const char* attr_name,
500 const tensorflow::DataType* values,
501 int num_values) {
502 return Unimplemented("SetAttrTypeList has not been implemented yet.");
503 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)504 Status MlirAbstractOp::SetAttrBoolList(const char* attr_name,
505 const unsigned char* values,
506 int num_values) {
507 return Unimplemented("SetAttrBoolList has not been implemented yet.");
508 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)509 Status MlirAbstractOp::SetAttrShapeList(const char* attr_name,
510 const int64_t** dims,
511 const int* num_dims, int num_values) {
512 return Unimplemented("SetAttrShapeList has not been implemented yet.");
513 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)514 Status MlirAbstractOp::SetAttrFunctionList(
515 const char* attr_name, absl::Span<const AbstractOperation*> values) {
516 return Unimplemented("SetAttrFunctionList has not been implemented yet.");
517 }
518
GetFunctionDef(tensorflow::FunctionDef ** f)519 Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
520 if (fdef_) {
521 *f = fdef_.get();
522 return Status::OK();
523 }
524 PassManager pm(func_.getContext());
525 ::tensorflow::applyTensorflowAndCLOptions(pm);
526 pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
527 pm.addPass(CreateBreakUpIslandsPass());
528
529 // In case of failure, the `diag_handler` converts MLIR errors emitted to
530 // the MLIRContext into a tensorflow::Status.
531 StatusScopedDiagnosticHandler diag_handler(func_.getContext());
532 LogicalResult result = pm.run(func_->getParentOfType<ModuleOp>());
533 (void)result;
534 TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
535
536 tensorflow::GraphExportConfig configs;
537 fdef_.reset(new tensorflow::FunctionDef());
538 TF_RETURN_IF_ERROR(
539 ConvertMlirFunctionToFunctionLibraryDef(func_, configs, fdef_.get()));
540 *f = fdef_.get();
541 return Status::OK();
542 }
543
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)544 Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals,
545 int* num_retvals) {
546 OperationState* state;
547 TF_RETURN_IF_ERROR(Create(operands_, &state));
548 Operation* op = function_context_->CreateOperationFromState(*state);
549 *num_retvals = op->getNumResults();
550 for (int i = 0; i < *num_retvals; i++)
551 retvals[i] = new MlirTensor(op->getResult(i));
552 return Status::OK();
553 }
554
CreateOperationFromState(const OperationState & state)555 Operation* MlirFunctionContext::CreateOperationFromState(
556 const OperationState& state) {
557 return builder_.createOperation(state);
558 }
559
AddParameter(tensorflow::DataType dtype,const tensorflow::PartialTensorShape & shape,TracingTensorHandle ** handle)560 Status MlirFunctionContext::AddParameter(
561 tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
562 TracingTensorHandle** handle) {
563 // TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
564 // resolved.
565 Type type;
566 TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
567 *handle = new MlirTensor(func_.getBody().front().addArgument(type));
568 return Status::OK();
569 }
570
AddInput(AbstractTensorHandle * input)571 Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
572 if (current_ods_input_ >= op_def_->input_arg_size())
573 return InvalidArgument(
574 absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
575 op_def_->input_arg_size(), " allowed input_args ; for op ",
576 state_->name.getStringRef().str()));
577
578 auto* operand = dyn_cast<MlirTensor>(input);
579 if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
580 operands_.push_back(operand->getValue());
581
582 // Get the next ArgDef and use it to infer the derived attributes associated
583 // to this input.
584 const tensorflow::OpDef::ArgDef& arg_def =
585 op_def_->input_arg(current_ods_input_++);
586 Type expected_type;
587 if (arg_def.type() != tensorflow::DT_INVALID) {
588 Builder builder(context_);
589 TF_RETURN_IF_ERROR(
590 tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
591 if (arg_def.is_ref()) {
592 Type output_type;
593 TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
594 expected_type = output_type;
595 }
596 } else {
597 expected_type = cast<MlirTensor>(input)->getElementType();
598 }
599 if (!arg_def.type_attr().empty())
600 attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
601
602 return Status::OK();
603 }
604
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)605 Status MlirAbstractOp::AddInputList(
606 absl::Span<AbstractTensorHandle* const> inputs) {
607 if (current_ods_input_ >= op_def_->input_arg_size())
608 return InvalidArgument(
609 absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
610 op_def_->input_arg_size(), " allowed input_args"));
611
612 for (AbstractTensorHandle* input : inputs) {
613 auto* operand = dyn_cast<MlirTensor>(input);
614 if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
615 operands_.push_back(operand->getValue());
616 }
617
618 // Get the next ArgDef and use it to infer the derived attributes associated
619 // to this input.
620 const tensorflow::OpDef::ArgDef& arg_def =
621 op_def_->input_arg(current_ods_input_++);
622 if (!arg_def.number_attr().empty()) {
623 Builder builder(context_);
624 attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
625 // TODO(aminim): handle ref variable.
626 if (arg_def.type() != tensorflow::DT_INVALID) {
627 // TODO(aminim): check type wrt input
628 Type arg_def_type;
629 TF_RETURN_IF_ERROR(
630 ConvertDataType(arg_def.type(), builder, &arg_def_type));
631 // Ensure each of the type in the list matches the op def type.
632 // TODO(aminim): can we improve the error message with the actual types?
633 for (AbstractTensorHandle* input : inputs)
634 if (arg_def_type != cast<MlirTensor>(input)->getElementType())
635 return InvalidArgument(
636 "Invalid input list: type mismatch the op def expectation");
637 } else if (!inputs.empty()) {
638 if (arg_def.type_attr().empty())
639 return FailedPrecondition(
640 "Invalid opdef type constraint: either type or type_attr required");
641
642 attrs_[arg_def.type_attr()] =
643 TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
644 }
645 } else if (!arg_def.type_list_attr().empty()) {
646 // TODO(aminim): handle ref variable.
647 SmallVector<Attribute, 8> types;
648 types.reserve(inputs.size());
649 for (AbstractTensorHandle* input : inputs)
650 types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
651 attrs_[arg_def.type_list_attr()] = ArrayAttr::get(GetContext(), types);
652 }
653 return Status::OK();
654 }
655
Finalize(OutputList * outputs,AbstractFunction ** f)656 Status MlirFunctionContext::Finalize(OutputList* outputs,
657 AbstractFunction** f) {
658 Block& body = func_.getBody().front();
659 SmallVector<Value, 8> ret_operands;
660 for (auto* output : outputs->outputs) {
661 auto* operand = dyn_cast<MlirTensor>(output);
662 if (!operand)
663 return InvalidArgument("Capturing eager tensors is not supported yet.");
664 if (operand->getValue().getContext() != context_.get())
665 return InvalidArgument(
666 "Capturing tensors from other context is not supported.");
667 ret_operands.push_back(operand->getValue());
668 }
669 builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
670
671 auto arg_types = body.getArgumentTypes();
672 auto result_types = body.getTerminator()->getOperandTypes();
673 func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
674 *f = new MlirFunction(std::move(context_), std::move(module_), func_);
675 return Status::OK();
676 }
677
678 extern "C" {
MlirTracingFactory(const char * fn_name,TF_Status * s)679 TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
680 return new MlirFunctionContext(fn_name);
681 }
682 }
683
684 } // namespace
685 } // namespace TF
686 } // namespace mlir
687