Skip to content

Commit 954f141

Browse files
authored
fix: pass request headers for the graph (#1219)
## Summary by Sourcery Enable forwarding custom request headers through the pagination client and streamline pagination behavior New Features: - Allow passing custom request headers to underlying GraphQL requests via client.query and pagination Enhancements: - Refactor and simplify ListField interface by removing unused fields - Simplify pagination loop to always continue when a full batch is returned, eliminating special-case directive tracking Tests: - Add tests for passing headers and using request-options object with @fetchall - Update existing pagination tests to include headers parameter in mock calls
1 parent 7b80592 commit 954f141

File tree

2 files changed

+97
-61
lines changed

2 files changed

+97
-61
lines changed

sdk/thegraph/src/utils/pagination.test.ts

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ interface RequestVariables {
189189
};
190190
}
191191

192-
const requestMock = mock(async (document: DocumentNode, variables: RequestVariables) => {
192+
const requestMock = mock(async (document: DocumentNode, variables: RequestVariables, _requestHeaders?: HeadersInit) => {
193193
// Check if the "holders" field in the query has "skip" or "first" arguments set
194194
let shouldFetchHolders = false;
195195
let holdersFieldHasPagination = false;
@@ -278,7 +278,6 @@ describe("createTheGraphClientWithPagination", () => {
278278
symbol
279279
}
280280
}`),
281-
{},
282281
);
283282
expect(result.tokens).toHaveLength(TEST_TOKENS.length);
284283
expect(result.tokens).toEqual(TEST_TOKENS);
@@ -289,6 +288,26 @@ describe("createTheGraphClientWithPagination", () => {
289288
);
290289
});
291290

291+
it("should return all tokens if @fetchAll is used with a request options object", async () => {
292+
const result = await client.query({
293+
document: theGraphGraphql(`
294+
query {
295+
tokens @fetchAll {
296+
name
297+
symbol
298+
}
299+
}`),
300+
variables: {},
301+
});
302+
expect(result.tokens).toHaveLength(TEST_TOKENS.length);
303+
expect(result.tokens).toEqual(TEST_TOKENS);
304+
const expectedCalls = Math.ceil(TEST_TOKENS.length / DEFAULT_PAGE_SIZE);
305+
// If the number of tokens is an exact multiple of the page size, we expect one extra call
306+
expect(requestMock).toHaveBeenCalledTimes(
307+
TEST_TOKENS.length % DEFAULT_PAGE_SIZE === 0 ? expectedCalls + 1 : expectedCalls,
308+
);
309+
});
310+
292311
it("should use existing pagination variables if they are passed in", async () => {
293312
const limit = 100;
294313
const skip = 1;
@@ -476,23 +495,54 @@ describe("createTheGraphClientWithPagination", () => {
476495
};
477496
}),
478497
);
479-
expect(requestMock).toHaveBeenNthCalledWith(1, expect.anything(), {
480-
orderBy: "name",
481-
orderDirection: "asc",
482-
where: {
483-
name: "Token",
498+
expect(requestMock).toHaveBeenNthCalledWith(
499+
1,
500+
expect.anything(),
501+
{
502+
orderBy: "name",
503+
orderDirection: "asc",
504+
where: {
505+
name: "Token",
506+
},
507+
first: DEFAULT_PAGE_SIZE,
508+
skip: 0,
484509
},
485-
first: DEFAULT_PAGE_SIZE,
486-
skip: 0,
487-
});
488-
expect(requestMock).toHaveBeenNthCalledWith(2, expect.anything(), {
489-
orderBy: "name",
490-
orderDirection: "asc",
491-
where: {
492-
name: "Token",
510+
undefined,
511+
);
512+
expect(requestMock).toHaveBeenNthCalledWith(
513+
2,
514+
expect.anything(),
515+
{
516+
orderBy: "name",
517+
orderDirection: "asc",
518+
where: {
519+
name: "Token",
520+
},
521+
first: DEFAULT_PAGE_SIZE,
522+
skip: DEFAULT_PAGE_SIZE,
523+
},
524+
undefined,
525+
);
526+
});
527+
528+
it("should pass request headers to the request", async () => {
529+
const result = await client.query(
530+
theGraphGraphql(`
531+
query {
532+
tokens @fetchAll {
533+
name
534+
symbol
535+
}
536+
}`),
537+
{},
538+
{
539+
"x-api-key": "test",
493540
},
494-
first: DEFAULT_PAGE_SIZE,
495-
skip: DEFAULT_PAGE_SIZE,
541+
);
542+
expect(result.tokens).toHaveLength(TEST_TOKENS.length);
543+
expect(result.tokens).toEqual(TEST_TOKENS);
544+
expect(requestMock).toHaveBeenNthCalledWith(1, expect.anything(), expect.anything(), {
545+
"x-api-key": "test",
496546
});
497547
});
498548
});

sdk/thegraph/src/utils/pagination.ts

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { sortBy } from "es-toolkit";
22
import { get, isArray, isEmpty, set } from "es-toolkit/compat";
33
import type { TadaDocumentNode } from "gql.tada";
4-
import { type ArgumentNode, type DocumentNode, Kind, parse, type SelectionNode, visit } from "graphql";
4+
import { type ArgumentNode, type DocumentNode, Kind, parse, visit } from "graphql";
55
import type { GraphQLClient, RequestDocument, RequestOptions, Variables } from "graphql-request";
66

77
// Constants for TheGraph limits
@@ -10,16 +10,12 @@ const FIRST_ARG = "first";
1010
const SKIP_ARG = "skip";
1111
const FETCH_ALL_DIRECTIVE = "fetchAll";
1212

13-
interface ListField {
13+
interface ListFieldWithFetchAllDirective {
1414
path: string[];
1515
fieldName: string;
16-
alias?: string;
1716
firstValue?: number;
1817
skipValue?: number;
1918
otherArgs: ArgumentNode[];
20-
selections?: ReadonlyArray<SelectionNode>;
21-
hasFetchAllDirective?: boolean;
22-
firstValueIsDefault?: boolean; // Track if first value was defaulted
2319
}
2420

2521
/**
@@ -115,8 +111,8 @@ function extractFetchAllFields(
115111
document: DocumentNode,
116112
variables?: Variables,
117113
fetchAllFields?: Set<string>,
118-
): ListField[] {
119-
const fields: ListField[] = [];
114+
): ListFieldWithFetchAllDirective[] {
115+
const fields: ListFieldWithFetchAllDirective[] = [];
120116
const pathStack: string[] = [];
121117

122118
visit(document, {
@@ -174,13 +170,9 @@ function extractFetchAllFields(
174170
fields.push({
175171
path: [...pathStack],
176172
fieldName: node.name.value,
177-
alias: node.alias?.value,
178-
firstValue: hasFetchAllDirective && (firstValue ?? THE_GRAPH_LIMIT),
179-
skipValue: hasFetchAllDirective && (skipValue ?? 0),
173+
firstValue: firstValue ?? THE_GRAPH_LIMIT,
174+
skipValue: skipValue ?? 0,
180175
otherArgs,
181-
selections: node.selectionSet?.selections,
182-
hasFetchAllDirective,
183-
firstValueIsDefault: hasFetchAllDirective ? firstValue === undefined : false,
184176
});
185177
}
186178
},
@@ -196,7 +188,7 @@ function extractFetchAllFields(
196188
// Create a query for a single field with specific pagination
197189
function createSingleFieldQuery(
198190
document: DocumentNode,
199-
targetField: ListField,
191+
targetField: ListFieldWithFetchAllDirective,
200192
skip: number,
201193
first: number,
202194
): DocumentNode {
@@ -251,7 +243,7 @@ function createSingleFieldQuery(
251243
}
252244

253245
// Create query without list fields
254-
function createNonListQuery(document: DocumentNode, listFields: ListField[]): DocumentNode | null {
246+
function createNonListQuery(document: DocumentNode, listFields: ListFieldWithFetchAllDirective[]): DocumentNode | null {
255247
let hasFields = false;
256248
const pathStack: string[] = [];
257249

@@ -319,7 +311,8 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
319311
async function executeListFieldPagination(
320312
document: DocumentNode,
321313
variables: Variables | undefined,
322-
field: ListField,
314+
field: ListFieldWithFetchAllDirective,
315+
requestHeaders?: HeadersInit,
323316
): Promise<unknown[]> {
324317
const results: unknown[] = [];
325318
let currentSkip = field.skipValue || 0;
@@ -332,11 +325,15 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
332325
while (hasMore) {
333326
const query = createSingleFieldQuery(document, field, currentSkip, batchSize);
334327
const existingVariables = filterVariables(variables, query) ?? {};
335-
const response = await theGraphClient.request(query, {
336-
...existingVariables,
337-
first: batchSize,
338-
skip: currentSkip,
339-
});
328+
const response = await theGraphClient.request(
329+
query,
330+
{
331+
...existingVariables,
332+
first: batchSize,
333+
skip: currentSkip,
334+
},
335+
requestHeaders,
336+
);
340337

341338
// Use array path format for es-toolkit's get function
342339
const data = get(response, field.path) ?? get(response, field.fieldName);
@@ -352,24 +349,8 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
352349
if (isArray(data) && data.length > 0) {
353350
results.push(...data);
354351

355-
// Continue fetching if:
356-
// 1. We have @fetchAll directive (fetch everything)
357-
// 2. We have an explicit first value > THE_GRAPH_LIMIT and haven't reached it
358-
// 3. We have a defaulted first value and got a full batch (treating it as "no explicit value")
359-
// 4. We have no first value and got a full batch
360-
if (field.hasFetchAllDirective) {
361-
// With @fetchAll, continue if we got a full batch
362-
hasMore = data.length === batchSize;
363-
} else if (field.firstValue && !field.firstValueIsDefault) {
364-
// With explicit first value (not defaulted), only continue if:
365-
// - We haven't reached the requested amount yet
366-
// - We got a full batch (indicating more data might exist)
367-
hasMore = data.length === batchSize && results.length < field.firstValue;
368-
} else {
369-
// When first is not specified or was defaulted (using default batch size),
370-
// continue if we got a full batch (standard TheGraph pagination behavior)
371-
hasMore = data.length === batchSize;
372-
}
352+
// With @fetchAll, continue if we got a full batch
353+
hasMore = data.length === batchSize;
373354
} else {
374355
hasMore = false;
375356
}
@@ -384,16 +365,20 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
384365
async query<TResult, TVariables extends Variables>(
385366
documentOrOptions: TadaDocumentNode<TResult, TVariables> | RequestDocument | RequestOptions<TVariables, TResult>,
386367
variablesRaw?: Omit<TVariables, "skip" | "first">,
368+
requestHeadersRaw?: HeadersInit,
387369
): Promise<TResult> {
388370
let document: TadaDocumentNode<TResult, TVariables> | RequestDocument;
389371
let variables: Omit<TVariables, "skip" | "first">;
372+
let requestHeaders: HeadersInit | undefined;
390373

391374
if (isRequestOptions(documentOrOptions)) {
392375
document = documentOrOptions.document;
393-
variables = documentOrOptions.variables as TVariables;
376+
variables = (documentOrOptions.variables ?? {}) as TVariables;
377+
requestHeaders = documentOrOptions.requestHeaders;
394378
} else {
395379
document = documentOrOptions;
396380
variables = variablesRaw ?? ({} as TVariables);
381+
requestHeaders = requestHeadersRaw;
397382
}
398383

399384
// First, detect and strip @fetchAll directives
@@ -404,7 +389,7 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
404389

405390
// If no list fields, execute normally
406391
if (listFields.length === 0) {
407-
return theGraphClient.request(processedDocument, variables as Variables);
392+
return theGraphClient.request(processedDocument, variables as Variables, requestHeaders);
408393
}
409394

410395
// Execute paginated queries for all list fields
@@ -416,7 +401,7 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
416401
// Process list fields in parallel for better performance
417402
const fieldDataPromises = sortedFields.map(async (field) => ({
418403
field,
419-
data: await executeListFieldPagination(processedDocument, variables, field),
404+
data: await executeListFieldPagination(processedDocument, variables, field, requestHeaders),
420405
}));
421406

422407
const fieldResults = await Promise.all(fieldDataPromises);
@@ -434,6 +419,7 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
434419
const nonListResult = await theGraphClient.request(
435420
nonListQuery,
436421
filterVariables(variables, nonListQuery) ?? {},
422+
requestHeaders,
437423
);
438424

439425
// Merge results, preserving list data
@@ -447,5 +433,5 @@ export function createTheGraphClientWithPagination(theGraphClient: Pick<GraphQLC
447433
}
448434

449435
function isRequestOptions(args: unknown): args is RequestOptions<Variables, unknown> {
450-
return typeof args === "object" && args !== null && "document" in args && "variables" in args;
436+
return typeof args === "object" && args !== null && "document" in args;
451437
}

0 commit comments

Comments
 (0)