• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/parsing_ops.cc.
17 #include <vector>
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/numbers.h"
24 
25 namespace tensorflow {
26 
27 class DecodeCSVOp : public OpKernel {
28  public:
DecodeCSVOp(OpKernelConstruction * ctx)29   explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
30     string delim;
31 
32     OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_));
33     OP_REQUIRES(ctx, out_type_.size() < std::numeric_limits<int>::max(),
34                 errors::InvalidArgument("Out type too large"));
35     OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim));
36     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_));
37     OP_REQUIRES_OK(ctx, ctx->GetAttr("select_cols", &select_cols_));
38     OP_REQUIRES(
39         ctx, out_type_.size() == select_cols_.size() || select_cols_.empty(),
40         errors::InvalidArgument("select_cols should match output size"));
41     select_all_cols_ = select_cols_.empty();
42     for (int i = 1; i < select_cols_.size(); i++) {
43       OP_REQUIRES(ctx, select_cols_[i - 1] < select_cols_[i],
44                   errors::InvalidArgument(
45                       "select_cols should be strictly increasing indices"));
46     }
47     OP_REQUIRES(
48         ctx, select_cols_.empty() || select_cols_.front() >= 0,
49         errors::InvalidArgument("select_cols should be non-negative indices"));
50     OP_REQUIRES(ctx, delim.size() == 1,
51                 errors::InvalidArgument("field_delim should be only 1 char"));
52     delim_ = delim[0];
53     OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_));
54   }
55 
Compute(OpKernelContext * ctx)56   void Compute(OpKernelContext* ctx) override {
57     const Tensor* records;
58     OpInputList record_defaults;
59 
60     OP_REQUIRES_OK(ctx, ctx->input("records", &records));
61     OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
62 
63     for (int i = 0; i < record_defaults.size(); ++i) {
64       OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
65                   errors::InvalidArgument(
66                       "Each record default should be at most rank 1"));
67       OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
68                   errors::InvalidArgument(
69                       "There should only be 1 default per field but field ", i,
70                       " has ", record_defaults[i].NumElements()));
71     }
72 
73     auto records_t = records->flat<tstring>();
74     int64_t records_size = records_t.size();
75 
76     OpOutputList output;
77     OP_REQUIRES_OK(ctx, ctx->output_list("output", &output));
78 
79     for (int i = 0; i < static_cast<int>(out_type_.size()); ++i) {
80       Tensor* out = nullptr;
81       OP_REQUIRES_OK(ctx, output.allocate(i, records->shape(), &out));
82     }
83 
84     for (int64_t i = 0; i < records_size; ++i) {
85       const StringPiece record(records_t(i));
86       std::vector<string> fields;
87       ExtractFields(ctx, record, &fields);
88       OP_REQUIRES(ctx, fields.size() == out_type_.size(),
89                   errors::InvalidArgument("Expect ", out_type_.size(),
90                                           " fields but have ", fields.size(),
91                                           " in record ", i));
92 
93       // Check each field in the record
94       for (int f = 0; f < static_cast<int>(out_type_.size()); ++f) {
95         const DataType& dtype = out_type_[f];
96         switch (dtype) {
97           case DT_INT32: {
98             // If this field is empty or NA value, check if default is given:
99             // If yes, use default value; Otherwise report error.
100             if (fields[f].empty() || fields[f] == na_value_) {
101               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
102                           errors::InvalidArgument(
103                               "Field ", f,
104                               " is required but missing in record ", i, "!"));
105 
106               output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0);
107             } else {
108               int32_t value;
109               OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value),
110                           errors::InvalidArgument(
111                               "Field ", f, " in record ", i,
112                               " is not a valid int32: ", fields[f]));
113               output[f]->flat<int32>()(i) = value;
114             }
115             break;
116           }
117           case DT_INT64: {
118             // If this field is empty or NA value, check if default is given:
119             // If yes, use default value; Otherwise report error.
120             if (fields[f].empty() || fields[f] == na_value_) {
121               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
122                           errors::InvalidArgument(
123                               "Field ", f,
124                               " is required but missing in record ", i, "!"));
125 
126               output[f]->flat<int64_t>()(i) =
127                   record_defaults[f].flat<int64_t>()(0);
128             } else {
129               int64_t value;
130               OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value),
131                           errors::InvalidArgument(
132                               "Field ", f, " in record ", i,
133                               " is not a valid int64: ", fields[f]));
134               output[f]->flat<int64_t>()(i) = value;
135             }
136             break;
137           }
138           case DT_FLOAT: {
139             // If this field is empty or NA value, check if default is given:
140             // If yes, use default value; Otherwise report error.
141             if (fields[f].empty() || fields[f] == na_value_) {
142               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
143                           errors::InvalidArgument(
144                               "Field ", f,
145                               " is required but missing in record ", i, "!"));
146               output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0);
147             } else {
148               float value;
149               OP_REQUIRES(ctx, strings::safe_strtof(fields[f], &value),
150                           errors::InvalidArgument(
151                               "Field ", f, " in record ", i,
152                               " is not a valid float: ", fields[f]));
153               output[f]->flat<float>()(i) = value;
154             }
155             break;
156           }
157           case DT_DOUBLE: {
158             // If this field is empty or NA value, check if default is given:
159             // If yes, use default value; Otherwise report error.
160             if (fields[f].empty() || fields[f] == na_value_) {
161               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
162                           errors::InvalidArgument(
163                               "Field ", f,
164                               " is required but missing in record ", i, "!"));
165               output[f]->flat<double>()(i) =
166                   record_defaults[f].flat<double>()(0);
167             } else {
168               double value;
169               OP_REQUIRES(ctx, strings::safe_strtod(fields[f], &value),
170                           errors::InvalidArgument(
171                               "Field ", f, " in record ", i,
172                               " is not a valid double: ", fields[f]));
173               output[f]->flat<double>()(i) = value;
174             }
175             break;
176           }
177           case DT_STRING: {
178             // If this field is empty or NA value, check if default is given:
179             // If yes, use default value; Otherwise report error.
180             if (fields[f].empty() || fields[f] == na_value_) {
181               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
182                           errors::InvalidArgument(
183                               "Field ", f,
184                               " is required but missing in record ", i, "!"));
185               output[f]->flat<tstring>()(i) =
186                   record_defaults[f].flat<tstring>()(0);
187             } else {
188               output[f]->flat<tstring>()(i) = std::move(fields[f]);
189             }
190             break;
191           }
192           default:
193             OP_REQUIRES(ctx, false,
194                         errors::InvalidArgument("csv: data type ", dtype,
195                                                 " not supported in field ", f));
196         }
197       }
198     }
199   }
200 
201  private:
202   std::vector<DataType> out_type_;
203   std::vector<int64_t> select_cols_;
204   char delim_;
205   bool use_quote_delim_;
206   bool select_all_cols_;
207   string na_value_;
208 
ExtractFields(OpKernelContext * ctx,StringPiece input,std::vector<string> * result)209   void ExtractFields(OpKernelContext* ctx, StringPiece input,
210                      std::vector<string>* result) {
211     int64_t current_idx = 0;
212     int64_t num_fields_parsed = 0;
213     int64_t selector_idx = 0;  // Keep track of index into select_cols
214 
215     if (!input.empty()) {
216       while (static_cast<size_t>(current_idx) < input.size()) {
217         if (input[current_idx] == '\n' || input[current_idx] == '\r') {
218           current_idx++;
219           continue;
220         }
221 
222         bool quoted = false;
223         bool include =
224             (select_all_cols_ || select_cols_[selector_idx] ==
225                                      static_cast<size_t>(num_fields_parsed));
226 
227         if (use_quote_delim_ && input[current_idx] == '"') {
228           quoted = true;
229           current_idx++;
230         }
231 
232         // This is the body of the field;
233         string field;
234         if (!quoted) {
235           while (static_cast<size_t>(current_idx) < input.size() &&
236                  input[current_idx] != delim_) {
237             OP_REQUIRES(ctx,
238                         (!use_quote_delim_ || input[current_idx] != '"') &&
239                             input[current_idx] != '\n' &&
240                             input[current_idx] != '\r',
241                         errors::InvalidArgument(
242                             "Unquoted fields cannot have quotes/CRLFs inside"));
243             if (include) field += input[current_idx];
244             current_idx++;
245           }
246 
247           // Go to next field or the end
248           current_idx++;
249         } else if (use_quote_delim_) {
250           // Quoted field needs to be ended with '"' and delim or end
251           while (
252               (static_cast<size_t>(current_idx) < input.size() - 1) &&
253               (input[current_idx] != '"' || input[current_idx + 1] != delim_)) {
254             if (input[current_idx] != '"') {
255               if (include) field += input[current_idx];
256               current_idx++;
257             } else {
258               OP_REQUIRES(
259                   ctx, input[current_idx + 1] == '"',
260                   errors::InvalidArgument("Quote inside a string has to be "
261                                           "escaped by another quote"));
262               if (include) field += '"';
263               current_idx += 2;
264             }
265           }
266 
267           OP_REQUIRES(
268               ctx,
269               (static_cast<size_t>(current_idx) < input.size() &&
270                input[current_idx] == '"' &&
271                (static_cast<size_t>(current_idx) == input.size() - 1 ||
272                 input[current_idx + 1] == delim_)),
273               errors::InvalidArgument("Quoted field has to end with quote "
274                                       "followed by delim or end"));
275 
276           current_idx += 2;
277         }
278 
279         num_fields_parsed++;
280         if (include) {
281           result->push_back(field);
282           selector_idx++;
283           if (selector_idx == select_cols_.size()) return;
284         }
285       }
286 
287       bool include =
288           (select_all_cols_ || select_cols_[selector_idx] ==
289                                    static_cast<size_t>(num_fields_parsed));
290       // Check if the last field is missing
291       if (include && input[input.size() - 1] == delim_)
292         result->push_back(string());
293     }
294   }
295 };
296 
297 REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp);
298 
299 }  // namespace tensorflow
300