Skip to content

Commit 2aea329

Browse files
committed
safety checks
1 parent a2e5ed3 commit 2aea329

File tree

2 files changed

+55
-18
lines changed

2 files changed

+55
-18
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ Create a DuckDB SQL Macro and save it somewhere. Here's an [example](https://gis
1212
Load your remote macro onto your system using a URL:
1313

1414
```sql
15-
D SELECT load_macro_from_url('https://gist.github.com/lmangani/518215a68e674ac662537d518799b893/raw/5f305480fdd7468f4ecda3686011bab8e8e711bf/bsky.sql') as res;
16-
┌─────────────────────────────┐
17-
│ res │
18-
varchar
19-
├─────────────────────────────┤
20-
│ Successfully loaded macro │
21-
└─────────────────────────────┘
15+
┌─────────────────────────────────────────┐
16+
│ res │
17+
varchar
18+
├─────────────────────────────────────────┤
19+
│ Successfully loaded macro: search_posts │
20+
└─────────────────────────────────────────┘
2221
```
2322

2423
Use your new macro and have fun:

src/webxtension_extension.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,39 +83,71 @@ static bool ContainsMacroDefinition(const std::string &content) {
8383
// Parse Function Name
8484
static std::string ExtractMacroName(const std::string &macro_sql) {
8585
try {
86-
// Convert to uppercase for case-insensitive matching
8786
std::string upper_sql = StringUtil::Upper(macro_sql);
88-
89-
// Find the MACRO keyword
9087
size_t macro_pos = upper_sql.find("MACRO");
9188
if (macro_pos == std::string::npos) {
9289
return "unknown";
9390
}
94-
95-
// Find the start of the name (after MACRO and any whitespace)
91+
9692
size_t name_start = macro_pos + 5; // length of "MACRO"
9793
while (name_start < upper_sql.length() && std::isspace(upper_sql[name_start])) {
9894
name_start++;
9995
}
100-
101-
// Find the end of the name (before the opening parenthesis)
96+
10297
size_t name_end = upper_sql.find('(', name_start);
10398
if (name_end == std::string::npos) {
10499
return "unknown";
105100
}
106-
107-
// Trim any trailing whitespace
101+
108102
while (name_end > name_start && std::isspace(upper_sql[name_end - 1])) {
109103
name_end--;
110104
}
111-
112-
// Get the original case version of the name from the input string
105+
113106
return macro_sql.substr(name_start, name_end - name_start);
114107
} catch (...) {
115108
return "unknown";
116109
}
117110
}
118111

112+
// Helper function to check for potentially dangerous SQL commands
113+
static std::pair<bool, std::string> ContainsDangerousCommands(const std::string &sql) {
114+
const std::vector<std::string> dangerous_commands = {
115+
"DELETE", "DROP", "TRUNCATE", "ALTER", "GRANT", "REVOKE",
116+
"CREATE USER", "ALTER USER", "DROP USER",
117+
"CREATE DATABASE", "DROP DATABASE",
118+
"EXEC", "EXECUTE",
119+
"SHUTDOWN", "RESTART",
120+
"SET GLOBAL", "SET SYSTEM",
121+
"LOAD EXTENSION", "UNLOAD EXTENSION",
122+
"ATTACH", "DETACH",
123+
"COPY", "EXPORT",
124+
"UPDATE", "MERGE"
125+
};
126+
127+
std::string upper_sql = StringUtil::Upper(sql);
128+
std::vector<std::string> found_commands;
129+
130+
for (const auto& cmd : dangerous_commands) {
131+
if (upper_sql.find(cmd) != std::string::npos) {
132+
found_commands.push_back(cmd);
133+
}
134+
}
135+
136+
if (!found_commands.empty()) {
137+
std::string warning = "Warning: SQL contains potentially dangerous commands: ";
138+
for (size_t i = 0; i < found_commands.size(); i++) {
139+
warning += found_commands[i];
140+
if (i < found_commands.size() - 1) {
141+
warning += ", ";
142+
}
143+
}
144+
warning += ". Please review the macro carefully before using it.";
145+
return std::make_pair(true, warning);
146+
}
147+
148+
return std::make_pair(false, "");
149+
}
150+
119151
// Function to fetch and create macro from URL
120152
static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Vector &result, DatabaseInstance *db_instance) {
121153
auto &context = state.GetContext();
@@ -142,6 +174,12 @@ static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Ve
142174
// Get the SQL content
143175
std::string macro_sql = res->body;
144176

177+
// Check for dangerous commands
178+
auto dangerous_check = ContainsDangerousCommands(macro_sql);
179+
if (dangerous_check.first) {
180+
throw std::runtime_error(dangerous_check.second);
181+
}
182+
145183
// Replace all \r\n with \n
146184
macro_sql = StringUtil::Replace(macro_sql, "\r\n", "\n");
147185
// Replace any remaining \r with \n

0 commit comments

Comments
 (0)