Skip to content

Commit 7313155

Browse files
authored
Merge pull request #10 from mortacious/feature/reset_bar
Add a set() function to the ProgressBar class to allow resetting the bar to a specific value
2 parents 03a9266 + 1a71ba2 commit 7313155

File tree

7 files changed

+65
-17
lines changed

7 files changed

+65
-17
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .progress import ProgressBar
2+
from ._version import __version__

examples/example_nested_loops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sleep import usleep
2+
import numba as nb
3+
from numba_progress import ProgressBar
4+
5+
@nb.njit(nogil=True)
6+
def numba_sleeper(num_iterations, sleep_us, progress):
7+
for i in range(num_iterations):
8+
progress[0].update()
9+
for j in range(num_iterations):
10+
usleep(sleep_us)
11+
progress[1].update(1)
12+
# reset the second progress bar to 0
13+
progress[1].set(0)
14+
15+
16+
if __name__ == "__main__":
17+
num_iterations = 30
18+
sleep_time_us = 25_000
19+
with ProgressBar(total=num_iterations, ncols=80) as numba_progress1, ProgressBar(total=num_iterations, ncols=80) as numba_progress2:
20+
# note: progressbar object must be passed as a tuple (a list will not work due to different treatment in numba)
21+
numba_sleeper(num_iterations, sleep_time_us, (numba_progress1, numba_progress2))

examples/example_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55

66
@nb.njit(nogil=True, parallel=True)
7-
def numba_parallel_sleeper(num_iterations, sleep_us, progress_hook):
7+
def numba_parallel_sleeper(num_iterations, sleep_us, progress_hook=None):
88
for i in nb.prange(num_iterations):
99
usleep(sleep_us)
10-
progress_hook.update(1)
10+
if progress_hook is not None:
11+
progress_hook.update(1)
1112

1213

1314
if __name__ == "__main__":

examples/example_sequential.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55

66
@nb.njit(nogil=True)
7-
def numba_sleeper(num_iterations, sleep_us, progress_hook):
7+
def numba_sleeper(num_iterations, sleep_us, progress_hook=None):
88
for i in range(num_iterations):
99
usleep(sleep_us)
10-
progress_hook.update(1)
10+
if progress_hook is not None:
11+
progress_hook.update(1)
1112

1213

1314
if __name__ == "__main__":

numba_progress/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0"
1+
__version__ = "1.1.0"

numba_progress/numba_atomic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from numba.extending import lower_builtin, type_callable
2222
from numba.np.arrayobj import basic_indexing, make_array, normalize_indices
2323

24-
__all__ = ["atomic_add", "atomic_sub", "atomic_max", "atomic_min"]
24+
__all__ = ["atomic_add", "atomic_sub", "atomic_max", "atomic_min", "atomic_xchg"]
2525

2626

2727
def atomic_rmw(context, builder, op, arrayty, val, ptr):
@@ -147,4 +147,10 @@ def atomic_min(ary, i, v):
147147
"""
148148
orig = ary[i]
149149
ary[i] = min(ary[i], v)
150+
return orig
151+
152+
@declare_atomic_array_op("xchg", "xchg", "xchg")
153+
def atomic_xchg(ary, i, v):
154+
orig = ary[i]
155+
ary[i] = v
150156
return orig

numba_progress/progress.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from threading import Thread, Event
88

99
from numba.extending import overload_method, typeof_impl, as_numba_type, models, register_model, \
10-
make_attribute_wrapper, overload_attribute, unbox, NativeValue, box
11-
from .numba_atomic import atomic_add
10+
make_attribute_wrapper, overload_attribute, unbox, NativeValue, box, lower_getattr, lower_setattr
11+
from .numba_atomic import atomic_add, atomic_xchg
1212
from numba import types
1313
from numba.core import cgutils
1414
from numba.core.boxing import unbox_array
@@ -92,18 +92,24 @@ def close(self):
9292
self._tqdm.close()
9393

9494
@property
95-
def value(self):
95+
def n(self):
9696
return self.hook[0]
97+
98+
def set(self, n=0):
99+
atomic_xchg(self.hook, 0, n)
100+
self._update_tqdm()
97101

98102
def update(self, n=1):
99103
atomic_add(self.hook, 0, n)
100104
self._update_tqdm()
101105

102106
def _update_tqdm(self):
103-
value = self.value
104-
diff = value - self._last_value
105-
self._last_value = value
106-
self._tqdm.update(diff)
107+
value = self.hook[0]
108+
#diff = value - self._last_value
109+
#self._last_value = value
110+
self._tqdm.n = value
111+
self._tqdm.refresh()
112+
#self._tqdm.update(diff)
107113

108114
def _update_function(self):
109115
"""Background thread for updating the progress bar.
@@ -151,11 +157,12 @@ def __init__(self, dmm, fe_type):
151157
make_attribute_wrapper(ProgressBarTypeImpl, 'hook', 'hook')
152158

153159

154-
@overload_attribute(ProgressBarTypeImpl, 'value')
160+
161+
@overload_attribute(ProgressBarTypeImpl, 'n')
155162
def get_value(progress_bar):
156-
def getter(progress_bar):
157-
return progress_bar.hook[0]
158-
return getter
163+
def getter(progress_bar):
164+
return progress_bar.hook[0]
165+
return getter
159166

160167

161168
@unbox(ProgressBarTypeImpl)
@@ -186,5 +193,15 @@ def _ol_update(self, n=1):
186193
def _update_impl(self, n=1):
187194
atomic_add(self.hook, 0, n)
188195
return _update_impl
196+
197+
@overload_method(ProgressBarTypeImpl, "set", jit_options={"nogil": True})
198+
def _ol_set(self, n=0):
199+
"""
200+
Numpy implementation of the update method.
201+
"""
202+
if isinstance(self, ProgressBarTypeImpl):
203+
def _set_impl(self, n=0):
204+
atomic_xchg(self.hook, 0, n)
205+
return _set_impl
189206

190207

0 commit comments

Comments
 (0)