diff --git a/src/install/installationManager.ts b/src/install/installationManager.ts index 1f951d751..754dfbc80 100644 --- a/src/install/installationManager.ts +++ b/src/install/installationManager.ts @@ -11,7 +11,7 @@ import { useAppState } from '../main-process/appState'; import type { AppWindow } from '../main-process/appWindow'; import { ComfyInstallation } from '../main-process/comfyInstallation'; import { createInstallStageInfo } from '../main-process/installStages'; -import type { InstallOptions, InstallValidation } from '../preload'; +import type { InstallOptions, InstallValidation, TorchDeviceType } from '../preload'; import { CmCli } from '../services/cmCli'; import { captureSentryException } from '../services/sentry'; import { type HasTelemetry, ITelemetry, trackEvent } from '../services/telemetry'; @@ -23,7 +23,7 @@ import { InstallWizard } from './installWizard'; import { Troubleshooting } from './troubleshooting'; const execAsync = promisify(exec); -const NVIDIA_DRIVER_MIN_VERSION = '580'; +export const NVIDIA_DRIVER_MIN_VERSION = '580'; /** * Extracts the NVIDIA driver version from `nvidia-smi` output. @@ -47,6 +47,36 @@ export function isNvidiaDriverBelowMinimum( return compareVersions(driverVersion, minimumVersion) < 0; } +type NvidiaDriverWarningContext = { + platform: NodeJS.Platform; + selectedDevice?: TorchDeviceType; + suppressWarningFor?: string; + driverVersion?: string; + minimumVersion?: string; +}; + +/** + * Determines whether the NVIDIA driver warning should be shown. + * @param context The driver warning context. + * @return `true` when the warning should be shown. + */ +export function shouldWarnAboutNvidiaDriver(context: NvidiaDriverWarningContext): boolean { + const { + platform, + selectedDevice, + suppressWarningFor, + driverVersion, + minimumVersion = NVIDIA_DRIVER_MIN_VERSION, + } = context; + + if (platform !== 'win32') return false; + if (selectedDevice !== 'nvidia') return false; + if (suppressWarningFor === minimumVersion) return false; + if (!driverVersion) return false; + + return isNvidiaDriverBelowMinimum(driverVersion, minimumVersion); +} + /** High-level / UI control over the installation of ComfyUI server. */ export class InstallationManager implements HasTelemetry { constructor( @@ -398,19 +428,36 @@ export class InstallationManager implements HasTelemetry { if (process.platform !== 'win32') return; if (installation.virtualEnvironment.selectedDevice !== 'nvidia') return; + const config = useDesktopConfig(); + const suppressWarningFor = config.get('suppressNvidiaDriverWarningFor'); + if (suppressWarningFor === NVIDIA_DRIVER_MIN_VERSION) return; + const driverVersion = (await this.getNvidiaDriverVersionFromSmi()) ?? (await this.getNvidiaDriverVersionFromSmiFallback()); - if (!driverVersion) return; - - if (!isNvidiaDriverBelowMinimum(driverVersion)) return; + if ( + !shouldWarnAboutNvidiaDriver({ + platform: process.platform, + selectedDevice: installation.virtualEnvironment.selectedDevice, + suppressWarningFor, + driverVersion, + }) + ) { + return; + } - await this.appWindow.showMessageBox({ + const { checkboxChecked } = await this.appWindow.showMessageBox({ type: 'warning', title: 'Update NVIDIA Driver', message: 'Your NVIDIA driver may be too old for PyTorch 2.9.1 + cu130.', detail: `Detected driver version: ${driverVersion}\nRecommended minimum: ${NVIDIA_DRIVER_MIN_VERSION}\n\nPlease consider updating your NVIDIA drivers and retrying if you run into issues.`, buttons: ['OK'], + checkboxLabel: "Don't show again", + checkboxChecked: false, }); + + if (checkboxChecked) { + config.set('suppressNvidiaDriverWarningFor', NVIDIA_DRIVER_MIN_VERSION); + } } /** diff --git a/src/store/desktopSettings.ts b/src/store/desktopSettings.ts index 6010253f5..21bb13e22 100644 --- a/src/store/desktopSettings.ts +++ b/src/store/desktopSettings.ts @@ -35,4 +35,6 @@ export type DesktopSettings = { versionConsentedMetrics?: string; /** Whether the user has generated an image successfully. */ hasGeneratedSuccessfully?: boolean; + /** The minimum NVIDIA driver version for which the warning was dismissed. */ + suppressNvidiaDriverWarningFor?: string; }; diff --git a/tests/unit/install/installationManager.test.ts b/tests/unit/install/installationManager.test.ts index af9f92969..3a3d4b938 100644 --- a/tests/unit/install/installationManager.test.ts +++ b/tests/unit/install/installationManager.test.ts @@ -7,8 +7,10 @@ import { ComfySettings } from '@/config/comfySettings'; import { IPC_CHANNELS } from '@/constants'; import { InstallationManager, + NVIDIA_DRIVER_MIN_VERSION, isNvidiaDriverBelowMinimum, parseNvidiaDriverVersionFromSmiOutput, + shouldWarnAboutNvidiaDriver, } from '@/install/installationManager'; import type { AppWindow } from '@/main-process/appWindow'; import { ComfyInstallation } from '@/main-process/comfyInstallation'; @@ -33,13 +35,15 @@ vi.mock('node:fs/promises', () => ({ })); const config = { - get: vi.fn((key: string) => { + get: vi.fn((key: string): string | undefined => { if (key === 'installState') return 'installed'; if (key === 'basePath') return 'valid/base'; + return undefined; }), - set: vi.fn((key: string, value: string) => { - if (key !== 'basePath') throw new Error(`Unexpected key: ${key}`); - if (!value) throw new Error(`Unexpected value: [${value}]`); + set: vi.fn((key: string, value: unknown) => { + const allowedKeys = new Set(['basePath', 'suppressNvidiaDriverWarningFor']); + if (!allowedKeys.has(key)) throw new Error(`Unexpected key: ${key}`); + if (key === 'basePath' && !value) throw new Error(`Unexpected value: [${value}]`); }), }; vi.mock('@/store/desktopConfig', () => ({ @@ -110,6 +114,7 @@ const createMockAppWindow = () => { send: vi.fn(), loadPage: vi.fn(() => Promise.resolve(null)), showOpenDialog: vi.fn(), + showMessageBox: vi.fn(() => Promise.resolve({ response: 0, checkboxChecked: false })), maximize: vi.fn(), }; return mock as unknown as AppWindow; @@ -251,6 +256,81 @@ describe('InstallationManager', () => { cleanup?.(); }); }); + + describe('shouldWarnAboutNvidiaDriver', () => { + type NvidiaDriverWarningInput = Parameters[0]; + + const scenarios: Array<{ + scenario: string; + input: NvidiaDriverWarningInput; + expected: boolean; + }> = [ + { + scenario: 'returns false on non-Windows platforms', + input: { + platform: 'darwin', + selectedDevice: 'nvidia', + driverVersion: '570.0', + suppressWarningFor: undefined, + }, + expected: false, + }, + { + scenario: 'returns false when device is not nvidia', + input: { + platform: 'win32', + selectedDevice: 'cpu', + driverVersion: '570.0', + suppressWarningFor: undefined, + }, + expected: false, + }, + { + scenario: 'returns false when warning is suppressed for the minimum version', + input: { + platform: 'win32', + selectedDevice: 'nvidia', + driverVersion: '570.0', + suppressWarningFor: NVIDIA_DRIVER_MIN_VERSION, + }, + expected: false, + }, + { + scenario: 'returns false when driver version is missing', + input: { + platform: 'win32', + selectedDevice: 'nvidia', + driverVersion: undefined, + suppressWarningFor: undefined, + }, + expected: false, + }, + { + scenario: 'returns true when driver is below minimum', + input: { + platform: 'win32', + selectedDevice: 'nvidia', + driverVersion: '579.0.0', + suppressWarningFor: undefined, + }, + expected: true, + }, + { + scenario: 'returns false when driver meets minimum', + input: { + platform: 'win32', + selectedDevice: 'nvidia', + driverVersion: '580.0.0', + suppressWarningFor: undefined, + }, + expected: false, + }, + ]; + + it.each(scenarios)('$scenario', ({ input, expected }) => { + expect(shouldWarnAboutNvidiaDriver(input)).toBe(expected); + }); + }); }); describe('parseNvidiaDriverVersionFromSmiOutput', () => {