1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17
18 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/core/framework/rng_alg.h"
22
23 namespace mlir {
24 namespace TF {
25
26 namespace {
27 // Returns int, float or complex DenseElementsAttr with scalar shape with the
28 // given element type and the integer value.
GetScalarOfType(Type ty,int64_t raw_value)29 static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
30 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
31 if (auto float_ty = ty.dyn_cast<FloatType>()) {
32 FloatAttr attr = FloatAttr::get(float_ty, raw_value);
33 return DenseElementsAttr::get(scalar_ty, attr);
34 }
35
36 if (auto int_ty = ty.dyn_cast<IntegerType>()) {
37 IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
38 return DenseElementsAttr::get(scalar_ty, attr);
39 }
40
41 if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
42 Type complex_element_ty = complex_ty.getElementType();
43 if (complex_element_ty.isF32()) {
44 return DenseElementsAttr::get(
45 scalar_ty, static_cast<std::complex<float>>(raw_value));
46 } else if (complex_element_ty.isF64()) {
47 return DenseElementsAttr::get(
48 scalar_ty, static_cast<std::complex<double>>(raw_value));
49 }
50 }
51 llvm_unreachable("unsupported type");
52 }
53
54 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
55 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)56 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
57 auto resource_type = resource.getType()
58 .cast<TensorType>()
59 .getElementType()
60 .cast<ResourceType>();
61 if (resource_type.getSubtypes().size() == 1)
62 return resource_type.getSubtypes().front();
63
64 return UnrankedTensorType::get(element_type);
65 }
66
HasResourceSubtype(Value resource)67 static bool HasResourceSubtype(Value resource) {
68 return resource.getType()
69 .cast<TensorType>()
70 .getElementType()
71 .cast<ResourceType>()
72 .getSubtypes()
73 .size() == 1;
74 }
75
GetResourceSubtype(Value resource)76 static Type GetResourceSubtype(Value resource) {
77 return resource.getType()
78 .cast<TensorType>()
79 .getElementType()
80 .cast<ResourceType>()
81 .getSubtypes()
82 .front();
83 }
84
85 // Decompose tf.RngReadAndSkip.
86 //
87 // For Philox, the resource variable holds a tensor<3xi64> with the state:
88 // [counter_lo, counter_hi, key]
89 //
90 // RngReadAndSkip increments the 128 bit counter value by 256 * delta and
91 // returns the original state value.
92 //
93 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
94 // [counter, key]
95 //
96 // RngReadAndSkip increments the 64 bit counter value by 256 * delta and
97 // returns a tensor<3xi64> value [counter, key, 0].
98 class DecomposeRngReadAndSkipOp : public RewritePattern {
99 public:
DecomposeRngReadAndSkipOp(MLIRContext * context)100 explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
101 : RewritePattern(RngReadAndSkipOp::getOperationName(), 1, context,
102 {
103 AddV2Op::getOperationName(),
104 AssignVariableOp::getOperationName(),
105 CastOp::getOperationName(),
106 ConstOp::getOperationName(),
107 LessOp::getOperationName(),
108 MulOp::getOperationName(),
109 PadOp::getOperationName(),
110 PackOp::getOperationName(),
111 ReadVariableOp::getOperationName(),
112 SelectV2Op::getOperationName(),
113 UnpackOp::getOperationName(),
114 }) {}
115
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const116 LogicalResult matchAndRewrite(Operation *op,
117 PatternRewriter &rewriter) const override {
118 auto rng_op = cast<RngReadAndSkipOp>(op);
119
120 DenseIntElementsAttr alg_constant;
121 if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
122 return rewriter.notifyMatchFailure(
123 op, "unable to determine algorithm statically");
124 }
125
126 if (alg_constant.getNumElements() != 1) {
127 return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
128 }
129
130 uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
131 tensorflow::Algorithm alg;
132 if (tensorflow::RNG_ALG_PHILOX == alg_value) {
133 alg = tensorflow::RNG_ALG_PHILOX;
134 } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
135 alg = tensorflow::RNG_ALG_THREEFRY;
136 } else {
137 return rewriter.notifyMatchFailure(op, "unsupported alg");
138 }
139
140 Type state_element_type = rewriter.getI64Type();
141 RankedTensorType op_type = RankedTensorType::get(
142 {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
143 state_element_type);
144 if (op_type != rng_op.getType()) {
145 return rewriter.notifyMatchFailure(op, "unexpected op type");
146 }
147
148 if (!HasResourceSubtype(rng_op.resource())) {
149 return rewriter.notifyMatchFailure(op, "missing resource subtype");
150 }
151
152 int counter_size = tensorflow::GetCounterSize(alg);
153 int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
154 RankedTensorType res_type =
155 RankedTensorType::get({state_size}, state_element_type);
156 if (res_type != GetResourceSubtype(rng_op.resource())) {
157 return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
158 }
159
160 Location loc = op->getLoc();
161
162 // Read the state value from the resource.
163 Value state =
164 rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
165
166 // Extract the key and counter from the state.
167 RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
168 auto unpacked = rewriter.create<UnpackOp>(
169 loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
170 Value key = unpacked.getResult(counter_size);
171
172 SmallVector<Value, 4> counter;
173 for (int i = 0; i < counter_size; ++i) {
174 counter.push_back(unpacked.getResult(i));
175 }
176
177 // Set the increment to 256 * delta.
178 Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
179 RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
180 Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
181 Value increment =
182 rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
183
184 // Increment the counter.
185 SmallVector<Value, 4> pack_args;
186 RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
187 Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
188 Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
189 for (int i = 0; i < counter_size; ++i) {
190 Value word = counter[i];
191 Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
192 Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
193 Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
194 pack_args.push_back(new_word);
195
196 Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
197 increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
198 }
199
200 // Save the new state value to the resource.
201 pack_args.push_back(key);
202 Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
203 rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
204
205 // Pad the original state as necessary to fill the output shape.
206 int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
207 Type i64 = rewriter.getI64Type();
208 RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
209 std::vector<int64_t> paddings_values = {0, pad};
210 Value paddings = rewriter.create<ConstOp>(
211 loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
212 Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
213
214 rewriter.replaceOp(op, output);
215 return success();
216 }
217 };
218
219 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
220 } // namespace
221
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,OwningRewritePatternList * patterns)222 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
223 OwningRewritePatternList *patterns) {
224 patterns->insert<DecomposeRngReadAndSkipOp>(context);
225 populateWithGenerated(*patterns);
226 }
227
228 } // namespace TF
229 } // namespace mlir
230