1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_
16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_
17
18 #include "absl/strings/string_view.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tfrt/bef/bef_encoding.h" // from @tf_runtime
26 #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime
27 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
28 #include "tfrt/host_context/host_context.h" // from @tf_runtime
29 #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
30 #include "tfrt/support/forward_decls.h" // from @tf_runtime
31
32 namespace tensorflow {
33 namespace tfd {
34
35 // Converts a TFRT string_view to the Abseil version.
ToAbslStringView(tfrt::string_view sv)36 inline absl::string_view ToAbslStringView(tfrt::string_view sv) {
37 return absl::string_view(sv.data(), sv.size());
38 }
39
40 // Parses the string representation of the DataType in `dtype` into `data_type`.
41 // Aborts the program for unsupported dtypes.
42 tensorflow::Status ParseTfDataType(absl::string_view dtype,
43 DataType* data_type);
44
45 // The following 2 functions convert between Tensorflow DataTypes and
46 // OpAttrTypes. The mapping between OpAttrType and DataType is defined in
47 // attr_type.def. Aborts on unsupported types.
48 DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type);
49 tfrt::OpAttrType ConvertFromTfDataType(DataType data_type);
50
51 // The following 2 functions convert between BEF attribute types and Tensorflow
52 // DataTypes. Aborts on unsupported datatypes.
53 DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type);
54 tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type);
55
56 // Parses the tensor valued `attr_value` and constructs the tensor with its
57 // contents in `tensor`. Returns OK status on success, INVALID_ARGUMENT on
58 // failure.
59 tensorflow::Status ParseTensorAttrValue(absl::string_view attr_value,
60 tensorflow::Tensor* tensor);
61
62 // Parses a string of the form "[1,2,3,...]" in `attr_value` and returns the
63 // constituent dimension sizes (shape) in `int_list_val`. Returns
64 // INVALID_ARGUMENT on invalid input.
65 tensorflow::Status ParseTensorShapeAttrValue(absl::string_view attr_value,
66 std::vector<int64_t>* shape_val);
67
68 // Parses a boolean from `attr_value` into `bool_val` and returns OK status on
69 // success. Returns INVALID_ARGUMENT on invalid input.
70 tensorflow::Status ParseBoolAttrValue(absl::string_view attr_value,
71 bool* bool_val);
72
73 // Parses an int64_t from `attr_value` into `int_val` and returns OK status on
74 // success. Returns INVLAID_ARGUMENT on invalid input.
75 tensorflow::Status ParseIntAttrValue(absl::string_view attr_value,
76 int64_t* int_val);
77
AttrValueSplit(absl::string_view str)78 inline std::vector<absl::string_view> AttrValueSplit(absl::string_view str) {
79 return absl::StrSplit(str, absl::MaxSplits('$', 1));
80 }
81
82 // Returns true if `attr_name` is an attribute that is not required by TFRT
83 // (usually added by stages higher in the lowering process)
84 bool IsUnusedAttribute(absl::string_view attr_name);
85
86 // Fills in the passed in AttrValueMap `attr_value_map` with attributes from
87 // `attrs`.
88 llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs,
89 tfrt::HostContext* host,
90 AttrValueMap* attr_value_map);
91
92 // Fills in the passed in AttrValueMap `attr_value_map`.
93 tensorflow::Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,
94 tfrt::AggregateAttr op_func_attr_array,
95 tensorflow::AttrValueMap* attr_value_map);
96
97 } // namespace tfd
98 } // namespace tensorflow
99
100 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_
101