1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "mlir/IR/Builders.h" // from @llvm-project 21 #include "mlir/IR/Location.h" // from @llvm-project 22 #include "mlir/IR/Value.h" // from @llvm-project 23 #include "mlir/Support/LLVM.h" // from @llvm-project 24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" 26 27 namespace mlir { 28 namespace TF { 29 namespace collection_ops_util { 30 31 // This file includes utilities for decomposing collection ops (stack, tensor 32 // list, tensor array) in TF. We represent such a data structure as a buffer of 33 // shape [max_element_count, element_shape]. 34 35 // Creates an i32 scalar tf.Const. 36 Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc); 37 38 // Creates an integer vector tf.Const. 39 Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc, 40 int bitwidth = 32); 41 42 // Returns the type of the size tensor used to track a data structure's element 43 // count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is 44 // easier to concat it with other offsets. 45 TensorType GetSizeType(OpBuilder builder); 46 47 // Reshapes a scalar value to match the size type tensor<i32>. 48 Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc); 49 50 // Creates ops that represent the indices of the slice for an element in the 51 // buffer. Requires `index` to have tensor<1xi32> type. 52 Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, 53 Location loc); 54 55 // Creates ops that slice the element out of a buffer at the given index. 56 // Requires `index` to have tensor<1xi32> type. 57 Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, 58 bool keep_slice_shape = false); 59 60 // Creates ops that copy the buffer and update an element at the given index. 61 // Requires `index` to have tensor<1xi32> type. 62 Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, 63 Location loc); 64 65 // Creates the buffer for the data structure with given element shape, type and 66 // maximum size. 67 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape, 68 int64_t max_size, Operation* op, 69 Type element_dtype, OpBuilder builder, 70 Value* buffer); 71 72 // Same as above, but uses a Value as max_size and check if it is a constant. 73 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape, 74 Value max_size, Operation* op, 75 Type element_dtype, OpBuilder builder, 76 Value* buffer); 77 78 // Tries to infer the element type with full shape based its write accesses. 79 // `infer_from_user` should check if the provided op is an accessing op that 80 // could be used to infer the type. 81 llvm::Optional<RankedTensorType> GetElementTypeFromAccess( 82 Value collection, ModuleOp module, 83 llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op); 84 85 // Creates a ReadVariableOp on a local variable. 86 Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc); 87 88 // Creates an AssignVariableOp on a local variable. 89 TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value, 90 OpBuilder builder, Location loc); 91 92 // Adds two values, or creates a logical-or if they are boolean type. 93 Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc); 94 95 // Gathers elements in buffer with the indices. 96 Value GatherElements(Value indices, Value buffer, OpBuilder builder, 97 Location loc); 98 99 // Scatters elements into buffer, where each scattered element is accumulated 100 // with the old value in buffer. 101 Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, 102 OpBuilder builder, Location loc); 103 104 } // namespace collection_ops_util 105 } // namespace TF 106 } // namespace mlir 107 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ 108