• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstddef>
16 #include <string>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 
27 namespace tensorflow {
28 namespace {
29 template <typename T>
CopyToBuffer(const T & value,uint8 * output)30 inline void CopyToBuffer(const T& value, uint8* output) {
31   // Memcpy to string is endian-dependent. We choose little-endian as
32   // standard. On big-endian machines, bytes should be reversed.
33 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
34   static_assert(port::kLittleEndian, "");
35   std::memcpy(output, &value, sizeof(value));
36 #else
37   static_assert(!port::kLittleEndian, "");
38   std::reverse_copy(reinterpret_cast<const uint8*>(&value),
39                     reinterpret_cast<const uint8*>(&value + 1), output);
40 #endif
41 }
42 
FarmhashFingerprint64(TTypes<uint8,2>::ConstTensor input,TTypes<uint8,2>::Matrix output)43 void FarmhashFingerprint64(TTypes<uint8, 2>::ConstTensor input,
44                            TTypes<uint8, 2>::Matrix output) {
45   DCHECK_EQ(output.dimension(0), input.dimension(0));
46   DCHECK_EQ(output.dimension(1), sizeof(uint64));
47   for (int64_t i = 0; i < output.dimension(0); ++i) {
48     const uint64 fingerprint =
49         Fingerprint64({reinterpret_cast<const char*>(&input(i, 0)),
50                        static_cast<std::size_t>(input.dimension(1))});
51     CopyToBuffer(fingerprint, &output(i, 0));
52   }
53 }
54 
FarmhashFingerprint64(TTypes<tstring>::ConstFlat input,TTypes<uint8,2>::Matrix output)55 void FarmhashFingerprint64(TTypes<tstring>::ConstFlat input,
56                            TTypes<uint8, 2>::Matrix output) {
57   DCHECK_EQ(output.dimension(0), input.dimension(0));
58   DCHECK_EQ(output.dimension(1), sizeof(uint64));
59   for (int64_t i = 0; i < input.dimension(0); ++i) {
60     const uint64 fingerprint =
61         Fingerprint64({input(i).data(), input(i).size()});
62     CopyToBuffer(fingerprint, &output(i, 0));
63   }
64 }
65 
66 class FingerprintOp : public OpKernel {
67  public:
FingerprintOp(OpKernelConstruction * context)68   explicit FingerprintOp(OpKernelConstruction* context) : OpKernel(context) {
69     DataType dtype;
70     OP_REQUIRES_OK(context, context->GetAttr("T", &dtype));
71     OP_REQUIRES(context, DataTypeCanUseMemcpy(dtype) || dtype == DT_STRING,
72                 errors::InvalidArgument("Data type not supported: ",
73                                         DataTypeString(dtype)));
74   }
75 
Compute(tensorflow::OpKernelContext * context)76   void Compute(tensorflow::OpKernelContext* context) override {
77     const Tensor& method_tensor = context->input(1);
78     OP_REQUIRES(context, TensorShapeUtils::IsScalar(method_tensor.shape()),
79                 errors::InvalidArgument("`method` should be a scalar string: ",
80                                         method_tensor.shape()));
81     // For now, farmhash64 is the only function supported.
82     const tstring& method = method_tensor.scalar<tstring>()();
83     OP_REQUIRES(
84         context, method == "farmhash64",
85         errors::InvalidArgument("Unsupported fingerprint method: ", method));
86 
87     const Tensor& input = context->input(0);
88     OP_REQUIRES(
89         context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
90         errors::InvalidArgument("`data` should have at least one dimension: ",
91                                 input.shape()));
92 
93     const int64_t dim0 = input.shape().dim_size(0);
94     int64_t dim1;
95     if (dim0 == 0) {
96       dim1 = 0;
97     } else {
98       dim1 = input.shape().num_elements() / dim0;
99     }
100 
101     Tensor* output;
102     OP_REQUIRES_OK(context,
103                    context->allocate_output(
104                        0, TensorShape{dim0, kFingerprintSize}, &output));
105 
106     if (input.dtype() == DT_STRING) {
107       if (dim1 > 1) {
108         Tensor temp;
109         OP_REQUIRES_OK(context, context->allocate_temp(
110                                     DT_UINT8,
111                                     TensorShape{input.shape().num_elements(),
112                                                 kFingerprintSize},
113                                     &temp));
114         // `temp` is a matrix of shape {input.num_elements, fingerprint_size},
115         // and each row contains the fingerprint value of corresponding string.
116         // To compute fingerprints of multiple strings, this op fingerprints the
117         // buffer containing the string fingerprints.
118         FarmhashFingerprint64(input.flat<tstring>(), temp.tensor<uint8, 2>());
119         FarmhashFingerprint64(static_cast<const Tensor&>(temp).shaped<uint8, 2>(
120                                   {dim0, dim1 * kFingerprintSize}),
121                               output->matrix<uint8>());
122       } else {
123         // In case dim1 == 1, each string computes into its own fingerprint
124         // value. There is no need to fingerprint twice.
125         FarmhashFingerprint64(input.flat<tstring>(), output->matrix<uint8>());
126       }
127     } else {
128       auto data = input.bit_casted_shaped<uint8, 2>(
129           {dim0, dim1 * DataTypeSize(input.dtype())});
130       FarmhashFingerprint64(data, output->matrix<uint8>());
131     }
132   }
133 
134  private:
135   static constexpr int kFingerprintSize = sizeof(uint64);
136 };
137 
138 REGISTER_KERNEL_BUILDER(Name("Fingerprint").Device(tensorflow::DEVICE_CPU),
139                         FingerprintOp);
140 }  // namespace
141 }  // namespace tensorflow
142