1 /++ 2 This is a submodule of $(MREF mir,ndslice). 3 4 Note: 5 The combination of 6 $(SUBREF topology, pairwise) with lambda `"a <= b"` (`"a < b"`) and $(SUBREF algorithm, all) can be used 7 to check if an ndslice is sorted (strictly monotonic). 8 $(SUBREF topology, iota) can be used to make an index. 9 $(SUBREF topology, map) and $(SUBREF topology, zip) can be used to create Schwartzian transform. 10 See also the examples in the module. 11 12 13 See_also: $(SUBREF topology, flattened) 14 15 `isSorted` and `isStrictlyMonotonic` 16 17 License: $(HTTP www.apache.org/licenses/LICENSE-2.0, Apache-2.0) 18 Copyright: 2020 Ilya Yaroshenko, Kaleidic Associates Advisory Limited, Symmetry Investments 19 Authors: Andrei Alexandrescu (Phobos), Ilya Yaroshenko (API, rework, Mir adoptation) 20 21 Macros: 22 SUBREF = $(REF_ALTTEXT $(TT $2), $2, mir, ndslice, $1)$(NBSP) 23 +/ 24 module mir.ndslice.sorting; 25 26 /// Check if ndslice is sorted, or strictly monotonic. 27 @safe pure version(mir_test) unittest 28 { 29 import mir.algorithm.iteration: all; 30 import mir.ndslice.slice: sliced; 31 import mir.ndslice.sorting: sort; 32 import mir.ndslice.topology: pairwise; 33 34 auto arr = [1, 1, 2].sliced; 35 36 assert(arr.pairwise!"a <= b".all); 37 assert(!arr.pairwise!"a < b".all); 38 39 arr = [4, 3, 2, 1].sliced; 40 41 assert(!arr.pairwise!"a <= b".all); 42 assert(!arr.pairwise!"a < b".all); 43 44 sort(arr); 45 46 assert(arr.pairwise!"a <= b".all); 47 assert(arr.pairwise!"a < b".all); 48 } 49 50 /// Create index 51 version(mir_test) unittest 52 { 53 import mir.algorithm.iteration: all; 54 import mir.ndslice.allocation: slice; 55 import mir.ndslice.slice: sliced; 56 import mir.ndslice.sorting: sort; 57 import mir.ndslice.topology: iota, pairwise; 58 59 auto arr = [4, 2, 3, 1].sliced; 60 61 auto index = arr.length.iota.slice; 62 index.sort!((a, b) => arr[a] < arr[b]); 63 64 assert(arr[index].pairwise!"a <= b".all); 65 } 66 67 /// Schwartzian transform 68 version(mir_test) unittest 69 { 70 import mir.algorithm.iteration: all; 71 import mir.ndslice.allocation: slice; 72 import mir.ndslice.slice: sliced; 73 import mir.ndslice.sorting: sort; 74 import mir.ndslice.topology: zip, map, pairwise; 75 76 alias transform = (a) => (a - 3) ^^ 2; 77 78 auto arr = [4, 2, 3, 1].sliced; 79 80 arr.map!transform.slice.zip(arr).sort!((l, r) => l.a < r.a); 81 82 assert(arr.map!transform.pairwise!"a <= b".all); 83 } 84 85 import mir.ndslice.slice; 86 import mir.math.common: optmath; 87 88 @optmath: 89 90 @safe pure version(mir_test) unittest 91 { 92 import mir.algorithm.iteration: all; 93 import mir.ndslice.topology: pairwise; 94 95 auto a = [1, 2, 3].sliced; 96 assert(a[0 .. 0].pairwise!"a <= b".all); 97 assert(a[0 .. 1].pairwise!"a <= b".all); 98 assert(a.pairwise!"a <= b".all); 99 auto b = [1, 3, 2].sliced; 100 assert(!b.pairwise!"a <= b".all); 101 102 // ignores duplicates 103 auto c = [1, 1, 2].sliced; 104 assert(c.pairwise!"a <= b".all); 105 } 106 107 @safe pure version(mir_test) unittest 108 { 109 import mir.algorithm.iteration: all; 110 import mir.ndslice.topology: pairwise; 111 112 assert([1, 2, 3][0 .. 0].sliced.pairwise!"a < b".all); 113 assert([1, 2, 3][0 .. 1].sliced.pairwise!"a < b".all); 114 assert([1, 2, 3].sliced.pairwise!"a < b".all); 115 assert(![1, 3, 2].sliced.pairwise!"a < b".all); 116 assert(![1, 1, 2].sliced.pairwise!"a < b".all); 117 } 118 119 120 /++ 121 Sorts ndslice, array, or series. 122 123 See_also: $(SUBREF topology, flattened). 124 +/ 125 template sort(alias less = "a < b") 126 { 127 import mir.functional: naryFun; 128 import mir.series: Series; 129 static if (__traits(isSame, naryFun!less, less)) 130 { 131 @optmath: 132 /++ 133 Sort n-dimensional slice. 134 +/ 135 Slice!(Iterator, N, kind) sort(Iterator, size_t N, SliceKind kind) 136 (Slice!(Iterator, N, kind) slice) 137 { 138 if (false) // break safety 139 { 140 import mir.utility : swapStars; 141 auto elem = typeof(*slice._iterator).init; 142 elem = elem; 143 auto l = less(elem, elem); 144 } 145 import mir.ndslice.topology: flattened; 146 if (slice.anyEmpty) 147 return slice; 148 .quickSortImpl!less(slice.flattened); 149 return slice; 150 } 151 152 /++ 153 Sort for arrays 154 +/ 155 T[] sort(T)(T[] ar) 156 { 157 return .sort!less(ar.sliced).field; 158 } 159 160 /++ 161 Sort for one-dimensional Series. 162 +/ 163 Series!(IndexIterator, Iterator, N, kind) 164 sort(IndexIterator, Iterator, size_t N, SliceKind kind) 165 (Series!(IndexIterator, Iterator, N, kind) series) 166 if (N == 1) 167 { 168 import mir.ndslice.sorting: sort; 169 import mir.ndslice.topology: zip; 170 with(series) 171 index.zip(data).sort!((a, b) => less(a.a, b.a)); 172 return series; 173 } 174 175 /++ 176 Sort for n-dimensional Series. 177 +/ 178 Series!(IndexIterator, Iterator, N, kind) 179 sort( 180 IndexIterator, 181 Iterator, 182 size_t N, 183 SliceKind kind, 184 SortIndexIterator, 185 DataIterator, 186 ) 187 ( 188 Series!(IndexIterator, Iterator, N, kind) series, 189 Slice!SortIndexIterator indexBuffer, 190 Slice!DataIterator dataBuffer, 191 ) 192 { 193 import mir.algorithm.iteration: each; 194 import mir.ndslice.sorting: sort; 195 import mir.ndslice.topology: iota, zip, ipack, evertPack; 196 197 assert(indexBuffer.length == series.length); 198 assert(dataBuffer.length == series.length); 199 indexBuffer[] = indexBuffer.length.iota!(typeof(indexBuffer.front)); 200 series.index.zip(indexBuffer).sort!((a, b) => less(a.a, b.a)); 201 series.data.ipack!1.evertPack.each!((sl){ 202 { 203 assert(sl.shape == dataBuffer.shape); 204 dataBuffer[] = sl[indexBuffer]; 205 sl[] = dataBuffer; 206 }}); 207 return series; 208 } 209 } 210 else 211 alias sort = .sort!(naryFun!less); 212 } 213 214 /// 215 @safe pure version(mir_test) unittest 216 { 217 import mir.algorithm.iteration: all; 218 import mir.ndslice.slice; 219 import mir.ndslice.sorting: sort; 220 import mir.ndslice.topology: pairwise; 221 222 int[10] arr = [7,1,3,2,9,0,5,4,8,6]; 223 224 auto data = arr[].sliced(arr.length); 225 data.sort(); 226 assert(data.pairwise!"a <= b".all); 227 } 228 229 /// one-dimensional series 230 pure version(mir_test) unittest 231 { 232 import mir.series; 233 234 auto index = [4, 2, 1, 3, 0].sliced; 235 auto data = [5.6, 3.4, 2.1, 7.8, 0.1].sliced; 236 auto series = index.series(data); 237 series.sort; 238 assert(series.index == [0, 1, 2, 3, 4]); 239 assert(series.data == [0.1, 2.1, 3.4, 7.8, 5.6]); 240 /// initial index and data are the same 241 assert(index.iterator is series.index.iterator); 242 assert(data.iterator is series.data.iterator); 243 244 foreach(obs; series) 245 { 246 static assert(is(typeof(obs) == Observation!(int, double))); 247 } 248 } 249 250 /// two-dimensional series 251 pure version(mir_test) unittest 252 { 253 import mir.series; 254 import mir.ndslice.allocation: uninitSlice; 255 256 auto index = [4, 2, 3, 1].sliced; 257 auto data = 258 [2.1, 3.4, 259 5.6, 7.8, 260 3.9, 9.0, 261 4.0, 2.0].sliced(4, 2); 262 auto series = index.series(data); 263 264 series.sort( 265 uninitSlice!size_t(series.length), // index buffer 266 uninitSlice!double(series.length), // data buffer 267 ); 268 269 assert(series.index == [1, 2, 3, 4]); 270 assert(series.data == 271 [[4.0, 2.0], 272 [5.6, 7.8], 273 [3.9, 9.0], 274 [2.1, 3.4]]); 275 /// initial index and data are the same 276 assert(index.iterator is series.index.iterator); 277 assert(data.iterator is series.data.iterator); 278 } 279 280 void quickSortImpl(alias less, Iterator)(Slice!Iterator slice) @trusted 281 { 282 import mir.utility : swap, swapStars; 283 284 enum max_depth = 64; 285 enum naive_est = 1024 / slice.Element!0.sizeof; 286 enum size_t naive = 32 > naive_est ? 32 : naive_est; 287 //enum size_t naive = 1; 288 static assert(naive >= 1); 289 290 for(;;) 291 { 292 auto l = slice._iterator; 293 auto r = l; 294 r += slice.length; 295 296 static if (naive > 1) 297 { 298 if (slice.length <= naive || __ctfe) 299 { 300 auto p = r; 301 --p; 302 while(p != l) 303 { 304 --p; 305 //static if (is(typeof(() nothrow 306 // { 307 // auto t = slice[0]; if (less(t, slice[0])) slice[0] = slice[0]; 308 // }))) 309 //{ 310 auto d = p; 311 import mir.functional: unref; 312 auto temp = unref(*d); 313 auto c = d; 314 ++c; 315 if (less(*c, temp)) 316 { 317 do 318 { 319 d[0] = *c; 320 ++d; 321 ++c; 322 } 323 while (c != r && less(*c, temp)); 324 d[0] = temp; 325 } 326 //} 327 //else 328 //{ 329 // auto d = p; 330 // auto c = d; 331 // ++c; 332 // while (less(*c, *d)) 333 // { 334 // swap(*d, *c); 335 // d = c; 336 // ++c; 337 // if (c == maxJ) break; 338 // } 339 //} 340 } 341 return; 342 } 343 } 344 else 345 { 346 if(slice.length <= 1) 347 return; 348 } 349 350 // partition 351 auto lessI = l; 352 --r; 353 auto pivotIdx = l + slice.length / 2; 354 setPivot!less(slice.length, l, pivotIdx, r); 355 import mir.functional: unref; 356 auto pivot = unref(*pivotIdx); 357 --lessI; 358 auto greaterI = r; 359 swapStars(pivotIdx, greaterI); 360 361 outer: for (;;) 362 { 363 do ++lessI; 364 while (less(*lessI, pivot)); 365 assert(lessI <= greaterI, "sort: invalid comparison function."); 366 for (;;) 367 { 368 if (greaterI == lessI) 369 break outer; 370 --greaterI; 371 if (!less(pivot, *greaterI)) 372 break; 373 } 374 assert(lessI <= greaterI, "sort: invalid comparison function."); 375 if (lessI == greaterI) 376 break; 377 swapStars(lessI, greaterI); 378 } 379 380 swapStars(r, lessI); 381 382 ptrdiff_t len = lessI - l; 383 auto tail = slice[len + 1 .. $]; 384 slice = slice[0 .. len]; 385 if (tail.length > slice.length) 386 swap(slice, tail); 387 quickSortImpl!less(tail); 388 } 389 } 390 391 void setPivot(alias less, Iterator)(size_t length, ref Iterator l, ref Iterator mid, ref Iterator r) @trusted 392 { 393 if (length < 512) 394 { 395 if (length >= 32) 396 medianOf!less(l, mid, r); 397 return; 398 } 399 auto quarter = length >> 2; 400 auto b = mid - quarter; 401 auto e = mid + quarter; 402 medianOf!less(l, e, mid, b, r); 403 } 404 405 void medianOf(alias less, bool leanRight = false, Iterator) 406 (ref Iterator a, ref Iterator b) @trusted 407 { 408 import mir.utility : swapStars; 409 410 if (less(*b, *a)) { 411 swapStars(a, b); 412 } 413 assert(!less(*b, *a)); 414 } 415 416 void medianOf(alias less, bool leanRight = false, Iterator) 417 (ref Iterator a, ref Iterator b, ref Iterator c) @trusted 418 { 419 import mir.utility : swapStars; 420 421 if (less(*c, *a)) // c < a 422 { 423 if (less(*a, *b)) // c < a < b 424 { 425 swapStars(a, b); 426 swapStars(a, c); 427 } 428 else // c < a, b <= a 429 { 430 swapStars(a, c); 431 if (less(*b, *a)) swapStars(a, b); 432 } 433 } 434 else // a <= c 435 { 436 if (less(*b, *a)) // b < a <= c 437 { 438 swapStars(a, b); 439 } 440 else // a <= c, a <= b 441 { 442 if (less(*c, *b)) swapStars(b, c); 443 } 444 } 445 assert(!less(*b, *a)); 446 assert(!less(*c, *b)); 447 } 448 449 void medianOf(alias less, bool leanRight = false, Iterator) 450 (ref Iterator a, ref Iterator b, ref Iterator c, ref Iterator d) @trusted 451 { 452 import mir.utility: swapStars; 453 454 static if (!leanRight) 455 { 456 // Eliminate the rightmost from the competition 457 if (less(*d, *c)) swapStars(c, d); // c <= d 458 if (less(*d, *b)) swapStars(b, d); // b <= d 459 medianOf!less(a, b, c); 460 } 461 else 462 { 463 // Eliminate the leftmost from the competition 464 if (less(*b, *a)) swapStars(a, b); // a <= b 465 if (less(*c, *a)) swapStars(a, c); // a <= c 466 medianOf!less(b, c, d); 467 } 468 } 469 470 void medianOf(alias less, bool leanRight = false, Iterator) 471 (ref Iterator a, ref Iterator b, ref Iterator c, ref Iterator d, ref Iterator e) @trusted 472 { 473 import mir.utility: swapStars; // Credit: Teppo Niinimäki 474 475 version(unittest) scope(success) 476 { 477 assert(!less(*c, *a)); 478 assert(!less(*c, *b)); 479 assert(!less(*d, *c)); 480 assert(!less(*e, *c)); 481 } 482 483 if (less(*c, *a)) swapStars(a, c); 484 if (less(*d, *b)) swapStars(b, d); 485 if (less(*d, *c)) 486 { 487 swapStars(c, d); 488 swapStars(a, b); 489 } 490 if (less(*e, *b)) swapStars(b, e); 491 if (less(*e, *c)) 492 { 493 swapStars(c, e); 494 if (less(*c, *a)) swapStars(a, c); 495 } 496 else 497 { 498 if (less(*c, *b)) swapStars(b, c); 499 } 500 } 501 502 503 /++ 504 Returns: `true` if a sorted array contains the value. 505 506 Params: 507 test = strict ordering symmetric predicate 508 509 For non-symmetric predicates please use a structure with two `opCall`s or an alias of two global functions, 510 that correponds to `(array[i], value)` and `(value, array[i])` cases. 511 512 See_also: $(LREF transitionIndex). 513 +/ 514 template assumeSortedContains(alias test = "a < b") 515 { 516 import mir.functional: naryFun; 517 static if (__traits(isSame, naryFun!test, test)) 518 { 519 @optmath: 520 /++ 521 Params: 522 slice = sorted one-dimensional slice or array. 523 v = value to test with. It is passed to second argument. 524 +/ 525 bool assumeSortedContains(Iterator, SliceKind kind, V) 526 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 527 { 528 auto ti = transitionIndex!test(slice, v); 529 return ti < slice.length && !test(v, slice[ti]); 530 } 531 532 /// ditto 533 bool assumeSortedContains(T, V)(scope T[] ar, auto ref scope const V v) 534 { 535 return .assumeSortedContains!test(ar.sliced, v); 536 } 537 } 538 else 539 alias assumeSortedContains = .assumeSortedContains!(naryFun!test); 540 } 541 542 /++ 543 Returns: the smallest index of a sorted array such 544 that the index corresponds to the arrays element at the index according to the predicate 545 and `-1` if the array doesn't contain corresponding element. 546 547 Params: 548 test = strict ordering symmetric predicate. 549 550 For non-symmetric predicates please use a structure with two `opCall`s or an alias of two global functions, 551 that correponds to `(array[i], value)` and `(value, array[i])` cases. 552 553 See_also: $(LREF transitionIndex). 554 +/ 555 template assumeSortedEqualIndex(alias test = "a < b") 556 { 557 import mir.functional: naryFun; 558 static if (__traits(isSame, naryFun!test, test)) 559 { 560 @optmath: 561 /++ 562 Params: 563 slice = sorted one-dimensional slice or array. 564 v = value to test with. It is passed to second argument. 565 +/ 566 sizediff_t assumeSortedEqualIndex(Iterator, SliceKind kind, V) 567 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 568 { 569 auto ti = transitionIndex!test(slice, v); 570 return ti < slice.length && !test(v, slice[ti]) ? ti : -1; 571 } 572 573 /// ditto 574 sizediff_t assumeSortedEqualIndex(T, V)(scope T[] ar, auto ref scope const V v) 575 { 576 return .assumeSortedEqualIndex!test(ar.sliced, v); 577 } 578 } 579 else 580 alias assumeSortedEqualIndex = .assumeSortedEqualIndex!(naryFun!test); 581 } 582 583 /// 584 version(mir_test) 585 @safe pure unittest 586 { 587 // sorted: a < b 588 auto a = [0, 1, 2, 3, 4, 6]; 589 590 assert(a.assumeSortedEqualIndex(2) == 2); 591 assert(a.assumeSortedEqualIndex(5) == -1); 592 593 // <= non strict predicates doesn't work 594 assert(a.assumeSortedEqualIndex!"a <= b"(2) == -1); 595 } 596 597 /++ 598 Computes transition index using binary search. 599 It is low-level API for lower and upper bounds of a sorted array. 600 601 Params: 602 test = ordering predicate for (`(array[i], value)`) pairs. 603 604 See_also: $(SUBREF topology, assumeSortedEqualIndex). 605 +/ 606 template transitionIndex(alias test = "a < b") 607 { 608 import mir.functional: naryFun; 609 static if (__traits(isSame, naryFun!test, test)) 610 { 611 @optmath: 612 /++ 613 Params: 614 slice = sorted one-dimensional slice or array. 615 v = value to test with. It is passed to second argument. 616 +/ 617 size_t transitionIndex(Iterator, SliceKind kind, V) 618 (auto ref Slice!(Iterator, 1, kind) slice, auto ref scope const V v) 619 { 620 size_t first = 0, count = slice.length; 621 while (count > 0) 622 { 623 immutable step = count / 2, it = first + step; 624 if (test(slice[it], v)) 625 { 626 first = it + 1; 627 count -= step + 1; 628 } 629 else 630 { 631 count = step; 632 } 633 } 634 return first; 635 } 636 637 /// ditto 638 size_t transitionIndex(T, V)(scope T[] ar, auto ref scope const V v) 639 { 640 return .transitionIndex!test(ar.sliced, v); 641 } 642 643 } 644 else 645 alias transitionIndex = .transitionIndex!(naryFun!test); 646 } 647 648 /// 649 version(mir_test) 650 @safe pure unittest 651 { 652 // sorted: a < b 653 auto a = [0, 1, 2, 3, 4, 6]; 654 655 auto i = a.transitionIndex(2); 656 assert(i == 2); 657 auto lowerBound = a[0 .. i]; 658 659 auto j = a.transitionIndex!"a <= b"(2); 660 assert(j == 3); 661 auto upperBound = a[j .. $]; 662 663 assert(a.transitionIndex(a[$ - 1]) == a.length - 1); 664 assert(a.transitionIndex!"a <= b"(a[$ - 1]) == a.length); 665 } 666 667 /++ 668 Computes an index for `r` based on the comparison `less`. The 669 index is a sorted array of indices into the original 670 range. 671 672 This technique is similar to sorting, but it is more flexible 673 because (1) it allows "sorting" of immutable collections, (2) allows 674 binary search even if the original collection does not offer random 675 access, (3) allows multiple indices, each on a different predicate, 676 and (4) may be faster when dealing with large objects. However, using 677 an index may also be slower under certain circumstances due to the 678 extra indirection, and is always larger than a sorting-based solution 679 because it needs space for the index in addition to the original 680 collection. The complexity is the same as `sort`'s. 681 682 Can be combined with $(SUBREF topology, indexed) to create a view that is sorted 683 based on the index. 684 685 Params: 686 less = The comparison to use. 687 r = The slice/array to index. 688 689 Returns: 690 Index slice/array. 691 692 See_Also: 693 $(HTTPS numpy.org/doc/stable/reference/generated/numpy.argsort.html, numpy.argsort) 694 +/ 695 Slice!(I*) makeIndex(I = size_t, alias less = "a < b", Iterator, SliceKind kind)(Slice!(Iterator, 1, kind) r) 696 { 697 import mir.functional: naryFun; 698 import mir.ndslice.allocation: slice; 699 import mir.ndslice.topology: iota; 700 return r 701 .length 702 .iota!I 703 .slice 704 .sort!((a, b) => naryFun!less(r[a], r[b])); 705 } 706 707 /// 708 I[] makeIndex(I = size_t, alias less = "a < b", T)(scope T[] r) 709 { 710 return .makeIndex!(I, less)(r.sliced).field; 711 } 712 713 /// 714 version(mir_test) 715 @safe pure nothrow 716 unittest 717 { 718 import mir.algorithm.iteration: all; 719 import mir.ndslice.topology: indexed, pairwise; 720 721 immutable arr = [ 2, 3, 1, 5, 0 ]; 722 auto index = arr.makeIndex; 723 724 assert(arr.indexed(index).pairwise!"a < b".all); 725 } 726 727 /// Sort based on index created from a separate array 728 version(mir_test) 729 @safe pure nothrow 730 unittest 731 { 732 import mir.algorithm.iteration: equal; 733 import mir.ndslice.topology: indexed; 734 735 immutable arr0 = [ 2, 3, 1, 5, 0 ]; 736 immutable arr1 = [ 1, 5, 4, 2, -1 ]; 737 738 auto index = makeIndex(arr0); 739 assert(index.equal([4, 2, 0, 1, 3])); 740 auto view = arr1.indexed(index); 741 assert(view.equal([-1, 4, 1, 5, 2])); 742 } 743 744 /++ 745 Partitions `slice` around `pivot` using comparison function `less`, algorithm 746 akin to $(LINK2 https://en.wikipedia.org/wiki/Quicksort#Hoare_partition_scheme, 747 Hoare partition). Specifically, permutes elements of `slice` and returns 748 an index `k < slice.length` such that: 749 750 $(UL 751 752 $(LI `slice[pivot]` is swapped to `slice[k]`) 753 754 755 $(LI All elements `e` in subrange `slice[0 .. k]` satisfy `!less(slice[k], e)` 756 (i.e. `slice[k]` is greater than or equal to each element to its left according 757 to predicate `less`)) 758 759 $(LI All elements `e` in subrange `slice[k .. $]` satisfy 760 `!less(e, slice[k])` (i.e. `slice[k]` is less than or equal to each element to 761 its right according to predicate `less`))) 762 763 If `slice` contains equivalent elements, multiple permutations of `slice` may 764 satisfy these constraints. In such cases, `pivotPartition` attempts to 765 distribute equivalent elements fairly to the left and right of `k` such that `k` 766 stays close to `slice.length / 2`. 767 768 Params: 769 less = The predicate used for comparison 770 771 Returns: 772 The new position of the pivot 773 774 See_Also: 775 $(HTTP jgrcs.info/index.php/jgrcs/article/view/142, Engineering of a Quicksort 776 Partitioning Algorithm), D. Abhyankar, Journal of Global Research in Computer 777 Science, February 2011. $(HTTPS youtube.com/watch?v=AxnotgLql0k, ACCU 2016 778 Keynote), Andrei Alexandrescu. 779 +/ 780 @trusted 781 template pivotPartition(alias less = "a < b") 782 { 783 import mir.functional: naryFun; 784 785 static if (__traits(isSame, naryFun!less, less)) 786 { 787 /++ 788 Params: 789 slice = slice being partitioned 790 pivot = The index of the pivot for partitioning, must be less than 791 `slice.length` or `0` if `slice.length` is `0` 792 +/ 793 size_t pivotPartition(Iterator, size_t N, SliceKind kind) 794 (Slice!(Iterator, N, kind) slice, 795 size_t pivot) 796 { 797 assert(pivot < slice.elementCount || slice.elementCount == 0 && pivot == 0, "pivotPartition: pivot must be less than the elementCount of the slice or the slice must be empty and pivot zero"); 798 799 if (slice.elementCount <= 1) return 0; 800 801 import mir.ndslice.topology: flattened; 802 803 auto flattenedSlice = slice.flattened; 804 auto frontI = flattenedSlice._iterator; 805 auto lastI = frontI + flattenedSlice.length - 1; 806 auto pivotI = frontI + pivot; 807 pivotPartitionImpl!less(frontI, lastI, pivotI); 808 return pivotI - frontI; 809 } 810 } else { 811 alias pivotPartition = .pivotPartition!(naryFun!less); 812 } 813 } 814 815 /// pivotPartition with 1-dimensional Slice 816 version(mir_test) 817 @safe pure nothrow 818 unittest 819 { 820 import mir.ndslice.slice: sliced; 821 import mir.algorithm.iteration: all; 822 823 auto x = [5, 3, 2, 6, 4, 1, 3, 7].sliced; 824 size_t pivot = pivotPartition(x, x.length / 2); 825 826 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 827 assert(x[pivot .. $].all!(a => a >= x[pivot])); 828 } 829 830 /// pivotPartition with 2-dimensional Slice 831 version(mir_test) 832 @safe pure 833 unittest 834 { 835 import mir.ndslice.fuse: fuse; 836 import mir.ndslice.topology: flattened; 837 import mir.algorithm.iteration: all; 838 839 auto x = [ 840 [5, 3, 2, 6], 841 [4, 1, 3, 7] 842 ].fuse; 843 844 size_t pivot = pivotPartition(x, x.elementCount / 2); 845 846 auto xFlattened = x.flattened; 847 assert(xFlattened[0 .. pivot].all!(a => a <= xFlattened[pivot])); 848 assert(xFlattened[pivot .. $].all!(a => a >= xFlattened[pivot])); 849 } 850 851 version(mir_test) 852 @safe 853 unittest 854 { 855 void test(alias less)() 856 { 857 import mir.ndslice.slice: sliced; 858 import mir.algorithm.iteration: all, equal; 859 860 Slice!(int*) x; 861 size_t pivot; 862 863 x = [-9, -4, -2, -2, 9].sliced; 864 pivot = pivotPartition!less(x, x.length / 2); 865 866 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 867 assert(x[pivot .. $].all!(a => a >= x[pivot])); 868 869 x = [9, 2, 8, -5, 5, 4, -8, -4, 9].sliced; 870 pivot = pivotPartition!less(x, x.length / 2); 871 872 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 873 assert(x[pivot .. $].all!(a => a >= x[pivot])); 874 875 x = [ 42 ].sliced; 876 pivot = pivotPartition!less(x, x.length / 2); 877 878 assert(pivot == 0); 879 assert(x.equal([ 42 ])); 880 881 x = [ 43, 42 ].sliced; 882 pivot = pivotPartition!less(x, 0); 883 assert(pivot == 1); 884 assert(x.equal([ 42, 43 ])); 885 886 x = [ 43, 42 ].sliced; 887 pivot = pivotPartition!less(x, 1); 888 889 assert(pivot == 0); 890 assert(x.equal([ 42, 43 ])); 891 892 x = [ 42, 42 ].sliced; 893 pivot = pivotPartition!less(x, 0); 894 895 assert(pivot == 0 || pivot == 1); 896 assert(x.equal([ 42, 42 ])); 897 898 pivot = pivotPartition!less(x, 1); 899 900 assert(pivot == 0 || pivot == 1); 901 assert(x.equal([ 42, 42 ])); 902 } 903 test!"a < b"; 904 static bool myLess(int a, int b) 905 { 906 static bool bogus; 907 if (bogus) throw new Exception(""); // just to make it no-nothrow 908 return a < b; 909 } 910 test!myLess; 911 } 912 913 @trusted 914 template pivotPartitionImpl(alias less) 915 { 916 void pivotPartitionImpl(Iterator) 917 (ref Iterator frontI, 918 ref Iterator lastI, 919 ref Iterator pivotI) 920 { 921 assert(pivotI <= lastI && pivotI >= frontI, "pivotPartition: pivot must be less than the length of slice or slice must be empty and pivot zero"); 922 923 if (frontI == lastI) return; 924 925 import mir.utility: swapStars; 926 927 // Pivot at the front 928 swapStars(pivotI, frontI); 929 930 // Fork implementation depending on nothrow copy, assignment, and 931 // comparison. If all of these are nothrow, use the specialized 932 // implementation discussed at 933 // https://youtube.com/watch?v=AxnotgLql0k. 934 static if (is(typeof( 935 () nothrow { auto x = frontI; x = frontI; return less(*x, *x); } 936 ))) 937 { 938 // Plant the pivot in the end as well as a sentinel 939 auto loI = frontI; 940 auto hiI = lastI; 941 auto save = *hiI; 942 *hiI = *frontI; // Vacancy is in r[$ - 1] now 943 944 // Start process 945 for (;;) 946 { 947 // Loop invariant 948 version(mir_test) 949 { 950 // this used to import std.algorithm.all, but we want to 951 // save imports when unittests are enabled if possible. 952 size_t len = lastI - frontI + 1; 953 foreach (x; 0 .. (loI - frontI)) 954 assert(!less(*frontI, frontI[x]), "pivotPartition: *frontI must not be less than frontI[x]"); 955 foreach (x; (hiI - frontI + 1) .. len) 956 assert(!less(frontI[x], *frontI), "pivotPartition: frontI[x] must not be less than *frontI"); 957 } 958 do ++loI; while (less(*loI, *frontI)); 959 *(hiI) = *(loI); 960 // Vacancy is now in slice[lo] 961 do --hiI; while (less(*frontI, *hiI)); 962 if (loI >= hiI) break; 963 *(loI) = *(hiI); 964 // Vacancy is not in slice[hi] 965 } 966 // Fixup 967 assert(loI - hiI <= 2, "pivotPartition: Following compare not possible"); 968 assert(!less(*frontI, *hiI), "pivotPartition: *hiI must not be less than *frontI"); 969 if (loI - hiI == 2) 970 { 971 assert(!less(hiI[1], *frontI), "pivotPartition: *(hiI + 1) must not be less than *frontI"); 972 *(loI) = hiI[1]; 973 --loI; 974 } 975 *loI = save; 976 if (less(*frontI, save)) --loI; 977 assert(!less(*frontI, *loI), "pivotPartition: *frontI must not be less than *loI"); 978 } else { 979 auto loI = frontI; 980 ++loI; 981 auto hiI = lastI; 982 983 loop: for (;; loI++, hiI--) 984 { 985 for (;; ++loI) 986 { 987 if (loI > hiI) break loop; 988 if (!less(*loI, *frontI)) break; 989 } 990 // found the left bound: !less(*loI, *frontI) 991 assert(loI <= hiI, "pivotPartition: loI must be less or equal than hiI"); 992 for (;; --hiI) 993 { 994 if (loI >= hiI) break loop; 995 if (!less(*frontI, *hiI)) break; 996 } 997 // found the right bound: !less(*frontI, *hiI), swap & make progress 998 assert(!less(*loI, *hiI), "pivotPartition: *lowI must not be less than *hiI"); 999 swapStars(loI, hiI); 1000 } 1001 --loI; 1002 } 1003 1004 swapStars(loI, frontI); 1005 pivotI = loI; 1006 } 1007 } 1008 1009 version(mir_test) 1010 @safe pure nothrow 1011 unittest { 1012 import mir.ndslice.sorting: partitionAt; 1013 import mir.ndslice.allocation: rcslice; 1014 auto x = rcslice!double(4); 1015 x[0] = 3; 1016 x[1] = 2; 1017 x[2] = 1; 1018 x[3] = 0; 1019 partitionAt!("a > b")(x, 2); 1020 } 1021 1022 1023 version(mir_test) 1024 @trusted pure nothrow 1025 unittest 1026 { 1027 import mir.ndslice.slice: sliced; 1028 import mir.algorithm.iteration: all; 1029 1030 auto x = [5, 3, 2, 6, 4, 1, 3, 7].sliced; 1031 auto frontI = x._iterator; 1032 auto lastI = x._iterator + x.length - 1; 1033 auto pivotI = frontI + x.length / 2; 1034 alias less = (a, b) => (a < b); 1035 pivotPartitionImpl!less(frontI, lastI, pivotI); 1036 size_t pivot = pivotI - frontI; 1037 1038 assert(x[0 .. pivot].all!(a => a <= x[pivot])); 1039 assert(x[pivot .. $].all!(a => a >= x[pivot])); 1040 } 1041 1042 /++ 1043 Partitions `slice`, such that all elements `e1` from `slice[0]` to `slice[nth]` 1044 satisfy `!less(slice[nth], e1)`, and all elements `e2` from `slice[nth]` to 1045 `slice[slice.length]` satisfy `!less(e2, slice[nth])`. This effectively reorders 1046 `slice` such that `slice[nth]` refers to the element that would fall there if 1047 the range were fully sorted. Performs an expected $(BIGOH slice.length) 1048 evaluations of `less` and `swap`, with a worst case of $(BIGOH slice.length^^2). 1049 1050 This function implements the [Fast, Deterministic Selection](https://erdani.com/research/sea2017.pdf) 1051 algorithm that is implemented in the [`topN`](https://dlang.org/library/std/algorithm/sorting/top_n.html) 1052 function in D's standard library (as of version `2.092.0`). 1053 1054 Params: 1055 less = The predicate to sort by. 1056 1057 See_Also: 1058 $(LREF pivotPartition), https://erdani.com/research/sea2017.pdf 1059 1060 +/ 1061 template partitionAt(alias less = "a < b") 1062 { 1063 import mir.functional: naryFun; 1064 1065 static if (__traits(isSame, naryFun!less, less)) 1066 { 1067 /++ 1068 Params: 1069 slice = n-dimensional slice 1070 nth = The index of the element that should be in sorted position after the 1071 function is finished. 1072 +/ 1073 void partitionAt(Iterator, size_t N, SliceKind kind) 1074 (Slice!(Iterator, N, kind) slice, size_t nth) @trusted nothrow @nogc 1075 { 1076 import mir.qualifier: lightScope; 1077 import core.lifetime: move; 1078 import mir.ndslice.topology: flattened; 1079 1080 assert(slice.elementCount > 0, "partitionAt: slice must have elementCount greater than 0"); 1081 assert(nth >= 0, "partitionAt: nth must be greater than or equal to zero"); 1082 assert(nth < slice.elementCount, "partitionAt: nth must be less than the elementCount of the slice"); 1083 1084 bool useSampling = true; 1085 auto flattenedSlice = slice.move.flattened; 1086 auto frontI = flattenedSlice._iterator.lightScope; 1087 auto lastI = frontI + (flattenedSlice.length - 1); 1088 partitionAtImpl!less(frontI, lastI, nth, useSampling); 1089 } 1090 } 1091 else 1092 alias partitionAt = .partitionAt!(naryFun!less); 1093 } 1094 1095 /// Partition 1-dimensional slice at nth 1096 version(mir_test) 1097 @safe pure nothrow 1098 unittest { 1099 import mir.ndslice.slice: sliced; 1100 1101 size_t nth = 2; 1102 auto x = [3, 1, 5, 2, 0].sliced; 1103 x.partitionAt(nth); 1104 assert(x[nth] == 2); 1105 } 1106 1107 /// Partition 2-dimensional slice 1108 version(mir_test) 1109 @safe pure nothrow 1110 unittest { 1111 import mir.ndslice.slice: sliced; 1112 1113 size_t nth = 4; 1114 auto x = [3, 1, 5, 2, 0, 7].sliced(3, 2); 1115 x.partitionAt(nth); 1116 assert(x[2, 0] == 5); 1117 } 1118 1119 /// Can supply alternate ordering function 1120 version(mir_test) 1121 @safe pure nothrow 1122 unittest { 1123 import mir.ndslice.slice: sliced; 1124 1125 size_t nth = 2; 1126 auto x = [3, 1, 5, 2, 0].sliced; 1127 x.partitionAt!("a > b")(nth); 1128 assert(x[nth] == 2); 1129 } 1130 1131 // Check issue #328 fixed 1132 version(mir_test) 1133 @safe pure nothrow 1134 unittest { 1135 import mir.ndslice.slice: sliced; 1136 1137 auto slice = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17].sliced; 1138 partitionAt(slice, 8); 1139 partitionAt(slice, 9); 1140 } 1141 1142 version(unittest) { 1143 template checkPartitionAtAll(alias less = "a < b") 1144 { 1145 import mir.functional: naryFun; 1146 import mir.ndslice.slice: SliceKind, Slice; 1147 1148 static if (__traits(isSame, naryFun!less, less)) 1149 { 1150 @safe pure nothrow 1151 static bool checkPartitionAtAll 1152 (Iterator, SliceKind kind)( 1153 Slice!(Iterator, 1, kind) x) 1154 { 1155 auto x_sorted = x.dup; 1156 x_sorted.sort!less; 1157 1158 bool result = true; 1159 1160 foreach (nth; 0 .. x.length) 1161 { 1162 auto x_i = x.dup; 1163 x_i.partitionAt!less(nth); 1164 if (x_i[nth] != x_sorted[nth]) { 1165 result = false; 1166 break; 1167 } 1168 } 1169 return result; 1170 } 1171 } else { 1172 alias checkPartitionAtAll = .checkPartitionAtAll!(naryFun!less); 1173 } 1174 } 1175 } 1176 1177 version(mir_test) 1178 @safe pure nothrow 1179 unittest { 1180 import mir.ndslice.slice: sliced; 1181 1182 assert(checkPartitionAtAll([2, 2].sliced)); 1183 1184 assert(checkPartitionAtAll([3, 1, 5, 2, 0].sliced)); 1185 assert(checkPartitionAtAll([3, 1, 5, 0, 2].sliced)); 1186 assert(checkPartitionAtAll([0, 0, 4, 3, 3].sliced)); 1187 assert(checkPartitionAtAll([5, 1, 5, 1, 5].sliced)); 1188 assert(checkPartitionAtAll([2, 2, 0, 0, 0].sliced)); 1189 1190 assert(checkPartitionAtAll([ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced)); 1191 assert(checkPartitionAtAll([ 4, 18, 16, 0, 15, 6, 2, 17, 10, 16].sliced)); 1192 assert(checkPartitionAtAll([ 7, 5, 9, 4, 4, 2, 12, 20, 15, 15].sliced)); 1193 1194 assert(checkPartitionAtAll([17, 87, 58, 50, 34, 98, 25, 77, 88, 79].sliced)); 1195 1196 assert(checkPartitionAtAll([ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced)); 1197 assert(checkPartitionAtAll([21, 3, 11, 22, 24, 12, 14, 12, 15, 15, 1, 3, 12, 15, 25, 19, 9, 16, 16, 19].sliced)); 1198 assert(checkPartitionAtAll([22, 6, 18, 0, 1, 8, 13, 13, 16, 19, 23, 17, 4, 6, 12, 24, 15, 20, 11, 17].sliced)); 1199 assert(checkPartitionAtAll([19, 23, 14, 5, 12, 3, 13, 7, 25, 25, 24, 9, 21, 25, 12, 22, 15, 22, 7, 11].sliced)); 1200 assert(checkPartitionAtAll([ 0, 2, 7, 16, 2, 20, 1, 11, 17, 5, 22, 17, 25, 13, 14, 5, 22, 21, 24, 14].sliced)); 1201 } 1202 1203 private @trusted pure nothrow @nogc 1204 void partitionAtImpl(alias less, Iterator)( 1205 Iterator loI, 1206 Iterator hiI, 1207 size_t n, 1208 bool useSampling) 1209 { 1210 assert(loI <= hiI, "partitionAtImpl: frontI must be less than or equal to lastI"); 1211 1212 import mir.utility: swapStars; 1213 import mir.functional: reverseArgs; 1214 1215 Iterator pivotI; 1216 size_t len; 1217 1218 for (;;) { 1219 len = hiI - loI + 1; 1220 1221 if (len <= 1) { 1222 break; 1223 } 1224 1225 if (n == 0) { 1226 pivotI = loI; 1227 foreach (i; 1 .. len) { 1228 if (less(loI[i], *pivotI)) { 1229 pivotI = loI + i; 1230 } 1231 } 1232 swapStars(loI + n, pivotI); 1233 break; 1234 } 1235 1236 if (n + 1 == len) { 1237 pivotI = loI; 1238 foreach (i; 1 .. len) { 1239 if (reverseArgs!less(loI[i], *pivotI)) { 1240 pivotI = loI + i; 1241 } 1242 } 1243 swapStars(loI + n, pivotI); 1244 break; 1245 } 1246 1247 if (len <= 12) { 1248 pivotI = loI + len / 2; 1249 pivotPartitionImpl!less(loI, hiI, pivotI); 1250 } else if (n * 16 <= (len - 1) * 7) { 1251 pivotI = partitionAtPartitionOffMedian!(less, false)(loI, hiI, n, useSampling); 1252 // Quality check 1253 if (useSampling) 1254 { 1255 auto pivot = pivotI - loI; 1256 if (pivot < n) 1257 { 1258 if (pivot * 4 < len) 1259 { 1260 useSampling = false; 1261 } 1262 } 1263 else if ((len - pivot) * 8 < len * 3) 1264 { 1265 useSampling = false; 1266 } 1267 } 1268 } else if (n * 16 >= (len - 1) * 9) { 1269 pivotI = partitionAtPartitionOffMedian!(less, true)(loI, hiI, n, useSampling); 1270 // Quality check 1271 if (useSampling) 1272 { 1273 auto pivot = pivotI - loI; 1274 if (pivot < n) 1275 { 1276 if (pivot * 8 < len * 3) 1277 { 1278 useSampling = false; 1279 } 1280 } 1281 else if ((len - pivot) * 4 < len) 1282 { 1283 useSampling = false; 1284 } 1285 } 1286 } else { 1287 pivotI = partitionAtPartition!less(loI, hiI, n, useSampling); 1288 // Quality check 1289 if (useSampling) { 1290 auto pivot = pivotI - loI; 1291 if (pivot * 9 < len * 2 || pivot * 9 > len * 7) 1292 { 1293 // Failed - abort sampling going forward 1294 useSampling = false; 1295 } 1296 } 1297 } 1298 1299 if (n < (pivotI - loI)) { 1300 hiI = pivotI - 1; 1301 } else if (n > (pivotI - loI)) { 1302 n -= (pivotI - loI + 1); 1303 loI = pivotI; 1304 ++loI; 1305 } else { 1306 break; 1307 } 1308 } 1309 } 1310 1311 version(mir_test) 1312 @trusted pure nothrow 1313 unittest { 1314 import mir.ndslice.slice: sliced; 1315 1316 size_t nth = 2; 1317 auto x = [3, 1, 5, 2, 0].sliced; 1318 auto frontI = x._iterator; 1319 auto lastI = frontI + x.elementCount - 1; 1320 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, 1, true); 1321 assert(x[nth] == 2); 1322 } 1323 1324 version(mir_test) 1325 @trusted pure nothrow 1326 unittest { 1327 import mir.ndslice.slice: sliced; 1328 1329 size_t nth = 4; 1330 auto x = [3, 1, 5, 2, 0, 7].sliced(3, 2); 1331 auto frontI = x._iterator; 1332 auto lastI = frontI + x.elementCount - 1; 1333 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1334 assert(x[2, 0] == 5); 1335 } 1336 1337 version(mir_test) 1338 @trusted pure nothrow 1339 unittest { 1340 import mir.ndslice.slice: sliced; 1341 1342 size_t nth = 1; 1343 auto x = [0, 0, 4, 3, 3].sliced; 1344 auto frontI = x._iterator; 1345 auto lastI = frontI + x.elementCount - 1; 1346 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1347 assert(x[nth] == 0); 1348 } 1349 1350 version(mir_test) 1351 @trusted pure nothrow 1352 unittest { 1353 import mir.ndslice.slice: sliced; 1354 1355 size_t nth = 2; 1356 auto x = [0, 0, 4, 3, 3].sliced; 1357 auto frontI = x._iterator; 1358 auto lastI = frontI + x.elementCount - 1; 1359 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1360 assert(x[nth] == 3); 1361 } 1362 1363 version(mir_test) 1364 @trusted pure nothrow 1365 unittest { 1366 import mir.ndslice.slice: sliced; 1367 1368 size_t nth = 3; 1369 auto x = [0, 0, 4, 3, 3].sliced; 1370 auto frontI = x._iterator; 1371 auto lastI = frontI + x.elementCount - 1; 1372 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1373 assert(x[nth] == 3); 1374 } 1375 1376 version(mir_test) 1377 @trusted pure nothrow 1378 unittest { 1379 import mir.ndslice.slice: sliced; 1380 1381 size_t nth = 4; 1382 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1383 auto frontI = x._iterator; 1384 auto lastI = frontI + x.elementCount - 1; 1385 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1386 assert(x[nth] == 7); 1387 } 1388 1389 version(mir_test) 1390 @trusted pure nothrow 1391 unittest { 1392 import mir.ndslice.slice: sliced; 1393 1394 size_t nth = 5; 1395 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1396 auto frontI = x._iterator; 1397 auto lastI = frontI + x.elementCount - 1; 1398 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1399 assert(x[nth] == 8); 1400 } 1401 1402 version(mir_test) 1403 @trusted pure nothrow 1404 unittest { 1405 import mir.ndslice.slice: sliced; 1406 1407 size_t nth = 6; 1408 auto x = [ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced; 1409 auto frontI = x._iterator; 1410 auto lastI = frontI + x.elementCount - 1; 1411 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, nth, true); 1412 assert(x[nth] == 10); 1413 } 1414 1415 // Check all partitionAt 1416 version(mir_test) 1417 @trusted pure nothrow 1418 unittest { 1419 import mir.ndslice.slice: sliced; 1420 import mir.ndslice.allocation: slice; 1421 1422 static immutable raw = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22]; 1423 1424 static void fill(T)(T x) { 1425 for (size_t i = 0; i < x.length; i++) { 1426 x[i] = raw[i]; 1427 } 1428 } 1429 auto x = slice!int(raw.length); 1430 fill(x); 1431 auto x_sort = x.dup; 1432 x_sort = x_sort.sort; 1433 size_t i = 0; 1434 while (i < raw.length) { 1435 auto frontI = x._iterator; 1436 auto lastI = frontI + x.length - 1; 1437 partitionAtImpl!((a, b) => (a < b))(frontI, lastI, i, true); 1438 assert(x[i] == x_sort[i]); 1439 fill(x); 1440 i++; 1441 } 1442 } 1443 1444 private @trusted pure nothrow @nogc 1445 Iterator partitionAtPartition(alias less, Iterator)( 1446 ref Iterator frontI, 1447 ref Iterator lastI, 1448 size_t n, 1449 bool useSampling) 1450 { 1451 size_t len = lastI - frontI + 1; 1452 1453 assert(len >= 9 && n < len, "partitionAtPartition: length must be longer than 9 and n must be less than r.length"); 1454 1455 size_t ninth = len / 9; 1456 size_t pivot = ninth / 2; 1457 // Position subrange r[loI .. hiI] to have length equal to ninth and its upper 1458 // median r[loI .. hiI][$ / 2] in exactly the same place as the upper median 1459 // of the entire range r[$ / 2]. This is to improve behavior for searching 1460 // the median in already sorted ranges. 1461 auto loI = frontI; 1462 loI += len / 2 - pivot; 1463 auto hiI = loI; 1464 hiI += ninth; 1465 1466 // We have either one straggler on the left, one on the right, or none. 1467 assert(loI - frontI <= lastI - hiI + 1 || lastI - hiI <= loI - frontI + 1, "partitionAtPartition: straggler check failed for loI, len, hiI"); 1468 assert(loI - frontI >= ninth * 4, "partitionAtPartition: loI - frontI >= ninth * 4"); 1469 assert((lastI + 1) - hiI >= ninth * 4, "partitionAtPartition: (lastI + 1) - hiI >= ninth * 4"); 1470 1471 // Partition in groups of 3, and the mid tertile again in groups of 3 1472 if (!useSampling) { 1473 auto loI_ = loI; 1474 loI_ -= ninth; 1475 auto hiI_ = hiI; 1476 hiI_ += ninth; 1477 p3!(less, Iterator)(frontI, lastI, loI_, hiI_); 1478 } 1479 p3!(less, Iterator)(frontI, lastI, loI, hiI); 1480 1481 // Get the median of medians of medians 1482 // Map the full interval of n to the full interval of the ninth 1483 pivot = (n * (ninth - 1)) / (len - 1); 1484 if (hiI > loI) { 1485 auto hiI_minus = hiI; 1486 --hiI_minus; 1487 partitionAtImpl!less(loI, hiI_minus, pivot, useSampling); 1488 } 1489 1490 auto pivotI = loI; 1491 pivotI += pivot; 1492 1493 return expandPartition!less(frontI, lastI, loI, pivotI, hiI); 1494 } 1495 1496 version(mir_test) 1497 @trusted pure nothrow 1498 unittest { 1499 import mir.ndslice.slice: sliced; 1500 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1501 auto x_sort = x.dup; 1502 x_sort = x_sort.sort; 1503 auto frontI = x._iterator; 1504 auto lastI = frontI + x.length - 1; 1505 size_t n = x.length / 2; 1506 partitionAtPartition!((a, b) => (a < b))(frontI, lastI, n, true); 1507 assert(x[n - 1] == x_sort[n - 1]); 1508 } 1509 1510 private @trusted pure nothrow @nogc 1511 Iterator partitionAtPartitionOffMedian(alias less, bool leanRight, Iterator)( 1512 ref Iterator frontI, 1513 ref Iterator lastI, 1514 size_t n, 1515 bool useSampling) 1516 { 1517 size_t len = lastI - frontI + 1; 1518 1519 assert(len >= 12, "partitionAtPartitionOffMedian: len must be greater than 11"); 1520 assert(n < len, "partitionAtPartitionOffMedian: n must be less than len"); 1521 auto _4 = len / 4; 1522 auto leftLimitI = frontI; 1523 static if (leanRight) 1524 leftLimitI += 2 * _4; 1525 else 1526 leftLimitI += _4; 1527 // Partition in groups of 4, and the left quartile again in groups of 3 1528 if (!useSampling) 1529 { 1530 auto leftLimit_plus_4 = leftLimitI; 1531 leftLimit_plus_4 += _4; 1532 p4!(less, leanRight)(frontI, lastI, leftLimitI, leftLimit_plus_4); 1533 } 1534 auto _12 = _4 / 3; 1535 auto loI = leftLimitI; 1536 loI += _12; 1537 auto hiI = loI; 1538 hiI += _12; 1539 p3!less(frontI, lastI, loI, hiI); 1540 1541 // Get the median of medians of medians 1542 // Map the full interval of n to the full interval of the ninth 1543 auto pivot = (n * (_12 - 1)) / (len - 1); 1544 if (hiI > loI) { 1545 auto hiI_minus = hiI; 1546 --hiI_minus; 1547 partitionAtImpl!less(loI, hiI_minus, pivot, useSampling); 1548 } 1549 auto pivotI = loI; 1550 pivotI += pivot; 1551 return expandPartition!less(frontI, lastI, loI, pivotI, hiI); 1552 } 1553 1554 version(mir_test) 1555 @trusted pure nothrow 1556 unittest { 1557 import mir.ndslice.slice: sliced; 1558 import mir.algorithm.iteration: equal; 1559 1560 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1561 auto frontI = x._iterator; 1562 auto lastI = frontI + x.length - 1; 1563 partitionAtPartitionOffMedian!((a, b) => (a < b), false)(frontI, lastI, 5, true); 1564 assert(x.equal([6, 7, 8, 9, 5, 0, 2, 7, 9, 15, 10, 25, 11, 10, 13, 18, 17, 13, 25, 22])); 1565 } 1566 1567 version(mir_test) 1568 @trusted pure nothrow 1569 unittest { 1570 import mir.ndslice.slice: sliced; 1571 import mir.algorithm.iteration: equal; 1572 1573 auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced; 1574 auto frontI = x._iterator; 1575 auto lastI = frontI + x.length - 1; 1576 partitionAtPartitionOffMedian!((a, b) => (a < b), true)(frontI, lastI, 15, true); 1577 assert(x.equal([6, 7, 8, 7, 5, 2, 9, 0, 9, 15, 25, 10, 11, 10, 13, 18, 17, 13, 25, 22])); 1578 } 1579 1580 private @trusted 1581 void p3(alias less, Iterator)( 1582 Iterator frontI, 1583 Iterator lastI, 1584 Iterator loI, 1585 Iterator hiI) 1586 { 1587 assert(loI <= hiI && hiI <= lastI, "p3: loI must be less than or equal to hiI and hiI must be less than or equal to lastI"); 1588 immutable diffI = hiI - loI; 1589 Iterator lo_loI; 1590 Iterator hi_loI; 1591 for (; loI < hiI; ++loI) 1592 { 1593 lo_loI = loI; 1594 lo_loI -= diffI; 1595 hi_loI = loI; 1596 hi_loI += diffI; 1597 assert(lo_loI >= frontI, "p3: lo_loI must be greater than or equal to frontI"); 1598 assert(hi_loI <= lastI, "p3: hi_loI must be less than or equal to lastI"); 1599 medianOf!less(lo_loI, loI, hi_loI); 1600 } 1601 } 1602 1603 version(mir_test) 1604 @trusted pure nothrow 1605 unittest { 1606 import mir.ndslice.slice: sliced; 1607 import mir.algorithm.iteration: equal; 1608 1609 auto x = [3, 4, 0, 5, 2, 1].sliced; 1610 auto frontI = x._iterator; 1611 auto lastI = frontI + x.length - 1; 1612 auto loI = frontI + 2; 1613 auto hiI = frontI + 4; 1614 p3!((a, b) => (a < b))(frontI, lastI, loI, hiI); 1615 assert(x.equal([0, 1, 2, 4, 3, 5])); 1616 } 1617 1618 private @trusted 1619 template p4(alias less, bool leanRight) 1620 { 1621 void p4(Iterator)( 1622 Iterator frontI, 1623 Iterator lastI, 1624 Iterator loI, 1625 Iterator hiI) 1626 { 1627 assert(loI <= hiI && hiI <= lastI, "p4: loI must be less than or equal to hiI and hiI must be less than or equal to lastI"); 1628 1629 immutable diffI = hiI - loI; 1630 immutable diffI2 = diffI * 2; 1631 1632 Iterator lo_loI; 1633 Iterator hi_loI; 1634 1635 static if (leanRight) 1636 Iterator lo2_loI; 1637 else 1638 Iterator hi2_loI; 1639 1640 for (; loI < hiI; ++loI) 1641 { 1642 lo_loI = loI - diffI; 1643 hi_loI = loI + diffI; 1644 1645 assert(lo_loI >= frontI, "p4: lo_loI must be greater than or equal to frontI"); 1646 assert(hi_loI <= lastI, "p4: hi_loI must be less than or equal to lastI"); 1647 1648 static if (leanRight) { 1649 lo2_loI = loI - diffI2; 1650 assert(lo2_loI >= frontI, "lo2_loI must be greater than or equal to frontI"); 1651 medianOf!(less, leanRight)(lo2_loI, lo_loI, loI, hi_loI); 1652 } else { 1653 hi2_loI = loI + diffI2; 1654 assert(hi2_loI <= lastI, "hi2_loI must be less than or equal to lastI"); 1655 medianOf!(less, leanRight)(lo_loI, loI, hi_loI, hi2_loI); 1656 } 1657 } 1658 } 1659 } 1660 1661 version(mir_test) 1662 @trusted pure nothrow 1663 unittest { 1664 import mir.ndslice.slice: sliced; 1665 import mir.algorithm.iteration: equal; 1666 1667 auto x = [3, 4, 0, 7, 2, 6, 5, 1, 4].sliced; 1668 auto frontI = x._iterator; 1669 auto lastI = frontI + x.length - 1; 1670 auto loI = frontI + 3; 1671 auto hiI = frontI + 5; 1672 p4!((a, b) => (a < b), false)(frontI, lastI, loI, hiI); 1673 assert(x.equal([3, 1, 0, 4, 2, 6, 4, 7, 5])); 1674 } 1675 1676 version(mir_test) 1677 @trusted pure nothrow 1678 unittest { 1679 import mir.ndslice.slice: sliced; 1680 import mir.algorithm.iteration: equal; 1681 1682 auto x = [3, 4, 0, 8, 2, 7, 5, 1, 4, 3].sliced; 1683 auto frontI = x._iterator; 1684 auto lastI = frontI + x.length - 1; 1685 auto loI = frontI + 4; 1686 auto hiI = frontI + 6; 1687 p4!((a, b) => (a < b), true)(frontI, lastI, loI, hiI); 1688 assert(x.equal([0, 4, 2, 1, 3, 7, 5, 8, 4, 3])); 1689 } 1690 1691 private @trusted 1692 template expandPartition(alias less) 1693 { 1694 Iterator expandPartition(Iterator)( 1695 ref Iterator frontI, 1696 ref Iterator lastI, 1697 ref Iterator loI, 1698 ref Iterator pivotI, 1699 ref Iterator hiI) 1700 { 1701 import mir.algorithm.iteration: all; 1702 1703 assert(frontI <= loI, "expandPartition: frontI must be less than or equal to loI"); 1704 assert(loI <= pivotI, "expandPartition: loI must be less than or equal pivotI"); 1705 assert(pivotI < hiI, "expandPartition: pivotI must be less than hiI"); 1706 assert(hiI <= lastI, "expandPartition: hiI must be less than or equal to lastI"); 1707 1708 foreach(x; loI .. (pivotI + 1)) 1709 assert(!less(*pivotI, *x), "expandPartition: loI .. (pivotI + 1) failed test"); 1710 foreach(x; (pivotI + 1) .. hiI) 1711 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. hiI failed test"); 1712 1713 import mir.utility: swapStars; 1714 import mir.algorithm.iteration: all; 1715 // We work with closed intervals! 1716 --hiI; 1717 1718 auto leftI = frontI; 1719 auto rightI = lastI; 1720 loop: for (;; ++leftI, --rightI) 1721 { 1722 for (;; ++leftI) 1723 { 1724 if (leftI == loI) break loop; 1725 if (!less(*leftI, *pivotI)) break; 1726 } 1727 for (;; --rightI) 1728 { 1729 if (rightI == hiI) break loop; 1730 if (!less(*pivotI, *rightI)) break; 1731 } 1732 swapStars(leftI, rightI); 1733 } 1734 1735 foreach(x; loI .. (pivotI + 1)) 1736 assert(!less(*pivotI, *x), "expandPartition: loI .. (pivotI + 1) failed less than test"); 1737 foreach(x; (pivotI + 1) .. (hiI + 1)) 1738 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. (hiI + 1) failed less than test"); 1739 foreach(x; frontI .. leftI) 1740 assert(!less(*pivotI, *x), "expandPartition: frontI .. leftI failed less than test"); 1741 foreach(x; (rightI + 1) .. (lastI + 1)) 1742 assert(!less(*x, *pivotI), "expandPartition: (rightI + 1) .. (lastI + 1) failed less than test"); 1743 1744 auto oldPivotI = pivotI; 1745 1746 if (leftI < loI) 1747 { 1748 // First loop: spend r[loI .. pivot] 1749 for (; loI < pivotI; ++leftI) 1750 { 1751 if (leftI == loI) goto done; 1752 if (!less(*oldPivotI, *leftI)) continue; 1753 --pivotI; 1754 assert(!less(*oldPivotI, *pivotI), "expandPartition: less check failed"); 1755 swapStars(leftI, pivotI); 1756 } 1757 // Second loop: make leftI and pivot meet 1758 for (;; ++leftI) 1759 { 1760 if (leftI == pivotI) goto done; 1761 if (!less(*oldPivotI, *leftI)) continue; 1762 for (;;) 1763 { 1764 if (leftI == pivotI) goto done; 1765 --pivotI; 1766 if (less(*pivotI, *oldPivotI)) 1767 { 1768 swapStars(leftI, pivotI); 1769 break; 1770 } 1771 } 1772 } 1773 } 1774 1775 // First loop: spend r[lo .. pivot] 1776 for (; hiI != pivotI; --rightI) 1777 { 1778 if (rightI == hiI) goto done; 1779 if (!less(*rightI, *oldPivotI)) continue; 1780 ++pivotI; 1781 assert(!less(*pivotI, *oldPivotI), "expandPartition: less check failed"); 1782 swapStars(rightI, pivotI); 1783 } 1784 // Second loop: make leftI and pivotI meet 1785 for (; rightI > pivotI; --rightI) 1786 { 1787 if (!less(*rightI, *oldPivotI)) continue; 1788 while (rightI > pivotI) 1789 { 1790 ++pivotI; 1791 if (less(*oldPivotI, *pivotI)) 1792 { 1793 swapStars(rightI, pivotI); 1794 break; 1795 } 1796 } 1797 } 1798 1799 done: 1800 swapStars(oldPivotI, pivotI); 1801 1802 1803 foreach(x; frontI .. (pivotI + 1)) 1804 assert(!less(*pivotI, *x), "expandPartition: frontI .. (pivotI + 1) failed test"); 1805 foreach(x; (pivotI + 1) .. (lastI + 1)) 1806 assert(!less(*x, *pivotI), "expandPartition: (pivotI + 1) .. (lastI + 1) failed test"); 1807 return pivotI; 1808 } 1809 } 1810 1811 version(mir_test) 1812 @trusted pure nothrow 1813 unittest 1814 { 1815 import mir.ndslice.slice: sliced; 1816 1817 auto a = [ 10, 5, 3, 4, 8, 11, 13, 3, 9, 4, 10 ].sliced; 1818 auto frontI = a._iterator; 1819 auto lastI = frontI + a.length - 1; 1820 auto loI = frontI + 4; 1821 auto pivotI = frontI + 5; 1822 auto hiI = frontI + 6; 1823 assert(expandPartition!((a, b) => a < b)(frontI, lastI, loI, pivotI, hiI) == (frontI + 9)); 1824 }