• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
17 
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/platform/status.h"
23 
24 namespace mlir {
25 namespace TFL {
26 
CreateConstOpWithSingleValue(PatternRewriter * rewriter,Location loc,ShapedType shaped_type,int value)27 stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
28     PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
29     int value) {
30   Type element_type = shaped_type.getElementType();
31   ShapedType scalar_type = RankedTensorType::get({}, element_type);
32   Attribute attr;
33   if (element_type.isF16()) {
34     auto floatType = mlir::FloatType::getF16(element_type.getContext());
35     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
36     std::vector<Attribute> floatValues({floatAttr});
37     attr = DenseElementsAttr::get(scalar_type, floatValues);
38   } else if (element_type.isBF16()) {
39     auto floatType = mlir::FloatType::getBF16(element_type.getContext());
40     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
41     std::vector<Attribute> floatValues({floatAttr});
42     attr = DenseElementsAttr::get(scalar_type, floatValues);
43   } else if (element_type.isF32()) {
44     attr =
45         DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
46   } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
47     auto etype = complex_type.getElementType();
48     if (etype.isF32()) {
49       auto dialect = etype.getContext()->getLoadedDialect("tf");
50       tensorflow::TensorProto repr;
51       repr.set_dtype(tensorflow::DT_COMPLEX64);
52 
53       tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
54       shape->set_unknown_rank(false);
55       shape->add_dim()->set_size(int64_t{1});
56       std::string content;
57       auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
58       content.assign(reinterpret_cast<const char*>(&complex_value),
59                      sizeof(complex_value));
60       repr.set_tensor_content(content);
61       std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
62 
63       attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
64     } else {
65       return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
66                                 "Unsupported type");
67     }
68   } else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
69     if (element_type.isSignedInteger()) {
70       switch (itype.getWidth()) {
71         case 8:
72           attr = DenseElementsAttr::get<int8_t>(scalar_type,
73                                                 static_cast<int8_t>(value));
74           break;
75         case 16:
76           attr = DenseElementsAttr::get<int16_t>(scalar_type,
77                                                  static_cast<int16_t>(value));
78           break;
79         case 32:
80           attr = DenseElementsAttr::get<int32_t>(scalar_type,
81                                                  static_cast<int32_t>(value));
82           break;
83         case 64:
84           attr = DenseElementsAttr::get<int64_t>(scalar_type,
85                                                  static_cast<int64_t>(value));
86           break;
87         default:
88           return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
89                                     "Unsupported type");
90       }
91     } else {
92       switch (itype.getWidth()) {
93         case 8:
94           attr = DenseElementsAttr::get<uint8_t>(scalar_type,
95                                                  static_cast<uint8_t>(value));
96           break;
97         case 16:
98           attr = DenseElementsAttr::get<uint16_t>(scalar_type,
99                                                   static_cast<uint16_t>(value));
100           break;
101         case 32:
102           attr = DenseElementsAttr::get<uint32_t>(scalar_type,
103                                                   static_cast<uint32_t>(value));
104           break;
105         case 64:
106           attr = DenseElementsAttr::get<uint64_t>(scalar_type,
107                                                   static_cast<uint64_t>(value));
108           break;
109         default:
110           return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
111                                     "Unsupported type");
112       }
113     }
114   } else {
115     return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
116                               "Unsupported type");
117   }
118   return rewriter->create<ConstantOp>(loc, scalar_type, attr);
119 }
120 
121 }  // namespace TFL
122 }  // namespace mlir
123