Skip to content

Commit

Permalink
refactor: _new_child_ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
HK-SHAO committed May 12, 2024
1 parent 83f8782 commit 1390806
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
10 changes: 5 additions & 5 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ def e(self):
yield "EF666"

g = G()
assert g._generator.send(None) == "A"
assert g._generator.send(None) == "B"
assert g._generator.send(None) == "C"
assert g._generator.send(None) == "D"
assert g._generator.send(None) == EmptyString.join(g.e())
assert iter(g).send(None) == "A"
assert iter(g).send(None) == "B"
assert iter(g).send(None) == "C"
assert iter(g).send(None) == "D"
assert iter(g).send(None) == EmptyString.join(g.e())
assert EmptyString.join(g) == "123"


Expand Down
2 changes: 1 addition & 1 deletion yieldlang/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def join(sep: Symbol, to_seq: Symbol, depth: int = -1) -> ProxySymbol:
Args:
sep (Symbol): The separator symbol.
to_seq (Symbol): The symbol to join.
depth (int | None): The maximum depth to flatten. If negative, flatten all symbols.
depth (int): The maximum depth to flatten. If negative, flatten all symbols.
Returns:
ProxySymbol: The joined symbol.
"""
Expand Down
33 changes: 21 additions & 12 deletions yieldlang/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ def __init__(self, sampler: BaseSampler | None = None) -> None:
"""The sampler to use for sampling symbols."""
self._top_ctx = YContextTree(max_depth=-1, cur_depth=0)
"""The root context for flattening symbols."""
self._generator: YGenerator = self.__iter_symbol(self.top)
"""The iterator to generate text."""
self._generator: YGenerator = self._iter_symbol(self.top)
"""The generator for the text."""

def __iter__(self) -> IteratorStr:
"""Get the iterator."""
def __iter__(self) -> YGenerator:
"""Get the generator."""
return self._generator

def __next__(self) -> str:
"""Get the next token."""
return next(self._generator)

def __iter_symbol(self, symbol: Symbol) -> YGenerator:
def _iter_symbol(self, symbol: Symbol) -> YGenerator:
"""Iterate over a symbol."""
try:
self._top_ctx.ret_value = []
Expand Down Expand Up @@ -103,13 +103,7 @@ def _flatten(self, symbol: Symbol, ctx: YContextTree) -> IteratorSymbol:
symbol (Symbol): The symbol to flatten.
ctx (FlattenContext): The context for flattening.
"""
child = YContextTree()
child.cur_depth = ctx.cur_depth + 1
child.max_depth = ctx.max_depth
ctx.children.append(child)
child.parent = ctx
ctx = child

ctx = self._new_child_ctx(ctx)
if ctx.max_depth > -1 and ctx.cur_depth > ctx.max_depth:
ctx.ret_value = symbol
yield symbol
Expand Down Expand Up @@ -139,3 +133,18 @@ def _flatten(self, symbol: Symbol, ctx: YContextTree) -> IteratorSymbol:
case _:
ctx.name = f"Invalid: {symbol}"
raise TypeError(f"Invalid symbol: {symbol}")

def _new_child_ctx(self, parent: YContextTree) -> YContextTree:
"""Create a new child context.
Args:
parent (YContextTree): The parent context.
Returns:
YContextTree: The child context.
"""
child = YContextTree()
child.cur_depth = parent.cur_depth + 1
child.max_depth = parent.max_depth
parent.children.append(child)
child.parent = parent
return child

0 comments on commit 1390806

Please sign in to comment.