Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions src/install/installationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { useAppState } from '../main-process/appState';
import type { AppWindow } from '../main-process/appWindow';
import { ComfyInstallation } from '../main-process/comfyInstallation';
import { createInstallStageInfo } from '../main-process/installStages';
import type { InstallOptions, InstallValidation } from '../preload';
import type { InstallOptions, InstallValidation, TorchDeviceType } from '../preload';
import { CmCli } from '../services/cmCli';
import { captureSentryException } from '../services/sentry';
import { type HasTelemetry, ITelemetry, trackEvent } from '../services/telemetry';
Expand All @@ -23,7 +23,7 @@ import { InstallWizard } from './installWizard';
import { Troubleshooting } from './troubleshooting';

const execAsync = promisify(exec);
const NVIDIA_DRIVER_MIN_VERSION = '580';
export const NVIDIA_DRIVER_MIN_VERSION = '580';

/**
* Extracts the NVIDIA driver version from `nvidia-smi` output.
Expand All @@ -47,6 +47,36 @@ export function isNvidiaDriverBelowMinimum(
return compareVersions(driverVersion, minimumVersion) < 0;
}

type NvidiaDriverWarningContext = {
platform: NodeJS.Platform;
selectedDevice?: TorchDeviceType;
suppressWarningFor?: string;
driverVersion?: string;
minimumVersion?: string;
};
Comment on lines +50 to +56
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Prefer an exported interface for the warning context.

This shape is part of the exported API via shouldWarnAboutNvidiaDriver, so make it explicit and extensible.

♻️ Suggested refactor
- type NvidiaDriverWarningContext = {
+ /** Context for {`@link` shouldWarnAboutNvidiaDriver}. */
+ export interface NvidiaDriverWarningContext {
   platform: NodeJS.Platform;
   selectedDevice?: TorchDeviceType;
   suppressWarningFor?: string;
   driverVersion?: string;
   minimumVersion?: string;
- };
+ }

As per coding guidelines, prefer interfaces for public object shapes.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
type NvidiaDriverWarningContext = {
platform: NodeJS.Platform;
selectedDevice?: TorchDeviceType;
suppressWarningFor?: string;
driverVersion?: string;
minimumVersion?: string;
};
/** Context for {`@link` shouldWarnAboutNvidiaDriver}. */
export interface NvidiaDriverWarningContext {
platform: NodeJS.Platform;
selectedDevice?: TorchDeviceType;
suppressWarningFor?: string;
driverVersion?: string;
minimumVersion?: string;
}
🤖 Prompt for AI Agents
In `@src/install/installationManager.ts` around lines 50 - 56, The exported shape
NvidiaDriverWarningContext (used by shouldWarnAboutNvidiaDriver) is currently
declared as a local type; change it to an exported interface so the API is
explicit and extensible: replace the type alias NvidiaDriverWarningContext with
an exported interface NvidiaDriverWarningContext including the same fields
(platform, selectedDevice?, suppressWarningFor?, driverVersion?,
minimumVersion?) and ensure any references (e.g., the
shouldWarnAboutNvidiaDriver function signature) import/accept the interface;
export the interface from the module so consumers can extend/implement it.


/**
* Determines whether the NVIDIA driver warning should be shown.
* @param context The driver warning context.
* @return `true` when the warning should be shown.
*/
export function shouldWarnAboutNvidiaDriver(context: NvidiaDriverWarningContext): boolean {
const {
platform,
selectedDevice,
suppressWarningFor,
driverVersion,
minimumVersion = NVIDIA_DRIVER_MIN_VERSION,
} = context;

if (platform !== 'win32') return false;
if (selectedDevice !== 'nvidia') return false;
if (suppressWarningFor === minimumVersion) return false;
if (!driverVersion) return false;

return isNvidiaDriverBelowMinimum(driverVersion, minimumVersion);
}

/** High-level / UI control over the installation of ComfyUI server. */
export class InstallationManager implements HasTelemetry {
constructor(
Expand Down Expand Up @@ -398,19 +428,36 @@ export class InstallationManager implements HasTelemetry {
if (process.platform !== 'win32') return;
if (installation.virtualEnvironment.selectedDevice !== 'nvidia') return;

const config = useDesktopConfig();
const suppressWarningFor = config.get('suppressNvidiaDriverWarningFor');
if (suppressWarningFor === NVIDIA_DRIVER_MIN_VERSION) return;

const driverVersion =
(await this.getNvidiaDriverVersionFromSmi()) ?? (await this.getNvidiaDriverVersionFromSmiFallback());
if (!driverVersion) return;

if (!isNvidiaDriverBelowMinimum(driverVersion)) return;
if (
!shouldWarnAboutNvidiaDriver({
platform: process.platform,
selectedDevice: installation.virtualEnvironment.selectedDevice,
suppressWarningFor,
driverVersion,
})
) {
return;
}

await this.appWindow.showMessageBox({
const { checkboxChecked } = await this.appWindow.showMessageBox({
type: 'warning',
title: 'Update NVIDIA Driver',
message: 'Your NVIDIA driver may be too old for PyTorch 2.9.1 + cu130.',
detail: `Detected driver version: ${driverVersion}\nRecommended minimum: ${NVIDIA_DRIVER_MIN_VERSION}\n\nPlease consider updating your NVIDIA drivers and retrying if you run into issues.`,
buttons: ['OK'],
checkboxLabel: "Don't show again",
checkboxChecked: false,
});

if (checkboxChecked) {
config.set('suppressNvidiaDriverWarningFor', NVIDIA_DRIVER_MIN_VERSION);
}
}

/**
Expand Down
2 changes: 2 additions & 0 deletions src/store/desktopSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ export type DesktopSettings = {
versionConsentedMetrics?: string;
/** Whether the user has generated an image successfully. */
hasGeneratedSuccessfully?: boolean;
/** The minimum NVIDIA driver version for which the warning was dismissed. */
suppressNvidiaDriverWarningFor?: string;
};
88 changes: 84 additions & 4 deletions tests/unit/install/installationManager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import { ComfySettings } from '@/config/comfySettings';
import { IPC_CHANNELS } from '@/constants';
import {
InstallationManager,
NVIDIA_DRIVER_MIN_VERSION,
isNvidiaDriverBelowMinimum,
parseNvidiaDriverVersionFromSmiOutput,
shouldWarnAboutNvidiaDriver,
} from '@/install/installationManager';
import type { AppWindow } from '@/main-process/appWindow';
import { ComfyInstallation } from '@/main-process/comfyInstallation';
Expand All @@ -33,13 +35,15 @@ vi.mock('node:fs/promises', () => ({
}));

const config = {
get: vi.fn((key: string) => {
get: vi.fn((key: string): string | undefined => {
if (key === 'installState') return 'installed';
if (key === 'basePath') return 'valid/base';
return undefined;
}),
set: vi.fn((key: string, value: string) => {
if (key !== 'basePath') throw new Error(`Unexpected key: ${key}`);
if (!value) throw new Error(`Unexpected value: [${value}]`);
set: vi.fn((key: string, value: unknown) => {
const allowedKeys = new Set(['basePath', 'suppressNvidiaDriverWarningFor']);
if (!allowedKeys.has(key)) throw new Error(`Unexpected key: ${key}`);
if (key === 'basePath' && !value) throw new Error(`Unexpected value: [${value}]`);
}),
};
vi.mock('@/store/desktopConfig', () => ({
Expand Down Expand Up @@ -110,6 +114,7 @@ const createMockAppWindow = () => {
send: vi.fn(),
loadPage: vi.fn(() => Promise.resolve(null)),
showOpenDialog: vi.fn(),
showMessageBox: vi.fn(() => Promise.resolve({ response: 0, checkboxChecked: false })),
maximize: vi.fn(),
};
return mock as unknown as AppWindow;
Expand Down Expand Up @@ -251,6 +256,81 @@ describe('InstallationManager', () => {
cleanup?.();
});
});

describe('shouldWarnAboutNvidiaDriver', () => {
type NvidiaDriverWarningInput = Parameters<typeof shouldWarnAboutNvidiaDriver>[0];

const scenarios: Array<{
scenario: string;
input: NvidiaDriverWarningInput;
expected: boolean;
}> = [
{
scenario: 'returns false on non-Windows platforms',
input: {
platform: 'darwin',
selectedDevice: 'nvidia',
driverVersion: '570.0',
suppressWarningFor: undefined,
},
expected: false,
},
{
scenario: 'returns false when device is not nvidia',
input: {
platform: 'win32',
selectedDevice: 'cpu',
driverVersion: '570.0',
suppressWarningFor: undefined,
},
expected: false,
},
{
scenario: 'returns false when warning is suppressed for the minimum version',
input: {
platform: 'win32',
selectedDevice: 'nvidia',
driverVersion: '570.0',
suppressWarningFor: NVIDIA_DRIVER_MIN_VERSION,
},
expected: false,
},
{
scenario: 'returns false when driver version is missing',
input: {
platform: 'win32',
selectedDevice: 'nvidia',
driverVersion: undefined,
suppressWarningFor: undefined,
},
expected: false,
},
{
scenario: 'returns true when driver is below minimum',
input: {
platform: 'win32',
selectedDevice: 'nvidia',
driverVersion: '579.0.0',
suppressWarningFor: undefined,
},
expected: true,
},
{
scenario: 'returns false when driver meets minimum',
input: {
platform: 'win32',
selectedDevice: 'nvidia',
driverVersion: '580.0.0',
suppressWarningFor: undefined,
},
expected: false,
},
];

it.each(scenarios)('$scenario', ({ input, expected }) => {
expect(shouldWarnAboutNvidiaDriver(input)).toBe(expected);
});
});
});

describe('parseNvidiaDriverVersionFromSmiOutput', () => {
Expand Down
Loading