@@ -752,6 +752,43 @@ def __init__(self, system, config):
752752 else :
753753 self ._gcmc_kwargs = None
754754
755+ # Limit the number of CPU threads available to Sire when running in parallel.
756+ if self ._is_gpu :
757+ # First get the total number of threads that are available to Sire.
758+ total_threads = _sr .legacy .Base .get_max_num_threads ()
759+
760+ # Get the number of GPU devices.
761+ devices = self ._get_gpu_devices (
762+ self ._config .platform ,
763+ log = False ,
764+ )
765+
766+ # Work out the number of GPU workers.
767+ num_gpu_workers = len (devices ) * self ._config .oversubscription_factor
768+
769+ # Adjust based on the maximum number of GPUs.
770+ if self ._config .max_gpus is not None :
771+ num_gpu_workers = min (
772+ self ._config .max_gpus * self ._config .oversubscription_factor ,
773+ num_gpu_workers ,
774+ )
775+
776+ # Divide the threads by the number of GPUs and oversubscribe factor.
777+ sire_threads = max (1 , total_threads // num_gpu_workers )
778+
779+ if self ._config .max_sire_threads is not None :
780+ if self ._config .max_sire_threads > sire_threads :
781+ _logger .warning (
782+ f"Requested 'max_sire_threads' of { self ._config .max_sire_threads } exceeds "
783+ f"the calculated maximum of { sire_threads } "
784+ )
785+ sire_threads = self ._config .max_sire_threads
786+
787+ _logger .info (f"Setting maximum Sire CPU threads to { sire_threads } " )
788+
789+ # Update the maximum number of threads.
790+ _sr .legacy .Base .set_max_num_threads (sire_threads )
791+
755792 def _check_space (self ):
756793 """
757794 Check if the system has a periodic space.
@@ -1423,7 +1460,7 @@ def _systems_are_same(system0, system1, num_gcmc_waters=0):
14231460 return True , None
14241461
14251462 @staticmethod
1426- def _get_gpu_devices (platform , oversubscription_factor = 1 ):
1463+ def _get_gpu_devices (platform , oversubscription_factor = 1 , log = True ):
14271464 """
14281465 Get list of available GPUs from CUDA_VISIBLE_DEVICES,
14291466 OPENCL_VISIBLE_DEVICES, or HIP_VISIBLE_DEVICES.
@@ -1437,6 +1474,9 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
14371474 oversubscription_factor: int
14381475 The number of concurrent workers per GPU. Default is 1.
14391476
1477+ log: bool
1478+ Whether to log the available devices. Default is True.
1479+
14401480 Returns
14411481 --------
14421482
@@ -1459,23 +1499,30 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
14591499 raise ValueError ("CUDA_VISIBLE_DEVICES not set" )
14601500 else :
14611501 available_devices = _os .environ .get ("CUDA_VISIBLE_DEVICES" ).split ("," )
1462- _logger .info (f"CUDA_VISIBLE_DEVICES set to { available_devices } " )
1502+ if log :
1503+ _logger .info (f"CUDA_VISIBLE_DEVICES set to { available_devices } " )
14631504 elif platform == "opencl" :
14641505 if _os .environ .get ("OPENCL_VISIBLE_DEVICES" ) is None :
14651506 raise ValueError ("OPENCL_VISIBLE_DEVICES not set" )
14661507 else :
14671508 available_devices = _os .environ .get ("OPENCL_VISIBLE_DEVICES" ).split ("," )
1468- _logger .info (f"OPENCL_VISIBLE_DEVICES set to { available_devices } " )
1509+ if log :
1510+ _logger .info (f"OPENCL_VISIBLE_DEVICES set to { available_devices } " )
14691511 elif platform == "hip" :
14701512 if _os .environ .get ("HIP_VISIBLE_DEVICES" ) is None :
14711513 raise ValueError ("HIP_VISIBLE_DEVICES not set" )
14721514 else :
14731515 available_devices = _os .environ .get ("HIP_VISIBLE_DEVICES" ).split ("," )
1474- _logger .info (f"HIP_VISIBLE_DEVICES set to { available_devices } " )
1516+ if log :
1517+ _logger .info (f"HIP_VISIBLE_DEVICES set to { available_devices } " )
14751518
14761519 num_gpus = len (available_devices )
1477- _logger .info (f"Number of GPUs available: { num_gpus } " )
1478- _logger .info (f"Number of concurrent workers per GPU: { oversubscription_factor } " )
1520+
1521+ if log :
1522+ _logger .info (f"Number of GPUs available: { num_gpus } " )
1523+ _logger .info (
1524+ f"Number of concurrent workers per GPU: { oversubscription_factor } "
1525+ )
14791526
14801527 return available_devices
14811528
0 commit comments