• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <memory>
18 
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Location.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
32 #include "mlir/Pass/PassManager.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "tensorflow/c/c_api.h"
35 #include "tensorflow/c/eager/abstract_context.h"
36 #include "tensorflow/c/eager/abstract_operation.h"
37 #include "tensorflow/c/eager/abstract_tensor_handle.h"
38 #include "tensorflow/c/eager/c_api.h"
39 #include "tensorflow/c/eager/c_api_internal.h"
40 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
41 #include "tensorflow/c/tensor_interface.h"
42 #include "tensorflow/c/tf_status.h"
43 #include "tensorflow/c/tf_status_helper.h"
44 #include "tensorflow/c/tf_status_internal.h"
45 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/tensor_shape.h"
57 #include "tensorflow/core/framework/types.pb.h"
58 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
59 #include "tensorflow/core/platform/errors.h"
60 
61 namespace mlir {
62 namespace TF {
63 using tensorflow::AbstractFunction;
64 using tensorflow::AbstractOperation;
65 using tensorflow::AbstractTensorHandle;
66 using tensorflow::AbstractTensorInterface;
67 using tensorflow::dyn_cast;
68 using tensorflow::OutputList;
69 using tensorflow::string;
70 using tensorflow::errors::FailedPrecondition;
71 using tensorflow::errors::InvalidArgument;
72 using tensorflow::errors::Unimplemented;
73 using tensorflow::tracing::TracingContext;
74 using tensorflow::tracing::TracingOperation;
75 using tensorflow::tracing::TracingTensorHandle;
76 
77 namespace {
78 
RegisterDialects(mlir::MLIRContext & ctx)79 void RegisterDialects(mlir::MLIRContext& ctx) {
80   mlir::DialectRegistry registry;
81   mlir::RegisterAllTensorFlowDialects(registry);
82   ctx.appendDialectRegistry(registry);
83   ctx.loadAllAvailableDialects();
84 }
85 
ConvertDataTypeToTensor(tensorflow::DataType dtype,Builder builder,Type * type)86 Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
87                                Type* type) {
88   Status s = tensorflow::ConvertDataType(dtype, builder, type);
89   if (s.ok()) *type = UnrankedTensorType::get(*type);
90   return s;
91 }
92 
93 class MlirTensor : public TracingTensorHandle {
94  public:
MlirTensor(Value value)95   explicit MlirTensor(Value value)
96       : TracingTensorHandle(kMlir), value_(value) {}
97 
DataType() const98   tensorflow::DataType DataType() const override {
99     tensorflow::DataType type;
100     Status s = ConvertToDataType(value_.getType(), &type);
101     if (!s.ok()) {
102       return tensorflow::DT_INVALID;
103     }
104     return type;
105   }
106 
Shape(tensorflow::PartialTensorShape * shape) const107   tensorflow::Status Shape(
108       tensorflow::PartialTensorShape* shape) const override {
109     // TODO(b/173074167): Implement this and enable tests in
110     // unified_api_test.cc.
111     return Unimplemented("MlirTensor::Shape is not implemented yet.");
112   }
113 
getValue()114   Value getValue() { return value_; }
getElementType()115   Type getElementType() {
116     return value_.getType().cast<ShapedType>().getElementType();
117   }
118 
119   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)120   static bool classof(const AbstractTensorHandle* ptr) {
121     return ptr->getKind() == kMlir;
122   }
123 
124  private:
125   Value value_;
126 };
127 
128 class MlirFunctionContext;
129 
130 class MlirAbstractOp : public TracingOperation {
131  public:
MlirAbstractOp(MLIRContext * context,MlirFunctionContext * function_context)132   explicit MlirAbstractOp(MLIRContext* context,
133                           MlirFunctionContext* function_context)
134       : TracingOperation(kMlir),
135         context_(context),
136         function_context_(function_context) {}
137 
Release()138   void Release() override { delete this; }
139 
140   Status Reset(const char* op, const char* raw_device_name) override;
141 
142   const string& Name() const override;
143 
144   const string& DeviceName() const override;
145 
146   Status SetDeviceName(const char* name) override;
147 
148   Status AddInput(AbstractTensorHandle* input) override;
149   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
150   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
151                  int* num_retvals) override;
152 
153   Status SetAttrString(const char* attr_name, const char* data,
154                        size_t length) override;
155   Status SetAttrInt(const char* attr_name, int64_t value) override;
156   Status SetAttrFloat(const char* attr_name, float value) override;
157   Status SetAttrBool(const char* attr_name, bool value) override;
158   Status SetAttrType(const char* attr_name,
159                      tensorflow::DataType dtype) override;
160   Status SetAttrShape(const char* attr_name, const int64_t* dims,
161                       const int num_dims) override;
162   Status SetAttrFunction(const char* attr_name,
163                          const AbstractOperation* value) override;
164   Status SetAttrFunctionName(const char* attr_name, const char* value,
165                              size_t length) override;
166   Status SetAttrTensor(const char* attr_name,
167                        AbstractTensorInterface* tensor) override;
168   Status SetAttrStringList(const char* attr_name, const void* const* values,
169                            const size_t* lengths, int num_values) override;
170   Status SetAttrFloatList(const char* attr_name, const float* values,
171                           int num_values) override;
172   Status SetAttrIntList(const char* attr_name, const int64_t* values,
173                         int num_values) override;
174   Status SetAttrTypeList(const char* attr_name,
175                          const tensorflow::DataType* values,
176                          int num_values) override;
177   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
178                          int num_values) override;
179   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
180                           const int* num_dims, int num_values) override;
181   Status SetAttrFunctionList(
182       const char* attr_name,
183       absl::Span<const AbstractOperation*> values) override;
184 
185   Status SetOpName(const char* const op_name) override;
186 
GetContext()187   MLIRContext* GetContext() { return context_; }
188 
189   Status AddRef(Type type, Type* output_type);
190 
191   Status Create(ArrayRef<Value> operands, OperationState**);
192 
193   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)194   static bool classof(const AbstractOperation* ptr) {
195     return ptr->getKind() == kMlir;
196   }
197 
198  private:
199   // Return true is there are still unfilled ODS slots for adding more inputs.
200   bool IsNextODSArgAvailable();
201 
202   MLIRContext* context_;
203   MlirFunctionContext* function_context_;
204   SmallVector<Value, 8> operands_;
205   llvm::StringMap<Attribute> attrs_;
206   std::unique_ptr<OperationState> state_;
207   // This is the index of the next ODS operand that will be added with AddInput
208   // or AddInput;
209   int current_ods_input_ = 0;
210   const tensorflow::OpDef* op_def_ = nullptr;
211   const char* op_name_ = nullptr;
212   string tf_op_type_;
213   // TODO(srbs): Use this.
214   string device_name_;
215 };
216 
217 // MlirFunction is a thin wrapper over a FuncOp.
218 class MlirFunction : public AbstractFunction {
219  public:
MlirFunction(std::unique_ptr<MLIRContext> context,OwningModuleRef module,FuncOp func)220   explicit MlirFunction(std::unique_ptr<MLIRContext> context,
221                         OwningModuleRef module, FuncOp func)
222       : AbstractFunction(kMlir),
223         context_(std::move(context)),
224         module_(std::move(module)),
225         func_(func) {}
226 
227   Status GetFunctionDef(tensorflow::FunctionDef** f) override;
228 
229   // For LLVM style RTTI.
classof(const AbstractFunction * ptr)230   static bool classof(const AbstractFunction* ptr) {
231     return ptr->getKind() == kMlir;
232   }
233 
234  private:
235   std::unique_ptr<MLIRContext> context_;
236   OwningModuleRef module_;
237   FuncOp func_;
238   std::unique_ptr<tensorflow::FunctionDef> fdef_;
239 };
240 
241 class MlirFunctionContext : public TracingContext {
242  public:
MlirFunctionContext(const char * name)243   explicit MlirFunctionContext(const char* name)
244       : TracingContext(kMlir),
245         context_(std::make_unique<MLIRContext>()),
246         builder_(context_.get()) {
247     RegisterDialects(*context_);
248     // TODO(aminim) figure out the location story here
249     module_ = ModuleOp::create(builder_.getUnknownLoc());
250     func_ = FuncOp::create(builder_.getUnknownLoc(), name,
251                            builder_.getFunctionType(llvm::None, llvm::None));
252     module_->push_back(func_);
253     builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
254   }
255 
Release()256   void Release() override { delete this; }
257 
CreateOperation()258   AbstractOperation* CreateOperation() override {
259     return new MlirAbstractOp(context_.get(), this);
260   }
261   Status AddParameter(tensorflow::DataType dtype,
262                       const tensorflow::PartialTensorShape& shape,
263                       TracingTensorHandle** handle) override;
264 
265   Status Finalize(OutputList* outputs, AbstractFunction** f) override;
266 
RegisterFunction(AbstractFunction * func)267   Status RegisterFunction(AbstractFunction* func) override {
268     return Unimplemented(
269         "Registering graph functions has not been implemented yet.");
270   }
271 
RemoveFunction(const string & func)272   Status RemoveFunction(const string& func) override {
273     return Unimplemented(
274         "MlirFunctionContext::RemoveFunction has not been implemented yet.");
275   }
276 
277   Operation* CreateOperationFromState(const OperationState& state);
278 
279  private:
280   std::unique_ptr<MLIRContext> context_;
281   OpBuilder builder_;
282   FuncOp func_;
283   OwningModuleRef module_;
284 };
285 
Reset(const char * op,const char * device_name)286 Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
287   if (state_) {
288     return FailedPrecondition("Reset called on already built op.");
289   }
290   TF_RETURN_IF_ERROR(
291       tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
292   assert(op_def_);
293 
294   tf_op_type_ = op;
295   std::string name = "tf.";
296   name += op;
297   // TODO(aminim) figure out the location story here
298   state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
299   return Status::OK();
300 }
301 
SetAttrType(const char * attr_name,tensorflow::DataType dtype)302 Status MlirAbstractOp::SetAttrType(const char* attr_name,
303                                    tensorflow::DataType dtype) {
304   if (!state_)
305     return FailedPrecondition(
306         "op_type must be specified before specifying attrs.");
307   Type mlir_type;
308   Builder builder(context_);
309   TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
310   attrs_[attr_name] = TypeAttr::get(mlir_type);
311   return Status::OK();
312 }
313 
SetOpName(const char * const op_name)314 Status MlirAbstractOp::SetOpName(const char* const op_name) {
315   // TODO(aminim): should we use a location?
316   if (op_name_) {
317     return FailedPrecondition("SetOpName called on already built op.");
318   }
319   op_name_ = op_name;
320   return Status::OK();
321 }
322 
AddRef(Type type,Type * output_type)323 Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
324   Type elt_type = getElementTypeOrSelf(type);
325   if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
326     return InvalidArgument("Requested reference to a reference type");
327   }
328   elt_type = TensorFlowRefType::get(elt_type);
329   if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
330     *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type);
331   }
332   *output_type = UnrankedTensorType::get(elt_type);
333   return Status::OK();
334 }
335 
Create(ArrayRef<Value> operands,OperationState ** state)336 Status MlirAbstractOp::Create(ArrayRef<Value> operands,
337                               OperationState** state) {
338   state_->operands = llvm::to_vector<4>(operands);
339   Builder builder(context_);
340 
341   if (current_ods_input_ != op_def_->input_arg_size())
342     return InvalidArgument(absl::StrCat("Mismatch in operands number: got ",
343                                         current_ods_input_, " expected ",
344                                         op_def_->input_arg_size(), " ; for op ",
345                                         state_->name.getStringRef().str()));
346 
347   // Process results according to the op_def and infer types for derived
348   // attributes.
349   for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
350     int original_size = state_->types.size();
351     if (!output_arg.number_attr().empty()) {
352       // Same type repeated "repeats" times.
353       Attribute repeats_attr = attrs_[output_arg.number_attr()];
354       if (!repeats_attr)
355         return InvalidArgument("Missing attribute '", output_arg.number_attr(),
356                                "' required for output list '",
357                                output_arg.name(), "'");
358       if (!repeats_attr.isa<IntegerAttr>())
359         return InvalidArgument("Attribute '", output_arg.number_attr(),
360                                "' required for output list '",
361                                output_arg.name(), "' isn't an integer");
362       int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
363 
364       if (!output_arg.type_attr().empty()) {
365         // Same type repeated "repeats" times.
366         Attribute attr = attrs_[output_arg.type_attr()];
367         if (!attr)
368           return InvalidArgument("Missing attribute '", output_arg.type_attr(),
369                                  "' required for output '", output_arg.name(),
370                                  "'");
371         TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
372         if (!type_attr)
373           return InvalidArgument("Attribute '", output_arg.type_attr(),
374                                  "' required for output '", output_arg.name(),
375                                  "' isn't a type attribute");
376         for (int i = 0; i < repeats; ++i)
377           state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
378       } else if (output_arg.type() != tensorflow::DT_INVALID) {
379         for (int i = 0; i < repeats; ++i) {
380           Type type;
381           TF_RETURN_IF_ERROR(
382               ConvertDataType(output_arg.type(), builder, &type));
383           state_->types.push_back(type);
384         }
385       } else {
386         return InvalidArgument("Missing type or type_attr field in ",
387                                output_arg.ShortDebugString());
388       }
389     } else if (!output_arg.type_attr().empty()) {
390       Attribute attr = attrs_[output_arg.type_attr()];
391       if (!attr)
392         return InvalidArgument("Missing attribute '", output_arg.type_attr(),
393                                "' required for output '", output_arg.name(),
394                                "'");
395       TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
396       if (!type_attr)
397         return InvalidArgument("Attribute '", output_arg.type_attr(),
398                                "' required for output '", output_arg.name(),
399                                "' isn't a type attribute");
400       state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
401     } else if (!output_arg.type_list_attr().empty()) {
402       // This is pointing to an attribute which is an array of types.
403       Attribute attr = attrs_[output_arg.type_list_attr()];
404       if (!attr)
405         return InvalidArgument(
406             "Missing attribute '", output_arg.type_list_attr(),
407             "' required for output '", output_arg.name(), "'");
408       ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
409       if (!array_attr)
410         return InvalidArgument("Attribute '", output_arg.type_list_attr(),
411                                "' required for output '", output_arg.name(),
412                                "' isn't an array attribute");
413       for (Attribute attr : array_attr) {
414         TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
415         if (!type_attr)
416           return InvalidArgument("Array Attribute '",
417                                  output_arg.type_list_attr(),
418                                  "' required for output '", output_arg.name(),
419                                  "' has a non-Type element");
420         state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
421       }
422     } else if (output_arg.type() != tensorflow::DT_INVALID) {
423       Type type;
424       Builder builder(context_);
425       TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
426       state_->types.push_back(type);
427     } else {
428       return InvalidArgument("No type fields in ",
429                              output_arg.ShortDebugString());
430     }
431     if (output_arg.is_ref()) {
432       // For all types that were added by this function call, make them refs.
433       for (Type& type : llvm::make_range(&state_->types[original_size],
434                                          state_->types.end())) {
435         Type output_type;
436         TF_RETURN_IF_ERROR(AddRef(type, &output_type));
437         type = output_type;
438       }
439     }
440   }
441   for (auto& it : attrs_) state_->addAttribute(it.first(), it.second);
442   *state = state_.get();
443   return Status::OK();
444 }
445 
Name() const446 const string& MlirAbstractOp::Name() const { return tf_op_type_; }
447 
DeviceName() const448 const string& MlirAbstractOp::DeviceName() const { return device_name_; }
449 
SetDeviceName(const char * name)450 Status MlirAbstractOp::SetDeviceName(const char* name) {
451   device_name_ = name;
452   return Status::OK();
453 }
454 
SetAttrString(const char * attr_name,const char * data,size_t length)455 Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
456                                      size_t length) {
457   return Unimplemented("SetAttrString has not been implemented yet.");
458 }
SetAttrInt(const char * attr_name,int64_t value)459 Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) {
460   return Unimplemented("SetAttrInt has not been implemented yet.");
461 }
SetAttrFloat(const char * attr_name,float value)462 Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
463   return Unimplemented("SetAttrFloat has not been implemented yet.");
464 }
SetAttrBool(const char * attr_name,bool value)465 Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
466   attrs_[attr_name] = BoolAttr::get(context_, value);
467   return Status::OK();
468 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)469 Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
470                                     const int num_dims) {
471   return Unimplemented("SetAttrShape has not been implemented yet.");
472 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)473 Status MlirAbstractOp::SetAttrFunction(const char* attr_name,
474                                        const AbstractOperation* value) {
475   return Unimplemented("SetAttrFunction has not been implemented yet.");
476 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)477 Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name,
478                                            const char* value, size_t length) {
479   return Unimplemented("SetAttrFunctionName has not been implemented yet.");
480 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)481 Status MlirAbstractOp::SetAttrTensor(const char* attr_name,
482                                      AbstractTensorInterface* tensor) {
483   return Unimplemented("SetAttrTensor has not been implemented yet.");
484 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)485 Status MlirAbstractOp::SetAttrStringList(const char* attr_name,
486                                          const void* const* values,
487                                          const size_t* lengths,
488                                          int num_values) {
489   return Unimplemented("SetAttrStringList has not been implemented yet.");
490 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)491 Status MlirAbstractOp::SetAttrFloatList(const char* attr_name,
492                                         const float* values, int num_values) {
493   return Unimplemented("SetAttrFloatList has not been implemented yet.");
494 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)495 Status MlirAbstractOp::SetAttrIntList(const char* attr_name,
496                                       const int64_t* values, int num_values) {
497   return Unimplemented("SetAttrIntList has not been implemented yet.");
498 }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)499 Status MlirAbstractOp::SetAttrTypeList(const char* attr_name,
500                                        const tensorflow::DataType* values,
501                                        int num_values) {
502   return Unimplemented("SetAttrTypeList has not been implemented yet.");
503 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)504 Status MlirAbstractOp::SetAttrBoolList(const char* attr_name,
505                                        const unsigned char* values,
506                                        int num_values) {
507   return Unimplemented("SetAttrBoolList has not been implemented yet.");
508 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)509 Status MlirAbstractOp::SetAttrShapeList(const char* attr_name,
510                                         const int64_t** dims,
511                                         const int* num_dims, int num_values) {
512   return Unimplemented("SetAttrShapeList has not been implemented yet.");
513 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)514 Status MlirAbstractOp::SetAttrFunctionList(
515     const char* attr_name, absl::Span<const AbstractOperation*> values) {
516   return Unimplemented("SetAttrFunctionList has not been implemented yet.");
517 }
518 
GetFunctionDef(tensorflow::FunctionDef ** f)519 Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
520   if (fdef_) {
521     *f = fdef_.get();
522     return Status::OK();
523   }
524   PassManager pm(func_.getContext());
525   ::tensorflow::applyTensorflowAndCLOptions(pm);
526   pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
527   pm.addPass(CreateBreakUpIslandsPass());
528 
529   // In case of failure, the `diag_handler` converts MLIR errors emitted to
530   // the MLIRContext into a tensorflow::Status.
531   StatusScopedDiagnosticHandler diag_handler(func_.getContext());
532   LogicalResult result = pm.run(func_->getParentOfType<ModuleOp>());
533   (void)result;
534   TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
535 
536   tensorflow::GraphExportConfig configs;
537   fdef_.reset(new tensorflow::FunctionDef());
538   TF_RETURN_IF_ERROR(
539       ConvertMlirFunctionToFunctionLibraryDef(func_, configs, fdef_.get()));
540   *f = fdef_.get();
541   return Status::OK();
542 }
543 
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)544 Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals,
545                                int* num_retvals) {
546   OperationState* state;
547   TF_RETURN_IF_ERROR(Create(operands_, &state));
548   Operation* op = function_context_->CreateOperationFromState(*state);
549   *num_retvals = op->getNumResults();
550   for (int i = 0; i < *num_retvals; i++)
551     retvals[i] = new MlirTensor(op->getResult(i));
552   return Status::OK();
553 }
554 
CreateOperationFromState(const OperationState & state)555 Operation* MlirFunctionContext::CreateOperationFromState(
556     const OperationState& state) {
557   return builder_.createOperation(state);
558 }
559 
AddParameter(tensorflow::DataType dtype,const tensorflow::PartialTensorShape & shape,TracingTensorHandle ** handle)560 Status MlirFunctionContext::AddParameter(
561     tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
562     TracingTensorHandle** handle) {
563   // TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
564   // resolved.
565   Type type;
566   TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
567   *handle = new MlirTensor(func_.getBody().front().addArgument(type));
568   return Status::OK();
569 }
570 
AddInput(AbstractTensorHandle * input)571 Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
572   if (current_ods_input_ >= op_def_->input_arg_size())
573     return InvalidArgument(
574         absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
575                      op_def_->input_arg_size(), " allowed input_args ; for op ",
576                      state_->name.getStringRef().str()));
577 
578   auto* operand = dyn_cast<MlirTensor>(input);
579   if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
580   operands_.push_back(operand->getValue());
581 
582   // Get the next ArgDef and use it to infer the derived attributes associated
583   // to this input.
584   const tensorflow::OpDef::ArgDef& arg_def =
585       op_def_->input_arg(current_ods_input_++);
586   Type expected_type;
587   if (arg_def.type() != tensorflow::DT_INVALID) {
588     Builder builder(context_);
589     TF_RETURN_IF_ERROR(
590         tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
591     if (arg_def.is_ref()) {
592       Type output_type;
593       TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
594       expected_type = output_type;
595     }
596   } else {
597     expected_type = cast<MlirTensor>(input)->getElementType();
598   }
599   if (!arg_def.type_attr().empty())
600     attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
601 
602   return Status::OK();
603 }
604 
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)605 Status MlirAbstractOp::AddInputList(
606     absl::Span<AbstractTensorHandle* const> inputs) {
607   if (current_ods_input_ >= op_def_->input_arg_size())
608     return InvalidArgument(
609         absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
610                      op_def_->input_arg_size(), " allowed input_args"));
611 
612   for (AbstractTensorHandle* input : inputs) {
613     auto* operand = dyn_cast<MlirTensor>(input);
614     if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
615     operands_.push_back(operand->getValue());
616   }
617 
618   // Get the next ArgDef and use it to infer the derived attributes associated
619   // to this input.
620   const tensorflow::OpDef::ArgDef& arg_def =
621       op_def_->input_arg(current_ods_input_++);
622   if (!arg_def.number_attr().empty()) {
623     Builder builder(context_);
624     attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
625     // TODO(aminim): handle ref variable.
626     if (arg_def.type() != tensorflow::DT_INVALID) {
627       // TODO(aminim): check type wrt input
628       Type arg_def_type;
629       TF_RETURN_IF_ERROR(
630           ConvertDataType(arg_def.type(), builder, &arg_def_type));
631       // Ensure each of the type in the list matches the op def type.
632       // TODO(aminim): can we improve the error message with the actual types?
633       for (AbstractTensorHandle* input : inputs)
634         if (arg_def_type != cast<MlirTensor>(input)->getElementType())
635           return InvalidArgument(
636               "Invalid input list: type mismatch the op def expectation");
637     } else if (!inputs.empty()) {
638       if (arg_def.type_attr().empty())
639         return FailedPrecondition(
640             "Invalid opdef type constraint: either type or type_attr required");
641 
642       attrs_[arg_def.type_attr()] =
643           TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
644     }
645   } else if (!arg_def.type_list_attr().empty()) {
646     // TODO(aminim): handle ref variable.
647     SmallVector<Attribute, 8> types;
648     types.reserve(inputs.size());
649     for (AbstractTensorHandle* input : inputs)
650       types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
651     attrs_[arg_def.type_list_attr()] = ArrayAttr::get(GetContext(), types);
652   }
653   return Status::OK();
654 }
655 
Finalize(OutputList * outputs,AbstractFunction ** f)656 Status MlirFunctionContext::Finalize(OutputList* outputs,
657                                      AbstractFunction** f) {
658   Block& body = func_.getBody().front();
659   SmallVector<Value, 8> ret_operands;
660   for (auto* output : outputs->outputs) {
661     auto* operand = dyn_cast<MlirTensor>(output);
662     if (!operand)
663       return InvalidArgument("Capturing eager tensors is not supported yet.");
664     if (operand->getValue().getContext() != context_.get())
665       return InvalidArgument(
666           "Capturing tensors from other context is not supported.");
667     ret_operands.push_back(operand->getValue());
668   }
669   builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
670 
671   auto arg_types = body.getArgumentTypes();
672   auto result_types = body.getTerminator()->getOperandTypes();
673   func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
674   *f = new MlirFunction(std::move(context_), std::move(module_), func_);
675   return Status::OK();
676 }
677 
678 extern "C" {
MlirTracingFactory(const char * fn_name,TF_Status * s)679 TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
680   return new MlirFunctionContext(fn_name);
681 }
682 }
683 
684 }  // namespace
685 }  // namespace TF
686 }  // namespace mlir
687