1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/parsing_ops.cc. 17 #include <vector> 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/framework/tensor_shape.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/strings/numbers.h" 24 25 namespace tensorflow { 26 27 class DecodeCSVOp : public OpKernel { 28 public: DecodeCSVOp(OpKernelConstruction * ctx)29 explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 30 string delim; 31 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_)); 33 OP_REQUIRES(ctx, out_type_.size() < std::numeric_limits<int>::max(), 34 errors::InvalidArgument("Out type too large")); 35 OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim)); 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_)); 37 OP_REQUIRES_OK(ctx, ctx->GetAttr("select_cols", &select_cols_)); 38 OP_REQUIRES( 39 ctx, out_type_.size() == select_cols_.size() || select_cols_.empty(), 40 errors::InvalidArgument("select_cols should match output size")); 41 select_all_cols_ = select_cols_.empty(); 42 for (int i = 1; i < select_cols_.size(); i++) { 43 OP_REQUIRES(ctx, select_cols_[i - 1] < select_cols_[i], 44 errors::InvalidArgument( 45 "select_cols should be strictly increasing indices")); 46 } 47 OP_REQUIRES( 48 ctx, select_cols_.empty() || select_cols_.front() >= 0, 49 errors::InvalidArgument("select_cols should be non-negative indices")); 50 OP_REQUIRES(ctx, delim.size() == 1, 51 errors::InvalidArgument("field_delim should be only 1 char")); 52 delim_ = delim[0]; 53 OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_)); 54 } 55 Compute(OpKernelContext * ctx)56 void Compute(OpKernelContext* ctx) override { 57 const Tensor* records; 58 OpInputList record_defaults; 59 60 OP_REQUIRES_OK(ctx, ctx->input("records", &records)); 61 OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults)); 62 63 for (int i = 0; i < record_defaults.size(); ++i) { 64 OP_REQUIRES(ctx, record_defaults[i].dims() <= 1, 65 errors::InvalidArgument( 66 "Each record default should be at most rank 1")); 67 OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2, 68 errors::InvalidArgument( 69 "There should only be 1 default per field but field ", i, 70 " has ", record_defaults[i].NumElements())); 71 } 72 73 auto records_t = records->flat<tstring>(); 74 int64_t records_size = records_t.size(); 75 76 OpOutputList output; 77 OP_REQUIRES_OK(ctx, ctx->output_list("output", &output)); 78 79 for (int i = 0; i < static_cast<int>(out_type_.size()); ++i) { 80 Tensor* out = nullptr; 81 OP_REQUIRES_OK(ctx, output.allocate(i, records->shape(), &out)); 82 } 83 84 for (int64_t i = 0; i < records_size; ++i) { 85 const StringPiece record(records_t(i)); 86 std::vector<string> fields; 87 ExtractFields(ctx, record, &fields); 88 OP_REQUIRES(ctx, fields.size() == out_type_.size(), 89 errors::InvalidArgument("Expect ", out_type_.size(), 90 " fields but have ", fields.size(), 91 " in record ", i)); 92 93 // Check each field in the record 94 for (int f = 0; f < static_cast<int>(out_type_.size()); ++f) { 95 const DataType& dtype = out_type_[f]; 96 switch (dtype) { 97 case DT_INT32: { 98 // If this field is empty or NA value, check if default is given: 99 // If yes, use default value; Otherwise report error. 100 if (fields[f].empty() || fields[f] == na_value_) { 101 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, 102 errors::InvalidArgument( 103 "Field ", f, 104 " is required but missing in record ", i, "!")); 105 106 output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0); 107 } else { 108 int32_t value; 109 OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value), 110 errors::InvalidArgument( 111 "Field ", f, " in record ", i, 112 " is not a valid int32: ", fields[f])); 113 output[f]->flat<int32>()(i) = value; 114 } 115 break; 116 } 117 case DT_INT64: { 118 // If this field is empty or NA value, check if default is given: 119 // If yes, use default value; Otherwise report error. 120 if (fields[f].empty() || fields[f] == na_value_) { 121 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, 122 errors::InvalidArgument( 123 "Field ", f, 124 " is required but missing in record ", i, "!")); 125 126 output[f]->flat<int64_t>()(i) = 127 record_defaults[f].flat<int64_t>()(0); 128 } else { 129 int64_t value; 130 OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value), 131 errors::InvalidArgument( 132 "Field ", f, " in record ", i, 133 " is not a valid int64: ", fields[f])); 134 output[f]->flat<int64_t>()(i) = value; 135 } 136 break; 137 } 138 case DT_FLOAT: { 139 // If this field is empty or NA value, check if default is given: 140 // If yes, use default value; Otherwise report error. 141 if (fields[f].empty() || fields[f] == na_value_) { 142 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, 143 errors::InvalidArgument( 144 "Field ", f, 145 " is required but missing in record ", i, "!")); 146 output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0); 147 } else { 148 float value; 149 OP_REQUIRES(ctx, strings::safe_strtof(fields[f], &value), 150 errors::InvalidArgument( 151 "Field ", f, " in record ", i, 152 " is not a valid float: ", fields[f])); 153 output[f]->flat<float>()(i) = value; 154 } 155 break; 156 } 157 case DT_DOUBLE: { 158 // If this field is empty or NA value, check if default is given: 159 // If yes, use default value; Otherwise report error. 160 if (fields[f].empty() || fields[f] == na_value_) { 161 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, 162 errors::InvalidArgument( 163 "Field ", f, 164 " is required but missing in record ", i, "!")); 165 output[f]->flat<double>()(i) = 166 record_defaults[f].flat<double>()(0); 167 } else { 168 double value; 169 OP_REQUIRES(ctx, strings::safe_strtod(fields[f], &value), 170 errors::InvalidArgument( 171 "Field ", f, " in record ", i, 172 " is not a valid double: ", fields[f])); 173 output[f]->flat<double>()(i) = value; 174 } 175 break; 176 } 177 case DT_STRING: { 178 // If this field is empty or NA value, check if default is given: 179 // If yes, use default value; Otherwise report error. 180 if (fields[f].empty() || fields[f] == na_value_) { 181 OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, 182 errors::InvalidArgument( 183 "Field ", f, 184 " is required but missing in record ", i, "!")); 185 output[f]->flat<tstring>()(i) = 186 record_defaults[f].flat<tstring>()(0); 187 } else { 188 output[f]->flat<tstring>()(i) = std::move(fields[f]); 189 } 190 break; 191 } 192 default: 193 OP_REQUIRES(ctx, false, 194 errors::InvalidArgument("csv: data type ", dtype, 195 " not supported in field ", f)); 196 } 197 } 198 } 199 } 200 201 private: 202 std::vector<DataType> out_type_; 203 std::vector<int64_t> select_cols_; 204 char delim_; 205 bool use_quote_delim_; 206 bool select_all_cols_; 207 string na_value_; 208 ExtractFields(OpKernelContext * ctx,StringPiece input,std::vector<string> * result)209 void ExtractFields(OpKernelContext* ctx, StringPiece input, 210 std::vector<string>* result) { 211 int64_t current_idx = 0; 212 int64_t num_fields_parsed = 0; 213 int64_t selector_idx = 0; // Keep track of index into select_cols 214 215 if (!input.empty()) { 216 while (static_cast<size_t>(current_idx) < input.size()) { 217 if (input[current_idx] == '\n' || input[current_idx] == '\r') { 218 current_idx++; 219 continue; 220 } 221 222 bool quoted = false; 223 bool include = 224 (select_all_cols_ || select_cols_[selector_idx] == 225 static_cast<size_t>(num_fields_parsed)); 226 227 if (use_quote_delim_ && input[current_idx] == '"') { 228 quoted = true; 229 current_idx++; 230 } 231 232 // This is the body of the field; 233 string field; 234 if (!quoted) { 235 while (static_cast<size_t>(current_idx) < input.size() && 236 input[current_idx] != delim_) { 237 OP_REQUIRES(ctx, 238 (!use_quote_delim_ || input[current_idx] != '"') && 239 input[current_idx] != '\n' && 240 input[current_idx] != '\r', 241 errors::InvalidArgument( 242 "Unquoted fields cannot have quotes/CRLFs inside")); 243 if (include) field += input[current_idx]; 244 current_idx++; 245 } 246 247 // Go to next field or the end 248 current_idx++; 249 } else if (use_quote_delim_) { 250 // Quoted field needs to be ended with '"' and delim or end 251 while ( 252 (static_cast<size_t>(current_idx) < input.size() - 1) && 253 (input[current_idx] != '"' || input[current_idx + 1] != delim_)) { 254 if (input[current_idx] != '"') { 255 if (include) field += input[current_idx]; 256 current_idx++; 257 } else { 258 OP_REQUIRES( 259 ctx, input[current_idx + 1] == '"', 260 errors::InvalidArgument("Quote inside a string has to be " 261 "escaped by another quote")); 262 if (include) field += '"'; 263 current_idx += 2; 264 } 265 } 266 267 OP_REQUIRES( 268 ctx, 269 (static_cast<size_t>(current_idx) < input.size() && 270 input[current_idx] == '"' && 271 (static_cast<size_t>(current_idx) == input.size() - 1 || 272 input[current_idx + 1] == delim_)), 273 errors::InvalidArgument("Quoted field has to end with quote " 274 "followed by delim or end")); 275 276 current_idx += 2; 277 } 278 279 num_fields_parsed++; 280 if (include) { 281 result->push_back(field); 282 selector_idx++; 283 if (selector_idx == select_cols_.size()) return; 284 } 285 } 286 287 bool include = 288 (select_all_cols_ || select_cols_[selector_idx] == 289 static_cast<size_t>(num_fields_parsed)); 290 // Check if the last field is missing 291 if (include && input[input.size() - 1] == delim_) 292 result->push_back(string()); 293 } 294 } 295 }; 296 297 REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp); 298 299 } // namespace tensorflow 300