1 /*
2 * Copyright (C) 2008 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 /* ---- includes ----------------------------------------------------------- */
18
19 #include "b_TensorEm/Int32Mat.h"
20 #include "b_TensorEm/Functions.h"
21 #include "b_BasicEm/Math.h"
22 #include "b_BasicEm/Functions.h"
23 #include "b_BasicEm/Memory.h"
24
25 /* ------------------------------------------------------------------------- */
26
27 /* ========================================================================= */
28 /* */
29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */
30 /* */
31 /* ========================================================================= */
32
33 /* ------------------------------------------------------------------------- */
34
bts_Int32Mat_reduceToNBits(int32 * ptrA,uint32 sizeA,int32 * bbpPtrA,uint32 nBitsA)35 void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
36 {
37 int32 shiftL;
38
39 /* find max element */
40 int32 maxL = 0;
41 int32* ptrL = ptrA;
42 int32 iL = sizeA;
43 while( iL-- )
44 {
45 int32 xL = *ptrL++;
46 if( xL < 0 ) xL = -xL;
47 if( xL > maxL ) maxL = xL;
48 }
49
50 /* determine shift */
51 shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;
52
53 if( shiftL > 0 )
54 {
55 ptrL = ptrA;
56 iL = sizeA;
57 while( iL-- )
58 {
59 *ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
60 ptrL++;
61 }
62
63 *bbpPtrA -= shiftL;
64 }
65 }
66
67 /* ------------------------------------------------------------------------- */
68
69 /* ========================================================================= */
70 /* */
71 /* ---- \ghd{ constructor / destructor } ----------------------------------- */
72 /* */
73 /* ========================================================================= */
74
75 /* ------------------------------------------------------------------------- */
76
bts_Int32Mat_init(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA)77 void bts_Int32Mat_init( struct bbs_Context* cpA,
78 struct bts_Int32Mat* ptrA )
79 {
80 ptrA->widthE = 0;
81 bbs_Int32Arr_init( cpA, &ptrA->arrE );
82 }
83
84 /* ------------------------------------------------------------------------- */
85
bts_Int32Mat_exit(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA)86 void bts_Int32Mat_exit( struct bbs_Context* cpA,
87 struct bts_Int32Mat* ptrA )
88 {
89 ptrA->widthE = 0;
90 bbs_Int32Arr_exit( cpA, &ptrA->arrE );
91 }
92 /* ------------------------------------------------------------------------- */
93
94 /* ========================================================================= */
95 /* */
96 /* ---- \ghd{ operators } -------------------------------------------------- */
97 /* */
98 /* ========================================================================= */
99
100 /* ------------------------------------------------------------------------- */
101
102 /* ========================================================================= */
103 /* */
104 /* ---- \ghd{ query functions } -------------------------------------------- */
105 /* */
106 /* ========================================================================= */
107
108 /* ------------------------------------------------------------------------- */
109
110 /* ========================================================================= */
111 /* */
112 /* ---- \ghd{ modify functions } ------------------------------------------- */
113 /* */
114 /* ========================================================================= */
115
116 /* ------------------------------------------------------------------------- */
117
bts_Int32Mat_create(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,int32 widthA,struct bbs_MemSeg * mspA)118 void bts_Int32Mat_create( struct bbs_Context* cpA,
119 struct bts_Int32Mat* ptrA,
120 int32 widthA,
121 struct bbs_MemSeg* mspA )
122 {
123 if( bbs_Context_error( cpA ) ) return;
124 bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
125 ptrA->widthE = widthA;
126 }
127
128 /* ------------------------------------------------------------------------- */
129
bts_Int32Mat_copy(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,const struct bts_Int32Mat * srcPtrA)130 void bts_Int32Mat_copy( struct bbs_Context* cpA,
131 struct bts_Int32Mat* ptrA,
132 const struct bts_Int32Mat* srcPtrA )
133 {
134 if( ptrA->widthE != srcPtrA->widthE )
135 {
136 bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
137 "size mismatch" );
138 return;
139 }
140
141 bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
142 }
143
144 /* ------------------------------------------------------------------------- */
145
146 /* ========================================================================= */
147 /* */
148 /* ---- \ghd{ I/O } -------------------------------------------------------- */
149 /* */
150 /* ========================================================================= */
151
152 /* ------------------------------------------------------------------------- */
153
bts_Int32Mat_memSize(struct bbs_Context * cpA,const struct bts_Int32Mat * ptrA)154 uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
155 const struct bts_Int32Mat *ptrA )
156 {
157 return bbs_SIZEOF16( uint32 )
158 + bbs_SIZEOF16( uint32 ) /* version */
159 + bbs_SIZEOF16( ptrA->widthE )
160 + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
161 }
162
163 /* ------------------------------------------------------------------------- */
164
bts_Int32Mat_memWrite(struct bbs_Context * cpA,const struct bts_Int32Mat * ptrA,uint16 * memPtrA)165 uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
166 const struct bts_Int32Mat* ptrA,
167 uint16* memPtrA )
168 {
169 uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
170 memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
171 memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
172 memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
173 memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
174 return memSizeL;
175 }
176
177 /* ------------------------------------------------------------------------- */
178
bts_Int32Mat_memRead(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,const uint16 * memPtrA,struct bbs_MemSeg * mspA)179 uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
180 struct bts_Int32Mat* ptrA,
181 const uint16* memPtrA,
182 struct bbs_MemSeg* mspA )
183 {
184 uint32 memSizeL, versionL;
185 if( bbs_Context_error( cpA ) ) return 0;
186 memPtrA += bbs_memRead32( &memSizeL, memPtrA );
187 memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
188 memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
189 memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );
190
191 if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
192 {
193 bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
194 "size mismatch" );
195 }
196 return memSizeL;
197 }
198
199 /* ------------------------------------------------------------------------- */
200
201 /* ========================================================================= */
202 /* */
203 /* ---- \ghd{ exec functions } --------------------------------------------- */
204 /* */
205 /* ========================================================================= */
206
207 /* ------------------------------------------------------------------------- */
208
bts_Int32Mat_solve(struct bbs_Context * cpA,const int32 * matA,int32 matWidthA,const int32 * inVecA,int32 * outVecA,int32 bbpA,int32 * tmpMatA,int32 * tmpVecA)209 flag bts_Int32Mat_solve( struct bbs_Context* cpA,
210 const int32* matA,
211 int32 matWidthA,
212 const int32* inVecA,
213 int32* outVecA,
214 int32 bbpA,
215 int32* tmpMatA,
216 int32* tmpVecA )
217 {
218 bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );
219
220 return bts_Int32Mat_solve2( cpA,
221 tmpMatA,
222 matWidthA,
223 inVecA,
224 outVecA,
225 bbpA,
226 tmpVecA );
227 }
228
229 /* ------------------------------------------------------------------------- */
230
bts_Int32Mat_solve2(struct bbs_Context * cpA,int32 * matA,int32 matWidthA,const int32 * inVecA,int32 * outVecA,int32 bbpA,int32 * tmpVecA)231 flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
232 int32* matA,
233 int32 matWidthA,
234 const int32* inVecA,
235 int32* outVecA,
236 int32 bbpA,
237 int32* tmpVecA )
238 {
239 int32 sizeL = matWidthA;
240 int32 bbpL = bbpA;
241 int32 iL, jL, kL;
242 int32 iPivL;
243 int32 jPivL;
244
245 int32* vecL = outVecA;
246 int32* matL = matA;
247 int32* checkArrL = tmpVecA;
248
249 for( iL = 0; iL < sizeL; iL++ )
250 {
251 checkArrL[ iL ] = 0;
252 }
253
254 bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );
255
256 iPivL = 0;
257
258 for( kL = 0; kL < sizeL; kL++ )
259 {
260 /* find pivot */
261 int32 maxAbsL = 0;
262 int32* pivRowL;
263
264 int32 bbp_pivRowL, bbp_vecL, shiftL;
265
266 jPivL = -1;
267 for( iL = 0; iL < sizeL; iL++ )
268 {
269 if( checkArrL[ iL ] != 1 )
270 {
271 int32* rowL = matL + ( iL * sizeL );
272 for( jL = 0; jL < sizeL; jL++ )
273 {
274 if( checkArrL[ jL ] == 0 )
275 {
276 int32 absElemL = rowL[ jL ];
277 if( absElemL < 0 ) absElemL = -absElemL;
278 if( maxAbsL < absElemL )
279 {
280 maxAbsL = absElemL;
281 iPivL = iL;
282 jPivL = jL;
283 }
284 }
285 else if( checkArrL[ jL ] > 1 )
286 {
287 return FALSE;
288 }
289 }
290 }
291 }
292
293 /* successfull ? */
294 if( jPivL < 0 )
295 {
296 return FALSE;
297 }
298
299 checkArrL[ jPivL ]++;
300
301 /* exchange rows to put pivot on diagonal, if neccessary */
302 if( iPivL != jPivL )
303 {
304 int32* row1PtrL = matL + ( iPivL * sizeL );
305 int32* row2PtrL = matL + ( jPivL * sizeL );
306 for( jL = 0; jL < sizeL; jL++ )
307 {
308 int32 tmpL = *row1PtrL;
309 *row1PtrL++ = *row2PtrL;
310 *row2PtrL++ = tmpL;
311 }
312
313 {
314 int32 tmpL = vecL[ jPivL ];
315 vecL[ jPivL ] = vecL[ iPivL ];
316 vecL[ iPivL ] = tmpL;
317 }
318 }
319 /* now index jPivL specifies pivot row and maximum element */
320
321
322 /** Overflow protection: only if the highest bit of the largest matrix element is set,
323 * we need to shift the whole matrix and the right side vector 1 bit to the right,
324 * to make sure there can be no overflow when the pivot row gets subtracted from the
325 * other rows.
326 * Getting that close to overflow is a rare event, so this shift will happen only
327 * occasionally, or not at all.
328 */
329 if( maxAbsL & 1073741824 ) /*( 1 << 30 )*/
330 {
331 /* right shift matrix by 1 */
332 int32 iL = sizeL * sizeL;
333 int32* ptrL = matL;
334 while( iL-- )
335 {
336 *ptrL = ( *ptrL + 1 ) >> 1;
337 ptrL++;
338 }
339
340 /* right shift right side vector by 1 */
341 iL = sizeL;
342 ptrL = vecL;
343 while( iL-- )
344 {
345 *ptrL = ( *ptrL + 1 ) >> 1;
346 ptrL++;
347 }
348
349 /* decrement bbpL */
350 bbpL--;
351 }
352
353
354 /* reduce elements of pivot row to 15 bit */
355 pivRowL = matL + jPivL * sizeL;
356 bbp_pivRowL = bbpL;
357 bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );
358
359 /* scale pivot row such that maximum equals 1 */
360 {
361 int32 maxL = pivRowL[ jPivL ];
362 int32 bbp_maxL = bbp_pivRowL;
363 int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/
364
365 for( jL = 0; jL < sizeL; jL++ )
366 {
367 pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
368 }
369 bbp_pivRowL = 15;
370
371 /* set to 1 to avoid computational errors */
372 pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL;
373
374 shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );
375
376 vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
377 bbp_vecL = bbpL + shiftL - bbp_maxL;
378
379 bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
380 }
381
382 /* subtract pivot row from all other rows */
383 for( iL = 0; iL < sizeL; iL++ )
384 {
385 if( iL != jPivL )
386 {
387 int32* rowPtrL = matL + iL * sizeL;
388
389 int32 tmpL = *( rowPtrL + jPivL );
390 int32 bbp_tmpL = bbpL;
391 bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );
392
393 shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
394 if( shiftL > 0 )
395 {
396 for( jL = 0; jL < sizeL; jL++ )
397 {
398 *rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
399 }
400 }
401 else
402 {
403 for( jL = 0; jL < sizeL; jL++ )
404 {
405 *rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
406 }
407 }
408
409 shiftL = bbp_tmpL + bbp_vecL - bbpL;
410 if( shiftL > 0 )
411 {
412 vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
413 }
414 else
415 {
416 vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
417 }
418 }
419 }
420
421 /* change bbp of pivot row back to bbpL */
422 shiftL = bbpL - bbp_pivRowL;
423 if( shiftL >= 0 )
424 {
425 for( jL = 0; jL < sizeL; jL++ )
426 {
427 pivRowL[ jL ] <<= shiftL;
428 }
429 }
430 else
431 {
432 shiftL = -shiftL;
433 for( jL = 0; jL < sizeL; jL++ )
434 {
435 pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
436 }
437 }
438
439 shiftL = bbpL - bbp_vecL;
440 if( shiftL >= 0 )
441 {
442 vecL[ jPivL ] <<= shiftL;
443 }
444 else
445 {
446 shiftL = -shiftL;
447 vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
448 }
449 /*
450 if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
451 */
452 } /* of kL */
453
454 /* in case bbpL has been decreased by the overflow protection, change it back now */
455 if( bbpA > bbpL )
456 {
457 /* find largest element of solution vector */
458 int32 maxL = 0;
459 int32 iL, shiftL;
460 for( iL = 0; iL < sizeL; iL++ )
461 {
462 int32 xL = vecL[ iL ];
463 if( xL < 0 ) xL = -xL;
464 if( xL > maxL ) maxL = xL;
465 }
466
467 /* check whether we can left shift without overflow */
468 shiftL = 30 - bts_absIntLog2( maxL );
469 if( shiftL < ( bbpA - bbpL ) )
470 {
471 /*
472 bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
473 "compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
474 */
475
476 return FALSE;
477 }
478
479 /* shift left */
480 shiftL = bbpA - bbpL;
481 for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
482 }
483
484 return TRUE;
485 }
486
487 /* ------------------------------------------------------------------------- */
488
489 /* ========================================================================= */
490
491