@@ -163,17 +163,29 @@ def _get_gpu_id(task_context: TaskContext) -> int:
163163 return gpu_id
164164
165165
166+ # When changing default rmm memory resources we retain the old ones
167+ # in this global array singleton to so that any (C++) allocations using them can
168+ # invoke the corresponding deallocate methods. They will get cleaned up only when
169+ # the process exits. This avoids a segfault in the case of creating a new
170+ # SAM resource with a smaller headroom.
171+ _old_memory_resources = []
172+
173+ # keep track of last headroom to check if new sam mr is needed.
174+ _last_sam_headroom_size = None
175+
176+
166177def _configure_memory_resource (
167178 uvm_enabled : bool = False ,
168179 sam_enabled : bool = False ,
169180 sam_headroom : Optional [int ] = None ,
170- force_sam_headroom : bool = False ,
171181) -> None :
172182 import cupy as cp
173183 import rmm
174184 from cuda .bindings import runtime
175185 from rmm .allocators .cupy import rmm_cupy_allocator
176186
187+ global _last_sam_headroom_size
188+
177189 _SYSTEM_MEMORY_SUPPORTED = rmm ._cuda .gpu .getDeviceAttribute (
178190 runtime .cudaDeviceAttr .cudaDevAttrPageableMemoryAccess ,
179191 rmm ._cuda .gpu .getDevice (),
@@ -193,19 +205,24 @@ def _configure_memory_resource(
193205 if not type (rmm .mr .get_current_device_resource ()) == type (
194206 rmm .mr .SystemMemoryResource ()
195207 ):
208+ _old_memory_resources .append (rmm .mr .get_current_device_resource ())
209+ _last_sam_headroom_size = None
196210 mr = rmm .mr .SystemMemoryResource ()
197211 rmm .mr .set_current_device_resource (mr )
198212 elif sam_enabled and sam_headroom is not None :
199- if force_sam_headroom or not type (rmm .mr .get_current_device_resource ()) == type (
200- rmm .mr .SamHeadroomMemoryResource (headroom = sam_headroom )
201- ):
213+ if sam_headroom != _last_sam_headroom_size or not type (
214+ rmm .mr .get_current_device_resource ()
215+ ) == type (rmm .mr .SamHeadroomMemoryResource (headroom = sam_headroom )):
216+ _old_memory_resources .append (rmm .mr .get_current_device_resource ())
217+ _last_sam_headroom_size = sam_headroom
202218 mr = rmm .mr .SamHeadroomMemoryResource (headroom = sam_headroom )
203219 rmm .mr .set_current_device_resource (mr )
204220
205221 if uvm_enabled :
206222 if not type (rmm .mr .get_current_device_resource ()) == type (
207223 rmm .mr .ManagedMemoryResource ()
208224 ):
225+ _old_memory_resources .append (rmm .mr .get_current_device_resource ())
209226 rmm .mr .set_current_device_resource (rmm .mr .ManagedMemoryResource ())
210227
211228 if sam_enabled or uvm_enabled :
0 commit comments