1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #define EIGEN_USE_THREADS 16 17 #include "tensorflow/core/kernels/lookup_table_init_op.h" 18 19 #include <algorithm> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/lookup_util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/io/inputbuffer.h" 33 #include "tensorflow/core/lib/strings/numbers.h" 34 #include "tensorflow/core/lib/strings/str_util.h" 35 #include "tensorflow/core/platform/macros.h" 36 37 namespace tensorflow { 38 39 // Kernel to initialize a look table given a key and value tensors. 40 // After this operation, the table becomes read-only. 41 class InitializeTableOp : public OpKernel { 42 public: InitializeTableOp(OpKernelConstruction * context)43 explicit InitializeTableOp(OpKernelConstruction* context) 44 : OpKernel(context) {} 45 Compute(OpKernelContext * ctx)46 void Compute(OpKernelContext* ctx) override { 47 mutex_lock l(mu_); 48 lookup::InitializableLookupTable* table; 49 OP_REQUIRES_OK(ctx, 50 GetInitializableLookupTable("table_handle", ctx, &table)); 51 core::ScopedUnref unref_me(table); 52 53 DataType expected_input_0 = 54 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 55 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), 56 table->value_dtype()}; 57 DataTypeVector expected_outputs = {}; 58 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); 59 60 const Tensor& keys = ctx->input(1); 61 OP_REQUIRES( 62 ctx, TensorShapeUtils::IsVector(keys.shape()), 63 errors::InvalidArgument("Keys must be a vector, but received shape", 64 keys.shape().DebugString())); 65 66 const Tensor& values = ctx->input(2); 67 OP_REQUIRES( 68 ctx, TensorShapeUtils::IsVector(values.shape()), 69 errors::InvalidArgument("Values must be a vector, but received shape", 70 values.shape().DebugString())); 71 72 OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(), 73 errors::InvalidArgument( 74 "Keys and values must have the same size ", 75 keys.NumElements(), " vs ", values.NumElements())); 76 77 int memory_used_before = 0; 78 if (ctx->track_allocations()) { 79 memory_used_before = table->MemoryUsed(); 80 } 81 OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); 82 if (ctx->track_allocations()) { 83 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 84 memory_used_before); 85 } 86 } 87 88 private: 89 mutex mu_; 90 }; 91 92 REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), 93 InitializeTableOp); 94 REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU), 95 InitializeTableOp); 96 97 // Kernel to initialize a lookup table from a text file. 98 // 99 // After this operation, the table becomes read-only. 100 class InitializeTableFromTextFileOp : public OpKernel { 101 public: InitializeTableFromTextFileOp(OpKernelConstruction * ctx)102 explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx) 103 : OpKernel(ctx) { 104 OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_)); 105 OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_)); 106 OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_)); 107 string delimiter; 108 OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter)); 109 OP_REQUIRES(ctx, delimiter.size() == 1, 110 errors::InvalidArgument("delimiter should be only 1 char")); 111 delimiter_ = delimiter[0]; 112 } 113 Compute(OpKernelContext * ctx)114 void Compute(OpKernelContext* ctx) override { 115 mutex_lock l(mu_); 116 lookup::InitializableLookupTable* table; 117 OP_REQUIRES_OK(ctx, 118 GetInitializableLookupTable("table_handle", ctx, &table)); 119 core::ScopedUnref unref_me(table); 120 121 DataType expected_input_0 = 122 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 123 DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; 124 DataTypeVector expected_outputs = {}; 125 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); 126 127 const Tensor& vocab_filename_tensor = ctx->input(1); 128 OP_REQUIRES( 129 ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()), 130 errors::InvalidArgument("filename should be a single string, but got ", 131 vocab_filename_tensor.shape().DebugString())); 132 133 string vocab_filename = vocab_filename_tensor.scalar<string>()(); 134 OP_REQUIRES(ctx, !vocab_filename.empty(), 135 errors::InvalidArgument("filename cannot be empty.")); 136 137 int64 memory_used_before = 0; 138 if (ctx->track_allocations()) { 139 memory_used_before = table->MemoryUsed(); 140 } 141 OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile( 142 vocab_filename, vocab_size_, delimiter_, key_index_, 143 value_index_, ctx->env(), table)); 144 if (ctx->track_allocations()) { 145 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 146 memory_used_before); 147 } 148 } 149 150 private: 151 mutex mu_; 152 int64 vocab_size_; 153 char delimiter_; 154 int64 key_index_; 155 int64 value_index_; 156 157 TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp); 158 }; 159 160 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU), 161 InitializeTableFromTextFileOp); 162 REGISTER_KERNEL_BUILDER( 163 Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU), 164 InitializeTableFromTextFileOp); 165 166 } // namespace tensorflow 167