Skip to content

Commit fdb2f63

Browse files
committed
Fix bug #2
1 parent 8937241 commit fdb2f63

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

pyfluent_iterables/fluent.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _iterable(self) -> Iterable[T]:
8888
def map(self, transform: Callable[[T], R]) -> "FluentIterable[R]":
8989
"""Returns a FluentIterable containing the results of applying the given transform function to each element in this Iterable."""
9090
iterable = self._iterable()
91-
return FluentFactoryWrapper(lambda: map(transform, iterable))
91+
return FluentFactoryWrapper(lambda: map(transform, iterable), maybe_sized=iterable)
9292

9393
def filter(self, predicate: Optional[Callable[[T], Any]] = None) -> "FluentIterable[T]":
9494
"""Returns a FluentIterable containing only elements matching the given predicate. If no predicate is given, returns only truthy elements."""
@@ -107,7 +107,7 @@ def filter_false(self, predicate: Optional[Callable[[T], Any]] = None) -> "Fluen
107107
def enumerate(self, start: int = 0) -> "FluentIterable[Tuple[int, T]]":
108108
"""Returns a FluentIterable over pairs of (index, element) for elements in the original Iterable. Indices start with the value of `start`."""
109109
iterable = self._iterable()
110-
return FluentFactoryWrapper(lambda: enumerate(iterable, start=start))
110+
return FluentFactoryWrapper(lambda: enumerate(iterable, start=start), maybe_sized=iterable)
111111

112112
def zip(self, *with_iterables: Iterable) -> "FluentIterable[Tuple]":
113113
"""Returns a sequence of tuples built from the elements of this iterable and other given iterables with the same index. The resulting Iterable ends as soon as the shortest input Iterable ends."""
@@ -229,10 +229,10 @@ def reversed(self) -> "FluentIterable[T]":
229229
"""
230230
iterable = self._iterable()
231231
if hasattr(iterable, "__reversed__"):
232-
return FluentFactoryWrapper(lambda: iterable.__reversed__()) # type: ignore[attr-defined]
232+
return FluentFactoryWrapper(lambda: iterable.__reversed__(), maybe_sized=iterable) # type: ignore[attr-defined]
233233
else:
234234
copy = list(iterable)
235-
return FluentFactoryWrapper(lambda: copy.__reversed__())
235+
return FluentFactoryWrapper(lambda: copy.__reversed__(), maybe_sized=copy)
236236

237237
def grouped(self, n: int) -> "FluentIterable[List[T]]":
238238
"""
@@ -356,42 +356,45 @@ def not_empty(self) -> bool:
356356
"""
357357
return any(True for _ in self._iterable())
358358

359-
def len(self) -> int:
359+
def len(self) -> int:
360360
"""Returns the number of elements in this iterable.
361361
Note that evaluation may result in iterating over the iterable if the wrapped collections doesn't implement the Sized contract.
362362
"""
363363
it = self.__iter__()
364364
if hasattr(it, "__len__"):
365-
return it.__len__() # type: ignore
365+
return cast(Sized, it).__len__()
366366
else:
367367
count = 0
368368
for _ in it:
369369
count += 1
370370
return count
371371

372372
def __len__(self) -> int:
373-
"""Returns the number of elements in this iterable.
374-
Note that evaluation may result in iterating over the iterable if the wrapped collections doesn't implement the Sized contract.
375-
"""
376-
return self.len()
373+
"""Returns the number of elements in this iterable, if it's known for the underlying iterable. Otherwise throws TypeError."""
374+
it = self.__iter__()
375+
if hasattr(it, "__len__"):
376+
return cast(Sized, it).__len__()
377+
# Raising exception instead of defering to len().
378+
# This is necessary, e.g., to work around undocumented behavior of len() which assumes __len__() is present only if size is known in advance
379+
raise TypeError(f"object of type '{type(self).__name__}' has no len()")
377380

378381
def sum(self):
379382
"""Returns the sum of elements in this iterable with the sum() built-in function"""
380-
return sum(self._iterable()) # type: ignore
383+
return sum(self._iterable()) # type: ignore
381384

382385
def min(self, key: Optional[Callable[[T], Any]] = None, default: Optional[T] = None):
383386
"""
384387
Return the smallest item in this iterable. The arguments have identical meaning to the min() built-in function:
385388
`key` specifies a function used to extract a comparison key, `default` specifies result value if this iterable is empty.
386389
"""
387-
return min(self._iterable(), key=key, default=default) # type: ignore
390+
return min(self._iterable(), key=key, default=default) # type: ignore
388391

389392
def max(self, key: Optional[Callable[[T], Any]] = None, default: Optional[T] = None):
390393
"""
391394
Return the smallest item in this iterable. The arguments have identical meaning to the min() built-in function:
392395
`key` specifies a function used to extract a comparison key, `default` specifies result value if this iterable is empty.
393396
"""
394-
return max(self._iterable(), key=key, default=default) # type: ignore
397+
return max(self._iterable(), key=key, default=default) # type: ignore
395398

396399
def reduce(
397400
self,
@@ -429,16 +432,22 @@ def _iterable(self):
429432
return self.inner
430433

431434

432-
433435
class FluentFactoryWrapper(FluentIterable[T]):
434436
"""Implementation for cases where a known factory to a non-reusable Iterator is available"""
435437

436-
def __init__(self, factory: Callable[[], Iterator[T]]):
438+
def __init__(self, factory: Callable[[], Iterator[T]], *, maybe_sized: Optional[Iterable] = None):
437439
self._factory = factory
440+
self._maybe_sized = maybe_sized
438441

439442
def __iter__(self) -> Iterator[T]:
440443
return self._factory()
441444

445+
def __len__(self) -> int:
446+
"""Returns the number of elements in this iterable, if it's known for the underlying iterable. Otherwise throws TypeError."""
447+
if hasattr(self._maybe_sized, "__len__"):
448+
return cast(Sized, self._maybe_sized).__len__()
449+
return super().__len__()
450+
442451
def _iterable(self):
443452
return self
444453

tests/test_fluent_iterable.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
from operator import mul
33
from typing import Iterable, Any, Callable, Mapping
4+
from unittest.mock import Mock
45

56
import pytest
67

@@ -762,6 +763,26 @@ def test_iterable_supports_contains():
762763
assert 9 not in f
763764

764765

766+
######
767+
# Regression tests
768+
######
769+
@pytest.mark.parametrize(
770+
"elements,operation",
771+
[
772+
([1, 2], lambda f: f.to_list()),
773+
([1, 2], lambda f: f.to_set()),
774+
([1, 2], lambda f: f.to_frozenset()),
775+
([1, 2], lambda f: f.to_tuple()),
776+
([("a", 1), ("b", 2)], lambda f: f.to_dict()),
777+
],
778+
)
779+
def test_does_not_invoke_map_transform_unnecessarily(elements, operation: Callable[[FluentIterable], Any]):
780+
transform = Mock(side_effect=lambda x: x)
781+
f = fluent(elements).map(transform)
782+
operation(f)
783+
assert transform.call_count == len(elements)
784+
785+
765786
######
766787
# Helper functions
767788
######

0 commit comments

Comments
 (0)