Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions azure/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ export declare class LocationListStep<T extends ILocationWizardContext> extends
public static getQuickPickDescription?: (location: AzExtLocation) => string | undefined;
}

/**
* A simple cache that deduplicates in-flight requests.
*
* Currently backed by an in-memory Map. Designed so the backing store can be
* swapped to persistent storage (e.g. `vscode.Memento` / globalState) in the
* future to survive across VS Code restarts with a longer TTL (e.g. 7 days).
*/
export declare class LocationCache<T> {
/**
* @param ttlMs Optional time-to-live in milliseconds. When omitted, entries
* never expire (suitable for in-memory caches that reset on extension
* deactivation). Set this when switching to persistent storage.
* @param now Clock function used for TTL checks. Override in tests to avoid
* real timers.
*/
constructor(ttlMs?: number, now?: () => number);
/**
* Get a value from the cache, or fetch it if missing/expired.
* Concurrent calls with the same key share a single in-flight request.
*/
getOrLoad(key: string, loader: () => Promise<T>): Promise<T>;
/** Remove all cached entries. */
clear(): void;
}

/**
* Checks to see if providers (i.e. 'Microsoft.Web') are registered and registers them if they're not
*/
Expand Down
1 change: 1 addition & 0 deletions azure/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export * from './utils/createPortalUri';
export * from './utils/parseAzureResourceId';
export * from './utils/setupAzureLogger';
export * from './utils/uiUtils';
export { LocationCache } from './wizard/LocationCache';
export * from './wizard/LocationListStep';
export * from './wizard/ResourceGroupCreateStep';
export * from './wizard/ResourceGroupListStep';
Expand Down
94 changes: 94 additions & 0 deletions azure/src/wizard/LocationCache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

/**
* A simple cache that deduplicates in-flight requests.
*
* Currently backed by an in-memory Map. Designed so the backing store can be
* swapped to persistent storage (e.g. `vscode.Memento` / globalState) in the
* future to survive across VS Code restarts with a longer TTL (e.g. 7 days).
*/

interface CacheEntry<T> {
data: T;
/** Unix timestamp (ms) when this entry was stored */
storedAt: number;
}

export class LocationCache<T> {
private readonly cache = new Map<string, CacheEntry<T>>();

/**
* In-flight promises keyed the same way as the cache.
* Ensures concurrent callers share the same request instead of firing duplicates.
*/
private readonly inflight = new Map<string, Promise<T>>();

/**
* Monotonically increasing counter incremented on each {@link clear} call.
* In-flight requests captured before a clear will see a stale generation
* and skip writing their result back into the cache.
*/
private generation = 0;

/**
* @param ttlMs Optional time-to-live in milliseconds. When omitted, entries
* never expire (suitable for in-memory caches that reset on extension
* deactivation). Set this when switching to persistent storage.
* @param now Clock function used for TTL checks. Override in tests to avoid
* real timers.
*/
constructor(private readonly ttlMs?: number, private readonly now: () => number = Date.now) { }

/**
* Get a value from the cache, or fetch it if missing/expired.
* Concurrent calls with the same key share a single in-flight request.
*/
getOrLoad(key: string, loader: () => Promise<T>): Promise<T> {
const cached = this.cache.get(key);
if (cached && !this.isExpired(cached)) {
return Promise.resolve(cached.data);
}

// Check for an in-flight request we can piggy-back on
const existing = this.inflight.get(key);
if (existing) {
return existing;
}

const gen = this.generation;

let loaderPromise: Promise<T>;
try {
loaderPromise = loader();
} catch (err) {
return Promise.reject(err instanceof Error ? err : new Error(String(err)));
}

const promise = loaderPromise.then(data => {
if (this.generation === gen) {
this.cache.set(key, { data, storedAt: this.now() });
}
this.inflight.delete(key);
return data;
}).catch(err => {
this.inflight.delete(key);
throw err;
});

this.inflight.set(key, promise);
return promise;
}

/** Remove all cached entries. */
clear(): void {
this.cache.clear();
this.generation++;
}

private isExpired(entry: CacheEntry<T>): boolean {
return this.ttlMs !== undefined && (this.now() - entry.storedAt) > this.ttlMs;
}
}
39 changes: 29 additions & 10 deletions azure/src/wizard/LocationListStep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import { createResourcesClient, createSubscriptionsClient } from '../clients';
import { resourcesProvider } from '../constants';
import { ext } from '../extensionVariables';
import { uiUtils } from '../utils/uiUtils';
import { LocationCache } from './LocationCache';

const allLocationsCache = new LocationCache<types.AzExtLocation[]>();
const providerLocationCache = new LocationCache<string[]>();

/* eslint-disable @typescript-eslint/naming-convention */
interface ILocationWizardContextInternal extends types.ILocationWizardContext {
Expand Down Expand Up @@ -254,19 +258,34 @@ export class LocationListStep<T extends ILocationWizardContextInternal> extends
}

async function getAllLocations(wizardContext: types.ILocationWizardContext): Promise<types.AzExtLocation[]> {
const client = await createSubscriptionsClient(wizardContext);
const locations = await uiUtils.listAllIterator<Location>(client.subscriptions.listLocations(wizardContext.subscriptionId, { includeExtendedLocations: wizardContext.includeExtendedLocations }));
return locations.filter((l): l is types.AzExtLocation => !!(l.id && l.name && l.displayName));
const includeExtended = !!wizardContext.includeExtendedLocations;
const cacheKey = `${wizardContext.subscriptionId}|${includeExtended}`;

return allLocationsCache.getOrLoad(cacheKey, async () => {
ext.outputChannel.appendLog(`Cache miss for all locations (key: "${cacheKey}"). Fetching from API...`);
const client = await createSubscriptionsClient(wizardContext);
const locations = await uiUtils.listAllIterator<Location>(client.subscriptions.listLocations(wizardContext.subscriptionId, { includeExtendedLocations: includeExtended }));
const filtered = locations.filter((l): l is types.AzExtLocation => !!(l.id && l.name && l.displayName));
ext.outputChannel.appendLog(`Fetched and cached ${filtered.length} locations for subscription "${wizardContext.subscriptionId}".`);
return filtered;
});
}

async function getProviderLocations(wizardContext: types.ILocationWizardContext, provider: string, resourceType: string): Promise<string[]> {
const rgClient = await createResourcesClient(wizardContext);
const providerData = await rgClient.providers.get(provider);
const resourceTypeData = providerData.resourceTypes?.find(rt => rt.resourceType?.toLowerCase() === resourceType.toLowerCase());
if (!resourceTypeData) {
throw new ProviderResourceTypeNotFoundError(providerData, resourceType);
}
return nonNullProp(resourceTypeData, 'locations');
const cacheKey = `${wizardContext.subscriptionId}|${provider.toLowerCase()}|${resourceType.toLowerCase()}`;

return providerLocationCache.getOrLoad(cacheKey, async () => {
ext.outputChannel.appendLog(`Cache miss for provider locations (key: "${cacheKey}"). Fetching from API...`);
const rgClient = await createResourcesClient(wizardContext);
const providerData = await rgClient.providers.get(provider);
const resourceTypeData = providerData.resourceTypes?.find(rt => rt.resourceType?.toLowerCase() === resourceType.toLowerCase());
if (!resourceTypeData) {
throw new ProviderResourceTypeNotFoundError(providerData, resourceType);
}
const locations = nonNullProp(resourceTypeData, 'locations');
ext.outputChannel.appendLog(`Fetched and cached ${locations.length} locations for provider "${provider}/${resourceType}".`);
return locations;
});
}

function compareLocation(l1: types.AzExtLocation, l2: types.AzExtLocation): number {
Expand Down
184 changes: 184 additions & 0 deletions azure/test/LocationCache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import * as assert from 'assert';
import { LocationCache } from '../src/wizard/LocationCache';

suite('LocationCache', () => {
let cache: LocationCache<string[]>;

setup(() => {
cache = new LocationCache();
});

test('returns data from loader on cache miss', async () => {
const result = await cache.getOrLoad('key1', () => Promise.resolve(['eastus', 'westus']));
assert.deepStrictEqual(result, ['eastus', 'westus']);
});

test('returns cached data on subsequent calls without calling loader again', async () => {
let callCount = 0;
const loader = () => {
callCount++;
return Promise.resolve(['eastus']);
};

const result1 = await cache.getOrLoad('key1', loader);
const result2 = await cache.getOrLoad('key1', loader);

assert.deepStrictEqual(result1, ['eastus']);
assert.deepStrictEqual(result2, ['eastus']);
assert.strictEqual(callCount, 1, 'loader should only be called once');
});

test('uses separate entries for different keys', async () => {
await cache.getOrLoad('sub1|false', () => Promise.resolve(['eastus']));
await cache.getOrLoad('sub2|false', () => Promise.resolve(['westus']));

const result1 = await cache.getOrLoad('sub1|false', () => Promise.reject(new Error('should not call')));
const result2 = await cache.getOrLoad('sub2|false', () => Promise.reject(new Error('should not call')));

assert.deepStrictEqual(result1, ['eastus']);
assert.deepStrictEqual(result2, ['westus']);
});

test('deduplicates concurrent in-flight requests for the same key', async () => {
let callCount = 0;
let resolve: (value: string[]) => void;
const loader = () => {
callCount++;
return new Promise<string[]>(r => { resolve = r; });
};

const p1 = cache.getOrLoad('key1', loader);
const p2 = cache.getOrLoad('key1', loader);

// Both should be waiting on the same promise
assert.strictEqual(callCount, 1, 'loader should only be called once for concurrent requests');

resolve!(['eastus']);
const [result1, result2] = await Promise.all([p1, p2]);

assert.deepStrictEqual(result1, ['eastus']);
assert.deepStrictEqual(result2, ['eastus']);
});

test('clear removes all cached entries', async () => {
let callCount = 0;
const loader = () => {
callCount++;
return Promise.resolve(['eastus']);
};

await cache.getOrLoad('key1', loader);
assert.strictEqual(callCount, 1);

cache.clear();

await cache.getOrLoad('key1', loader);
assert.strictEqual(callCount, 2, 'loader should be called again after clear');
});

test('clear prevents in-flight requests from repopulating the cache', async () => {
let resolve: (value: string[]) => void;
const loader = () => new Promise<string[]>(r => { resolve = r; });

const p1 = cache.getOrLoad('key1', loader);

// Clear while the request is still in-flight
cache.clear();

// Resolve the stale request
resolve!(['stale']);
await p1;

// The stale result should NOT have been cached, so a new loader fires
let callCount = 0;
await cache.getOrLoad('key1', () => { callCount++; return Promise.resolve(['fresh']); });
assert.strictEqual(callCount, 1, 'loader should be called because stale result was not cached');
});

test('expired entries are refreshed (injectable clock)', async () => {
let time = 1000;
const clock = () => time;
const ttlCache = new LocationCache<string[]>(100, clock);

let callCount = 0;
const loader = () => {
callCount++;
return Promise.resolve([`result-${callCount}`]);
};

const result1 = await ttlCache.getOrLoad('key1', loader);
assert.deepStrictEqual(result1, ['result-1']);

// Advance past TTL
time = 1200;

const result2 = await ttlCache.getOrLoad('key1', loader);
assert.deepStrictEqual(result2, ['result-2']);
assert.strictEqual(callCount, 2, 'loader should be called again after expiry');
});

test('entries without TTL never expire', async () => {
let callCount = 0;

await cache.getOrLoad('key1', () => { callCount++; return Promise.resolve(['eastus']); });
await cache.getOrLoad('key1', () => { callCount++; return Promise.resolve(['westus']); });

assert.strictEqual(callCount, 1);
});

test('loader error does not poison the cache', async () => {
let shouldFail = true;
const loader = () => {
if (shouldFail) {
return Promise.reject(new Error('network error'));
}
return Promise.resolve(['eastus']);
};

await assert.rejects(() => cache.getOrLoad('key1', loader), /network error/);

shouldFail = false;
const result = await cache.getOrLoad('key1', loader);
assert.deepStrictEqual(result, ['eastus']);
});

test('loader error is propagated to all concurrent waiters', async () => {
let reject: (err: Error) => void;
const loader = () => new Promise<string[]>((_, r) => { reject = r; });

const p1 = cache.getOrLoad('key1', loader);
const p2 = cache.getOrLoad('key1', loader);

reject!(new Error('boom'));

await assert.rejects(() => p1, /boom/);
await assert.rejects(() => p2, /boom/);
});

test('after error, a new loader call succeeds', async () => {
let callCount = 0;
const failLoader = () => { callCount++; return Promise.reject(new Error('fail')); };
const okLoader = () => { callCount++; return Promise.resolve(['eastus']); };

await assert.rejects(() => cache.getOrLoad('key1', failLoader), /fail/);
const result = await cache.getOrLoad('key1', okLoader);

assert.deepStrictEqual(result, ['eastus']);
assert.strictEqual(callCount, 2);
});

test('synchronous throw from loader is handled', async () => {
const loader = (): Promise<string[]> => { throw new Error('sync boom'); };

await assert.rejects(() => cache.getOrLoad('key1', loader), /sync boom/);

// Cache should not be poisoned
const result = await cache.getOrLoad('key1', () => Promise.resolve(['eastus']));
assert.deepStrictEqual(result, ['eastus']);
});
});
Loading