Skip to content

Commit 4b03093

Browse files
authored
Merge pull request #109 from OpenBioSim/fix_thread_limit
Divide Sire threads between number of GPU workers
2 parents 9aefda4 + e60dd05 commit 4b03093

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

src/somd2/config/_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
platform="auto",
131131
max_threads=None,
132132
max_gpus=None,
133+
max_sire_threads=None,
133134
opencl_platform_index=0,
134135
oversubscription_factor=1,
135136
replica_exchange=False,
@@ -346,6 +347,11 @@ def __init__(
346347
Maximum number of GPUs to use for simulation (Default None, uses all available.)
347348
Does nothing if platform is set to CPU.
348349
350+
max_sire_threads: int
351+
Maximum number of CPU threads to use within Sire (e.g. for I/O operations).
352+
(Default None, divides the total available threads between the number of
353+
GPUs multiplied by the oversubscription factor.)
354+
349355
opencl_platform_index: int
350356
The OpenCL platform index to use when multiple OpenCL implementations are
351357
available on the system.
@@ -529,6 +535,7 @@ def __init__(
529535
self.platform = platform
530536
self.max_threads = max_threads
531537
self.max_gpus = max_gpus
538+
self.max_sire_threads = max_sire_threads
532539
self.opencl_platform_index = opencl_platform_index
533540
self.oversubscription_factor = oversubscription_factor
534541
self.replica_exchange = replica_exchange
@@ -1552,6 +1559,20 @@ def max_gpus(self, max_gpus):
15521559
"CPU platform requested but max_gpus set - ignoring max_gpus"
15531560
)
15541561

1562+
@property
1563+
def max_sire_threads(self):
1564+
return self._max_sire_threads
1565+
1566+
@max_sire_threads.setter
1567+
def max_sire_threads(self, max_sire_threads):
1568+
if max_sire_threads is not None:
1569+
try:
1570+
self._max_sire_threads = int(max_sire_threads)
1571+
except:
1572+
raise ValueError("'max_sire_threads' must be of type 'int'")
1573+
else:
1574+
self._max_sire_threads = None
1575+
15551576
@property
15561577
def opencl_platform_index(self):
15571578
return self._opencl_platform_index

src/somd2/runner/_base.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,43 @@ def __init__(self, system, config):
752752
else:
753753
self._gcmc_kwargs = None
754754

755+
# Limit the number of CPU threads available to Sire when running in parallel.
756+
if self._is_gpu:
757+
# First get the total number of threads that are available to Sire.
758+
total_threads = _sr.legacy.Base.get_max_num_threads()
759+
760+
# Get the number of GPU devices.
761+
devices = self._get_gpu_devices(
762+
self._config.platform,
763+
log=False,
764+
)
765+
766+
# Work out the number of GPU workers.
767+
num_gpu_workers = len(devices) * self._config.oversubscription_factor
768+
769+
# Adjust based on the maximum number of GPUs.
770+
if self._config.max_gpus is not None:
771+
num_gpu_workers = min(
772+
self._config.max_gpus * self._config.oversubscription_factor,
773+
num_gpu_workers,
774+
)
775+
776+
# Divide the threads by the number of GPUs and oversubscribe factor.
777+
sire_threads = max(1, total_threads // num_gpu_workers)
778+
779+
if self._config.max_sire_threads is not None:
780+
if self._config.max_sire_threads > sire_threads:
781+
_logger.warning(
782+
f"Requested 'max_sire_threads' of {self._config.max_sire_threads} exceeds "
783+
f"the calculated maximum of {sire_threads}"
784+
)
785+
sire_threads = self._config.max_sire_threads
786+
787+
_logger.info(f"Setting maximum Sire CPU threads to {sire_threads}")
788+
789+
# Update the maximum number of threads.
790+
_sr.legacy.Base.set_max_num_threads(sire_threads)
791+
755792
def _check_space(self):
756793
"""
757794
Check if the system has a periodic space.
@@ -1423,7 +1460,7 @@ def _systems_are_same(system0, system1, num_gcmc_waters=0):
14231460
return True, None
14241461

14251462
@staticmethod
1426-
def _get_gpu_devices(platform, oversubscription_factor=1):
1463+
def _get_gpu_devices(platform, oversubscription_factor=1, log=True):
14271464
"""
14281465
Get list of available GPUs from CUDA_VISIBLE_DEVICES,
14291466
OPENCL_VISIBLE_DEVICES, or HIP_VISIBLE_DEVICES.
@@ -1437,6 +1474,9 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
14371474
oversubscription_factor: int
14381475
The number of concurrent workers per GPU. Default is 1.
14391476
1477+
log: bool
1478+
Whether to log the available devices. Default is True.
1479+
14401480
Returns
14411481
--------
14421482
@@ -1459,23 +1499,30 @@ def _get_gpu_devices(platform, oversubscription_factor=1):
14591499
raise ValueError("CUDA_VISIBLE_DEVICES not set")
14601500
else:
14611501
available_devices = _os.environ.get("CUDA_VISIBLE_DEVICES").split(",")
1462-
_logger.info(f"CUDA_VISIBLE_DEVICES set to {available_devices}")
1502+
if log:
1503+
_logger.info(f"CUDA_VISIBLE_DEVICES set to {available_devices}")
14631504
elif platform == "opencl":
14641505
if _os.environ.get("OPENCL_VISIBLE_DEVICES") is None:
14651506
raise ValueError("OPENCL_VISIBLE_DEVICES not set")
14661507
else:
14671508
available_devices = _os.environ.get("OPENCL_VISIBLE_DEVICES").split(",")
1468-
_logger.info(f"OPENCL_VISIBLE_DEVICES set to {available_devices}")
1509+
if log:
1510+
_logger.info(f"OPENCL_VISIBLE_DEVICES set to {available_devices}")
14691511
elif platform == "hip":
14701512
if _os.environ.get("HIP_VISIBLE_DEVICES") is None:
14711513
raise ValueError("HIP_VISIBLE_DEVICES not set")
14721514
else:
14731515
available_devices = _os.environ.get("HIP_VISIBLE_DEVICES").split(",")
1474-
_logger.info(f"HIP_VISIBLE_DEVICES set to {available_devices}")
1516+
if log:
1517+
_logger.info(f"HIP_VISIBLE_DEVICES set to {available_devices}")
14751518

14761519
num_gpus = len(available_devices)
1477-
_logger.info(f"Number of GPUs available: {num_gpus}")
1478-
_logger.info(f"Number of concurrent workers per GPU: {oversubscription_factor}")
1520+
1521+
if log:
1522+
_logger.info(f"Number of GPUs available: {num_gpus}")
1523+
_logger.info(
1524+
f"Number of concurrent workers per GPU: {oversubscription_factor}"
1525+
)
14791526

14801527
return available_devices
14811528

0 commit comments

Comments
 (0)