Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/azure/MssqlVSCodeAzureSubscriptionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ export class MssqlVSCodeAzureSubscriptionProvider extends VSCodeAzureSubscriptio
}

protected override async getTenantFilters(): Promise<TenantId[]> {
return this.getSelectedSubscriptions().map((id) => id.split("/")[0]);
// Format is now "account/tenantId/subscriptionId", extract tenantId (index 1)
return this.getSelectedSubscriptions().map((id) => id.split("/")[1]);
}

protected override async getSubscriptionFilters(): Promise<SubscriptionId[]> {
return this.getSelectedSubscriptions().map((id) => id.split("/")[1]);
// Format is now "account/tenantId/subscriptionId", extract subscriptionId (index 2)
return this.getSelectedSubscriptions().map((id) => id.split("/")[2]);
}
}
24 changes: 16 additions & 8 deletions src/connectionconfig/azureHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,18 @@ export async function promptForAzureSubscriptionFilter(
return false;
}

await vscode.workspace.getConfiguration().update(
configSelectedAzureSubscriptions,
selectedSubs.map((s) => `${s.tenantId}/${s.subscriptionId}`),
vscode.ConfigurationTarget.Global,
const filterConfig = selectedSubs.map(
(s) => `${s.group}/${s.tenantId}/${s.subscriptionId}`,
);

await vscode.workspace
.getConfiguration()
.update(
configSelectedAzureSubscriptions,
filterConfig,
vscode.ConfigurationTarget.Global,
);

return true;
} catch (error) {
state.formMessage = { message: l10n.t("Error loading Azure subscriptions.") };
Expand All @@ -287,19 +293,21 @@ export async function getSubscriptionQuickPickItems(
false /* don't use the current filter, 'cause we're gonna set it */,
);

// Get previously selected subscriptions as "account/tenantId/subscriptionId" strings
const prevSelectedSubs = vscode.workspace
.getConfiguration()
.get<string[] | undefined>(configSelectedAzureSubscriptions)
?.map((entry) => entry.split("/")[1]);
.get<string[] | undefined>(configSelectedAzureSubscriptions);

const quickPickItems: SubscriptionPickItem[] = allSubs
.map((sub) => {
const compositeKey = `${sub.account.label}/${sub.tenantId}/${sub.subscriptionId}`;
const isPicked = prevSelectedSubs ? prevSelectedSubs.includes(compositeKey) : true;
return {
label: sub.name,
description: sub.subscriptionId,
description: `${sub.subscriptionId} (${sub.account.label})`,
tenantId: sub.tenantId,
subscriptionId: sub.subscriptionId,
picked: prevSelectedSubs ? prevSelectedSubs.includes(sub.subscriptionId) : true,
picked: isPicked,
group: sub.account.label,
};
})
Expand Down
33 changes: 28 additions & 5 deletions src/connectionconfig/connectionDialogWebviewController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,17 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
});

this.registerReducer("loadAzureServers", async (state, payload) => {
await this.loadAzureServersForSubscription(state, payload.subscriptionId);
// Find the subscription in state to get its tenantId
const subscription = state.azureSubscriptions.find(
(s) => s.id === payload.subscriptionId,
);
if (subscription) {
await this.loadAzureServersForSubscription(
state,
subscription.tenantId,
subscription.id,
);
}

return state;
});
Expand Down Expand Up @@ -1484,7 +1494,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
}

state.loadingAzureSubscriptionsStatus = ApiStatus.Loading;
this.updateState();
this.updateState(state);

// getSubscriptions() below checks this config setting if filtering is specified. If the user has this set, then we use it; if not, we get all subscriptions.
// The specific vscode config setting it uses is hardcoded into the VS Code Azure SDK, so we need to use the same value here.
Expand All @@ -1498,8 +1508,13 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
TelemetryActions.LoadAzureSubscriptions,
);

// Store subscriptions with composite key "tenantId/subscriptionId" to handle cases where
// the same subscription is accessible from multiple tenants/accounts
this._azureSubscriptions = new Map(
(await auth.getSubscriptions(shouldUseFilter)).map((s) => [s.subscriptionId, s]),
(await auth.getSubscriptions(shouldUseFilter)).map((s) => [
`${s.tenantId}/${s.subscriptionId}`,
s,
]),
);
const tenantSubMap = Map.groupBy<string, AzureSubscription>(
Array.from(this._azureSubscriptions.values()),
Expand All @@ -1513,6 +1528,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
subs.push({
id: s.subscriptionId,
name: s.name,
tenantId: s.tenantId,
loaded: false,
});
}
Expand All @@ -1526,6 +1542,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
undefined, // additionalProperties
{
subscriptionCount: subs.length,
tenantCount: tenantSubMap.size,
},
);
this.updateState();
Expand Down Expand Up @@ -1566,7 +1583,11 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
for (const t of tenantSubMap.keys()) {
for (const s of tenantSubMap.get(t)) {
promiseArray.push(
this.loadAzureServersForSubscription(state, s.subscriptionId),
this.loadAzureServersForSubscription(
state,
s.tenantId,
s.subscriptionId,
),
);
}
}
Expand Down Expand Up @@ -1597,9 +1618,11 @@ export class ConnectionDialogWebviewController extends FormWebviewController<

private async loadAzureServersForSubscription(
state: ConnectionDialogWebviewState,
tenantId: string,
subscriptionId: string,
) {
const azSub = this._azureSubscriptions.get(subscriptionId);
const compositeKey = `${tenantId}/${subscriptionId}`;
const azSub = this._azureSubscriptions.get(compositeKey);
const stateSub = state.azureSubscriptions.find((s) => s.id === subscriptionId);

try {
Expand Down
8 changes: 7 additions & 1 deletion src/reactviews/pages/ConnectionDialog/azureBrowsePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,17 @@ export const AzureBrowsePage = () => {
setSelectedServerWithFormState,
setServerValue,
srvs,
DefaultSelectionMode.SelectFirstIfAny,
DefaultSelectionMode.AlwaysSelectNone,
);
}
}, [locations, selectedLocation, context.state.azureServers]);

// databases
useEffect(() => {
if (!selectedServer) {
setDatabases([]);
setSelectedDatabase(undefined);
setDatabaseValue("");
return; // should not be visible if no server is selected
}

Expand All @@ -209,6 +212,9 @@ export const AzureBrowsePage = () => {
);

if (!server) {
setDatabases([]);
setSelectedDatabase(undefined);
setDatabaseValue("");
return;
}

Expand Down
1 change: 1 addition & 0 deletions src/sharedInterfaces/connectionDialog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ export interface CreateConnectionGroupDialogProps extends IDialogProps {
export interface AzureSubscriptionInfo {
name: string;
id: string;
tenantId: string;
loaded: boolean;
}

Expand Down
183 changes: 183 additions & 0 deletions test/unit/MssqlVSCodeAzureSubscriptionProvider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { expect } from "chai";
import * as sinon from "sinon";
import * as vscode from "vscode";
import { MssqlVSCodeAzureSubscriptionProvider } from "../../src/azure/MssqlVSCodeAzureSubscriptionProvider";

suite("MssqlVSCodeAzureSubscriptionProvider Tests", () => {
let sandbox: sinon.SinonSandbox;
let configStub: {
get: sinon.SinonStub;
has: sinon.SinonStub;
inspect: sinon.SinonStub;
update: sinon.SinonStub;
};

setup(() => {
sandbox = sinon.createSandbox();
configStub = {
get: sandbox.stub(),
has: sandbox.stub(),
inspect: sandbox.stub(),
update: sandbox.stub(),
};
sandbox.stub(vscode.workspace, "getConfiguration").returns(configStub);
});

teardown(() => {
sandbox.restore();
});

suite("Filter Methods", () => {
test("getTenantFilters extracts tenant IDs from composite keys", async () => {
const mockFilterConfig = [
"account1@example.com/tenant1/subscription1",
"account2@example.com/tenant2/subscription2",
"account1@example.com/tenant1/subscription3",
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const tenantFilters = await provider["getTenantFilters"]();

expect(tenantFilters).to.have.lengthOf(3);
expect(tenantFilters).to.include("tenant1");
expect(tenantFilters).to.include("tenant2");
});

test("getSubscriptionFilters extracts subscription IDs from composite keys", async () => {
const mockFilterConfig = [
"account1@example.com/tenant1/subscription1",
"account2@example.com/tenant2/subscription2",
"account1@example.com/tenant1/subscription3",
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const subscriptionFilters = await provider["getSubscriptionFilters"]();

expect(subscriptionFilters).to.have.lengthOf(3);
expect(subscriptionFilters).to.include("subscription1");
expect(subscriptionFilters).to.include("subscription2");
expect(subscriptionFilters).to.include("subscription3");
});

test("getTenantFilters handles empty configuration", async () => {
configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns([]);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const tenantFilters = await provider["getTenantFilters"]();

expect(tenantFilters).to.be.an("array").that.is.empty;
});

test("getSubscriptionFilters handles empty configuration", async () => {
configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns([]);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const subscriptionFilters = await provider["getSubscriptionFilters"]();

expect(subscriptionFilters).to.be.an("array").that.is.empty;
});

test("filters correctly handle duplicate subscriptions from different accounts", async () => {
// Same subscription accessible from two different accounts
const mockFilterConfig = [
"account1@example.com/shared-tenant/shared-subscription",
"account2@example.com/shared-tenant/shared-subscription",
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const subscriptionFilters = await provider["getSubscriptionFilters"]();

// Should return both entries (even though subscription ID is the same)
// because they're from different accounts
expect(subscriptionFilters).to.have.lengthOf(2);
expect(subscriptionFilters[0]).to.equal("shared-subscription");
expect(subscriptionFilters[1]).to.equal("shared-subscription");
});

test("filters correctly parse complex email addresses", async () => {
const mockFilterConfig = [
"user+tag@company.com/tenant-guid-1/subscription-guid-1",
"user.name@sub.domain.com/tenant-guid-2/subscription-guid-2",
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();

const tenantFilters = await provider["getTenantFilters"]();
expect(tenantFilters).to.have.lengthOf(2);
expect(tenantFilters).to.include("tenant-guid-1");
expect(tenantFilters).to.include("tenant-guid-2");

const subscriptionFilters = await provider["getSubscriptionFilters"]();
expect(subscriptionFilters).to.have.lengthOf(2);
expect(subscriptionFilters).to.include("subscription-guid-1");
expect(subscriptionFilters).to.include("subscription-guid-2");
});

test("getTenantFilters handles malformed filter strings gracefully", async () => {
const mockFilterConfig = [
"account1@example.com/tenant1/subscription1",
"malformed-string", // No slashes
"account2@example.com/tenant2", // Missing subscription ID (only 2 parts)
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const tenantFilters = await provider["getTenantFilters"]();

// Should still extract what it can, even with malformed entries
expect(tenantFilters).to.have.lengthOf(3);
expect(tenantFilters[0]).to.equal("tenant1");
expect(tenantFilters[1]).to.be.undefined; // malformed-string.split("/")[1]
expect(tenantFilters[2]).to.equal("tenant2");
});

test("getSubscriptionFilters handles malformed filter strings gracefully", async () => {
const mockFilterConfig = [
"account1@example.com/tenant1/subscription1",
"malformed-string", // No slashes
"account2@example.com/tenant2", // Missing subscription ID (only 2 parts)
];

configStub.get
.withArgs("mssql.selectedAzureSubscriptions", sinon.match.any)
.returns(mockFilterConfig);

const provider = MssqlVSCodeAzureSubscriptionProvider.getInstance();
const subscriptionFilters = await provider["getSubscriptionFilters"]();

// Should still extract what it can, even with malformed entries
expect(subscriptionFilters).to.have.lengthOf(3);
expect(subscriptionFilters[0]).to.equal("subscription1");
expect(subscriptionFilters[1]).to.be.undefined; // malformed-string.split("/")[2]
expect(subscriptionFilters[2]).to.be.undefined; // "account2@example.com/tenant2".split("/")[2]
});
});
});
Loading