1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
17
18 #include <string>
19
20 #include "mlir/IR/AffineMap.h" // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/Location.h" // from @llvm-project
24 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34
35 using mlir::IntegerType;
36 using mlir::MemRefType;
37 using mlir::RankedTensorType;
38 using mlir::VectorType;
39 using tensorflow::int64;
40 using xla::PrimitiveType;
41 using xla::ShapeUtil;
42
43 namespace xla {
44
TypeToPrimitiveType(mlir::Type type)45 PrimitiveType TypeToPrimitiveType(mlir::Type type) {
46 if (type.isBF16()) {
47 return PrimitiveType::BF16;
48 } else if (type.isF16()) {
49 return PrimitiveType::F16;
50 } else if (type.isF32()) {
51 return PrimitiveType::F32;
52 } else if (type.isF64()) {
53 return PrimitiveType::F64;
54 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
55 mlir::Type element_ty = complex_type.getElementType();
56 if (element_ty.isF32()) {
57 return PrimitiveType::C64;
58
59 } else if (element_ty.isF64()) {
60 return PrimitiveType::C128;
61 }
62 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
63 } else if (auto integer_type = type.dyn_cast<mlir::IntegerType>()) {
64 bool is_unsigned = integer_type.isUnsigned();
65 switch (integer_type.getWidth()) {
66 case 1:
67 return PrimitiveType::PRED;
68 case 8:
69 return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
70 case 16:
71 return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
72 case 32:
73 return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
74 case 64:
75 return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
76 default:
77 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
78 }
79 }
80 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
81 }
82
TypeToShape(mlir::Type type,CustomShapeRepresentationFn shape_representation_fn)83 StatusOr<Shape> TypeToShape(
84 mlir::Type type, CustomShapeRepresentationFn shape_representation_fn) {
85 tensorflow::PartialTensorShape partial_tensor_shape =
86 tensorflow::ConvertTypeToTensorShape(type);
87
88 tensorflow::TensorShape fully_defined_tensor_shape;
89 if (!partial_tensor_shape.AsTensorShape(&fully_defined_tensor_shape)) {
90 return tensorflow::errors::InvalidArgument(
91 "XLA HLO only allows fully-defined shape");
92 }
93
94 tensorflow::DataType dtype;
95 TF_RETURN_IF_ERROR(tensorflow::ConvertToDataType(type, &dtype));
96
97 return shape_representation_fn(fully_defined_tensor_shape, dtype);
98 }
99
TypeToShape(mlir::Type type)100 Shape TypeToShape(mlir::Type type) {
101 PrimitiveType ptype = TypeToPrimitiveType(type);
102 if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
103 return ShapeUtil::MakeShape(ptype, {});
104
105 if (type.isIntOrFloat()) {
106 auto* context = type.getContext();
107 mlir::emitError(mlir::UnknownLoc::get(context))
108 << "lowering should have been handled by primitive type lowering for "
109 << debugString(type);
110 } else if (auto v = type.dyn_cast<mlir::VectorType>()) {
111 llvm::SmallVector<int64, 4> span(v.getShape().begin(), v.getShape().end());
112 mlir::Type element_type = v.getElementType();
113 PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
114 if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
115 return ShapeUtil::MakeShape(primitive_type, span);
116 } else if (auto m = type.dyn_cast<mlir::MemRefType>()) {
117 llvm::SmallVector<int64, 6> span(m.getShape().begin(), m.getShape().end());
118 mlir::Type element_type = m.getElementType();
119 // Treat a memref of a vector as if it was a memref of primitive type with
120 // the vector dimensions at the end.
121 if (auto v = element_type.dyn_cast<mlir::VectorType>()) {
122 element_type = v.getElementType();
123 span.insert(span.end(), v.getShape().begin(), v.getShape().end());
124 }
125 PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
126 if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
127 // For the primitive type case, the shape of the memref is similar to the
128 // vector type case (i.e., it is, modulo the layout, the same dimensions
129 // and primitive type).
130 if (m.getAffineMaps().empty())
131 return ShapeUtil::MakeShape(primitive_type, span);
132
133 if (m.getAffineMaps().size() == 1) {
134 llvm::SmallVector<int64_t, 4> strides;
135 int64_t offset;
136 if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
137
138 llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
139 for (const auto& e : llvm::enumerate(strides)) {
140 strides_with_indices.push_back({e.value(), e.index()});
141 }
142 std::stable_sort(strides_with_indices.begin(),
143 strides_with_indices.end());
144
145 llvm::SmallVector<int64, 4> minor_to_major;
146 int64_t stride = 1;
147 for (const auto& pr : strides_with_indices) {
148 minor_to_major.push_back(pr.second);
149
150 // Either the affine map is not perfectly strided, or the dimensions
151 // recovered from strides don't match the actual dimensions in shapes.
152 if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
153
154 stride *= m.getShape()[pr.second];
155 }
156
157 llvm::SmallVector<int64, 4> dimensions(m.getShape().begin(),
158 m.getShape().end());
159 return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
160 minor_to_major);
161 }
162 } else if (auto t = type.dyn_cast<mlir::RankedTensorType>()) {
163 // TODO(jpienaar): This is only handling the base case with primitive
164 // element type.
165 llvm::SmallVector<int64, 4> span(t.getShape().begin(), t.getShape().end());
166 // Only fully static shapes are supported.
167 // TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
168 if (std::find(t.getShape().begin(), t.getShape().end(), -1) !=
169 t.getShape().end())
170 return {};
171 mlir::Type element_type = t.getElementType();
172 PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
173 // Only primitive element type supported.
174 if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
175 return ShapeUtil::MakeShape(primitive_type, span);
176 } else if (auto tuple_type = type.dyn_cast<mlir::TupleType>()) {
177 llvm::SmallVector<Shape, 4> shapes;
178 shapes.reserve(tuple_type.size());
179 for (mlir::Type sub_type : tuple_type.getTypes()) {
180 shapes.push_back(TypeToShape(sub_type));
181 }
182 return ShapeUtil::MakeTupleShape(shapes);
183
184 } else if (type.isa<mlir::mhlo::TokenType>()) {
185 return ShapeUtil::MakeTokenShape();
186 }
187
188 // Return empty XLA shape to signify error. No MLIR Type maps to a empty
189 // Shape.
190 return {};
191 }
192
193 } // namespace xla
194