1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ 18 19 #include "llvm/Support/SourceMgr.h" 20 #include "llvm/Support/raw_ostream.h" 21 #include "mlir/IR/Diagnostics.h" // TF:llvm-project 22 #include "mlir/IR/Location.h" // TF:llvm-project 23 #include "mlir/IR/MLIRContext.h" // TF:llvm-project 24 #include "tensorflow/core/lib/core/status.h" 25 26 // Error utilities for MLIR when interacting with code using Status returns. 27 namespace mlir { 28 29 // TensorFlow's Status is used for error reporting back to callers. 30 using tensorflow::Status; 31 32 // Diagnostic handler that collects all the diagnostics reported and can produce 33 // a Status to return to callers. This is for the case where MLIR functions are 34 // called from a function that will return a Status: MLIR code still uses the 35 // default error reporting, and the final return function can return the Status 36 // constructed from the diagnostics collected. 37 class StatusScopedDiagnosticHandler : public SourceMgrDiagnosticHandler { 38 public: 39 // Constructs a diagnostic handler in a context. If propagate is true, then 40 // diagnostics reported are also propagated back to the original diagnostic 41 // handler. 42 explicit StatusScopedDiagnosticHandler(MLIRContext* context, 43 bool propagate = false); 44 // On destruction error consumption is verified. 45 ~StatusScopedDiagnosticHandler(); 46 47 // Returns whether any errors were reported. 48 bool ok() const; 49 50 // Returns Status corresponding to the diagnostics reported. This consumes the 51 // diagnostics reported and returns a Status of type Unknown. It is required 52 // to consume the error status, if there is one, before destroying the object. 53 Status ConsumeStatus(); 54 55 // Returns the combination of the passed in status and consumed diagnostics. 56 // This consumes the diagnostics reported and either appends the diagnostics 57 // to the error message of 'status' (if 'status' is already an error state), 58 // or returns an Unknown status (if diagnostics reported), otherwise OK. 59 Status Combine(Status status); 60 61 private: 62 LogicalResult handler(Diagnostic* diag); 63 64 // String stream to assemble the final error message. 65 std::string diag_str_; 66 llvm::raw_string_ostream diag_stream_; 67 68 // A SourceMgr to use for the base handler class. 69 llvm::SourceMgr source_mgr_; 70 71 // Whether to propagate diagnostics to the old diagnostic handler. 72 bool propagate_; 73 }; 74 } // namespace mlir 75 76 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ 77