• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
17 
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/platform/status.h"
23 
24 namespace mlir {
25 namespace TFL {
26 
CreateConstOpWithSingleValue(PatternRewriter * rewriter,Location loc,ShapedType shaped_type,int value)27 stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
28     PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
29     int value) {
30   Type element_type = shaped_type.getElementType();
31   ShapedType scalar_type = RankedTensorType::get({}, element_type);
32   Attribute attr;
33   if (element_type.isF16()) {
34     auto floatType = mlir::FloatType::getF16(element_type.getContext());
35     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
36     std::vector<Attribute> floatValues({floatAttr});
37     attr = DenseElementsAttr::get(scalar_type, floatValues);
38   } else if (element_type.isBF16()) {
39     auto floatType = mlir::FloatType::getBF16(element_type.getContext());
40     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
41     std::vector<Attribute> floatValues({floatAttr});
42     attr = DenseElementsAttr::get(scalar_type, floatValues);
43   } else if (element_type.isF32()) {
44     attr =
45         DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
46   } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
47     auto etype = complex_type.getElementType();
48     if (etype.isF32()) {
49       auto dialect = etype.getContext()->getLoadedDialect("tf");
50       tensorflow::TensorProto repr;
51       repr.set_dtype(tensorflow::DT_COMPLEX64);
52 
53       tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
54       shape->set_unknown_rank(false);
55       shape->add_dim()->set_size(int64_t{1});
56       std::string content;
57       auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
58       content.assign(reinterpret_cast<const char*>(&complex_value),
59                      sizeof(complex_value));
60       repr.set_tensor_content(content);
61       std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
62 
63       attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
64     } else {
65       return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
66                                 "Unsupported type");
67     }
68   } else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
69     switch (itype.getWidth()) {
70       case 8:
71         attr = DenseElementsAttr::get<int8_t>(scalar_type,
72                                               static_cast<int8_t>(value));
73         break;
74       case 16:
75         attr = DenseElementsAttr::get<int16_t>(scalar_type,
76                                                static_cast<int16_t>(value));
77         break;
78       case 32:
79         attr = DenseElementsAttr::get<int32_t>(scalar_type,
80                                                static_cast<int32_t>(value));
81         break;
82       case 64:
83         attr = DenseElementsAttr::get<int64_t>(scalar_type,
84                                                static_cast<int64_t>(value));
85         break;
86       default:
87         return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
88                                   "Unsupported type");
89     }
90   } else {
91     return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
92                               "Unsupported type");
93   }
94   return rewriter->create<ConstantOp>(loc, scalar_type, attr);
95 }
96 
97 }  // namespace TFL
98 }  // namespace mlir
99