1 /* 2 * Copyright (c) 2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_EXPERIMENTAL_IPOSTOP 25 #define ARM_COMPUTE_EXPERIMENTAL_IPOSTOP 26 27 #include <memory> 28 #include <numeric> 29 #include <vector> 30 31 namespace arm_compute 32 { 33 namespace experimental 34 { 35 /** Type of Post Op */ 36 enum class PostOpType 37 { 38 Activation, 39 Eltwise_Add, 40 Eltwise_PRelu 41 }; 42 /** An ordered sequence of type of Post Ops */ 43 using PostOpTypeSequence = std::vector<PostOpType>; 44 /** An elementwise n-ary operation that can be appended to and fused with (at kernel-level) other operators 45 * It contains: 46 * 1. The attributes of the original operator. 47 * 2. Any additional tensor argument. 48 * 3. The position of the previous op's dst tensor in its argument list ( @ref prev_dst_pos ) 49 * 50 * For example, a series of chained ops: 51 * 52 * div(src1, relu(conv(src0, weights, bias, conv_info), act_info), div_info) 53 * 54 * translates to 55 * 56 * dst = conv(src0, weights, bias, conv_info) // main op 57 * dst = relu(dst, act_info) // previous dst is placed in the first (and only) argument 58 * dst = div(src1, dst, div_info) // previous dst is placed in the second argument 59 * 60 * which in turn translates to: 61 * 62 * main op: conv(src0, weights, bias, conv_info) 63 * post op1: relu(act_info, prev_dst_pos = 0) 64 * post op2: div(div_info, src1, prev_dst_pos = 1) 65 * 66 * @note: On Broadcasting 67 * For n-ary post ops, the tensor arguments must not "widen" the dst tensor of the main op 68 * For example, for a dst of shape [14, 1, 34]: 69 * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dim 0 70 * * post_op_arg1 = [14, 1, 34] is allowed: no broadcast 71 * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1 72 * * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor 73 * 74 * @note: On Data layout 75 * All post ops are data layout agnostic. This means post ops do not have an inherent idea of "width", "height" and so on. 76 * Should we want to perform a post op with 2 tensors of different data layouts (where data layouts are significant to both), 77 * then we need to perform necessary permutation op beforehand to unify their data layout before they can be fused with a post op 78 * 79 * Note although post ops themselves should be able to support any data layout, the main op they fuse to may impose 80 * additional restrictions in the presence of post ops. For example, the implementation of a gemm op may only allow 81 * NHWC data layout if post ops are provided. Such restrictions are main op implementation specific. 82 * 83 * @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type 84 * @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime 85 * and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time 86 * after the ITensor or ITensorInfo objects are already constructed 87 */ 88 template <typename TensorRelatedT> 89 struct IPostOp 90 { 91 /** Get the arity of the post op 92 * @note: that this is one fewer than the arity of the original op, because we implicitly pass the previous op's dst 93 * tensor as one of the arguments 94 */ arityIPostOp95 size_t arity() const 96 { 97 return arguments().size(); 98 } 99 /** The position of previous op's dst in current op's argument list */ 100 virtual int prev_dst_pos() const = 0; 101 /** The IPostOp type */ 102 virtual PostOpType type() const = 0; 103 /** The argument tensors 104 * The order of the argument tensor is strictly preserved 105 */ 106 virtual std::vector<TensorRelatedT *> arguments() = 0; 107 virtual std::vector<const TensorRelatedT *> arguments() const = 0; 108 /** Clone method used in cases where PostOps are owned by unique_ptr 109 * @note: This performs a shallow copy of the TensorRelatedT if TensorRelatedT points to a resource 110 */ 111 virtual std::unique_ptr<IPostOp<TensorRelatedT>> clone() const = 0; ~IPostOpIPostOp112 virtual ~IPostOp() 113 { 114 } 115 }; 116 117 /** A sequence of PostOps that can be appended to the end of other operators */ 118 template <typename TensorRelatedT> 119 class PostOpList 120 { 121 public: 122 /** Constructor */ 123 PostOpList() = default; 124 /** Destructor */ 125 ~PostOpList() = default; PostOpList(const PostOpList & other)126 PostOpList(const PostOpList &other) 127 { 128 for(const auto &op : other._post_ops) 129 { 130 this->_post_ops.push_back(op->clone()); 131 } 132 } 133 PostOpList &operator=(const PostOpList &other) 134 { 135 PostOpList tmp{ other }; 136 std::swap(tmp, *this); 137 return *this; 138 } 139 PostOpList(PostOpList &&other) = default; 140 PostOpList &operator=(PostOpList &&other) = default; 141 142 /** Add a new post op at the end of the list */ 143 template <typename OpT, typename... Args> push_back_op(Args &&...args)144 void push_back_op(Args &&... args) 145 { 146 _post_ops.push_back(std::make_unique<OpT>(std::forward<Args>(args)...)); 147 } 148 149 /** Number of post ops */ size()150 size_t size() const 151 { 152 return _post_ops.size(); 153 } 154 155 /** Total number of post ops */ total_num_arguments()156 size_t total_num_arguments() const 157 { 158 return std::accumulate(_post_ops.begin(), _post_ops.end(), 0, [](size_t op1_arity, const auto & op2) 159 { 160 return op1_arity + op2->arity(); 161 }); 162 } 163 164 /** Get the underlying post op list */ get_list()165 std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list() 166 { 167 return _post_ops; 168 } get_list()169 const std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> &get_list() const 170 { 171 return _post_ops; 172 } 173 174 private: 175 std::vector<std::unique_ptr<IPostOp<TensorRelatedT>>> _post_ops{}; 176 }; 177 178 } // namespace experimental 179 } // namespace arm_compute 180 #endif //ARM_COMPUTE_EXPERIMENTAL_IPOSTOP 181