Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit a2aade6

Browse files
authored
ability to filter by total gpu memory on machine (#389)
1 parent f3a392c commit a2aade6

File tree

2 files changed

+210
-9
lines changed
  • crates
    • shared/src/models
    • validator/src/validators/synthetic_data

2 files changed

+210
-9
lines changed

crates/shared/src/models/node.rs

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ pub struct ComputeRequirements {
4040
pub struct GpuRequirements {
4141
pub count: Option<u32>,
4242
pub model: Option<String>,
43+
// per Card
4344
pub memory_mb: Option<u32>,
4445
pub memory_mb_min: Option<u32>,
4546
pub memory_mb_max: Option<u32>,
47+
// System wide GPU memory per gpu type
48+
pub total_memory_min: Option<u32>,
49+
pub total_memory_max: Option<u32>,
4650
pub indices: Option<Vec<u32>>,
4751
}
4852

@@ -262,6 +266,47 @@ impl FromStr for ComputeRequirements {
262266
anyhow!("Invalid gpu:memory_mb_max value '{}': {}", value, e)
263267
})?);
264268
}
269+
// --- Total GPU Memory Specifications ---
270+
"gpu:total_memory_min" => {
271+
if !gpu_spec_started {
272+
gpu_spec_started = true;
273+
}
274+
if current_gpu_spec.total_memory_max.is_some()
275+
&& current_gpu_spec.total_memory_max.unwrap()
276+
< value.parse::<u32>().unwrap()
277+
{
278+
return Err(anyhow!(
279+
"Invalid gpu:total_memory_min value '{}': {}",
280+
value,
281+
"min value is greater than max value"
282+
));
283+
}
284+
285+
current_gpu_spec.total_memory_min =
286+
Some(value.parse::<u32>().map_err(|e| {
287+
anyhow!("Invalid gpu:total_memory_min value '{}': {}", value, e)
288+
})?);
289+
}
290+
"gpu:total_memory_max" => {
291+
if !gpu_spec_started {
292+
gpu_spec_started = true;
293+
}
294+
if current_gpu_spec.total_memory_min.is_some()
295+
&& current_gpu_spec.total_memory_min.unwrap()
296+
> value.parse::<u32>().unwrap()
297+
{
298+
return Err(anyhow!(
299+
"Invalid gpu:total_memory_max value '{}': {}",
300+
value,
301+
"max value is less than min value"
302+
));
303+
}
304+
305+
current_gpu_spec.total_memory_max =
306+
Some(value.parse::<u32>().map_err(|e| {
307+
anyhow!("Invalid gpu:total_memory_max value '{}': {}", value, e)
308+
})?);
309+
}
265310
// --- CPU Specifications ---
266311
"cpu:cores" => {
267312
let mut cpu = requirements.cpu.take().unwrap_or_default();
@@ -298,7 +343,9 @@ impl FromStr for ComputeRequirements {
298343
|| current_gpu_spec.model.is_some()
299344
|| current_gpu_spec.memory_mb.is_some()
300345
|| current_gpu_spec.memory_mb_min.is_some()
301-
|| current_gpu_spec.memory_mb_max.is_some())
346+
|| current_gpu_spec.memory_mb_max.is_some()
347+
|| current_gpu_spec.total_memory_min.is_some()
348+
|| current_gpu_spec.total_memory_max.is_some())
302349
{
303350
requirements.gpu.push(current_gpu_spec);
304351
}
@@ -399,8 +446,14 @@ impl GpuSpecs {
399446
.split(',')
400447
.map(|m| m.trim().to_lowercase().replace(' ', "_"))
401448
.any(|normalized_req| {
449+
// Try both with and without underscores for flexible matching
450+
let spec_no_underscore = normalized_spec.replace('_', "");
451+
let req_no_underscore = normalized_req.replace('_', "");
452+
402453
normalized_spec.contains(&normalized_req)
403454
|| normalized_req.contains(&normalized_spec)
455+
|| spec_no_underscore.contains(&req_no_underscore)
456+
|| req_no_underscore.contains(&spec_no_underscore)
404457
})
405458
})) {
406459
return false;
@@ -426,6 +479,25 @@ impl GpuSpecs {
426479
}
427480
}
428481

482+
// Check total memory requirements (count * memory_mb)
483+
if let (Some(req_total_min), Some(gpu_count), Some(gpu_memory)) =
484+
(requirement.total_memory_min, self.count, self.memory_mb)
485+
{
486+
let total_memory = gpu_count * gpu_memory;
487+
if total_memory < req_total_min {
488+
return false;
489+
}
490+
}
491+
492+
if let (Some(req_total_max), Some(gpu_count), Some(gpu_memory)) =
493+
(requirement.total_memory_max, self.count, self.memory_mb)
494+
{
495+
let total_memory = gpu_count * gpu_memory;
496+
if total_memory > req_total_max {
497+
return false;
498+
}
499+
}
500+
429501
// All checked fields meet the requirement
430502
true
431503
}
@@ -969,4 +1041,141 @@ mod tests {
9691041
let req_str = "gpu:memory_mb_min=20000;gpu:memory_mb_max=40000";
9701042
assert!(ComputeRequirements::from_str(req_str).is_ok());
9711043
}
1044+
1045+
#[test]
1046+
fn test_total_memory_parsing() {
1047+
let req_str =
1048+
"gpu:count=4;gpu:model=A100;gpu:total_memory_min=160000;gpu:total_memory_max=320000";
1049+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1050+
1051+
assert_eq!(requirements.gpu.len(), 1);
1052+
let gpu_req = &requirements.gpu[0];
1053+
assert_eq!(gpu_req.count, Some(4));
1054+
assert_eq!(gpu_req.model, Some("A100".to_string()));
1055+
assert_eq!(gpu_req.total_memory_min, Some(160000));
1056+
assert_eq!(gpu_req.total_memory_max, Some(320000));
1057+
}
1058+
1059+
#[test]
1060+
fn test_total_memory_validation() {
1061+
// Test that total_memory_min > total_memory_max is rejected
1062+
let req_str = "gpu:total_memory_min=400000;gpu:total_memory_max=200000";
1063+
assert!(ComputeRequirements::from_str(req_str).is_err());
1064+
1065+
// Test that total_memory_max > total_memory_min is accepted
1066+
let req_str = "gpu:total_memory_min=200000;gpu:total_memory_max=400000";
1067+
assert!(ComputeRequirements::from_str(req_str).is_ok());
1068+
}
1069+
1070+
#[test]
1071+
fn test_meets_total_memory_requirements() {
1072+
// Node has 4x A100 with 40GB each = 160GB total
1073+
let specs =
1074+
create_compute_specs(Some(4), Some("NVIDIA A100"), Some(40000), None, None, None);
1075+
1076+
// Test case 1: Total memory requirement within range
1077+
let req_str =
1078+
"gpu:count=4;gpu:model=A100;gpu:total_memory_min=120000;gpu:total_memory_max=200000";
1079+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1080+
assert!(
1081+
specs.meets(&requirements),
1082+
"Should meet total memory requirements within range"
1083+
);
1084+
1085+
// Test case 2: Total memory requirement too high (min)
1086+
let req_str = "gpu:count=4;gpu:model=A100;gpu:total_memory_min=200000";
1087+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1088+
assert!(
1089+
!specs.meets(&requirements),
1090+
"Should not meet requirements as total memory is below min"
1091+
);
1092+
1093+
// Test case 3: Total memory requirement too low (max)
1094+
let req_str = "gpu:count=4;gpu:model=A100;gpu:total_memory_max=120000";
1095+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1096+
assert!(
1097+
!specs.meets(&requirements),
1098+
"Should not meet requirements as total memory is above max"
1099+
);
1100+
1101+
// Test case 4: Exact total memory match
1102+
let req_str =
1103+
"gpu:count=4;gpu:model=A100;gpu:total_memory_min=160000;gpu:total_memory_max=160000";
1104+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1105+
assert!(
1106+
specs.meets(&requirements),
1107+
"Should meet exact total memory requirements"
1108+
);
1109+
}
1110+
1111+
#[test]
1112+
fn test_meets_total_memory_missing_fields() {
1113+
// Node has count but no memory specified
1114+
let specs_no_memory = ComputeSpecs {
1115+
gpu: Some(GpuSpecs {
1116+
count: Some(4),
1117+
model: Some("A100".to_string()),
1118+
memory_mb: None, // No memory specified
1119+
..Default::default()
1120+
}),
1121+
..Default::default()
1122+
};
1123+
1124+
// Node has memory but no count specified
1125+
let specs_no_count = ComputeSpecs {
1126+
gpu: Some(GpuSpecs {
1127+
count: None, // No count specified
1128+
model: Some("A100".to_string()),
1129+
memory_mb: Some(40000),
1130+
..Default::default()
1131+
}),
1132+
..Default::default()
1133+
};
1134+
1135+
let req_str = "gpu:model=A100;gpu:total_memory_min=120000";
1136+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1137+
1138+
// Both should pass because total memory check is skipped when count or memory is missing
1139+
assert!(
1140+
specs_no_memory.meets(&requirements),
1141+
"Should pass when memory is not specified"
1142+
);
1143+
assert!(
1144+
specs_no_count.meets(&requirements),
1145+
"Should pass when count is not specified"
1146+
);
1147+
}
1148+
1149+
#[test]
1150+
fn test_meets_total_memory_or_logic() {
1151+
// Node has 8x H100 with 80GB each = 640GB total
1152+
let specs =
1153+
create_compute_specs(Some(8), Some("NVIDIA H100"), Some(80000), None, None, None);
1154+
1155+
// Requirements: (4x A100 with 160GB total) OR (8x H100 with 500GB+ total)
1156+
let req_str = "gpu:count=4;gpu:model=A100;gpu:total_memory_min=160000;gpu:count=8;gpu:model=H100;gpu:total_memory_min=500000";
1157+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1158+
1159+
assert_eq!(requirements.gpu.len(), 2);
1160+
assert!(
1161+
specs.meets(&requirements),
1162+
"Should meet the second GPU option with total memory requirement"
1163+
);
1164+
}
1165+
1166+
#[test]
1167+
fn test_complex_total_memory_scenario() {
1168+
// Node has 2x RTX 4090 with 24GB each = 48GB total
1169+
let specs = create_compute_specs(Some(2), Some("RTX 4090"), Some(24000), None, None, None);
1170+
1171+
// Requirements allow multiple options with different total memory requirements
1172+
let req_str = "gpu:count=8;gpu:model=H100;gpu:total_memory_min=600000;gpu:count=4;gpu:model=A100;gpu:total_memory_min=160000;gpu:count=2;gpu:model=RTX4090;gpu:total_memory_min=40000;gpu:total_memory_max=60000";
1173+
let requirements = ComputeRequirements::from_str(req_str).unwrap();
1174+
1175+
assert_eq!(requirements.gpu.len(), 3);
1176+
assert!(
1177+
specs.meets(&requirements),
1178+
"Should meet the third GPU option with total memory range"
1179+
);
1180+
}
9721181
}

crates/validator/src/validators/synthetic_data/mod.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,12 +1395,4 @@ mod tests {
13951395

13961396
Ok(())
13971397
}
1398-
1399-
#[tokio::test]
1400-
async fn test_group_information_from_prod_string() -> Result<(), Error> {
1401-
let file =
1402-
"Qwen/Qwen3-14B/PrimeIntellect/INTELLECT-2-RL-Dataset/1-d4eb155339fc64e-1-20-0.parquet";
1403-
let group_info = GroupInformation::from_str(file)?;
1404-
Ok(())
1405-
}
14061398
}

0 commit comments

Comments
 (0)