• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <memory>
18 
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Location.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
32 #include "mlir/Pass/PassManager.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "tensorflow/c/c_api.h"
35 #include "tensorflow/c/eager/abstract_context.h"
36 #include "tensorflow/c/eager/abstract_operation.h"
37 #include "tensorflow/c/eager/abstract_tensor_handle.h"
38 #include "tensorflow/c/eager/c_api.h"
39 #include "tensorflow/c/eager/c_api_internal.h"
40 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
41 #include "tensorflow/c/tensor_interface.h"
42 #include "tensorflow/c/tf_status.h"
43 #include "tensorflow/c/tf_status_helper.h"
44 #include "tensorflow/c/tf_status_internal.h"
45 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/tensor_shape.h"
57 #include "tensorflow/core/framework/types.pb.h"
58 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
59 #include "tensorflow/core/platform/errors.h"
60 
61 namespace mlir {
62 namespace TF {
63 using tensorflow::AbstractFunction;
64 using tensorflow::AbstractOperation;
65 using tensorflow::AbstractTensorHandle;
66 using tensorflow::AbstractTensorInterface;
67 using tensorflow::dyn_cast;
68 using tensorflow::OutputList;
69 using tensorflow::string;
70 using tensorflow::errors::FailedPrecondition;
71 using tensorflow::errors::InvalidArgument;
72 using tensorflow::errors::Unimplemented;
73 using tensorflow::tracing::TracingContext;
74 using tensorflow::tracing::TracingOperation;
75 using tensorflow::tracing::TracingTensorHandle;
76 
77 namespace {
78 
RegisterDialects(mlir::MLIRContext & ctx)79 void RegisterDialects(mlir::MLIRContext& ctx) {
80   mlir::DialectRegistry registry;
81   mlir::RegisterAllTensorFlowDialects(registry);
82   ctx.appendDialectRegistry(registry);
83   ctx.loadAllAvailableDialects();
84 }
85 
ConvertDataTypeToTensor(tensorflow::DataType dtype,Builder builder,Type * type)86 Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
87                                Type* type) {
88   Status s = tensorflow::ConvertDataType(dtype, builder, type);
89   if (s.ok()) *type = UnrankedTensorType::get(*type);
90   return s;
91 }
92 
93 class MlirTensor : public TracingTensorHandle {
94  public:
MlirTensor(Value value)95   explicit MlirTensor(Value value)
96       : TracingTensorHandle(kMlir), value_(value) {}
97 
DataType() const98   tensorflow::DataType DataType() const override {
99     tensorflow::DataType type;
100     Status s = ConvertToDataType(value_.getType(), &type);
101     if (!s.ok()) {
102       return tensorflow::DT_INVALID;
103     }
104     return type;
105   }
106 
Shape(tensorflow::PartialTensorShape * shape) const107   tensorflow::Status Shape(
108       tensorflow::PartialTensorShape* shape) const override {
109     // TODO(b/173074167): Implement this and enable tests in
110     // unified_api_test.cc.
111     return Unimplemented("MlirTensor::Shape is not implemented yet.");
112   }
113 
getValue()114   Value getValue() { return value_; }
getElementType()115   Type getElementType() {
116     return value_.getType().cast<ShapedType>().getElementType();
117   }
118 
119   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)120   static bool classof(const AbstractTensorHandle* ptr) {
121     return ptr->getKind() == kMlir;
122   }
123 
124  private:
125   Value value_;
126 };
127 
128 class MlirFunctionContext;
129 
130 class MlirAbstractOp : public TracingOperation {
131  public:
MlirAbstractOp(MLIRContext * context,MlirFunctionContext * function_context)132   explicit MlirAbstractOp(MLIRContext* context,
133                           MlirFunctionContext* function_context)
134       : TracingOperation(kMlir),
135         context_(context),
136         function_context_(function_context) {}
137 
Release()138   void Release() override { delete this; }
139 
140   Status Reset(const char* op, const char* raw_device_name) override;
141 
142   const string& Name() const override;
143 
144   const string& DeviceName() const override;
145 
146   Status SetDeviceName(const char* name) override;
147 
148   Status AddInput(AbstractTensorHandle* input) override;
149   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
150   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
151                  int* num_retvals) override;
152 
153   Status SetAttrString(const char* attr_name, const char* data,
154                        size_t length) override;
155   Status SetAttrInt(const char* attr_name, int64_t value) override;
156   Status SetAttrFloat(const char* attr_name, float value) override;
157   Status SetAttrBool(const char* attr_name, bool value) override;
158   Status SetAttrType(const char* attr_name,
159                      tensorflow::DataType dtype) override;
160   Status SetAttrShape(const char* attr_name, const int64_t* dims,
161                       const int num_dims) override;
162   Status SetAttrFunction(const char* attr_name,
163                          const AbstractOperation* value) override;
164   Status SetAttrFunctionName(const char* attr_name, const char* value,
165                              size_t length) override;
166   Status SetAttrTensor(const char* attr_name,
167                        AbstractTensorInterface* tensor) override;
168   Status SetAttrStringList(const char* attr_name, const void* const* values,
169                            const size_t* lengths, int num_values) override;
170   Status SetAttrFloatList(const char* attr_name, const float* values,
171                           int num_values) override;
172   Status SetAttrIntList(const char* attr_name, const int64_t* values,
173                         int num_values) override;
174   Status SetAttrTypeList(const char* attr_name,
175                          const tensorflow::DataType* values,
176                          int num_values) override;
177   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
178                          int num_values) override;
179   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
180                           const int* num_dims, int num_values) override;
181   Status SetAttrFunctionList(
182       const char* attr_name,
183       absl::Span<const AbstractOperation*> values) override;
184 
185   Status SetOpName(const char* const op_name) override;
186 
GetContext()187   MLIRContext* GetContext() { return context_; }
188 
189   Status AddRef(Type type, Type* output_type);
190 
191   Status Create(ArrayRef<Value> operands, OperationState**);
192 
193   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)194   static bool classof(const AbstractOperation* ptr) {
195     return ptr->getKind() == kMlir;
196   }
197 
198  private:
199   // Return true is there are still unfilled ODS slots for adding more inputs.
200   bool IsNextODSArgAvailable();
201 
202   MLIRContext* context_;
203   MlirFunctionContext* function_context_;
204   SmallVector<Value, 8> operands_;
205   llvm::StringMap<Attribute> attrs_;
206   std::unique_ptr<OperationState> state_;
207   // This is the index of the next ODS operand that will be added with AddInput
208   // or AddInput;
209   int current_ods_input_ = 0;
210   const tensorflow::OpDef* op_def_ = nullptr;
211   const char* op_name_ = nullptr;
212   string tf_op_type_;
213   // TODO(srbs): Use this.
214   string device_name_;
215 };
216 
217 // MlirFunction is a thin wrapper over a FuncOp.
218 class MlirFunction : public AbstractFunction {
219  public:
MlirFunction(std::unique_ptr<MLIRContext> context,OwningOpRef<mlir::ModuleOp> module,func::FuncOp func)220   explicit MlirFunction(std::unique_ptr<MLIRContext> context,
221                         OwningOpRef<mlir::ModuleOp> module, func::FuncOp func)
222       : AbstractFunction(kMlir),
223         context_(std::move(context)),
224         module_(std::move(module)),
225         func_(func) {}
226 
227   Status GetFunctionDef(tensorflow::FunctionDef** f) override;
228 
229   // For LLVM style RTTI.
classof(const AbstractFunction * ptr)230   static bool classof(const AbstractFunction* ptr) {
231     return ptr->getKind() == kMlir;
232   }
233 
234  private:
235   std::unique_ptr<MLIRContext> context_;
236   OwningOpRef<mlir::ModuleOp> module_;
237   func::FuncOp func_;
238   std::unique_ptr<tensorflow::FunctionDef> fdef_;
239 };
240 
241 class MlirFunctionContext : public TracingContext {
242  public:
MlirFunctionContext(const char * name)243   explicit MlirFunctionContext(const char* name)
244       : TracingContext(kMlir),
245         context_(std::make_unique<MLIRContext>()),
246         builder_(context_.get()) {
247     RegisterDialects(*context_);
248     // TODO(aminim) figure out the location story here
249     module_ = ModuleOp::create(builder_.getUnknownLoc());
250     func_ =
251         func::FuncOp::create(builder_.getUnknownLoc(), name,
252                              builder_.getFunctionType(llvm::None, llvm::None));
253     module_->push_back(func_);
254     builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
255   }
256 
Release()257   void Release() override { delete this; }
258 
CreateOperation()259   AbstractOperation* CreateOperation() override {
260     return new MlirAbstractOp(context_.get(), this);
261   }
262   Status AddParameter(tensorflow::DataType dtype,
263                       const tensorflow::PartialTensorShape& shape,
264                       TracingTensorHandle** handle) override;
265 
266   Status Finalize(OutputList* outputs, AbstractFunction** f) override;
267 
RegisterFunction(AbstractFunction * func)268   Status RegisterFunction(AbstractFunction* func) override {
269     return Unimplemented(
270         "Registering graph functions has not been implemented yet.");
271   }
272 
RemoveFunction(const string & func)273   Status RemoveFunction(const string& func) override {
274     return Unimplemented(
275         "MlirFunctionContext::RemoveFunction has not been implemented yet.");
276   }
277 
278   Operation* CreateOperationFromState(const OperationState& state);
279 
280  private:
281   std::unique_ptr<MLIRContext> context_;
282   OpBuilder builder_;
283   func::FuncOp func_;
284   OwningOpRef<mlir::ModuleOp> module_;
285 };
286 
Reset(const char * op,const char * device_name)287 Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
288   if (state_) {
289     return FailedPrecondition("Reset called on already built op.");
290   }
291   TF_RETURN_IF_ERROR(
292       tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
293   assert(op_def_);
294 
295   tf_op_type_ = op;
296   std::string name = "tf.";
297   name += op;
298   // TODO(aminim) figure out the location story here
299   state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
300   return ::tensorflow::OkStatus();
301 }
302 
SetAttrType(const char * attr_name,tensorflow::DataType dtype)303 Status MlirAbstractOp::SetAttrType(const char* attr_name,
304                                    tensorflow::DataType dtype) {
305   if (!state_)
306     return FailedPrecondition(
307         "op_type must be specified before specifying attrs.");
308   Type mlir_type;
309   Builder builder(context_);
310   TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
311   attrs_[attr_name] = TypeAttr::get(mlir_type);
312   return ::tensorflow::OkStatus();
313 }
314 
SetOpName(const char * const op_name)315 Status MlirAbstractOp::SetOpName(const char* const op_name) {
316   // TODO(aminim): should we use a location?
317   if (op_name_) {
318     return FailedPrecondition("SetOpName called on already built op.");
319   }
320   op_name_ = op_name;
321   return ::tensorflow::OkStatus();
322 }
323 
AddRef(Type type,Type * output_type)324 Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
325   Type elt_type = getElementTypeOrSelf(type);
326   if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
327     return InvalidArgument("Requested reference to a reference type");
328   }
329   elt_type = TensorFlowRefType::get(elt_type);
330   if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
331     *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type);
332   }
333   *output_type = UnrankedTensorType::get(elt_type);
334   return ::tensorflow::OkStatus();
335 }
336 
Create(ArrayRef<Value> operands,OperationState ** state)337 Status MlirAbstractOp::Create(ArrayRef<Value> operands,
338                               OperationState** state) {
339   state_->operands = llvm::to_vector<4>(operands);
340   Builder builder(context_);
341 
342   if (current_ods_input_ != op_def_->input_arg_size())
343     return InvalidArgument(absl::StrCat("Mismatch in operands number: got ",
344                                         current_ods_input_, " expected ",
345                                         op_def_->input_arg_size(), " ; for op ",
346                                         state_->name.getStringRef().str()));
347 
348   // Process results according to the op_def and infer types for derived
349   // attributes.
350   for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
351     int original_size = state_->types.size();
352     if (!output_arg.number_attr().empty()) {
353       // Same type repeated "repeats" times.
354       Attribute repeats_attr = attrs_[output_arg.number_attr()];
355       if (!repeats_attr)
356         return InvalidArgument("Missing attribute '", output_arg.number_attr(),
357                                "' required for output list '",
358                                output_arg.name(), "'");
359       if (!repeats_attr.isa<IntegerAttr>())
360         return InvalidArgument("Attribute '", output_arg.number_attr(),
361                                "' required for output list '",
362                                output_arg.name(), "' isn't an integer");
363       int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
364 
365       if (!output_arg.type_attr().empty()) {
366         // Same type repeated "repeats" times.
367         Attribute attr = attrs_[output_arg.type_attr()];
368         if (!attr)
369           return InvalidArgument("Missing attribute '", output_arg.type_attr(),
370                                  "' required for output '", output_arg.name(),
371                                  "'");
372         TypedAttr type_attr = attr.dyn_cast<TypedAttr>();
373         if (!type_attr)
374           return InvalidArgument("Attribute '", output_arg.type_attr(),
375                                  "' required for output '", output_arg.name(),
376                                  "' isn't a type attribute");
377         for (int i = 0; i < repeats; ++i)
378           state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
379       } else if (output_arg.type() != tensorflow::DT_INVALID) {
380         for (int i = 0; i < repeats; ++i) {
381           Type type;
382           TF_RETURN_IF_ERROR(
383               ConvertDataType(output_arg.type(), builder, &type));
384           state_->types.push_back(type);
385         }
386       } else {
387         return InvalidArgument("Missing type or type_attr field in ",
388                                output_arg.ShortDebugString());
389       }
390     } else if (!output_arg.type_attr().empty()) {
391       Attribute attr = attrs_[output_arg.type_attr()];
392       if (!attr)
393         return InvalidArgument("Missing attribute '", output_arg.type_attr(),
394                                "' required for output '", output_arg.name(),
395                                "'");
396       TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
397       if (!type_attr)
398         return InvalidArgument("Attribute '", output_arg.type_attr(),
399                                "' required for output '", output_arg.name(),
400                                "' isn't a type attribute");
401       state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
402     } else if (!output_arg.type_list_attr().empty()) {
403       // This is pointing to an attribute which is an array of types.
404       Attribute attr = attrs_[output_arg.type_list_attr()];
405       if (!attr)
406         return InvalidArgument(
407             "Missing attribute '", output_arg.type_list_attr(),
408             "' required for output '", output_arg.name(), "'");
409       ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
410       if (!array_attr)
411         return InvalidArgument("Attribute '", output_arg.type_list_attr(),
412                                "' required for output '", output_arg.name(),
413                                "' isn't an array attribute");
414       for (Attribute attr : array_attr) {
415         TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
416         if (!type_attr)
417           return InvalidArgument("Array Attribute '",
418                                  output_arg.type_list_attr(),
419                                  "' required for output '", output_arg.name(),
420                                  "' has a non-Type element");
421         state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
422       }
423     } else if (output_arg.type() != tensorflow::DT_INVALID) {
424       Type type;
425       Builder builder(context_);
426       TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
427       state_->types.push_back(type);
428     } else {
429       return InvalidArgument("No type fields in ",
430                              output_arg.ShortDebugString());
431     }
432     if (output_arg.is_ref()) {
433       // For all types that were added by this function call, make them refs.
434       for (Type& type : llvm::make_range(&state_->types[original_size],
435                                          state_->types.end())) {
436         Type output_type;
437         TF_RETURN_IF_ERROR(AddRef(type, &output_type));
438         type = output_type;
439       }
440     }
441   }
442   for (auto& it : attrs_) state_->addAttribute(it.first(), it.second);
443   *state = state_.get();
444   return ::tensorflow::OkStatus();
445 }
446 
Name() const447 const string& MlirAbstractOp::Name() const { return tf_op_type_; }
448 
DeviceName() const449 const string& MlirAbstractOp::DeviceName() const { return device_name_; }
450 
SetDeviceName(const char * name)451 Status MlirAbstractOp::SetDeviceName(const char* name) {
452   device_name_ = name;
453   return ::tensorflow::OkStatus();
454 }
455 
SetAttrString(const char * attr_name,const char * data,size_t length)456 Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
457                                      size_t length) {
458   return Unimplemented("SetAttrString has not been implemented yet.");
459 }
SetAttrInt(const char * attr_name,int64_t value)460 Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) {
461   return Unimplemented("SetAttrInt has not been implemented yet.");
462 }
SetAttrFloat(const char * attr_name,float value)463 Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
464   return Unimplemented("SetAttrFloat has not been implemented yet.");
465 }
SetAttrBool(const char * attr_name,bool value)466 Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
467   attrs_[attr_name] = BoolAttr::get(context_, value);
468   return ::tensorflow::OkStatus();
469 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)470 Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
471                                     const int num_dims) {
472   return Unimplemented("SetAttrShape has not been implemented yet.");
473 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)474 Status MlirAbstractOp::SetAttrFunction(const char* attr_name,
475                                        const AbstractOperation* value) {
476   return Unimplemented("SetAttrFunction has not been implemented yet.");
477 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)478 Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name,
479                                            const char* value, size_t length) {
480   return Unimplemented("SetAttrFunctionName has not been implemented yet.");
481 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)482 Status MlirAbstractOp::SetAttrTensor(const char* attr_name,
483                                      AbstractTensorInterface* tensor) {
484   return Unimplemented("SetAttrTensor has not been implemented yet.");
485 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)486 Status MlirAbstractOp::SetAttrStringList(const char* attr_name,
487                                          const void* const* values,
488                                          const size_t* lengths,
489                                          int num_values) {
490   return Unimplemented("SetAttrStringList has not been implemented yet.");
491 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)492 Status MlirAbstractOp::SetAttrFloatList(const char* attr_name,
493                                         const float* values, int num_values) {
494   return Unimplemented("SetAttrFloatList has not been implemented yet.");
495 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)496 Status MlirAbstractOp::SetAttrIntList(const char* attr_name,
497                                       const int64_t* values, int num_values) {
498   return Unimplemented("SetAttrIntList has not been implemented yet.");
499 }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)500 Status MlirAbstractOp::SetAttrTypeList(const char* attr_name,
501                                        const tensorflow::DataType* values,
502                                        int num_values) {
503   return Unimplemented("SetAttrTypeList has not been implemented yet.");
504 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)505 Status MlirAbstractOp::SetAttrBoolList(const char* attr_name,
506                                        const unsigned char* values,
507                                        int num_values) {
508   return Unimplemented("SetAttrBoolList has not been implemented yet.");
509 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)510 Status MlirAbstractOp::SetAttrShapeList(const char* attr_name,
511                                         const int64_t** dims,
512                                         const int* num_dims, int num_values) {
513   return Unimplemented("SetAttrShapeList has not been implemented yet.");
514 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)515 Status MlirAbstractOp::SetAttrFunctionList(
516     const char* attr_name, absl::Span<const AbstractOperation*> values) {
517   return Unimplemented("SetAttrFunctionList has not been implemented yet.");
518 }
519 
GetFunctionDef(tensorflow::FunctionDef ** f)520 Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
521   if (fdef_) {
522     *f = fdef_.get();
523     return ::tensorflow::OkStatus();
524   }
525   PassManager pm(func_.getContext());
526   ::tensorflow::applyTensorflowAndCLOptions(pm);
527   pm.addNestedPass<func::FuncOp>(
528       CreateFunctionalToExecutorDialectConversionPass());
529   pm.addPass(CreateBreakUpIslandsPass());
530 
531   // In case of failure, the `diag_handler` converts MLIR errors emitted to
532   // the MLIRContext into a tensorflow::Status.
533   StatusScopedDiagnosticHandler diag_handler(func_.getContext());
534   LogicalResult result = pm.run(func_->getParentOfType<ModuleOp>());
535   (void)result;
536   TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
537 
538   tensorflow::GraphExportConfig configs;
539   fdef_.reset(new tensorflow::FunctionDef());
540   TF_RETURN_IF_ERROR(
541       ConvertMlirFunctionToFunctionLibraryDef(func_, configs, fdef_.get()));
542   *f = fdef_.get();
543   return ::tensorflow::OkStatus();
544 }
545 
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)546 Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals,
547                                int* num_retvals) {
548   OperationState* state;
549   TF_RETURN_IF_ERROR(Create(operands_, &state));
550   Operation* op = function_context_->CreateOperationFromState(*state);
551   *num_retvals = op->getNumResults();
552   for (int i = 0; i < *num_retvals; i++)
553     retvals[i] = new MlirTensor(op->getResult(i));
554   return ::tensorflow::OkStatus();
555 }
556 
CreateOperationFromState(const OperationState & state)557 Operation* MlirFunctionContext::CreateOperationFromState(
558     const OperationState& state) {
559   return builder_.create(state);
560 }
561 
AddParameter(tensorflow::DataType dtype,const tensorflow::PartialTensorShape & shape,TracingTensorHandle ** handle)562 Status MlirFunctionContext::AddParameter(
563     tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
564     TracingTensorHandle** handle) {
565   // TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
566   // resolved.
567   Type type;
568   TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
569   *handle =
570       new MlirTensor(func_.getBody().front().addArgument(type, func_.getLoc()));
571   return ::tensorflow::OkStatus();
572 }
573 
AddInput(AbstractTensorHandle * input)574 Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
575   if (current_ods_input_ >= op_def_->input_arg_size())
576     return InvalidArgument(
577         absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
578                      op_def_->input_arg_size(), " allowed input_args ; for op ",
579                      state_->name.getStringRef().str()));
580 
581   auto* operand = dyn_cast<MlirTensor>(input);
582   if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
583   operands_.push_back(operand->getValue());
584 
585   // Get the next ArgDef and use it to infer the derived attributes associated
586   // to this input.
587   const tensorflow::OpDef::ArgDef& arg_def =
588       op_def_->input_arg(current_ods_input_++);
589   Type expected_type;
590   if (arg_def.type() != tensorflow::DT_INVALID) {
591     Builder builder(context_);
592     TF_RETURN_IF_ERROR(
593         tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
594     if (arg_def.is_ref()) {
595       Type output_type;
596       TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
597       expected_type = output_type;
598     }
599   } else {
600     expected_type = cast<MlirTensor>(input)->getElementType();
601   }
602   if (!arg_def.type_attr().empty())
603     attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
604 
605   return ::tensorflow::OkStatus();
606 }
607 
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)608 Status MlirAbstractOp::AddInputList(
609     absl::Span<AbstractTensorHandle* const> inputs) {
610   if (current_ods_input_ >= op_def_->input_arg_size())
611     return InvalidArgument(
612         absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
613                      op_def_->input_arg_size(), " allowed input_args"));
614 
615   for (AbstractTensorHandle* input : inputs) {
616     auto* operand = dyn_cast<MlirTensor>(input);
617     if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
618     operands_.push_back(operand->getValue());
619   }
620 
621   // Get the next ArgDef and use it to infer the derived attributes associated
622   // to this input.
623   const tensorflow::OpDef::ArgDef& arg_def =
624       op_def_->input_arg(current_ods_input_++);
625   if (!arg_def.number_attr().empty()) {
626     Builder builder(context_);
627     attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
628     // TODO(aminim): handle ref variable.
629     if (arg_def.type() != tensorflow::DT_INVALID) {
630       // TODO(aminim): check type wrt input
631       Type arg_def_type;
632       TF_RETURN_IF_ERROR(
633           ConvertDataType(arg_def.type(), builder, &arg_def_type));
634       // Ensure each of the type in the list matches the op def type.
635       // TODO(aminim): can we improve the error message with the actual types?
636       for (AbstractTensorHandle* input : inputs)
637         if (arg_def_type != cast<MlirTensor>(input)->getElementType())
638           return InvalidArgument(
639               "Invalid input list: type mismatch the op def expectation");
640     } else if (!inputs.empty()) {
641       if (arg_def.type_attr().empty())
642         return FailedPrecondition(
643             "Invalid opdef type constraint: either type or type_attr required");
644 
645       attrs_[arg_def.type_attr()] =
646           TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
647     }
648   } else if (!arg_def.type_list_attr().empty()) {
649     // TODO(aminim): handle ref variable.
650     SmallVector<Attribute, 8> types;
651     types.reserve(inputs.size());
652     for (AbstractTensorHandle* input : inputs)
653       types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
654     attrs_[arg_def.type_list_attr()] = ArrayAttr::get(GetContext(), types);
655   }
656   return ::tensorflow::OkStatus();
657 }
658 
Finalize(OutputList * outputs,AbstractFunction ** f)659 Status MlirFunctionContext::Finalize(OutputList* outputs,
660                                      AbstractFunction** f) {
661   Block& body = func_.getBody().front();
662   SmallVector<Value, 8> ret_operands;
663   for (auto* output : outputs->outputs) {
664     auto* operand = dyn_cast<MlirTensor>(output);
665     if (!operand)
666       return InvalidArgument("Capturing eager tensors is not supported yet.");
667     if (operand->getValue().getContext() != context_.get())
668       return InvalidArgument(
669           "Capturing tensors from other context is not supported.");
670     ret_operands.push_back(operand->getValue());
671   }
672   builder_.create<func::ReturnOp>(func_.getLoc(), ret_operands);
673 
674   auto arg_types = body.getArgumentTypes();
675   auto result_types = body.getTerminator()->getOperandTypes();
676   func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
677   *f = new MlirFunction(std::move(context_), std::move(module_), func_);
678   return ::tensorflow::OkStatus();
679 }
680 
681 extern "C" {
MlirTracingFactory(const char * fn_name,TF_Status * s)682 TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
683   return new MlirFunctionContext(fn_name);
684 }
685 }
686 
687 }  // namespace
688 }  // namespace TF
689 }  // namespace mlir
690