Skip to content

Commit afdcb85

Browse files
committed
Avoid for loop over sources
It's a tiny improvement but makes it a bit more readable too.
1 parent 91778e6 commit afdcb85

File tree

13 files changed

+141
-106
lines changed

13 files changed

+141
-106
lines changed

library/agent/Context.test.ts

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,21 @@ t.test("it clears cache when context is mutated", async (t) => {
107107
const context = { ...sampleContext };
108108

109109
runWithContext(context, () => {
110-
t.same(extractStringsFromUserInputCached(getContext()!, "body"), undefined);
111110
t.same(
112-
extractStringsFromUserInputCached(getContext()!, "query"),
113-
new Set(["abc", "def"])
111+
extractStringsFromUserInputCached(getContext()!),
112+
new Set(["abc", "def", "http://localhost:4000"])
114113
);
115114

116115
updateContext(getContext()!, "query", {});
117-
t.same(extractStringsFromUserInputCached(getContext()!, "body"), undefined);
118116
t.same(
119-
extractStringsFromUserInputCached(getContext()!, "query"),
120-
new Set()
117+
extractStringsFromUserInputCached(getContext()!),
118+
new Set(["http://localhost:4000"])
121119
);
122120

123121
runWithContext({ ...context, body: { a: "z" }, query: { b: "y" } }, () => {
124122
t.same(
125-
extractStringsFromUserInputCached(getContext()!, "body"),
126-
new Set(["a", "z"])
127-
);
128-
t.same(
129-
extractStringsFromUserInputCached(getContext()!, "query"),
130-
new Set(["b", "y"])
123+
extractStringsFromUserInputCached(getContext()!),
124+
new Set(["a", "z", "b", "y", "http://localhost:4000"])
131125
);
132126
});
133127
});

library/agent/Context.ts

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import type { ParsedQs } from "qs";
22
import { extractStringsFromUserInput } from "../helpers/extractStringsFromUserInput";
33
import { ContextStorage } from "./context/ContextStorage";
44
import { AsyncResource } from "async_hooks";
5-
import { Source, SOURCES } from "./Source";
65
import type { Endpoint } from "./Config";
76

87
export type User = { id: string; name?: string };
@@ -25,7 +24,7 @@ export type Context = {
2524
xml?: unknown[];
2625
subdomains?: string[]; // https://expressjs.com/en/5x/api.html#req.subdomains
2726
markUnsafe?: unknown[];
28-
cache?: Map<Source, ReturnType<typeof extractStringsFromUserInput>>;
27+
cache?: ReturnType<typeof extractStringsFromUserInput>;
2928
/**
3029
* Used to store redirects in outgoing http(s) requests that are started by a user-supplied input (hostname and port / url) to prevent SSRF redirect attacks.
3130
*/
@@ -44,10 +43,6 @@ export function getContext(): Readonly<Context> | undefined {
4443
return ContextStorage.getStore();
4544
}
4645

47-
function isSourceKey(key: string): key is Source {
48-
return SOURCES.includes(key as Source);
49-
}
50-
5146
// We need to use a function to mutate the context because we need to clear the cache when the user input changes
5247
export function updateContext<K extends keyof Context>(
5348
context: Context,
@@ -56,9 +51,8 @@ export function updateContext<K extends keyof Context>(
5651
) {
5752
context[key] = value;
5853

59-
if (context.cache && isSourceKey(key)) {
60-
context.cache.delete(key);
61-
}
54+
// Clear all the cached user input strings
55+
delete context.cache;
6256
}
6357

6458
/**
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import { Context } from "../agent/Context";
2-
import { Source } from "../agent/Source";
2+
import { SOURCES } from "../agent/Source";
33
import { extractStringsFromUserInput } from "./extractStringsFromUserInput";
44

5+
type ReturnValue = ReturnType<typeof extractStringsFromUserInput>;
6+
57
export function extractStringsFromUserInputCached(
6-
context: Context,
7-
source: Source
8-
): ReturnType<typeof extractStringsFromUserInput> | undefined {
9-
if (!context[source]) {
10-
return undefined;
8+
context: Context
9+
): ReturnValue {
10+
if (context.cache) {
11+
return context.cache;
1112
}
1213

13-
if (!context.cache) {
14-
context.cache = new Map();
15-
}
14+
const userStrings: ReturnValue = new Set();
1615

17-
let result = context.cache.get(source);
16+
for (const source of SOURCES) {
17+
if (!context[source]) {
18+
continue;
19+
}
1820

19-
if (!result) {
20-
result = extractStringsFromUserInput(context[source]);
21-
context.cache.set(source, result);
21+
for (const item of extractStringsFromUserInput(context[source])) {
22+
userStrings.add(item);
23+
}
2224
}
2325

24-
return result;
26+
context.cache = userStrings;
27+
28+
return userStrings;
2529
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import * as t from "tap";
2+
import { Context } from "../agent/Context";
3+
import { getSourceForUserString } from "./getSourceForUserString";
4+
5+
function createContext(): Context {
6+
return {
7+
remoteAddress: "::1",
8+
method: "POST",
9+
url: "http://local.aikido.io",
10+
query: {},
11+
headers: {},
12+
body: {
13+
image: "http://localhost:4000/api/internal",
14+
},
15+
cookies: {},
16+
routeParams: {},
17+
source: "express",
18+
route: "/posts/:id",
19+
};
20+
}
21+
22+
t.test(
23+
"it returns undefined if the user string cannot be found in the context",
24+
async () => {
25+
t.same(getSourceForUserString(createContext(), "unknown"), undefined);
26+
}
27+
);
28+
29+
t.test(
30+
"it returns source if the user string is found in the context",
31+
async () => {
32+
t.same(
33+
getSourceForUserString(
34+
createContext(),
35+
"http://localhost:4000/api/internal"
36+
),
37+
"body"
38+
);
39+
}
40+
);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { Context } from "../agent/Context";
2+
import { Source, SOURCES } from "../agent/Source";
3+
import { extractStringsFromUserInput } from "./extractStringsFromUserInput";
4+
5+
export function getSourceForUserString(
6+
context: Context,
7+
str: string
8+
): Source | undefined {
9+
for (const source of SOURCES) {
10+
if (!context[source]) {
11+
continue;
12+
}
13+
14+
const userStrings = extractStringsFromUserInput(context[source]);
15+
16+
if (userStrings.has(str)) {
17+
return source;
18+
}
19+
}
20+
21+
return undefined;
22+
}

library/sources/xml/isXmlInContext.ts

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Context } from "../../agent/Context";
22
import { SOURCES } from "../../agent/Source";
3-
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
3+
import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput";
44

55
/**
66
* Checks if the XML string can be found in the context.
@@ -11,15 +11,14 @@ export function isXmlInContext(xml: string, context: Context): boolean {
1111
// Skip parsed XML
1212
continue;
1313
}
14-
const userInput = extractStringsFromUserInputCached(context, source);
15-
if (!userInput) {
14+
15+
if (!context[source]) {
1616
continue;
1717
}
1818

19-
for (const str of userInput) {
20-
if (str === xml) {
21-
return true;
22-
}
19+
const userInput = extractStringsFromUserInput(context[source]);
20+
if (userInput.has(xml)) {
21+
return true;
2322
}
2423
}
2524

library/vulnerabilities/attack-wave-detection/queryParamsContainDangerousPayload.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { Context } from "../../agent/Context";
2-
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
2+
import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput";
33

44
const keywords = [
55
"SELECT (CASE WHEN",
@@ -24,10 +24,15 @@ const keywords = [
2424
* Check the query for some common SQL or path traversal patterns.
2525
*/
2626
export function queryParamsContainDangerousPayload(context: Context): boolean {
27-
const queryStrings = extractStringsFromUserInputCached(context, "query");
27+
if (!context.query) {
28+
return false;
29+
}
30+
31+
const queryStrings = extractStringsFromUserInput(context.query);
2832
if (!queryStrings) {
2933
return false;
3034
}
35+
3136
for (const str of queryStrings) {
3237
// Performance optimization
3338
// Some keywords like ../ are shorter than this min length check

library/vulnerabilities/js-injection/checkContextForJsInjection.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { Context } from "../../agent/Context";
22
import { InterceptorResult } from "../../agent/hooks/InterceptorResult";
3-
import { SOURCES } from "../../agent/Source";
43
import { getPathsToPayload } from "../../helpers/attackPath";
54
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
5+
import { getSourceForUserString } from "../../helpers/getSourceForUserString";
66
import { detectJsInjection } from "./detectJsInjection";
77

88
/**
@@ -18,14 +18,10 @@ export function checkContextForJsInjection({
1818
operation: string;
1919
context: Context;
2020
}): InterceptorResult {
21-
for (const source of SOURCES) {
22-
const userInput = extractStringsFromUserInputCached(context, source);
23-
if (!userInput) {
24-
continue;
25-
}
26-
27-
for (const str of userInput) {
28-
if (detectJsInjection(js, str)) {
21+
for (const str of extractStringsFromUserInputCached(context)) {
22+
if (detectJsInjection(js, str)) {
23+
const source = getSourceForUserString(context, str);
24+
if (source) {
2925
return {
3026
operation: operation,
3127
kind: "code_injection",

library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { Context } from "../../agent/Context";
22
import { InterceptorResult } from "../../agent/hooks/InterceptorResult";
3-
import { SOURCES } from "../../agent/Source";
43
import { getPathsToPayload } from "../../helpers/attackPath";
54
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
5+
import { getSourceForUserString } from "../../helpers/getSourceForUserString";
66
import { detectPathTraversal } from "./detectPathTraversal";
77

88
/**
@@ -26,14 +26,10 @@ export function checkContextForPathTraversal({
2626
return;
2727
}
2828

29-
for (const source of SOURCES) {
30-
const userInput = extractStringsFromUserInputCached(context, source);
31-
if (!userInput) {
32-
continue;
33-
}
34-
35-
for (const str of userInput) {
36-
if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) {
29+
for (const str of extractStringsFromUserInputCached(context)) {
30+
if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) {
31+
const source = getSourceForUserString(context, str);
32+
if (source) {
3733
return {
3834
operation: operation,
3935
kind: "path_traversal",

library/vulnerabilities/shell-injection/checkContextForShellInjection.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { Context } from "../../agent/Context";
22
import { InterceptorResult } from "../../agent/hooks/InterceptorResult";
3-
import { SOURCES } from "../../agent/Source";
43
import { getPathsToPayload } from "../../helpers/attackPath";
54
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
5+
import { getSourceForUserString } from "../../helpers/getSourceForUserString";
66
import { detectShellInjection } from "./detectShellInjection";
77

88
/**
@@ -18,14 +18,10 @@ export function checkContextForShellInjection({
1818
operation: string;
1919
context: Context;
2020
}): InterceptorResult {
21-
for (const source of SOURCES) {
22-
const userInput = extractStringsFromUserInputCached(context, source);
23-
if (!userInput) {
24-
continue;
25-
}
26-
27-
for (const str of userInput) {
28-
if (detectShellInjection(command, str)) {
21+
for (const str of extractStringsFromUserInputCached(context)) {
22+
if (detectShellInjection(command, str)) {
23+
const source = getSourceForUserString(context, str);
24+
if (source) {
2925
return {
3026
operation: operation,
3127
kind: "shell_injection",

0 commit comments

Comments
 (0)