|
7 | 7 | from threading import Thread, Event |
8 | 8 |
|
9 | 9 | from numba.extending import overload_method, typeof_impl, as_numba_type, models, register_model, \ |
10 | | - make_attribute_wrapper, overload_attribute, unbox, NativeValue, box |
11 | | -from .numba_atomic import atomic_add |
| 10 | + make_attribute_wrapper, overload_attribute, unbox, NativeValue, box, lower_getattr, lower_setattr |
| 11 | +from .numba_atomic import atomic_add, atomic_xchg |
12 | 12 | from numba import types |
13 | 13 | from numba.core import cgutils |
14 | 14 | from numba.core.boxing import unbox_array |
@@ -92,18 +92,24 @@ def close(self): |
92 | 92 | self._tqdm.close() |
93 | 93 |
|
94 | 94 | @property |
95 | | - def value(self): |
| 95 | + def n(self): |
96 | 96 | return self.hook[0] |
| 97 | + |
| 98 | + def set(self, n=0): |
| 99 | + atomic_xchg(self.hook, 0, n) |
| 100 | + self._update_tqdm() |
97 | 101 |
|
98 | 102 | def update(self, n=1): |
99 | 103 | atomic_add(self.hook, 0, n) |
100 | 104 | self._update_tqdm() |
101 | 105 |
|
102 | 106 | def _update_tqdm(self): |
103 | | - value = self.value |
104 | | - diff = value - self._last_value |
105 | | - self._last_value = value |
106 | | - self._tqdm.update(diff) |
| 107 | + value = self.hook[0] |
| 108 | + #diff = value - self._last_value |
| 109 | + #self._last_value = value |
| 110 | + self._tqdm.n = value |
| 111 | + self._tqdm.refresh() |
| 112 | + #self._tqdm.update(diff) |
107 | 113 |
|
108 | 114 | def _update_function(self): |
109 | 115 | """Background thread for updating the progress bar. |
@@ -151,11 +157,12 @@ def __init__(self, dmm, fe_type): |
151 | 157 | make_attribute_wrapper(ProgressBarTypeImpl, 'hook', 'hook') |
152 | 158 |
|
153 | 159 |
|
154 | | -@overload_attribute(ProgressBarTypeImpl, 'value') |
| 160 | + |
| 161 | +@overload_attribute(ProgressBarTypeImpl, 'n') |
155 | 162 | def get_value(progress_bar): |
156 | | - def getter(progress_bar): |
157 | | - return progress_bar.hook[0] |
158 | | - return getter |
| 163 | + def getter(progress_bar): |
| 164 | + return progress_bar.hook[0] |
| 165 | + return getter |
159 | 166 |
|
160 | 167 |
|
161 | 168 | @unbox(ProgressBarTypeImpl) |
@@ -186,5 +193,15 @@ def _ol_update(self, n=1): |
186 | 193 | def _update_impl(self, n=1): |
187 | 194 | atomic_add(self.hook, 0, n) |
188 | 195 | return _update_impl |
| 196 | + |
| 197 | +@overload_method(ProgressBarTypeImpl, "set", jit_options={"nogil": True}) |
| 198 | +def _ol_set(self, n=0): |
| 199 | + """ |
| 200 | + Numpy implementation of the update method. |
| 201 | + """ |
| 202 | + if isinstance(self, ProgressBarTypeImpl): |
| 203 | + def _set_impl(self, n=0): |
| 204 | + atomic_xchg(self.hook, 0, n) |
| 205 | + return _set_impl |
189 | 206 |
|
190 | 207 |
|
0 commit comments