• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for experimental sql input op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import sqlite3
24
25from tensorflow.contrib.data.python.ops import readers
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import test
30
31
32class SqlDatasetTest(test.TestCase):
33
34  def _createSqlDataset(self, output_types, num_repeats=1):
35    dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
36                                 self.query, output_types).repeat(num_repeats)
37    iterator = dataset.make_initializable_iterator()
38    init_op = iterator.initializer
39    get_next = iterator.get_next()
40    return init_op, get_next
41
42  def setUp(self):
43    self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
44    self.driver_name = array_ops.placeholder_with_default(
45        array_ops.constant("sqlite", dtypes.string), shape=[])
46    self.query = array_ops.placeholder(dtypes.string, shape=[])
47
48    conn = sqlite3.connect(self.data_source_name)
49    c = conn.cursor()
50    c.execute("DROP TABLE IF EXISTS students")
51    c.execute("DROP TABLE IF EXISTS people")
52    c.execute("DROP TABLE IF EXISTS townspeople")
53    c.execute(
54        "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
55        "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
56        "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
57        "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
58        "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
59        "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
60        "account_balance INTEGER, registration_complete INTEGER)")
61    c.executemany(
62        "INSERT INTO students (first_name, last_name, motto, school_id, "
63        "favorite_nonsense_word, desk_number, income, favorite_number, "
64        "favorite_big_number, favorite_negative_number, "
65        "favorite_medium_sized_number, brownie_points, account_balance, "
66        "registration_complete) "
67        "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
68        [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
69          9223372036854775807, -2, 32767, 0, 0, 1),
70         ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
71          -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
72    c.execute(
73        "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
74        "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
75    c.executemany(
76        "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
77        [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
78                                                    "California")])
79    c.execute(
80        "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
81        "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
82        "FLOAT, accolades FLOAT, triumphs FLOAT)")
83    c.executemany(
84        "INSERT INTO townspeople (first_name, last_name, victories, "
85        "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
86        [("George", "Washington", 20.00,
87          1331241.321342132321324589798264627463827647382647382643874,
88          9007199254740991.0),
89         ("John", "Adams", -19.95,
90          1331241321342132321324589798264627463827647382647382643874.0,
91          9007199254740992.0)])
92    conn.commit()
93    conn.close()
94
95  # Test that SqlDataset can read from a database table.
96  def testReadResultSet(self):
97    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
98                                                dtypes.string), 2)
99    with self.test_session() as sess:
100      for _ in range(2):  # Run twice to verify statelessness of db operations.
101        sess.run(
102            init_op,
103            feed_dict={
104                self.query: "SELECT first_name, last_name, motto FROM students "
105                            "ORDER BY first_name DESC"
106            })
107        for _ in range(2):  # Dataset is repeated. See setUp.
108          self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
109          self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
110        with self.assertRaises(errors.OutOfRangeError):
111          sess.run(get_next)
112
113  # Test that SqlDataset works on a join query.
114  def testReadResultSetJoinQuery(self):
115    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
116                                                dtypes.string))
117    with self.test_session() as sess:
118      sess.run(
119          init_op,
120          feed_dict={
121              self.query:
122                  "SELECT students.first_name, state, motto FROM students "
123                  "INNER JOIN people "
124                  "ON students.first_name = people.first_name "
125                  "AND students.last_name = people.last_name"
126          })
127      self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
128      with self.assertRaises(errors.OutOfRangeError):
129        sess.run(get_next)
130
131  # Test that SqlDataset can read a database entry with a null-terminator
132  # in the middle of the text and place the entry in a `string` tensor.
133  def testReadResultSetNullTerminator(self):
134    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
135                                                dtypes.string))
136    with self.test_session() as sess:
137      sess.run(
138          init_op,
139          feed_dict={
140              self.query:
141                  "SELECT first_name, last_name, favorite_nonsense_word "
142                  "FROM students ORDER BY first_name DESC"
143          })
144      self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
145      self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
146    with self.assertRaises(errors.OutOfRangeError):
147      sess.run(get_next)
148
149  # Test that SqlDataset works when used on two different queries.
150  # Because the output types of the dataset must be determined at graph-creation
151  # time, the two queries must have the same number and types of columns.
152  def testReadResultSetReuseSqlDataset(self):
153    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
154                                                dtypes.string))
155    with self.test_session() as sess:
156      sess.run(
157          init_op,
158          feed_dict={
159              self.query: "SELECT first_name, last_name, motto FROM students "
160                          "ORDER BY first_name DESC"
161          })
162      self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
163      self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
164      with self.assertRaises(errors.OutOfRangeError):
165        sess.run(get_next)
166      sess.run(
167          init_op,
168          feed_dict={
169              self.query: "SELECT first_name, last_name, state FROM people "
170                          "ORDER BY first_name DESC"
171          })
172      self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
173      self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
174                       sess.run(get_next))
175      with self.assertRaises(errors.OutOfRangeError):
176        sess.run(get_next)
177
178  # Test that an `OutOfRangeError` is raised on the first call to
179  # `get_next_str_only` if result set is empty.
180  def testReadEmptyResultSet(self):
181    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
182                                                dtypes.string))
183    with self.test_session() as sess:
184      sess.run(
185          init_op,
186          feed_dict={
187              self.query: "SELECT first_name, last_name, motto FROM students "
188                          "WHERE first_name = 'Nonexistent'"
189          })
190      with self.assertRaises(errors.OutOfRangeError):
191        sess.run(get_next)
192
193  # Test that an error is raised when `driver_name` is invalid.
194  def testReadResultSetWithInvalidDriverName(self):
195    init_op = self._createSqlDataset((dtypes.string, dtypes.string,
196                                      dtypes.string))[0]
197    with self.test_session() as sess:
198      with self.assertRaises(errors.InvalidArgumentError):
199        sess.run(
200            init_op,
201            feed_dict={
202                self.driver_name: "sqlfake",
203                self.query: "SELECT first_name, last_name, motto FROM students "
204                            "ORDER BY first_name DESC"
205            })
206
207  # Test that an error is raised when a column name in `query` is nonexistent
208  def testReadResultSetWithInvalidColumnName(self):
209    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
210                                                dtypes.string))
211    with self.test_session() as sess:
212      sess.run(
213          init_op,
214          feed_dict={
215              self.query:
216                  "SELECT first_name, last_name, fake_column FROM students "
217                  "ORDER BY first_name DESC"
218          })
219      with self.assertRaises(errors.UnknownError):
220        sess.run(get_next)
221
222  # Test that an error is raised when there is a syntax error in `query`.
223  def testReadResultSetOfQueryWithSyntaxError(self):
224    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
225                                                dtypes.string))
226    with self.test_session() as sess:
227      sess.run(
228          init_op,
229          feed_dict={
230              self.query:
231                  "SELEmispellECT first_name, last_name, motto FROM students "
232                  "ORDER BY first_name DESC"
233          })
234      with self.assertRaises(errors.UnknownError):
235        sess.run(get_next)
236
237  # Test that an error is raised when the number of columns in `query`
238  # does not match the length of `output_types`.
239  def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
240    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
241                                                dtypes.string))
242    with self.test_session() as sess:
243      sess.run(
244          init_op,
245          feed_dict={
246              self.query: "SELECT first_name, last_name FROM students "
247                          "ORDER BY first_name DESC"
248          })
249      with self.assertRaises(errors.InvalidArgumentError):
250        sess.run(get_next)
251
252  # Test that no results are returned when `query` is an insert query rather
253  # than a select query. In particular, the error refers to the number of
254  # output types passed to the op not matching the number of columns in the
255  # result set of the query (namely, 0 for an insert statement.)
256  def testReadResultSetOfInsertQuery(self):
257    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
258                                                dtypes.string))
259    with self.test_session() as sess:
260      sess.run(
261          init_op,
262          feed_dict={
263              self.query:
264                  "INSERT INTO students (first_name, last_name, motto) "
265                  "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
266          })
267      with self.assertRaises(errors.InvalidArgumentError):
268        sess.run(get_next)
269
270  # Test that `SqlDataset` can read an integer from a SQLite database table and
271  # place it in an `int8` tensor.
272  def testReadResultSetInt8(self):
273    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
274    with self.test_session() as sess:
275      sess.run(
276          init_op,
277          feed_dict={
278              self.query: "SELECT first_name, desk_number FROM students "
279                          "ORDER BY first_name DESC"
280          })
281      self.assertEqual((b"John", 9), sess.run(get_next))
282      self.assertEqual((b"Jane", 127), sess.run(get_next))
283      with self.assertRaises(errors.OutOfRangeError):
284        sess.run(get_next)
285
286  # Test that `SqlDataset` can read a negative or 0-valued integer from a
287  # SQLite database table and place it in an `int8` tensor.
288  def testReadResultSetInt8NegativeAndZero(self):
289    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
290                                                dtypes.int8))
291    with self.test_session() as sess:
292      sess.run(
293          init_op,
294          feed_dict={
295              self.query: "SELECT first_name, income, favorite_negative_number "
296                          "FROM students "
297                          "WHERE first_name = 'John' ORDER BY first_name DESC"
298          })
299      self.assertEqual((b"John", 0, -2), sess.run(get_next))
300    with self.assertRaises(errors.OutOfRangeError):
301      sess.run(get_next)
302
303  # Test that `SqlDataset` can read a large (positive or negative) integer from
304  # a SQLite database table and place it in an `int8` tensor.
305  def testReadResultSetInt8MaxValues(self):
306    init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
307    with self.test_session() as sess:
308      sess.run(
309          init_op,
310          feed_dict={
311              self.query:
312                  "SELECT desk_number, favorite_negative_number FROM students "
313                  "ORDER BY first_name DESC"
314          })
315      self.assertEqual((9, -2), sess.run(get_next))
316      # Max and min values of int8
317      self.assertEqual((127, -128), sess.run(get_next))
318    with self.assertRaises(errors.OutOfRangeError):
319      sess.run(get_next)
320
321  # Test that `SqlDataset` can read an integer from a SQLite database table and
322  # place it in an `int16` tensor.
323  def testReadResultSetInt16(self):
324    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
325    with self.test_session() as sess:
326      sess.run(
327          init_op,
328          feed_dict={
329              self.query: "SELECT first_name, desk_number FROM students "
330                          "ORDER BY first_name DESC"
331          })
332      self.assertEqual((b"John", 9), sess.run(get_next))
333      self.assertEqual((b"Jane", 127), sess.run(get_next))
334      with self.assertRaises(errors.OutOfRangeError):
335        sess.run(get_next)
336
337  # Test that `SqlDataset` can read a negative or 0-valued integer from a
338  # SQLite database table and place it in an `int16` tensor.
339  def testReadResultSetInt16NegativeAndZero(self):
340    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
341                                                dtypes.int16))
342    with self.test_session() as sess:
343      sess.run(
344          init_op,
345          feed_dict={
346              self.query: "SELECT first_name, income, favorite_negative_number "
347                          "FROM students "
348                          "WHERE first_name = 'John' ORDER BY first_name DESC"
349          })
350      self.assertEqual((b"John", 0, -2), sess.run(get_next))
351    with self.assertRaises(errors.OutOfRangeError):
352      sess.run(get_next)
353
354  # Test that `SqlDataset` can read a large (positive or negative) integer from
355  # a SQLite database table and place it in an `int16` tensor.
356  def testReadResultSetInt16MaxValues(self):
357    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
358    with self.test_session() as sess:
359      sess.run(
360          init_op,
361          feed_dict={
362              self.query: "SELECT first_name, favorite_medium_sized_number "
363                          "FROM students ORDER BY first_name DESC"
364          })
365      # Max value of int16
366      self.assertEqual((b"John", 32767), sess.run(get_next))
367      # Min value of int16
368      self.assertEqual((b"Jane", -32768), sess.run(get_next))
369    with self.assertRaises(errors.OutOfRangeError):
370      sess.run(get_next)
371
372  # Test that `SqlDataset` can read an integer from a SQLite database table and
373  # place it in an `int32` tensor.
374  def testReadResultSetInt32(self):
375    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
376    with self.test_session() as sess:
377      sess.run(
378          init_op,
379          feed_dict={
380              self.query: "SELECT first_name, desk_number FROM students "
381                          "ORDER BY first_name DESC"
382          })
383      self.assertEqual((b"John", 9), sess.run(get_next))
384      self.assertEqual((b"Jane", 127), sess.run(get_next))
385
386  # Test that `SqlDataset` can read a negative or 0-valued integer from a
387  # SQLite database table and place it in an `int32` tensor.
388  def testReadResultSetInt32NegativeAndZero(self):
389    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
390    with self.test_session() as sess:
391      sess.run(
392          init_op,
393          feed_dict={
394              self.query: "SELECT first_name, income FROM students "
395                          "ORDER BY first_name DESC"
396          })
397      self.assertEqual((b"John", 0), sess.run(get_next))
398      self.assertEqual((b"Jane", -20000), sess.run(get_next))
399    with self.assertRaises(errors.OutOfRangeError):
400      sess.run(get_next)
401
402  # Test that `SqlDataset` can read a large (positive or negative) integer from
403  # a SQLite database table and place it in an `int32` tensor.
404  def testReadResultSetInt32MaxValues(self):
405    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
406    with self.test_session() as sess:
407      sess.run(
408          init_op,
409          feed_dict={
410              self.query: "SELECT first_name, favorite_number FROM students "
411                          "ORDER BY first_name DESC"
412          })
413      # Max value of int32
414      self.assertEqual((b"John", 2147483647), sess.run(get_next))
415      # Min value of int32
416      self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
417    with self.assertRaises(errors.OutOfRangeError):
418      sess.run(get_next)
419
420  # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
421  # table and place it in an `int32` tensor.
422  def testReadResultSetInt32VarCharColumnAsInt(self):
423    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
424    with self.test_session() as sess:
425      sess.run(
426          init_op,
427          feed_dict={
428              self.query: "SELECT first_name, school_id FROM students "
429                          "ORDER BY first_name DESC"
430          })
431      self.assertEqual((b"John", 123), sess.run(get_next))
432      self.assertEqual((b"Jane", 1000), sess.run(get_next))
433    with self.assertRaises(errors.OutOfRangeError):
434      sess.run(get_next)
435
436  # Test that `SqlDataset` can read an integer from a SQLite database table
437  # and place it in an `int64` tensor.
438  def testReadResultSetInt64(self):
439    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
440    with self.test_session() as sess:
441      sess.run(
442          init_op,
443          feed_dict={
444              self.query: "SELECT first_name, desk_number FROM students "
445                          "ORDER BY first_name DESC"
446          })
447      self.assertEqual((b"John", 9), sess.run(get_next))
448      self.assertEqual((b"Jane", 127), sess.run(get_next))
449      with self.assertRaises(errors.OutOfRangeError):
450        sess.run(get_next)
451
452  # Test that `SqlDataset` can read a negative or 0-valued integer from a
453  # SQLite database table and place it in an `int64` tensor.
454  def testReadResultSetInt64NegativeAndZero(self):
455    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
456    with self.test_session() as sess:
457      sess.run(
458          init_op,
459          feed_dict={
460              self.query: "SELECT first_name, income FROM students "
461                          "ORDER BY first_name DESC"
462          })
463      self.assertEqual((b"John", 0), sess.run(get_next))
464      self.assertEqual((b"Jane", -20000), sess.run(get_next))
465    with self.assertRaises(errors.OutOfRangeError):
466      sess.run(get_next)
467
468  # Test that `SqlDataset` can read a large (positive or negative) integer from
469  # a SQLite database table and place it in an `int64` tensor.
470  def testReadResultSetInt64MaxValues(self):
471    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
472    with self.test_session() as sess:
473      sess.run(
474          init_op,
475          feed_dict={
476              self.query:
477                  "SELECT first_name, favorite_big_number FROM students "
478                  "ORDER BY first_name DESC"
479          })
480      # Max value of int64
481      self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
482      # Min value of int64
483      self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
484    with self.assertRaises(errors.OutOfRangeError):
485      sess.run(get_next)
486
487  # Test that `SqlDataset` can read an integer from a SQLite database table and
488  # place it in a `uint8` tensor.
489  def testReadResultSetUInt8(self):
490    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
491    with self.test_session() as sess:
492      sess.run(
493          init_op,
494          feed_dict={
495              self.query: "SELECT first_name, desk_number FROM students "
496                          "ORDER BY first_name DESC"
497          })
498      self.assertEqual((b"John", 9), sess.run(get_next))
499      self.assertEqual((b"Jane", 127), sess.run(get_next))
500      with self.assertRaises(errors.OutOfRangeError):
501        sess.run(get_next)
502
503  # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
504  # SQLite database table and place them in `uint8` tensors.
505  def testReadResultSetUInt8MinAndMaxValues(self):
506    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
507    with self.test_session() as sess:
508      sess.run(
509          init_op,
510          feed_dict={
511              self.query: "SELECT first_name, brownie_points FROM students "
512                          "ORDER BY first_name DESC"
513          })
514      # Min value of uint8
515      self.assertEqual((b"John", 0), sess.run(get_next))
516      # Max value of uint8
517      self.assertEqual((b"Jane", 255), sess.run(get_next))
518    with self.assertRaises(errors.OutOfRangeError):
519      sess.run(get_next)
520
521  # Test that `SqlDataset` can read an integer from a SQLite database table
522  # and place it in a `uint16` tensor.
523  def testReadResultSetUInt16(self):
524    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
525    with self.test_session() as sess:
526      sess.run(
527          init_op,
528          feed_dict={
529              self.query: "SELECT first_name, desk_number FROM students "
530                          "ORDER BY first_name DESC"
531          })
532      self.assertEqual((b"John", 9), sess.run(get_next))
533      self.assertEqual((b"Jane", 127), sess.run(get_next))
534      with self.assertRaises(errors.OutOfRangeError):
535        sess.run(get_next)
536
537  # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
538  # SQLite database table and place them in `uint16` tensors.
539  def testReadResultSetUInt16MinAndMaxValues(self):
540    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
541    with self.test_session() as sess:
542      sess.run(
543          init_op,
544          feed_dict={
545              self.query: "SELECT first_name, account_balance FROM students "
546                          "ORDER BY first_name DESC"
547          })
548      # Min value of uint16
549      self.assertEqual((b"John", 0), sess.run(get_next))
550      # Max value of uint16
551      self.assertEqual((b"Jane", 65535), sess.run(get_next))
552    with self.assertRaises(errors.OutOfRangeError):
553      sess.run(get_next)
554
555  # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
556  # SQLite database table and place them as `True` and `False` respectively
557  # in `bool` tensors.
558  def testReadResultSetBool(self):
559    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
560    with self.test_session() as sess:
561      sess.run(
562          init_op,
563          feed_dict={
564              self.query:
565                  "SELECT first_name, registration_complete FROM students "
566                  "ORDER BY first_name DESC"
567          })
568      self.assertEqual((b"John", True), sess.run(get_next))
569      self.assertEqual((b"Jane", False), sess.run(get_next))
570      with self.assertRaises(errors.OutOfRangeError):
571        sess.run(get_next)
572
573  # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
574  # from a SQLite database table and place it as `True` in a `bool` tensor.
575  def testReadResultSetBoolNotZeroOrOne(self):
576    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
577    with self.test_session() as sess:
578      sess.run(
579          init_op,
580          feed_dict={
581              self.query: "SELECT first_name, favorite_medium_sized_number "
582                          "FROM students ORDER BY first_name DESC"
583          })
584      self.assertEqual((b"John", True), sess.run(get_next))
585      self.assertEqual((b"Jane", True), sess.run(get_next))
586      with self.assertRaises(errors.OutOfRangeError):
587        sess.run(get_next)
588
589  # Test that `SqlDataset` can read a float from a SQLite database table
590  # and place it in a `float64` tensor.
591  def testReadResultSetFloat64(self):
592    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
593                                                dtypes.float64))
594    with self.test_session() as sess:
595      sess.run(
596          init_op,
597          feed_dict={
598              self.query:
599                  "SELECT first_name, last_name, victories FROM townspeople "
600                  "ORDER BY first_name"
601          })
602      self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
603      self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
604      with self.assertRaises(errors.OutOfRangeError):
605        sess.run(get_next)
606
607  # Test that `SqlDataset` can read a float from a SQLite database table beyond
608  # the precision of 64-bit IEEE, without throwing an error. Test that
609  # `SqlDataset` identifies such a value as equal to itself.
610  def testReadResultSetFloat64OverlyPrecise(self):
611    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
612                                                dtypes.float64))
613    with self.test_session() as sess:
614      sess.run(
615          init_op,
616          feed_dict={
617              self.query:
618                  "SELECT first_name, last_name, accolades FROM townspeople "
619                  "ORDER BY first_name"
620          })
621      self.assertEqual(
622          (b"George", b"Washington",
623           1331241.321342132321324589798264627463827647382647382643874),
624          sess.run(get_next))
625      self.assertEqual(
626          (b"John", b"Adams",
627           1331241321342132321324589798264627463827647382647382643874.0),
628          sess.run(get_next))
629      with self.assertRaises(errors.OutOfRangeError):
630        sess.run(get_next)
631
632  # Test that `SqlDataset` can read a float from a SQLite database table,
633  # representing the largest integer representable as a 64-bit IEEE float
634  # such that the previous integer is also representable as a 64-bit IEEE float.
635  # Test that `SqlDataset` can distinguish these two numbers.
636  def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
637    init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
638                                                dtypes.float64))
639    with self.test_session() as sess:
640      sess.run(
641          init_op,
642          feed_dict={
643              self.query:
644                  "SELECT first_name, last_name, triumphs FROM townspeople "
645                  "ORDER BY first_name"
646          })
647      self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
648                          sess.run(get_next))
649      self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
650                          sess.run(get_next))
651      with self.assertRaises(errors.OutOfRangeError):
652        sess.run(get_next)
653
654
655if __name__ == "__main__":
656  test.main()
657