Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Introduce list sorting for comparable elements #2609

Closed
wants to merge 13 commits into from
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ what we publish.

### ⭐️ New

- Add a `sort` function for list of `ComparableCollectionElement`s.
[PR #2609](https://github.com/modularml/mojo/pull/2609) by
[@mzaks](https://github.com/mzaks)

- Mojo has introduced `@parameter for`, a new feature for compile-time
programming. `@parameter for` defines a for loop where the sequence and the
induction values in the sequence must be parameter values. For example:
Expand Down
63 changes: 61 additions & 2 deletions stdlib/src/builtin/bool.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ trait Boolable:
@register_passable("trivial")
struct Bool(
Stringable,
CollectionElement,
ComparableCollectionElement,
Boolable,
EqualityComparable,
Intable,
Indexer,
):
Expand Down Expand Up @@ -200,6 +199,66 @@ struct Bool(
self._as_scalar_bool(), rhs._as_scalar_bool()
)

@always_inline("nodebug")
fn __lt__(self, rhs: Self) -> Bool:
"""Compare this Bool to RHS using less-than comparison.

Args:
rhs: The rhs of the operation.

Returns:
True if self is False and rhs is True.
"""

return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred lt>`](
self._as_scalar_bool(), rhs._as_scalar_bool()
)

@always_inline("nodebug")
fn __le__(self, rhs: Self) -> Bool:
"""Compare this Bool to RHS using less-than-or-equal comparison.

Args:
rhs: The rhs of the operation.

Returns:
True if self is False and rhs is True or False.
"""

return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred le>`](
self._as_scalar_bool(), rhs._as_scalar_bool()
)

@always_inline("nodebug")
fn __gt__(self, rhs: Self) -> Bool:
"""Compare this Bool to RHS using greater-than comparison.

Args:
rhs: The rhs of the operation.

Returns:
True if self is True and rhs is False.
"""

return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred gt>`](
self._as_scalar_bool(), rhs._as_scalar_bool()
)

@always_inline("nodebug")
fn __ge__(self, rhs: Self) -> Bool:
"""Compare this Bool to RHS using greater-than-or-equal comparison.

Args:
rhs: The rhs of the operation.

Returns:
True if self is True and rhs is True or False.
"""

return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ge>`](
self._as_scalar_bool(), rhs._as_scalar_bool()
)

# ===-------------------------------------------------------------------===#
# Bitwise operations
# ===-------------------------------------------------------------------===#
Expand Down
88 changes: 88 additions & 0 deletions stdlib/src/builtin/sort.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,91 @@ fn _small_sort[
_sort_partial_3[type, cmp_fn](array, 0, 2, 3)
_sort_partial_3[type, cmp_fn](array, 1, 2, 3)
return


# ===----------------------------------------------------------------------=== #
# Comparable elements list sorting
# ===----------------------------------------------------------------------=== #


@always_inline
fn insertion_sort[type: ComparableCollectionElement](inout list: List[type]):
"""Sort list of the order comparable elements in-place with insertion sort algorithm.

Parameters:
type: The order comparable collection element type.

Args:
list: The list of the order comparable elements which will be sorted in-place.
"""
for i in range(1, len(list)):
var key = list[i]
var j = i - 1
while j >= 0 and key < list[j]:
list[j + 1] = list[j]
j -= 1
list[j + 1] = key


fn _quick_sort[
type: ComparableCollectionElement
](inout list: List[type], low: Int, high: Int):
"""Sort section of the list, between low and high, with quick sort algorithm in-place.

Parameters:
type: The order comparable collection element type.

Args:
list: The list of the order comparable elements which will be sorted in-place.
low: Int value identifying the lowest index of the list section to be sorted.
high: Int value identifying the highest index of the list section to be sorted.
"""

@always_inline
@parameter
fn _partition(low: Int, high: Int) -> Int:
var pivot = list[high]
var i = low - 1
for j in range(low, high):
if list[j] <= pivot:
i += 1
list[j], list[i] = list[i], list[j]
list[i + 1], list[high] = list[high], list[i + 1]
JoeLoser marked this conversation as resolved.
Show resolved Hide resolved
return i + 1

if low < high:
var pi = _partition(low, high)
_quick_sort(list, low, pi - 1)
_quick_sort(list, pi + 1, high)


@always_inline
fn quick_sort[type: ComparableCollectionElement](inout list: List[type]):
"""Sort list of the order comparable elements in-place with quick sort algorithm.

Parameters:
type: The order comparable collection element type.

Args:
list: The list of the order comparable elements which will be sorted in-place.
"""
_quick_sort(list, 0, len(list) - 1)


fn sort[
type: ComparableCollectionElement, slist_ub: Int = 64
](inout list: List[type]):
"""Sort list of the order comparable elements in-place. This function picks the best algorithm based on the list length.

Parameters:
type: The order comparable collection element type.
slist_ub: The upper bound for a list size which is considered small.

Args:
list: The list of the scalars which will be sorted in-place.
"""
var count = len(list)
if count <= slist_ub:
insertion_sort(list) # small lists are best sorted with insertion sort
else:
quick_sort(list) # others are best sorted with quick sort
1 change: 1 addition & 0 deletions stdlib/src/builtin/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ struct String(
Representable,
IntableRaising,
KeyElement,
Comparable,
Boolable,
Formattable,
ToFormatter,
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct StringLiteral(
KeyElement,
Boolable,
Formattable,
Comparable,
):
"""This type represents a string literal.

Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/builtin/value.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ trait StringableCollectionElement(CollectionElement, Stringable):
pass


trait ComparableCollectionElement(CollectionElement, EqualityComparable):
trait ComparableCollectionElement(CollectionElement, Comparable):
"""
This trait is a temporary solution to enable comparison of
collection elements as utilized in the `index` and `count` methods of
Expand Down
8 changes: 7 additions & 1 deletion stdlib/src/memory/unsafe_pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ from memory.memory import _free, _malloc
@register_passable("trivial")
struct UnsafePointer[
T: AnyType, address_space: AddressSpace = AddressSpace.GENERIC
](Boolable, CollectionElement, Stringable, Intable, EqualityComparable):
](
Boolable,
CollectionElement,
Stringable,
Intable,
Comparable,
):
"""This is a pointer type that can point to any generic value that is movable.

Parameters:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/utils/index.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ fn _bool_tuple_reduce[

@value
@register_passable("trivial")
struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable):
struct StaticIntTuple[size: Int](Sized, Stringable, Comparable):
"""A base struct that implements size agnostic index functions.

Parameters:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/utils/stringref.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct StringRef(
Stringable,
Hashable,
Boolable,
EqualityComparable,
Comparable,
):
"""
Represent a constant reference to a string, i.e. a sequence of characters
Expand Down
33 changes: 33 additions & 0 deletions stdlib/test/builtin/test_bool.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ def test_indexer():
assert_equal(0, Bool.__index__(False))


def test_comparisons():
assert_true(False == False)
assert_true(True == True)
assert_false(False == True)
assert_false(True == False)

assert_true(False != True)
assert_true(True != False)
assert_false(False != False)
assert_false(True != True)

assert_true(True > False)
assert_false(False > True)
assert_false(False > False)
assert_false(True > True)

assert_true(True >= False)
assert_false(False >= True)
assert_true(False >= False)
assert_true(True >= True)

assert_false(True < False)
assert_true(False < True)
assert_false(False < False)
assert_false(True < True)

assert_false(True <= False)
assert_true(False <= True)
assert_true(False <= False)
assert_true(True <= True)


def main():
test_bool_cast_to_int()
test_bool_none()
Expand All @@ -113,3 +145,4 @@ def main():
test_bitwise()
test_neg()
test_indexer()
test_comparisons()