1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy 5 // of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 // ============================================================================= 15 16 // TensorFlow kernels and Ops for computing a masked matrix product. 17 18 #include <algorithm> 19 #include <numeric> 20 #include <vector> 21 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/threadpool.h" 30 31 using tensorflow::DEVICE_CPU; 32 using tensorflow::DT_BOOL; 33 using tensorflow::DT_FLOAT; 34 using tensorflow::DT_INT64; 35 using tensorflow::OpKernel; 36 using tensorflow::OpKernelConstruction; 37 using tensorflow::OpKernelContext; 38 using tensorflow::Tensor; 39 using tensorflow::TensorShape; 40 using tensorflow::TensorShapeUtils; 41 using tensorflow::errors::InvalidArgument; 42 43 namespace tensorflow { 44 45 typedef Eigen::Map< 46 Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 47 EigenMatInt64Map; 48 typedef Eigen::Map< 49 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 50 EigenMatFloatMap; 51 typedef Eigen::Map< 52 const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 53 ConstEigenMatInt64Map; 54 typedef Eigen::Map< 55 const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 56 ConstEigenMatFloatMap; 57 58 class MaskedMatmulOp : public OpKernel { 59 public: MaskedMatmulOp(OpKernelConstruction * context)60 explicit MaskedMatmulOp(OpKernelConstruction* context) : OpKernel(context) { 61 OP_REQUIRES_OK( 62 context, 63 context->MatchSignature( 64 {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, {DT_FLOAT})); 65 } 66 Compute(OpKernelContext * context)67 void Compute(OpKernelContext* context) override { 68 // Computes the product a * b, but only for indices (i, j) in mask_indices. 69 // The result is stored in prod_values, a 1-tensor, such that for all i, 70 // prod_values[i] = (a * b)[mask_indices[i, 0], mask_indices[i, 1]]. 71 const Tensor& a = context->input(0); 72 const Tensor& b = context->input(1); 73 const Tensor& mask_indices = context->input(2); 74 const Tensor& transpose_a = context->input(3); 75 const Tensor& transpose_b = context->input(4); 76 77 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()), 78 InvalidArgument("Input a should be a matrix.")); 79 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()), 80 InvalidArgument("Input b should be a matrix.")); 81 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(mask_indices.shape()), 82 InvalidArgument("Input mask_indices should be a matrix.")); 83 OP_REQUIRES(context, TensorShapeUtils::IsScalar(transpose_a.shape()), 84 InvalidArgument("Input transpose_a should be a scalar.")); 85 OP_REQUIRES(context, TensorShapeUtils::IsScalar(transpose_b.shape()), 86 InvalidArgument("Input transpose_b should be a scalar.")); 87 88 const bool adj_a = transpose_a.scalar<bool>()(); 89 const bool adj_b = transpose_b.scalar<bool>()(); 90 const int64 a_dim_0 = a.dim_size(adj_a ? 1 : 0); 91 const int64 a_dim_1 = a.dim_size(adj_a ? 0 : 1); 92 const int64 b_dim_0 = b.dim_size(adj_b ? 1 : 0); 93 const int64 b_dim_1 = b.dim_size(adj_b ? 0 : 1); 94 const int64 num_nonzero_elements = mask_indices.dim_size(0); 95 96 OP_REQUIRES(context, a_dim_1 == b_dim_0, 97 InvalidArgument("Matrix shapes are incompatible: a has shape ", 98 a.shape().DebugString(), ", while b has shape ", 99 b.shape().DebugString(), ".")); 100 OP_REQUIRES(context, mask_indices.dim_size(1) == 2, 101 InvalidArgument("mask_indices should be a matrix of shape ", 102 "[nnz 2], where nnz is the number of non-zero ", 103 "elements.")); 104 105 ConstEigenMatFloatMap a_mat(a.matrix<float>().data(), a.dim_size(0), 106 a.dim_size(1)); 107 ConstEigenMatFloatMap b_mat(b.matrix<float>().data(), b.dim_size(0), 108 b.dim_size(1)); 109 ConstEigenMatInt64Map indices_mat(mask_indices.matrix<int64>().data(), 110 num_nonzero_elements, 2); 111 112 Tensor* prod_values_tensor; 113 OP_REQUIRES_OK(context, context->allocate_output( 114 0, TensorShape({num_nonzero_elements}), 115 &prod_values_tensor)); 116 EigenMatFloatMap prod_values(prod_values_tensor->vec<float>().data(), 1, 117 num_nonzero_elements); 118 119 auto get_a_index = [&indices_mat, &a_dim_0](int64 i) { 120 int64 a_index = internal::SubtleMustCopy(indices_mat(i, 0)); 121 CHECK(FastBoundsCheck(a_index, a_dim_0)) 122 << "In mask_indices[" << i << ", :], the row index " << a_index 123 << " is out of bounds [0, " << a_dim_0 << ")."; 124 return a_index; 125 }; 126 auto get_b_index = [&indices_mat, &b_dim_1](int64 i) { 127 int64 b_index = internal::SubtleMustCopy(indices_mat(i, 1)); 128 CHECK(FastBoundsCheck(b_index, b_dim_1)) 129 << "In mask_indices[" << i << ", :], the column index " << b_index 130 << " is out of bounds [0, " << b_dim_1 << ")."; 131 return b_index; 132 }; 133 auto get_dot_product = [&adj_a, &adj_b, &a_mat, &b_mat](int64 i, int64 j) { 134 if (adj_a) { 135 if (adj_b) { 136 return a_mat.col(i).dot(b_mat.row(j)); 137 } else { 138 return a_mat.col(i).dot(b_mat.col(j)); 139 } 140 } else { 141 if (adj_b) { 142 return a_mat.row(i).dot(b_mat.row(j)); 143 } else { 144 return a_mat.row(i).dot(b_mat.col(j)); 145 } 146 } 147 }; 148 149 std::vector<int64> perm(num_nonzero_elements); 150 std::iota(perm.begin(), perm.end(), 0); 151 // TODO(walidk): improve performance in the case adj_a and not adj_b 152 // TODO(walidk): benchmark smaller inputs, and potentially skip the sort 153 // when the input fits in L3 cache. 154 // Compute a permutation to sort either the a or b matrix, to take advantage 155 // of CPU caching. Since row access is efficient (given the RowMajor 156 // ordering), we prefer to 157 // sort according to a when a is transposed, 158 // sort according to b when b is not transpose. 159 auto compare_a_index = [&get_a_index](int64 i, int64 j) { 160 return get_a_index(i) < get_a_index(j); 161 }; 162 auto compare_b_index = [&get_b_index](int64 i, int64 j) { 163 return get_b_index(i) < get_b_index(j); 164 }; 165 if (adj_a) { 166 std::stable_sort(perm.begin(), perm.end(), compare_a_index); 167 } else if (!adj_b) { 168 std::stable_sort(perm.begin(), perm.end(), compare_b_index); 169 } 170 171 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 172 // Based on benchmarks, the cost is on the order of 20 cycles per dimension 173 const int64 cost_per_unit = 20 * a_dim_1; 174 // Lambda encapsulating the per-shard computation. 175 auto work = [&](int64 begin, int64 end) { 176 for (int64 i = begin; i < end; ++i) { 177 const int64 p = perm[i]; 178 const int64 a_index = get_a_index(p); 179 const int64 b_index = get_b_index(p); 180 prod_values(p) = get_dot_product(a_index, b_index); 181 } 182 }; 183 // Shard the work. 184 worker_threads.workers->ParallelFor(num_nonzero_elements, cost_per_unit, 185 work); 186 } 187 }; 188 REGISTER_KERNEL_BUILDER(Name("MaskedMatmul").Device(DEVICE_CPU), 189 MaskedMatmulOp); 190 191 } // namespace tensorflow 192