• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License.  You may obtain a copy
5 // of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // =============================================================================
15 
16 // TensorFlow kernels and Ops for computing a masked matrix product.
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 
31 using tensorflow::DEVICE_CPU;
32 using tensorflow::DT_BOOL;
33 using tensorflow::DT_FLOAT;
34 using tensorflow::DT_INT64;
35 using tensorflow::OpKernel;
36 using tensorflow::OpKernelConstruction;
37 using tensorflow::OpKernelContext;
38 using tensorflow::Tensor;
39 using tensorflow::TensorShape;
40 using tensorflow::TensorShapeUtils;
41 using tensorflow::errors::InvalidArgument;
42 
43 namespace tensorflow {
44 
45 typedef Eigen::Map<
46     Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
47     EigenMatInt64Map;
48 typedef Eigen::Map<
49     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
50     EigenMatFloatMap;
51 typedef Eigen::Map<
52     const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
53     ConstEigenMatInt64Map;
54 typedef Eigen::Map<
55     const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
56     ConstEigenMatFloatMap;
57 
58 class MaskedMatmulOp : public OpKernel {
59  public:
MaskedMatmulOp(OpKernelConstruction * context)60   explicit MaskedMatmulOp(OpKernelConstruction* context) : OpKernel(context) {
61     OP_REQUIRES_OK(
62         context,
63         context->MatchSignature(
64             {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, {DT_FLOAT}));
65   }
66 
Compute(OpKernelContext * context)67   void Compute(OpKernelContext* context) override {
68     // Computes the product a * b, but only for indices (i, j) in mask_indices.
69     // The result is stored in prod_values, a 1-tensor, such that for all i,
70     // prod_values[i] = (a * b)[mask_indices[i, 0], mask_indices[i, 1]].
71     const Tensor& a = context->input(0);
72     const Tensor& b = context->input(1);
73     const Tensor& mask_indices = context->input(2);
74     const Tensor& transpose_a = context->input(3);
75     const Tensor& transpose_b = context->input(4);
76 
77     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
78                 InvalidArgument("Input a should be a matrix."));
79     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
80                 InvalidArgument("Input b should be a matrix."));
81     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(mask_indices.shape()),
82                 InvalidArgument("Input mask_indices should be a matrix."));
83     OP_REQUIRES(context, TensorShapeUtils::IsScalar(transpose_a.shape()),
84                 InvalidArgument("Input transpose_a should be a scalar."));
85     OP_REQUIRES(context, TensorShapeUtils::IsScalar(transpose_b.shape()),
86                 InvalidArgument("Input transpose_b should be a scalar."));
87 
88     const bool adj_a = transpose_a.scalar<bool>()();
89     const bool adj_b = transpose_b.scalar<bool>()();
90     const int64 a_dim_0 = a.dim_size(adj_a ? 1 : 0);
91     const int64 a_dim_1 = a.dim_size(adj_a ? 0 : 1);
92     const int64 b_dim_0 = b.dim_size(adj_b ? 1 : 0);
93     const int64 b_dim_1 = b.dim_size(adj_b ? 0 : 1);
94     const int64 num_nonzero_elements = mask_indices.dim_size(0);
95 
96     OP_REQUIRES(context, a_dim_1 == b_dim_0,
97                 InvalidArgument("Matrix shapes are incompatible: a has shape ",
98                                 a.shape().DebugString(), ", while b has shape ",
99                                 b.shape().DebugString(), "."));
100     OP_REQUIRES(context, mask_indices.dim_size(1) == 2,
101                 InvalidArgument("mask_indices should be a matrix of shape ",
102                                 "[nnz 2], where nnz is the number of non-zero ",
103                                 "elements."));
104 
105     ConstEigenMatFloatMap a_mat(a.matrix<float>().data(), a.dim_size(0),
106                                 a.dim_size(1));
107     ConstEigenMatFloatMap b_mat(b.matrix<float>().data(), b.dim_size(0),
108                                 b.dim_size(1));
109     ConstEigenMatInt64Map indices_mat(mask_indices.matrix<int64>().data(),
110                                       num_nonzero_elements, 2);
111 
112     Tensor* prod_values_tensor;
113     OP_REQUIRES_OK(context, context->allocate_output(
114                                 0, TensorShape({num_nonzero_elements}),
115                                 &prod_values_tensor));
116     EigenMatFloatMap prod_values(prod_values_tensor->vec<float>().data(), 1,
117                                  num_nonzero_elements);
118 
119     auto get_a_index = [&indices_mat, &a_dim_0](int64 i) {
120       int64 a_index = internal::SubtleMustCopy(indices_mat(i, 0));
121       CHECK(FastBoundsCheck(a_index, a_dim_0))
122           << "In mask_indices[" << i << ", :], the row index " << a_index
123           << " is out of bounds [0, " << a_dim_0 << ").";
124       return a_index;
125     };
126     auto get_b_index = [&indices_mat, &b_dim_1](int64 i) {
127       int64 b_index = internal::SubtleMustCopy(indices_mat(i, 1));
128       CHECK(FastBoundsCheck(b_index, b_dim_1))
129           << "In mask_indices[" << i << ", :], the column index " << b_index
130           << " is out of bounds [0, " << b_dim_1 << ").";
131       return b_index;
132     };
133     auto get_dot_product = [&adj_a, &adj_b, &a_mat, &b_mat](int64 i, int64 j) {
134       if (adj_a) {
135         if (adj_b) {
136           return a_mat.col(i).dot(b_mat.row(j));
137         } else {
138           return a_mat.col(i).dot(b_mat.col(j));
139         }
140       } else {
141         if (adj_b) {
142           return a_mat.row(i).dot(b_mat.row(j));
143         } else {
144           return a_mat.row(i).dot(b_mat.col(j));
145         }
146       }
147     };
148 
149     std::vector<int64> perm(num_nonzero_elements);
150     std::iota(perm.begin(), perm.end(), 0);
151     // TODO(walidk): improve performance in the case adj_a and not adj_b
152     // TODO(walidk): benchmark smaller inputs, and potentially skip the sort
153     // when the input fits in L3 cache.
154     // Compute a permutation to sort either the a or b matrix, to take advantage
155     // of CPU caching. Since row access is efficient (given the RowMajor
156     // ordering), we prefer to
157     //   sort according to a when a is transposed,
158     //   sort according to b when b is not transpose.
159     auto compare_a_index = [&get_a_index](int64 i, int64 j) {
160       return get_a_index(i) < get_a_index(j);
161     };
162     auto compare_b_index = [&get_b_index](int64 i, int64 j) {
163       return get_b_index(i) < get_b_index(j);
164     };
165     if (adj_a) {
166       std::stable_sort(perm.begin(), perm.end(), compare_a_index);
167     } else if (!adj_b) {
168       std::stable_sort(perm.begin(), perm.end(), compare_b_index);
169     }
170 
171     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
172     // Based on benchmarks, the cost is on the order of 20 cycles per dimension
173     const int64 cost_per_unit = 20 * a_dim_1;
174     // Lambda encapsulating the per-shard computation.
175     auto work = [&](int64 begin, int64 end) {
176       for (int64 i = begin; i < end; ++i) {
177         const int64 p = perm[i];
178         const int64 a_index = get_a_index(p);
179         const int64 b_index = get_b_index(p);
180         prod_values(p) = get_dot_product(a_index, b_index);
181       }
182     };
183     // Shard the work.
184     worker_threads.workers->ParallelFor(num_nonzero_elements, cost_per_unit,
185                                         work);
186   }
187 };
188 REGISTER_KERNEL_BUILDER(Name("MaskedMatmul").Device(DEVICE_CPU),
189                         MaskedMatmulOp);
190 
191 }  // namespace tensorflow
192