1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package slices_test 6 7import ( 8 "cmp" 9 "fmt" 10 "math" 11 "math/rand" 12 . "slices" 13 "strconv" 14 "strings" 15 "testing" 16) 17 18var ints = [...]int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586} 19var float64s = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8, 74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3} 20var strs = [...]string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} 21 22func TestSortIntSlice(t *testing.T) { 23 data := Clone(ints[:]) 24 Sort(data) 25 if !IsSorted(data) { 26 t.Errorf("sorted %v", ints) 27 t.Errorf(" got %v", data) 28 } 29} 30 31func TestSortFuncIntSlice(t *testing.T) { 32 data := Clone(ints[:]) 33 SortFunc(data, func(a, b int) int { return a - b }) 34 if !IsSorted(data) { 35 t.Errorf("sorted %v", ints) 36 t.Errorf(" got %v", data) 37 } 38} 39 40func TestSortFloat64Slice(t *testing.T) { 41 data := Clone(float64s[:]) 42 Sort(data) 43 if !IsSorted(data) { 44 t.Errorf("sorted %v", float64s) 45 t.Errorf(" got %v", data) 46 } 47} 48 49func TestSortStringSlice(t *testing.T) { 50 data := Clone(strs[:]) 51 Sort(data) 52 if !IsSorted(data) { 53 t.Errorf("sorted %v", strs) 54 t.Errorf(" got %v", data) 55 } 56} 57 58func TestSortLarge_Random(t *testing.T) { 59 n := 1000000 60 if testing.Short() { 61 n /= 100 62 } 63 data := make([]int, n) 64 for i := 0; i < len(data); i++ { 65 data[i] = rand.Intn(100) 66 } 67 if IsSorted(data) { 68 t.Fatalf("terrible rand.rand") 69 } 70 Sort(data) 71 if !IsSorted(data) { 72 t.Errorf("sort didn't sort - 1M ints") 73 } 74} 75 76type intPair struct { 77 a, b int 78} 79 80type intPairs []intPair 81 82// Pairs compare on a only. 83func intPairCmp(x, y intPair) int { 84 return x.a - y.a 85} 86 87// Record initial order in B. 88func (d intPairs) initB() { 89 for i := range d { 90 d[i].b = i 91 } 92} 93 94// InOrder checks if a-equal elements were not reordered. 95// If reversed is true, expect reverse ordering. 96func (d intPairs) inOrder(reversed bool) bool { 97 lastA, lastB := -1, 0 98 for i := 0; i < len(d); i++ { 99 if lastA != d[i].a { 100 lastA = d[i].a 101 lastB = d[i].b 102 continue 103 } 104 if !reversed { 105 if d[i].b <= lastB { 106 return false 107 } 108 } else { 109 if d[i].b >= lastB { 110 return false 111 } 112 } 113 lastB = d[i].b 114 } 115 return true 116} 117 118func TestStability(t *testing.T) { 119 n, m := 100000, 1000 120 if testing.Short() { 121 n, m = 1000, 100 122 } 123 data := make(intPairs, n) 124 125 // random distribution 126 for i := 0; i < len(data); i++ { 127 data[i].a = rand.Intn(m) 128 } 129 if IsSortedFunc(data, intPairCmp) { 130 t.Fatalf("terrible rand.rand") 131 } 132 data.initB() 133 SortStableFunc(data, intPairCmp) 134 if !IsSortedFunc(data, intPairCmp) { 135 t.Errorf("Stable didn't sort %d ints", n) 136 } 137 if !data.inOrder(false) { 138 t.Errorf("Stable wasn't stable on %d ints", n) 139 } 140 141 // already sorted 142 data.initB() 143 SortStableFunc(data, intPairCmp) 144 if !IsSortedFunc(data, intPairCmp) { 145 t.Errorf("Stable shuffled sorted %d ints (order)", n) 146 } 147 if !data.inOrder(false) { 148 t.Errorf("Stable shuffled sorted %d ints (stability)", n) 149 } 150 151 // sorted reversed 152 for i := 0; i < len(data); i++ { 153 data[i].a = len(data) - i 154 } 155 data.initB() 156 SortStableFunc(data, intPairCmp) 157 if !IsSortedFunc(data, intPairCmp) { 158 t.Errorf("Stable didn't sort %d ints", n) 159 } 160 if !data.inOrder(false) { 161 t.Errorf("Stable wasn't stable on %d ints", n) 162 } 163} 164 165type S struct { 166 a int 167 b string 168} 169 170func cmpS(s1, s2 S) int { 171 return cmp.Compare(s1.a, s2.a) 172} 173 174func TestMinMax(t *testing.T) { 175 intCmp := func(a, b int) int { return a - b } 176 177 tests := []struct { 178 data []int 179 wantMin int 180 wantMax int 181 }{ 182 {[]int{7}, 7, 7}, 183 {[]int{1, 2}, 1, 2}, 184 {[]int{2, 1}, 1, 2}, 185 {[]int{1, 2, 3}, 1, 3}, 186 {[]int{3, 2, 1}, 1, 3}, 187 {[]int{2, 1, 3}, 1, 3}, 188 {[]int{2, 2, 3}, 2, 3}, 189 {[]int{3, 2, 3}, 2, 3}, 190 {[]int{0, 2, -9}, -9, 2}, 191 } 192 for _, tt := range tests { 193 t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) { 194 gotMin := Min(tt.data) 195 if gotMin != tt.wantMin { 196 t.Errorf("Min got %v, want %v", gotMin, tt.wantMin) 197 } 198 199 gotMinFunc := MinFunc(tt.data, intCmp) 200 if gotMinFunc != tt.wantMin { 201 t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin) 202 } 203 204 gotMax := Max(tt.data) 205 if gotMax != tt.wantMax { 206 t.Errorf("Max got %v, want %v", gotMax, tt.wantMax) 207 } 208 209 gotMaxFunc := MaxFunc(tt.data, intCmp) 210 if gotMaxFunc != tt.wantMax { 211 t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax) 212 } 213 }) 214 } 215 216 svals := []S{ 217 {1, "a"}, 218 {2, "a"}, 219 {1, "b"}, 220 {2, "b"}, 221 } 222 223 gotMin := MinFunc(svals, cmpS) 224 wantMin := S{1, "a"} 225 if gotMin != wantMin { 226 t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin) 227 } 228 229 gotMax := MaxFunc(svals, cmpS) 230 wantMax := S{2, "a"} 231 if gotMax != wantMax { 232 t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax) 233 } 234} 235 236func TestMinMaxNaNs(t *testing.T) { 237 fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14} 238 if Min(fs) != -400.4 { 239 t.Errorf("got min %v, want -400.4", Min(fs)) 240 } 241 if Max(fs) != 999.9 { 242 t.Errorf("got max %v, want 999.9", Max(fs)) 243 } 244 245 // No matter which element of fs is replaced with a NaN, both Min and Max 246 // should propagate the NaN to their output. 247 for i := 0; i < len(fs); i++ { 248 testfs := Clone(fs) 249 testfs[i] = math.NaN() 250 251 fmin := Min(testfs) 252 if !math.IsNaN(fmin) { 253 t.Errorf("got min %v, want NaN", fmin) 254 } 255 256 fmax := Max(testfs) 257 if !math.IsNaN(fmax) { 258 t.Errorf("got max %v, want NaN", fmax) 259 } 260 } 261} 262 263func TestMinMaxPanics(t *testing.T) { 264 intCmp := func(a, b int) int { return a - b } 265 emptySlice := []int{} 266 267 if !panics(func() { Min(emptySlice) }) { 268 t.Errorf("Min([]): got no panic, want panic") 269 } 270 271 if !panics(func() { Max(emptySlice) }) { 272 t.Errorf("Max([]): got no panic, want panic") 273 } 274 275 if !panics(func() { MinFunc(emptySlice, intCmp) }) { 276 t.Errorf("MinFunc([]): got no panic, want panic") 277 } 278 279 if !panics(func() { MaxFunc(emptySlice, intCmp) }) { 280 t.Errorf("MaxFunc([]): got no panic, want panic") 281 } 282} 283 284func TestBinarySearch(t *testing.T) { 285 str1 := []string{"foo"} 286 str2 := []string{"ab", "ca"} 287 str3 := []string{"mo", "qo", "vo"} 288 str4 := []string{"ab", "ad", "ca", "xy"} 289 290 // slice with repeating elements 291 strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"} 292 293 // slice with all element equal 294 strSame := []string{"xx", "xx", "xx"} 295 296 tests := []struct { 297 data []string 298 target string 299 wantPos int 300 wantFound bool 301 }{ 302 {[]string{}, "foo", 0, false}, 303 {[]string{}, "", 0, false}, 304 305 {str1, "foo", 0, true}, 306 {str1, "bar", 0, false}, 307 {str1, "zx", 1, false}, 308 309 {str2, "aa", 0, false}, 310 {str2, "ab", 0, true}, 311 {str2, "ad", 1, false}, 312 {str2, "ca", 1, true}, 313 {str2, "ra", 2, false}, 314 315 {str3, "bb", 0, false}, 316 {str3, "mo", 0, true}, 317 {str3, "nb", 1, false}, 318 {str3, "qo", 1, true}, 319 {str3, "tr", 2, false}, 320 {str3, "vo", 2, true}, 321 {str3, "xr", 3, false}, 322 323 {str4, "aa", 0, false}, 324 {str4, "ab", 0, true}, 325 {str4, "ac", 1, false}, 326 {str4, "ad", 1, true}, 327 {str4, "ax", 2, false}, 328 {str4, "ca", 2, true}, 329 {str4, "cc", 3, false}, 330 {str4, "dd", 3, false}, 331 {str4, "xy", 3, true}, 332 {str4, "zz", 4, false}, 333 334 {strRepeats, "da", 2, true}, 335 {strRepeats, "db", 5, false}, 336 {strRepeats, "ma", 6, true}, 337 {strRepeats, "mb", 8, false}, 338 339 {strSame, "xx", 0, true}, 340 {strSame, "ab", 0, false}, 341 {strSame, "zz", 3, false}, 342 } 343 for _, tt := range tests { 344 t.Run(tt.target, func(t *testing.T) { 345 { 346 pos, found := BinarySearch(tt.data, tt.target) 347 if pos != tt.wantPos || found != tt.wantFound { 348 t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 349 } 350 } 351 352 { 353 pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare) 354 if pos != tt.wantPos || found != tt.wantFound { 355 t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 356 } 357 } 358 }) 359 } 360} 361 362func TestBinarySearchInts(t *testing.T) { 363 data := []int{20, 30, 40, 50, 60, 70, 80, 90} 364 tests := []struct { 365 target int 366 wantPos int 367 wantFound bool 368 }{ 369 {20, 0, true}, 370 {23, 1, false}, 371 {43, 3, false}, 372 {80, 6, true}, 373 } 374 for _, tt := range tests { 375 t.Run(strconv.Itoa(tt.target), func(t *testing.T) { 376 { 377 pos, found := BinarySearch(data, tt.target) 378 if pos != tt.wantPos || found != tt.wantFound { 379 t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 380 } 381 } 382 383 { 384 cmp := func(a, b int) int { 385 return a - b 386 } 387 pos, found := BinarySearchFunc(data, tt.target, cmp) 388 if pos != tt.wantPos || found != tt.wantFound { 389 t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 390 } 391 } 392 }) 393 } 394} 395 396func TestBinarySearchFloats(t *testing.T) { 397 data := []float64{math.NaN(), -0.25, 0.0, 1.4} 398 tests := []struct { 399 target float64 400 wantPos int 401 wantFound bool 402 }{ 403 {math.NaN(), 0, true}, 404 {math.Inf(-1), 1, false}, 405 {-0.25, 1, true}, 406 {0.0, 2, true}, 407 {1.4, 3, true}, 408 {1.5, 4, false}, 409 } 410 for _, tt := range tests { 411 t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) { 412 { 413 pos, found := BinarySearch(data, tt.target) 414 if pos != tt.wantPos || found != tt.wantFound { 415 t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 416 } 417 } 418 }) 419 } 420} 421 422func TestBinarySearchFunc(t *testing.T) { 423 data := []int{1, 10, 11, 2} // sorted lexicographically 424 cmp := func(a int, b string) int { 425 return strings.Compare(strconv.Itoa(a), b) 426 } 427 pos, found := BinarySearchFunc(data, "2", cmp) 428 if pos != 3 || !found { 429 t.Errorf("BinarySearchFunc(%v, %q, cmp) = %v, %v, want %v, %v", data, "2", pos, found, 3, true) 430 } 431} 432