Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions src/renderer/components/Team/ProviderDetailsModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,40 @@ interface ProviderDetailsModalProps {
providerId?: string;
}

// Default configurations for each provider type
const ACCELERATOR_OPTIONS = ['AppleSilicon', 'NVIDIA', 'AMD', 'cpu'];

// Default configurations for each provider type (excluding supported_accelerators,
// which is managed via the dedicated UI field).
const DEFAULT_CONFIGS = {
skypilot: `{
"server_url": "<Your SkyPilot server URL e.g. http://localhost:46580>",
"default_env_vars": {
"SKYPILOT_USER_ID": "<Your SkyPilot user ID>",
"SKYPILOT_USER": "<Your SkyPilot user name>"
},
"default_entrypoint_command": "",
"supported_accelerators": ["NVIDIA"]
"default_entrypoint_command": ""
}`,
slurm: `{
"mode": "ssh",
"ssh_host": "<Machine IP for the SLURM login node>",
"ssh_user": "<Your SLURM user ID - all jobs will run as this user>",
"ssh_key_path": "",
"ssh_port": 22,
"supported_accelerators": ["NVIDIA"]
"ssh_port": 22
}`,
runpod: `{
"api_key": "<Your Runpod API key>",
"api_base_url": "https://rest.runpod.io/v1",
"supported_accelerators": ["NVIDIA"]
}`,
local: `{
"supported_accelerators": ["AppleSilicon", "cpu"]
"api_base_url": "https://rest.runpod.io/v1"
}`,
local: `{}`,
} as const;

const DEFAULT_SUPPORTED_ACCELERATORS: Record<string, string[]> = {
skypilot: ['NVIDIA'],
slurm: ['NVIDIA'],
runpod: ['NVIDIA'],
local: ['AppleSilicon', 'cpu'],
};

const ACCELERATOR_OPTIONS = ['AppleSilicon', 'NVIDIA', 'AMD', 'cpu'];

export default function ProviderDetailsModal({
open,
onClose,
Expand Down Expand Up @@ -148,21 +151,22 @@ export default function ProviderDetailsModal({
setName(providerData.name || '');
setType(providerData.type || '');
// Config is an object, stringify it for display in textarea
const configObj =
const rawConfigObj =
typeof providerData.config === 'string'
? JSON.parse(providerData.config || '{}')
: providerData.config || {};

// Parse SLURM-specific fields if this is a SLURM provider
if (providerData.type === 'slurm') {
parseSlurmConfig(configObj);
// Extract supported_accelerators into dedicated state, but do not show it in raw JSON.
if (rawConfigObj.supported_accelerators) {
setSupportedAccelerators(rawConfigObj.supported_accelerators);
delete rawConfigObj.supported_accelerators;
}

if (configObj.supported_accelerators) {
setSupportedAccelerators(configObj.supported_accelerators);
// Parse SLURM-specific fields if this is a SLURM provider
if (providerData.type === 'slurm') {
parseSlurmConfig(rawConfigObj);
}

setConfig(JSON.stringify(configObj, null, 2));
setConfig(JSON.stringify(rawConfigObj, null, 2));
} else if (!providerId) {
// Reset form when in "add" mode (no providerId)
setName('');
Expand Down Expand Up @@ -203,18 +207,22 @@ export default function ProviderDetailsModal({
DEFAULT_CONFIGS[type as keyof typeof DEFAULT_CONFIGS];
setConfig(defaultConfig);

try {
const configObj = JSON.parse(defaultConfig);
if (configObj.supported_accelerators) {
setSupportedAccelerators(configObj.supported_accelerators);
}
// Initialize default supported accelerators per provider type, but keep them
// out of the raw JSON configuration.
if (DEFAULT_SUPPORTED_ACCELERATORS[type]) {
setSupportedAccelerators(DEFAULT_SUPPORTED_ACCELERATORS[type]);
} else {
setSupportedAccelerators([]);
}

// Parse SLURM defaults
if (type === 'slurm') {
// Parse SLURM defaults from the JSON template
if (type === 'slurm') {
try {
const configObj = JSON.parse(defaultConfig);
parseSlurmConfig(configObj);
} catch (e) {
// Ignore parse errors
}
} catch (e) {
// Ignore parse errors
}
}
}, [type, providerId]);
Expand All @@ -226,14 +234,6 @@ export default function ProviderDetailsModal({
if (type === 'slurm') {
const configObj = buildSlurmConfig();
setConfig(JSON.stringify(configObj, null, 2));
} else if (type && type in DEFAULT_CONFIGS) {
try {
const configObj = JSON.parse(config);
configObj.supported_accelerators = supportedAccelerators;
setConfig(JSON.stringify(configObj, null, 2));
} catch (e) {
// If JSON is invalid (e.g. while typing), don't update
}
}
}
}, [
Expand Down Expand Up @@ -282,6 +282,12 @@ export default function ProviderDetailsModal({
let parsedConfig: any;
if (type === 'slurm') {
parsedConfig = buildSlurmConfig();
} else if (type === 'local') {
// Local providers are configured via supported accelerators only
parsedConfig = {};
if (supportedAccelerators.length > 0) {
parsedConfig.supported_accelerators = supportedAccelerators;
}
} else {
// The API expects an object for config, not a JSON string
parsedConfig = typeof config === 'string' ? JSON.parse(config) : config;
Expand Down Expand Up @@ -542,7 +548,7 @@ export default function ProviderDetailsModal({
)}

{/* Generic JSON config for non-SLURM providers or advanced editing */}
{type !== 'slurm' && (
{type !== 'slurm' && type !== 'local' && (
<FormControl sx={{ mt: 1 }}>
<FormLabel>Configuration</FormLabel>
<Textarea
Expand Down