Skip to content

Commit 36fa3f8

Browse files
committed
fix: GraphQL complexity validator exponential fragment traversal DoS (GHSA-mfj6-6p54-m98c)
1 parent 458b718 commit 36fa3f8

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

spec/GraphQLQueryComplexity.spec.js

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,34 @@ describe('graphql query complexity', () => {
179179
});
180180
});
181181

182+
describe('fragment fan-out', () => {
183+
it('should reject query with exponential fragment fan-out efficiently', async () => {
184+
await setupGraphQL({
185+
requestComplexity: { graphQLFields: 100 },
186+
});
187+
// Binary fan-out: each fragment spreads the next one twice.
188+
// Without fix: 2^(levels-1) field visits = 2^25 ≈ 33M (hangs event loop).
189+
// With fix (memoization): O(levels) traversal, same field count, instant rejection.
190+
const levels = 26;
191+
let query = 'query Q { ...F0 }\n';
192+
for (let i = 0; i < levels; i++) {
193+
if (i === levels - 1) {
194+
query += `fragment F${i} on Query { __typename }\n`;
195+
} else {
196+
query += `fragment F${i} on Query { ...F${i + 1} ...F${i + 1} }\n`;
197+
}
198+
}
199+
const start = Date.now();
200+
const result = await graphqlRequest(query);
201+
const elapsed = Date.now() - start;
202+
// Must complete in under 5 seconds (without fix it would take seconds or hang)
203+
expect(elapsed).toBeLessThan(5000);
204+
// Field count is 2^(levels-1) = 16777216, which exceeds the limit of 100
205+
expect(result.errors).toBeDefined();
206+
expect(result.errors[0].message).toMatch(/Number of GraphQL fields .* exceeds maximum allowed/);
207+
});
208+
});
209+
182210
describe('where argument breadth', () => {
183211
it('should enforce depth and field limits regardless of where argument breadth', async () => {
184212
await setupGraphQL({

src/GraphQL/helpers/queryComplexity.js

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import { GraphQLError } from 'graphql';
22
import logger from '../../logger';
33

4-
function calculateQueryComplexity(operation, fragments) {
4+
function calculateQueryComplexity(operation, fragments, limits = {}) {
55
let maxDepth = 0;
66
let totalFields = 0;
7+
const fragmentCache = new Map();
8+
const { maxDepth: allowedMaxDepth, maxFields: allowedMaxFields } = limits;
79

810
function visitSelectionSet(selectionSet, depth, visitedFragments) {
911
if (!selectionSet) {
1012
return;
1113
}
14+
if (
15+
(allowedMaxFields !== undefined && allowedMaxFields !== -1 && totalFields > allowedMaxFields) ||
16+
(allowedMaxDepth !== undefined && allowedMaxDepth !== -1 && maxDepth > allowedMaxDepth)
17+
) {
18+
return;
19+
}
1220
for (const selection of selectionSet.selections) {
1321
if (selection.kind === 'Field') {
1422
totalFields++;
@@ -23,14 +31,30 @@ function calculateQueryComplexity(operation, fragments) {
2331
visitSelectionSet(selection.selectionSet, depth, visitedFragments);
2432
} else if (selection.kind === 'FragmentSpread') {
2533
const name = selection.name.value;
34+
if (fragmentCache.has(name)) {
35+
const cached = fragmentCache.get(name);
36+
totalFields += cached.fields;
37+
const adjustedDepth = depth + cached.maxDepthDelta;
38+
if (adjustedDepth > maxDepth) {
39+
maxDepth = adjustedDepth;
40+
}
41+
continue;
42+
}
2643
if (visitedFragments.has(name)) {
2744
continue;
2845
}
2946
const fragment = fragments[name];
3047
if (fragment) {
31-
const branchVisited = new Set(visitedFragments);
32-
branchVisited.add(name);
33-
visitSelectionSet(fragment.selectionSet, depth, branchVisited);
48+
visitedFragments.add(name);
49+
const savedFields = totalFields;
50+
const savedMaxDepth = maxDepth;
51+
maxDepth = depth;
52+
visitSelectionSet(fragment.selectionSet, depth, visitedFragments);
53+
const fieldsContribution = totalFields - savedFields;
54+
const maxDepthDelta = maxDepth - depth;
55+
fragmentCache.set(name, { fields: fieldsContribution, maxDepthDelta });
56+
maxDepth = Math.max(savedMaxDepth, maxDepth);
57+
visitedFragments.delete(name);
3458
}
3559
}
3660
}
@@ -69,7 +93,8 @@ function createComplexityValidationPlugin(getConfig) {
6993

7094
const { depth, fields } = calculateQueryComplexity(
7195
requestContext.operation,
72-
fragments
96+
fragments,
97+
{ maxDepth: graphQLDepth, maxFields: graphQLFields }
7398
);
7499

75100
if (graphQLDepth !== -1 && depth > graphQLDepth) {

0 commit comments

Comments
 (0)