1#!/usr/bin/env python 2 3# Copyright Jim Bosch & Ankit Daftery 2010-2012. 4# Distributed under the Boost Software License, Version 1.0. 5# (See accompanying file LICENSE_1_0.txt or copy at 6# http://www.boost.org/LICENSE_1_0.txt) 7 8import ndarray_ext 9import unittest 10import numpy 11 12class TestNdarray(unittest.TestCase): 13 14 def testNdzeros(self): 15 for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): 16 v = numpy.zeros(60, dtype=dtp) 17 dt = numpy.dtype(dtp) 18 for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)): 19 a1 = ndarray_ext.zeros(shape,dt) 20 a2 = v.reshape(a1.shape) 21 self.assertEqual(shape,a1.shape) 22 self.assert_((a1 == a2).all()) 23 24 def testNdzeros_matrix(self): 25 for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): 26 dt = numpy.dtype(dtp) 27 shape = (6, 10) 28 a1 = ndarray_ext.zeros_matrix(shape, dt) 29 a2 = numpy.matrix(numpy.zeros(shape, dtype=dtp)) 30 self.assertEqual(shape,a1.shape) 31 self.assert_((a1 == a2).all()) 32 self.assertEqual(type(a1), type(a2)) 33 34 def testNdarray(self): 35 a = range(0,60) 36 for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): 37 v = numpy.array(a, dtype=dtp) 38 dt = numpy.dtype(dtp) 39 a1 = ndarray_ext.array(a) 40 a2 = ndarray_ext.array(a,dt) 41 self.assert_((a1 == v).all()) 42 self.assert_((a2 == v).all()) 43 for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)): 44 a1 = a1.reshape(shape) 45 self.assertEqual(shape,a1.shape) 46 a2 = a2.reshape(shape) 47 self.assertEqual(shape,a2.shape) 48 49 def testNdempty(self): 50 for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): 51 dt = numpy.dtype(dtp) 52 for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)): 53 a1 = ndarray_ext.empty(shape,dt) 54 a2 = ndarray_ext.c_empty(shape,dt) 55 self.assertEqual(shape,a1.shape) 56 self.assertEqual(shape,a2.shape) 57 58 def testTranspose(self): 59 for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128): 60 dt = numpy.dtype(dtp) 61 for shape in ((6,10),(4,3,5),(2,2,3,5)): 62 a1 = numpy.empty(shape,dt) 63 a2 = a1.transpose() 64 a1 = ndarray_ext.transpose(a1) 65 self.assertEqual(a1.shape,a2.shape) 66 67 def testSqueeze(self): 68 a1 = numpy.array([[[3,4,5]]]) 69 a2 = a1.squeeze() 70 a1 = ndarray_ext.squeeze(a1) 71 self.assertEqual(a1.shape,a2.shape) 72 73 def testReshape(self): 74 a1 = numpy.empty((2,2)) 75 a2 = ndarray_ext.reshape(a1,(1,4)) 76 self.assertEqual(a2.shape,(1,4)) 77 78 def testShapeIndex(self): 79 a = numpy.arange(24) 80 a.shape = (1,2,3,4) 81 def shape_check(i): 82 print(i) 83 self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] ) 84 for i in range(4): 85 shape_check(i) 86 for i in range(-1,-5,-1): 87 shape_check(i) 88 try: 89 ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError 90 self.assertTrue(False) 91 except IndexError: 92 pass 93 94 def testStridesIndex(self): 95 a = numpy.arange(24) 96 a.shape = (1,2,3,4) 97 def strides_check(i): 98 print(i) 99 self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] ) 100 for i in range(4): 101 strides_check(i) 102 for i in range(-1,-5,-1): 103 strides_check(i) 104 try: 105 ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError 106 self.assertTrue(False) 107 except IndexError: 108 pass 109 110 111if __name__=="__main__": 112 unittest.main() 113