diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 72625be5fad8..6b8d0364726a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,38 +8,35 @@ body: value: > DuckDB has several repositories for different components, please make sure you're raising your issue in the correct one: - * [Documentation/website](https://github.com/duckdb/duckdb-web/issues/new) - * APIs: - * [duckdb-java](https://github.com/duckdb/duckdb-java/issues/new) - * [duckdb-node](https://github.com/duckdb/duckdb-node/issues/new) - * [duckdb-node-neo](https://github.com/duckdb/duckdb-node-neo/issues/new) - * [duckdb-odbc](https://github.com/duckdb/duckdb-odbc/issues/new) - * [duckdb-python](https://github.com/duckdb/duckdb-python/issues/new) - * [duckdb-r](https://github.com/duckdb/duckdb-r/issues/new) - * [duckdb-rs](https://github.com/duckdb/duckdb-rs/issues/new) - * [duckdb-swift](https://github.com/duckdb/duckdb-swift/issues/new) - * [duckdb-wasm](https://github.com/duckdb/duckdb-wasm/issues/new) - * [go-duckdb](https://github.com/marcboeker/go-duckdb/issues/new) + * [Documentation](https://github.com/duckdb/duckdb-web/issues/new) + * Clients: + [Go](https://github.com/duckdb/duckdb-go/issues/new), + [Java (JDBC)](https://github.com/duckdb/duckdb-java/issues/new), + [Node.js](https://github.com/duckdb/duckdb-node-neo/issues/new), + [ODBC](https://github.com/duckdb/duckdb-odbc/issues/new), + [Python](https://github.com/duckdb/duckdb-python/issues/new), + [R](https://github.com/duckdb/duckdb-r/issues/new), + [Rust](https://github.com/duckdb/duckdb-rs/issues/new), + [WebAssembly (Wasm)](https://github.com/duckdb/duckdb-wasm/issues/new) * Extensions: - * [AWS extension](https://github.com/duckdb/duckdb-aws/issues/new) - * [Azure extension](https://github.com/duckdb/duckdb-azure/issues/new) - * [Delta extension](https://github.com/duckdb/duckdb-delta/issues/new) - * [Encodings extension](https://github.com/duckdb/duckdb-encodings/issues/new) - * [Excel extension](https://github.com/duckdb/duckdb-excel/issues/new) - * [fts (full text search) extension](https://github.com/duckdb/duckdb-fts/issues/new) - * [httpfs extension](https://github.com/duckdb/duckdb-httpfs/issues/new) - * [Iceberg extension](https://github.com/duckdb/duckdb-iceberg/issues/new) - * [inet extension](https://github.com/duckdb/duckdb-inet/issues/new) - * [MySQL extension](https://github.com/duckdb/duckdb-mysql/issues/new) - * [Postgres scanner](https://github.com/duckdb/duckdb-postgres/issues/new) - * [Spatial extension](https://github.com/duckdb/duckdb-spatial/issues/new) - * [SQLite scanner](https://github.com/duckdb/duckdb-sqlite/issues/new) - * [UI](https://github.com/duckdb/duckdb-ui/issues/new) - * [VSS extension](https://github.com/duckdb/duckdb-vss/issues/new) - * Connectors: - * [dbt-duckdb](https://github.com/duckdb/dbt-duckdb) + [`aws`](https://github.com/duckdb/duckdb-aws/issues/new), + [`azure`](https://github.com/duckdb/duckdb-azure/issues/new), + [`delta`](https://github.com/duckdb/duckdb-delta/issues/new), + [`ducklake`](https://github.com/duckdb/duckdb-ducklake/issues/new), + [`encodings`](https://github.com/duckdb/duckdb-encodings/issues/new), + [`excel`](https://github.com/duckdb/duckdb-excel/issues/new), + [`fts`](https://github.com/duckdb/duckdb-fts/issues/new), + [`httpfs`](https://github.com/duckdb/duckdb-httpfs/issues/new), + [`iceberg`](https://github.com/duckdb/duckdb-iceberg/issues/new), + [`inet`](https://github.com/duckdb/duckdb-inet/issues/new), + [`mysql`](https://github.com/duckdb/duckdb-mysql/issues/new), + [`postgres`](https://github.com/duckdb/duckdb-postgres/issues/new), + [`spatial`](https://github.com/duckdb/duckdb-spatial/issues/new), + [`sqlite`](https://github.com/duckdb/duckdb-sqlite/issues/new), + [`ui`](https://github.com/duckdb/duckdb-ui/issues/new), + [`vss`](https://github.com/duckdb/duckdb-vss/issues/new) - If none of the above repositories are applicable, feel free to raise it in this one + If the issue occurs in core DuckDB (e.g., a SQL query crashes or returns incorrect results) or if the issue is in the DuckDB command line client, feel free to raise it in this repository. Please report security vulnerabilities using GitHub's [report vulnerability form](https://github.com/duckdb/duckdb/security/advisories/new). diff --git a/.github/config/extensions/avro.cmake b/.github/config/extensions/avro.cmake index 980d9b3564fb..d4fdcf9c0e31 100644 --- a/.github/config/extensions/avro.cmake +++ b/.github/config/extensions/avro.cmake @@ -2,6 +2,7 @@ if (NOT MINGW) duckdb_extension_load(avro LOAD_TESTS DONT_LINK GIT_URL https://github.com/duckdb/duckdb-avro - GIT_TAG 0c97a61781f63f8c5444cf3e0c6881ecbaa9fe13 + GIT_TAG 7b75062f6345d11c5342c09216a75c57342c2e82 + APPLY_PATCHES ) endif() diff --git a/.github/config/extensions/aws.cmake b/.github/config/extensions/aws.cmake index 3260e350874e..ee0547d8089e 100644 --- a/.github/config/extensions/aws.cmake +++ b/.github/config/extensions/aws.cmake @@ -2,6 +2,6 @@ if (NOT MINGW AND NOT ${WASM_ENABLED}) duckdb_extension_load(aws ### TODO: re-enable LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb-aws - GIT_TAG 812ce80fde0bfa6e4641b6fd798087349a610795 + GIT_TAG 18803d5e55b9f9f6dda5047d0fdb4f4238b6801d ) endif() diff --git a/.github/config/extensions/azure.cmake b/.github/config/extensions/azure.cmake index 108fd6666a0a..fe6dcfd19914 100644 --- a/.github/config/extensions/azure.cmake +++ b/.github/config/extensions/azure.cmake @@ -2,6 +2,6 @@ if (NOT MINGW AND NOT ${WASM_ENABLED}) duckdb_extension_load(azure LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb-azure - GIT_TAG 5e458fcc466d2bc421922b11f4316564e3017800 + GIT_TAG 0709c0fa1cf67a668b58b1f06ff3e5fc1696e10a ) endif() diff --git a/.github/config/extensions/ducklake.cmake b/.github/config/extensions/ducklake.cmake index eb4ceb0a2c81..42ff34e3c3e6 100644 --- a/.github/config/extensions/ducklake.cmake +++ b/.github/config/extensions/ducklake.cmake @@ -1,5 +1,5 @@ duckdb_extension_load(ducklake DONT_LINK GIT_URL https://github.com/duckdb/ducklake - GIT_TAG 45788f0a875844ac8fed048c99b87f7f4b1c2ac1 + GIT_TAG f134ad86f2f6e7cdf4133086c38ecd9c48f1a772 ) diff --git a/.github/config/extensions/httpfs.cmake b/.github/config/extensions/httpfs.cmake index 63f201ec7b91..7cf901e05a2a 100644 --- a/.github/config/extensions/httpfs.cmake +++ b/.github/config/extensions/httpfs.cmake @@ -1,7 +1,7 @@ duckdb_extension_load(httpfs LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb-httpfs - GIT_TAG 0518838dae609ab8e8ae66960ce982b839754075 - INCLUDE_DIR src/include + GIT_TAG 8356a9017444f54018159718c8017ff7db4ea756 APPLY_PATCHES + INCLUDE_DIR src/include ) diff --git a/.github/config/extensions/iceberg.cmake b/.github/config/extensions/iceberg.cmake index 3463f5de82b5..ac47b9f604c6 100644 --- a/.github/config/extensions/iceberg.cmake +++ b/.github/config/extensions/iceberg.cmake @@ -4,7 +4,6 @@ IF (NOT WIN32) else () set(LOAD_ICEBERG_TESTS "") endif() - if (NOT MINGW AND NOT ${WASM_ENABLED}) duckdb_extension_load(iceberg # ${LOAD_ICEBERG_TESTS} TODO: re-enable once autoloading test is fixed diff --git a/.github/config/extensions/inet.cmake b/.github/config/extensions/inet.cmake index baa0dce2606c..7b112317c3fa 100644 --- a/.github/config/extensions/inet.cmake +++ b/.github/config/extensions/inet.cmake @@ -4,4 +4,5 @@ duckdb_extension_load(inet GIT_TAG f6a2a14f061d2dfccdb4283800b55fef3fcbb128 INCLUDE_DIR src/include TEST_DIR test/sql + APPLY_PATCHES ) diff --git a/.github/config/extensions/mysql_scanner.cmake b/.github/config/extensions/mysql_scanner.cmake index 987420db70df..581cac266d26 100644 --- a/.github/config/extensions/mysql_scanner.cmake +++ b/.github/config/extensions/mysql_scanner.cmake @@ -3,6 +3,6 @@ if (NOT MINGW AND NOT ${WASM_ENABLED} AND NOT ${MUSL_ENABLED}) DONT_LINK LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb-mysql - GIT_TAG 8a32d4e069438585e80494e296e407653aebfed3 + GIT_TAG c80647b33972c150f0bd0001c35085cefdc82d1e ) endif() diff --git a/.github/config/extensions/spatial.cmake b/.github/config/extensions/spatial.cmake index bc9b60e22643..5d27cf96f70d 100644 --- a/.github/config/extensions/spatial.cmake +++ b/.github/config/extensions/spatial.cmake @@ -6,5 +6,6 @@ duckdb_extension_load(spatial GIT_TAG a6a607fe3a98ef9ad4bed218490b770f725fbc12 INCLUDE_DIR src/spatial TEST_DIR test/sql + APPLY_PATCHES ) endif() diff --git a/.github/config/extensions/sqlite_scanner.cmake b/.github/config/extensions/sqlite_scanner.cmake index 2ae4f8b2260d..59d852240978 100644 --- a/.github/config/extensions/sqlite_scanner.cmake +++ b/.github/config/extensions/sqlite_scanner.cmake @@ -8,5 +8,5 @@ endif() duckdb_extension_load(sqlite_scanner ${STATIC_LINK_SQLITE} LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb-sqlite - GIT_TAG 833e105cbcaa0f6e8d34d334f3b920ce86f6fdf9 + GIT_TAG 0c93d610af1e1f66292559fcf0f01a93597a98b6 ) diff --git a/.github/patches/extensions/avro/fix.patch b/.github/patches/extensions/avro/fix.patch new file mode 100644 index 000000000000..5f1f295d11a0 --- /dev/null +++ b/.github/patches/extensions/avro/fix.patch @@ -0,0 +1,13 @@ +diff --git a/src/field_ids.cpp b/src/field_ids.cpp +index d197f8d..52fb48c 100644 +--- a/src/field_ids.cpp ++++ b/src/field_ids.cpp +@@ -5,6 +5,8 @@ namespace duckdb { + + namespace avro { + ++constexpr const char *FieldID::DUCKDB_FIELD_ID; ++ + FieldID::FieldID() : set(false) { + } + diff --git a/.github/patches/extensions/httpfs/fix.patch b/.github/patches/extensions/httpfs/fix.patch index dbc2066b561a..0327dd513e33 100644 --- a/.github/patches/extensions/httpfs/fix.patch +++ b/.github/patches/extensions/httpfs/fix.patch @@ -1,157 +1,38 @@ -diff --git a/src/httpfs.cpp b/src/httpfs.cpp -index 802581e..a11af95 100644 ---- a/src/httpfs.cpp -+++ b/src/httpfs.cpp -@@ -729,7 +729,7 @@ void HTTPFileHandle::LoadFileInfo() { - return; - } else { - // HEAD request fail, use Range request for another try (read only one byte) -- if (flags.OpenForReading() && res->status != HTTPStatusCode::NotFound_404) { -+ if (flags.OpenForReading() && res->status != HTTPStatusCode::NotFound_404 && res->status != HTTPStatusCode::MovedPermanently_301) { - auto range_res = hfs.GetRangeRequest(*this, path, {}, 0, nullptr, 2); - if (range_res->status != HTTPStatusCode::PartialContent_206 && - range_res->status != HTTPStatusCode::Accepted_202 && range_res->status != HTTPStatusCode::OK_200) { -diff --git a/src/httpfs_extension.cpp b/src/httpfs_extension.cpp -index 79d9923..9070621 100644 ---- a/src/httpfs_extension.cpp -+++ b/src/httpfs_extension.cpp -@@ -70,7 +70,7 @@ static void LoadInternal(ExtensionLoader &loader) { - config.AddExtensionOption("ca_cert_file", "Path to a custom certificate file for self-signed certificates.", - LogicalType::VARCHAR, Value("")); - // Global S3 config -- config.AddExtensionOption("s3_region", "S3 Region", LogicalType::VARCHAR, Value("us-east-1")); -+ config.AddExtensionOption("s3_region", "S3 Region", LogicalType::VARCHAR); - config.AddExtensionOption("s3_access_key_id", "S3 Access Key ID", LogicalType::VARCHAR); - config.AddExtensionOption("s3_secret_access_key", "S3 Access Key", LogicalType::VARCHAR); - config.AddExtensionOption("s3_session_token", "S3 Session Token", LogicalType::VARCHAR); -diff --git a/src/include/s3fs.hpp b/src/include/s3fs.hpp -index 525e0dd..a7e933e 100644 ---- a/src/include/s3fs.hpp -+++ b/src/include/s3fs.hpp -@@ -231,7 +231,7 @@ public: - return true; - } - -- static string GetS3BadRequestError(S3AuthParams &s3_auth_params); -+ static string GetS3BadRequestError(S3AuthParams &s3_auth_params, string correct_region = ""); - static string GetS3AuthError(S3AuthParams &s3_auth_params); - static string GetGCSAuthError(S3AuthParams &s3_auth_params); - static HTTPException GetS3Error(S3AuthParams &s3_auth_params, const HTTPResponse &response, const string &url); diff --git a/src/s3fs.cpp b/src/s3fs.cpp -index cbdecba..72eddc3 100644 +index 72eddc3..601ecba 100644 --- a/src/s3fs.cpp +++ b/src/s3fs.cpp -@@ -872,6 +872,7 @@ void S3FileHandle::Initialize(optional_ptr opener) { - ErrorData error(ex); - bool refreshed_secret = false; - if (error.Type() == ExceptionType::IO || error.Type() == ExceptionType::HTTP) { -+ // legacy endpoint (no region) returns 400 - auto context = opener->TryGetClientContext(); - if (context) { - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); -@@ -887,9 +888,13 @@ void S3FileHandle::Initialize(optional_ptr opener) { - auto &extra_info = error.ExtraInfo(); - auto entry = extra_info.find("status_code"); - if (entry != extra_info.end()) { -- if (entry->second == "400") { -- // 400: BAD REQUEST -- auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params); -+ if (entry->second == "301" || entry->second == "400") { -+ auto new_region = extra_info.find("header_x-amz-bucket-region"); -+ string correct_region = ""; -+ if (new_region != extra_info.end()) { -+ correct_region = new_region->second; -+ } -+ auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params, correct_region); - throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); +@@ -895,7 +895,7 @@ void S3FileHandle::Initialize(optional_ptr opener) { + correct_region = new_region->second; + } + auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params, correct_region); +- throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); ++ throw Exception(extra_info, error.Type(), error.RawMessage() + extra_text); } if (entry->second == "403") { -@@ -1138,12 +1143,15 @@ bool S3FileSystem::ListFiles(const string &directory, const std::function opener) { + } else { + extra_text = S3FileSystem::GetS3AuthError(auth_params); + } +- throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); ++ throw Exception(extra_info, error.Type(), error.RawMessage() + extra_text); + } + } + throw; +@@ -941,13 +941,13 @@ bool S3FileSystem::CanHandleFile(const string &fpath) { + void S3FileSystem::RemoveFile(const string &path, optional_ptr opener) { + auto handle = OpenFile(path, FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS, opener); + if (!handle) { +- throw IOException("Could not remove file \"%s\": %s", {{"errno", "404"}}, path, "No such file or directory"); ++ throw IOException({{"errno", "404"}}, "Could not remove file \"%s\": %s", path, "No such file or directory"); + } + + auto &s3fh = handle->Cast(); + auto res = DeleteRequest(*handle, s3fh.path, {}); + if (res->status != HTTPStatusCode::OK_200 && res->status != HTTPStatusCode::NoContent_204) { +- throw IOException("Could not remove file \"%s\": %s", {{"errno", to_string(static_cast(res->status))}}, ++ throw IOException({{"errno", to_string(static_cast(res->status))}}, "Could not remove file \"%s\": %s", + path, res->GetError()); } - return extra_text; } -diff --git a/test/sql/copy/csv/test_csv_remote.test b/test/sql/copy/csv/test_csv_remote.test -index 4144082..9e51585 100644 ---- a/test/sql/copy/csv/test_csv_remote.test -+++ b/test/sql/copy/csv/test_csv_remote.test -@@ -7,15 +7,26 @@ require httpfs - statement ok - PRAGMA enable_verification - -+# Test load from url with query string -+query IIIIIIIIIIII -+FROM sniff_csv('https://github.com/duckdb/duckdb/raw/main/data/csv/customer.csv?v=1') -+---- -+, " (empty) \n (empty) 0 0 [{'name': column00, 'type': BIGINT}, {'name': column01, 'type': VARCHAR}, {'name': column02, 'type': BIGINT}, {'name': column03, 'type': BIGINT}, {'name': column04, 'type': BIGINT}, {'name': column05, 'type': BIGINT}, {'name': column06, 'type': BIGINT}, {'name': column07, 'type': VARCHAR}, {'name': column08, 'type': VARCHAR}, {'name': column09, 'type': VARCHAR}, {'name': column10, 'type': VARCHAR}, {'name': column11, 'type': BIGINT}, {'name': column12, 'type': BIGINT}, {'name': column13, 'type': BIGINT}, {'name': column14, 'type': VARCHAR}, {'name': column15, 'type': VARCHAR}, {'name': column16, 'type': VARCHAR}, {'name': column17, 'type': BIGINT}] NULL NULL NULL FROM read_csv('https://github.com/duckdb/duckdb/raw/main/data/csv/customer.csv?v=1', auto_detect=false, delim=',', quote='"', escape='', new_line='\n', skip=0, comment='', header=false, columns={'column00': 'BIGINT', 'column01': 'VARCHAR', 'column02': 'BIGINT', 'column03': 'BIGINT', 'column04': 'BIGINT', 'column05': 'BIGINT', 'column06': 'BIGINT', 'column07': 'VARCHAR', 'column08': 'VARCHAR', 'column09': 'VARCHAR', 'column10': 'VARCHAR', 'column11': 'BIGINT', 'column12': 'BIGINT', 'column13': 'BIGINT', 'column14': 'VARCHAR', 'column15': 'VARCHAR', 'column16': 'VARCHAR', 'column17': 'BIGINT'}); -+ -+ -+# This test abuses the LOCAL_EXTENSION_REPO env to make sure tests are only run when running extension tests -+# in duckdb/duckdb. Otherwise you need to pass a data dir when exex -+ -+require-env LOCAL_EXTENSION_REPO - - # regular csv file - query ITTTIITITTIIII nosort webpagecsv --SELECT * FROM read_csv_auto('duckdb/data/csv/real/web_page.csv') ORDER BY 1; -+SELECT * FROM read_csv_auto('data/csv/real/web_page.csv') ORDER BY 1; - ---- - - # file with gzip - query IIIIIIIIIIIIIII nosort lineitemcsv --SELECT * FROM read_csv_auto('duckdb/data/csv/lineitem1k.tbl.gz') ORDER BY ALL; -+SELECT * FROM read_csv_auto('data/csv/lineitem1k.tbl.gz') ORDER BY ALL; - ---- - - query ITTTIITITTIIII nosort webpagecsv -@@ -25,10 +36,3 @@ SELECT * FROM read_csv_auto('https://raw.githubusercontent.com/duckdb/duckdb/mai - query IIIIIIIIIIIIIII nosort lineitemcsv - select * from read_csv_auto('https://raw.githubusercontent.com/duckdb/duckdb/main/data/csv/lineitem1k.tbl.gz') ORDER BY ALL; - ---- -- -- --# Test load from url with query string --query IIIIIIIIIIII --FROM sniff_csv('https://github.com/duckdb/duckdb/raw/main/data/csv/customer.csv?v=1') ------ --, " (empty) \n (empty) 0 0 [{'name': column00, 'type': BIGINT}, {'name': column01, 'type': VARCHAR}, {'name': column02, 'type': BIGINT}, {'name': column03, 'type': BIGINT}, {'name': column04, 'type': BIGINT}, {'name': column05, 'type': BIGINT}, {'name': column06, 'type': BIGINT}, {'name': column07, 'type': VARCHAR}, {'name': column08, 'type': VARCHAR}, {'name': column09, 'type': VARCHAR}, {'name': column10, 'type': VARCHAR}, {'name': column11, 'type': BIGINT}, {'name': column12, 'type': BIGINT}, {'name': column13, 'type': BIGINT}, {'name': column14, 'type': VARCHAR}, {'name': column15, 'type': VARCHAR}, {'name': column16, 'type': VARCHAR}, {'name': column17, 'type': BIGINT}] NULL NULL NULL FROM read_csv('https://github.com/duckdb/duckdb/raw/main/data/csv/customer.csv?v=1', auto_detect=false, delim=',', quote='"', escape='', new_line='\n', skip=0, comment='', header=false, columns={'column00': 'BIGINT', 'column01': 'VARCHAR', 'column02': 'BIGINT', 'column03': 'BIGINT', 'column04': 'BIGINT', 'column05': 'BIGINT', 'column06': 'BIGINT', 'column07': 'VARCHAR', 'column08': 'VARCHAR', 'column09': 'VARCHAR', 'column10': 'VARCHAR', 'column11': 'BIGINT', 'column12': 'BIGINT', 'column13': 'BIGINT', 'column14': 'VARCHAR', 'column15': 'VARCHAR', 'column16': 'VARCHAR', 'column17': 'BIGINT'}); -diff --git a/test/sql/copy/s3/url_encode.test b/test/sql/copy/s3/url_encode.test -index 66cbd5c..f5a5912 100644 ---- a/test/sql/copy/s3/url_encode.test -+++ b/test/sql/copy/s3/url_encode.test -@@ -132,12 +132,20 @@ set s3_endpoint=''; - statement error - SELECT * FROM 's3://test-bucket/whatever.parquet'; - ---- --:.*Unknown error for HTTP HEAD to 'http://test-bucket.s3.eu-west-1.amazonaws.com/whatever.parquet'.* -+:.*HTTP Error: Unable to connect to URL .*http://test-bucket.s3.eu-west-1.amazonaws.com/whatever.parquet.*: 301 .Moved Permanently..* -+.* -+.*Bad Request - this can be caused by the S3 region being set incorrectly.* -+.*Provided region is: .eu-west-1.* -+.*Correct region is: .us-east-1.* - - statement error - SELECT * FROM 'r2://test-bucket/whatever.parquet'; - ---- --:.*Unknown error for HTTP HEAD to 'http://test-bucket.s3.eu-west-1.amazonaws.com/whatever.parquet'.* -+:.*HTTP Error: Unable to connect to URL .*http://test-bucket.s3.eu-west-1.amazonaws.com/whatever.parquet.*: 301 .Moved Permanently..* -+.* -+.*Bad Request - this can be caused by the S3 region being set incorrectly.* -+.*Provided region is: .eu-west-1.* -+.*Correct region is: .us-east-1.* - - statement error - SELECT * FROM 'gcs://test-bucket/whatever.parquet'; diff --git a/.github/patches/extensions/inet/hugeint_fixes.patch b/.github/patches/extensions/inet/hugeint_fixes.patch new file mode 100644 index 000000000000..4b4375d116d7 --- /dev/null +++ b/.github/patches/extensions/inet/hugeint_fixes.patch @@ -0,0 +1,19 @@ +diff --git a/src/inet_functions.cpp b/src/inet_functions.cpp +index da92a4c..afa7446 100644 +--- a/src/inet_functions.cpp ++++ b/src/inet_functions.cpp +@@ -185,11 +185,12 @@ static INET_TYPE AddImplementation(INET_TYPE ip, hugeint_t val) { + if (val > 0) { + address_out = + AddOperatorOverflowCheck::Operation( +- address_in, val); ++ address_in, (uhugeint_t)val); + } else { ++ // TODO: this is off for when val is the minimal uhugeint_t value + address_out = + SubtractOperatorOverflowCheck::Operation(address_in, -val); ++ uhugeint_t>(address_in, (uhugeint_t)(-val)); + } + + if (addr_type == IPAddressType::IP_ADDRESS_V4 && diff --git a/.github/patches/extensions/spatial/fix.patch b/.github/patches/extensions/spatial/fix.patch new file mode 100644 index 000000000000..26a4b8ddf223 --- /dev/null +++ b/.github/patches/extensions/spatial/fix.patch @@ -0,0 +1,16 @@ +diff --git a/src/spatial/modules/main/spatial_functions_scalar.cpp b/src/spatial/modules/main/spatial_functions_scalar.cpp +index 60ca7373ce..a44cfc7a82 100644 +--- a/src/spatial/modules/main/spatial_functions_scalar.cpp ++++ b/src/spatial/modules/main/spatial_functions_scalar.cpp +@@ -9243,6 +9243,11 @@ struct ST_MMin : VertexAggFunctionBase { + static constexpr auto ORDINATE = VertexOrdinate::M; + }; + ++constexpr const char * ST_M::NAME; ++constexpr const char * ST_X::NAME; ++constexpr const char * ST_Y::NAME; ++constexpr const char * ST_Z::NAME; ++ + } // namespace + + // Helper to access the constant distance from the bind data diff --git a/.github/workflows/BundleStaticLibs.yml b/.github/workflows/BundleStaticLibs.yml index d79b820087b8..7b25b5eb04a0 100644 --- a/.github/workflows/BundleStaticLibs.yml +++ b/.github/workflows/BundleStaticLibs.yml @@ -44,16 +44,18 @@ jobs: strategy: matrix: include: - - version: "macos-13" + - xcode_target_flag: "x86_64" architecture: "amd64" - - version: "macos-14" + - xcode_target_flag: "arm64" architecture: "arm64" - runs-on: ${{ matrix.version }} + runs-on: macos-latest env: EXTENSION_CONFIGS: '${GITHUB_WORKSPACE}/.github/config/bundled_extensions.cmake' ENABLE_EXTENSION_AUTOLOADING: 1 ENABLE_EXTENSION_AUTOINSTALL: 1 GEN: ninja + OSX_BUILD_ARCH: ${{ matrix.xcode_target_flag }} + DUCKDB_PLATFORM: osx_${{ matrix.architecture }} steps: - uses: actions/checkout@v4 @@ -83,10 +85,6 @@ jobs: run: | make gather-libs - - name: Print platform - shell: bash - run: ./build/release/duckdb -c "PRAGMA platform;" - - name: Deploy shell: bash env: @@ -103,6 +101,7 @@ jobs: path: | static-libs-osx-${{ matrix.architecture }}.zip + bundle-mingw-static-lib: name: Windows MingW static libs runs-on: windows-latest diff --git a/.github/workflows/CodeQuality.yml b/.github/workflows/CodeQuality.yml index 5b5f3d00e3fe..67f410030ba6 100644 --- a/.github/workflows/CodeQuality.yml +++ b/.github/workflows/CodeQuality.yml @@ -18,7 +18,8 @@ on: - '.github/workflows/**' - '!.github/workflows/lcov_exclude' - '!.github/workflows/CodeQuality.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' merge_group: pull_request: types: [opened, reopened, ready_for_review, converted_to_draft] @@ -29,7 +30,8 @@ on: - '.github/workflows/**' - '!.github/workflows/lcov_exclude' - '!.github/workflows/CodeQuality.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} diff --git a/.github/workflows/Julia.yml b/.github/workflows/Julia.yml index 0fb9b4e69be7..abb3a78bfab6 100644 --- a/.github/workflows/Julia.yml +++ b/.github/workflows/Julia.yml @@ -16,7 +16,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Julia.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' merge_group: pull_request: types: [opened, reopened, ready_for_review, converted_to_draft] @@ -29,7 +30,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Julia.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} diff --git a/.github/workflows/Main.yml b/.github/workflows/Main.yml index df6d04de7896..8ef63205977d 100644 --- a/.github/workflows/Main.yml +++ b/.github/workflows/Main.yml @@ -14,7 +14,10 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Main.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' + - '!.github/patches/extensions/fts/*.patch' # fts used in some jobs + - '!.github/config/extensions/fts.cmake' merge_group: pull_request: types: [opened, reopened, ready_for_review, converted_to_draft] @@ -25,7 +28,10 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Main.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' + - '!.github/patches/extensions/fts/*.patch' # fts used in some jobs + - '!.github/config/extensions/fts.cmake' concurrency: @@ -464,3 +470,9 @@ jobs: shell: bash run: | ./build/release/test/unittest --test-config test/configs/prefetch_all_storage.json + + - name: Test peg_parser + if: (success() || failure()) && steps.build.conclusion == 'success' + shell: bash + run: | + ./build/release/test/unittest --test-config test/configs/peg_parser.json diff --git a/.github/workflows/Regression.yml b/.github/workflows/Regression.yml index 25098e1f45a0..1145a155090a 100644 --- a/.github/workflows/Regression.yml +++ b/.github/workflows/Regression.yml @@ -23,6 +23,10 @@ on: - '.github/workflows/**' - '!.github/workflows/Regression.yml' - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' + - '!.github/patches/extensions/httpfs/*.patch' # httpfs used in some jobs + - '!.github/config/extensions/httpfs.cmake' merge_group: pull_request: types: [opened, reopened, ready_for_review, converted_to_draft] @@ -33,7 +37,10 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Regression.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' + - '!.github/patches/extensions/httpfs/*.patch' # httpfs used in some jobs + - '!.github/config/extensions/httpfs.cmake' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} diff --git a/.github/workflows/Swift.yml b/.github/workflows/Swift.yml index a58a3e953e96..a8e9dba1bb7c 100644 --- a/.github/workflows/Swift.yml +++ b/.github/workflows/Swift.yml @@ -16,7 +16,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Swift.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' merge_group: pull_request: types: [opened, reopened, ready_for_review, converted_to_draft] @@ -29,7 +30,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Swift.yml' - - '.github/config/out_of_tree_extensions.cmake' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} diff --git a/.github/workflows/Windows.yml b/.github/workflows/Windows.yml index be377475f8e5..fb0c10f4d125 100644 --- a/.github/workflows/Windows.yml +++ b/.github/workflows/Windows.yml @@ -33,6 +33,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Windows.yml' + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' merge_group: pull_request: @@ -45,7 +47,8 @@ on: - '.github/patches/duckdb-wasm/**' - '.github/workflows/**' - '!.github/workflows/Windows.yml' - + - '.github/config/extensions/*.cmake' + - '.github/patches/extensions/**/*.patch' concurrency: group: windows-${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }}-${{ inputs.override_git_describe }} diff --git a/CMakeLists.txt b/CMakeLists.txt index b85e8c1eceb9..94e16b44e440 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,14 +37,24 @@ set(CMAKE_VERBOSE_MAKEFILE OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_MACOSX_RPATH 1) -find_program(CCACHE_PROGRAM ccache) -if(CCACHE_PROGRAM) - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}") -else() - find_program(CCACHE_PROGRAM sccache) - if(CCACHE_PROGRAM) - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}") - endif() +if(NOT DEFINED CMAKE_C_COMPILER_LAUNCHER) + find_program(COMPILER_LAUNCHER NAMES ccache sccache) + if(COMPILER_LAUNCHER) + message(STATUS "Using ${COMPILER_LAUNCHER} as C compiler launcher") + set(CMAKE_C_COMPILER_LAUNCHER + "${COMPILER_LAUNCHER}" + CACHE STRING "" FORCE) + endif() +endif() + +if(NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER) + find_program(COMPILER_LAUNCHER NAMES ccache sccache) + if(COMPILER_LAUNCHER) + message(STATUS "Using ${COMPILER_LAUNCHER} as C++ compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER + "${COMPILER_LAUNCHER}" + CACHE STRING "" FORCE) + endif() endif() # Determine install paths diff --git a/data/csv/afl/3981/case_0.csv b/data/csv/afl/3981/case_0.csv deleted file mode 100644 index 59390ec49901..000000000000 Binary files a/data/csv/afl/3981/case_0.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_1.csv b/data/csv/afl/3981/case_1.csv deleted file mode 100644 index a8919290cb72..000000000000 Binary files a/data/csv/afl/3981/case_1.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_2.csv b/data/csv/afl/3981/case_2.csv deleted file mode 100644 index 2154533db63a..000000000000 Binary files a/data/csv/afl/3981/case_2.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_3.csv b/data/csv/afl/3981/case_3.csv deleted file mode 100644 index 7fb006c47f72..000000000000 Binary files a/data/csv/afl/3981/case_3.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_4.csv b/data/csv/afl/3981/case_4.csv deleted file mode 100644 index d73484ecd2ed..000000000000 Binary files a/data/csv/afl/3981/case_4.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_5.csv b/data/csv/afl/3981/case_5.csv deleted file mode 100644 index 7e5b80b63bdd..000000000000 Binary files a/data/csv/afl/3981/case_5.csv and /dev/null differ diff --git a/data/csv/afl/3981/case_6.csv b/data/csv/afl/3981/case_6.csv deleted file mode 100644 index 3a200bab6ab5..000000000000 Binary files a/data/csv/afl/3981/case_6.csv and /dev/null differ diff --git a/data/parquet-testing/broken/internal_6129.parquet b/data/parquet-testing/broken/internal_6129.parquet new file mode 100644 index 000000000000..bf426a659c49 Binary files /dev/null and b/data/parquet-testing/broken/internal_6129.parquet differ diff --git a/data/parquet-testing/broken/internal_6165.parquet b/data/parquet-testing/broken/internal_6165.parquet new file mode 100644 index 000000000000..00f245b9e9ab Binary files /dev/null and b/data/parquet-testing/broken/internal_6165.parquet differ diff --git a/data/storage/cte_v1.db.gz b/data/storage/cte_v1.db.gz new file mode 100644 index 000000000000..c04621a345c0 Binary files /dev/null and b/data/storage/cte_v1.db.gz differ diff --git a/data/storage/cte_v1_4.db.gz b/data/storage/cte_v1_4.db.gz new file mode 100644 index 000000000000..caf302585d57 Binary files /dev/null and b/data/storage/cte_v1_4.db.gz differ diff --git a/extension/autocomplete/CMakeLists.txt b/extension/autocomplete/CMakeLists.txt index 544e65aae9fa..741975363037 100644 --- a/extension/autocomplete/CMakeLists.txt +++ b/extension/autocomplete/CMakeLists.txt @@ -8,6 +8,9 @@ set(AUTOCOMPLETE_EXTENSION_FILES autocomplete_extension.cpp matcher.cpp tokenizer.cpp keyword_helper.cpp keyword_map.cpp) +add_subdirectory(transformer) +add_subdirectory(parser) + build_static_extension(autocomplete ${AUTOCOMPLETE_EXTENSION_FILES}) set(PARAMETERS "-warnings") build_loadable_extension(autocomplete ${PARAMETERS} diff --git a/extension/autocomplete/autocomplete_extension.cpp b/extension/autocomplete/autocomplete_extension.cpp index 580f43dd1190..a565661b57bd 100644 --- a/extension/autocomplete/autocomplete_extension.cpp +++ b/extension/autocomplete/autocomplete_extension.cpp @@ -10,6 +10,7 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/main/extension/extension_loader.hpp" +#include "transformer/peg_transformer.hpp" #include "duckdb/parser/keyword_helper.hpp" #include "matcher.hpp" #include "duckdb/catalog/default/builtin_types/types.hpp" @@ -491,7 +492,8 @@ static duckdb::unique_ptr GenerateSuggestions(Clien // tokenize the input vector tokens; vector suggestions; - MatchState state(tokens, suggestions); + ParseResultAllocator parse_allocator; + MatchState state(tokens, suggestions, parse_allocator); vector unicode_spaces; string clean_sql; const string &sql_ref = StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql; @@ -618,11 +620,11 @@ class ParserTokenizer : public BaseTokenizer { statements.push_back(std::move(tokens)); tokens.clear(); } - void OnLastToken(TokenizeState state, string last_word, idx_t) override { + void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) override { if (last_word.empty()) { return; } - tokens.push_back(std::move(last_word)); + tokens.emplace_back(std::move(last_word), last_pos); } vector> statements; @@ -654,7 +656,8 @@ static duckdb::unique_ptr CheckPEGParserBind(ClientContext &contex continue; } vector suggestions; - MatchState state(tokens, suggestions); + ParseResultAllocator parse_allocator; + MatchState state(tokens, suggestions, parse_allocator); MatcherAllocator allocator; auto &matcher = Matcher::RootMatcher(allocator); @@ -681,6 +684,43 @@ static duckdb::unique_ptr CheckPEGParserBind(ClientContext &contex void CheckPEGParserFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { } +class PEGParserExtension : public ParserExtension { +public: + PEGParserExtension() { + parser_override = PEGParser; + } + + static ParserOverrideResult PEGParser(ParserExtensionInfo *info, const string &query) { + vector root_tokens; + string clean_sql; + + ParserTokenizer tokenizer(query, root_tokens); + tokenizer.TokenizeInput(); + tokenizer.statements.push_back(std::move(root_tokens)); + + vector> result; + try { + for (auto &tokenized_statement : tokenizer.statements) { + if (tokenized_statement.empty()) { + continue; + } + auto &transformer = PEGTransformerFactory::GetInstance(); + auto statement = transformer.Transform(tokenized_statement, "Statement"); + if (statement) { + statement->stmt_location = NumericCast(tokenized_statement[0].offset); + statement->stmt_length = + NumericCast(tokenized_statement[tokenized_statement.size() - 1].offset + + tokenized_statement[tokenized_statement.size() - 1].length); + } + result.push_back(std::move(statement)); + } + return ParserOverrideResult(std::move(result)); + } catch (std::exception &e) { + return ParserOverrideResult(e); + } + } +}; + static void LoadInternal(ExtensionLoader &loader) { TableFunction auto_complete_fun("sql_auto_complete", {LogicalType::VARCHAR}, SQLAutoCompleteFunction, SQLAutoCompleteBind, SQLAutoCompleteInit); @@ -689,6 +729,9 @@ static void LoadInternal(ExtensionLoader &loader) { TableFunction check_peg_parser_fun("check_peg_parser", {LogicalType::VARCHAR}, CheckPEGParserFunction, CheckPEGParserBind, nullptr); loader.RegisterFunction(check_peg_parser_fun); + + auto &config = DBConfig::GetConfig(loader.GetDatabaseInstance()); + config.parser_extensions.push_back(PEGParserExtension()); } void AutocompleteExtension::Load(ExtensionLoader &loader) { diff --git a/extension/autocomplete/include/ast/setting_info.hpp b/extension/autocomplete/include/ast/setting_info.hpp new file mode 100644 index 000000000000..16e41d813983 --- /dev/null +++ b/extension/autocomplete/include/ast/setting_info.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "duckdb/common/enums/set_scope.hpp" +#include "duckdb/common/string.hpp" + +namespace duckdb { + +struct SettingInfo { + string name; + SetScope scope = SetScope::AUTOMATIC; // Default value is defined here +}; + +} // namespace duckdb diff --git a/extension/autocomplete/include/keyword_helper.hpp b/extension/autocomplete/include/keyword_helper.hpp index 1de4c220045e..fad021f06bf0 100644 --- a/extension/autocomplete/include/keyword_helper.hpp +++ b/extension/autocomplete/include/keyword_helper.hpp @@ -4,7 +4,7 @@ #include "duckdb/common/string.hpp" namespace duckdb { -enum class KeywordCategory : uint8_t { +enum class PEGKeywordCategory : uint8_t { KEYWORD_NONE, KEYWORD_UNRESERVED, KEYWORD_RESERVED, @@ -12,14 +12,14 @@ enum class KeywordCategory : uint8_t { KEYWORD_COL_NAME }; -class KeywordHelper { +class PEGKeywordHelper { public: - static KeywordHelper &Instance(); - bool KeywordCategoryType(const string &text, KeywordCategory type) const; + static PEGKeywordHelper &Instance(); + bool KeywordCategoryType(const string &text, PEGKeywordCategory type) const; void InitializeKeywordMaps(); private: - KeywordHelper(); + PEGKeywordHelper(); bool initialized; case_insensitive_set_t reserved_keyword_map; case_insensitive_set_t unreserved_keyword_map; diff --git a/extension/autocomplete/include/matcher.hpp b/extension/autocomplete/include/matcher.hpp index 35f3037eeb44..7eb58119bd2f 100644 --- a/extension/autocomplete/include/matcher.hpp +++ b/extension/autocomplete/include/matcher.hpp @@ -11,8 +11,10 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/reference_map.hpp" +#include "transformer/parse_result.hpp" namespace duckdb { +class ParseResultAllocator; class Matcher; class MatcherAllocator; @@ -73,11 +75,14 @@ enum class TokenType { WORD }; struct MatcherToken { // NOLINTNEXTLINE: allow implicit conversion from text - MatcherToken(string text_p) : text(std::move(text_p)) { + MatcherToken(string text_p, idx_t offset_p) : text(std::move(text_p)), offset(offset_p) { + length = text.length(); } TokenType type = TokenType::WORD; string text; + idx_t offset = 0; + idx_t length = 0; }; struct MatcherSuggestion { @@ -96,17 +101,19 @@ struct MatcherSuggestion { }; struct MatchState { - MatchState(vector &tokens, vector &suggestions) - : tokens(tokens), suggestions(suggestions), token_index(0) { + MatchState(vector &tokens, vector &suggestions, ParseResultAllocator &allocator) + : tokens(tokens), suggestions(suggestions), token_index(0), allocator(allocator) { } MatchState(MatchState &state) - : tokens(state.tokens), suggestions(state.suggestions), token_index(state.token_index) { + : tokens(state.tokens), suggestions(state.suggestions), token_index(state.token_index), + allocator(state.allocator) { } vector &tokens; vector &suggestions; reference_set_t added_suggestions; idx_t token_index; + ParseResultAllocator &allocator; void AddSuggestion(MatcherSuggestion suggestion); }; @@ -121,6 +128,7 @@ class Matcher { //! Match virtual MatchResultType Match(MatchState &state) const = 0; + virtual optional_ptr MatchParseResult(MatchState &state) const = 0; virtual SuggestionType AddSuggestion(MatchState &state) const; virtual SuggestionType AddSuggestionInternal(MatchState &state) const = 0; virtual string ToString() const = 0; @@ -166,4 +174,12 @@ class MatcherAllocator { vector> matchers; }; +class ParseResultAllocator { +public: + optional_ptr Allocate(unique_ptr parse_result); + +private: + vector> parse_results; +}; + } // namespace duckdb diff --git a/extension/autocomplete/include/parser/peg_parser.hpp b/extension/autocomplete/include/parser/peg_parser.hpp new file mode 100644 index 000000000000..b6723cdc86a3 --- /dev/null +++ b/extension/autocomplete/include/parser/peg_parser.hpp @@ -0,0 +1,66 @@ + +#pragma once +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { +enum class PEGRuleType { + LITERAL, // literal rule ('Keyword'i) + REFERENCE, // reference to another rule (Rule) + OPTIONAL, // optional rule (Rule?) + OR, // or rule (Rule1 / Rule2) + REPEAT // repeat rule (Rule1* +}; + +enum class PEGTokenType { + LITERAL, // literal token ('Keyword'i) + REFERENCE, // reference token (Rule) + OPERATOR, // operator token (/ or ) + FUNCTION_CALL, // start of function call (i.e. Function(...)) + REGEX // regular expression ([ \t\n\r] or <[a-z_]i[a-z0-9_]i>) +}; + +struct PEGToken { + PEGTokenType type; + string_t text; +}; + +struct PEGRule { + string_map_t parameters; + vector tokens; + + void Clear() { + parameters.clear(); + tokens.clear(); + } +}; + +struct PEGParser { +public: + void ParseRules(const char *grammar); + void AddRule(string_t rule_name, PEGRule rule); + + case_insensitive_map_t rules; +}; + +enum class PEGParseState { + RULE_NAME, // Rule name + RULE_SEPARATOR, // look for <- + RULE_DEFINITION // part of rule definition +}; + +inline bool IsPEGOperator(char c) { + switch (c) { + case '/': + case '?': + case '(': + case ')': + case '*': + case '!': + return true; + default: + return false; + } +} + +} // namespace duckdb diff --git a/extension/autocomplete/include/transformer/parse_result.hpp b/extension/autocomplete/include/transformer/parse_result.hpp new file mode 100644 index 000000000000..7f69285ea9f1 --- /dev/null +++ b/extension/autocomplete/include/transformer/parse_result.hpp @@ -0,0 +1,320 @@ +#pragma once +#include "duckdb/common/arena_linked_list.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +class PEGTransformer; // Forward declaration + +enum class ParseResultType : uint8_t { + LIST, + OPTIONAL, + REPEAT, + CHOICE, + EXPRESSION, + IDENTIFIER, + KEYWORD, + OPERATOR, + STATEMENT, + EXTENSION, + NUMBER, + STRING, + INVALID +}; + +inline const char *ParseResultToString(ParseResultType type) { + switch (type) { + case ParseResultType::LIST: + return "LIST"; + case ParseResultType::OPTIONAL: + return "OPTIONAL"; + case ParseResultType::REPEAT: + return "REPEAT"; + case ParseResultType::CHOICE: + return "CHOICE"; + case ParseResultType::EXPRESSION: + return "EXPRESSION"; + case ParseResultType::IDENTIFIER: + return "IDENTIFIER"; + case ParseResultType::KEYWORD: + return "KEYWORD"; + case ParseResultType::OPERATOR: + return "OPERATOR"; + case ParseResultType::STATEMENT: + return "STATEMENT"; + case ParseResultType::EXTENSION: + return "EXTENSION"; + case ParseResultType::NUMBER: + return "NUMBER"; + case ParseResultType::STRING: + return "STRING"; + case ParseResultType::INVALID: + return "INVALID"; + } + return "INVALID"; +} + +class ParseResult { +public: + explicit ParseResult(ParseResultType type) : type(type) { + } + virtual ~ParseResult() = default; + + template + TARGET &Cast() { + if (TARGET::TYPE != ParseResultType::INVALID && type != TARGET::TYPE) { + throw InternalException("Failed to cast parse result of type %s to type %s for rule %s", + ParseResultToString(TARGET::TYPE), ParseResultToString(type), name); + } + return reinterpret_cast(*this); + } + + ParseResultType type; + string name; + + virtual void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const { + ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type); + if (!name.empty()) { + ss << " (" << name << ")"; + } + } + + // The public entry point + std::string ToString() const { + std::stringstream ss; + std::unordered_set visited; + // The root is always the "last" element at its level + ToStringInternal(ss, visited, "", true); + return ss.str(); + } +}; + +struct IdentifierParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::IDENTIFIER; + string identifier; + + explicit IdentifierParseResult(string identifier_p) : ParseResult(TYPE), identifier(std::move(identifier_p)) { + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << ": \"" << identifier << "\"\n"; + } +}; + +struct KeywordParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::KEYWORD; + string keyword; + + explicit KeywordParseResult(string keyword_p) : ParseResult(TYPE), keyword(std::move(keyword_p)) { + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << ": \"" << keyword << "\"\n"; + } +}; + +struct ListParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::LIST; + vector> children; + +public: + explicit ListParseResult(vector> results_p, string name_p) + : ParseResult(TYPE), children(std::move(results_p)) { + name = name_p; + } + + optional_ptr GetChild(idx_t index) { + if (index >= children.size()) { + throw InternalException("Child index out of bounds"); + } + return children[index]; + } + + template + T &Child(idx_t index) { + auto child_ptr = GetChild(index); + return child_ptr->Cast(); + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ss << indent << (is_last ? "└─" : "├─"); + + if (visited.count(this)) { + ss << " List (" << name << ") [... already printed ...]\n"; + return; + } + visited.insert(this); + + ss << " " << ParseResultToString(type); + if (!name.empty()) { + ss << " (" << name << ")"; + } + ss << " [" << children.size() << " children]\n"; + + std::string child_indent = indent + (is_last ? " " : "│ "); + for (size_t i = 0; i < children.size(); ++i) { + if (children[i]) { + children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1); + } else { + ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n"; + } + } + } +}; + +struct RepeatParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::REPEAT; + vector> children; + + explicit RepeatParseResult(vector> results_p) + : ParseResult(TYPE), children(std::move(results_p)) { + } + + template + T &Child(idx_t index) { + if (index >= children.size()) { + throw InternalException("Child index out of bounds"); + } + return children[index]->Cast(); + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ss << indent << (is_last ? "└─" : "├─"); + + if (visited.count(this)) { + ss << " Repeat (" << name << ") [... already printed ...]\n"; + return; + } + visited.insert(this); + + ss << " " << ParseResultToString(type); + if (!name.empty()) { + ss << " (" << name << ")"; + } + ss << " [" << children.size() << " children]\n"; + + std::string child_indent = indent + (is_last ? " " : "│ "); + for (size_t i = 0; i < children.size(); ++i) { + if (children[i]) { + children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1); + } else { + ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n"; + } + } + } +}; + +struct OptionalParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::OPTIONAL; + optional_ptr optional_result; + + explicit OptionalParseResult() : ParseResult(TYPE), optional_result(nullptr) { + } + explicit OptionalParseResult(optional_ptr result_p) : ParseResult(TYPE), optional_result(result_p) { + name = result_p->name; + } + + bool HasResult() const { + return optional_result != nullptr; + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + if (HasResult()) { + // The optional node has a value, so we "collapse" it by just printing its child. + // We pass the same indentation and is_last status, so it takes the place of the Optional node. + optional_result->ToStringInternal(ss, visited, indent, is_last); + } else { + // The optional node is empty, which is useful information, so we print it. + ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [empty]\n"; + } + } +}; + +class ChoiceParseResult : public ParseResult { +public: + static constexpr ParseResultType TYPE = ParseResultType::CHOICE; + + explicit ChoiceParseResult(optional_ptr parse_result_p, idx_t selected_idx_p) + : ParseResult(TYPE), result(parse_result_p), selected_idx(selected_idx_p) { + name = parse_result_p->name; + } + + optional_ptr result; + idx_t selected_idx; + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + if (result) { + // The choice was resolved. We print a marker and then print the child below it. + ss << indent << (is_last ? "└─" : "├─") << " [" << ParseResultToString(type) << " (idx: " << selected_idx + << ")] ->\n"; + + // The child is now on a new indentation level and is the only child of our marker. + std::string child_indent = indent + (is_last ? " " : "│ "); + result->ToStringInternal(ss, visited, child_indent, true); + } else { + // The choice had no result. + ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [no result]\n"; + } + } +}; + +class NumberParseResult : public ParseResult { +public: + static constexpr ParseResultType TYPE = ParseResultType::NUMBER; + + explicit NumberParseResult(string number_p) : ParseResult(TYPE), number(std::move(number_p)) { + } + // TODO(dtenwolde): Should probably be stored as a size_t, int32_t or float_t depending on what number is. + string number; + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << ": " << number << "\n"; + } +}; + +class StringLiteralParseResult : public ParseResult { +public: + static constexpr ParseResultType TYPE = ParseResultType::STRING; + + explicit StringLiteralParseResult(string string_p) : ParseResult(TYPE), result(std::move(string_p)) { + } + string result; + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << ": \"" << result << "\"\n"; + } +}; + +class OperatorParseResult : public ParseResult { +public: + static constexpr ParseResultType TYPE = ParseResultType::OPERATOR; + + explicit OperatorParseResult(string operator_p) : ParseResult(TYPE), operator_token(std::move(operator_p)) { + } + string operator_token; + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << ": " << operator_token << "\n"; + } +}; + +} // namespace duckdb diff --git a/extension/autocomplete/include/transformer/peg_transformer.hpp b/extension/autocomplete/include/transformer/peg_transformer.hpp new file mode 100644 index 000000000000..4db5de8af9a8 --- /dev/null +++ b/extension/autocomplete/include/transformer/peg_transformer.hpp @@ -0,0 +1,172 @@ +#pragma once + +#include "tokenizer.hpp" +#include "parse_result.hpp" +#include "transform_enum_result.hpp" +#include "transform_result.hpp" +#include "ast/setting_info.hpp" +#include "duckdb/function/macro_function.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/parsed_data/create_type_info.hpp" +#include "duckdb/parser/parsed_data/transaction_info.hpp" +#include "duckdb/parser/statement/copy_database_statement.hpp" +#include "duckdb/parser/statement/set_statement.hpp" +#include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "parser/peg_parser.hpp" +#include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/statement/drop_statement.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" + +namespace duckdb { + +// Forward declare +struct QualifiedName; +struct MatcherToken; + +struct PEGTransformerState { + explicit PEGTransformerState(const vector &tokens_p) : tokens(tokens_p), token_index(0) { + } + + const vector &tokens; + idx_t token_index; +}; + +class PEGTransformer { +public: + using AnyTransformFunction = + std::function(PEGTransformer &, optional_ptr)>; + + PEGTransformer(ArenaAllocator &allocator, PEGTransformerState &state, + const case_insensitive_map_t &transform_functions, + const case_insensitive_map_t &grammar_rules, + const case_insensitive_map_t> &enum_mappings) + : allocator(allocator), state(state), grammar_rules(grammar_rules), transform_functions(transform_functions), + enum_mappings(enum_mappings) { + } + +public: + template + T Transform(optional_ptr parse_result) { + auto it = transform_functions.find(parse_result->name); + if (it == transform_functions.end()) { + throw NotImplementedException("No transformer function found for rule '%s'", parse_result->name); + } + auto &func = it->second; + + unique_ptr base_result = func(*this, parse_result); + if (!base_result) { + throw InternalException("Transformer for rule '%s' returned a nullptr.", parse_result->name); + } + + auto *typed_result_ptr = dynamic_cast *>(base_result.get()); + if (!typed_result_ptr) { + throw InternalException("Transformer for rule '" + parse_result->name + "' returned an unexpected type."); + } + + return std::move(typed_result_ptr->value); + } + + template + T Transform(ListParseResult &parse_result, idx_t child_index) { + auto child_parse_result = parse_result.GetChild(child_index); + return Transform(child_parse_result); + } + + template + T TransformEnum(optional_ptr parse_result) { + auto enum_rule_name = parse_result->name; + + auto rule_value = enum_mappings.find(enum_rule_name); + if (rule_value == enum_mappings.end()) { + throw ParserException("Enum transform failed: could not find mapping for '%s'", enum_rule_name); + } + + auto *typed_enum_ptr = dynamic_cast *>(rule_value->second.get()); + if (!typed_enum_ptr) { + throw InternalException("Enum mapping for rule '%s' has an unexpected type.", enum_rule_name); + } + + return typed_enum_ptr->value; + } + + template + void TransformOptional(ListParseResult &list_pr, idx_t child_idx, T &target) { + auto &opt = list_pr.Child(child_idx); + if (opt.HasResult()) { + target = Transform(opt.optional_result); + } + } + + // Make overloads return raw pointers, as ownership is handled by the ArenaAllocator. + template + T *Make(Args &&...args) { + return allocator.Make(std::forward(args)...); + } + + void ClearParameters(); + static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type); + void SetParam(const string &name, idx_t index, PreparedParamType type); + bool GetParam(const string &name, idx_t &index, PreparedParamType type); + +public: + ArenaAllocator &allocator; + PEGTransformerState &state; + const case_insensitive_map_t &grammar_rules; + const case_insensitive_map_t &transform_functions; + const case_insensitive_map_t> &enum_mappings; + case_insensitive_map_t named_parameter_map; + idx_t prepared_statement_parameter_index = 0; + PreparedParamType last_param_type = PreparedParamType::INVALID; +}; + +class PEGTransformerFactory { +public: + static PEGTransformerFactory &GetInstance(); + explicit PEGTransformerFactory(); + static unique_ptr Transform(vector &tokens, const char *root_rule = "Statement"); + +private: + template + void RegisterEnum(const string &rule_name, T value) { + auto existing_rule = enum_mappings.find(rule_name); + if (existing_rule != enum_mappings.end()) { + throw InternalException("EnumRule %s already exists", rule_name); + } + enum_mappings[rule_name] = make_uniq>(value); + } + + template + void Register(const string &rule_name, FUNC function) { + auto existing_rule = sql_transform_functions.find(rule_name); + if (existing_rule != sql_transform_functions.end()) { + throw InternalException("Rule %s already exists", rule_name); + } + sql_transform_functions[rule_name] = + [function](PEGTransformer &transformer, + optional_ptr parse_result) -> unique_ptr { + auto result_value = function(transformer, parse_result); + return make_uniq>(std::move(result_value)); + }; + } + + PEGTransformerFactory(const PEGTransformerFactory &) = delete; + + static unique_ptr TransformStatement(PEGTransformer &, optional_ptr list); + + // use.gram + static unique_ptr TransformUseStatement(PEGTransformer &transformer, + optional_ptr parse_result); + static QualifiedName TransformUseTarget(PEGTransformer &transformer, optional_ptr parse_result); + +private: + PEGParser parser; + case_insensitive_map_t sql_transform_functions; + case_insensitive_map_t> enum_mappings; +}; + +} // namespace duckdb diff --git a/extension/autocomplete/include/transformer/transform_enum_result.hpp b/extension/autocomplete/include/transformer/transform_enum_result.hpp new file mode 100644 index 000000000000..31e058997938 --- /dev/null +++ b/extension/autocomplete/include/transformer/transform_enum_result.hpp @@ -0,0 +1,15 @@ +#pragma once + +namespace duckdb { +struct TransformEnumValue { + virtual ~TransformEnumValue() = default; +}; + +template +struct TypedTransformEnumResult : public TransformEnumValue { + explicit TypedTransformEnumResult(T value_p) : value(std::move(value_p)) { + } + T value; +}; + +} // namespace duckdb diff --git a/extension/autocomplete/include/transformer/transform_result.hpp b/extension/autocomplete/include/transformer/transform_result.hpp new file mode 100644 index 000000000000..2b9529eccc04 --- /dev/null +++ b/extension/autocomplete/include/transformer/transform_result.hpp @@ -0,0 +1,16 @@ +#pragma once + +namespace duckdb { + +struct TransformResultValue { + virtual ~TransformResultValue() = default; +}; + +template +struct TypedTransformResult : public TransformResultValue { + explicit TypedTransformResult(T value_p) : value(std::move(value_p)) { + } + T value; +}; + +} // namespace duckdb diff --git a/extension/autocomplete/inline_grammar.py b/extension/autocomplete/inline_grammar.py index 6e953b47e09d..2b9699382c27 100644 --- a/extension/autocomplete/inline_grammar.py +++ b/extension/autocomplete/inline_grammar.py @@ -82,7 +82,7 @@ def load_keywords(filepath): f.write("/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */\n") f.write("#include \"keyword_helper.hpp\"\n\n") f.write("namespace duckdb {\n") - f.write("void KeywordHelper::InitializeKeywordMaps() { // Renamed for clarity\n") + f.write("void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity\n") f.write("\tif (initialized) {\n\t\treturn;\n\t};\n") f.write("\tinitialized = true;\n\n") diff --git a/extension/autocomplete/keyword_helper.cpp b/extension/autocomplete/keyword_helper.cpp index dda2e4c4a7dd..893c96671283 100644 --- a/extension/autocomplete/keyword_helper.cpp +++ b/extension/autocomplete/keyword_helper.cpp @@ -1,30 +1,30 @@ #include "keyword_helper.hpp" namespace duckdb { -KeywordHelper &KeywordHelper::Instance() { - static KeywordHelper instance; +PEGKeywordHelper &PEGKeywordHelper::Instance() { + static PEGKeywordHelper instance; return instance; } -KeywordHelper::KeywordHelper() { +PEGKeywordHelper::PEGKeywordHelper() { InitializeKeywordMaps(); } -bool KeywordHelper::KeywordCategoryType(const std::string &text, const KeywordCategory type) const { +bool PEGKeywordHelper::KeywordCategoryType(const std::string &text, const PEGKeywordCategory type) const { switch (type) { - case KeywordCategory::KEYWORD_RESERVED: { + case PEGKeywordCategory::KEYWORD_RESERVED: { auto it = reserved_keyword_map.find(text); return it != reserved_keyword_map.end(); } - case KeywordCategory::KEYWORD_UNRESERVED: { + case PEGKeywordCategory::KEYWORD_UNRESERVED: { auto it = unreserved_keyword_map.find(text); return it != unreserved_keyword_map.end(); } - case KeywordCategory::KEYWORD_TYPE_FUNC: { + case PEGKeywordCategory::KEYWORD_TYPE_FUNC: { auto it = typefunc_keyword_map.find(text); return it != typefunc_keyword_map.end(); } - case KeywordCategory::KEYWORD_COL_NAME: { + case PEGKeywordCategory::KEYWORD_COL_NAME: { auto it = colname_keyword_map.find(text); return it != colname_keyword_map.end(); } diff --git a/extension/autocomplete/keyword_map.cpp b/extension/autocomplete/keyword_map.cpp index 5531b0142e8d..622474585051 100644 --- a/extension/autocomplete/keyword_map.cpp +++ b/extension/autocomplete/keyword_map.cpp @@ -2,7 +2,7 @@ #include "keyword_helper.hpp" namespace duckdb { -void KeywordHelper::InitializeKeywordMaps() { // Renamed for clarity +void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity if (initialized) { return; }; diff --git a/extension/autocomplete/matcher.cpp b/extension/autocomplete/matcher.cpp index c357d0e745fe..adcd12a3c78a 100644 --- a/extension/autocomplete/matcher.cpp +++ b/extension/autocomplete/matcher.cpp @@ -10,6 +10,8 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/exception/parser_exception.hpp" #include "tokenizer.hpp" +#include "parser/peg_parser.hpp" +#include "transformer/parse_result.hpp" #ifdef PEG_PARSER_SOURCE_FILE #include #else @@ -17,7 +19,6 @@ #endif namespace duckdb { -struct PEGParser; SuggestionType Matcher::AddSuggestion(MatchState &state) const { auto entry = state.added_suggestions.find(*this); @@ -53,14 +54,19 @@ class KeywordMatcher : public Matcher { } MatchResultType Match(MatchState &state) const override { - auto &token = state.tokens[state.token_index]; - if (StringUtil::CIEquals(keyword, token.text)) { - // move to the next token - state.token_index++; - return MatchResultType::SUCCESS; - } else { + if (!MatchKeyword(state)) { return MatchResultType::FAIL; } + return MatchResultType::SUCCESS; + } + + optional_ptr MatchParseResult(MatchState &state) const override { + if (!MatchKeyword(state)) { + return nullptr; + } + auto result = state.allocator.Allocate(make_uniq(keyword)); + result->name = name; + return result; } SuggestionType AddSuggestionInternal(MatchState &state) const override { @@ -74,6 +80,20 @@ class KeywordMatcher : public Matcher { return "'" + keyword + "'"; } +private: + bool MatchKeyword(MatchState &state) const { + if (state.token_index >= state.tokens.size()) { + return false; + } + auto &token = state.tokens[state.token_index]; + if (StringUtil::CIEquals(keyword, token.text)) { + // move to the next token + state.token_index++; + return true; + } + return false; + } + private: string keyword; int32_t score_bonus; @@ -122,6 +142,22 @@ class ListMatcher : public Matcher { return MatchResultType::SUCCESS; } + optional_ptr MatchParseResult(MatchState &state) const override { + MatchState list_state(state); + vector> results; + + for (const auto &child_matcher : matchers) { + auto child_result = child_matcher.get().MatchParseResult(list_state); + if (!child_result) { + return nullptr; + } + results.push_back(child_result); + } + state.token_index = list_state.token_index; + // Empty name implies it's a subrule, e.g. 'SET'i (StandardAssignment / SetTimeZone) + return state.allocator.Allocate(make_uniq(std::move(results), name)); + } + SuggestionType AddSuggestionInternal(MatchState &state) const override { for (auto &matcher : matchers) { auto suggestion_result = matcher.get().AddSuggestion(state); @@ -160,7 +196,7 @@ class OptionalMatcher : public Matcher { MatchResultType Match(MatchState &state) const override { MatchState child_state(state); auto child_match = matcher.Match(child_state); - if (child_match != MatchResultType::SUCCESS) { + if (child_match == MatchResultType::FAIL) { // did not succeed in matching - go back up (but return success anyway) return MatchResultType::SUCCESS; } @@ -169,6 +205,18 @@ class OptionalMatcher : public Matcher { return MatchResultType::SUCCESS; } + optional_ptr MatchParseResult(MatchState &state) const override { + MatchState child_state(state); + auto child_match = matcher.MatchParseResult(child_state); + if (child_match == nullptr) { + // did not succeed in matching - go back up (simply return a nullptr) + return state.allocator.Allocate(make_uniq()); + } + // propagate the child state upwards + state.token_index = child_state.token_index; + return state.allocator.Allocate(make_uniq(child_match)); + } + SuggestionType AddSuggestionInternal(MatchState &state) const override { matcher.AddSuggestion(state); return SuggestionType::OPTIONAL; @@ -205,6 +253,20 @@ class ChoiceMatcher : public Matcher { return MatchResultType::FAIL; } + optional_ptr MatchParseResult(MatchState &state) const override { + for (idx_t i = 0; i < matchers.size(); i++) { + MatchState choice_state(state); + auto child_result = matchers[i].get().MatchParseResult(choice_state); + if (child_result != nullptr) { + // we matched this child - propagate upwards + state.token_index = choice_state.token_index; + auto result = state.allocator.Allocate(make_uniq(child_result, i)); + return result; + } + } + return nullptr; + } + SuggestionType AddSuggestionInternal(MatchState &state) const override { for (auto &child_matcher : matchers) { child_matcher.get().AddSuggestion(state); @@ -266,6 +328,41 @@ class RepeatMatcher : public Matcher { } } + optional_ptr MatchParseResult(MatchState &state) const override { + MatchState repeat_state(state); + vector> results; + + // First, we MUST match the element at least once. + auto first_result = element.MatchParseResult(repeat_state); + if (!first_result) { + // The first match failed, so the whole repeat fails. + return nullptr; + } + results.push_back(first_result); + + // After the first success, the overall result is a success. + // Now, we continue matching the element as many times as possible. + while (true) { + // Propagate the new state upwards. + state.token_index = repeat_state.token_index; + + // Check if there are any tokens left. + if (repeat_state.token_index >= state.tokens.size()) { + break; + } + + // Try to match the element again. + auto next_result = element.MatchParseResult(repeat_state); + if (!next_result) { + break; + } + results.push_back(next_result); + } + + // Return all collected results in a RepeatParseResult. + return state.allocator.Allocate(make_uniq(std::move(results))); + } + SuggestionType AddSuggestionInternal(MatchState &state) const override { element.AddSuggestion(state); return SuggestionType::MANDATORY; @@ -301,44 +398,23 @@ class IdentifierMatcher : public Matcher { } MatchResultType Match(MatchState &state) const override { - // variable matchers match anything except for reserved keywords - auto &token_text = state.tokens[state.token_index].text; - const auto &keyword_helper = KeywordHelper::Instance(); - switch (suggestion_type) { - case SuggestionState::SUGGEST_TYPE_NAME: - if (keyword_helper.KeywordCategoryType(token_text, KeywordCategory::KEYWORD_RESERVED) || - keyword_helper.KeywordCategoryType(token_text, GetBannedCategory())) { - return MatchResultType::FAIL; - } - break; - default: { - const auto banned_category = GetBannedCategory(); - const auto allowed_override_category = banned_category == KeywordCategory::KEYWORD_COL_NAME - ? KeywordCategory::KEYWORD_TYPE_FUNC - : KeywordCategory::KEYWORD_COL_NAME; - - const bool is_reserved = keyword_helper.KeywordCategoryType(token_text, KeywordCategory::KEYWORD_RESERVED); - const bool has_extra_banned_category = keyword_helper.KeywordCategoryType(token_text, banned_category); - const bool has_banned_flag = is_reserved || has_extra_banned_category; - - const bool is_unreserved = - keyword_helper.KeywordCategoryType(token_text, KeywordCategory::KEYWORD_UNRESERVED); - const bool has_override_flag = keyword_helper.KeywordCategoryType(token_text, allowed_override_category); - const bool has_allowed_flag = is_unreserved || has_override_flag; - - if (has_banned_flag && !has_allowed_flag) { - return MatchResultType::FAIL; - } - break; - } - } - if (!IsIdentifier(token_text)) { + if (!MatchIdentifier(state)) { return MatchResultType::FAIL; } - state.token_index++; return MatchResultType::SUCCESS; } + optional_ptr MatchParseResult(MatchState &state) const override { + if (state.token_index >= state.tokens.size()) { + return nullptr; + } + auto &token_text = state.tokens[state.token_index].text; + if (!MatchIdentifier(state)) { + return nullptr; + } + return state.allocator.Allocate(make_uniq(token_text)); + } + bool SupportsStringLiteral() const { switch (suggestion_type) { case SuggestionState::SUGGEST_TABLE_NAME: @@ -349,13 +425,13 @@ class IdentifierMatcher : public Matcher { } } - KeywordCategory GetBannedCategory() const { + PEGKeywordCategory GetBannedCategory() const { switch (suggestion_type) { case SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME: case SuggestionState::SUGGEST_TABLE_FUNCTION_NAME: - return KeywordCategory::KEYWORD_COL_NAME; + return PEGKeywordCategory::KEYWORD_COL_NAME; default: - return KeywordCategory::KEYWORD_TYPE_FUNC; + return PEGKeywordCategory::KEYWORD_TYPE_FUNC; } } @@ -395,6 +471,47 @@ class IdentifierMatcher : public Matcher { } } +private: + bool MatchIdentifier(MatchState &state) const { + // variable matchers match anything except for reserved keywords + auto &token_text = state.tokens[state.token_index].text; + const auto &keyword_helper = PEGKeywordHelper::Instance(); + switch (suggestion_type) { + case SuggestionState::SUGGEST_TYPE_NAME: + if (keyword_helper.KeywordCategoryType(token_text, PEGKeywordCategory::KEYWORD_RESERVED) || + keyword_helper.KeywordCategoryType(token_text, GetBannedCategory())) { + return false; + } + break; + default: { + const auto banned_category = GetBannedCategory(); + const auto allowed_override_category = banned_category == PEGKeywordCategory::KEYWORD_COL_NAME + ? PEGKeywordCategory::KEYWORD_TYPE_FUNC + : PEGKeywordCategory::KEYWORD_COL_NAME; + + const bool is_reserved = + keyword_helper.KeywordCategoryType(token_text, PEGKeywordCategory::KEYWORD_RESERVED); + const bool has_extra_banned_category = keyword_helper.KeywordCategoryType(token_text, banned_category); + const bool has_banned_flag = is_reserved || has_extra_banned_category; + + const bool is_unreserved = + keyword_helper.KeywordCategoryType(token_text, PEGKeywordCategory::KEYWORD_UNRESERVED); + const bool has_override_flag = keyword_helper.KeywordCategoryType(token_text, allowed_override_category); + const bool has_allowed_flag = is_unreserved || has_override_flag; + + if (has_banned_flag && !has_allowed_flag) { + return false; + } + break; + } + } + if (!IsIdentifier(token_text)) { + return false; + } + state.token_index++; + return true; + } + SuggestionState suggestion_type; }; @@ -407,13 +524,28 @@ class ReservedIdentifierMatcher : public IdentifierMatcher { } MatchResultType Match(MatchState &state) const override { - // reserved variable matchers match anything + if (!MatchReservedIdentifier(state)) { + return MatchResultType::FAIL; + } + return MatchResultType::SUCCESS; + } + + optional_ptr MatchParseResult(MatchState &state) const override { + auto &token_text = state.tokens[state.token_index].text; + if (!MatchReservedIdentifier(state)) { + return nullptr; + } + return state.allocator.Allocate(make_uniq(token_text)); + } + +private: + bool MatchReservedIdentifier(MatchState &state) const { auto &token_text = state.tokens[state.token_index].text; if (!IsIdentifier(token_text)) { - return MatchResultType::FAIL; + return false; } state.token_index++; - return MatchResultType::SUCCESS; + return true; } }; @@ -427,12 +559,25 @@ class StringLiteralMatcher : public Matcher { MatchResultType Match(MatchState &state) const override { // variable matchers match anything except for reserved keywords + if (!MatchStringLiteral(state)) { + return MatchResultType::FAIL; + } + return MatchResultType::SUCCESS; + } + + optional_ptr MatchParseResult(MatchState &state) const override { + if (state.token_index >= state.tokens.size()) { + return nullptr; + } auto &token_text = state.tokens[state.token_index].text; - if (token_text.size() >= 2 && token_text.front() == '\'' && token_text.back() == '\'') { - state.token_index++; - return MatchResultType::SUCCESS; + if (!MatchStringLiteral(state)) { + return nullptr; } - return MatchResultType::FAIL; + string stripped_string = token_text.substr(1, token_text.length() - 2); + + auto result = state.allocator.Allocate(make_uniq(stripped_string)); + result->name = name; + return result; } SuggestionType AddSuggestionInternal(MatchState &state) const override { @@ -442,6 +587,16 @@ class StringLiteralMatcher : public Matcher { string ToString() const override { return "STRING_LITERAL"; } + +private: + static bool MatchStringLiteral(MatchState &state) { + auto &token_text = state.tokens[state.token_index].text; + if (token_text.size() >= 2 && token_text.front() == '\'' && token_text.back() == '\'') { + state.token_index++; + return true; + } + return false; + } }; class NumberLiteralMatcher : public Matcher { @@ -454,19 +609,25 @@ class NumberLiteralMatcher : public Matcher { MatchResultType Match(MatchState &state) const override { // variable matchers match anything except for reserved keywords - auto &token_text = state.tokens[state.token_index].text; - if (!BaseTokenizer::CharacterIsInitialNumber(token_text[0])) { + if (!MatchNumberLiteral(state)) { return MatchResultType::FAIL; } - for (idx_t i = 1; i < token_text.size(); i++) { - if (!BaseTokenizer::CharacterIsNumber(token_text[i])) { - return MatchResultType::FAIL; - } - } - state.token_index++; return MatchResultType::SUCCESS; } + optional_ptr MatchParseResult(MatchState &state) const override { + if (state.token_index >= state.tokens.size()) { + return nullptr; + } + auto &token_text = state.tokens[state.token_index].text; + if (!MatchNumberLiteral(state)) { + return nullptr; + } + auto result = state.allocator.Allocate(make_uniq(token_text)); + result->name = name; + return result; + } + SuggestionType AddSuggestionInternal(MatchState &state) const override { return SuggestionType::MANDATORY; } @@ -474,6 +635,21 @@ class NumberLiteralMatcher : public Matcher { string ToString() const override { return "NUMBER_LITERAL"; } + +private: + static bool MatchNumberLiteral(MatchState &state) { + auto &token_text = state.tokens[state.token_index].text; + if (!BaseTokenizer::CharacterIsInitialNumber(token_text[0])) { + return false; + } + for (idx_t i = 1; i < token_text.size(); i++) { + if (!BaseTokenizer::CharacterIsNumber(token_text[i])) { + return false; + } + } + state.token_index++; + return true; + } }; class OperatorMatcher : public Matcher { @@ -485,6 +661,33 @@ class OperatorMatcher : public Matcher { } MatchResultType Match(MatchState &state) const override { + if (!MatchOperator(state)) { + return MatchResultType::FAIL; + } + return MatchResultType::SUCCESS; + } + + optional_ptr MatchParseResult(MatchState &state) const override { + if (state.token_index >= state.tokens.size()) { + return nullptr; + } + auto &token_text = state.tokens[state.token_index].text; + if (!MatchOperator(state)) { + return nullptr; + } + return state.allocator.Allocate(make_uniq(token_text)); + } + + SuggestionType AddSuggestionInternal(MatchState &state) const override { + return SuggestionType::MANDATORY; + } + + string ToString() const override { + return "OPERATOR"; + } + +private: + static bool MatchOperator(MatchState &state) { auto &token_text = state.tokens[state.token_index].text; for (auto &c : token_text) { switch (c) { @@ -504,19 +707,11 @@ class OperatorMatcher : public Matcher { case '|': break; default: - return MatchResultType::FAIL; + return false; } } state.token_index++; - return MatchResultType::SUCCESS; - } - - SuggestionType AddSuggestionInternal(MatchState &state) const override { - return SuggestionType::MANDATORY; - } - - string ToString() const override { - return "OPERATOR"; + return true; } }; @@ -526,6 +721,12 @@ Matcher &MatcherAllocator::Allocate(unique_ptr matcher) { return result; } +optional_ptr ParseResultAllocator::Allocate(unique_ptr parse_result) { + auto result_ptr = parse_result.get(); + parse_results.push_back(std::move(parse_result)); + return optional_ptr(result_ptr); +} + //! Class for building matchers class MatcherFactory { friend struct MatcherList; @@ -674,253 +875,6 @@ Matcher &MatcherFactory::Operator() const { return allocator.Allocate(make_uniq()); } -enum class PEGRuleType { - LITERAL, // literal rule ('Keyword'i) - REFERENCE, // reference to another rule (Rule) - OPTIONAL, // optional rule (Rule?) - OR, // or rule (Rule1 / Rule2) - REPEAT // repeat rule (Rule1* -}; - -enum class PEGTokenType { - LITERAL, // literal token ('Keyword'i) - REFERENCE, // reference token (Rule) - OPERATOR, // operator token (/ or ) - FUNCTION_CALL, // start of function call (i.e. Function(...)) - REGEX // regular expression ([ \t\n\r] or <[a-z_]i[a-z0-9_]i>) -}; - -struct PEGToken { - PEGTokenType type; - string_t text; -}; - -struct PEGRule { - string_map_t parameters; - vector tokens; - - void Clear() { - parameters.clear(); - tokens.clear(); - } -}; - -struct PEGParser { -public: - void ParseRules(const char *grammar); - - void AddRule(string_t rule_name, PEGRule rule) { - auto entry = rules.find(rule_name); - if (entry != rules.end()) { - throw InternalException("Failed to parse grammar - duplicate rule name %s", rule_name.GetString()); - } - rules.insert(make_pair(rule_name, std::move(rule))); - } - - string_map_t rules; -}; - -enum class PEGParseState { - RULE_NAME, // Rule name - RULE_SEPARATOR, // look for <- - RULE_DEFINITION // part of rule definition -}; - -bool IsPEGOperator(char c) { - switch (c) { - case '/': - case '?': - case '(': - case ')': - case '*': - case '!': - return true; - default: - return false; - } -} - -void PEGParser::ParseRules(const char *grammar) { - string_t rule_name; - PEGRule rule; - PEGParseState parse_state = PEGParseState::RULE_NAME; - idx_t bracket_count = 0; - bool in_or_clause = false; - // look for the rules - idx_t c = 0; - while (grammar[c]) { - if (grammar[c] == '#') { - // comment - ignore until EOL - while (grammar[c] && !StringUtil::CharacterIsNewline(grammar[c])) { - c++; - } - continue; - } - if (parse_state == PEGParseState::RULE_DEFINITION && StringUtil::CharacterIsNewline(grammar[c]) && - bracket_count == 0 && !in_or_clause && !rule.tokens.empty()) { - // if we see a newline while we are parsing a rule definition we can complete the rule - AddRule(rule_name, std::move(rule)); - rule_name = string_t(); - rule.Clear(); - // look for the subsequent rule - parse_state = PEGParseState::RULE_NAME; - c++; - continue; - } - if (StringUtil::CharacterIsSpace(grammar[c])) { - // skip whitespace - c++; - continue; - } - switch (parse_state) { - case PEGParseState::RULE_NAME: { - // look for alpha-numerics - idx_t start_pos = c; - if (grammar[c] == '%') { - // rules can start with % (%whitespace) - c++; - } - while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { - c++; - } - if (c == start_pos) { - throw InternalException("Failed to parse grammar - expected an alpha-numeric rule name (pos %d)", c); - } - rule_name = string_t(grammar + start_pos, c - start_pos); - rule.Clear(); - parse_state = PEGParseState::RULE_SEPARATOR; - break; - } - case PEGParseState::RULE_SEPARATOR: { - if (grammar[c] == '(') { - if (!rule.parameters.empty()) { - throw InternalException("Failed to parse grammar - multiple parameters at position %d", c); - } - // parameter - c++; - idx_t parameter_start = c; - while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { - c++; - } - if (parameter_start == c) { - throw InternalException("Failed to parse grammar - expected a parameter at position %d", c); - } - rule.parameters.insert( - make_pair(string_t(grammar + parameter_start, c - parameter_start), rule.parameters.size())); - if (grammar[c] != ')') { - throw InternalException("Failed to parse grammar - expected closing bracket at position %d", c); - } - c++; - } else { - if (grammar[c] != '<' || grammar[c + 1] != '-') { - throw InternalException("Failed to parse grammar - expected a rule definition (<-) (pos %d)", c); - } - c += 2; - parse_state = PEGParseState::RULE_DEFINITION; - } - break; - } - case PEGParseState::RULE_DEFINITION: { - // we parse either: - // (1) a literal ('Keyword'i) - // (2) a rule reference (Rule) - // (3) an operator ( '(' '/' '?' '*' ')') - in_or_clause = false; - if (grammar[c] == '\'') { - // parse literal - c++; - idx_t literal_start = c; - while (grammar[c] && grammar[c] != '\'') { - if (grammar[c] == '\\') { - // escape - c++; - } - c++; - } - if (!grammar[c]) { - throw InternalException("Failed to parse grammar - did not find closing ' (pos %d)", c); - } - PEGToken token; - token.text = string_t(grammar + literal_start, c - literal_start); - token.type = PEGTokenType::LITERAL; - rule.tokens.push_back(token); - c++; - } else if (StringUtil::CharacterIsAlphaNumeric(grammar[c])) { - // alphanumeric character - this is a rule reference - idx_t rule_start = c; - while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { - c++; - } - PEGToken token; - token.text = string_t(grammar + rule_start, c - rule_start); - if (grammar[c] == '(') { - // this is a function call - c++; - bracket_count++; - token.type = PEGTokenType::FUNCTION_CALL; - } else { - token.type = PEGTokenType::REFERENCE; - } - rule.tokens.push_back(token); - } else if (grammar[c] == '[' || grammar[c] == '<') { - // regular expression- [^"] or <...> - idx_t rule_start = c; - char final_char = grammar[c] == '[' ? ']' : '>'; - while (grammar[c] && grammar[c] != final_char) { - if (grammar[c] == '\\') { - // handle escapes - c++; - } - if (grammar[c]) { - c++; - } - } - c++; - PEGToken token; - token.text = string_t(grammar + rule_start, c - rule_start); - token.type = PEGTokenType::REGEX; - rule.tokens.push_back(token); - } else if (IsPEGOperator(grammar[c])) { - if (grammar[c] == '(') { - bracket_count++; - } else if (grammar[c] == ')') { - if (bracket_count == 0) { - throw InternalException("Failed to parse grammar - unclosed bracket at position %d in rule %s", - c, rule_name.GetString()); - } - bracket_count--; - } else if (grammar[c] == '/') { - in_or_clause = true; - } - // operator - operators are always length 1 - PEGToken token; - token.text = string_t(grammar + c, 1); - token.type = PEGTokenType::OPERATOR; - rule.tokens.push_back(token); - c++; - } else { - throw InternalException("Unrecognized rule contents in rule %s (character %s)", rule_name.GetString(), - string(1, grammar[c])); - } - } - default: - break; - } - if (!grammar[c]) { - break; - } - } - if (parse_state == PEGParseState::RULE_SEPARATOR) { - throw InternalException("Failed to parse grammar - rule %s does not have a definition", rule_name.GetString()); - } - if (parse_state == PEGParseState::RULE_DEFINITION) { - if (rule.tokens.empty()) { - throw InternalException("Failed to parse grammar - rule %s is empty", rule_name.GetString()); - } - AddRule(rule_name, std::move(rule)); - } -} - Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name) { vector> parameters; return CreateMatcher(parser, rule_name, parameters); @@ -1022,7 +976,7 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve } } // look up the rule - auto entry = parser.rules.find(rule_name); + auto entry = parser.rules.find(rule_name.GetString()); if (entry == parser.rules.end()) { throw InternalException("Failed to create matcher for rule %s - rule is missing", rule_name.GetString()); } @@ -1095,25 +1049,28 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve } case '/': { // OR operator - this signifies a choice between the last rule and the next rule - auto &last_matcher = list.GetLastRootMatcher().matcher; - if (last_matcher.Type() != MatcherType::LIST) { + auto &last_root_matcher = list.GetLastRootMatcher().matcher; + if (last_root_matcher.Type() != MatcherType::LIST) { throw InternalException("OR expected a list matcher"); } - auto &list_matcher = last_matcher.Cast(); + auto &list_matcher = last_root_matcher.Cast(); if (list_matcher.matchers.empty()) { throw InternalException("OR rule found as first token"); } - auto &final_matcher = list_matcher.matchers.back(); - vector> choice_matchers; - choice_matchers.push_back(final_matcher); - auto &choice_matcher = Choice(choice_matchers); + auto &previous_matcher = list_matcher.matchers.back(); - // the choice matcher gets added to the list matcher (instead of the previous matcher) - list_matcher.matchers.pop_back(); - list_matcher.matchers.push_back(choice_matcher); - // then it gets pushed onto the stack of matchers - // the next rule will then get pushed onto the choice matcher - list.AddRootMatcher(choice_matcher); + if (previous_matcher.get().Type() == MatcherType::CHOICE) { + list.AddRootMatcher(previous_matcher); + } else { + vector> choice_options; + choice_options.push_back(previous_matcher); + auto &new_choice_matcher = Choice(choice_options); + + list_matcher.matchers.pop_back(); + list_matcher.matchers.push_back(new_choice_matcher); + + list.AddRootMatcher(new_choice_matcher); + } break; } case '(': { diff --git a/extension/autocomplete/parser/CMakeLists.txt b/extension/autocomplete/parser/CMakeLists.txt new file mode 100644 index 000000000000..8486f809fb0f --- /dev/null +++ b/extension/autocomplete/parser/CMakeLists.txt @@ -0,0 +1,4 @@ +add_library_unity(duckdb_peg_parser OBJECT peg_parser.cpp) +set(AUTOCOMPLETE_EXTENSION_FILES + ${AUTOCOMPLETE_EXTENSION_FILES} $ + PARENT_SCOPE) diff --git a/extension/autocomplete/parser/peg_parser.cpp b/extension/autocomplete/parser/peg_parser.cpp new file mode 100644 index 000000000000..f0f72f2456bc --- /dev/null +++ b/extension/autocomplete/parser/peg_parser.cpp @@ -0,0 +1,194 @@ +#include "parser/peg_parser.hpp" + +namespace duckdb { + +void PEGParser::AddRule(string_t rule_name, PEGRule rule) { + auto entry = rules.find(rule_name.GetString()); + if (entry != rules.end()) { + throw InternalException("Failed to parse grammar - duplicate rule name %s", rule_name.GetString()); + } + rules.insert(make_pair(rule_name, std::move(rule))); +} + +void PEGParser::ParseRules(const char *grammar) { + string_t rule_name; + PEGRule rule; + PEGParseState parse_state = PEGParseState::RULE_NAME; + idx_t bracket_count = 0; + bool in_or_clause = false; + // look for the rules + idx_t c = 0; + while (grammar[c]) { + if (grammar[c] == '#') { + // comment - ignore until EOL + while (grammar[c] && !StringUtil::CharacterIsNewline(grammar[c])) { + c++; + } + continue; + } + if (parse_state == PEGParseState::RULE_DEFINITION && StringUtil::CharacterIsNewline(grammar[c]) && + bracket_count == 0 && !in_or_clause && !rule.tokens.empty()) { + // if we see a newline while we are parsing a rule definition we can complete the rule + AddRule(rule_name, std::move(rule)); + rule_name = string_t(); + rule.Clear(); + // look for the subsequent rule + parse_state = PEGParseState::RULE_NAME; + c++; + continue; + } + if (StringUtil::CharacterIsSpace(grammar[c])) { + // skip whitespace + c++; + continue; + } + switch (parse_state) { + case PEGParseState::RULE_NAME: { + // look for alpha-numerics + idx_t start_pos = c; + if (grammar[c] == '%') { + // rules can start with % (%whitespace) + c++; + } + while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { + c++; + } + if (c == start_pos) { + throw InternalException("Failed to parse grammar - expected an alpha-numeric rule name (pos %d)", c); + } + rule_name = string_t(grammar + start_pos, c - start_pos); + rule.Clear(); + parse_state = PEGParseState::RULE_SEPARATOR; + break; + } + case PEGParseState::RULE_SEPARATOR: { + if (grammar[c] == '(') { + if (!rule.parameters.empty()) { + throw InternalException("Failed to parse grammar - multiple parameters at position %d", c); + } + // parameter + c++; + idx_t parameter_start = c; + while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { + c++; + } + if (parameter_start == c) { + throw InternalException("Failed to parse grammar - expected a parameter at position %d", c); + } + rule.parameters.insert( + make_pair(string_t(grammar + parameter_start, c - parameter_start), rule.parameters.size())); + if (grammar[c] != ')') { + throw InternalException("Failed to parse grammar - expected closing bracket at position %d", c); + } + c++; + } else { + if (grammar[c] != '<' || grammar[c + 1] != '-') { + throw InternalException("Failed to parse grammar - expected a rule definition (<-) (pos %d)", c); + } + c += 2; + parse_state = PEGParseState::RULE_DEFINITION; + } + break; + } + case PEGParseState::RULE_DEFINITION: { + // we parse either: + // (1) a literal ('Keyword'i) + // (2) a rule reference (Rule) + // (3) an operator ( '(' '/' '?' '*' ')') + in_or_clause = false; + if (grammar[c] == '\'') { + // parse literal + c++; + idx_t literal_start = c; + while (grammar[c] && grammar[c] != '\'') { + if (grammar[c] == '\\') { + // escape + c++; + } + c++; + } + if (!grammar[c]) { + throw InternalException("Failed to parse grammar - did not find closing ' (pos %d)", c); + } + PEGToken token; + token.text = string_t(grammar + literal_start, c - literal_start); + token.type = PEGTokenType::LITERAL; + rule.tokens.push_back(token); + c++; + } else if (StringUtil::CharacterIsAlphaNumeric(grammar[c])) { + // alphanumeric character - this is a rule reference + idx_t rule_start = c; + while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) { + c++; + } + PEGToken token; + token.text = string_t(grammar + rule_start, c - rule_start); + if (grammar[c] == '(') { + // this is a function call + c++; + bracket_count++; + token.type = PEGTokenType::FUNCTION_CALL; + } else { + token.type = PEGTokenType::REFERENCE; + } + rule.tokens.push_back(token); + } else if (grammar[c] == '[' || grammar[c] == '<') { + // regular expression- [^"] or <...> + idx_t rule_start = c; + char final_char = grammar[c] == '[' ? ']' : '>'; + while (grammar[c] && grammar[c] != final_char) { + if (grammar[c] == '\\') { + // handle escapes + c++; + } + if (grammar[c]) { + c++; + } + } + c++; + PEGToken token; + token.text = string_t(grammar + rule_start, c - rule_start); + token.type = PEGTokenType::REGEX; + rule.tokens.push_back(token); + } else if (IsPEGOperator(grammar[c])) { + if (grammar[c] == '(') { + bracket_count++; + } else if (grammar[c] == ')') { + if (bracket_count == 0) { + throw InternalException("Failed to parse grammar - unclosed bracket at position %d in rule %s", + c, rule_name.GetString()); + } + bracket_count--; + } else if (grammar[c] == '/') { + in_or_clause = true; + } + // operator - operators are always length 1 + PEGToken token; + token.text = string_t(grammar + c, 1); + token.type = PEGTokenType::OPERATOR; + rule.tokens.push_back(token); + c++; + } else { + throw InternalException("Unrecognized rule contents in rule %s (character %s)", rule_name.GetString(), + string(1, grammar[c])); + } + } + default: + break; + } + if (!grammar[c]) { + break; + } + } + if (parse_state == PEGParseState::RULE_SEPARATOR) { + throw InternalException("Failed to parse grammar - rule %s does not have a definition", rule_name.GetString()); + } + if (parse_state == PEGParseState::RULE_DEFINITION) { + if (rule.tokens.empty()) { + throw InternalException("Failed to parse grammar - rule %s is empty", rule_name.GetString()); + } + AddRule(rule_name, std::move(rule)); + } +} + +} // namespace duckdb diff --git a/extension/autocomplete/tokenizer.cpp b/extension/autocomplete/tokenizer.cpp index e6c63ce7d881..8331b1973b8e 100644 --- a/extension/autocomplete/tokenizer.cpp +++ b/extension/autocomplete/tokenizer.cpp @@ -134,7 +134,7 @@ void BaseTokenizer::PushToken(idx_t start, idx_t end) { return; } string last_token = sql.substr(start, end - start); - tokens.emplace_back(std::move(last_token)); + tokens.emplace_back(std::move(last_token), start); } bool BaseTokenizer::IsValidDollarTagCharacter(char c) { @@ -229,14 +229,14 @@ bool BaseTokenizer::TokenizeInput() { idx_t op_len; if (IsSpecialOperator(i, op_len)) { // special operator - push the special operator - tokens.emplace_back(sql.substr(i, op_len)); + tokens.emplace_back(sql.substr(i, op_len), last_pos); i += op_len - 1; last_pos = i + 1; break; } if (IsSingleByteOperator(c)) { // single-byte operator - directly push the token - tokens.emplace_back(string(1, c)); + tokens.emplace_back(string(1, c), last_pos); last_pos = i + 1; break; } @@ -358,7 +358,7 @@ bool BaseTokenizer::TokenizeInput() { size_t full_marker_len = dollar_quote_marker.size() + 2; string quoted = sql.substr(last_pos, (start + dollar_quote_marker.size() + 1) - last_pos); quoted = "'" + quoted.substr(full_marker_len, quoted.size() - 2 * full_marker_len) + "'"; - tokens.emplace_back(quoted); + tokens.emplace_back(quoted, full_marker_len); dollar_quote_marker = string(); state = TokenizeState::STANDARD; i = end; diff --git a/extension/autocomplete/transformer/CMakeLists.txt b/extension/autocomplete/transformer/CMakeLists.txt new file mode 100644 index 000000000000..26626d0d9769 --- /dev/null +++ b/extension/autocomplete/transformer/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library_unity(duckdb_peg_transformer OBJECT peg_transformer.cpp + peg_transformer_factory.cpp transform_use.cpp) +set(AUTOCOMPLETE_EXTENSION_FILES + ${AUTOCOMPLETE_EXTENSION_FILES} $ + PARENT_SCOPE) diff --git a/extension/autocomplete/transformer/peg_transformer.cpp b/extension/autocomplete/transformer/peg_transformer.cpp new file mode 100644 index 000000000000..5d7bee835070 --- /dev/null +++ b/extension/autocomplete/transformer/peg_transformer.cpp @@ -0,0 +1,47 @@ +#include "transformer/peg_transformer.hpp" + +#include "duckdb/parser/statement/set_statement.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +void PEGTransformer::ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type) { + // Mixing positional/auto-increment and named parameters is not supported + if (last_type == PreparedParamType::INVALID) { + return; + } + if (last_type == PreparedParamType::NAMED) { + if (new_type != PreparedParamType::NAMED) { + throw NotImplementedException("Mixing named and positional parameters is not supported yet"); + } + } + if (last_type != PreparedParamType::NAMED) { + if (new_type == PreparedParamType::NAMED) { + throw NotImplementedException("Mixing named and positional parameters is not supported yet"); + } + } +} + +bool PEGTransformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) { + ParamTypeCheck(last_param_type, type); + auto entry = named_parameter_map.find(identifier); + if (entry == named_parameter_map.end()) { + return false; + } + index = entry->second; + return true; +} + +void PEGTransformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) { + ParamTypeCheck(last_param_type, type); + last_param_type = type; + D_ASSERT(!named_parameter_map.count(identifier)); + named_parameter_map[identifier] = index; +} + +void PEGTransformer::ClearParameters() { + prepared_statement_parameter_index = 0; + named_parameter_map.clear(); +} + +} // namespace duckdb diff --git a/extension/autocomplete/transformer/peg_transformer_factory.cpp b/extension/autocomplete/transformer/peg_transformer_factory.cpp new file mode 100644 index 000000000000..34d96339b5e1 --- /dev/null +++ b/extension/autocomplete/transformer/peg_transformer_factory.cpp @@ -0,0 +1,68 @@ +#include "transformer/peg_transformer.hpp" +#include "matcher.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/tableref/showref.hpp" + +namespace duckdb { + +unique_ptr PEGTransformerFactory::TransformStatement(PEGTransformer &transformer, + optional_ptr parse_result) { + auto &list_pr = parse_result->Cast(); + auto &choice_pr = list_pr.Child(0); + return transformer.Transform>(choice_pr.result); +} + +unique_ptr PEGTransformerFactory::Transform(vector &tokens, const char *root_rule) { + string token_stream; + for (auto &token : tokens) { + token_stream += token.text + " "; + } + + vector suggestions; + ParseResultAllocator parse_result_allocator; + MatchState state(tokens, suggestions, parse_result_allocator); + MatcherAllocator allocator; + auto &matcher = Matcher::RootMatcher(allocator); + auto match_result = matcher.MatchParseResult(state); + if (match_result == nullptr || state.token_index < state.tokens.size()) { + // TODO(dtenwolde) add error handling + string token_list; + for (idx_t i = 0; i < tokens.size(); i++) { + if (!token_list.empty()) { + token_list += "\n"; + } + if (i < 10) { + token_list += " "; + } + token_list += to_string(i) + ":" + tokens[i].text; + } + throw ParserException("Failed to parse query - did not consume all tokens (got to token %d - %s)\nTokens:\n%s", + state.token_index, tokens[state.token_index].text, token_list); + } + + match_result->name = root_rule; + ArenaAllocator transformer_allocator(Allocator::DefaultAllocator()); + PEGTransformerState transformer_state(tokens); + auto &factory = GetInstance(); + PEGTransformer transformer(transformer_allocator, transformer_state, factory.sql_transform_functions, + factory.parser.rules, factory.enum_mappings); + auto result = transformer.Transform>(match_result); + return transformer.Transform>(match_result); +} + +#define REGISTER_TRANSFORM(FUNCTION) Register(string(#FUNCTION).substr(9), &FUNCTION) + +PEGTransformerFactory &PEGTransformerFactory::GetInstance() { + static PEGTransformerFactory instance; + return instance; +} + +PEGTransformerFactory::PEGTransformerFactory() { + REGISTER_TRANSFORM(TransformStatement); + + // use.gram + REGISTER_TRANSFORM(TransformUseStatement); + REGISTER_TRANSFORM(TransformUseTarget); +} +} // namespace duckdb diff --git a/extension/autocomplete/transformer/transform_use.cpp b/extension/autocomplete/transformer/transform_use.cpp new file mode 100644 index 000000000000..c4d88f700652 --- /dev/null +++ b/extension/autocomplete/transformer/transform_use.cpp @@ -0,0 +1,51 @@ +#include "transformer/peg_transformer.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +// UseStatement <- 'USE' UseTarget +unique_ptr PEGTransformerFactory::TransformUseStatement(PEGTransformer &transformer, + optional_ptr parse_result) { + auto &list_pr = parse_result->Cast(); + auto qn = transformer.Transform(list_pr, 1); + + string value_str; + if (IsInvalidSchema(qn.schema)) { + value_str = qn.name; + } else { + value_str = qn.schema + "." + qn.name; + } + + auto value_expr = make_uniq(Value(value_str)); + return make_uniq("schema", std::move(value_expr), SetScope::AUTOMATIC); +} + +// UseTarget <- (CatalogName '.' ReservedSchemaName) / SchemaName / CatalogName +QualifiedName PEGTransformerFactory::TransformUseTarget(PEGTransformer &transformer, + optional_ptr parse_result) { + auto &list_pr = parse_result->Cast(); + auto &choice_pr = list_pr.Child(0); + QualifiedName result; + if (choice_pr.result->type == ParseResultType::LIST) { + vector entries; + auto use_target_children = choice_pr.result->Cast(); + for (auto &child : use_target_children.children) { + if (child->type == ParseResultType::IDENTIFIER) { + entries.push_back(child->Cast().identifier); + } + } + if (entries.size() == 2) { + result.catalog = INVALID_CATALOG; + result.schema = entries[0]; + result.name = entries[1]; + } else { + throw InternalException("Invalid amount of entries for use statement"); + } + } else if (choice_pr.result->type == ParseResultType::IDENTIFIER) { + result.name = choice_pr.result->Cast().identifier; + } else { + throw InternalException("Unexpected parse result type encountered in UseTarget"); + } + return result; +} +} // namespace duckdb diff --git a/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/extension/core_functions/aggregate/distributive/arg_min_max.cpp index d2bdfbe54443..6bddb106fe1f 100644 --- a/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -15,7 +15,7 @@ namespace duckdb { namespace { struct ArgMinMaxStateBase { - ArgMinMaxStateBase() : is_initialized(false), arg_null(false) { + ArgMinMaxStateBase() : is_initialized(false), arg_null(false), val_null(false) { } template @@ -34,6 +34,7 @@ struct ArgMinMaxStateBase { bool is_initialized; bool arg_null; + bool val_null; }; // Out-of-line specialisations @@ -81,7 +82,7 @@ struct ArgMinMaxState : public ArgMinMaxStateBase { } }; -template +template struct ArgMinMaxBase { template static void Initialize(STATE &state) { @@ -94,25 +95,48 @@ struct ArgMinMaxBase { } template - static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, + static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, const bool y_null, AggregateInputData &aggregate_input_data) { - if (IGNORE_NULL) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL) { STATE::template AssignValue(state.arg, x, aggregate_input_data); STATE::template AssignValue(state.value, y, aggregate_input_data); } else { state.arg_null = x_null; + state.val_null = y_null; if (!state.arg_null) { STATE::template AssignValue(state.arg, x, aggregate_input_data); } - STATE::template AssignValue(state.value, y, aggregate_input_data); + if (!state.val_null) { + STATE::template AssignValue(state.value, y, aggregate_input_data); + } } } template static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) { + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); if (!state.is_initialized) { - if (IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) { - Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + binary.left_mask.RowIsValid(binary.lidx) && binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); state.is_initialized = true; } } else { @@ -122,8 +146,14 @@ struct ArgMinMaxBase { template static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) { - if ((IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) && COMPARATOR::Operation(y_data, state.value)) { - Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); + + if (binary.right_mask.RowIsValid(binary.ridx) && COMPARATOR::Operation(y_data, state.value)) { + if (bind_data.null_handling != ArgMinMaxNullHandling::IGNORE_ANY_NULL || + binary.left_mask.RowIsValid(binary.lidx)) { + Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), false, binary.input); + } } } @@ -132,8 +162,10 @@ struct ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - Assign(target, source.arg, source.value, source.arg_null, aggregate_input_data); + + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + Assign(target, source.arg, source.value, source.arg_null, false, aggregate_input_data); target.is_initialized = true; } } @@ -148,9 +180,10 @@ struct ArgMinMaxBase { } static bool IgnoreNull() { - return IGNORE_NULL; + return false; } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { @@ -158,7 +191,9 @@ struct ArgMinMaxBase { } function.arguments[0] = arguments[0]->return_type; function.return_type = arguments[0]->return_type; - return nullptr; + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; @@ -186,12 +221,14 @@ struct GenericArgMinMaxState { } }; -template -struct VectorArgMinMaxBase : ArgMinMaxBase { +template +struct VectorArgMinMaxBase : ArgMinMaxBase { template static void Update(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + auto &arg = inputs[0]; UnifiedVectorFormat adata; arg.ToUnifiedFormat(count, adata); @@ -213,21 +250,36 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { auto states = UnifiedVectorFormat::GetData(sdata); for (idx_t i = 0; i < count; i++) { - const auto bidx = bdata.sel->get_index(i); - if (!bdata.validity.RowIsValid(bidx)) { - continue; - } - const auto bval = bys[bidx]; + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; const auto aidx = adata.sel->get_index(i); const auto arg_null = !adata.validity.RowIsValid(aidx); - if (IGNORE_NULL && arg_null) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && arg_null) { continue; } - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { + const auto bidx = bdata.sel->get_index(i); + + if (!bdata.validity.RowIsValid(bidx)) { + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL && !state.is_initialized) { + state.is_initialized = true; + state.val_null = true; + if (!arg_null) { + if (&state == last_state) { + assign_count--; + } + assign_sel[assign_count++] = UnsafeNumericCast(i); + last_state = &state; + } + } + continue; + } + + const auto bval = bys[bidx]; + + if (!state.is_initialized || state.val_null || COMPARATOR::template Operation(bval, state.value)) { STATE::template AssignValue(state.value, bval, aggregate_input_data); state.arg_null = arg_null; // micro-adaptivity: it is common we overwrite the same state repeatedly @@ -270,8 +322,12 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value, aggregate_input_data); + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + target.val_null = source.val_null; + if (!target.val_null) { + STATE::template AssignValue(target.value, source.value, aggregate_input_data); + } target.arg_null = source.arg_null; if (!target.arg_null) { STATE::template AssignValue(target.arg, source.arg, aggregate_input_data); @@ -290,6 +346,7 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { } } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { @@ -297,31 +354,48 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { } function.arguments[0] = arguments[0]->return_type; function.return_type = arguments[0]->return_type; - return nullptr; + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; template -AggregateFunction GetGenericArgMinMaxFunction() { +bind_aggregate_function_t GetBindFunction(const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + return OP::template Bind; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + return OP::template Bind; + default: + return OP::template Bind; + } +} + +template +AggregateFunction GetGenericArgMinMaxFunction(const ArgMinMaxNullHandling null_handling) { using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction( {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); } template -AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; return function; @@ -330,18 +404,19 @@ AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max aggregate"); } @@ -356,19 +431,21 @@ const vector ArgMaxByTypes() { } template -void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; auto function = @@ -377,9 +454,9 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { function.destructor = AggregateFunction::StateDestroy; } - function.bind = OP::Bind; + function.bind = GetBindFunction(null_handling); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; #endif @@ -388,18 +465,19 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max by aggregate"); } @@ -407,37 +485,38 @@ AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const Logic #endif template -void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type, + ArgMinMaxNullHandling null_handling) { D_ASSERT(type.id() == LogicalTypeId::DECIMAL); #ifndef DUCKDB_SMALLER_BINARY switch (type.InternalType()) { case PhysicalType::INT16: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT32: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); default: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); } #else - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); #endif } -template +template unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto decimal_type = arguments[0]->return_type; @@ -469,51 +548,69 @@ unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateF } auto name = std::move(function.name); - function = GetDecimalArgMinMaxFunction(by_type, decimal_type); + function = GetDecimalArgMinMaxFunction(by_type, decimal_type, NULL_HANDLING); function.name = std::move(name); function.return_type = decimal_type; - return nullptr; + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } template -void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type) { - fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); +void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type, + const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::IGNORE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + } } template -void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { - fun.AddFunction(GetGenericArgMinMaxFunction()); +void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + fun.AddFunction(GetGenericArgMinMaxFunction(null_handling)); } -template -void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { - using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; +template +void AddArgMinMaxFunctions(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; #ifndef DUCKDB_SMALLER_BINARY - using OP = ArgMinMaxBase; - using VECTOR_OP = VectorArgMinMaxBase; + using OP = ArgMinMaxBase; + using VECTOR_OP = VectorArgMinMaxBase; #else using OP = GENERIC_VECTOR_OP; using VECTOR_OP = GENERIC_VECTOR_OP; #endif - AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); - AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); - AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); - AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); - AddArgMinMaxFunctionBy(fun, LogicalType::DATE); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); - AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); + AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DATE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BLOB, null_handling); auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { - AddDecimalArgMinMaxFunctionBy(fun, by_type); + AddDecimalArgMinMaxFunctionBy(fun, by_type, null_handling); } - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY, null_handling); // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest - AddGenericArgMinMaxFunction(fun); + AddGenericArgMinMaxFunction(fun, null_handling); } //------------------------------------------------------------------------------ @@ -547,6 +644,8 @@ class ArgMinMaxNState { template void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggr_input.bind_data); + const auto &bind_data = aggr_input.bind_data->Cast(); auto &val_vector = inputs[0]; auto &arg_vector = inputs[1]; @@ -560,8 +659,8 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); - STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last); + STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); @@ -571,9 +670,16 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp for (idx_t i = 0; i < count; i++) { const auto arg_idx = arg_format.sel->get_index(i); const auto val_idx = val_format.sel->get_index(i); - if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx))) { + continue; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + !val_format.validity.RowIsValid(val_idx)) { continue; } + const auto state_idx = state_format.sel->get_index(i); auto &state = *states[state_idx]; @@ -671,7 +777,77 @@ void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, } } -template +template +void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) { + using STATE = ArgMinMaxNState; + using OP = MinMaxNOperation; + + function.state_size = AggregateFunction::StateSize; + function.initialize = AggregateFunction::StateInitialize; + function.combine = AggregateFunction::StateCombine; + function.destructor = AggregateFunction::StateDestroy; + + function.finalize = MinMaxNOperation::Finalize; + function.update = ArgMinMaxNUpdate; +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(function); + break; + } +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { + switch (val_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + } +} + +template unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, vector> &arguments) { for (auto &arg : arguments) { @@ -682,19 +858,24 @@ unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunctio const auto val_type = arguments[0]->return_type.InternalType(); const auto arg_type = arguments[1]->return_type.InternalType(); + function.return_type = LogicalType::LIST(arguments[0]->return_type); // Specialize the function based on the input types - SpecializeArgMinMaxNFunction(val_type, arg_type, function); + auto function_data = make_uniq(NULL_HANDLING, NULLS_LAST); + if (NULL_HANDLING != ArgMinMaxNullHandling::IGNORE_ANY_NULL) { + SpecializeArgMinMaxNullNFunction(val_type, arg_type, function); + } else { + SpecializeArgMinMaxNFunction(val_type, arg_type, function); + } - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; + return unique_ptr(std::move(function_data)); } -template +template void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, ArgMinMaxNBind); + nullptr, ArgMinMaxNBind); return set.AddFunction(function); } @@ -707,27 +888,41 @@ void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunctionSet ArgMinFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMaxFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMinNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); return fun; } AggregateFunctionSet ArgMaxNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); + return fun; +} + +AggregateFunctionSet ArgMinNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); + return fun; +} + +AggregateFunctionSet ArgMaxNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } diff --git a/extension/core_functions/aggregate/distributive/functions.json b/extension/core_functions/aggregate/distributive/functions.json index d8e614c50243..9dcc4910660f 100644 --- a/extension/core_functions/aggregate/distributive/functions.json +++ b/extension/core_functions/aggregate/distributive/functions.json @@ -21,6 +21,13 @@ "example": "arg_min_null(A, B)", "type": "aggregate_function_set" }, + { + "name": "arg_min_nulls_last", + "parameters": "arg,val,N", + "description": "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row.", + "example": "arg_min_null_val(A, B, N)", + "type": "aggregate_function_set" + }, { "name": "arg_max", "parameters": "arg,val", @@ -36,6 +43,13 @@ "example": "arg_max_null(A, B)", "type": "aggregate_function_set" }, + { + "name": "arg_max_nulls_last", + "parameters": "arg,val,N", + "description": "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row.", + "example": "arg_min_null_val(A, B, N)", + "type": "aggregate_function_set" + }, { "name": "bit_and", "parameters": "arg", diff --git a/extension/core_functions/function_list.cpp b/extension/core_functions/function_list.cpp index a8ba52658155..810b020ab880 100644 --- a/extension/core_functions/function_list.cpp +++ b/extension/core_functions/function_list.cpp @@ -73,8 +73,10 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), diff --git a/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp b/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp index 39bc9459c0af..4add0a00db27 100644 --- a/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp +++ b/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp @@ -57,6 +57,16 @@ struct ArgMinNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMinNullsLastFun { + static constexpr const char *Name = "arg_min_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct ArgMaxFun { static constexpr const char *Name = "arg_max"; static constexpr const char *Parameters = "arg,val"; @@ -89,6 +99,16 @@ struct ArgMaxNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMaxNullsLastFun { + static constexpr const char *Name = "arg_max_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct BitAndFun { static constexpr const char *Name = "bit_and"; static constexpr const char *Parameters = "arg"; diff --git a/extension/json/json_functions/json_create.cpp b/extension/json/json_functions/json_create.cpp index 8387ef750387..36eacfd1870f 100644 --- a/extension/json/json_functions/json_create.cpp +++ b/extension/json/json_functions/json_create.cpp @@ -111,11 +111,11 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct auto &type = arguments[i]->return_type; if (arguments[i]->HasParameter()) { throw ParameterNotResolvedException(); - } else if (type == LogicalTypeId::SQLNULL) { - // This is needed for macro's - bound_function.arguments.push_back(type); } else if (object && i % 2 == 0) { - // Key, must be varchar + if (type != LogicalType::VARCHAR) { + throw BinderException("json_object() keys must be VARCHAR, add an explicit cast to argument \"%s\"", + arguments[i]->GetName()); + } bound_function.arguments.push_back(LogicalType::VARCHAR); } else { // Value, cast to types that we can put in JSON @@ -128,7 +128,7 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct static unique_ptr JSONObjectBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() % 2 != 0) { - throw InvalidInputException("json_object() requires an even number of arguments"); + throw BinderException("json_object() requires an even number of arguments"); } return JSONCreateBindParams(bound_function, arguments, true); } @@ -141,7 +141,7 @@ static unique_ptr JSONArrayBind(ClientContext &context, ScalarFunc static unique_ptr ToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("to_json() takes exactly one argument"); + throw BinderException("to_json() takes exactly one argument"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -149,14 +149,14 @@ static unique_ptr ToJSONBind(ClientContext &context, ScalarFunctio static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("array_to_json() takes exactly one argument"); + throw BinderException("array_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arg_id != LogicalTypeId::LIST && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("array_to_json() argument type must be LIST"); + throw BinderException("array_to_json() argument type must be LIST"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -164,14 +164,14 @@ static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFu static unique_ptr RowToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("row_to_json() takes exactly one argument"); + throw BinderException("row_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arguments[0]->return_type.id() != LogicalTypeId::STRUCT && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("row_to_json() argument type must be STRUCT"); + throw BinderException("row_to_json() argument type must be STRUCT"); } return JSONCreateBindParams(bound_function, arguments, false); } diff --git a/extension/parquet/CMakeLists.txt b/extension/parquet/CMakeLists.txt index 23b3ebbdb070..db4f53d60e0a 100644 --- a/extension/parquet/CMakeLists.txt +++ b/extension/parquet/CMakeLists.txt @@ -26,6 +26,7 @@ set(PARQUET_EXTENSION_FILES parquet_timestamp.cpp parquet_writer.cpp parquet_shredding.cpp + parquet_column_schema.cpp serialize_parquet.cpp zstd_file_system.cpp geo_parquet.cpp) diff --git a/extension/parquet/column_writer.cpp b/extension/parquet/column_writer.cpp index 365cb9a9eb9c..25470ea6e10b 100644 --- a/extension/parquet/column_writer.cpp +++ b/extension/parquet/column_writer.cpp @@ -108,10 +108,9 @@ void ColumnWriterStatistics::WriteGeoStats(duckdb_parquet::GeospatialStatistics //===--------------------------------------------------------------------===// // ColumnWriter //===--------------------------------------------------------------------===// -ColumnWriter::ColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : writer(writer), column_schema(column_schema), schema_path(std::move(schema_path_p)), - can_have_nulls(can_have_nulls) { +ColumnWriter::ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema_p, vector schema_path_p) + : writer(writer), column_schema(std::move(column_schema_p)), schema_path(std::move(schema_path_p)) { + can_have_nulls = column_schema.repetition_type == duckdb_parquet::FieldRepetitionType::OPTIONAL; } ColumnWriter::~ColumnWriter() { } @@ -244,19 +243,22 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat // Create Column Writer //===--------------------------------------------------------------------===// -ParquetColumnSchema ColumnWriter::FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, - optional_ptr shredding_types, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) { - auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; +unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, + vector path_in_schema, const LogicalType &type, + const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat, idx_t max_define, bool can_have_nulls) { + path_in_schema.push_back(name); + if (!can_have_nulls) { max_define--; } - idx_t schema_idx = schemas.size(); + auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; optional_ptr field_id; optional_ptr child_field_ids; + optional_ptr shredding_type; if (field_ids) { auto field_id_it = field_ids->ids->find(name); if (field_id_it != field_ids->ids->end()) { @@ -264,22 +266,14 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorchild_field_ids; } } - optional_ptr shredding_type; if (shredding_types) { shredding_type = shredding_types->GetChild(name); } - if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { - // variant type - // variants are stored as follows: - // group VARIANT { - // metadata BYTE_ARRAY, - // value BYTE_ARRAY, - // [] - // } - + if (type.id() == LogicalTypeId::VARIANT) { const bool is_shredded = shredding_type != nullptr; + //! Build the child types for the Parquet VARIANT child_list_t child_types; child_types.emplace_back("metadata", LogicalType::BLOB); child_types.emplace_back("value", LogicalType::BLOB); @@ -291,23 +285,16 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vector> child_writers; + child_writers.reserve(child_types.size()); + + //! Then construct the child writers for the Parquet VARIANT + for (auto &entry : child_types) { + auto &child_name = entry.first; + auto &child_type = entry.second; bool is_optional; if (child_name == "metadata") { is_optional = false; @@ -322,286 +309,185 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vector(writer, std::move(variant_column), path_in_schema, + std::move(child_writers)); } if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { - auto &child_types = StructType::GetChildTypes(type); - // set up the schema element for this struct - duckdb_parquet::SchemaElement schema_element; - schema_element.repetition_type = null_type; - schema_element.num_children = UnsafeNumericCast(child_types.size()); - schema_element.__isset.num_children = true; - schema_element.__isset.type = false; - schema_element.__isset.repetition_type = true; - schema_element.name = name; + auto struct_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; + struct_column.field_id = field_id->field_id; } - schemas.push_back(std::move(schema_element)); - ParquetColumnSchema struct_column(name, type, max_define, max_repeat, schema_idx, 0); // construct the child schemas recursively - struct_column.children.reserve(child_types.size()); - for (auto &child_type : child_types) { - struct_column.children.emplace_back(FillParquetSchema(schemas, child_type.second, child_type.first, - child_field_ids, shredding_type, max_repeat, - max_define + 1, true)); + auto &child_types = StructType::GetChildTypes(type); + vector> child_writers; + child_writers.reserve(child_types.size()); + for (auto &entry : child_types) { + auto &child_type = entry.second; + auto &child_name = entry.first; + child_writers.push_back(CreateWriterRecursive(context, writer, path_in_schema, child_type, child_name, + allow_geometry, child_field_ids, shredding_type, max_repeat, + max_define + 1, true)); } - return struct_column; + return make_uniq(writer, std::move(struct_column), std::move(path_in_schema), + std::move(child_writers)); } + if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::ARRAY) { auto is_list = type.id() == LogicalTypeId::LIST; auto &child_type = is_list ? ListType::GetChildType(type) : ArrayType::GetChildType(type); - // set up the two schema elements for the list - // for some reason we only set the converted type in the OPTIONAL element - // first an OPTIONAL element - duckdb_parquet::SchemaElement optional_element; - optional_element.repetition_type = null_type; - optional_element.num_children = 1; - optional_element.converted_type = ConvertedType::LIST; - optional_element.__isset.num_children = true; - optional_element.__isset.type = false; - optional_element.__isset.repetition_type = true; - optional_element.__isset.converted_type = true; - optional_element.name = name; - if (field_id && field_id->set) { - optional_element.__isset.field_id = true; - optional_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(optional_element)); - - // then a REPEATED element - duckdb_parquet::SchemaElement repeated_element; - repeated_element.repetition_type = FieldRepetitionType::REPEATED; - repeated_element.num_children = 1; - repeated_element.__isset.num_children = true; - repeated_element.__isset.type = false; - repeated_element.__isset.repetition_type = true; - repeated_element.name = "list"; - schemas.push_back(std::move(repeated_element)); - - ParquetColumnSchema list_column(name, type, max_define, max_repeat, schema_idx, 0); - list_column.children.push_back(FillParquetSchema(schemas, child_type, "element", child_field_ids, - shredding_type, max_repeat + 1, max_define + 2, true)); - return list_column; - } - if (type.id() == LogicalTypeId::MAP) { - // map type - // maps are stored as follows: - // group (MAP) { - // repeated group key_value { - // required key; - // value; - // } - // } - // top map element - duckdb_parquet::SchemaElement top_element; - top_element.repetition_type = null_type; - top_element.num_children = 1; - top_element.converted_type = ConvertedType::MAP; - top_element.__isset.repetition_type = true; - top_element.__isset.num_children = true; - top_element.__isset.converted_type = true; - top_element.__isset.type = false; - top_element.name = name; - if (field_id && field_id->set) { - top_element.__isset.field_id = true; - top_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(top_element)); - - // key_value element - duckdb_parquet::SchemaElement kv_element; - kv_element.repetition_type = FieldRepetitionType::REPEATED; - kv_element.num_children = 2; - kv_element.__isset.repetition_type = true; - kv_element.__isset.num_children = true; - kv_element.__isset.type = false; - kv_element.name = "key_value"; - schemas.push_back(std::move(kv_element)); - - // construct the child types recursively - vector kv_types {MapType::KeyType(type), MapType::ValueType(type)}; - vector kv_names {"key", "value"}; - - ParquetColumnSchema map_column(name, type, max_define, max_repeat, schema_idx, 0); - map_column.children.reserve(2); - for (idx_t i = 0; i < 2; i++) { - // key needs to be marked as REQUIRED - bool is_key = i == 0; - auto child_schema = FillParquetSchema(schemas, kv_types[i], kv_names[i], child_field_ids, shredding_type, - max_repeat + 1, max_define + 2, !is_key); - - map_column.children.push_back(std::move(child_schema)); - } - return map_column; - } - duckdb_parquet::SchemaElement schema_element; - schema_element.type = ParquetWriter::DuckDBTypeToParquetType(type); - schema_element.repetition_type = null_type; - schema_element.__isset.num_children = false; - schema_element.__isset.type = true; - schema_element.__isset.repetition_type = true; - schema_element.name = name; - if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; - } - ParquetWriter::SetSchemaProperties(type, schema_element); - schemas.push_back(std::move(schema_element)); - return ParquetColumnSchema(name, type, max_define, max_repeat, schema_idx, 0); -} - -unique_ptr -ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - const vector &parquet_schemas, - const ParquetColumnSchema &schema, vector path_in_schema) { - auto &type = schema.type; - auto can_have_nulls = parquet_schemas[schema.schema_index].repetition_type == FieldRepetitionType::OPTIONAL; - path_in_schema.push_back(schema.name); + path_in_schema.push_back("list"); + auto child_writer = + CreateWriterRecursive(context, writer, path_in_schema, child_type, "element", allow_geometry, + child_field_ids, shredding_type, max_repeat + 1, max_define + 2, true); - if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { - vector> child_writers; - child_writers.reserve(schema.children.size()); - for (idx_t i = 0; i < schema.children.size(); i++) { - child_writers.push_back( - CreateWriterRecursive(context, writer, parquet_schemas, schema.children[i], path_in_schema)); + auto list_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + list_column.field_id = field_id->field_id; } - return make_uniq(writer, schema, path_in_schema, std::move(child_writers), can_have_nulls); - } - if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { - // construct the child writers recursively - vector> child_writers; - child_writers.reserve(schema.children.size()); - for (auto &child_column : schema.children) { - child_writers.push_back( - CreateWriterRecursive(context, writer, parquet_schemas, child_column, path_in_schema)); - } - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writers), - can_have_nulls); - } - if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::ARRAY) { - auto is_list = type.id() == LogicalTypeId::LIST; - path_in_schema.push_back("list"); - auto child_writer = CreateWriterRecursive(context, writer, parquet_schemas, schema.children[0], path_in_schema); if (is_list) { - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writer), - can_have_nulls); + return make_uniq(writer, std::move(list_column), std::move(path_in_schema), + std::move(child_writer)); } else { - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writer), - can_have_nulls); + return make_uniq(writer, std::move(list_column), std::move(path_in_schema), + std::move(child_writer)); } } + if (type.id() == LogicalTypeId::MAP) { path_in_schema.push_back("key_value"); + // construct the child types recursively + child_list_t key_value; + key_value.reserve(2); + key_value.emplace_back("key", MapType::KeyType(type)); + key_value.emplace_back("value", MapType::ValueType(type)); + auto key_value_type = LogicalType::STRUCT(key_value); + + auto map_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + map_column.field_id = field_id->field_id; + } + vector> child_writers; child_writers.reserve(2); for (idx_t i = 0; i < 2; i++) { // key needs to be marked as REQUIRED + bool is_key = i == 0; + auto &child_name = key_value[i].first; + auto &child_type = key_value[i].second; auto child_writer = - CreateWriterRecursive(context, writer, parquet_schemas, schema.children[i], path_in_schema); + CreateWriterRecursive(context, writer, path_in_schema, child_type, child_name, allow_geometry, + child_field_ids, shredding_type, max_repeat + 1, max_define + 2, !is_key); + child_writers.push_back(std::move(child_writer)); } - auto struct_writer = - make_uniq(writer, schema, path_in_schema, std::move(child_writers), can_have_nulls); - return make_uniq(writer, schema, path_in_schema, std::move(struct_writer), can_have_nulls); + + auto key_value_schema = + ParquetColumnSchema::FromLogicalType("key_value", key_value_type, max_define + 1, max_repeat + 1, 0, + FieldRepetitionType::REPEATED, allow_geometry); + auto struct_writer = make_uniq(writer, std::move(key_value_schema), path_in_schema, + std::move(child_writers)); + return make_uniq(writer, std::move(map_column), path_in_schema, std::move(struct_writer)); + } + + auto schema = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + schema.field_id = field_id->field_id; } if (type.id() == LogicalTypeId::BLOB && type.GetAlias() == "WKB_BLOB") { - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); } switch (type.id()) { case LogicalTypeId::BOOLEAN: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::TINYINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::SMALLINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::INTEGER: case LogicalTypeId::DATE: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::BIGINT: case LogicalTypeId::TIME: case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: case LogicalTypeId::TIMESTAMP_MS: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::TIME_TZ: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::HUGEINT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UHUGEINT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::TIMESTAMP_NS: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::TIMESTAMP_SEC: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UTINYINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::USMALLINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::UINTEGER: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UBIGINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::FLOAT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::DOUBLE: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::DECIMAL: switch (type.InternalType()) { case PhysicalType::INT16: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case PhysicalType::INT32: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case PhysicalType::INT64: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); default: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); } case LogicalTypeId::BLOB: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::VARCHAR: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UUID: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::INTERVAL: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::ENUM: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); default: throw InternalException("Unsupported type \"%s\" in Parquet writer", type.ToString()); } diff --git a/extension/parquet/geo_parquet.cpp b/extension/parquet/geo_parquet.cpp index 48e2b047f773..a2cc2a82168d 100644 --- a/extension/parquet/geo_parquet.cpp +++ b/extension/parquet/geo_parquet.cpp @@ -43,17 +43,19 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ throw InvalidInputException("Geoparquet metadata is not an object"); } - auto result = make_uniq(); + // We dont actually care about the version for now, as we only support V1+native + auto result = make_uniq(GeoParquetVersion::BOTH); // Check and parse the version const auto version_val = yyjson_obj_get(root, "version"); if (!yyjson_is_str(version_val)) { throw InvalidInputException("Geoparquet metadata does not have a version"); } - result->version = yyjson_get_str(version_val); - if (StringUtil::StartsWith(result->version, "2")) { - // Guard against a breaking future 2.0 version - throw InvalidInputException("Geoparquet version %s is not supported", result->version); + + auto version = yyjson_get_str(version_val); + if (StringUtil::StartsWith(version, "3")) { + // Guard against a breaking future 3.0 version + throw InvalidInputException("Geoparquet version %s is not supported", version); } // Check and parse the geometry columns @@ -177,7 +179,20 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) yyjson_mut_doc_set_root(doc, root); // Add the version - yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); + switch (version) { + case GeoParquetVersion::V1: + case GeoParquetVersion::BOTH: + yyjson_mut_obj_add_strcpy(doc, root, "version", "1.0.0"); + break; + case GeoParquetVersion::V2: + yyjson_mut_obj_add_strcpy(doc, root, "version", "2.0.0"); + break; + case GeoParquetVersion::NONE: + default: + // Should never happen, we should not be writing anything + yyjson_mut_doc_free(doc); + throw InternalException("GeoParquetVersion::NONE should not write metadata"); + } // Add the primary column yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), diff --git a/extension/parquet/include/column_writer.hpp b/extension/parquet/include/column_writer.hpp index bc7d1b82d003..1463137add33 100644 --- a/extension/parquet/include/column_writer.hpp +++ b/extension/parquet/include/column_writer.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #include "parquet_types.h" #include "parquet_column_schema.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { class MemoryStream; @@ -63,13 +64,32 @@ class ColumnWriterPageState { } }; +struct ParquetAnalyzeSchemaState { +public: + ParquetAnalyzeSchemaState() { + } + virtual ~ParquetAnalyzeSchemaState() { + } + +public: + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + class ColumnWriter { protected: static constexpr uint16_t PARQUET_DEFINE_VALID = UINT16_C(65535); public: - ColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path, - bool can_have_nulls); + ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); virtual ~ColumnWriter(); public: @@ -79,8 +99,12 @@ class ColumnWriter { const ParquetColumnSchema &Schema() const { return column_schema; } + ParquetColumnSchema &Schema() { + return column_schema; + } inline idx_t SchemaIndex() const { - return column_schema.schema_index; + D_ASSERT(column_schema.schema_index.IsValid()); + return column_schema.schema_index.GetIndex(); } inline idx_t MaxDefine() const { return column_schema.max_define; @@ -88,16 +112,49 @@ class ColumnWriter { idx_t MaxRepeat() const { return column_schema.max_repeat; } + virtual bool HasTransform() { + for (auto &child_writer : child_writers) { + if (child_writer->HasTransform()) { + throw NotImplementedException("ColumnWriter of type '%s' requires a transform, but is not a root " + "column, this isn't supported currently", + child_writer->Type()); + } + } + return false; + } + virtual LogicalType TransformedType() { + throw NotImplementedException("Writer does not have a transformed type"); + } + virtual unique_ptr TransformExpression(unique_ptr expr) { + throw NotImplementedException("Writer does not have a transform expression"); + } + + virtual unique_ptr AnalyzeSchemaInit() { + return nullptr; + } + + const vector> &ChildWriters() const { + return child_writers; + } + + virtual void AnalyzeSchema(ParquetAnalyzeSchemaState &state, Vector &input, idx_t count) { + throw NotImplementedException("Writer doesn't require an AnalyzeSchema pass"); + } + + virtual void AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state) { + throw NotImplementedException("Writer doesn't require an AnalyzeSchemaFinalize pass"); + } + + virtual void FinalizeSchema(vector &schemas) = 0; - static ParquetColumnSchema - FillParquetSchema(vector &schemas, const LogicalType &type, const string &name, - optional_ptr field_ids, optional_ptr shredding_types, - idx_t max_repeat = 0, idx_t max_define = 1, bool can_have_nulls = true); //! Create the column writer for a specific type recursively static unique_ptr CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - const vector &parquet_schemas, - const ParquetColumnSchema &schema, - vector path_in_schema); + vector path_in_schema, const LogicalType &type, + const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat = 0, idx_t max_define = 1, + bool can_have_nulls = true); virtual unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) = 0; @@ -132,7 +189,7 @@ class ColumnWriter { public: ParquetWriter &writer; - const ParquetColumnSchema &column_schema; + ParquetColumnSchema column_schema; vector schema_path; bool can_have_nulls; diff --git a/extension/parquet/include/geo_parquet.hpp b/extension/parquet/include/geo_parquet.hpp index 424e7c324c48..0e236c73aa24 100644 --- a/extension/parquet/include/geo_parquet.hpp +++ b/extension/parquet/include/geo_parquet.hpp @@ -33,6 +33,31 @@ enum class GeoParquetColumnEncoding : uint8_t { MULTIPOLYGON, }; +enum class GeoParquetVersion : uint8_t { + // Write GeoParquet 1.0 metadata + // GeoParquet 1.0 has the widest support among readers and writers + V1, + + // Write GeoParquet 2.0 + // The GeoParquet 2.0 options is identical to GeoParquet 1.0 except the underlying storage + // of spatial columns is Parquet native geometry, where the Parquet writer will include + // native statistics according to the underlying Parquet options. Compared to 'BOTH', this will + // actually write the metadata as containing GeoParquet version 2.0.0 + // However, V2 isnt standardized yet, so this option is still a bit experimental + V2, + + // Write GeoParquet 1.0 metadata, with native Parquet geometry types + // This is a bit of a hold-over option for compatibility with systems that + // reject GeoParquet 2.0 metadata, but can read Parquet native geometry types as they simply ignore the extra + // logical type. DuckDB v1.4.0 falls into this category. + BOTH, + + // Do not write GeoParquet metadata + // This option suppresses GeoParquet metadata; however, spatial types will be written as + // Parquet native Geometry/Geography. + NONE, +}; + struct GeoParquetColumnMetadata { // The encoding of the geometry column GeoParquetColumnEncoding geometry_encoding; @@ -49,6 +74,8 @@ struct GeoParquetColumnMetadata { class GeoParquetFileMetadata { public: + explicit GeoParquetFileMetadata(GeoParquetVersion geo_parquet_version) : version(geo_parquet_version) { + } void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStatsData &stats); void Write(duckdb_parquet::FileMetaData &file_meta_data); @@ -68,8 +95,8 @@ class GeoParquetFileMetadata { private: mutex write_lock; - string version = "1.1.0"; unordered_map geometry_columns; + GeoParquetVersion version; }; } // namespace duckdb diff --git a/extension/parquet/include/parquet_column_schema.hpp b/extension/parquet/include/parquet_column_schema.hpp index d467e2a0263d..263183117667 100644 --- a/extension/parquet/include/parquet_column_schema.hpp +++ b/extension/parquet/include/parquet_column_schema.hpp @@ -12,6 +12,12 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; +using duckdb_parquet::SchemaElement; + using duckdb_parquet::FileMetaData; struct ParquetOptions; @@ -30,29 +36,60 @@ enum class ParquetExtraTypeInfo { }; struct ParquetColumnSchema { +public: ParquetColumnSchema() = default; - ParquetColumnSchema(idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t file_index, - ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); - ParquetColumnSchema(string name, LogicalType type, idx_t max_define, idx_t max_repeat, idx_t schema_index, - idx_t column_index, ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); - ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, ParquetColumnSchemaType schema_type); + ParquetColumnSchema(ParquetColumnSchema &&other) = default; + ParquetColumnSchema(const ParquetColumnSchema &other) = default; + ParquetColumnSchema &operator=(ParquetColumnSchema &&other) = default; - ParquetColumnSchemaType schema_type; +public: + //! Writer constructors + static ParquetColumnSchema FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t column_index, + duckdb_parquet::FieldRepetitionType::type repetition_type, + bool allow_geometry, + ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); + +public: + //! Reader constructors + static ParquetColumnSchema FromSchemaElement(const SchemaElement &element, idx_t max_define, idx_t max_repeat, + idx_t schema_index, idx_t column_index, ParquetColumnSchemaType type, + const ParquetOptions &options); + static ParquetColumnSchema FromParentSchema(ParquetColumnSchema parent, LogicalType result_type, + ParquetColumnSchemaType schema_type); + static ParquetColumnSchema FromChildSchemas(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t schema_index, idx_t column_index, + vector &&children, + ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); + static ParquetColumnSchema FileRowNumber(); + +public: + unique_ptr Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, + idx_t row_group_idx_p, const vector &columns) const; + +public: + void SetSchemaIndex(idx_t schema_idx); + +public: string name; - LogicalType type; idx_t max_define; idx_t max_repeat; - idx_t schema_index; + //! Populated by FinalizeSchema if used in the parquet_writer path + optional_idx schema_index; idx_t column_index; + ParquetColumnSchemaType schema_type; + LogicalType type; optional_idx parent_schema_index; uint32_t type_length = 0; uint32_t type_scale = 0; duckdb_parquet::Type::type parquet_type = duckdb_parquet::Type::INT32; ParquetExtraTypeInfo type_info = ParquetExtraTypeInfo::NONE; vector children; - - unique_ptr Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, - idx_t row_group_idx_p, const vector &columns) const; + optional_idx field_id; + //! Whether a column is nullable or not + duckdb_parquet::FieldRepetitionType::type repetition_type = duckdb_parquet::FieldRepetitionType::OPTIONAL; + //! Whether the column can be recognized as a GEOMETRY type + bool allow_geometry = false; }; } // namespace duckdb diff --git a/extension/parquet/include/parquet_dbp_decoder.hpp b/extension/parquet/include/parquet_dbp_decoder.hpp index 31fb26cc9c28..775160215c87 100644 --- a/extension/parquet/include/parquet_dbp_decoder.hpp +++ b/extension/parquet/include/parquet_dbp_decoder.hpp @@ -18,7 +18,7 @@ class DbpDecoder { : buffer_(buffer, buffer_len), // block_size_in_values(ParquetDecodeUtils::VarintDecode(buffer_)), - number_of_miniblocks_per_block(ParquetDecodeUtils::VarintDecode(buffer_)), + number_of_miniblocks_per_block(DecodeNumberOfMiniblocksPerBlock(buffer_)), number_of_values_in_a_miniblock(block_size_in_values / number_of_miniblocks_per_block), total_value_count(ParquetDecodeUtils::VarintDecode(buffer_)), previous_value(ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_))), @@ -31,7 +31,7 @@ class DbpDecoder { number_of_values_in_a_miniblock % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0)) { throw InvalidInputException("Parquet file has invalid block sizes for DELTA_BINARY_PACKED"); } - }; + } ByteBuffer BufferPtr() const { return buffer_; @@ -68,6 +68,15 @@ class DbpDecoder { } private: + static idx_t DecodeNumberOfMiniblocksPerBlock(ByteBuffer &buffer) { + auto res = ParquetDecodeUtils::VarintDecode(buffer); + if (res == 0) { + throw InvalidInputException( + "Parquet file has invalid number of miniblocks per block for DELTA_BINARY_PACKED"); + } + return res; + } + template void GetBatchInternal(const data_ptr_t target_values_ptr, const idx_t batch_size) { if (batch_size == 0) { diff --git a/extension/parquet/include/parquet_reader.hpp b/extension/parquet/include/parquet_reader.hpp index de905c70cf2a..09ec66fa4566 100644 --- a/extension/parquet/include/parquet_reader.hpp +++ b/extension/parquet/include/parquet_reader.hpp @@ -195,6 +195,8 @@ class ParquetReader : public BaseFileReader { static unique_ptr ReadStatistics(const ParquetUnionData &union_data, const string &name); LogicalType DeriveLogicalType(const SchemaElement &s_ele, ParquetColumnSchema &schema) const; + static LogicalType DeriveLogicalType(const SchemaElement &s_ele, const ParquetOptions &options, + ParquetColumnSchema &schema); void AddVirtualColumn(column_t virtual_column_id) override; diff --git a/extension/parquet/include/parquet_writer.hpp b/extension/parquet/include/parquet_writer.hpp index be784288fe85..decf436f73ff 100644 --- a/extension/parquet/include/parquet_writer.hpp +++ b/extension/parquet/include/parquet_writer.hpp @@ -56,6 +56,53 @@ enum class ParquetVersion : uint8_t { V2 = 2, //! Includes the encodings above }; +class ParquetWriteTransformData { +public: + ParquetWriteTransformData(ClientContext &context, vector types, + vector> expressions); + +public: + ColumnDataCollection &ApplyTransform(ColumnDataCollection &input); + +private: + //! The buffer to store the transformed chunks of a rowgroup + ColumnDataCollection buffer; + //! The expression(s) to apply to the input chunk + vector> expressions; + //! The expression executor used to transform the input chunk + ExpressionExecutor executor; + //! The intermediate chunk to target the transform to + DataChunk chunk; +}; + +struct ParquetWriteLocalState : public LocalFunctionData { +public: + explicit ParquetWriteLocalState(ClientContext &context, const vector &types); + +public: + ColumnDataCollection buffer; + ColumnDataAppendState append_state; + //! If any of the column writers require a transformation to a different shape, this will be initialized and used + unique_ptr transform_data; +}; + +struct ParquetWriteGlobalState : public GlobalFunctionData { +public: + ParquetWriteGlobalState() { + } + +public: + void LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason); + +public: + unique_ptr writer; + optional_ptr op; + mutex lock; + unique_ptr combine_buffer; + //! If any of the column writers require a transformation to a different shape, this will be initialized and used + unique_ptr transform_data; +}; + class ParquetWriter { public: ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, @@ -64,17 +111,19 @@ class ParquetWriter { shared_ptr encryption_config, optional_idx dictionary_size_limit, idx_t string_dictionary_page_size_limit, bool enable_bloom_filters, double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl, - ParquetVersion parquet_version); + ParquetVersion parquet_version, GeoParquetVersion geoparquet_version); ~ParquetWriter(); public: - void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result); + void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result, + unique_ptr &transform_data); void FlushRowGroup(PreparedRowGroup &row_group); - void Flush(ColumnDataCollection &buffer); + void Flush(ColumnDataCollection &buffer, unique_ptr &transform_data); void Finalize(); static duckdb_parquet::Type::type DuckDBTypeToParquetType(const LogicalType &duckdb_type); - static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele); + static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry); ClientContext &GetContext() { return context; @@ -118,9 +167,13 @@ class ParquetWriter { ParquetVersion GetParquetVersion() const { return parquet_version; } + GeoParquetVersion GetGeoParquetVersion() const { + return geoparquet_version; + } const string &GetFileName() const { return file_name; } + void AnalyzeSchema(ColumnDataCollection &buffer, vector> &column_writers); uint32_t Write(const duckdb_apache::thrift::TBase &object); uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size); @@ -134,6 +187,8 @@ class ParquetWriter { void SetWrittenStatistics(CopyFunctionFileStatistics &written_stats); void FlushColumnStats(idx_t col_idx, duckdb_parquet::ColumnChunk &chunk, optional_ptr writer_stats); + void InitializePreprocessing(unique_ptr &transform_data); + void InitializeSchemaElements(); private: void GatherWrittenStatistics(); @@ -155,7 +210,7 @@ class ParquetWriter { bool debug_use_openssl; shared_ptr encryption_util; ParquetVersion parquet_version; - vector column_schemas; + GeoParquetVersion geoparquet_version; unique_ptr writer; //! Atomics to reduce contention when rotating writes to multiple Parquet files diff --git a/extension/parquet/include/writer/array_column_writer.hpp b/extension/parquet/include/writer/array_column_writer.hpp index 1ebb16c04036..404430e1e292 100644 --- a/extension/parquet/include/writer/array_column_writer.hpp +++ b/extension/parquet/include/writer/array_column_writer.hpp @@ -14,9 +14,9 @@ namespace duckdb { class ArrayColumnWriter : public ListColumnWriter { public: - ArrayColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - unique_ptr child_writer_p, bool can_have_nulls) - : ListColumnWriter(writer, column_schema, std::move(schema_path_p), std::move(child_writer_p), can_have_nulls) { + ArrayColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + unique_ptr child_writer_p) + : ListColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writer_p)) { } ~ArrayColumnWriter() override = default; diff --git a/extension/parquet/include/writer/boolean_column_writer.hpp b/extension/parquet/include/writer/boolean_column_writer.hpp index eeaa3d23c30f..a5606a125b8c 100644 --- a/extension/parquet/include/writer/boolean_column_writer.hpp +++ b/extension/parquet/include/writer/boolean_column_writer.hpp @@ -14,8 +14,7 @@ namespace duckdb { class BooleanColumnWriter : public PrimitiveColumnWriter { public: - BooleanColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - bool can_have_nulls); + BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~BooleanColumnWriter() override = default; public: diff --git a/extension/parquet/include/writer/decimal_column_writer.hpp b/extension/parquet/include/writer/decimal_column_writer.hpp index 38c696571310..91ced28995a1 100644 --- a/extension/parquet/include/writer/decimal_column_writer.hpp +++ b/extension/parquet/include/writer/decimal_column_writer.hpp @@ -14,8 +14,7 @@ namespace duckdb { class FixedDecimalColumnWriter : public PrimitiveColumnWriter { public: - FixedDecimalColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls); + FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~FixedDecimalColumnWriter() override = default; public: diff --git a/extension/parquet/include/writer/enum_column_writer.hpp b/extension/parquet/include/writer/enum_column_writer.hpp index 4e3e6e3aaa1b..ba0f6c4549f9 100644 --- a/extension/parquet/include/writer/enum_column_writer.hpp +++ b/extension/parquet/include/writer/enum_column_writer.hpp @@ -15,8 +15,7 @@ class EnumWriterPageState; class EnumColumnWriter : public PrimitiveColumnWriter { public: - EnumColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - bool can_have_nulls); + EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~EnumColumnWriter() override = default; uint32_t bit_width; diff --git a/extension/parquet/include/writer/list_column_writer.hpp b/extension/parquet/include/writer/list_column_writer.hpp index 902d3001ce89..df7ecf276825 100644 --- a/extension/parquet/include/writer/list_column_writer.hpp +++ b/extension/parquet/include/writer/list_column_writer.hpp @@ -26,9 +26,9 @@ class ListColumnWriterState : public ColumnWriterState { class ListColumnWriter : public ColumnWriter { public: - ListColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - unique_ptr child_writer_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + ListColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + unique_ptr child_writer_p) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { child_writers.push_back(std::move(child_writer_p)); } ~ListColumnWriter() override = default; @@ -44,6 +44,7 @@ class ListColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; protected: ColumnWriter &GetChildWriter(); diff --git a/extension/parquet/include/writer/primitive_column_writer.hpp b/extension/parquet/include/writer/primitive_column_writer.hpp index 28b217692219..36874cf6da66 100644 --- a/extension/parquet/include/writer/primitive_column_writer.hpp +++ b/extension/parquet/include/writer/primitive_column_writer.hpp @@ -57,8 +57,7 @@ class PrimitiveColumnWriterState : public ColumnWriterState { //! Base class for writing non-compound types (ex. numerics, strings) class PrimitiveColumnWriter : public ColumnWriter { public: - PrimitiveColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path, - bool can_have_nulls); + PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); ~PrimitiveColumnWriter() override = default; //! We limit the uncompressed page size to 100MB @@ -75,6 +74,7 @@ class PrimitiveColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; protected: static void WriteLevels(Allocator &allocator, WriteStream &temp_writer, const unsafe_vector &levels, diff --git a/extension/parquet/include/writer/struct_column_writer.hpp b/extension/parquet/include/writer/struct_column_writer.hpp index bbb6cd06b61f..a3d433467d1d 100644 --- a/extension/parquet/include/writer/struct_column_writer.hpp +++ b/extension/parquet/include/writer/struct_column_writer.hpp @@ -14,9 +14,9 @@ namespace duckdb { class StructColumnWriter : public ColumnWriter { public: - StructColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + StructColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + vector> child_writers_p) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { child_writers = std::move(child_writers_p); } ~StructColumnWriter() override = default; @@ -32,6 +32,7 @@ class StructColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; }; } // namespace duckdb diff --git a/extension/parquet/include/writer/templated_column_writer.hpp b/extension/parquet/include/writer/templated_column_writer.hpp index 4c9f1d8aa6c8..e9bd8ad35d15 100644 --- a/extension/parquet/include/writer/templated_column_writer.hpp +++ b/extension/parquet/include/writer/templated_column_writer.hpp @@ -116,10 +116,8 @@ class StandardWriterPageState : public ColumnWriterPageState { template class StandardColumnWriter : public PrimitiveColumnWriter { public: - StandardColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, // NOLINT - bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + StandardColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } ~StandardColumnWriter() override = default; diff --git a/extension/parquet/include/writer/variant_column_writer.hpp b/extension/parquet/include/writer/variant_column_writer.hpp index 74fdda60872d..07250c4ac06d 100644 --- a/extension/parquet/include/writer/variant_column_writer.hpp +++ b/extension/parquet/include/writer/variant_column_writer.hpp @@ -10,21 +10,100 @@ #include "struct_column_writer.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" namespace duckdb { +using variant_type_map = array(VariantLogicalType::ENUM_SIZE)>; + +struct ObjectAnalyzeData; +struct ArrayAnalyzeData; + +struct VariantAnalyzeData { +public: + VariantAnalyzeData() { + } + +public: + //! Map for every value what type it is + variant_type_map type_map = {}; + //! Map for every decimal value what physical type it has + array decimal_type_map = {}; + unique_ptr object_data = nullptr; + unique_ptr array_data = nullptr; +}; + +struct ObjectAnalyzeData { +public: + ObjectAnalyzeData() { + } + +public: + case_insensitive_map_t fields; +}; + +struct ArrayAnalyzeData { +public: + ArrayAnalyzeData() { + } + +public: + VariantAnalyzeData child; +}; + +struct VariantAnalyzeSchemaState : public ParquetAnalyzeSchemaState { +public: + VariantAnalyzeSchemaState() { + } + ~VariantAnalyzeSchemaState() override { + } + +public: + VariantAnalyzeData analyze_data; +}; + class VariantColumnWriter : public StructColumnWriter { public: - VariantColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - vector> child_writers_p, bool can_have_nulls) - : StructColumnWriter(writer, column_schema, std::move(schema_path_p), std::move(child_writers_p), - can_have_nulls) { + VariantColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + vector> child_writers_p) + : StructColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writers_p)) { } ~VariantColumnWriter() override = default; +public: + void FinalizeSchema(vector &schemas) override; + unique_ptr AnalyzeSchemaInit() override; + void AnalyzeSchema(ParquetAnalyzeSchemaState &state, Vector &input, idx_t count) override; + void AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state) override; + + bool HasTransform() override { + return true; + } + LogicalType TransformedType() override { + child_list_t children; + for (auto &writer : child_writers) { + auto &child_name = writer->Schema().name; + auto &child_type = writer->Schema().type; + children.emplace_back(child_name, child_type); + } + return LogicalType::STRUCT(std::move(children)); + } + unique_ptr TransformExpression(unique_ptr expr) override { + vector> arguments; + arguments.push_back(unique_ptr_cast(std::move(expr))); + + return make_uniq(TransformedType(), GetTransformFunction(), std::move(arguments), + nullptr, false); + } + public: static ScalarFunction GetTransformFunction(); static LogicalType TransformTypedValueRecursive(const LogicalType &type); + +private: + //! Whether the schema of the variant has been analyzed already + bool is_analyzed = false; }; } // namespace duckdb diff --git a/extension/parquet/parquet_column_schema.cpp b/extension/parquet/parquet_column_schema.cpp new file mode 100644 index 000000000000..64409c1abbbd --- /dev/null +++ b/extension/parquet/parquet_column_schema.cpp @@ -0,0 +1,113 @@ +#include "parquet_column_schema.hpp" +#include "parquet_reader.hpp" + +namespace duckdb { + +void ParquetColumnSchema::SetSchemaIndex(idx_t schema_idx) { + D_ASSERT(!schema_index.IsValid()); + schema_index = schema_idx; +} + +//! Writer constructors + +ParquetColumnSchema ParquetColumnSchema::FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t column_index, + duckdb_parquet::FieldRepetitionType::type repetition_type, + bool allow_geometry, ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.column_index = column_index; + res.repetition_type = repetition_type; + res.schema_type = schema_type; + res.type = type; + res.allow_geometry = allow_geometry; + return res; +} + +//! Reader constructors + +ParquetColumnSchema ParquetColumnSchema::FromSchemaElement(const duckdb_parquet::SchemaElement &element, + idx_t max_define, idx_t max_repeat, idx_t schema_index, + idx_t column_index, ParquetColumnSchemaType schema_type, + const ParquetOptions &options) { + ParquetColumnSchema res; + res.name = element.name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.schema_index = schema_index; + res.column_index = column_index; + res.schema_type = schema_type; + res.type = ParquetReader::DeriveLogicalType(element, options, res); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FromParentSchema(ParquetColumnSchema parent, LogicalType result_type, + ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = parent.name; + res.max_define = parent.max_define; + res.max_repeat = parent.max_repeat; + D_ASSERT(parent.schema_index.IsValid()); + res.schema_index = parent.schema_index; + res.column_index = parent.column_index; + res.schema_type = schema_type; + res.type = result_type; + res.children.push_back(std::move(parent)); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FromChildSchemas(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t schema_index, idx_t column_index, + vector &&children, + ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.schema_index = schema_index; + res.column_index = column_index; + res.schema_type = schema_type; + res.type = type; + res.children = std::move(children); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FileRowNumber() { + ParquetColumnSchema res; + res.name = "file_row_number"; + res.max_define = 0; + res.max_repeat = 0; + res.schema_index = 0; + res.column_index = 0; + res.schema_type = ParquetColumnSchemaType::FILE_ROW_NUMBER; + res.type = LogicalType::BIGINT, res.repetition_type = duckdb_parquet::FieldRepetitionType::type::OPTIONAL; + return res; +} + +unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, + const ParquetOptions &parquet_options, idx_t row_group_idx_p, + const vector &columns) const { + if (schema_type == ParquetColumnSchemaType::EXPRESSION) { + return nullptr; + } + if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { + auto stats = NumericStats::CreateUnknown(type); + auto &row_groups = file_meta_data.row_groups; + D_ASSERT(row_group_idx_p < row_groups.size()); + idx_t row_group_offset_min = 0; + for (idx_t i = 0; i < row_group_idx_p; i++) { + row_group_offset_min += row_groups[i].num_rows; + } + + NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); + NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + + row_groups[row_group_idx_p].num_rows))); + stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + return stats.ToUnique(); + } + return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns, parquet_options.can_have_nan); +} + +} // namespace duckdb diff --git a/extension/parquet/parquet_extension.cpp b/extension/parquet/parquet_extension.cpp index 6ce7733a7858..77c78b3a8d55 100644 --- a/extension/parquet/parquet_extension.cpp +++ b/extension/parquet/parquet_extension.cpp @@ -94,36 +94,27 @@ struct ParquetWriteBindData : public TableFunctionData { //! Which encodings to include when writing ParquetVersion parquet_version = ParquetVersion::V1; -}; - -struct ParquetWriteGlobalState : public GlobalFunctionData { - unique_ptr writer; - optional_ptr op; - void LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason) { - if (!op) { - return; - } - DUCKDB_LOG(writer->GetContext(), PhysicalOperatorLogType, *op, "ParquetWriter", "FlushRowGroup", - {{"file", writer->GetFileName()}, - {"rows", to_string(buffer.Count())}, - {"size", to_string(buffer.SizeInBytes())}, - {"reason", reason}}); - } - - mutex lock; - unique_ptr combine_buffer; + //! Which geo-parquet version to use when writing + GeoParquetVersion geoparquet_version = GeoParquetVersion::V1; }; -struct ParquetWriteLocalState : public LocalFunctionData { - explicit ParquetWriteLocalState(ClientContext &context, const vector &types) : buffer(context, types) { - buffer.SetPartitionIndex(0); // Makes the buffer manager less likely to spill this data - buffer.InitializeAppend(append_state); +void ParquetWriteGlobalState::LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason) { + if (!op) { + return; } + DUCKDB_LOG(writer->GetContext(), PhysicalOperatorLogType, *op, "ParquetWriter", "FlushRowGroup", + {{"file", writer->GetFileName()}, + {"rows", to_string(buffer.Count())}, + {"size", to_string(buffer.SizeInBytes())}, + {"reason", reason}}); +} - ColumnDataCollection buffer; - ColumnDataAppendState append_state; -}; +ParquetWriteLocalState::ParquetWriteLocalState(ClientContext &context, const vector &types) + : buffer(context, types) { + buffer.SetPartitionIndex(0); // Makes the buffer manager less likely to spill this data + buffer.InitializeAppend(append_state); +} static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &input) { auto ©_options = input.options; @@ -147,6 +138,7 @@ static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &inp copy_options["binary_as_string"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["file_row_number"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["can_have_nan"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); + copy_options["geoparquet_version"] = CopyOption(LogicalType::VARCHAR, CopyOptionMode::WRITE_ONLY); copy_options["shredding"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); } @@ -219,10 +211,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } else { case_insensitive_set_t variant_names; for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - if (sql_types[col_idx].id() != LogicalTypeId::STRUCT) { - continue; - } - if (sql_types[col_idx].GetAlias() != "PARQUET_VARIANT") { + if (sql_types[col_idx].id() != LogicalTypeId::VARIANT) { continue; } variant_names.emplace(names[col_idx]); @@ -333,6 +322,19 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } else { throw BinderException("Expected parquet_version 'V1' or 'V2'"); } + } else if (loption == "geoparquet_version") { + const auto roption = StringUtil::Upper(option.second[0].ToString()); + if (roption == "NONE") { + bind_data->geoparquet_version = GeoParquetVersion::NONE; + } else if (roption == "V1") { + bind_data->geoparquet_version = GeoParquetVersion::V1; + } else if (roption == "V2") { + bind_data->geoparquet_version = GeoParquetVersion::V2; + } else if (roption == "BOTH") { + bind_data->geoparquet_version = GeoParquetVersion::BOTH; + } else { + throw BinderException("Expected geoparquet_version 'NONE', 'V1' or 'BOTH'"); + } } else { throw InternalException("Unrecognized option for PARQUET: %s", option.first.c_str()); } @@ -365,7 +367,7 @@ static unique_ptr ParquetWriteInitializeGlobal(ClientContext parquet_bind.encryption_config, parquet_bind.dictionary_size_limit, parquet_bind.string_dictionary_page_size_limit, parquet_bind.enable_bloom_filters, parquet_bind.bloom_filter_false_positive_ratio, parquet_bind.compression_level, parquet_bind.debug_use_openssl, - parquet_bind.parquet_version); + parquet_bind.parquet_version, parquet_bind.geoparquet_version); return std::move(global_state); } @@ -391,7 +393,7 @@ static void ParquetWriteSink(ExecutionContext &context, FunctionData &bind_data_ global_state.LogFlushingRowGroup(local_state.buffer, reason); // if the chunk collection exceeds a certain size (rows/bytes) we flush it to the parquet file local_state.append_state.current_chunk_state.handles.clear(); - global_state.writer->Flush(local_state.buffer); + global_state.writer->Flush(local_state.buffer, local_state.transform_data); local_state.buffer.InitializeAppend(local_state.append_state); } } @@ -406,7 +408,7 @@ static void ParquetWriteCombine(ExecutionContext &context, FunctionData &bind_da local_state.buffer.SizeInBytes() >= bind_data.row_group_size_bytes / 2) { // local state buffer is more than half of the row_group_size(_bytes), just flush it global_state.LogFlushingRowGroup(local_state.buffer, "Combine"); - global_state.writer->Flush(local_state.buffer); + global_state.writer->Flush(local_state.buffer, local_state.transform_data); return; } @@ -421,7 +423,7 @@ static void ParquetWriteCombine(ExecutionContext &context, FunctionData &bind_da guard.unlock(); global_state.LogFlushingRowGroup(*owned_combine_buffer, "Combine"); // Lock free, of course - global_state.writer->Flush(*owned_combine_buffer); + global_state.writer->Flush(*owned_combine_buffer, local_state.transform_data); } return; } @@ -435,7 +437,7 @@ static void ParquetWriteFinalize(ClientContext &context, FunctionData &bind_data // flush the combine buffer (if it's there) if (global_state.combine_buffer) { global_state.LogFlushingRowGroup(*global_state.combine_buffer, "Finalize"); - global_state.writer->Flush(*global_state.combine_buffer); + global_state.writer->Flush(*global_state.combine_buffer, global_state.transform_data); } // finalize: write any additional metadata to the file here @@ -534,6 +536,39 @@ ParquetVersion EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template <> +const char *EnumUtil::ToChars(GeoParquetVersion value) { + switch (value) { + case GeoParquetVersion::NONE: + return "NONE"; + case GeoParquetVersion::V1: + return "V1"; + case GeoParquetVersion::V2: + return "V2"; + case GeoParquetVersion::BOTH: + return "BOTH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); + } +} + +template <> +GeoParquetVersion EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return GeoParquetVersion::NONE; + } + if (StringUtil::Equals(value, "V1")) { + return GeoParquetVersion::V1; + } + if (StringUtil::Equals(value, "V2")) { + return GeoParquetVersion::V2; + } + if (StringUtil::Equals(value, "BOTH")) { + return GeoParquetVersion::BOTH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + static optional_idx SerializeCompressionLevel(const int64_t compression_level) { return compression_level < 0 ? NumericLimits::Maximum() - NumericCast(AbsValue(compression_level)) : NumericCast(compression_level); @@ -587,7 +622,9 @@ static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bin serializer.WritePropertyWithDefault(115, "string_dictionary_page_size_limit", bind_data.string_dictionary_page_size_limit, default_value.string_dictionary_page_size_limit); - serializer.WriteProperty(116, "shredding_types", bind_data.shredding_types); + serializer.WritePropertyWithDefault(116, "geoparquet_version", bind_data.geoparquet_version, + default_value.geoparquet_version); + serializer.WriteProperty(117, "shredding_types", bind_data.shredding_types); } static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { @@ -620,7 +657,9 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize deserializer.ReadPropertyWithExplicitDefault(114, "parquet_version", default_value.parquet_version); data->string_dictionary_page_size_limit = deserializer.ReadPropertyWithExplicitDefault( 115, "string_dictionary_page_size_limit", default_value.string_dictionary_page_size_limit); - data->shredding_types = deserializer.ReadProperty(116, "shredding_types"); + data->geoparquet_version = + deserializer.ReadPropertyWithExplicitDefault(116, "geoparquet_version", default_value.geoparquet_version); + data->shredding_types = deserializer.ReadProperty(117, "shredding_types"); return std::move(data); } @@ -657,7 +696,8 @@ static unique_ptr ParquetWritePrepareBatch(ClientContext &con unique_ptr collection) { auto &global_state = gstate.Cast(); auto result = make_uniq(); - global_state.writer->PrepareRowGroup(*collection, result->prepared_row_group); + unique_ptr transform_data; + global_state.writer->PrepareRowGroup(*collection, result->prepared_row_group, transform_data); return std::move(result); } @@ -751,38 +791,6 @@ static bool IsGeometryType(const LogicalType &type, ClientContext &context) { return GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context); } -static string GetShredding(case_insensitive_map_t> &options, const string &col_name) { - //! At this point, the options haven't been parsed yet, so we have to parse them ourselves. - auto it = options.find("shredding"); - if (it == options.end()) { - return string(); - } - auto &shredding = it->second; - if (shredding.empty()) { - return string(); - } - - auto &shredding_val = shredding[0]; - if (shredding_val.type().id() != LogicalTypeId::STRUCT) { - return string(); - } - - auto &shredded_variants = StructType::GetChildTypes(shredding_val.type()); - auto &values = StructValue::GetChildren(shredding_val); - for (idx_t i = 0; i < shredded_variants.size(); i++) { - auto &shredded_variant = shredded_variants[i]; - if (shredded_variant.first != col_name) { - continue; - } - auto &shredded_val = values[i]; - if (shredded_val.type().id() != LogicalTypeId::VARCHAR) { - return string(); - } - return shredded_val.GetValue(); - } - return string(); -} - static vector> ParquetWriteSelect(CopyToSelectInput &input) { auto &context = input.context; @@ -807,23 +815,6 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu cast_expr->SetAlias(name); result.push_back(std::move(cast_expr)); any_change = true; - } else if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::VARIANT) { - vector> arguments; - arguments.push_back(std::move(expr)); - - auto shredded_type_str = GetShredding(input.options, name); - if (!shredded_type_str.empty()) { - arguments.push_back(make_uniq(Value(shredded_type_str))); - } - - auto transform_func = VariantColumnWriter::GetTransformFunction(); - transform_func.bind(context, transform_func, arguments); - - auto func_expr = make_uniq(transform_func.return_type, transform_func, - std::move(arguments), nullptr, false); - func_expr->SetAlias(name); - result.push_back(std::move(func_expr)); - any_change = true; } // If this is an EXPORT DATABASE statement, we dont want to write "lossy" types, instead cast them to VARCHAR else if (input.copy_to_type == CopyToType::EXPORT_DATABASE && TypeVisitor::Contains(type, IsTypeLossy)) { diff --git a/extension/parquet/parquet_field_id.cpp b/extension/parquet/parquet_field_id.cpp index d1ff138cc711..642fc26c76ae 100644 --- a/extension/parquet/parquet_field_id.cpp +++ b/extension/parquet/parquet_field_id.cpp @@ -3,6 +3,8 @@ namespace duckdb { +constexpr const char *FieldID::DUCKDB_FIELD_ID; + ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { } diff --git a/extension/parquet/parquet_reader.cpp b/extension/parquet/parquet_reader.cpp index 2ce1a44038a5..99971575c364 100644 --- a/extension/parquet/parquet_reader.cpp +++ b/extension/parquet/parquet_reader.cpp @@ -175,6 +175,11 @@ LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &fi } LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, ParquetColumnSchema &schema) const { + return DeriveLogicalType(s_ele, parquet_options, schema); +} + +LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, const ParquetOptions &parquet_options, + ParquetColumnSchema &schema) { // inner node if (s_ele.type == Type::FIXED_LEN_BYTE_ARRAY && !s_ele.__isset.type_length) { throw IOException("FIXED_LEN_BYTE_ARRAY requires length to be set"); @@ -396,10 +401,8 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, Parquet ParquetColumnSchema ParquetReader::ParseColumnSchema(const SchemaElement &s_ele, idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, ParquetColumnSchemaType type) { - ParquetColumnSchema schema(max_define, max_repeat, schema_index, column_index, type); - schema.name = s_ele.name; - schema.type = DeriveLogicalType(s_ele, schema); - return schema; + return ParquetColumnSchema::FromSchemaElement(s_ele, max_define, max_repeat, schema_index, column_index, type, + parquet_options); } unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &context, @@ -466,8 +469,8 @@ unique_ptr ParquetReader::CreateReader(ClientContext &context) { auto column_id = entry.first; auto &expression = entry.second; auto child_reader = std::move(root_struct_reader.child_readers[column_id]); - auto expr_schema = make_uniq(child_reader->Schema(), expression->return_type, - ParquetColumnSchemaType::EXPRESSION); + auto expr_schema = make_uniq(ParquetColumnSchema::FromParentSchema( + child_reader->Schema(), expression->return_type, ParquetColumnSchemaType::EXPRESSION)); auto expr_reader = make_uniq(context, std::move(child_reader), expression->Copy(), std::move(expr_schema)); root_struct_reader.child_readers[column_id] = std::move(expr_reader); @@ -475,49 +478,6 @@ unique_ptr ParquetReader::CreateReader(ClientContext &context) { return ret; } -ParquetColumnSchema::ParquetColumnSchema(idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, - ParquetColumnSchemaType schema_type) - : ParquetColumnSchema(string(), LogicalTypeId::INVALID, max_define, max_repeat, schema_index, column_index, - schema_type) { -} - -ParquetColumnSchema::ParquetColumnSchema(string name_p, LogicalType type_p, idx_t max_define, idx_t max_repeat, - idx_t schema_index, idx_t column_index, ParquetColumnSchemaType schema_type) - : schema_type(schema_type), name(std::move(name_p)), type(std::move(type_p)), max_define(max_define), - max_repeat(max_repeat), schema_index(schema_index), column_index(column_index) { -} - -ParquetColumnSchema::ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, - ParquetColumnSchemaType schema_type) - : schema_type(schema_type), name(parent.name), type(std::move(result_type)), max_define(parent.max_define), - max_repeat(parent.max_repeat), schema_index(parent.schema_index), column_index(parent.column_index) { - children.push_back(std::move(parent)); -} - -unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, - const ParquetOptions &parquet_options, idx_t row_group_idx_p, - const vector &columns) const { - if (schema_type == ParquetColumnSchemaType::EXPRESSION) { - return nullptr; - } - if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { - auto stats = NumericStats::CreateUnknown(type); - auto &row_groups = file_meta_data.row_groups; - D_ASSERT(row_group_idx_p < row_groups.size()); - idx_t row_group_offset_min = 0; - for (idx_t i = 0; i < row_group_idx_p; i++) { - row_group_offset_min += row_groups[i].num_rows; - } - - NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); - NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + - row_groups[row_group_idx_p].num_rows))); - stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - return stats.ToUnique(); - } - return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns, parquet_options.can_have_nan); -} - static bool IsVariantType(const SchemaElement &root, const vector &children) { if (children.size() < 2) { return false; @@ -594,8 +554,8 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d // geoarrow types, although geometry columns, are structs and have children and are handled below. if (metadata->geo_metadata && metadata->geo_metadata->IsGeometryColumn(s_ele.name) && s_ele.num_children == 0) { auto root_schema = ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx++); - return ParquetColumnSchema(std::move(root_schema), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); + return ParquetColumnSchema::FromParentSchema(std::move(root_schema), GeoParquetFileMetadata::GeometryType(), + ParquetColumnSchemaType::GEOMETRY); } } @@ -650,14 +610,12 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d throw IOException("MAP_KEY_VALUE needs to be repeated"); } auto result_type = LogicalType::MAP(child_schemas[0].type, child_schemas[1].type); - ParquetColumnSchema struct_schema(s_ele.name, ListType::GetChildType(result_type), max_define - 1, - max_repeat - 1, this_idx, next_file_idx); - struct_schema.children = std::move(child_schemas); - - ParquetColumnSchema map_schema(s_ele.name, std::move(result_type), max_define, max_repeat, this_idx, - next_file_idx); - map_schema.children.push_back(std::move(struct_schema)); - return map_schema; + vector map_children; + map_children.emplace_back(ParquetColumnSchema::FromChildSchemas( + s_ele.name, ListType::GetChildType(result_type), max_define - 1, max_repeat - 1, this_idx, + next_file_idx, std::move(child_schemas))); + return ParquetColumnSchema::FromChildSchemas(s_ele.name, result_type, max_define, max_repeat, this_idx, + next_file_idx, std::move(map_children)); } ParquetColumnSchema result; if (child_schemas.size() > 1 || (!is_list && !is_map && !is_repeated)) { @@ -672,13 +630,10 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d } else { result_type = LogicalType::STRUCT(std::move(struct_types)); } - ParquetColumnSchema struct_schema(s_ele.name, std::move(result_type), max_define, max_repeat, this_idx, - next_file_idx); - struct_schema.children = std::move(child_schemas); - if (is_variant) { - struct_schema.schema_type = ParquetColumnSchemaType::VARIANT; - } - result = std::move(struct_schema); + ParquetColumnSchemaType schema_type = + is_variant ? ParquetColumnSchemaType::VARIANT : ParquetColumnSchemaType::COLUMN; + result = ParquetColumnSchema::FromChildSchemas(s_ele.name, result_type, max_define, max_repeat, this_idx, + next_file_idx, std::move(child_schemas), schema_type); } else { // if we have a struct with only a single type, pull up result = std::move(child_schemas[0]); @@ -686,10 +641,9 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d } if (is_repeated) { auto list_type = LogicalType::LIST(result.type); - ParquetColumnSchema list_schema(s_ele.name, std::move(list_type), max_define, max_repeat, this_idx, - next_file_idx); - list_schema.children.push_back(std::move(result)); - result = std::move(list_schema); + vector list_child = {std::move(result)}; + result = ParquetColumnSchema::FromChildSchemas(s_ele.name, std::move(list_type), max_define, max_repeat, + this_idx, next_file_idx, std::move(list_child)); } result.parent_schema_index = this_idx; return result; @@ -702,17 +656,16 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d auto result = ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx++); if (s_ele.repetition_type == FieldRepetitionType::REPEATED) { auto list_type = LogicalType::LIST(result.type); - ParquetColumnSchema list_schema(s_ele.name, std::move(list_type), max_define, max_repeat, this_idx, - next_file_idx); - list_schema.children.push_back(std::move(result)); - return list_schema; + vector list_child = {std::move(result)}; + return ParquetColumnSchema::FromChildSchemas(s_ele.name, std::move(list_type), max_define, max_repeat, + this_idx, next_file_idx, std::move(list_child)); } // Convert to geometry type if possible if (s_ele.__isset.logicalType && (s_ele.logicalType.__isset.GEOMETRY || s_ele.logicalType.__isset.GEOGRAPHY) && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - return ParquetColumnSchema(std::move(result), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); + return ParquetColumnSchema::FromParentSchema(std::move(result), GeoParquetFileMetadata::GeometryType(), + ParquetColumnSchemaType::GEOMETRY); } return result; @@ -720,8 +673,7 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d } static ParquetColumnSchema FileRowNumberSchema() { - return ParquetColumnSchema("file_row_number", LogicalType::BIGINT, 0, 0, 0, 0, - ParquetColumnSchemaType::FILE_ROW_NUMBER); + return ParquetColumnSchema::FileRowNumber(); } unique_ptr ParquetReader::ParseSchema(ClientContext &context) { @@ -730,23 +682,28 @@ unique_ptr ParquetReader::ParseSchema(ClientContext &contex idx_t next_file_idx = 0; if (file_meta_data->schema.empty()) { - throw IOException("Parquet reader: no schema elements found"); + throw IOException("Failed to read Parquet file \"%s\": no schema elements found", file.path); } if (file_meta_data->schema[0].num_children == 0) { - throw IOException("Parquet reader: root schema element has no children"); + throw IOException("Failed to read Parquet file \"%s\": root schema element has no children", file.path); } auto root = ParseSchemaRecursive(0, 0, 0, next_schema_idx, next_file_idx, context); if (root.type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("Root element of Parquet file must be a struct"); + throw InvalidInputException("Failed to read Parquet file \"%s\": Root element of Parquet file must be a struct", + file.path); } D_ASSERT(next_schema_idx == file_meta_data->schema.size() - 1); - D_ASSERT(file_meta_data->row_groups.empty() || next_file_idx == file_meta_data->row_groups[0].columns.size()); + if (!file_meta_data->row_groups.empty() && next_file_idx != file_meta_data->row_groups[0].columns.size()) { + throw InvalidInputException("Failed to read Parquet file \"%s\": row group does not have enough columns", + file.path); + } if (parquet_options.file_row_number) { for (auto &column : root.children) { auto &name = column.name; if (StringUtil::CIEquals(name, "file_row_number")) { - throw BinderException( - "Using file_row_number option on file with column named file_row_number is not supported"); + throw BinderException("Failed to read Parquet file \"%s\": Using file_row_number option on file with " + "column named file_row_number is not supported", + file.path); } } root.children.push_back(FileRowNumberSchema()); @@ -761,7 +718,8 @@ MultiFileColumnDefinition ParquetReader::ParseColumnDefinition(const FileMetaDat result.identifier = Value::INTEGER(MultiFileReader::ORDINAL_FIELD_ID); return result; } - auto &column_schema = file_meta_data.schema[element.schema_index]; + D_ASSERT(element.schema_index.IsValid()); + auto &column_schema = file_meta_data.schema[element.schema_index.GetIndex()]; if (column_schema.__isset.field_id) { result.identifier = Value::INTEGER(column_schema.field_id); diff --git a/extension/parquet/parquet_writer.cpp b/extension/parquet/parquet_writer.cpp index 1aeb319c7bc7..24fdb3bf4974 100644 --- a/extension/parquet/parquet_writer.cpp +++ b/extension/parquet/parquet_writer.cpp @@ -144,7 +144,8 @@ Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); } -void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele) { +void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry) { if (duckdb_type.IsJSONType()) { schema_ele.converted_type = ConvertedType::JSON; schema_ele.__isset.converted_type = true; @@ -152,7 +153,7 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.__set_JSON(duckdb_parquet::JsonType()); return; } - if (duckdb_type.GetAlias() == "WKB_BLOB") { + if (duckdb_type.GetAlias() == "WKB_BLOB" && allow_geometry) { schema_ele.__isset.logicalType = true; schema_ele.logicalType.__isset.GEOMETRY = true; // TODO: Set CRS in the future @@ -328,13 +329,34 @@ class ParquetStatsAccumulator { vector> stats_unifiers; }; +ParquetWriteTransformData::ParquetWriteTransformData(ClientContext &context, vector types, + vector> expressions_p) + : buffer(context, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR), expressions(std::move(expressions_p)), + executor(context, expressions) { + chunk.Initialize(buffer.GetAllocator(), types); +} + +//! TODO: this doesnt work.. the ParquetWriteTransformData is shared with all threads, the method is stateful, but has +//! no locks Either every local state needs its own copy of this or we need a lock so its used by one thread at a time.. +//! The former has my preference +ColumnDataCollection &ParquetWriteTransformData::ApplyTransform(ColumnDataCollection &input) { + buffer.Reset(); + for (auto &input_chunk : input.Chunks()) { + chunk.Reset(); + executor.Execute(input_chunk, chunk); + buffer.Append(chunk); + } + return buffer; +} + ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, ShreddingType shredding_types_p, const vector> &kv_metadata, shared_ptr encryption_config_p, optional_idx dictionary_size_limit_p, idx_t string_dictionary_page_size_limit_p, bool enable_bloom_filters_p, double bloom_filter_false_positive_ratio_p, - int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version) + int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version, + GeoParquetVersion geoparquet_version) : context(context), file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), shredding_types(std::move(shredding_types_p)), encryption_config(std::move(encryption_config_p)), @@ -342,7 +364,8 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file string_dictionary_page_size_limit(string_dictionary_page_size_limit_p), enable_bloom_filters(enable_bloom_filters_p), bloom_filter_false_positive_ratio(bloom_filter_false_positive_ratio_p), compression_level(compression_level_p), - debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), total_written(0), num_row_groups(0) { + debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), geoparquet_version(geoparquet_version), + total_written(0), num_row_groups(0) { // initialize the file writer writer = make_uniq(fs, file_name.c_str(), @@ -375,8 +398,6 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file file_meta_data.created_by = StringUtil::Format("DuckDB version %s (build %s)", DuckDB::LibraryVersion(), DuckDB::SourceID()); - file_meta_data.schema.resize(1); - for (auto &kv_pair : kv_metadata) { duckdb_parquet::KeyValue kv; kv.__set_key(kv_pair.first); @@ -385,34 +406,132 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file file_meta_data.__isset.key_value_metadata = true; } - // populate root schema object - file_meta_data.schema[0].name = "duckdb_schema"; - file_meta_data.schema[0].num_children = NumericCast(sql_types.size()); - file_meta_data.schema[0].__isset.num_children = true; - file_meta_data.schema[0].repetition_type = duckdb_parquet::FieldRepetitionType::REQUIRED; - file_meta_data.schema[0].__isset.repetition_type = true; - auto &unique_names = column_names; VerifyUniqueNames(unique_names); - // construct the child schemas + // V1 GeoParquet stores geometries as blobs, no logical type + auto allow_geometry = geoparquet_version != GeoParquetVersion::V1; + + // construct the column writers + D_ASSERT(sql_types.size() == unique_names.size()); for (idx_t i = 0; i < sql_types.size(); i++) { - auto child_schema = ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], - &field_ids, &shredding_types); - column_schemas.push_back(std::move(child_schema)); - } - // now construct the writers based on the schemas - for (auto &child_schema : column_schemas) { vector path_in_schema; - column_writers.push_back( - ColumnWriter::CreateWriterRecursive(context, *this, file_meta_data.schema, child_schema, path_in_schema)); + column_writers.push_back(ColumnWriter::CreateWriterRecursive(context, *this, path_in_schema, sql_types[i], + unique_names[i], allow_geometry, &field_ids, + &shredding_types)); } } ParquetWriter::~ParquetWriter() { } -void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result) { +void ParquetWriter::AnalyzeSchema(ColumnDataCollection &buffer, vector> &column_writers) { + D_ASSERT(buffer.ColumnCount() == column_writers.size()); + vector> states; + bool needs_analyze = false; + lock_guard glock(lock); + + vector column_ids; + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &writer = column_writers[i]; + auto state = writer->AnalyzeSchemaInit(); + if (state) { + needs_analyze = true; + states.push_back(std::move(state)); + column_ids.push_back(i); + } else { + states.push_back(nullptr); + } + } + + if (!needs_analyze) { + return; + } + + for (auto &chunk : buffer.Chunks(column_ids)) { + idx_t index = 0; + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &state = states[i]; + if (!state) { + continue; + } + auto &writer = column_writers[i]; + writer->AnalyzeSchema(*state, chunk.data[index++], chunk.size()); + } + } + + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &writer = column_writers[i]; + auto &state = states[i]; + if (!state) { + continue; + } + writer->AnalyzeSchemaFinalize(*state); + } +} + +void ParquetWriter::InitializePreprocessing(unique_ptr &transform_data) { + if (transform_data) { + return; + } + + vector transformed_types; + vector> transform_expressions; + for (idx_t col_idx = 0; col_idx < column_writers.size(); col_idx++) { + auto &column_writer = *column_writers[col_idx]; + auto &original_type = sql_types[col_idx]; + auto expr = make_uniq(original_type, col_idx); + if (!column_writer.HasTransform()) { + transformed_types.push_back(original_type); + transform_expressions.push_back(std::move(expr)); + continue; + } + transformed_types.push_back(column_writer.TransformedType()); + transform_expressions.push_back(column_writer.TransformExpression(std::move(expr))); + } + transform_data = make_uniq(context, transformed_types, std::move(transform_expressions)); +} + +void ParquetWriter::InitializeSchemaElements() { + //! Populate the schema elements of the parquet file we're writing + lock_guard glock(lock); + if (!file_meta_data.schema.empty()) { + return; + } + // populate root schema object + file_meta_data.schema.resize(1); + file_meta_data.schema[0].name = "duckdb_schema"; + file_meta_data.schema[0].num_children = NumericCast(sql_types.size()); + file_meta_data.schema[0].__isset.num_children = true; + file_meta_data.schema[0].repetition_type = duckdb_parquet::FieldRepetitionType::REQUIRED; + file_meta_data.schema[0].__isset.repetition_type = true; + + for (auto &column_writer : column_writers) { + column_writer->FinalizeSchema(file_meta_data.schema); + } +} + +void ParquetWriter::PrepareRowGroup(ColumnDataCollection &raw_buffer, PreparedRowGroup &result, + unique_ptr &transform_data) { + AnalyzeSchema(raw_buffer, column_writers); + + bool requires_transform = false; + for (auto &writer_p : column_writers) { + auto &writer = *writer_p; + + if (writer.HasTransform()) { + requires_transform = true; + break; + } + } + + reference buffer_ref(raw_buffer); + if (requires_transform) { + InitializePreprocessing(transform_data); + buffer_ref = transform_data->ApplyTransform(raw_buffer); + } + auto &buffer = buffer_ref.get(); + // We write 8 columns at a time so that iterating over ColumnDataCollection is more efficient static constexpr idx_t COLUMNS_PER_PASS = 8; @@ -424,6 +543,8 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro row_group.num_rows = NumericCast(buffer.Count()); row_group.__isset.file_offset = true; + InitializeSchemaElements(); + auto &states = result.states; // iterate over each of the columns of the chunk collection and write them D_ASSERT(buffer.ColumnCount() == column_writers.size()); @@ -551,13 +672,13 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { ++num_row_groups; } -void ParquetWriter::Flush(ColumnDataCollection &buffer) { +void ParquetWriter::Flush(ColumnDataCollection &buffer, unique_ptr &transform_data) { if (buffer.Count() == 0) { return; } PreparedRowGroup prepared_row_group; - PrepareRowGroup(buffer, prepared_row_group); + PrepareRowGroup(buffer, prepared_row_group, transform_data); buffer.Reset(); FlushRowGroup(prepared_row_group); @@ -804,20 +925,25 @@ static unique_ptr GetBaseStatsUnifier(const LogicalType &typ } } -static void GetStatsUnifier(const ParquetColumnSchema &schema, vector> &unifiers, +static void GetStatsUnifier(const ColumnWriter &column_writer, vector> &unifiers, string base_name = string()) { - if (!base_name.empty()) { - base_name += "."; + auto &schema = column_writer.Schema(); + if (schema.repetition_type != duckdb_parquet::FieldRepetitionType::REPEATED) { + if (!base_name.empty()) { + base_name += "."; + } + base_name += KeywordHelper::WriteQuoted(schema.name, '\"'); } - base_name += KeywordHelper::WriteQuoted(schema.name, '\"'); - if (schema.children.empty()) { + + auto &children = column_writer.ChildWriters(); + if (children.empty()) { auto unifier = GetBaseStatsUnifier(schema.type); unifier->column_name = std::move(base_name); unifiers.push_back(std::move(unifier)); return; } - for (auto &child_schema : schema.children) { - GetStatsUnifier(child_schema, unifiers, base_name); + for (auto &child_writer : children) { + GetStatsUnifier(*child_writer, unifiers, base_name); } } @@ -915,6 +1041,7 @@ void ParquetWriter::GatherWrittenStatistics() { } void ParquetWriter::Finalize() { + InitializeSchemaElements(); // dump the bloom filters right before footer, not if stuff is encrypted @@ -956,7 +1083,8 @@ void ParquetWriter::Finalize() { } // Add geoparquet metadata to the file metadata - if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { + if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context) && + geoparquet_version != GeoParquetVersion::NONE) { geoparquet_data->Write(file_meta_data); } @@ -986,7 +1114,7 @@ void ParquetWriter::Finalize() { GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { if (!geoparquet_data) { - geoparquet_data = make_uniq(); + geoparquet_data = make_uniq(geoparquet_version); } return *geoparquet_data; } @@ -1007,7 +1135,7 @@ void ParquetWriter::SetWrittenStatistics(CopyFunctionFileStatistics &written_sta stats_accumulator = make_uniq(); // create the per-column stats unifiers for (auto &column_writer : column_writers) { - GetStatsUnifier(column_writer->Schema(), stats_accumulator->stats_unifiers); + GetStatsUnifier(*column_writer, stats_accumulator->stats_unifiers); } } diff --git a/extension/parquet/writer/boolean_column_writer.cpp b/extension/parquet/writer/boolean_column_writer.cpp index 5994a5d275ac..1157e4bd6e34 100644 --- a/extension/parquet/writer/boolean_column_writer.cpp +++ b/extension/parquet/writer/boolean_column_writer.cpp @@ -35,9 +35,9 @@ class BooleanWriterPageState : public ColumnWriterPageState { uint8_t byte_pos = 0; }; -BooleanColumnWriter::BooleanColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +BooleanColumnWriter::BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } unique_ptr BooleanColumnWriter::InitializeStatsState() { diff --git a/extension/parquet/writer/decimal_column_writer.cpp b/extension/parquet/writer/decimal_column_writer.cpp index 5f70697b71f8..4710a9fe6e37 100644 --- a/extension/parquet/writer/decimal_column_writer.cpp +++ b/extension/parquet/writer/decimal_column_writer.cpp @@ -66,9 +66,9 @@ class FixedDecimalStatistics : public ColumnWriterStatistics { } }; -FixedDecimalColumnWriter::FixedDecimalColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +FixedDecimalColumnWriter::FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } unique_ptr FixedDecimalColumnWriter::InitializeStatsState() { diff --git a/extension/parquet/writer/enum_column_writer.cpp b/extension/parquet/writer/enum_column_writer.cpp index b08d2f56673c..3ba5d9b28874 100644 --- a/extension/parquet/writer/enum_column_writer.cpp +++ b/extension/parquet/writer/enum_column_writer.cpp @@ -16,9 +16,9 @@ class EnumWriterPageState : public ColumnWriterPageState { bool written_value; }; -EnumColumnWriter::EnumColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +EnumColumnWriter::EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { bit_width = RleBpDecoder::ComputeBitWidth(EnumType::GetSize(Type())); } diff --git a/extension/parquet/writer/list_column_writer.cpp b/extension/parquet/writer/list_column_writer.cpp index b043a94bcb37..a54017f235f1 100644 --- a/extension/parquet/writer/list_column_writer.cpp +++ b/extension/parquet/writer/list_column_writer.cpp @@ -2,6 +2,11 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; + unique_ptr ListColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { auto result = make_uniq(row_group, row_group.columns.size()); result->child_state = GetChildWriter().InitializeWriteState(row_group); @@ -141,4 +146,50 @@ ColumnWriter &ListColumnWriter::GetChildWriter() { return *child_writers[0]; } +void ListColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto null_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + auto &type = schema.type; + + // set up the two schema elements for the list + // for some reason we only set the converted type in the OPTIONAL element + // first an OPTIONAL element + duckdb_parquet::SchemaElement optional_element; + optional_element.repetition_type = null_type; + optional_element.num_children = 1; + optional_element.converted_type = (type.id() == LogicalTypeId::MAP) ? ConvertedType::MAP : ConvertedType::LIST; + optional_element.__isset.num_children = true; + optional_element.__isset.type = false; + optional_element.__isset.repetition_type = true; + optional_element.__isset.converted_type = true; + optional_element.name = name; + if (field_id.IsValid()) { + optional_element.__isset.field_id = true; + optional_element.field_id = field_id.GetIndex(); + } + schemas.push_back(std::move(optional_element)); + + if (type.id() != LogicalTypeId::MAP) { + duckdb_parquet::SchemaElement repeated_element; + repeated_element.repetition_type = FieldRepetitionType::REPEATED; + repeated_element.__isset.num_children = true; + repeated_element.__isset.type = false; + repeated_element.__isset.repetition_type = true; + repeated_element.num_children = 1; + repeated_element.name = "list"; + schemas.push_back(std::move(repeated_element)); + } else { + //! When we're describing a MAP, we skip the dummy "list" element + //! Instead, the "key_value" struct will be marked as REPEATED + D_ASSERT(GetChildWriter().Schema().repetition_type == FieldRepetitionType::REPEATED); + } + GetChildWriter().FinalizeSchema(schemas); +} + } // namespace duckdb diff --git a/extension/parquet/writer/primitive_column_writer.cpp b/extension/parquet/writer/primitive_column_writer.cpp index 16189ab24320..c23680e73c8d 100644 --- a/extension/parquet/writer/primitive_column_writer.cpp +++ b/extension/parquet/writer/primitive_column_writer.cpp @@ -7,9 +7,12 @@ namespace duckdb { using duckdb_parquet::Encoding; using duckdb_parquet::PageType; -PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path), can_have_nulls) { +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_PAGE_SIZE; +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; + +PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path)) { } unique_ptr PrimitiveColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { @@ -111,7 +114,7 @@ void PrimitiveColumnWriter::BeginWrite(ColumnWriterState &state_p) { hdr.type = PageType::DATA_PAGE; hdr.__isset.data_page_header = true; - hdr.data_page_header.num_values = UnsafeNumericCast(page_info.row_count); + hdr.data_page_header.num_values = NumericCast(page_info.row_count); hdr.data_page_header.encoding = GetEncoding(state); hdr.data_page_header.definition_level_encoding = Encoding::RLE; hdr.data_page_header.repetition_level_encoding = Encoding::RLE; @@ -304,12 +307,24 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta } if (state.stats_state->HasGeoStats()) { - column_chunk.meta_data.__isset.geospatial_statistics = true; - state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); - // Add the geospatial statistics to the extra GeoParquet metadata - writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, - *state.stats_state->GetGeoStats()); + auto gpq_version = writer.GetGeoParquetVersion(); + + const auto has_real_stats = gpq_version == GeoParquetVersion::NONE || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + const auto has_json_stats = gpq_version == GeoParquetVersion::V1 || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + + if (has_real_stats) { + // Write the parquet native geospatial statistics + column_chunk.meta_data.__isset.geospatial_statistics = true; + state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + } + if (has_json_stats) { + // Add the geospatial statistics to the extra GeoParquet metadata + writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, + *state.stats_state->GetGeoStats()); + } } for (const auto &write_info : state.write_info) { @@ -417,4 +432,33 @@ void PrimitiveColumnWriter::WriteDictionary(PrimitiveColumnWriterState &state, u state.write_info.insert(state.write_info.begin(), std::move(write_info)); } +void PrimitiveColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + auto &type = schema.type; + auto allow_geometry = schema.allow_geometry; + + duckdb_parquet::SchemaElement schema_element; + schema_element.type = ParquetWriter::DuckDBTypeToParquetType(type); + schema_element.repetition_type = repetition_type; + schema_element.__isset.num_children = false; + schema_element.__isset.type = true; + schema_element.__isset.repetition_type = true; + schema_element.name = name; + if (field_id.IsValid()) { + schema_element.__isset.field_id = true; + schema_element.field_id = field_id.GetIndex(); + } + ParquetWriter::SetSchemaProperties(type, schema_element, allow_geometry); + schemas.push_back(std::move(schema_element)); + + D_ASSERT(child_writers.empty()); +} + } // namespace duckdb diff --git a/extension/parquet/writer/struct_column_writer.cpp b/extension/parquet/writer/struct_column_writer.cpp index c9b6bcf9d7c2..a792b736bef8 100644 --- a/extension/parquet/writer/struct_column_writer.cpp +++ b/extension/parquet/writer/struct_column_writer.cpp @@ -2,6 +2,11 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; + class StructColumnWriterState : public ColumnWriterState { public: StructColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) @@ -100,4 +105,33 @@ void StructColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { } } +void StructColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + + // set up the schema element for this struct + duckdb_parquet::SchemaElement schema_element; + schema_element.repetition_type = repetition_type; + schema_element.num_children = child_writers.size(); + schema_element.__isset.num_children = true; + schema_element.__isset.type = false; + schema_element.__isset.repetition_type = true; + schema_element.name = name; + if (field_id.IsValid()) { + schema_element.__isset.field_id = true; + schema_element.field_id = field_id.GetIndex(); + } + schemas.push_back(std::move(schema_element)); + + for (auto &child_writer : child_writers) { + child_writer->FinalizeSchema(schemas); + } +} + } // namespace duckdb diff --git a/extension/parquet/writer/variant/CMakeLists.txt b/extension/parquet/writer/variant/CMakeLists.txt index 955efb46c2a6..a255f061c722 100644 --- a/extension/parquet/writer/variant/CMakeLists.txt +++ b/extension/parquet/writer/variant/CMakeLists.txt @@ -1,4 +1,5 @@ -add_library_unity(duckdb_parquet_writer_variant OBJECT convert_variant.cpp) +add_library_unity(duckdb_parquet_writer_variant OBJECT convert_variant.cpp + analyze_variant.cpp) set(PARQUET_EXTENSION_FILES ${PARQUET_EXTENSION_FILES} $ diff --git a/extension/parquet/writer/variant/analyze_variant.cpp b/extension/parquet/writer/variant/analyze_variant.cpp new file mode 100644 index 000000000000..c7575d09fb6d --- /dev/null +++ b/extension/parquet/writer/variant/analyze_variant.cpp @@ -0,0 +1,202 @@ +#include "writer/variant_column_writer.hpp" +#include "parquet_writer.hpp" +#include "duckdb/common/types/decimal.hpp" + +namespace duckdb { + +unique_ptr VariantColumnWriter::AnalyzeSchemaInit() { + if (child_writers.size() == 2 && !is_analyzed) { + return make_uniq(); + } + //! Variant is already shredded explicitly, no need to analyze + return nullptr; +} + +static void AnalyzeSchemaInternal(VariantAnalyzeData &state, UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_index) { + if (!variant.RowIsValid(row)) { + state.type_map[static_cast(VariantLogicalType::VARIANT_NULL)]++; + return; + } + + auto type_id = variant.GetTypeId(row, values_index); + state.type_map[static_cast(type_id)]++; + + if (type_id == VariantLogicalType::OBJECT) { + if (!state.object_data) { + state.object_data = make_uniq(); + } + auto &object_data = *state.object_data; + + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto child_values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + auto child_key_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + + auto &key = variant.GetKey(row, child_key_index); + auto &child_state = object_data.fields[key.GetString()]; + AnalyzeSchemaInternal(child_state, variant, row, child_values_index); + } + } else if (type_id == VariantLogicalType::ARRAY) { + if (!state.array_data) { + state.array_data = make_uniq(); + } + auto &array_data = *state.array_data; + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto child_values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + auto &child_state = array_data.child; + AnalyzeSchemaInternal(child_state, variant, row, child_values_index); + } + } else if (type_id == VariantLogicalType::DECIMAL) { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + auto physical_type = decimal_data.GetPhysicalType(); + switch (physical_type) { + case PhysicalType::INT32: + state.decimal_type_map[0]++; + break; + case PhysicalType::INT64: + state.decimal_type_map[1]++; + break; + case PhysicalType::INT128: + state.decimal_type_map[2]++; + break; + default: + break; + } + } else if (type_id == VariantLogicalType::BOOL_FALSE) { + //! Move it to bool_true to have the counts all in one place + state.type_map[static_cast(VariantLogicalType::BOOL_TRUE)]++; + state.type_map[static_cast(VariantLogicalType::BOOL_FALSE)]--; + } +} + +void VariantColumnWriter::AnalyzeSchema(ParquetAnalyzeSchemaState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(input, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + for (idx_t i = 0; i < count; i++) { + AnalyzeSchemaInternal(state.analyze_data, variant, i, 0); + } +} + +namespace { + +struct ShredAnalysisState { + idx_t highest_count = 0; + LogicalTypeId type_id; + PhysicalType decimal_type; +}; + +} // namespace + +template +static void CheckPrimitive(const VariantAnalyzeData &state, ShredAnalysisState &result) { + auto count = state.type_map[static_cast(VARIANT_TYPE)]; + if (VARIANT_TYPE == VariantLogicalType::DECIMAL) { + if (!count) { + return; + } + auto int32_count = state.decimal_type_map[0]; + if (int32_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT32; + } + auto int64_count = state.decimal_type_map[1]; + if (int64_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT64; + } + auto int128_count = state.decimal_type_map[2]; + if (int128_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT128; + } + } else { + if (count > result.highest_count) { + result.highest_count = count; + result.type_id = SHREDDED_TYPE; + } + } +} + +static LogicalType ConstructShreddedType(const VariantAnalyzeData &state) { + ShredAnalysisState result; + + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + //! FIXME: It's not enough for decimals to have the same PhysicalType, their width+scale has to match in order to + //! shred on the type. + // CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + + auto array_count = state.type_map[static_cast(VariantLogicalType::ARRAY)]; + auto object_count = state.type_map[static_cast(VariantLogicalType::OBJECT)]; + if (array_count > object_count) { + if (array_count > result.highest_count) { + auto &array_data = *state.array_data; + return LogicalType::LIST(ConstructShreddedType(array_data.child)); + } + } else { + if (object_count > result.highest_count) { + auto &object_data = *state.object_data; + + //! TODO: implement some logic to determine which fields are worth shredding, considering the overhead when + //! only 10% of rows make use of the field + child_list_t field_types; + for (auto &field : object_data.fields) { + field_types.emplace_back(field.first, ConstructShreddedType(field.second)); + } + return LogicalType::STRUCT(field_types); + } + } + + if (result.type_id == LogicalTypeId::DECIMAL) { + //! TODO: what should the scale be??? + if (result.decimal_type == PhysicalType::INT32) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } else if (result.decimal_type == PhysicalType::INT64) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } else if (result.decimal_type == PhysicalType::INT128) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } + } + return result.type_id; +} + +void VariantColumnWriter::AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state_p) { + auto &state = state_p.Cast(); + auto shredded_type = ConstructShreddedType(state.analyze_data); + + auto typed_value = TransformTypedValueRecursive(shredded_type); + is_analyzed = true; + + auto &schema = Schema(); + auto &context = writer.GetContext(); + D_ASSERT(child_writers.size() == 2); + child_writers.pop_back(); + //! Recreate the column writer for 'value' because this is now "optional" + child_writers.push_back(ColumnWriter::CreateWriterRecursive(context, writer, schema_path, LogicalType::BLOB, + "value", false, nullptr, nullptr, schema.max_repeat, + schema.max_define + 1, true)); + child_writers.push_back(ColumnWriter::CreateWriterRecursive(context, writer, schema_path, typed_value, + "typed_value", false, nullptr, nullptr, + schema.max_repeat, schema.max_define + 1, true)); +} + +} // namespace duckdb diff --git a/extension/parquet/writer/variant/convert_variant.cpp b/extension/parquet/writer/variant/convert_variant.cpp index 836229d19c3b..ace71d26542c 100644 --- a/extension/parquet/writer/variant/convert_variant.cpp +++ b/extension/parquet/writer/variant/convert_variant.cpp @@ -1119,6 +1119,33 @@ static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &r } } +void VariantColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = Schema(); + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + + // variant group + duckdb_parquet::SchemaElement top_element; + top_element.repetition_type = repetition_type; + top_element.num_children = child_writers.size(); + top_element.logicalType.__isset.VARIANT = true; + top_element.logicalType.VARIANT.__isset.specification_version = true; + top_element.logicalType.VARIANT.specification_version = 1; + top_element.__isset.logicalType = true; + top_element.__isset.num_children = true; + top_element.__isset.repetition_type = true; + top_element.name = name; + schemas.push_back(std::move(top_element)); + + for (auto &child_writer : child_writers) { + child_writer->FinalizeSchema(schemas); + } +} + LogicalType VariantColumnWriter::TransformTypedValueRecursive(const LogicalType &type) { switch (type.id()) { case LogicalTypeId::STRUCT: { diff --git a/scripts/generate_metric_enums.py b/scripts/generate_metric_enums.py index fb05478d8cd6..f57a1b1ec9dc 100644 --- a/scripts/generate_metric_enums.py +++ b/scripts/generate_metric_enums.py @@ -14,41 +14,56 @@ optimizer_file = os.path.join("..", "src", "include", "duckdb", "common", "enums", "optimizer_type.hpp") metrics = [ - "QUERY_NAME", + "ATTACH_LOAD_STORAGE_LATENCY", + "ATTACH_REPLAY_WAL_LATENCY", "BLOCKED_THREAD_TIME", + "CHECKPOINT_LATENCY", "CPU_TIME", - "EXTRA_INFO", "CUMULATIVE_CARDINALITY", - "OPERATOR_TYPE", - "OPERATOR_CARDINALITY", "CUMULATIVE_ROWS_SCANNED", + "EXTRA_INFO", + "LATENCY", + "OPERATOR_CARDINALITY", + "OPERATOR_NAME", "OPERATOR_ROWS_SCANNED", "OPERATOR_TIMING", + "OPERATOR_TYPE", + "QUERY_NAME", "RESULT_SET_SIZE", - "LATENCY", "ROWS_RETURNED", - "OPERATOR_NAME", "SYSTEM_PEAK_BUFFER_MEMORY", "SYSTEM_PEAK_TEMP_DIR_SIZE", "TOTAL_BYTES_READ", "TOTAL_BYTES_WRITTEN", + "WAITING_TO_ATTACH_LATENCY", ] phase_timing_metrics = [ "ALL_OPTIMIZERS", "CUMULATIVE_OPTIMIZER_TIMING", - "PLANNER", - "PLANNER_BINDING", "PHYSICAL_PLANNER", "PHYSICAL_PLANNER_COLUMN_BINDING", - "PHYSICAL_PLANNER_RESOLVE_TYPES", "PHYSICAL_PLANNER_CREATE_PLAN", + "PHYSICAL_PLANNER_RESOLVE_TYPES", + "PLANNER", + "PLANNER_BINDING", ] query_global_metrics = [ + "ATTACH_LOAD_STORAGE_LATENCY", + "ATTACH_REPLAY_WAL_LATENCY", "BLOCKED_THREAD_TIME", + "CHECKPOINT_LATENCY", "SYSTEM_PEAK_BUFFER_MEMORY", "SYSTEM_PEAK_TEMP_DIR_SIZE", + "WAITING_TO_ATTACH_LATENCY", +] + +excluded_query_global_metrics = [ + "ATTACH_LOAD_STORAGE_LATENCY", + "ATTACH_REPLAY_WAL_LATENCY", + "CHECKPOINT_LATENCY", + "WAITING_TO_ATTACH_LATENCY", ] optimizer_types = [] @@ -382,7 +397,13 @@ def write_custom_profiling_optimizer(f): metrics.sort() for metric in metrics: - f.write(f'"{metric}": "true"\n') + skip = False + for excluded_metric in excluded_query_global_metrics: + if metric == excluded_metric: + skip = True + break + if not skip: + f.write(f'"{metric}": "true"\n') f.write("\n") write_statement( diff --git a/scripts/sqllogictest/__init__.py b/scripts/sqllogictest/__init__.py deleted file mode 100644 index 6053ce836ec1..000000000000 --- a/scripts/sqllogictest/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -from .token import TokenType, Token -from .base_statement import BaseStatement -from .test import SQLLogicTest -from .base_decorator import BaseDecorator -from .statement import ( - Statement, - Require, - Mode, - Halt, - Load, - Set, - Query, - HashThreshold, - Loop, - Foreach, - Endloop, - RequireEnv, - Restart, - Reconnect, - Sleep, - SleepUnit, - Skip, - Unzip, - Unskip, -) -from .decorator import SkipIf, OnlyIf -from .expected_result import ExpectedResult -from .parser import SQLLogicParser, SQLParserException - -__all__ = [ - TokenType, - Token, - BaseStatement, - SQLLogicTest, - BaseDecorator, - Statement, - ExpectedResult, - Require, - Mode, - Halt, - Load, - Set, - Query, - HashThreshold, - Loop, - Foreach, - Endloop, - RequireEnv, - Restart, - Reconnect, - Sleep, - SleepUnit, - Skip, - Unzip, - Unskip, - SkipIf, - OnlyIf, - SQLLogicParser, - SQLParserException, -] diff --git a/scripts/sqllogictest/base_decorator.py b/scripts/sqllogictest/base_decorator.py deleted file mode 100644 index 93222ce43289..000000000000 --- a/scripts/sqllogictest/base_decorator.py +++ /dev/null @@ -1,6 +0,0 @@ -from sqllogictest.token import Token - - -class BaseDecorator: - def __init__(self, token: Token): - self.token: Token = token diff --git a/scripts/sqllogictest/base_statement.py b/scripts/sqllogictest/base_statement.py deleted file mode 100644 index ed80dbf94c84..000000000000 --- a/scripts/sqllogictest/base_statement.py +++ /dev/null @@ -1,25 +0,0 @@ -from sqllogictest.token import Token, TokenType -from sqllogictest.base_decorator import BaseDecorator -from typing import List - - -class BaseStatement: - def __init__(self, header: Token, line: int): - self.header: Token = header - self.query_line: int = line - self.decorators: List[BaseDecorator] = [] - - def add_decorators(self, decorators: List[BaseDecorator]): - self.decorators = decorators - - def get_decorators(self) -> List[BaseDecorator]: - return self.decorators - - def get_query_line(self) -> int: - return self.query_line - - def get_type(self) -> TokenType: - return self.header.type - - def get_parameters(self) -> List[str]: - return self.header.parameters diff --git a/scripts/sqllogictest/decorator/__init__.py b/scripts/sqllogictest/decorator/__init__.py deleted file mode 100644 index 4457291e1da1..000000000000 --- a/scripts/sqllogictest/decorator/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .skip_if import SkipIf -from .only_if import OnlyIf - -__all__ = [SkipIf, OnlyIf] diff --git a/scripts/sqllogictest/decorator/only_if.py b/scripts/sqllogictest/decorator/only_if.py deleted file mode 100644 index 9f1d8ced499b..000000000000 --- a/scripts/sqllogictest/decorator/only_if.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_decorator import BaseDecorator -from sqllogictest.token import Token - - -class OnlyIf(BaseDecorator): - def __init__(self, token: Token): - super().__init__(token) diff --git a/scripts/sqllogictest/decorator/skip_if.py b/scripts/sqllogictest/decorator/skip_if.py deleted file mode 100644 index 2ad62288e381..000000000000 --- a/scripts/sqllogictest/decorator/skip_if.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_decorator import BaseDecorator -from sqllogictest.token import Token - - -class SkipIf(BaseDecorator): - def __init__(self, token: Token): - super().__init__(token) diff --git a/scripts/sqllogictest/expected_result.py b/scripts/sqllogictest/expected_result.py deleted file mode 100644 index ecfd5826903f..000000000000 --- a/scripts/sqllogictest/expected_result.py +++ /dev/null @@ -1,23 +0,0 @@ -from enum import Enum, auto -from typing import Optional, List - - -class ExpectedResult: - class Type(Enum): - SUCCESS = auto() - ERROR = auto() - UNKNOWN = auto() - - def __init__(self, type: "ExpectedResult.Type"): - self.type = type - self.lines: Optional[List[str]] = None - self.column_count: Optional[int] = None - - def add_lines(self, lines: List[str]): - self.lines = lines - - def set_expected_column_count(self, column_count: int): - self.column_count = column_count - - def get_expected_column_count(self) -> Optional[int]: - return self.column_count diff --git a/scripts/sqllogictest/logger.py b/scripts/sqllogictest/logger.py deleted file mode 100644 index aa83897b68eb..000000000000 --- a/scripts/sqllogictest/logger.py +++ /dev/null @@ -1,187 +0,0 @@ -import logging -import termcolor -from typing import Union -from duckdb import tokenize, token_type -from .statement import Query, Statement - - -class SQLLogicTestLogger: - def __init__(self, context, command: Union[Query, Statement], file_name: str): - self.file_name = file_name - self.context = context - self.query_line = command.query_line - self.sql_query = '\n'.join(command.lines) - - def log(self, message): - logging.error(message) - - def print_expected_result(self, values, columns, row_wise): - if row_wise: - for value in values: - print(value) - else: - c = 0 - for value in values: - if c != 0: - print("\t", end="") - print(value, end="") - c += 1 - if c >= columns: - c = 0 - print() - - def print_line_sep(self): - line_sep = "=" * 80 - print(termcolor.colored(line_sep, 'grey')) - - def print_header(self, header): - print(termcolor.colored(header, 'white', attrs=['bold'])) - - def print_file_header(self): - self.print_header(f"File {self.file_name}:{self.query_line})") - - def print_sql(self): - query = self.sql_query.strip() - if not query.endswith(";"): - query += ";" - query = self.context.replace_keywords(query) - print(query) - - def print_sql_formatted(self): - print(termcolor.colored("SQL Query", attrs=['bold'])) - query = self.context.replace_keywords(self.sql_query) - tokens = tokenize(query) - for i, token in enumerate(tokens): - next_token_start = tokens[i + 1].start if i + 1 < len(tokens) else len(query) - token_text = query[token.start : next_token_start] - # Apply highlighting based on token type - if token.type in [token_type.identifier, token_type.numeric_const, token_type.string_const]: - print(termcolor.colored(token_text, 'yellow'), end="") - elif token.type == token_type.keyword: - print(termcolor.colored(token_text, 'green', attrs=['bold']), end="") - else: - print(token_text, end="") - print() - - def print_error_header(self, description): - self.print_line_sep() - print(termcolor.colored(description, 'red', attrs=['bold']), end=" ") - print(termcolor.colored(f"({self.file_name}:{self.query_line})!", attrs=['bold'])) - - def print_result_error(self, result_values, values, expected_column_count, row_wise): - self.print_header("Expected result:") - self.print_line_sep() - self.print_expected_result(values, expected_column_count, row_wise) - self.print_line_sep() - self.print_header("Actual result:") - self.print_line_sep() - self.print_expected_result(result_values, expected_column_count, False) - - def unexpected_failure(self, result): - self.print_line_sep() - print(f"Query unexpectedly failed ({self.file_name}:{self.query_line})\n") - self.print_line_sep() - self.print_sql() - self.print_line_sep() - print(result) # FIXME - - def output_result(self, result, result_values_string): - for column_name in result.names: - print(column_name, end="\t") - print() - for column_type in result.types: - print(column_type.to_string(), end="\t") - print() - self.print_line_sep() - for r in range(result.row_count): - for c in range(result.column_count): - print(result_values_string[r * result.column_count + c], end="\t") - print() - - def output_hash(self, hash_value): - self.print_line_sep() - self.print_sql() - self.print_line_sep() - print(hash_value) - self.print_line_sep() - - def column_count_mismatch(self, result, result_values_string, expected_column_count, row_wise): - self.print_error_header("Wrong column count in query!") - print( - f"Expected {termcolor.colored(expected_column_count, 'white', attrs=['bold'])} columns, but got {termcolor.colored(result.column_count, 'white', attrs=['bold'])} columns" - ) - self.print_line_sep() - self.print_sql() - self.print_line_sep() - self.print_result_error(result_values_string, result._result, expected_column_count, row_wise) - - def not_cleanly_divisible(self, expected_column_count, actual_column_count): - self.print_error_header("Error in test!") - print(f"Expected {expected_column_count} columns, but {actual_column_count} values were supplied") - print("This is not cleanly divisible (i.e. the last row does not have enough values)") - - def wrong_row_count(self, expected_rows, result_values_string, comparison_values, expected_column_count, row_wise): - self.print_error_header("Wrong row count in query!") - row_count = len(result_values_string) - print( - f"Expected {termcolor.colored(int(expected_rows), 'white', attrs=['bold'])} rows, but got {termcolor.colored(row_count, 'white', attrs=['bold'])} rows" - ) - self.print_line_sep() - self.print_sql() - self.print_line_sep() - self.print_result_error(result_values_string, comparison_values, expected_column_count, row_wise) - - def column_count_mismatch_correct_result(self, original_expected_columns, expected_column_count, result): - self.print_line_sep() - self.print_error_header("Wrong column count in query!") - print( - f"Expected {termcolor.colored(original_expected_columns, 'white', attrs=['bold'])} columns, but got {termcolor.colored(expected_column_count, 'white', attrs=['bold'])} columns" - ) - self.print_line_sep() - self.print_sql() - print(f"The expected result {termcolor.colored('matched', 'white', attrs=['bold'])} the query result.") - print( - f"Suggested fix: modify header to \"{termcolor.colored('query', 'green')} {'I' * result.column_count}{termcolor.colored('', 'white')}\"" - ) - self.print_line_sep() - - def split_mismatch(self, row_number, expected_column_count, split_count): - self.print_line_sep() - self.print_error_header(f"Error in test! Column count mismatch after splitting on tab on row {row_number}!") - print( - f"Expected {termcolor.colored(int(expected_column_count), 'white', attrs=['bold'])} columns, but got {termcolor.colored(split_count, 'white', attrs=['bold'])} columns" - ) - print("Does the result contain tab values? In that case, place every value on a single row.") - self.print_line_sep() - - def wrong_result_hash(self, expected_hash_value, hash_value): - self.print_error_header("Wrong result hash when comparing to previous query result!") - self.print_line_sep() - self.print_sql() - self.print_line_sep() - self.print_header("Expected hash value:") - self.print_line_sep() - print(expected_hash_value) - self.print_line_sep() - self.print_header("Actual hash value:") - self.print_line_sep() - print(hash_value) - self.print_line_sep() - - def unexpected_statement(self, expect_ok, result): - description = "Query unexpectedly succeeded!" if not expect_ok else "Query unexpectedly failed!" - self.print_error_header(description) - self.print_line_sep() - self.print_sql() - self.print_line_sep() - result.print() - - def expected_error_mismatch(self, expected_error, result): - self.print_error_header( - f"Query failed, but error message did not match expected error message: {expected_error}" - ) - self.print_line_sep() - self.print_sql() - self.print_header("Actual result:") - self.print_line_sep() - result.print() diff --git a/scripts/sqllogictest/parser/__init__.py b/scripts/sqllogictest/parser/__init__.py deleted file mode 100644 index 90c98254ec73..000000000000 --- a/scripts/sqllogictest/parser/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parser import SQLLogicParser, SQLParserException, SQLLogicTest - -__all__ = [SQLLogicParser, SQLParserException, SQLLogicTest] diff --git a/scripts/sqllogictest/parser/parser.py b/scripts/sqllogictest/parser/parser.py deleted file mode 100644 index 021693de3898..000000000000 --- a/scripts/sqllogictest/parser/parser.py +++ /dev/null @@ -1,635 +0,0 @@ -import os - -from typing import List, Optional - -from ..token import Token, TokenType - -from ..expected_result import ExpectedResult - -from ..statement import ( - Statement, - Require, - Mode, - Halt, - Set, - Load, - Query, - HashThreshold, - Loop, - Foreach, - Endloop, - RequireEnv, - Restart, - Reconnect, - Sleep, - Skip, - Unzip, - Unskip, - SortStyle, -) -from ..statement.sleep import get_sleep_unit, SleepUnit - -from ..decorator import SkipIf, OnlyIf - -from ..base_decorator import BaseDecorator -from ..base_statement import BaseStatement -from ..test import SQLLogicTest - - -def create_formatted_list(items) -> str: - res = '' - for i, option in enumerate(items): - if i + 1 == len(items): - spacer = ' or ' - elif i != 0: - spacer = ', ' - else: - spacer = '' - res += f"{spacer}'{option}'" - return res - - -def is_space(char: str): - return char == ' ' or char == '\t' or char == '\n' or char == '\v' or char == '\f' or char == '\r' - - -### -------- PARSER ---------- -class SQLParserException(Exception): - def __init__(self, message): - self.message = "Parser Error: " + message - super().__init__(self.message) - - -class SQLLogicParser: - def reset(self): - self.current_line = 0 - self.seen_statement = False - self.lines = [] - self.current_test = None - - def __init__(self): - self.reset() - self.STATEMENTS = { - TokenType.SQLLOGIC_STATEMENT: self.statement_statement, - TokenType.SQLLOGIC_QUERY: self.statement_query, - TokenType.SQLLOGIC_REQUIRE: self.statement_require, - TokenType.SQLLOGIC_HASH_THRESHOLD: self.statement_hash_threshold, - TokenType.SQLLOGIC_HALT: self.statement_halt, - TokenType.SQLLOGIC_MODE: self.statement_mode, - TokenType.SQLLOGIC_SET: self.statement_set, - TokenType.SQLLOGIC_LOOP: self.statement_loop, - TokenType.SQLLOGIC_CONCURRENT_LOOP: self.statement_loop, - TokenType.SQLLOGIC_FOREACH: self.statement_foreach, - TokenType.SQLLOGIC_CONCURRENT_FOREACH: self.statement_foreach, - TokenType.SQLLOGIC_ENDLOOP: self.statement_endloop, - TokenType.SQLLOGIC_REQUIRE_ENV: self.statement_require_env, - TokenType.SQLLOGIC_LOAD: self.statement_load, - TokenType.SQLLOGIC_RESTART: self.statement_restart, - TokenType.SQLLOGIC_RECONNECT: self.statement_reconnect, - TokenType.SQLLOGIC_SLEEP: self.statement_sleep, - TokenType.SQLLOGIC_UNZIP: self.statement_unzip, - TokenType.SQLLOGIC_INVALID: None, - } - self.DECORATORS = { - TokenType.SQLLOGIC_SKIP_IF: self.decorator_skipif, - TokenType.SQLLOGIC_ONLY_IF: self.decorator_onlyif, - } - self.FOREACH_COLLECTIONS = { - "": [ - "none", - "uncompressed", - "rle", - "bitpacking", - "dictionary", - "fsst", - "dict_fsst", - "alp", - "alprd", - ], - "": ["bool", "interval", "varchar"], - "": ["float", "double"], - "": ["tinyint", "smallint", "integer", "bigint", "hugeint"], - "": ["tinyint", "smallint", "integer", "bigint", "hugeint"], - "": ["utinyint", "usmallint", "uinteger", "ubigint", "uhugeint"], - "": [ - "bool", - "tinyint", - "smallint", - "int", - "bigint", - "hugeint", - "uhugeint", - "utinyint", - "usmallint", - "uint", - "ubigint", - "date", - "time", - "timestamp", - "timestamp_s", - "timestamp_ms", - "timestamp_ns", - "time_tz", - "timestamp_tz", - "float", - "double", - "dec_4_1", - "dec_9_4", - "dec_18_6", - "dec38_10", - "uuid", - "interval", - "varchar", - "blob", - "bit", - "small_enum", - "medium_enum", - "large_enum", - "int_array", - "double_array", - "date_array", - "timestamp_array", - "timestamptz_array", - "varchar_array", - "nested_int_array", - "struct", - "struct_of_arrays", - "array_of_structs", - "map", - "union", - "fixed_int_array", - "fixed_varchar_array", - "fixed_nested_int_array", - "fixed_nested_varchar_array", - "fixed_struct_array", - "struct_of_fixed_array", - "fixed_array_of_int_list", - "list_of_fixed_int_array", - ], - } - - def peek(self): - return self.peek_no_strip().strip() - - def peek_no_strip(self): - if self.current_line >= len(self.lines): - raise SQLParserException("File already fully consumed") - return self.lines[self.current_line] - - def consume(self): - if self.current_line >= len(self.lines): - raise SQLParserException("File already fully consumed") - self.current_line += 1 - - def fail(self, message): - file_path = self.current_test.path - error_message = f"{file_path}:{self.current_line + 1}: {message}" - raise SQLParserException(error_message) - - def get_expected_result(self, statement_type: str) -> ExpectedResult: - type_map = { - 'ok': ExpectedResult.Type.SUCCESS, - 'error': ExpectedResult.Type.ERROR, - 'maybe': ExpectedResult.Type.UNKNOWN, - } - if statement_type not in type_map: - error = 'statement argument should be ' + create_formatted_list(type_map.keys()) - self.fail(error) - return ExpectedResult(type_map[statement_type]) - - def extract_expected_lines(self) -> Optional[List[str]]: - end_of_file = self.current_line >= len(self.lines) - if end_of_file or self.peek() != "----": - return None - - self.consume() - result = [] - while self.current_line < len(self.lines) and self.peek_no_strip().strip('\n'): - result.append(self.peek_no_strip().strip('\n')) - self.consume() - return result - - def statement_statement(self, header: Token) -> Optional[BaseStatement]: - options = ['ok', 'error', 'maybe'] - if len(header.parameters) < 1: - self.fail(f"statement requires at least one parameter ({create_formatted_list(options)})") - expected_result = self.get_expected_result(header.parameters[0]) - - statement = Statement(header, self.current_line + 1) - statement.file_name = self.current_test.path - - self.next_line() - statement_text = self.extract_statement() - if statement_text == []: - self.fail("Unexpected empty statement text") - statement.add_lines(statement_text) - - expected_lines: Optional[List[str]] = self.extract_expected_lines() - if expected_result.type == ExpectedResult.Type.SUCCESS: - if expected_lines != None: - if len(expected_lines) != 0: - self.fail( - "Failed to parse statement: only statement error can have an expected error message, not statement ok" - ) - expected_result.add_lines(expected_lines) - elif expected_result.type == ExpectedResult.Type.ERROR or expected_result.type == ExpectedResult.Type.UNKNOWN: - if expected_lines != None: - expected_result.add_lines(expected_lines) - elif not self.current_test.is_sqlite_test(): - print(statement) - self.fail('Failed to parse statement: statement error needs to have an expected error message') - else: - self.fail(f"Unexpected ExpectedResult Type: {expected_result.type.name}") - - statement.expected_result = expected_result - if len(header.parameters) >= 2: - statement.set_connection(header.parameters[1]) - return statement - - def statement_query(self, header: Token) -> BaseStatement: - if len(header.parameters) < 1: - self.fail("query requires at least one parameter (query III)") - query = Query(header, self.current_line + 1) - - # parse the expected column count - query.expected_column_count = 0 - column_text = header.parameters[0] - accepted_chars = ['T', 'I', 'R'] - if not all(x in accepted_chars for x in column_text): - self.fail(f"Found unknown character in {column_text}, expected {create_formatted_list(accepted_chars)}") - expected_column_count = len(column_text) - - query.expected_column_count = expected_column_count - if query.expected_column_count == 0: - self.fail("Query requires at least a single column in the result") - - query.file_name = self.current_test.path - query.query_line = self.current_line + 1 - # extract the SQL statement - self.next_line() - statement_text = self.extract_statement() - query.add_lines(statement_text) - - # extract the expected result - expected_result = self.get_expected_result('ok') - expected_lines: Optional[List[str]] = self.extract_expected_lines() - if expected_lines != None: - expected_result.add_lines(expected_lines) - expected_result.set_expected_column_count(expected_column_count) - query.expected_result = expected_result - - def get_sort_style(parameters: List[str]) -> SortStyle: - sort_style = SortStyle.NO_SORT - if len(parameters) > 1: - sort_style = parameters[1] - if sort_style == "nosort": - # Do no sorting - sort_style = SortStyle.NO_SORT - elif sort_style == "rowsort" or sort_style == "sort": - # Row-oriented sorting - sort_style = SortStyle.ROW_SORT - elif sort_style == "valuesort": - # Sort all values independently - sort_style = SortStyle.VALUE_SORT - else: - sort_style = SortStyle.UNKNOWN - return sort_style - - # figure out the sort style - sort_style = get_sort_style(header.parameters) - if sort_style == SortStyle.UNKNOWN: - sort_style = SortStyle.NO_SORT - query.set_connection(header.parameters[1]) - query.set_sortstyle(sort_style) - - # check the label of the query - if len(header.parameters) > 2: - query.set_label(header.parameters[2]) - return query - - def statement_hash_threshold(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) != 1: - self.fail("hash-threshold requires a parameter") - threshold = int(header.parameters[0]) - return HashThreshold(header, self.current_line + 1, threshold) - - def statement_halt(self, header: Token) -> Optional[BaseStatement]: - return Halt(header, self.current_line + 1) - - def statement_mode(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) != 1: - self.fail("mode requires one parameter") - parameter = header.parameters[0] - if parameter == "skip": - return Skip(header, self.current_line + 1) - elif parameter == "unskip": - return Unskip(header, self.current_line + 1) - else: - return Mode(header, self.current_line + 1, parameter) - - def statement_require(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) < 1: - self.fail("require requires a single parameter") - return Require(header, self.current_line + 1) - - def statement_set(self, header: Token) -> Optional[BaseStatement]: - parameters = header.parameters - if len(parameters) < 1: - self.fail("set requires at least 1 parameter (e.g. set ignore_error_messages HTTP Error)") - accepted_options = ['ignore_error_messages', 'always_fail_error_messages', 'seed'] - if parameters[0] in accepted_options: - error_messages = [] - # Parse the parameter list as a comma separated list of strings that can contain spaces - # e.g. `set ignore_error_messages This is an error message, This_is_another, and another` - tmp = [[y.strip() for y in x.split(',') if y.strip() != ''] for x in parameters[1:]] - for x in tmp: - error_messages.extend(x) - statement = Set(header, self.current_line + 1) - statement.add_error_messages(error_messages) - return statement - else: - self.fail( - f"unrecognized set parameter: {parameters[0]}, expected {create_formatted_list(accepted_options)}" - ) - - def statement_load(self, header: Token) -> Optional[BaseStatement]: - statement = Load(header, self.current_line + 1) - if len(header.parameters) > 1 and header.parameters[1] == "readonly": - statement.set_readonly() - if len(header.parameters) > 2: - statement.set_version(header.parameters[2]) - return statement - - def statement_loop(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) != 3: - self.fail("Expected loop [iterator_name] [start] [end] (e.g. loop i 1 300)") - is_parallel = header.type == TokenType.SQLLOGIC_CONCURRENT_LOOP - statement = Loop(header, self.current_line + 1, is_parallel) - statement.set_name(header.parameters[0]) - statement.set_start(int(header.parameters[1])) - statement.set_end(int(header.parameters[2])) - return statement - - def statement_foreach(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) < 2: - self.fail( - "Expected foreach [iterator_name] [m1] [m2] [etc...] (e.g. foreach type integer " "smallint float)" - ) - is_parallel = header.type == TokenType.SQLLOGIC_CONCURRENT_FOREACH - statement = Foreach(header, self.current_line + 1, is_parallel) - statement.set_name(header.parameters[0]) - raw_values = header.parameters[1:] - - def add_tokens(result, param): - token_name = param.lower().strip() - - if token_name in self.FOREACH_COLLECTIONS: - result.extend(self.FOREACH_COLLECTIONS[token_name]) - else: - result.append(param) - - foreach_tokens = [] - for value in raw_values: - add_tokens(foreach_tokens, value) - - statement.set_values(foreach_tokens) - return statement - - def statement_endloop(self, header: Token) -> Optional[BaseStatement]: - return Endloop(header, self.current_line + 1) - - def statement_require_env(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) != 1 and len(header.parameters) != 2: - self.fail("require-env requires 1 argument: [optional: ]") - return RequireEnv(header, self.current_line + 1) - - def statement_restart(self, header: Token) -> Optional[BaseStatement]: - return Restart(header, self.current_line + 1) - - def statement_reconnect(self, header: Token) -> Optional[BaseStatement]: - return Reconnect(header, self.current_line + 1) - - def statement_sleep(self, header: Token) -> Optional[BaseStatement]: - if len(header.parameters) != 2: - self.fail("sleep requires two parameter (e.g. sleep 1 second)") - sleep_duration = int(header.parameters[0]) - sleep_unit = get_sleep_unit(header.parameters[1]) - if sleep_unit == SleepUnit.UNKNOWN: - options = ['second', 'millisecond', 'microsecond', 'nanosecond'] - self.fail(f"Unrecognized sleep mode - expected {create_formatted_list(options)}") - return Sleep(header, self.current_line + 1, sleep_duration, sleep_unit) - - def statement_unzip(self, header: Token) -> Optional[BaseStatement]: - params = header.parameters - if len(params) != 1 and len(params) != 2: - docs = """ - unzip requires 1 parameter, the path to a (g)zipped file. - Optionally a destination location can be provided, defaulting to '__TEST_DIR__/' - """ - self.fail(docs) - - source = params[0] - - accepted_filetypes = {'.gz'} - - basename = os.path.basename(source) - stem, extension = os.path.splitext(basename) - if extension not in accepted_filetypes: - accepted_options = ", ".join(list(accepted_filetypes)) - self.fail( - f"unzip: input does not end in a valid file extension ({extension}), accepted options are: {accepted_options}" - ) - destination = params[1] if len(params) == 2 else f'__TEST_DIR__/{stem}' - return Unzip(header, self.current_line + 1, source, destination) - - # Decorators - - def decorator_skipif(self, token: Token) -> Optional[BaseDecorator]: - return SkipIf(token) - - def decorator_onlyif(self, token: Token) -> Optional[BaseDecorator]: - return OnlyIf(token) - - def parse(self, file_path: str) -> Optional[SQLLogicTest]: - if not self.open_file(file_path): - return None - - while self.next_statement(): - token = self.tokenize() - - # throw explicit error on single line statements that are not separated by a comment or newline - if self.is_single_line_statement(token) and not self.next_line_empty_or_comment(): - self.fail("All test statements need to be separated by an empty line") - - # Parse any number of decorators first - parse_method = self.DECORATORS.get(token.type) - decorators: List[BaseDecorator] = [] - while parse_method != None: - decorator = parse_method(token) - if not decorator: - self.fail(f"Parser did not produce a decorator for {token.type.name}") - decorators.append(decorator) - self.next_line() - token = self.tokenize() - parse_method = self.DECORATORS.get(token.type) - - # Then parse the statement - parse_method = self.STATEMENTS.get(token.type) - if parse_method: - statement = parse_method(token) - else: - self.fail(f"Unexpected token type: {token.type.name}") - if not statement: - self.fail(f"Parser did not produce a statement for {token.type.name}") - statement.add_decorators(decorators) - self.current_test.add_statement(statement) - return self.current_test - - def open_file(self, path): - self.reset() - self.current_test = SQLLogicTest(path) - try: - with open(path, 'r') as infile: - self.lines = [line.replace("\r", "") for line in infile.readlines()] - return True - except IOError: - return False - except UnicodeDecodeError: - return False - - def empty_or_comment(self, line): - return not line.strip('\n') or line.startswith("#") - - def next_line_empty_or_comment(self): - if self.current_line + 1 >= len(self.lines): - return True - else: - return self.empty_or_comment(self.lines[self.current_line + 1]) - - def eof(self): - return self.current_line >= len(self.lines) - - def next_statement(self): - if self.seen_statement: - while not self.eof() and not self.empty_or_comment(self.peek()): - self.consume() - self.seen_statement = True - - while not self.eof() and self.empty_or_comment(self.peek()): - self.consume() - - return not self.eof() - - def next_line(self): - self.consume() - - def extract_statement(self): - statement = [] - - while not self.eof() and not self.empty_or_comment(self.peek_no_strip()): - line = self.peek_no_strip() - if line.strip('\n') == "----": - break - statement.append(line.strip('\n')) - self.consume() - return statement - - def tokenize(self): - result = Token() - if self.current_line >= len(self.lines): - result.type = TokenType.SQLLOGIC_INVALID - return result - - line = self.peek_no_strip() - argument_list = line.split() - argument_list = [x for x in line.strip('\n').split() if not is_space(x)] - - if not argument_list: - self.fail("Empty line!?") - - result.type = self.command_to_token(argument_list[0]) - result.parameters.extend(argument_list[1:]) - return result - - def is_single_line_statement(self, token): - single_line_statements = [ - TokenType.SQLLOGIC_HASH_THRESHOLD, - TokenType.SQLLOGIC_HALT, - TokenType.SQLLOGIC_MODE, - TokenType.SQLLOGIC_SET, - TokenType.SQLLOGIC_LOOP, - TokenType.SQLLOGIC_FOREACH, - TokenType.SQLLOGIC_CONCURRENT_LOOP, - TokenType.SQLLOGIC_CONCURRENT_FOREACH, - TokenType.SQLLOGIC_ENDLOOP, - TokenType.SQLLOGIC_REQUIRE, - TokenType.SQLLOGIC_REQUIRE_ENV, - TokenType.SQLLOGIC_LOAD, - TokenType.SQLLOGIC_RESTART, - TokenType.SQLLOGIC_RECONNECT, - TokenType.SQLLOGIC_SLEEP, - TokenType.SQLLOGIC_UNZIP, - ] - - if token.type in single_line_statements: - return True - elif token.type in [ - TokenType.SQLLOGIC_SKIP_IF, - TokenType.SQLLOGIC_ONLY_IF, - TokenType.SQLLOGIC_INVALID, - TokenType.SQLLOGIC_STATEMENT, - TokenType.SQLLOGIC_QUERY, - ]: - return False - else: - raise RuntimeError("Unknown SQLLogic token found!") - - def command_to_token(self, token): - token_map = { - "skipif": TokenType.SQLLOGIC_SKIP_IF, - "onlyif": TokenType.SQLLOGIC_ONLY_IF, - "statement": TokenType.SQLLOGIC_STATEMENT, - "query": TokenType.SQLLOGIC_QUERY, - "hash-threshold": TokenType.SQLLOGIC_HASH_THRESHOLD, - "halt": TokenType.SQLLOGIC_HALT, - "mode": TokenType.SQLLOGIC_MODE, - "set": TokenType.SQLLOGIC_SET, - "loop": TokenType.SQLLOGIC_LOOP, - "concurrentloop": TokenType.SQLLOGIC_CONCURRENT_LOOP, - "foreach": TokenType.SQLLOGIC_FOREACH, - "concurrentforeach": TokenType.SQLLOGIC_CONCURRENT_FOREACH, - "endloop": TokenType.SQLLOGIC_ENDLOOP, - "require": TokenType.SQLLOGIC_REQUIRE, - "require-env": TokenType.SQLLOGIC_REQUIRE_ENV, - "load": TokenType.SQLLOGIC_LOAD, - "restart": TokenType.SQLLOGIC_RESTART, - "reconnect": TokenType.SQLLOGIC_RECONNECT, - "unzip": TokenType.SQLLOGIC_UNZIP, - "sleep": TokenType.SQLLOGIC_SLEEP, - } - - if token in token_map: - return token_map[token] - else: - self.fail(f"Unrecognized parameter {token}") - return TokenType.SQLLOGIC_INVALID - - -import argparse - - -def main(): - parser = argparse.ArgumentParser(description="SQL Logic Parser") - parser.add_argument("filename", type=str, help="Path to the SQL logic file") - args = parser.parse_args() - - filename = args.filename - - parser = SQLLogicParser() - out: Optional[SQLLogicTest] = parser.parse(filename) - if not out: - raise SQLParserException(f"Test {filename} could not be parsed") - - -if __name__ == "__main__": - main() diff --git a/scripts/sqllogictest/result.py b/scripts/sqllogictest/result.py deleted file mode 100644 index a631860e51b1..000000000000 --- a/scripts/sqllogictest/result.py +++ /dev/null @@ -1,1365 +0,0 @@ -from hashlib import md5 -import gc - -from .base_statement import BaseStatement -from .test import SQLLogicTest -from .statement import ( - Statement, - Require, - Mode, - Halt, - Set, - Load, - Query, - HashThreshold, - Loop, - Foreach, - Endloop, - RequireEnv, - Restart, - Reconnect, - Sleep, - SleepUnit, - Skip, - Unzip, - SortStyle, - Unskip, -) - -from .expected_result import ExpectedResult -from typing import Optional, Any, Tuple, List, Dict, Generator -import typing - -from .logger import SQLLogicTestLogger -import duckdb -import os -import math -import time -import threading - -import re -from functools import cmp_to_key -from enum import Enum - -### Helper structs - - -class RequireResult(Enum): - MISSING = 0 - PRESENT = 1 - - -class ExecuteResult: - class Type(Enum): - SUCCESS = 0 - ERROR = 1 - SKIPPED = 2 - - def __init__(self, type: "ExecuteResult.Type"): - self.type = type - - -### Exceptions - -BUILTIN_EXTENSIONS = [ - 'json', - 'parquet', - 'icu', -] - -from duckdb import DuckDBPyConnection - -# def patch_execute(method): -# def patched_execute(self, *args, **kwargs): -# print(*args) -# return method(self, *args, **kwargs) -# return patched_execute - -# patched_execute = patch_execute(getattr(DuckDBPyConnection, "execute")) -# setattr(DuckDBPyConnection, "execute", patched_execute) - - -class SQLLogicStatementData: - # Context information about a statement - def __init__(self, test: SQLLogicTest, statement: BaseStatement): - self.test = test - self.statement = statement - - def __str__(self) -> str: - return f'{self.test.path}:{self.statement.get_query_line()}' - - __repr__ = __str__ - - -class TestException(Exception): - __test__ = False - __slots__ = ['data', 'message', 'result'] - - def __init__(self, data: SQLLogicStatementData, message: str, result: ExecuteResult): - self.message = message - super().__init__(self.message) - self.data = data - self.result = result - - def handle_result(self) -> ExecuteResult: - return self.result - - -class SkipException(TestException): - def __init__(self, data: SQLLogicStatementData, message: str): - super().__init__(data, message, ExecuteResult(ExecuteResult.Type.SKIPPED)) - - -class FailException(TestException): - def __init__(self, data: SQLLogicStatementData, message: str): - super().__init__(data, message, ExecuteResult(ExecuteResult.Type.ERROR)) - - -### Result primitive - - -class QueryResult: - def __init__(self, result: List[Tuple[Any]], types: List[str], error: Optional[Exception] = None): - self._result = result - self.types = types - self.error = error - if not error: - self._column_count = len(self.types) - self._row_count = len(result) - if self._row_count > 0: - assert self._column_count == len(self._result[0]) - - def get_value(self, column, row): - return self._result[row][column] - - def row_count(self) -> int: - return self._row_count - - @property - def column_count(self) -> int: - assert self._column_count != 0 - return self._column_count - - def has_error(self) -> bool: - return self.error != None - - def get_error(self) -> Optional[Exception]: - return self.error - - def check(self, context, query: Query) -> None: - expected_column_count = query.expected_result.get_expected_column_count() - values = query.expected_result.lines - sort_style = query.get_sortstyle() - query_label = query.get_label() - query_has_label = query_label != None - runner = context.runner - - logger = SQLLogicTestLogger(context, query, runner.test.path) - - # If the result has an error, log it - if self.has_error(): - logger.unexpected_failure() - if runner.skip_error_message(self.get_error()): - runner.finished_processing_file = True - return - context.fail(self.get_error()) - - row_count = self.row_count() - column_count = self.column_count - total_value_count = row_count * column_count - - if len(values) == 1 and result_is_hash(values[0]): - compare_hash = True - is_hash = True - else: - compare_hash = query_has_label or (runner.hash_threshold > 0 and total_value_count > runner.hash_threshold) - is_hash = False - - result_values_string = duck_db_convert_result(self, runner.original_sqlite_test) - - if runner.output_result_mode: - logger.output_result(self, result_values_string) - - if sort_style == SortStyle.ROW_SORT: - ncols = self.column_count - nrows = int(total_value_count / ncols) - rows = [result_values_string[i * ncols : (i + 1) * ncols] for i in range(nrows)] - - # Define the comparison function - def compare_rows(a, b): - for col_idx, val in enumerate(a): - a_val = val - b_val = b[col_idx] - if a_val != b_val: - return -1 if a_val < b_val else 1 - return 0 - - # Sort the individual rows based on element comparison - sorted_rows = sorted(rows, key=cmp_to_key(compare_rows)) - rows = sorted_rows - - for row_idx, row in enumerate(rows): - for col_idx, val in enumerate(row): - result_values_string[row_idx * ncols + col_idx] = val - elif sort_style == SortStyle.VALUE_SORT: - result_values_string.sort() - - comparison_values = [] - if len(values) == 1 and result_is_file(values[0]): - fname = context.replace_keywords(values[0]) - try: - comparison_values = load_result_from_file(fname, self) - # FIXME this is kind of dumb - # We concatenate it with tabs just so we can split it again later - for x in range(len(comparison_values)): - comparison_values[x] = "\t".join(list(comparison_values[x])) - except duckdb.Error as e: - logger.print_error_header(str(e)) - context.fail(f"Failed to load result from {fname}") - else: - comparison_values = values - - hash_value = "" - if runner.output_hash_mode or compare_hash: - hash_context = md5() - for val in result_values_string: - hash_context.update(str(val).encode()) - hash_context.update("\n".encode()) - digest = hash_context.hexdigest() - hash_value = f"{total_value_count} values hashing to {digest}" - if runner.output_hash_mode: - logger.output_hash(hash_value) - return - - if not compare_hash: - original_expected_columns = expected_column_count - column_count_mismatch = False - - if expected_column_count != self.column_count: - expected_column_count = self.column_count - column_count_mismatch = True - - expected_rows = len(comparison_values) / expected_column_count - row_wise = expected_column_count > 1 and len(comparison_values) == self.row_count() - - if not row_wise: - all_tabs = all("\t" in val for val in comparison_values) - row_wise = all_tabs - - if row_wise: - expected_rows = len(comparison_values) - row_wise = True - elif len(comparison_values) % expected_column_count != 0: - if column_count_mismatch: - logger.column_count_mismatch(self, values, original_expected_columns, row_wise) - else: - logger.not_cleanly_divisible(expected_column_count, len(comparison_values)) - # FIXME: the logger should just create the strings to send to self.fail()/self.skip() - context.fail("") - - if expected_rows != self.row_count(): - if column_count_mismatch: - logger.column_count_mismatch(self, values, original_expected_columns, row_wise) - else: - logger.wrong_row_count( - expected_rows, result_values_string, comparison_values, expected_column_count, row_wise - ) - context.fail("") - - if row_wise: - current_row = 0 - for i, val in enumerate(comparison_values): - splits = [x for x in val.split("\t") if x != ''] - if len(splits) != expected_column_count: - if column_count_mismatch: - logger.column_count_mismatch(self, values, original_expected_columns, row_wise) - logger.split_mismatch(i + 1, expected_column_count, len(splits)) - context.fail("") - for c, split_val in enumerate(splits): - lvalue_str = result_values_string[current_row * expected_column_count + c] - rvalue_str = split_val - success = compare_values(self, lvalue_str, split_val, c) - if not success: - logger.print_error_header("Wrong result in query!") - logger.print_line_sep() - logger.print_sql() - logger.print_line_sep() - print(f"Mismatch on row {current_row + 1}, column {c + 1}") - print(f"{lvalue_str} <> {rvalue_str}") - logger.print_line_sep() - logger.print_result_error(result_values_string, values, expected_column_count, row_wise) - context.fail("") - # Increment the assertion counter - assert success - current_row += 1 - else: - current_row, current_column = 0, 0 - for i, val in enumerate(comparison_values): - lvalue_str = result_values_string[current_row * expected_column_count + current_column] - rvalue_str = val - success = compare_values(self, lvalue_str, rvalue_str, current_column) - if not success: - logger.print_error_header("Wrong result in query!") - logger.print_line_sep() - logger.print_sql() - logger.print_line_sep() - print(f"Mismatch on row {current_row + 1}, column {current_column + 1}") - print(f"{lvalue_str} <> {rvalue_str}") - logger.print_line_sep() - logger.print_result_error(result_values_string, values, expected_column_count, row_wise) - context.fail("") - # Increment the assertion counter - assert success - - current_column += 1 - if current_column == expected_column_count: - current_row += 1 - current_column = 0 - - if column_count_mismatch: - logger.column_count_mismatch_correct_result(original_expected_columns, expected_column_count, self) - context.fail("") - else: - hash_compare_error = False - expected_hash_value = None - if query_has_label: - expected_hash_value = runner.hash_label_map.get(query_label) - if expected_hash_value is None: - runner.hash_label_map[query_label] = hash_value - runner.result_label_map[query_label] = self - else: - hash_compare_error = expected_hash_value != hash_value - - if is_hash and not hash_compare_error: - expected_hash_value = values[0] - hash_compare_error = values[0] != hash_value - - if hash_compare_error: - expected_result = runner.result_label_map.get(query_label) - logger.wrong_result_hash(expected_hash_value, hash_value) - - if expected_result: - logger.print_result_error( - result_values_string, - duck_db_convert_result(expected_result, runner.original_sqlite_test), - expected_result.column_count, - False, - ) - context.fail("") - - assert not hash_compare_error - - -class SQLLogicConnectionPool: - __slots__ = [ - 'connection', - 'cursors', - ] - - def __init__(self, con: duckdb.DuckDBPyConnection): - assert con - self.cursors = {} - self.connection = con - - def initialize_connection(self, context: "SQLLogicContext", con: duckdb.DuckDBPyConnection): - runner = context.runner - if runner.test.is_sqlite_test(): - con.execute("SET integer_division=true") - try: - con.execute("SET timezone='UTC'") - except duckdb.Error: - pass - env_var = os.getenv("LOCAL_EXTENSION_REPO") - if env_var: - con.execute("SET autoload_known_extensions=True") - con.execute(f"SET autoinstall_extension_repository='{env_var}'") - - def get_connection(self, name: Optional[str] = None) -> duckdb.DuckDBPyConnection: - """ - Either fetch the 'self.connection' object if name is None - Or get-or-create the cursor identified by name - """ - assert self.connection - if name is None: - return self.connection - - if name not in self.cursors: - # TODO: do we need to run any set up on a new named connection ?? - self.cursors[name] = self.connection.cursor() - return self.cursors[name] - - -class SQLLogicDatabase: - __slots__ = ['path', 'database', 'config'] - - def __init__( - self, path: str, context: Optional["SQLLogicContext"] = None, additional_config: Optional[Dict[str, str]] = None - ): - """ - Connection Hierarchy: - - database - └── connection - └── cursor1 - └── cursor2 - └── cursor3 - - 'connection' is a cursor of 'database'. - Every entry of 'cursors' is a cursor created from 'connection'. - - This is important to understand how ClientConfig settings affect each cursor. - """ - self.reset() - if additional_config: - self.config.update(additional_config) - self.path = path - - # Now re-open the current database - read_only = 'access_mode' in self.config and self.config['access_mode'] == 'read_only' - if 'access_mode' not in self.config: - self.config['access_mode'] = 'automatic' - self.database = duckdb.connect(path, read_only, self.config) - - # Load any previously loaded extensions again - if context: - for extension in context.runner.extensions: - self.load_extension(context, extension) - - def reset(self): - self.database: Optional[duckdb.DuckDBPyConnection] = None - self.config: Dict[str, Any] = { - 'allow_unsigned_extensions': True, - 'allow_unredacted_secrets': True, - } - self.path = '' - - def load_extension(self, context: "SQLLogicContext", extension: str): - if extension in BUILTIN_EXTENSIONS: - # No need to load - return - path = context.get_extension_path(extension) - # Serialize it as a POSIX compliant path - query = f"LOAD '{path}'" - self.database.execute(query) - - def connect(self) -> SQLLogicConnectionPool: - return SQLLogicConnectionPool(self.database.cursor()) - - -def is_regex(input: str) -> bool: - return input.startswith(":") or input.startswith(":") - - -def matches_regex(input: str, actual_str: str) -> bool: - if input.startswith(":"): - should_match = True - regex_str = input.replace(":", "") - else: - should_match = False - regex_str = input.replace(":", "") - # The exact match will never be the same, allow leading and trailing messages - if regex_str[:2] != '.*': - regex_str = ".*" + regex_str - if regex_str[-2:] != '.*': - regex_str = regex_str + '.*' - - re_options = re.DOTALL - re_pattern = re.compile(regex_str, re_options) - regex_matches = bool(re_pattern.fullmatch(actual_str)) - return regex_matches == should_match - - -def has_external_access(conn): - # this is required for the python tester to work, as we make use of replacement scans - try: - res = conn.sql("select current_setting('enable_external_access')").fetchone()[0] - return res - except duckdb.TransactionException: - return True - except duckdb.BinderException: - return True - except duckdb.InvalidInputException: - return True - - -def compare_values(result: QueryResult, actual_str, expected_str, current_column): - error = False - - if actual_str == expected_str: - return True - - if is_regex(expected_str): - return matches_regex(expected_str, actual_str) - - sql_type = result.types[current_column] - - def is_numeric(type) -> bool: - NUMERIC_TYPES = [ - "TINYINT", - "SMALLINT", - "INTEGER", - "BIGINT", - "HUGEINT", - "FLOAT", - "DOUBLE", - "DECIMAL", - "UTINYINT", - "USMALLINT", - "UINTEGER", - "UBIGINT", - "UHUGEINT", - ] - if str(type) in NUMERIC_TYPES: - return True - return 'DECIMAL' in str(type) - - if is_numeric(sql_type): - if sql_type in [duckdb.typing.FLOAT, duckdb.typing.DOUBLE]: - # ApproxEqual - expected = convert_value(expected_str, sql_type) - actual = convert_value(actual_str, sql_type) - if expected == actual: - return True - if math.isnan(expected) and math.isnan(actual): - return True - epsilon = abs(actual) * 0.01 + 0.00000001 - if abs(expected - actual) <= epsilon: - return True - return False - expected = convert_value(expected_str, sql_type) - actual = convert_value(actual_str, sql_type) - return expected == actual - - if sql_type == duckdb.typing.BOOLEAN or sql_type.id == 'timestamp with time zone': - expected = convert_value(expected_str, sql_type) - actual = convert_value(actual_str, sql_type) - return expected == actual - expected = sql_logic_test_convert_value(expected_str, sql_type, False) - actual = actual_str - error = actual != expected - - if error: - return False - return True - - -def result_is_hash(result): - parts = result.split() - if len(parts) != 5: - return False - if not parts[0].isdigit(): - return False - if parts[1] != "values" or parts[2] != "hashing" or len(parts[4]) != 32: - return False - return all([x.islower() or x.isnumeric() for x in parts[4]]) - - -def result_is_file(result: str): - return result.startswith(':') - - -def load_result_from_file(fname, result: QueryResult): - con = duckdb.connect() - con.execute(f"PRAGMA threads={os.cpu_count()}") - column_count = result.column_count - - fname = fname.replace(":", "") - - struct_definition = "STRUCT_PACK(" - for i in range(column_count): - if i > 0: - struct_definition += ", " - struct_definition += f"c{i} := VARCHAR" - struct_definition += ")" - - csv_result = con.execute( - f""" - SELECT * FROM read_csv( - '{fname}', - header=1, - sep='|', - columns={struct_definition}, - auto_detect=false, - all_varchar=true - ) - """ - ) - - return csv_result.fetchall() - - -def convert_value(value, type: str): - if value is None or value == 'NULL': - return 'NULL' - query = f'select $1::{type}' - return duckdb.execute(query, [value]).fetchone()[0] - - -def sql_logic_test_convert_value(value, sql_type, is_sqlite_test: bool) -> str: - if value is None or value == 'NULL': - return 'NULL' - if is_sqlite_test: - if sql_type in [ - duckdb.typing.BOOLEAN, - duckdb.typing.DOUBLE, - duckdb.typing.FLOAT, - ] or any([type_str in str(sql_type) for type_str in ['DECIMAL', 'HUGEINT']]): - return convert_value(value, 'BIGINT::VARCHAR') - if sql_type == duckdb.typing.BOOLEAN: - return "1" if convert_value(value, sql_type) else "0" - else: - res = convert_value(value, 'VARCHAR') - if len(res) == 0: - res = "(empty)" - else: - res = res.replace("\0", "\\0") - return res - - -def duck_db_convert_result(result: QueryResult, is_sqlite_test: bool) -> List[str]: - out_result = [] - row_count = result.row_count() - column_count = result.column_count - - for r in range(row_count): - for c in range(column_count): - value = result.get_value(c, r) - converted_value = sql_logic_test_convert_value(value, result.types[c], is_sqlite_test) - out_result.append(converted_value) - - return out_result - - -class SQLLogicRunner: - __slots__ = [ - 'skipped', - 'error', - 'skip_level', - 'loaded_databases', - 'database', - 'extensions', - 'environment_variables', - 'test', - 'hash_threshold', - 'hash_label_map', - 'result_label_map', - 'required_requires', - 'output_hash_mode', - 'output_result_mode', - 'debug_mode', - 'finished_processing_file', - 'ignore_error_messages', - 'always_fail_error_messages', - 'original_sqlite_test', - 'build_directory', - 'skip_reload', # <-- used for 'force_reload' and 'force_storage', unused for now - ] - - def reset(self): - self.skip_level: int = 0 - - # The set of databases that have been loaded by this runner at any point - # Used for cleanup - self.loaded_databases: typing.Set[str] = set() - self.database: Optional[SQLLogicDatabase] = None - self.extensions = set(BUILTIN_EXTENSIONS) - self.environment_variables: Dict[str, str] = {} - self.test: Optional[SQLLogicTest] = None - - self.hash_threshold: int = 0 - self.hash_label_map: Dict[str, str] = {} - self.result_label_map: Dict[str, Any] = {} - - # FIXME: create a CLI argument for this - self.required_requires: set = set() - self.output_hash_mode = False - self.output_result_mode = False - self.debug_mode = False - - self.finished_processing_file = False - # If these error messages occur in a test, the test will abort but still count as passed - self.ignore_error_messages = {"HTTP", "Unable to connect"} - # If these error messages occur in a statement that is expected to fail, the test will fail - self.always_fail_error_messages = {"differs from original result!", "INTERNAL"} - - self.original_sqlite_test = False - - def skip_error_message(self, message): - for error_message in self.ignore_error_messages: - if error_message in str(message): - return True - return False - - def __init__(self, build_directory: Optional[str] = None): - self.reset() - self.build_directory = build_directory - - def skip(self): - self.skip_level += 1 - - def unskip(self): - self.skip_level -= 1 - - def skip_active(self) -> bool: - return self.skip_level > 0 - - def is_required(self, param): - return param in self.required_requires - - -class SQLLogicContext: - __slots__ = [ - 'iterator', - 'runner', - 'generator', - 'STATEMENTS', - 'pool', - 'statements', - 'current_statement', - 'keywords', - 'error', - 'is_loop', - 'is_parallel', - 'build_directory', - 'cached_config_settings', - ] - - def reset(self): - self.iterator = 0 - - def replace_keywords(self, input: str): - # Apply a replacement for every registered keyword - if '__BUILD_DIRECTORY__' in input: - self.skiptest("Test contains __BUILD_DIRECTORY__ which isnt supported") - for key, value in self.keywords.items().__reversed__(): - input = input.replace(key, value) - return input - - def get_extension_path(self, extension: str): - if self.runner.build_directory is None: - self.skiptest("Tried to load an extension, but --build-dir was not set!") - root = self.runner.build_directory - path = os.path.join(root, "extension", extension, f"{extension}.duckdb_extension") - return path - - def __init__( - self, - pool: SQLLogicConnectionPool, - runner: SQLLogicRunner, - statements: List[BaseStatement], - keywords: Dict[str, str], - iteration_generator, - ): - self.statements = statements - self.runner = runner - self.is_loop = True - self.is_parallel = False - self.error: Optional[TestException] = None - self.generator: Generator[Any] = iteration_generator - self.keywords = keywords - self.cached_config_settings: List[Tuple[str, str]] = [] - self.current_statement: Optional[SQLLogicStatementData] = None - self.pool: Optional[SQLLogicConnectionPool] = pool - self.STATEMENTS = { - Query: self.execute_query, - Statement: self.execute_statement, - RequireEnv: self.execute_require_env, - Require: self.execute_require, - Load: self.execute_load, - Skip: self.execute_skip, - Unskip: self.execute_unskip, - Mode: self.execute_mode, - Sleep: self.execute_sleep, - Reconnect: self.execute_reconnect, - Halt: self.execute_halt, - Restart: self.execute_restart, - HashThreshold: self.execute_hash_threshold, - Set: self.execute_set, - Unzip: self.execute_unzip, - Loop: self.execute_loop, - Foreach: self.execute_foreach, - Endloop: None, # <-- should never be encountered outside of Loop/Foreach - } - - def add_keyword(self, key, value): - # Make sure that loop names can't silently collide - key = f'${{{key}}}' - assert key not in self.keywords - self.keywords[key] = str(value) - - def remove_keyword(self, key): - key = f'${{{key}}}' - assert key in self.keywords - self.keywords.pop(key) - - def fail(self, message: str): - self.error = FailException(self.current_statement, message) - raise self.error - - def skiptest(self, message: str): - self.error = SkipException(self.current_statement, message) - raise self.error - - def in_loop(self) -> bool: - return self.is_loop - - def get_connection(self, name: Optional[str] = None) -> duckdb.DuckDBPyConnection: - return self.pool.get_connection(name) - - def execute_load(self, load: Load): - if self.in_loop(): - # FIXME: should add support for this, the CPP tester supports this - self.skiptest("load cannot be called in a loop") - # self.fail("load cannot be called in a loop") - - readonly = load.readonly - - if load.header.parameters: - dbpath = load.header.parameters[0] - dbpath = self.replace_keywords(dbpath) - if not readonly: - # delete the target database file, if it exists - self.runner.delete_database(dbpath) - else: - dbpath = "" - self.runner.loaded_databases.add(dbpath) - - # set up the config file - additional_config = {} - if readonly: - additional_config['temp_directory'] = "" - additional_config['access_mode'] = 'read_only' - else: - additional_config['access_mode'] = 'automatic' - - if load.version: - additional_config['storage_compatibility_version'] = str(load.version) - - self.pool = None - self.runner.database = None - self.runner.database = SQLLogicDatabase(dbpath, self, additional_config) - self.pool = self.runner.database.connect() - - def execute_query(self, query: Query): - assert isinstance(query, Query) - conn = self.get_connection(query.connection_name) - if not has_external_access(conn): - self.skiptest("enable_external_access is explicitly disabled by the test") - sql_query = '\n'.join(query.lines) - sql_query = self.replace_keywords(sql_query) - - expected_result = query.expected_result - assert expected_result.type == ExpectedResult.Type.SUCCESS - - try: - statements = conn.extract_statements(sql_query) - statement = statements[-1] - if 'pivot' in sql_query and len(statements) != 1: - self.skiptest("Can not deal properly with a PIVOT statement") - - def returns_changed_rows(sql_query, statement) -> bool: - if duckdb.ExpectedResultType.CHANGED_ROWS not in statement.expected_result_type: - return False - if statement.type in [ - duckdb.StatementType.DELETE, - duckdb.StatementType.UPDATE, - duckdb.StatementType.INSERT, - duckdb.StatementType.MERGE_INTO, - ]: - if 'returning' in sql_query.lower(): - return False - return True - if statement.type in [duckdb.StatementType.COPY]: - if 'return_files' in sql_query.lower(): - return False - if 'return_stats' in sql_query.lower(): - return False - return True - return len(statement.expected_result_type) == 1 - - if returns_changed_rows(sql_query, statement): - conn.execute(sql_query) - result = conn.fetchall() - query_result = QueryResult(result, [duckdb.typing.BIGINT]) - elif duckdb.ExpectedResultType.QUERY_RESULT in statement.expected_result_type: - original_rel = conn.query(sql_query) - if original_rel is None: - query_result = QueryResult([(0,)], ['BIGINT']) - else: - original_types = original_rel.types - # We create new names for the columns, because they might be duplicated - aliased_columns = [f'c{i}' for i in range(len(original_types))] - - expressions = [f'"{name}"::VARCHAR' for name, sql_type in zip(aliased_columns, original_types)] - aliased_table = ", ".join(aliased_columns) - expression_list = ", ".join(expressions) - try: - # Select from the result, converting the Values to the right type for comparison - transformed_query = ( - f"select {expression_list} from original_rel unnamed_subquery_blabla({aliased_table})" - ) - stringified_rel = conn.query(transformed_query) - except duckdb.Error as e: - self.fail(f"Could not select from the ValueRelation: {str(e)}") - result = stringified_rel.fetchall() - query_result = QueryResult(result, original_types) - else: - conn.execute(sql_query) - result = conn.fetchall() - query_result = QueryResult(result, []) - if expected_result.lines == None: - return - except duckdb.Error as e: - print(e) - query_result = QueryResult([], [], e) - - query_result.check(self, query) - - def execute_skip(self, statement: Skip): - self.runner.skip() - - def execute_unzip(self, statement: Unzip): - import gzip - import shutil - - source = self.replace_keywords(statement.source) - destination = self.replace_keywords(statement.destination) - - with gzip.open(source, 'rb') as f_in: - with open(destination, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) - print(f"Extracted to '{destination}'") - - def execute_unskip(self, statement: Unskip): - self.runner.unskip() - - def execute_halt(self, statement: Halt): - self.skiptest("HALT was encountered in file") - - def execute_restart(self, statement: Restart): - if self.is_parallel: - self.fail("Cannot restart database in parallel") - - old_settings = self.cached_config_settings - - path = self.runner.database.path - self.pool = None - self.runner.database = None - gc.collect() - self.runner.database = SQLLogicDatabase(path, self) - self.pool = self.runner.database.connect() - con = self.pool.get_connection() - - for setting in old_settings: - name, value = setting - if name in [ - 'access_mode', - 'enable_external_access', - 'allow_unsigned_extensions', - 'allow_unredacted_secrets', - 'duckdb_api', - ]: - # Cannot be set after initialization - continue - - # If enable_profiling is NULL, skip setting custom_profiling_settings to not - # accidentally enable profiling. - # In that case, custom_profiling_settings is set to the default value anyway. - if name == "custom_profiling_settings" and "enable_profiling" not in old_settings: - continue - - query = f"SET {name}='{value}'" - con.execute(query) - - def execute_set(self, statement: Set): - option = statement.header.parameters[0] - if option == 'ignore_error_messages': - string_set = ( - self.runner.ignore_error_messages - if option == "ignore_error_messages" - else self.runner.always_fail_error_messages - ) - string_set.clear() - string_set = statement.error_messages - elif option == 'seed': - con = self.get_connection() - con.execute(f"SELECT SETSEED({statement.header.parameters[1]})") - self.runner.skip_reload = True - else: - self.skiptest(f"SET '{option}' is not implemented!") - - def execute_hash_threshold(self, statement: HashThreshold): - self.runner.hash_threshold = statement.threshold - - def execute_reconnect(self, statement: Reconnect): - if self.is_parallel: - self.fail("reconnect can not be used inside a parallel loop") - self.pool = None - self.pool = self.runner.database.connect() - con = self.pool.get_connection() - self.pool.initialize_connection(self, con) - - def execute_sleep(self, statement: Sleep): - def calculate_sleep_time(duration: float, unit: SleepUnit) -> float: - if unit == SleepUnit.SECOND: - return duration - elif unit == SleepUnit.MILLISECOND: - return duration / 1000 - elif unit == SleepUnit.MICROSECOND: - return duration / 1000000 - elif unit == SleepUnit.NANOSECOND: - return duration / 1000000000 - else: - raise ValueError("Unknown sleep unit") - - unit = statement.get_unit() - duration = statement.get_duration() - - time_to_sleep = calculate_sleep_time(duration, unit) - time.sleep(time_to_sleep) - - def execute_mode(self, statement: Mode): - parameter = statement.header.parameters[0] - if parameter == "output_hash": - self.runner.output_hash_mode = True - elif parameter == "output_result": - self.runner.output_result_mode = True - elif parameter == "no_output": - self.runner.output_hash_mode = False - self.runner.output_result_mode = False - elif parameter == "debug": - self.runner.debug_mode = True - else: - raise RuntimeError("unrecognized mode: " + parameter) - - def execute_statement(self, statement: Statement): - assert isinstance(statement, Statement) - conn = self.get_connection(statement.connection_name) - if not has_external_access(conn): - self.skiptest("enable_external_access is explicitly disabled by the test") - - sql_query = '\n'.join(statement.lines) - sql_query = self.replace_keywords(sql_query) - - expected_result = statement.expected_result - try: - conn.execute(sql_query) - result = conn.fetchall() - if expected_result.type == ExpectedResult.Type.ERROR: - self.fail(f"Query unexpectedly succeeded") - if expected_result.type != ExpectedResult.Type.UNKNOWN: - assert expected_result.lines == None - except duckdb.Error as e: - if expected_result.type == ExpectedResult.Type.SUCCESS: - self.fail(f"Query unexpectedly failed: {str(e)}") - if expected_result.lines == None: - return - expected = '\n'.join(expected_result.lines) - if is_regex(expected): - if not matches_regex(expected, str(e)): - self.fail( - f"Query failed, but did not produce the right error: {expected}\nInstead it produced: {str(e)}" - ) - else: - # Sanitize the expected error - if expected.startswith('Dependency Error: '): - expected = expected.split('Dependency Error: ')[1] - if expected not in str(e): - self.fail( - f"Query failed, but did not produce the right error: {expected}\nInstead it produced: {str(e)}" - ) - - def check_require(self, statement: Require) -> RequireResult: - not_an_extension = [ - "notmingw", - "mingw", - "notwindows", - "windows", - "longdouble", - "64bit", - "noforcestorage", - "nothreadsan", - "strinline", - "vector_size", - "exact_vector_size", - "block_size", - "skip_reload", - "no_alternative_verify", - ] - param = statement.header.parameters[0].lower() - if param in not_an_extension: - if param == 'vector_size': - required_vector_size = int(statement.header.parameters[1]) - if duckdb.__standard_vector_size__ < required_vector_size: - return RequireResult.MISSING - return RequireResult.PRESENT - if param == 'exact_vector_size': - required_vector_size = int(statement.header.parameters[1]) - if duckdb.__standard_vector_size__ == required_vector_size: - return RequireResult.PRESENT - return RequireResult.MISSING - if param == 'skip_reload': - self.runner.skip_reload = True - return RequireResult.PRESENT - return RequireResult.MISSING - - # Already loaded - if param in self.runner.extensions: - return RequireResult.PRESENT - - if self.runner.build_directory is None: - return RequireResult.MISSING - - connection = self.pool.get_connection() - autoload_known_extensions = connection.execute( - "select value::BOOLEAN from duckdb_settings() where name == 'autoload_known_extensions'" - ).fetchone()[0] - if param == "no_extension_autoloading": - if autoload_known_extensions: - # If autoloading is on, we skip this test - return RequireResult.MISSING - return RequireResult.PRESENT - - allow_unsigned_extensions = connection.execute( - "select value::BOOLEAN from duckdb_settings() where name == 'allow_unsigned_extensions'" - ).fetchone()[0] - if param == "allow_unsigned_extensions": - if allow_unsigned_extensions == False: - # If extension validation is turned on (that is allow_unsigned_extensions=False), skip test - return RequireResult.MISSING - return RequireResult.PRESENT - - excluded_from_autoloading = True - for ext in self.runner.AUTOLOADABLE_EXTENSIONS: - if ext == param: - excluded_from_autoloading = False - break - - if autoload_known_extensions == False: - try: - self.runner.database.load_extension(self, param) - self.runner.extensions.add(param) - except duckdb.Error: - return RequireResult.MISSING - elif excluded_from_autoloading: - return RequireResult.MISSING - - return RequireResult.PRESENT - - def execute_require(self, statement: Require): - require_result = self.check_require(statement) - if require_result != RequireResult.MISSING: - return - param = statement.header.parameters[0].lower() - if self.runner.is_required(param): - # This extension / setting was explicitly required - self.fail("require {}: FAILED".format(param)) - self.skiptest(f"require {param}: Missing") - - def execute_require_env(self, statement: RequireEnv): - key = statement.header.parameters[0] - res = os.getenv(key) - if self.in_loop(): - # FIXME: we can just remove the keyword at the end of the loop - # I think we should support this - # ... actually the way we set up keywords here, this is already the behavior - # inside the python sqllogic runner, since contexts are created and destroyed at loop start and end - self.skiptest(f"require-env can not be called in a loop") - if res is None: - self.skiptest(f"require-env {key} failed, not set") - if len(statement.header.parameters) != 1: - expected = statement.header.parameters[1] - if res != expected: - self.skiptest(f"require-env {key} failed, expected '{expected}', but found '{res}'") - self.add_keyword(key, res) - - def get_loop_statements(self): - saved_iterator = self.iterator - # Loop until EndLoop is found - statement = None - depth = 0 - while self.iterator < len(self.statements): - statement = self.next_statement() - if statement.__class__ in [Foreach, Loop]: - depth += 1 - if statement.__class__ == Endloop: - if depth == 0: - break - depth -= 1 - if not statement or statement.__class__ != Endloop: - raise Exception("no corresponding 'endloop' found before the end of the file!") - statements = self.statements[saved_iterator : self.iterator - 1] - return statements - - def execute_parallel(self, context: "SQLLogicContext", key, value): - context.is_parallel = True - try: - # For some reason the lambda won't capture the 'value' when created outside of 'execute_parallel' - def update_value(context: SQLLogicContext) -> Generator[Any, Any, Any]: - context.add_keyword(key, value) - yield None - context.remove_keyword(key) - - context.generator = update_value - context.execute() - except TestException: - assert context.error is not None - - def execute_loop(self, loop: Loop): - statements = self.get_loop_statements() - - if not loop.parallel: - # Every iteration the 'value' of the loop key needs to change - def update_value(context: SQLLogicContext) -> Generator[Any, Any, Any]: - key = loop.name - for val in range(loop.start, loop.end): - context.add_keyword(key, val) - yield None - context.remove_keyword(key) - - loop_context = SQLLogicContext(self.pool, self.runner, statements, self.keywords.copy(), update_value) - try: - loop_context.execute() - except TestException: - self.error = loop_context.error - else: - contexts: Dict[Tuple[str, int], Any] = {} - for val in range(loop.start, loop.end): - # FIXME: these connections are expected to have the same settings - # So we need to apply the cached settings to them - contexts[(loop.name, val)] = SQLLogicContext( - self.runner.database.connect(), - self.runner, - statements, - self.keywords.copy(), - None, # generator, can't be created yet - ) - - threads = [] - for keyval, context in contexts.items(): - key, value = keyval - t = threading.Thread(target=self.execute_parallel, args=(context, key, value)) - threads.append(t) - t.start() - - for thread in threads: - thread.join() - - for _, context in contexts.items(): - if context.error is not None: - # Propagate the exception - self.error = context.error - raise self.error - - def execute_foreach(self, foreach: Foreach): - statements = self.get_loop_statements() - - if not foreach.parallel: - # Every iteration the 'value' of the loop key needs to change - def update_value(context: SQLLogicContext) -> Generator[Any, Any, Any]: - loop_keys = foreach.name.split(',') - - for val in foreach.values: - if len(loop_keys) != 1: - values = val.split(',') - else: - values = [val] - assert len(values) == len(loop_keys) - for i, key in enumerate(loop_keys): - context.add_keyword(key, values[i]) - yield None - for key in loop_keys: - context.remove_keyword(key) - - loop_context = SQLLogicContext(self.pool, self.runner, statements, self.keywords.copy(), update_value) - loop_context.execute() - else: - # parallel loop: launch threads - contexts: List[Tuple[str, int, Any]] = [] - loop_keys = foreach.name.split(',') - for val in foreach.values: - if len(loop_keys) != 1: - values = val.split(',') - else: - values = [val] - - assert len(values) == len(loop_keys) - for i, key in enumerate(loop_keys): - contexts.append( - ( - foreach.name, - values[i], - SQLLogicContext( - self.runner.database.connect(), - self.runner, - statements, - self.keywords.copy(), - None, # generator, can't be created yet - ), - ) - ) - - threads = [] - for x in contexts: - key, value, context = x - t = threading.Thread(target=self.execute_parallel, args=(context, key, value)) - threads.append(t) - t.start() - - for thread in threads: - thread.join() - - for x in contexts: - _, _, context = x - if context.error is not None: - self.error = context.error - raise self.error - - def next_statement(self): - if self.iterator >= len(self.statements): - raise Exception("'next_statement' out of range, statements already consumed") - statement = self.statements[self.iterator] - self.iterator += 1 - return statement - - def verify_statements(self) -> None: - unsupported_statements = [ - statement for statement in self.statements if statement.__class__ not in self.STATEMENTS.keys() - ] - if unsupported_statements == []: - return - types = set([x.__class__ for x in unsupported_statements]) - error = f'skipped because the following statement types are not supported: {str(list([x for x in types]))}' - self.skiptest(error) - - def update_settings(self): - # Because we need to fire a query to get the settings required for 'restart' - # we do this preemptively before executing a statement - con = self.pool.get_connection() - try: - self.cached_config_settings = con.execute( - "select name, value from duckdb_settings() where value != 'NULL' and value != ''" - ).fetchall() - except duckdb.Error: - pass - - def execute(self): - try: - for _ in self.generator(self): - self.reset() - while self.iterator < len(self.statements): - statement = self.next_statement() - self.current_statement = SQLLogicStatementData(self.runner.test, statement) - if self.runner.skip_active() and statement.__class__ != Unskip: - # Keep skipping until Unskip is found - continue - if statement.get_decorators() != []: - self.skiptest("Decorators are not supported yet") - method = self.STATEMENTS.get(statement.__class__) - if not method: - self.skiptest("Not supported by the runner") - self.update_settings() - method(statement) - except TestException as e: - raise (e) - return ExecuteResult(ExecuteResult.Type.SUCCESS) diff --git a/scripts/sqllogictest/statement/__init__.py b/scripts/sqllogictest/statement/__init__.py deleted file mode 100644 index 2600fe110963..000000000000 --- a/scripts/sqllogictest/statement/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -from .statement import Statement -from .require import Require -from .mode import Mode -from .halt import Halt -from .load import Load -from .set import Set -from .load import Load -from .query import Query, SortStyle -from .hash_threshold import HashThreshold -from .loop import Loop -from .foreach import Foreach -from .endloop import Endloop -from .require_env import RequireEnv -from .restart import Restart -from .reconnect import Reconnect -from .sleep import Sleep, SleepUnit -from .unzip import Unzip - -from .skip import Skip, Unskip - -__all__ = [ - Statement, - Require, - Mode, - Halt, - Load, - Set, - Query, - HashThreshold, - Loop, - Foreach, - Endloop, - RequireEnv, - Restart, - Reconnect, - Sleep, - SleepUnit, - Skip, - Unzip, - Unskip, - SortStyle, -] diff --git a/scripts/sqllogictest/statement/endloop.py b/scripts/sqllogictest/statement/endloop.py deleted file mode 100644 index 89b415791124..000000000000 --- a/scripts/sqllogictest/statement/endloop.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Endloop(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/foreach.py b/scripts/sqllogictest/statement/foreach.py deleted file mode 100644 index 7f52beea2a1a..000000000000 --- a/scripts/sqllogictest/statement/foreach.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token -from typing import Optional, List - - -class Foreach(BaseStatement): - def __init__(self, header: Token, line: int, parallel: bool): - super().__init__(header, line) - self.parallel = parallel - self.values: List[str] = [] - self.name: Optional[str] = None - - def set_name(self, name: str): - self.name = name - - def set_values(self, values: List[str]): - self.values = values diff --git a/scripts/sqllogictest/statement/halt.py b/scripts/sqllogictest/statement/halt.py deleted file mode 100644 index 8be9e46e29a7..000000000000 --- a/scripts/sqllogictest/statement/halt.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Halt(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/hash_threshold.py b/scripts/sqllogictest/statement/hash_threshold.py deleted file mode 100644 index 0134b6c8158c..000000000000 --- a/scripts/sqllogictest/statement/hash_threshold.py +++ /dev/null @@ -1,8 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class HashThreshold(BaseStatement): - def __init__(self, header: Token, line: int, threshold: int): - super().__init__(header, line) - self.threshold = threshold diff --git a/scripts/sqllogictest/statement/load.py b/scripts/sqllogictest/statement/load.py deleted file mode 100644 index 98ac08439836..000000000000 --- a/scripts/sqllogictest/statement/load.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Load(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) - self.readonly: bool = False - self.version: Optional[int] = None - - def set_readonly(self): - self.readonly = True - - def set_version(self, version: str): - self.version = version diff --git a/scripts/sqllogictest/statement/loop.py b/scripts/sqllogictest/statement/loop.py deleted file mode 100644 index a0a2a696963c..000000000000 --- a/scripts/sqllogictest/statement/loop.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token -from typing import Optional, List - - -class Loop(BaseStatement): - def __init__(self, header: Token, line: int, parallel: bool): - super().__init__(header, line) - self.parallel = parallel - self.name: Optional[str] = None - self.start: Optional[int] = None - self.end: Optional[int] = None - - def set_name(self, name: str): - self.name = name - - def set_start(self, start: List[str]): - self.start = start - - def set_end(self, end: List[str]): - self.end = end diff --git a/scripts/sqllogictest/statement/mode.py b/scripts/sqllogictest/statement/mode.py deleted file mode 100644 index be4155883632..000000000000 --- a/scripts/sqllogictest/statement/mode.py +++ /dev/null @@ -1,8 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Mode(BaseStatement): - def __init__(self, header: Token, line: int, parameter: str): - super().__init__(header, line) - self.parameter = parameter diff --git a/scripts/sqllogictest/statement/query.py b/scripts/sqllogictest/statement/query.py deleted file mode 100644 index 29ac003a94a3..000000000000 --- a/scripts/sqllogictest/statement/query.py +++ /dev/null @@ -1,44 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.expected_result import ExpectedResult -from sqllogictest.token import Token -from typing import Optional, List -from enum import Enum - - -class SortStyle(Enum): - NO_SORT = 0 - ROW_SORT = 1 - VALUE_SORT = 2 - UNKNOWN = 3 - - -class Query(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) - self.label: Optional[str] = None - self.lines: List[str] = [] - self.expected_result: Optional[ExpectedResult] = None - self.connection_name: Optional[str] = None - self.sortstyle: Optional[SortStyle] = None - self.label: Optional[str] = None - - def add_lines(self, lines: List[str]): - self.lines.extend(lines) - - def set_connection(self, connection: str): - self.connection_name = connection - - def set_expected_result(self, expected_result: ExpectedResult): - self.expected_result = expected_result - - def set_sortstyle(self, sortstyle: SortStyle): - self.sortstyle = sortstyle - - def get_sortstyle(self) -> Optional[SortStyle]: - return self.sortstyle - - def set_label(self, label: str): - self.label = label - - def get_label(self) -> Optional[str]: - return self.label diff --git a/scripts/sqllogictest/statement/reconnect.py b/scripts/sqllogictest/statement/reconnect.py deleted file mode 100644 index 558e4c17c35b..000000000000 --- a/scripts/sqllogictest/statement/reconnect.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Reconnect(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/require.py b/scripts/sqllogictest/statement/require.py deleted file mode 100644 index 25ab0dc66e70..000000000000 --- a/scripts/sqllogictest/statement/require.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Require(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/require_env.py b/scripts/sqllogictest/statement/require_env.py deleted file mode 100644 index 712ac8a707f3..000000000000 --- a/scripts/sqllogictest/statement/require_env.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class RequireEnv(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/restart.py b/scripts/sqllogictest/statement/restart.py deleted file mode 100644 index 5e271180271b..000000000000 --- a/scripts/sqllogictest/statement/restart.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Restart(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/set.py b/scripts/sqllogictest/statement/set.py deleted file mode 100644 index e19488a0cd7b..000000000000 --- a/scripts/sqllogictest/statement/set.py +++ /dev/null @@ -1,12 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token -from typing import List - - -class Set(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) - self.error_messages = [] - - def add_error_messages(self, error_messages: List[str]): - self.error_messages.extend(error_messages) diff --git a/scripts/sqllogictest/statement/skip.py b/scripts/sqllogictest/statement/skip.py deleted file mode 100644 index 88c4e930cc83..000000000000 --- a/scripts/sqllogictest/statement/skip.py +++ /dev/null @@ -1,12 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Skip(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) - - -class Unskip(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) diff --git a/scripts/sqllogictest/statement/sleep.py b/scripts/sqllogictest/statement/sleep.py deleted file mode 100644 index dae4a4b43de0..000000000000 --- a/scripts/sqllogictest/statement/sleep.py +++ /dev/null @@ -1,41 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token -from enum import Enum, auto - - -class SleepUnit(Enum): - SECOND = auto() - MILLISECOND = auto() - MICROSECOND = auto() - NANOSECOND = auto() - UNKNOWN = auto() - - -def get_sleep_unit(unit): - seconds = ["second", "seconds", "sec"] - milliseconds = ["millisecond", "milliseconds", "milli"] - microseconds = ["microsecond", "microseconds", "micro"] - nanoseconds = ["nanosecond", "nanoseconds", "nano"] - if unit in seconds: - return SleepUnit.SECOND - elif unit in milliseconds: - return SleepUnit.MILLISECOND - elif unit in microseconds: - return SleepUnit.MICROSECOND - elif unit in nanoseconds: - return SleepUnit.NANOSECOND - else: - return SleepUnit.UNKNOWN - - -class Sleep(BaseStatement): - def __init__(self, header: Token, line: int, duration: int, unit: SleepUnit): - super().__init__(header, line) - self.duration = duration - self.unit = unit - - def get_duration(self) -> int: - return self.duration - - def get_unit(self) -> SleepUnit: - return self.unit diff --git a/scripts/sqllogictest/statement/statement.py b/scripts/sqllogictest/statement/statement.py deleted file mode 100644 index db1027067acf..000000000000 --- a/scripts/sqllogictest/statement/statement.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.expected_result import ExpectedResult -from sqllogictest.token import Token -from typing import List, Optional - - -class Statement(BaseStatement): - def __init__(self, header: Token, line: int): - super().__init__(header, line) - self.lines: List[str] = [] - self.expected_result: Optional[ExpectedResult] = None - self.connection_name: Optional[str] = None - - def add_lines(self, lines: List[str]): - self.lines.extend(lines) - - def set_connection(self, connection: str): - self.connection_name = connection - - def set_expected_result(self, expected_result: ExpectedResult): - self.expected_result = expected_result diff --git a/scripts/sqllogictest/statement/unzip.py b/scripts/sqllogictest/statement/unzip.py deleted file mode 100644 index 030e1b7b4e70..000000000000 --- a/scripts/sqllogictest/statement/unzip.py +++ /dev/null @@ -1,9 +0,0 @@ -from sqllogictest.base_statement import BaseStatement -from sqllogictest.token import Token - - -class Unzip(BaseStatement): - def __init__(self, header: Token, line: int, source: str, destination: str): - super().__init__(header, line) - self.source = source - self.destination = destination diff --git a/scripts/sqllogictest/test.py b/scripts/sqllogictest/test.py deleted file mode 100644 index 21d08d7a12ea..000000000000 --- a/scripts/sqllogictest/test.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import List -from .base_statement import BaseStatement - - -class SQLLogicTest: - __slots__ = ['path', 'statements'] - - def __init__(self, path: str): - self.path: str = path - self.statements: List[BaseStatement] = [] - - def add_statement(self, statement: BaseStatement): - self.statements.append(statement) - - def is_sqlite_test(self): - return 'test/sqlite/select' in self.path or 'third_party/sqllogictest' in self.path diff --git a/scripts/sqllogictest/token.py b/scripts/sqllogictest/token.py deleted file mode 100644 index 88d42d156d38..000000000000 --- a/scripts/sqllogictest/token.py +++ /dev/null @@ -1,31 +0,0 @@ -from enum import Enum, auto - - -class TokenType(Enum): - SQLLOGIC_INVALID = auto() - SQLLOGIC_SKIP_IF = auto() - SQLLOGIC_ONLY_IF = auto() - SQLLOGIC_STATEMENT = auto() - SQLLOGIC_QUERY = auto() - SQLLOGIC_HASH_THRESHOLD = auto() - SQLLOGIC_HALT = auto() - SQLLOGIC_MODE = auto() - SQLLOGIC_SET = auto() - SQLLOGIC_LOOP = auto() - SQLLOGIC_CONCURRENT_LOOP = auto() - SQLLOGIC_FOREACH = auto() - SQLLOGIC_CONCURRENT_FOREACH = auto() - SQLLOGIC_ENDLOOP = auto() - SQLLOGIC_REQUIRE = auto() - SQLLOGIC_REQUIRE_ENV = auto() - SQLLOGIC_LOAD = auto() - SQLLOGIC_RESTART = auto() - SQLLOGIC_RECONNECT = auto() - SQLLOGIC_SLEEP = auto() - SQLLOGIC_UNZIP = auto() - - -class Token: - def __init__(self): - self.type = TokenType.SQLLOGIC_INVALID - self.parameters = [] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1915e92081df..72e996107f64 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,7 +24,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") set(EXIT_TIME_DESTRUCTORS_WARNING TRUE) set(CMAKE_CXX_FLAGS_DEBUG - "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing -Wsign-conversion -Wsign-compare -Wconversion" + "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing -Wsign-conversion -Wsign-compare -Wconversion -Wtype-limits" ) endif() diff --git a/src/catalog/catalog_entry/copy_function_catalog_entry.cpp b/src/catalog/catalog_entry/copy_function_catalog_entry.cpp index 25544a343987..6d639bef695e 100644 --- a/src/catalog/catalog_entry/copy_function_catalog_entry.cpp +++ b/src/catalog/catalog_entry/copy_function_catalog_entry.cpp @@ -3,6 +3,8 @@ namespace duckdb { +constexpr const char *CopyFunctionCatalogEntry::Name; + CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCopyFunctionInfo &info) : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { diff --git a/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp b/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp index ff247dcb07d2..9d9789192367 100644 --- a/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp +++ b/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" namespace duckdb { +constexpr const char *PragmaFunctionCatalogEntry::Name; PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info) diff --git a/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp index 49b20f677aa6..e5778ad4ce2d 100644 --- a/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp +++ b/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp @@ -5,6 +5,8 @@ namespace duckdb { +constexpr const char *ScalarFunctionCatalogEntry::Name; + ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateScalarFunctionInfo &info) : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { diff --git a/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/catalog/catalog_entry/sequence_catalog_entry.cpp index 6153a8e8ad7a..d6a548a26765 100644 --- a/src/catalog/catalog_entry/sequence_catalog_entry.cpp +++ b/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -13,6 +13,8 @@ namespace duckdb { +constexpr const char *SequenceCatalogEntry::Name; + SequenceData::SequenceData(CreateSequenceInfo &info) : usage_count(info.usage_count), counter(info.start_value), last_value(info.start_value), increment(info.increment), start_value(info.start_value), min_value(info.min_value), max_value(info.max_value), cycle(info.cycle) { diff --git a/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/catalog/catalog_entry/table_catalog_entry.cpp index 5ed480e86ce7..8582fa93c6ac 100644 --- a/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -19,6 +19,8 @@ namespace duckdb { +constexpr const char *TableCatalogEntry::Name; + TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), constraints(std::move(info.constraints)) { diff --git a/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/catalog/catalog_entry/table_function_catalog_entry.cpp index a6a41ff6197e..f06ef164eddc 100644 --- a/src/catalog/catalog_entry/table_function_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_function_catalog_entry.cpp @@ -4,6 +4,8 @@ namespace duckdb { +constexpr const char *TableFunctionCatalogEntry::Name; + TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info) : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { diff --git a/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/catalog/catalog_entry/type_catalog_entry.cpp index 0bb4a3f3af71..324413b7c280 100644 --- a/src/catalog/catalog_entry/type_catalog_entry.cpp +++ b/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -9,6 +9,8 @@ namespace duckdb { +constexpr const char *TypeCatalogEntry::Name; + TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type), bind_function(info.bind_function) { diff --git a/src/catalog/catalog_set.cpp b/src/catalog/catalog_set.cpp index deff8daae445..d374f6999d70 100644 --- a/src/catalog/catalog_set.cpp +++ b/src/catalog/catalog_set.cpp @@ -401,8 +401,6 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); } - entry->OnDrop(); - // create a new tombstone entry and replace the currently stored one // set the timestamp to the timestamp of the current transaction // and point it at the tombstone node @@ -454,6 +452,7 @@ void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEnt void CatalogSet::CommitDrop(transaction_t commit_id, transaction_t start_time, CatalogEntry &entry) { auto &duck_catalog = GetCatalog(); + entry.OnDrop(); // Make sure that we don't see any uncommitted changes auto transaction_id = MAX_TRANSACTION_ID; // This will allow us to see all committed changes made before this COMMIT happened diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 539ed2bd3178..03da02062631 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -13,7 +13,6 @@ add_subdirectory(tree_renderer) add_subdirectory(row_operations) add_subdirectory(serializer) add_subdirectory(sort) -add_subdirectory(sorting) add_subdirectory(types) add_subdirectory(value_operations) add_subdirectory(vector_operations) diff --git a/src/common/allocator.cpp b/src/common/allocator.cpp index 977087939852..609558785272 100644 --- a/src/common/allocator.cpp +++ b/src/common/allocator.cpp @@ -35,6 +35,8 @@ namespace duckdb { +constexpr const idx_t Allocator::MAXIMUM_ALLOC_SIZE; + AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { } diff --git a/src/common/csv_writer.cpp b/src/common/csv_writer.cpp index bb9ff81d216d..0f3126691a46 100644 --- a/src/common/csv_writer.cpp +++ b/src/common/csv_writer.cpp @@ -16,7 +16,7 @@ CSVWriterState::CSVWriterState() } CSVWriterState::CSVWriterState(ClientContext &context, idx_t flush_size_p) - : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context))) { + : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context), flush_size)) { } CSVWriterState::CSVWriterState(DatabaseInstance &db, idx_t flush_size_p) @@ -198,18 +198,6 @@ void CSVWriter::ResetInternal(optional_ptr local_state) { bytes_written = 0; } -unique_ptr CSVWriter::InitializeLocalWriteState(ClientContext &context, idx_t flush_size) { - auto res = make_uniq(context, flush_size); - res->stream = make_uniq(); - return res; -} - -unique_ptr CSVWriter::InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size) { - auto res = make_uniq(db, flush_size); - res->stream = make_uniq(); - return res; -} - idx_t CSVWriter::BytesWritten() { if (shared) { lock_guard flock(lock); diff --git a/src/common/enum_util.cpp b/src/common/enum_util.cpp index 1005c6f91dff..5003cc5815e6 100644 --- a/src/common/enum_util.cpp +++ b/src/common/enum_util.cpp @@ -2798,32 +2798,36 @@ MetaPipelineType EnumUtil::FromString(const char *value) { const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, + { static_cast(MetricsType::ATTACH_LOAD_STORAGE_LATENCY), "ATTACH_LOAD_STORAGE_LATENCY" }, + { static_cast(MetricsType::ATTACH_REPLAY_WAL_LATENCY), "ATTACH_REPLAY_WAL_LATENCY" }, { static_cast(MetricsType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, + { static_cast(MetricsType::CHECKPOINT_LATENCY), "CHECKPOINT_LATENCY" }, { static_cast(MetricsType::CPU_TIME), "CPU_TIME" }, - { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, { static_cast(MetricsType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, - { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, - { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, { static_cast(MetricsType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, + { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, + { static_cast(MetricsType::LATENCY), "LATENCY" }, + { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, + { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, { static_cast(MetricsType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, { static_cast(MetricsType::OPERATOR_TIMING), "OPERATOR_TIMING" }, + { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, + { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, { static_cast(MetricsType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, - { static_cast(MetricsType::LATENCY), "LATENCY" }, { static_cast(MetricsType::ROWS_RETURNED), "ROWS_RETURNED" }, - { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, { static_cast(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY), "SYSTEM_PEAK_BUFFER_MEMORY" }, { static_cast(MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE), "SYSTEM_PEAK_TEMP_DIR_SIZE" }, { static_cast(MetricsType::TOTAL_BYTES_READ), "TOTAL_BYTES_READ" }, { static_cast(MetricsType::TOTAL_BYTES_WRITTEN), "TOTAL_BYTES_WRITTEN" }, + { static_cast(MetricsType::WAITING_TO_ATTACH_LATENCY), "WAITING_TO_ATTACH_LATENCY" }, { static_cast(MetricsType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, { static_cast(MetricsType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, - { static_cast(MetricsType::PLANNER), "PLANNER" }, - { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, { static_cast(MetricsType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, { static_cast(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, { static_cast(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, + { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, + { static_cast(MetricsType::PLANNER), "PLANNER" }, + { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, { static_cast(MetricsType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, { static_cast(MetricsType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, { static_cast(MetricsType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, @@ -2860,12 +2864,12 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { template<> const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 56, "MetricsType", static_cast(value)); + return StringUtil::EnumToString(GetMetricsTypeValues(), 60, "MetricsType", static_cast(value)); } template<> MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 56, "MetricsType", value)); + return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 60, "MetricsType", value)); } const StringUtil::EnumStringLiteral *GetMultiFileColumnMappingModeValues() { diff --git a/src/common/enums/metric_type.cpp b/src/common/enums/metric_type.cpp index a4b7c73385b6..84b552037d39 100644 --- a/src/common/enums/metric_type.cpp +++ b/src/common/enums/metric_type.cpp @@ -50,12 +50,12 @@ profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { return { MetricsType::ALL_OPTIMIZERS, MetricsType::CUMULATIVE_OPTIMIZER_TIMING, - MetricsType::PLANNER, - MetricsType::PLANNER_BINDING, MetricsType::PHYSICAL_PLANNER, MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, MetricsType::PHYSICAL_PLANNER_CREATE_PLAN, + MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, + MetricsType::PLANNER, + MetricsType::PLANNER_BINDING, }; } @@ -235,12 +235,12 @@ bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { switch(type) { case MetricsType::ALL_OPTIMIZERS: case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricsType::PLANNER: - case MetricsType::PLANNER_BINDING: case MetricsType::PHYSICAL_PLANNER: case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: + case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: + case MetricsType::PLANNER: + case MetricsType::PLANNER_BINDING: return true; default: return false; @@ -249,9 +249,13 @@ bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { bool MetricsUtils::IsQueryGlobalMetric(MetricsType type) { switch(type) { + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: case MetricsType::BLOCKED_THREAD_TIME: + case MetricsType::CHECKPOINT_LATENCY: case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricsType::WAITING_TO_ATTACH_LATENCY: return true; default: return false; diff --git a/src/common/error_data.cpp b/src/common/error_data.cpp index 2ddf94af69aa..07b039ea83f2 100644 --- a/src/common/error_data.cpp +++ b/src/common/error_data.cpp @@ -80,9 +80,9 @@ void ErrorData::Throw(const string &prepended_message) const { D_ASSERT(initialized); if (!prepended_message.empty()) { string new_message = prepended_message + raw_message; - throw Exception(type, new_message, extra_info); + throw Exception(extra_info, type, new_message); } else { - throw Exception(type, raw_message, extra_info); + throw Exception(extra_info, type, raw_message); } } diff --git a/src/common/exception.cpp b/src/common/exception.cpp index 2012c1fcce2e..1d7345cf4476 100644 --- a/src/common/exception.cpp +++ b/src/common/exception.cpp @@ -19,17 +19,17 @@ Exception::Exception(ExceptionType exception_type, const string &message) : std::runtime_error(ToJSON(exception_type, message)) { } -Exception::Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info) - : std::runtime_error(ToJSON(exception_type, message, extra_info)) { +Exception::Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message) + : std::runtime_error(ToJSON(extra_info, exception_type, message)) { } string Exception::ToJSON(ExceptionType type, const string &message) { unordered_map extra_info; - return ToJSON(type, message, extra_info); + return ToJSON(extra_info, type, message); } -string Exception::ToJSON(ExceptionType type, const string &message, const unordered_map &extra_info) { +string Exception::ToJSON(const unordered_map &extra_info, ExceptionType type, const string &message) { #ifndef DUCKDB_DEBUG_STACKTRACE // by default we only enable stack traces for internal exceptions if (type == ExceptionType::INTERNAL || type == ExceptionType::FATAL) @@ -240,9 +240,8 @@ TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const Lo TypeMismatchException::TypeMismatchException(optional_idx error_location, const LogicalType &type_1, const LogicalType &type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg, - Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::MISMATCH_TYPE, + "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg) { } TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { @@ -306,8 +305,8 @@ DependencyException::DependencyException(const string &msg) : Exception(Exceptio IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { } -IOException::IOException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::IO, msg, extra_info) { +IOException::IOException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::IO, msg) { } MissingExtensionException::MissingExtensionException(const string &msg) @@ -342,17 +341,17 @@ InternalException::InternalException(const string &msg) : Exception(ExceptionTyp InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { } -InvalidInputException::InvalidInputException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_INPUT, msg, extra_info) { +InvalidInputException::InvalidInputException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::INVALID_INPUT, msg) { } InvalidConfigurationException::InvalidConfigurationException(const string &msg) : Exception(ExceptionType::INVALID_CONFIGURATION, msg) { } -InvalidConfigurationException::InvalidConfigurationException(const string &msg, - const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_CONFIGURATION, msg, extra_info) { +InvalidConfigurationException::InvalidConfigurationException(const unordered_map &extra_info, + const string &msg) + : Exception(extra_info, ExceptionType::INVALID_CONFIGURATION, msg) { } OutOfMemoryException::OutOfMemoryException(const string &msg) diff --git a/src/common/exception/binder_exception.cpp b/src/common/exception/binder_exception.cpp index 62dca06fb5af..70f71a52b4ab 100644 --- a/src/common/exception/binder_exception.cpp +++ b/src/common/exception/binder_exception.cpp @@ -7,8 +7,8 @@ namespace duckdb { BinderException::BinderException(const string &msg) : Exception(ExceptionType::BINDER, msg) { } -BinderException::BinderException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::BINDER, msg, extra_info) { +BinderException::BinderException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::BINDER, msg) { } BinderException BinderException::ColumnNotFound(const string &name, const vector &similar_bindings, @@ -20,7 +20,7 @@ BinderException BinderException::ColumnNotFound(const string &name, const vector extra_info["candidates"] = StringUtil::Join(similar_bindings, ","); } return BinderException( - StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", name, candidate_str), extra_info); + extra_info, StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", name, candidate_str)); } BinderException BinderException::NoMatchingFunction(const string &catalog_name, const string &schema_name, @@ -45,15 +45,14 @@ BinderException BinderException::NoMatchingFunction(const string &catalog_name, extra_info["candidates"] = StringUtil::Join(candidates, ","); } return BinderException( + extra_info, StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " "explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str), - extra_info); + call_str, candidate_str)); } BinderException BinderException::Unsupported(ParsedExpression &expr, const string &message) { auto extra_info = Exception::InitializeExtraInfo("UNSUPPORTED", expr.GetQueryLocation()); - return BinderException(message, extra_info); + return BinderException(extra_info, message); } - } // namespace duckdb diff --git a/src/common/exception/catalog_exception.cpp b/src/common/exception/catalog_exception.cpp index 5d890f1cdb60..b1cd4caf79cf 100644 --- a/src/common/exception/catalog_exception.cpp +++ b/src/common/exception/catalog_exception.cpp @@ -9,8 +9,8 @@ namespace duckdb { CatalogException::CatalogException(const string &msg) : Exception(ExceptionType::CATALOG, msg) { } -CatalogException::CatalogException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::CATALOG, msg, extra_info) { +CatalogException::CatalogException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::CATALOG, msg) { } CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion) { @@ -35,9 +35,9 @@ CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_in if (!suggestion.empty()) { extra_info["candidates"] = suggestion; } - return CatalogException(StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, - version_info, did_you_mean), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, + version_info, did_you_mean)); } CatalogException CatalogException::MissingEntry(CatalogType type, const string &name, const string &suggestion, @@ -55,17 +55,17 @@ CatalogException CatalogException::MissingEntry(const string &type, const string if (!suggestions.empty()) { extra_info["candidates"] = StringUtil::Join(suggestions, ", "); } - return CatalogException(StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, - StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean")), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, + StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean"))); } CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const string &name, QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("ENTRY_ALREADY_EXISTS", optional_idx()); extra_info["name"] = name; extra_info["type"] = CatalogTypeToString(type); - return CatalogException(StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name)); } } // namespace duckdb diff --git a/src/common/exception/conversion_exception.cpp b/src/common/exception/conversion_exception.cpp index 013dbdb9e5df..bf021b4eb0e6 100644 --- a/src/common/exception/conversion_exception.cpp +++ b/src/common/exception/conversion_exception.cpp @@ -17,7 +17,7 @@ ConversionException::ConversionException(const string &msg) : Exception(Exceptio } ConversionException::ConversionException(optional_idx error_location, const string &msg) - : Exception(ExceptionType::CONVERSION, msg, Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::CONVERSION, msg) { } } // namespace duckdb diff --git a/src/common/exception/parser_exception.cpp b/src/common/exception/parser_exception.cpp index f3875da3890d..3afb2ea3d201 100644 --- a/src/common/exception/parser_exception.cpp +++ b/src/common/exception/parser_exception.cpp @@ -7,13 +7,12 @@ namespace duckdb { ParserException::ParserException(const string &msg) : Exception(ExceptionType::PARSER, msg) { } -ParserException::ParserException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::PARSER, msg, extra_info) { +ParserException::ParserException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::PARSER, msg) { } ParserException ParserException::SyntaxError(const string &query, const string &error_message, optional_idx error_location) { - return ParserException(error_message, Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location)); + return ParserException(Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location), error_message); } - } // namespace duckdb diff --git a/src/common/exception_format_value.cpp b/src/common/exception_format_value.cpp index 51e34ec0e085..27b4eb4659ce 100644 --- a/src/common/exception_format_value.cpp +++ b/src/common/exception_format_value.cpp @@ -28,65 +28,61 @@ ExceptionFormatValue::ExceptionFormatValue(uhugeint_t uhuge_val) ExceptionFormatValue::ExceptionFormatValue(string str_val) : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { } -ExceptionFormatValue::ExceptionFormatValue(String str_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(str_val.ToStdString()) { +ExceptionFormatValue::ExceptionFormatValue(const String &str_val) : ExceptionFormatValue(str_val.ToStdString()) { } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value) { return ExceptionFormatValue(TypeIdToString(value)); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value) { return ExceptionFormatValue(value.ToString()); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value) { + return ExceptionFormatValue(static_cast(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value) { return KeywordHelper::WriteQuoted(value.raw_string, '\''); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value) { return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value) { return ExceptionFormatValue(value); } diff --git a/src/common/local_file_system.cpp b/src/common/local_file_system.cpp index 8733e0162046..5829cb54848e 100644 --- a/src/common/local_file_system.cpp +++ b/src/common/local_file_system.cpp @@ -369,7 +369,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (flags.ReturnNullIfExists() && errno == EEXIST) { return nullptr; } - throw IOException("Cannot open file \"%s\": %s", {{"errno", std::to_string(errno)}}, path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, "Cannot open file \"%s\": %s", path, strerror(errno)); } #if defined(__DARWIN__) || defined(__APPLE__) @@ -436,7 +436,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF extended_error += ". Also, failed closing file"; } extended_error += ". See also https://duckdb.org/docs/stable/connect/concurrency"; - throw IOException("Could not set lock on file \"%s\": %s", {{"errno", std::to_string(retained_errno)}}, + throw IOException({{"errno", std::to_string(retained_errno)}}, "Could not set lock on file \"%s\": %s", path, extended_error); } } @@ -454,7 +454,7 @@ void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { int fd = handle.Cast().fd; off_t offset = lseek(fd, UnsafeNumericCast(location), SEEK_SET); if (offset == (off_t)-1) { - throw IOException("Could not seek to location %lld for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not seek to location %lld for file \"%s\": %s", location, handle.path, strerror(errno)); } } @@ -463,7 +463,7 @@ idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { int fd = handle.Cast().fd; off_t position = lseek(fd, 0, SEEK_CUR); if (position == (off_t)-1) { - throw IOException("Could not get file position file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not get file position file \"%s\": %s", handle.path, strerror(errno)); } return UnsafeNumericCast(position); @@ -477,7 +477,7 @@ void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, i int64_t bytes_read = pread(fd, read_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_read == 0) { @@ -498,7 +498,7 @@ int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes int fd = unix_handle.fd; int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } @@ -519,12 +519,13 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, int64_t bytes_written = pwrite(fd, write_buffer, UnsafeNumericCast(bytes_to_write), UnsafeNumericCast(current_location)); if (bytes_written < 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_written == 0) { - throw IOException("Could not write to file \"%s\" - attempted to write 0 bytes: %s", - {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, + "Could not write to file \"%s\" - attempted to write 0 bytes: %s", handle.path, + strerror(errno)); } write_buffer += bytes_written; bytes_to_write -= bytes_written; @@ -544,7 +545,7 @@ int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_byte MinValue(idx_t(NumericLimits::Maximum()), idx_t(bytes_to_write)); int64_t current_bytes_written = write(fd, buffer, bytes_to_write_this_call); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } buffer = (void *)(data_ptr_cast(buffer) + current_bytes_written); @@ -577,7 +578,7 @@ int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { int fd = handle.Cast().fd; struct stat s; if (fstat(fd, &s) == -1) { - throw IOException("Failed to get file size for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Failed to get file size for file \"%s\": %s", handle.path, strerror(errno)); } return s.st_size; @@ -587,7 +588,7 @@ timestamp_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { int fd = handle.Cast().fd; struct stat s; if (fstat(fd, &s) == -1) { - throw IOException("Failed to get last modified time for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Failed to get last modified time for file \"%s\": %s", handle.path, strerror(errno)); } return Timestamp::FromEpochSeconds(s.st_mtime); @@ -601,7 +602,7 @@ FileType LocalFileSystem::GetFileType(FileHandle &handle) { void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { int fd = handle.Cast().fd; if (ftruncate(fd, new_size) != 0) { - throw IOException("Could not truncate file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not truncate file \"%s\": %s", handle.path, strerror(errno)); } } @@ -628,12 +629,12 @@ void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr opener) { auto normalized_file = NormalizeLocalPath(filename); if (std::remove(normalized_file) != 0) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", std::to_string(errno)}}, filename, + throw IOException({{"errno", std::to_string(errno)}}, "Could not remove file \"%s\": %s", filename, strerror(errno)); } } @@ -767,8 +768,7 @@ void LocalFileSystem::FileSync(FileHandle &handle) { } // For other types of errors, throw normal IO exception. - throw IOException("Could not fsync file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.GetPath(), - strerror(errno)); + throw IOException("Could not fsync file \"%s\": %s", handle.GetPath(), strerror(errno)); } void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -776,7 +776,7 @@ void LocalFileSystem::MoveFile(const string &source, const string &target, optio auto normalized_target = NormalizeLocalPath(target); //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists if (rename(normalized_source, normalized_target) != 0) { - throw IOException("Could not rename file!", {{"errno", std::to_string(errno)}}); + throw IOException({{"errno", std::to_string(errno)}}, "Could not rename file!"); } } @@ -1052,7 +1052,7 @@ static int64_t FSWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t n auto bytes_to_write = MinValue(idx_t(NumericLimits::Maximum()), idx_t(nr_bytes)); DWORD current_bytes_written = FSInternalWrite(handle, hFile, buffer, bytes_to_write, location); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } bytes_written += current_bytes_written; diff --git a/src/common/progress_bar/unscented_kalman_filter.cpp b/src/common/progress_bar/unscented_kalman_filter.cpp index 551cdbb0f97d..98b2a1319616 100644 --- a/src/common/progress_bar/unscented_kalman_filter.cpp +++ b/src/common/progress_bar/unscented_kalman_filter.cpp @@ -254,11 +254,11 @@ void UnscentedKalmanFilter::UpdateInternal(double measured_progress) { } // Ensure progress stays in bounds - x[0] = std::max(0.0, std::min(1.0, x[0])); + x[0] = std::max(0.0, std::min(scale_factor, x[0])); } double UnscentedKalmanFilter::GetProgress() const { - return x[0]; + return x[0] / scale_factor; } double UnscentedKalmanFilter::GetVelocity() const { diff --git a/src/common/row_operations/CMakeLists.txt b/src/common/row_operations/CMakeLists.txt index b07cd84aa466..f534c5d0b147 100644 --- a/src/common/row_operations/CMakeLists.txt +++ b/src/common/row_operations/CMakeLists.txt @@ -1,14 +1,5 @@ -add_library_unity( - duckdb_row_operations - OBJECT - row_aggregate.cpp - row_scatter.cpp - row_gather.cpp - row_matcher.cpp - row_external.cpp - row_radix_scatter.cpp - row_heap_scatter.cpp - row_heap_gather.cpp) +add_library_unity(duckdb_row_operations OBJECT row_aggregate.cpp + row_matcher.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/src/common/row_operations/row_external.cpp b/src/common/row_operations/row_external.cpp deleted file mode 100644 index e4e3ec87d865..000000000000 --- a/src/common/row_operations/row_external.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), - string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb diff --git a/src/common/row_operations/row_gather.cpp b/src/common/row_operations/row_gather.cpp deleted file mode 100644 index 8e5ed315b924..000000000000 --- a/src/common/row_operations/row_gather.cpp +++ /dev/null @@ -1,176 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/constant_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array_uninitialized(count); - auto mask_locations = make_unsafe_uniq_array_uninitialized(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - NestedValidity parent_validity(mask_locations.get(), col_no); - RowOperations::HeapGather(col, count, col_sel, data_locations.get(), &parent_validity); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -} // namespace duckdb diff --git a/src/common/row_operations/row_heap_gather.cpp b/src/common/row_operations/row_heap_gather.cpp deleted file mode 100644 index fa433c64e120..000000000000 --- a/src/common/row_operations/row_heap_gather.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - NestedValidity parent_validity(struct_validitymask_locations, i); - RowOperations::HeapGather(*children[i], vcount, sel, key_locations, &parent_validity); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapGatherArrayVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // Setup - auto &child_type = ArrayType::GetChildType(v.GetType()); - auto array_size = ArrayType::GetSize(v.GetType()); - auto &child_vector = ArrayVector::GetEntry(v); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < vcount; i++) { - // Setup validity mask - data_ptr_t array_validitymask_location = key_locations[i]; - key_locations[i] += array_validitymask_size; - - NestedValidity parent_validity(array_validitymask_location); - - // The size of each variable size entry is stored after the validity mask - // (if the child type is variable size) - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // row idx - const auto row_idx = sel.get_index(i); - - idx_t array_start = row_idx * array_size; - idx_t elem_remaining = array_size; - - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - SelectionVector array_sel(STANDARD_VECTOR_SIZE); - - if (child_type_is_var_size) { - // variable size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } else { - // constant size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } - - // Pass on this array's validity mask to the child vector - RowOperations::HeapGather(child_vector, chunk_size, array_sel, array_entry_locations, &parent_validity); - - elem_remaining -= chunk_size; - array_start += chunk_size; - parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t *key_locations, - optional_ptr parent_validity) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (parent_validity) { - for (idx_t i = 0; i < vcount; i++) { - const auto valid = parent_validity->IsValid(i); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - case PhysicalType::ARRAY: - HeapGatherArrayVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb diff --git a/src/common/row_operations/row_heap_scatter.cpp b/src/common/row_operations/row_heap_scatter.cpp deleted file mode 100644 index 01cf7b5897ea..000000000000 --- a/src/common/row_operations/row_heap_scatter.cpp +++ /dev/null @@ -1,581 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -NestedValidity::NestedValidity(data_ptr_t validitymask_location) - : list_validity_location(validitymask_location), struct_validity_locations(nullptr), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { -} - -NestedValidity::NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index) - : list_validity_location(nullptr), struct_validity_locations(validitymask_locations), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { - ValidityBytes::GetEntryIndex(child_vector_index, entry_idx, idx_in_entry); -} - -void NestedValidity::SetInvalid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = ~(1UL << list_idx_in_entry); - list_validity_location[list_entry_idx] &= bit; - } else { - // Is Struct - const auto bit = ~(1UL << idx_in_entry); - *(struct_validity_locations[idx] + entry_idx) &= bit; - } -} - -void NestedValidity::OffsetListBy(idx_t offset) { - list_validity_offset += offset; -} - -bool NestedValidity::IsValid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = (1UL << list_idx_in_entry); - return list_validity_location[list_entry_idx] & bit; - } else { - // Is Struct - const auto bit = (1UL << idx_in_entry); - return *(struct_validity_locations[idx] + entry_idx) & bit; - } -} - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -static void ComputeArrayEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_vector = ArrayVector::GetEntry(v); - - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - const idx_t array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - - // Validity for the array elements - entry_sizes[i] += array_validitymask_size; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ArrayType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += array_size * sizeof(idx_t); - } - - auto elem_idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(elem_idx + offset); - - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - // the array could span multiple vectors, so we divide it into chunks - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // compute and add to the total - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t arr_elem_idx = 0; arr_elem_idx < chunk_size; arr_elem_idx++) { - entry_sizes[i] += array_entry_sizes[arr_elem_idx]; - } - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::ARRAY: - ComputeArrayEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - NestedValidity struct_validity(struct_validitymask_locations, i); - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, key_locations, &struct_validity, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (parent_validity) { - // set the row validitymask for this column to invalid - parent_validity->SetInvalid(i); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, list_entry_locations, nullptr, - entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapScatterArrayVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_type = ArrayType::GetChildType(v.GetType()); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - UnifiedVectorFormat child_vdata; - child_vector.ToUnifiedFormat(ArrayVector::GetTotalSize(v), child_vdata); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - // Set if the whole array itself is null in the parent entry - auto source_idx = vdata.sel->get_index(sel.get_index(i) + offset); - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - - // Now we can serialize the array itself - // Every array starts with a validity mask for the children - data_ptr_t array_validitymask_location = key_locations[i]; - memset(array_validitymask_location, -1, array_validitymask_size); - key_locations[i] += array_validitymask_size; - - NestedValidity array_parent_validity(array_validitymask_location); - - // If the array contains variable size entries, we reserve spaces for them here - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // Then comes the elements - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - while (elem_remaining > 0) { - // the array elements can span multiple vectors, so we divide it into chunks - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // Setup the locations for the elements - if (child_type_is_var_size) { - // The elements are variable sized - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += array_entry_sizes[elem_idx]; - - // Now store the size of the entry - Store(array_entry_sizes[elem_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } else { - // The elements are constant sized - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - } - } - - RowOperations::HeapScatter(child_vector, ArrayVector::GetTotalSize(v), - *FlatVector::IncrementalSelectionVector(), chunk_size, array_entry_locations, - &array_parent_validity, array_start); - - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - array_parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, key_locations, - parent_validity, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::ARRAY: - HeapScatterArrayVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb diff --git a/src/common/row_operations/row_radix_scatter.cpp b/src/common/row_operations/row_radix_scatter.cpp deleted file mode 100644 index a85a7199776e..000000000000 --- a/src/common/row_operations/row_radix_scatter.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 2, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 2); - key_location += width - 2; - } - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width, array_offset); - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { -#ifdef DEBUG - // initialize to verify written width later - auto key_locations_copy = make_uniq_array(ser_count); - for (idx_t i = 0; i < ser_count; i++) { - key_locations_copy[i] = key_locations[i]; - } -#endif - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - case PhysicalType::ARRAY: - RadixScatterArrayVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } - -#ifdef DEBUG - for (idx_t i = 0; i < ser_count; i++) { - D_ASSERT(key_locations[i] == key_locations_copy[i] + width); - } -#endif -} - -} // namespace duckdb diff --git a/src/common/row_operations/row_scatter.cpp b/src/common/row_operations/row_scatter.cpp deleted file mode 100644 index 1912d248474a..000000000000 --- a/src/common/row_operations/row_scatter.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no, const idx_t col_count) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx], col_count); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t col_count) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row, col_count); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), UnsafeNumericCast(str.GetSize())); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - NestedValidity parent_validity(validitymask_locations, col_no); - RowOperations::HeapScatter(vec, vcount, sel, count, data_locations, &parent_validity); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - idx_t column_count = layout.ColumnCount(); - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row, column_count).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(NumericCast(entry_sizes[i]), data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb diff --git a/src/common/sort/CMakeLists.txt b/src/common/sort/CMakeLists.txt index c0459bd4572d..8d61e5c3d359 100644 --- a/src/common/sort/CMakeLists.txt +++ b/src/common/sort/CMakeLists.txt @@ -1,11 +1,5 @@ -add_library_unity( - duckdb_sort - OBJECT - comparators.cpp - merge_sorter.cpp - radix_sort.cpp - sort_state.cpp - sorted_block.cpp) +add_library_unity(duckdb_sort OBJECT hashed_sort.cpp sort.cpp sorted_run.cpp + sorted_run_merger.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/common/sort/comparators.cpp b/src/common/sort/comparators.cpp deleted file mode 100644 index 4df4cccc4430..000000000000 --- a/src/common/sort/comparators.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/sort/comparators.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr, sort_layout.column_count); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col] && tie_string.GetSize() > 0) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::ARRAY: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout) && !TieIsBreakable(tie_col, r_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - case PhysicalType::ARRAY: - return CompareArrayAndAdvance(l_ptr, r_ptr, ArrayType::GetChildType(type), valid, ArrayType::GetSize(type)); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr, types.size()); - ValidityBytes right_validity(right_ptr, types.size()); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid, idx_t array_size) { - if (!valid) { - return 0; - } - - // Load array validity masks - ValidityBytes left_validity(left_ptr, array_size); - ValidityBytes right_validity(right_ptr, array_size); - left_ptr += (array_size + 7) / 8; - right_ptr += (array_size + 7) / 8; - - int comp_res = 0; - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT8: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT128: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INTERVAL: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized array entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += array_size * sizeof(idx_t); - right_ptr += array_size * sizeof(idx_t); - for (idx_t i = 0; i < array_size; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareArrayAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr, left_len); - ValidityBytes right_validity(right_ptr, right_len); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); -} - -} // namespace duckdb diff --git a/src/common/sorting/hashed_sort.cpp b/src/common/sort/hashed_sort.cpp similarity index 100% rename from src/common/sorting/hashed_sort.cpp rename to src/common/sort/hashed_sort.cpp diff --git a/src/common/sort/merge_sorter.cpp b/src/common/sort/merge_sorter.cpp deleted file mode 100644 index c670fd574381..000000000000 --- a/src/common/sort/merge_sorter.cpp +++ /dev/null @@ -1,667 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - // Check for interrupts after merging a partition - if (state.context.interrupted) { - throw InterruptException(); - } - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb diff --git a/src/common/sort/radix_sort.cpp b/src/common/sort/radix_sort.cpp deleted file mode 100644 index b193cee619df..000000000000 --- a/src/common/sort/radix_sort.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/duckdb_pdqsort.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } - - if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } - - if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } - - const auto block_size = buffer_manager.GetBlockSize(); - auto temp_block = - buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); - auto pre_allocated_array = - make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - pre_allocated_array.get(), false); -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array_uninitialized(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb diff --git a/src/common/sorting/sort.cpp b/src/common/sort/sort.cpp similarity index 96% rename from src/common/sorting/sort.cpp rename to src/common/sort/sort.cpp index 56bde8499a28..8f2a1e6e7b5c 100644 --- a/src/common/sorting/sort.cpp +++ b/src/common/sort/sort.cpp @@ -377,6 +377,15 @@ class SortGlobalSourceState : public GlobalSourceState { return merger_global_state ? merger_global_state->MaxThreads() : 1; } + void Destroy() { + if (!merger_global_state) { + return; + } + auto guard = merger_global_state->Lock(); + merger.sorted_runs.clear(); + sink.temporary_memory_state.reset(); + } + public: //! The global sink state SortGlobalSinkState &sink; @@ -476,16 +485,26 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator } // Merge into global output collection - auto guard = gstate.Lock(); - if (!gstate.column_data) { - gstate.column_data = std::move(local_column_data); - } else { - gstate.column_data->Merge(*local_column_data); + { + auto guard = gstate.Lock(); + if (!gstate.column_data) { + gstate.column_data = std::move(local_column_data); + } else { + gstate.column_data->Merge(*local_column_data); + } } + // Destroy local state before returning + input.local_state.Cast().merger_local_state.reset(); + // Return type indicates whether materialization is done const auto progress_data = GetProgress(context.client, input.global_state); - return progress_data.done == progress_data.total ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + if (progress_data.done == progress_data.total) { + // Destroy global state before returning + gstate.Destroy(); + return SourceResultType::FINISHED; + } + return SourceResultType::HAVE_MORE_OUTPUT; } unique_ptr Sort::GetColumnData(OperatorSourceInput &input) const { diff --git a/src/common/sort/sort_state.cpp b/src/common/sort/sort_state.cpp deleted file mode 100644 index 369f032f197c..000000000000 --- a/src/common/sort/sort_state.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - case PhysicalType::ARRAY: - // Arrays get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, ArrayType::GetChildType(type)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - const auto block_size = buffer_manager->GetBlockSize(); - - // Radix sorting data - auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); - radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); - - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); - blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); - } - - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); - payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); - payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto &buffer_manager = row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager.Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager.Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, - unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(ClientContext &context_p, const vector &orders, - RowLayout &payload_layout) - : context(context_p), buffer_manager(BufferManager::GetBufferManager(context)), sort_layout(SortLayout(orders)), - payload_layout(payload_layout), block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb diff --git a/src/common/sort/sorted_block.cpp b/src/common/sort/sorted_block.cpp deleted file mode 100644 index c4766c956b40..000000000000 --- a/src/common/sort/sorted_block.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/common/sort/sorted_block.hpp" - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); - radix_sorting_data.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), external(gss.external), - cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb diff --git a/src/common/sorting/sorted_run.cpp b/src/common/sort/sorted_run.cpp similarity index 100% rename from src/common/sorting/sorted_run.cpp rename to src/common/sort/sorted_run.cpp diff --git a/src/common/sorting/sorted_run_merger.cpp b/src/common/sort/sorted_run_merger.cpp similarity index 99% rename from src/common/sorting/sorted_run_merger.cpp rename to src/common/sort/sorted_run_merger.cpp index d87cef470b8f..874a7fc0415d 100644 --- a/src/common/sorting/sorted_run_merger.cpp +++ b/src/common/sort/sorted_run_merger.cpp @@ -844,6 +844,7 @@ SortedRunMerger::SortedRunMerger(const Sort &sort_p, vector SortedRunMerger::GetLocalSourceState(ExecutionContext &, GlobalSourceState &gstate_p) const { auto &gstate = gstate_p.Cast(); + auto guard = gstate.Lock(); return make_uniq(gstate); } diff --git a/src/common/sorting/CMakeLists.txt b/src/common/sorting/CMakeLists.txt deleted file mode 100644 index 09366bb04bf0..000000000000 --- a/src/common/sorting/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -add_library_unity(duckdb_sorting OBJECT hashed_sort.cpp sort.cpp sorted_run.cpp - sorted_run_merger.cpp) - -set(ALL_OBJECT_FILES - ${ALL_OBJECT_FILES} $ - PARENT_SCOPE) diff --git a/src/common/types.cpp b/src/common/types.cpp index 15fd9364e08d..6542fafc2772 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -31,6 +31,9 @@ namespace duckdb { +constexpr idx_t ArrayType::MAX_ARRAY_SIZE; +const idx_t UnionType::MAX_UNION_MEMBERS; + LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { } diff --git a/src/common/types/decimal.cpp b/src/common/types/decimal.cpp index 5ecb39a0a818..8fa2264551c8 100644 --- a/src/common/types/decimal.cpp +++ b/src/common/types/decimal.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/types/cast_helpers.hpp" namespace duckdb { +constexpr uint8_t Decimal::MAX_WIDTH_DECIMAL; template string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { diff --git a/src/common/types/geometry.cpp b/src/common/types/geometry.cpp index e05816546fef..5d66bed291f9 100644 --- a/src/common/types/geometry.cpp +++ b/src/common/types/geometry.cpp @@ -748,6 +748,8 @@ void ToStringRecursive(BlobReader &reader, TextWriter &writer, idx_t depth, bool //---------------------------------------------------------------------------------------------------------------------- namespace duckdb { +constexpr const idx_t Geometry::MAX_RECURSION_DEPTH; + bool Geometry::FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict) { TextReader reader(wkt_text.GetData(), static_cast(wkt_text.GetSize())); BlobWriter writer; diff --git a/src/common/types/row/CMakeLists.txt b/src/common/types/row/CMakeLists.txt index 384de7fd9bcc..d10285942c9f 100644 --- a/src/common/types/row/CMakeLists.txt +++ b/src/common/types/row/CMakeLists.txt @@ -8,9 +8,6 @@ add_library_unity( OBJECT block_iterator.cpp partitioned_tuple_data.cpp - row_data_collection.cpp - row_data_collection_scanner.cpp - row_layout.cpp tuple_data_allocator.cpp tuple_data_collection.cpp tuple_data_iterator.cpp diff --git a/src/common/types/row/row_data_collection.cpp b/src/common/types/row/row_data_collection.cpp deleted file mode 100644 index b178b7fb50bd..000000000000 --- a/src/common/types/row/row_data_collection.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection.hpp" - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb diff --git a/src/common/types/row/row_data_collection_scanner.cpp b/src/common/types/row/row_data_collection_scanner.cpp deleted file mode 100644 index 9b3a4be06efc..000000000000 --- a/src/common/types/row/row_data_collection_scanner.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, - NumericCast(heap_offset)); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + NumericCast(block_idx); - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); -} - -void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - auto &data_block = rows.blocks[block_idx]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); - } -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb diff --git a/src/common/types/row/row_layout.cpp b/src/common/types/row/row_layout.cpp deleted file mode 100644 index 3add8e425d98..000000000000 --- a/src/common/types/row/row_layout.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/common/types/row/row_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb diff --git a/src/common/types/string_type.cpp b/src/common/types/string_type.cpp index f5a236557760..bea85327a5c8 100644 --- a/src/common/types/string_type.cpp +++ b/src/common/types/string_type.cpp @@ -6,6 +6,8 @@ #include "utf8proc_wrapper.hpp" namespace duckdb { +constexpr idx_t string_t::MAX_STRING_SIZE; +constexpr idx_t string_t::INLINE_LENGTH; void string_t::Verify() const { #ifdef DEBUG diff --git a/src/common/vector_operations/is_distinct_from.cpp b/src/common/vector_operations/is_distinct_from.cpp index e57f9738d1e1..d2370229acc4 100644 --- a/src/common/vector_operations/is_distinct_from.cpp +++ b/src/common/vector_operations/is_distinct_from.cpp @@ -1,6 +1,8 @@ #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" namespace duckdb { @@ -289,6 +291,7 @@ template idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel, optional_ptr null_mask) { if (!sel) { + D_ASSERT(count <= STANDARD_VECTOR_SIZE); sel = FlatVector::IncrementalSelectionVector(); } @@ -478,21 +481,22 @@ void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, } void ExtractNestedMask(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - ValidityMask *child_mask, optional_ptr null_mask) { + ValidityMask *child_mask_p, optional_ptr null_mask) { - if (!child_mask) { + if (!child_mask_p) { return; } + auto &child_mask = *child_mask_p; for (idx_t i = 0; i < count; ++i) { const auto slice_idx = slice_sel.get_index(i); const auto result_idx = sel.get_index(slice_idx); - if (child_mask && !child_mask->RowIsValid(slice_idx)) { + if (!child_mask.RowIsValid(slice_idx)) { null_mask->SetInvalid(result_idx); } } - child_mask->Reset(null_mask->Capacity()); + child_mask.Reset(null_mask->Capacity()); } void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { @@ -890,6 +894,7 @@ idx_t DistinctSelectNested(Vector &left, Vector &right, optional_ptr(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, false_opt, null_mask); - switch (left.GetType().InternalType()) { + auto &left_type = left.GetType(); + switch (left_type.InternalType()) { case PhysicalType::LIST: match_count += DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); diff --git a/src/execution/expression_executor.cpp b/src/execution/expression_executor.cpp index 6707e28e96a8..8ea2bba0909a 100644 --- a/src/execution/expression_executor.cpp +++ b/src/execution/expression_executor.cpp @@ -182,6 +182,7 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co VectorOperations::DefaultCast(vector, intermediate, count, true); } intermediate.Verify(count); + //! FIXME: this is probably also where we want to test 'variant_normalize' Vector result(vector.GetType(), true, false, count); //! Then cast back into the original type diff --git a/src/execution/index/art/base_leaf.cpp b/src/execution/index/art/base_leaf.cpp index a694ca3b593d..4a9332fc9d17 100644 --- a/src/execution/index/art/base_leaf.cpp +++ b/src/execution/index/art/base_leaf.cpp @@ -30,8 +30,10 @@ void BaseLeaf::InsertByteInternal(BaseLeaf &n, const uint8_t byt } template -BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, Node &node, const uint8_t byte) { - auto &n = Node::Ref(art, node, node.GetType()); +NodeHandle> BaseLeaf::DeleteByteInternal(ART &art, Node &node, + const uint8_t byte) { + NodeHandle> handle(art, node); + auto &n = handle.Get(); uint8_t child_pos = 0; for (; child_pos < n.count; child_pos++) { @@ -45,7 +47,7 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, for (uint8_t i = child_pos; i < n.count; i++) { n.key[i] = n.key[i + 1]; } - return n; + return handle; } //===--------------------------------------------------------------------===// @@ -53,27 +55,36 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, //===--------------------------------------------------------------------===// void Node7Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node15. - auto &n7 = Node::Ref(art, node, NODE_7_LEAF); - if (n7.count == CAPACITY) { - auto node7 = node; - Node15Leaf::GrowNode7Leaf(art, node, node7); - Node15Leaf::InsertByte(art, node, byte); - return; - } + { + NodeHandle handle(art, node); + auto &n7 = handle.Get(); - InsertByteInternal(n7, byte); + if (n7.count != CAPACITY) { + InsertByteInternal(n7, byte); + return; + } + } + // The node is full. Grow to Node15. + auto node7 = node; + Node15Leaf::GrowNode7Leaf(art, node, node7); + Node15Leaf::InsertByte(art, node, byte); } void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byte, const ARTKey &row_id) { - auto &n7 = DeleteByteInternal(art, node, byte); + idx_t remainder; + { + auto n7_handle = DeleteByteInternal(art, node, byte); + auto &n7 = n7_handle.Get(); + + if (n7.count != 1) { + return; + } - // Compress one-way nodes. - if (n7.count == 1) { + // Compress one-way nodes. D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); // Get the remaining row ID. - auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; + remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; remainder |= UnsafeNumericCast(n7.key[0]); // Free the prefix (nodes) and inline the remainder. @@ -82,23 +93,27 @@ void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byt Leaf::New(prefix, UnsafeNumericCast(remainder)); return; } - - // Free the Node7Leaf and inline the remainder. - Node::FreeNode(art, node); - Leaf::New(node, UnsafeNumericCast(remainder)); } + // Free the Node7Leaf and inline the remainder. + Node::FreeNode(art, node); + Leaf::New(node, UnsafeNumericCast(remainder)); } void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) { - auto &n7 = New(art, node7_leaf); - auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); - node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + { + auto n7_handle = New(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n7.count = n15.count; - for (uint8_t i = 0; i < n15.count; i++) { - n7.key[i] = n15.key[i]; - } + NodeHandle n15_handle(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + + n7.count = n15.count; + for (uint8_t i = 0; i < n15.count; i++) { + n7.key[i] = n15.key[i]; + } + } Node::FreeNode(art, node15_leaf); } @@ -107,54 +122,66 @@ void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) //===--------------------------------------------------------------------===// void Node15Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node256Leaf. - auto &n15 = Node::Ref(art, node, NODE_15_LEAF); - if (n15.count == CAPACITY) { - auto node15 = node; - Node256Leaf::GrowNode15Leaf(art, node, node15); - Node256Leaf::InsertByte(art, node, byte); - return; + { + NodeHandle n15_handle(art, node); + auto &n15 = n15_handle.Get(); + if (n15.count != CAPACITY) { + InsertByteInternal(n15, byte); + return; + } } - - InsertByteInternal(n15, byte); + auto node15 = node; + Node256Leaf::GrowNode15Leaf(art, node, node15); + Node256Leaf::InsertByte(art, node, byte); } void Node15Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { - auto &n15 = DeleteByteInternal(art, node, byte); - - // Shrink node to Node7. - if (n15.count < Node7Leaf::CAPACITY) { - auto node15 = node; - Node7Leaf::ShrinkNode15Leaf(art, node, node15); + { + auto n15_handle = DeleteByteInternal(art, node, byte); + auto &n15 = n15_handle.Get(); + if (n15.count >= Node7Leaf::CAPACITY) { + return; + } } + auto node15 = node; + Node7Leaf::ShrinkNode15Leaf(art, node, node15); } void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { - auto &n7 = Node::Ref(art, node7_leaf, NType::NODE_7_LEAF); - auto &n15 = New(art, node15_leaf); - node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + { + NodeHandle n7_handle(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n15.count = n7.count; - for (uint8_t i = 0; i < n7.count; i++) { - n15.key[i] = n7.key[i]; - } + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + n15.count = n7.count; + for (uint8_t i = 0; i < n7.count; i++) { + n15.key[i] = n7.key[i]; + } + } Node::FreeNode(art, node7_leaf); } void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { - auto &n15 = New(art, node15_leaf); - auto &n256 = Node::Ref(art, node256_leaf, NType::NODE_256_LEAF); - node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); - - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (mask.RowIsValid(i)) { - n15.key[n15.count] = UnsafeNumericCast(i); - n15.count++; + { + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + + NodeHandle n256_handle(art, node256_leaf); + auto &n256 = n256_handle.Get(); + + node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); + + ValidityMask mask(&n256.mask[0], Node256::CAPACITY); + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (mask.RowIsValid(i)) { + n15.key[n15.count] = UnsafeNumericCast(i); + n15.count++; + } } } - Node::FreeNode(art, node256_leaf); } diff --git a/src/execution/index/fixed_size_allocator.cpp b/src/execution/index/fixed_size_allocator.cpp index dd4758bb971a..cffd0b61c7dc 100644 --- a/src/execution/index/fixed_size_allocator.cpp +++ b/src/execution/index/fixed_size_allocator.cpp @@ -4,9 +4,9 @@ namespace duckdb { -FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) - : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), segment_size(segment_size), - total_segment_count(0) { +FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, MemoryTag memory_tag) + : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), memory_tag(memory_tag), + segment_size(segment_size), total_segment_count(0) { if (segment_size > block_manager.GetBlockSize() - sizeof(validity_t)) { throw InternalException("The maximum segment size of fixed-size allocators is " + @@ -48,7 +48,7 @@ IndexPointer FixedSizeAllocator::New() { if (!buffer_with_free_space.IsValid()) { // Add a new buffer. auto buffer_id = GetAvailableBufferId(); - buffers[buffer_id] = make_uniq(block_manager); + buffers[buffer_id] = make_uniq(block_manager, memory_tag); buffers_with_free_space.insert(buffer_id); buffer_with_free_space = buffer_id; diff --git a/src/execution/index/fixed_size_buffer.cpp b/src/execution/index/fixed_size_buffer.cpp index 82bbccac2783..26e56cd5ce03 100644 --- a/src/execution/index/fixed_size_buffer.cpp +++ b/src/execution/index/fixed_size_buffer.cpp @@ -35,12 +35,12 @@ void PartialBlockForIndex::Clear() { constexpr idx_t FixedSizeBuffer::BASE[]; constexpr uint8_t FixedSizeBuffer::SHIFT[]; -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) +FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag) : block_manager(block_manager), readers(0), segment_count(0), allocation_size(0), dirty(false), vacuum(false), loaded(false), block_pointer(), block_handle(nullptr) { auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, &block_manager, false); + buffer_handle = buffer_manager.Allocate(memory_tag, &block_manager, false); block_handle = buffer_handle.GetBlockHandle(); // Zero-initialize the buffer as it might get serialized to storage. diff --git a/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/execution/operator/helper/physical_buffered_batch_collector.cpp index 404d143431b5..d08332a72ac6 100644 --- a/src/execution/operator/helper/physical_buffered_batch_collector.cpp +++ b/src/execution/operator/helper/physical_buffered_batch_collector.cpp @@ -94,7 +94,7 @@ unique_ptr PhysicalBufferedBatchCollector::GetLocalSinkState(Exe unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } diff --git a/src/execution/operator/helper/physical_buffered_collector.cpp b/src/execution/operator/helper/physical_buffered_collector.cpp index 7795230dc514..f0bdea11cb75 100644 --- a/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/execution/operator/helper/physical_buffered_collector.cpp @@ -48,7 +48,7 @@ SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &conte unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } diff --git a/src/execution/operator/helper/physical_limit.cpp b/src/execution/operator/helper/physical_limit.cpp index 5a4339c63052..3963ff36e601 100644 --- a/src/execution/operator/helper/physical_limit.cpp +++ b/src/execution/operator/helper/physical_limit.cpp @@ -8,6 +8,8 @@ namespace duckdb { +constexpr const idx_t PhysicalLimit::MAX_LIMIT_VALUE; + PhysicalLimit::PhysicalLimit(PhysicalPlan &physical_plan, vector types, BoundLimitNode limit_val_p, BoundLimitNode offset_val_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), diff --git a/src/execution/operator/helper/physical_reset.cpp b/src/execution/operator/helper/physical_reset.cpp index 1f5baf75d1f5..9476b4e23c48 100644 --- a/src/execution/operator/helper/physical_reset.cpp +++ b/src/execution/operator/helper/physical_reset.cpp @@ -36,8 +36,7 @@ SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &ch auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { - throw InvalidInputException("Extension parameter %s was not found after autoloading", - name.ToStdString()); + throw InvalidInputException("Extension parameter %s was not found after autoloading", name); } } ResetExtensionVariable(context, config, entry->second); diff --git a/src/execution/operator/join/physical_asof_join.cpp b/src/execution/operator/join/physical_asof_join.cpp index 31fc85160048..08d2ed320b39 100644 --- a/src/execution/operator/join/physical_asof_join.cpp +++ b/src/execution/operator/join/physical_asof_join.cpp @@ -82,10 +82,8 @@ class AsOfGlobalSinkState : public GlobalSinkState { using HashedSortPtr = unique_ptr; using HashedSinkPtr = unique_ptr; using PartitionMarkers = vector; - using HashGroupPtr = unique_ptr; - using HashGroups = vector; - AsOfGlobalSinkState(ClientContext &client, const PhysicalAsOfJoin &op) : is_outer(IsRightOuterJoin(op.join_type)) { + AsOfGlobalSinkState(ClientContext &client, const PhysicalAsOfJoin &op) { // Set up partitions for both sides hashed_sorts.reserve(2); hashed_sinks.reserve(2); @@ -101,8 +99,6 @@ class AsOfGlobalSinkState : public GlobalSinkState { rhs.estimated_cardinality, true); hashed_sinks.emplace_back(sort->GetGlobalSinkState(client)); hashed_sorts.emplace_back(std::move(sort)); - - hash_groups.resize(2); } //! The child that is being materialised (right/1 then left/0) @@ -111,12 +107,6 @@ class AsOfGlobalSinkState : public GlobalSinkState { vector hashed_sorts; //! The child's partitioning buffer vector hashed_sinks; - //! The child's hash groups - vector hash_groups; - //! Whether the right side is outer - const bool is_outer; - //! The right outer join markers (one per partition) - vector right_outers; }; class AsOfLocalSinkState : public LocalSinkState { @@ -203,6 +193,8 @@ OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// +enum class AsOfJoinSourceStage : uint8_t { INNER, RIGHT, DONE }; + class AsOfPayloadScanner { public: using Types = vector; @@ -306,11 +298,79 @@ AsOfPayloadScanner::AsOfPayloadScanner(const SortedRun &sorted_run, const Hashed } } +class AsOfLocalSourceState; + +class AsOfGlobalSourceState : public GlobalSourceState { +public: + using HashGroupPtr = unique_ptr; + using HashGroups = vector; + + AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op); + + //! Assign a new task to the local state + bool AssignTask(AsOfLocalSourceState &lsource); + //! Can we shift to the next stage? + bool TryPrepareNextStage(); + + //! The parent operator + const PhysicalAsOfJoin &op; + //! For synchronizing the external hash join + atomic stage; + //! The child's hash groups + vector hashed_groups; + //! Whether the right side is outer + const bool is_right_outer; + //! The right outer join markers (one per partition) + vector right_outers; + //! The next buffer to flush + atomic next_left; + //! The number of flushed buffers + atomic flushed_left; + //! The right outer output read position. + atomic next_right; + //! The right outer output read position. + atomic flushed_right; + +public: + idx_t MaxThreads() override { + return hashed_groups[1].size(); + } +}; + +AsOfGlobalSourceState::AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op) + : op(op), stage(AsOfJoinSourceStage::INNER), is_right_outer(IsRightOuterJoin(op.join_type)), next_left(0), + flushed_left(0), next_right(0), flushed_right(0) { + + // Take ownership of the hash groups + auto &gsink = op.sink_state->Cast(); + hashed_groups.resize(2); + for (idx_t child = 0; child < 2; ++child) { + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_sink = *gsink.hashed_sinks[child]; + auto hashed_source = hashed_sort.GetGlobalSourceState(client, hashed_sink); + auto &sorted_runs = hashed_sort.GetSortedRuns(*hashed_source); + auto &hash_groups = hashed_groups[child]; + hash_groups.resize(sorted_runs.size()); + + for (idx_t group_idx = 0; group_idx < sorted_runs.size(); ++group_idx) { + hash_groups[group_idx] = std::move(sorted_runs[group_idx]); + } + } + + // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple + auto &rhs_groups = hashed_groups[1]; + right_outers.reserve(rhs_groups.size()); + for (const auto &hash_group : rhs_groups) { + right_outers.emplace_back(OuterJoinMarker(is_right_outer)); + right_outers.back().Initialize(hash_group ? hash_group->Count() : 0); + } +} + class AsOfProbeBuffer { public: using Orders = vector; - AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); + AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource); public: // Comparison utilities @@ -376,6 +436,8 @@ class AsOfProbeBuffer { ClientContext &client; const PhysicalAsOfJoin &op; + //! The source state + AsOfGlobalSourceState &gsource; //! Is the inequality strict? const bool strict; @@ -412,9 +474,9 @@ class AsOfProbeBuffer { bool fetch_next_left; }; -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op) - : client(client), op(op), strict(IsStrictComparison(op.comparison_type)), left_outer(IsLeftOuterJoin(op.join_type)), - lhs_executor(client), rhs_executor(client), fetch_next_left(true) { +AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource) + : client(client), op(op), gsource(gsource), strict(IsStrictComparison(op.comparison_type)), + left_outer(IsLeftOuterJoin(op.join_type)), lhs_executor(client), rhs_executor(client), fetch_next_left(true) { lhs_keys.Initialize(client, op.join_key_types); for (const auto &cond : op.conditions) { @@ -444,14 +506,14 @@ void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { auto &gsink = op.sink_state->Cast(); // Always set right_bin too for memory management - auto &rhs_groups = gsink.hash_groups[1]; + auto &rhs_groups = gsource.hashed_groups[1]; if (scan_bin < rhs_groups.size()) { right_bin = scan_bin; } else { right_bin = rhs_groups.size(); } - auto &lhs_groups = gsink.hash_groups[0]; + auto &lhs_groups = gsource.hashed_groups[0]; if (scan_bin < lhs_groups.size()) { left_bin = scan_bin; } else { @@ -509,7 +571,7 @@ void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { right_pos = 0; if (right_bin < rhs_groups.size()) { right_group = rhs_groups[right_bin].get(); - right_outer = gsink.right_outers.data() + right_bin; + right_outer = gsource.right_outers.data() + right_bin; if (right_group && right_group->Count()) { right_itr = CreateIteratorState(*right_group); rhs_scanner = make_uniq(*right_group, *gsink.hashed_sorts[1]); @@ -574,15 +636,13 @@ bool AsOfProbeBuffer::NextLeft() { } void AsOfProbeBuffer::EndLeftScan() { - auto &gsink = op.sink_state->Cast(); - right_group = nullptr; right_itr.reset(); rhs_scanner.reset(); right_outer = nullptr; - auto &rhs_groups = gsink.hash_groups[1]; - if (!gsink.is_outer && right_bin < rhs_groups.size()) { + auto &rhs_groups = gsource.hashed_groups[1]; + if (!gsource.is_right_outer && right_bin < rhs_groups.size()) { rhs_groups[right_bin].reset(); } @@ -590,7 +650,7 @@ void AsOfProbeBuffer::EndLeftScan() { left_itr.reset(); lhs_scanner.reset(); - auto &lhs_groups = gsink.hash_groups[0]; + auto &lhs_groups = gsource.hashed_groups[0]; if (left_bin < lhs_groups.size()) { lhs_groups[left_bin].reset(); } @@ -814,54 +874,8 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { } } -class AsOfGlobalSourceState : public GlobalSourceState { -public: - AsOfGlobalSourceState(ClientContext &client, AsOfGlobalSinkState &gsink_p); - - AsOfGlobalSinkState &gsink; - //! The next buffer to flush - atomic next_left; - //! The number of flushed buffers - atomic flushed; - //! The right outer output read position. - atomic next_right; - -public: - idx_t MaxThreads() override { - return gsink.hash_groups[1].size(); - } -}; - -AsOfGlobalSourceState::AsOfGlobalSourceState(ClientContext &client, AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_left(0), flushed(0), next_right(0) { - - // Take ownership of the hash groups - for (idx_t child = 0; child < 2; ++child) { - auto &hashed_sort = *gsink.hashed_sorts[child]; - auto &hashed_sink = *gsink.hashed_sinks[child]; - auto hashed_source = hashed_sort.GetGlobalSourceState(client, hashed_sink); - auto &sorted_runs = hashed_sort.GetSortedRuns(*hashed_source); - auto &hash_groups = gsink.hash_groups[child]; - hash_groups.resize(sorted_runs.size()); - - for (idx_t group_idx = 0; group_idx < sorted_runs.size(); ++group_idx) { - hash_groups[group_idx] = std::move(sorted_runs[group_idx]); - } - } - - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.hash_groups[1]; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.size()); - for (const auto &hash_group : rhs_partition) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group ? hash_group->Count() : 0); - } -} - unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &client) const { - auto &gsink = sink_state->Cast(); - return make_uniq(client, gsink); + return make_uniq(client, *this); } class AsOfLocalSourceState : public LocalSourceState { @@ -870,6 +884,26 @@ class AsOfLocalSourceState : public LocalSourceState { AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op); + //! Task management + bool TaskFinished() const { + if (hash_group) { + return !scanner.get(); + } else { + return !probe_buffer.Scanning(); + } + } + + void ExecuteInnerTask(DataChunk &chunk); + void ExecuteOuterTask(DataChunk &chunk); + + void ExecuteTask(DataChunk &chunk) { + if (hash_group) { + ExecuteOuterTask(chunk); + } else { + ExecuteInnerTask(chunk); + } + } + idx_t BeginRightScan(const idx_t hash_bin); AsOfGlobalSourceState &gsource; @@ -878,24 +912,30 @@ class AsOfLocalSourceState : public LocalSourceState { //! The left side partition being probed AsOfProbeBuffer probe_buffer; - //! The read partition + //! The rhs group idx_t hash_bin; HashGroupPtr hash_group; //! The read cursor unique_ptr scanner; + //! The right outer buffer + DataChunk rhs_chunk; + //! The right outer slicer + SelectionVector rsel; //! Pointer to the right marker const bool *rhs_matches = {}; }; AsOfLocalSourceState::AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op) - : gsource(gsource), context(context), probe_buffer(context.client, op) { + : gsource(gsource), context(context), probe_buffer(context.client, op, gsource), rsel(STANDARD_VECTOR_SIZE) { + + rhs_chunk.Initialize(context.client, op.children[1].get().GetTypes()); } idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { hash_bin = hash_bin_p; - auto &rhs_groups = gsource.gsink.hash_groups[1]; + auto &rhs_groups = gsource.hashed_groups[1]; if (hash_bin >= rhs_groups.size()) { return 0; } @@ -904,9 +944,10 @@ idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { if (!hash_group || !hash_group->Count()) { return 0; } - scanner = make_uniq(*hash_group, *gsource.gsink.hashed_sorts[1]); + auto &gsink = gsource.op.sink_state->Cast(); + scanner = make_uniq(*hash_group, *gsink.hashed_sorts[1]); - rhs_matches = gsource.gsink.right_outers[hash_bin].GetMatches(); + rhs_matches = gsource.right_outers[hash_bin].GetMatches(); return scanner->Remaining(); } @@ -917,106 +958,134 @@ unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionCont return make_uniq(context, gsource, *this); } -SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - auto &rhs_groups = gsource.gsink.hash_groups[1]; - auto &client = context.client; - - // Step 1: Join the partitions - auto &lhs_groups = gsource.gsink.hash_groups[0]; - const auto left_bins = lhs_groups.size(); - while (gsource.flushed < left_bins) { - // Make sure we have something to flush - if (!lsource.probe_buffer.Scanning()) { - const auto left_bin = gsource.next_left++; - if (left_bin < left_bins) { - // More to flush - lsource.probe_buffer.BeginLeftScan(left_bin); - } else if (!IsRightOuterJoin(join_type) || client.interrupted) { - return SourceResultType::FINISHED; +bool AsOfGlobalSourceState::TryPrepareNextStage() { + // Inside the lock. + auto &lhs_groups = hashed_groups[0]; + auto &rhs_groups = hashed_groups[1]; + switch (stage.load()) { + case AsOfJoinSourceStage::INNER: + if (flushed_left >= lhs_groups.size()) { + stage = IsRightOuterJoin(op.join_type) ? AsOfJoinSourceStage::RIGHT : AsOfJoinSourceStage::DONE; + return true; + } + break; + case AsOfJoinSourceStage::RIGHT: + if (flushed_right >= rhs_groups.size()) { + stage = AsOfJoinSourceStage::DONE; + return true; + } + break; + default: + break; + } + return false; +} + +bool AsOfGlobalSourceState::AssignTask(AsOfLocalSourceState &lsource) { + auto guard = Lock(); + + auto &lhs_groups = hashed_groups[0]; + auto &rhs_groups = hashed_groups[1]; + + switch (stage.load()) { + case AsOfJoinSourceStage::INNER: + while (next_left < lhs_groups.size()) { + // More to flush + const auto left_bin = next_left++; + lsource.probe_buffer.BeginLeftScan(left_bin); + if (!lsource.TaskFinished()) { + return true; } else { - // Wait for all threads to finish - // TODO: How to implement a spin wait correctly? - // Returning BLOCKED seems to hang the system. - TaskScheduler::GetScheduler(client).YieldThread(); - continue; + ++flushed_left; } } - - lsource.probe_buffer.GetData(context, chunk); - if (chunk.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (lsource.probe_buffer.HasMoreData()) { - // Join the next partition - continue; - } else { - lsource.probe_buffer.EndLeftScan(); - gsource.flushed++; + break; + case AsOfJoinSourceStage::RIGHT: + while (next_right < rhs_groups.size()) { + const auto right_bin = next_right++; + lsource.BeginRightScan(right_bin); + if (!lsource.TaskFinished()) { + return true; + } else { + ++flushed_right; + } } + break; + default: + break; } - // Step 2: Emit right join matches - if (!IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; + return false; +} + +void AsOfLocalSourceState::ExecuteInnerTask(DataChunk &chunk) { + while (probe_buffer.HasMoreData()) { + probe_buffer.GetData(context, chunk); + if (chunk.size()) { + return; + } } + probe_buffer.EndLeftScan(); + gsource.flushed_left++; +} - DataChunk rhs_chunk; - rhs_chunk.Initialize(context.client, children[1].get().GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner || !lsource.scanner->Remaining()) { - lsource.scanner.reset(); - lsource.hash_group.reset(); - auto hash_bin = gsource.next_right++; - if (hash_bin >= rhs_groups.size()) { - return SourceResultType::FINISHED; - } +SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); - for (; hash_bin < rhs_groups.size(); hash_bin = gsource.next_right++) { - if (rhs_groups[hash_bin]) { - break; - } + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.stage != AsOfJoinSourceStage::DONE && chunk.size() == 0) { + if (!lsource.TaskFinished() || gsource.AssignTask(lsource)) { + lsource.ExecuteTask(chunk); + } else { + auto guard = gsource.Lock(); + if (gsource.TryPrepareNextStage() || gsource.stage == AsOfJoinSourceStage::DONE) { + gsource.UnblockTasks(guard); + } else { + return gsource.BlockSource(guard, input.interrupt_state); } - lsource.BeginRightScan(hash_bin); } - const auto rhs_position = lsource.scanner->Scanned(); - lsource.scanner->Scan(rhs_chunk); + } + + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +void AsOfLocalSourceState::ExecuteOuterTask(DataChunk &chunk) { + idx_t result_count = 0; + while (!result_count) { + const auto rhs_position = scanner->Scanned(); + scanner->Scan(rhs_chunk); const auto count = rhs_chunk.size(); if (count == 0) { - return SourceResultType::FINISHED; + scanner.reset(); + ++gsource.flushed_right; + return; } // figure out which tuples didn't find a match in the RHS - auto rhs_matches = lsource.rhs_matches; - idx_t result_count = 0; + result_count = 0; for (idx_t i = 0; i < count; i++) { if (!rhs_matches[rhs_position + i]) { rsel.set_index(result_count++, i); } } - - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0].get().GetTypes().size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { - const auto rhs_idx = right_projection_map[col_idx]; - chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); - } - chunk.SetCardinality(result_count); - break; - } } - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + // if there were any tuples that didn't find a match, output them + const auto &op = gsource.op; + const idx_t left_column_count = op.children[0].get().GetTypes().size(); + for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } + for (idx_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); + } + chunk.SetCardinality(result_count); } //===--------------------------------------------------------------------===// diff --git a/src/execution/operator/join/physical_iejoin.cpp b/src/execution/operator/join/physical_iejoin.cpp index 24981993ee2e..d4ea9f38f5e2 100644 --- a/src/execution/operator/join/physical_iejoin.cpp +++ b/src/execution/operator/join/physical_iejoin.cpp @@ -229,6 +229,8 @@ OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, Da //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// +enum class IEJoinSourceStage : uint8_t { INIT, INNER, OUTER, DONE }; + struct IEJoinUnion { using SortedTable = PhysicalRangeJoin::GlobalSortedTable; using ChunkRange = std::pair; @@ -804,13 +806,66 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse return result_count; } +class IEJoinLocalSourceState; + +class IEJoinGlobalSourceState : public GlobalSourceState { +public: + IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) + : op(op), gsink(gsink), stage(IEJoinSourceStage::INIT), next_pair(0), completed(0), left_outers(0), + next_left(0), right_outers(0), next_right(0) { + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + left_blocks = left_table.BlockCount(); + left_ranges = (left_blocks + left_per_thread - 1) / left_per_thread; + + right_blocks = right_table.BlockCount(); + right_ranges = (right_blocks + right_per_thread - 1) / right_per_thread; + + pair_count = left_ranges * right_ranges; + } + + void Initialize(); + bool TryPrepareNextStage(); + bool AssignTask(ExecutionContext &context, IEJoinLocalSourceState &lstate); + +public: + idx_t MaxThreads() override; + + ProgressData GetProgress() const; + + const PhysicalIEJoin &op; + IEJoinGlobalState &gsink; + + atomic stage; + + // Join queue state + idx_t left_blocks = 0; + idx_t left_ranges = 0; + const idx_t left_per_thread = 1024; + idx_t right_blocks = 0; + idx_t right_ranges = 0; + const idx_t right_per_thread = 1024; + idx_t pair_count; + atomic next_pair; + atomic completed; + + // Outer joins + atomic left_outers; + atomic next_left; + + atomic right_outers; + atomic next_right; +}; + class IEJoinLocalSourceState : public LocalSourceState { public: - IEJoinLocalSourceState(ClientContext &client, const PhysicalIEJoin &op) - : op(op), lsel(STANDARD_VECTOR_SIZE), rsel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), + IEJoinLocalSourceState(ClientContext &client, IEJoinGlobalSourceState &gsource) + : gsource(gsource), lsel(STANDARD_VECTOR_SIZE), rsel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), left_executor(client), right_executor(client), left_matches(nullptr), right_matches(nullptr) { + auto &op = gsource.op; auto &allocator = Allocator::Get(client); unprojected.InitializeEmpty(op.unprojected_types); lpayload.Initialize(allocator, op.children[0].get().GetTypes()); @@ -864,7 +919,21 @@ class IEJoinLocalSourceState : public LocalSourceState { return count; } - const PhysicalIEJoin &op; + // Are we executing a task? + bool TaskFinished() const { + return !joiner && !left_matches && !right_matches; + } + + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + void ResolveComplexJoin(ExecutionContext &context, DataChunk &result); + // Resolve left join results + void ExecuteLeftTask(ExecutionContext &context, DataChunk &result); + // Resolve right join results + void ExecuteRightTask(ExecutionContext &context, DataChunk &result); + // Execute the current task + void ExecuteTask(ExecutionContext &context, DataChunk &result); + + IEJoinGlobalSourceState &gsource; // Joining unique_ptr joiner; @@ -903,45 +972,46 @@ class IEJoinLocalSourceState : public LocalSourceState { bool *right_matches; }; -void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { - auto &state = state_p.Cast(); - auto &ie_sink = sink_state->Cast(); +void IEJoinLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &result) { + if (joiner) { + ResolveComplexJoin(context, result); + } else if (left_matches != nullptr) { + ExecuteLeftTask(context, result); + } else if (right_matches != nullptr) { + ExecuteRightTask(context, result); + } +} - auto &chunk = state.unprojected; +void IEJoinLocalSourceState::ResolveComplexJoin(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto &conditions = op.conditions; + + auto &chunk = unprojected; auto &left_table = *ie_sink.tables[0]; - auto &lsel = state.lsel; - auto &lpayload = state.lpayload; - auto &left_iterator = *state.left_iterator; - auto &left_chunk_state = state.left_chunk_state; - auto &left_block_index = state.left_block_index; - auto &left_scan_state = *state.left_scan_state; - const auto left_cols = children[0].get().GetTypes().size(); + const auto left_cols = op.children[0].get().GetTypes().size(); auto &right_table = *ie_sink.tables[1]; - auto &rsel = state.rsel; - auto &rpayload = state.rpayload; - auto &right_iterator = *state.right_iterator; - auto &right_chunk_state = state.right_chunk_state; - auto &right_block_index = state.right_block_index; - auto &right_scan_state = *state.right_scan_state; do { - auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); + auto result_count = joiner->JoinComplexBlocks(lsel, rsel); if (result_count == 0) { // exhausted this pair + joiner.reset(); + ++gsource.completed; return; } // found matches: extract them - left_table.Repin(left_iterator); - right_table.Repin(right_iterator); + left_table.Repin(*left_iterator); + right_table.Repin(*right_iterator); - SliceSortedPayload(lpayload, left_table, left_iterator, left_chunk_state, left_block_index, lsel, result_count, - left_scan_state); - SliceSortedPayload(rpayload, right_table, right_iterator, right_chunk_state, right_block_index, rsel, - result_count, right_scan_state); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, lsel, + result_count, *left_scan_state); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, + result_count, *right_scan_state); auto sel = FlatVector::IncrementalSelectionVector(); if (conditions.size() > 2) { @@ -950,24 +1020,25 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re // to we can compute the values for comparison. const auto tail_cols = conditions.size() - 2; - state.left_executor.SetChunk(lpayload); - state.right_executor.SetChunk(rpayload); + left_executor.SetChunk(lpayload); + right_executor.SetChunk(rpayload); auto tail_count = result_count; - auto true_sel = &state.true_sel; + auto match_sel = &true_sel; for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { - auto &left = state.left_keys.data[cmp_idx]; - state.left_executor.ExecuteExpression(cmp_idx, left); + auto &left = left_keys.data[cmp_idx]; + left_executor.ExecuteExpression(cmp_idx, left); - auto &right = state.right_keys.data[cmp_idx]; - state.right_executor.ExecuteExpression(cmp_idx, right); + auto &right = right_keys.data[cmp_idx]; + right_executor.ExecuteExpression(cmp_idx, right); if (tail_count < result_count) { left.Slice(*sel, tail_count); right.Slice(*sel, tail_count); } - tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); - sel = true_sel; + tail_count = + op.SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, match_sel); + sel = match_sel; } if (tail_count < result_count) { @@ -990,79 +1061,88 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re // We need all of the data to compute other predicates, // but we only return what is in the projection map - ProjectResult(chunk, result); + op.ProjectResult(chunk, result); // found matches: mark the found matches if required if (left_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; + left_table.found_match[left_base + lsel[sel->get_index(i)]] = true; } } if (right_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; + right_table.found_match[right_base + rsel[sel->get_index(i)]] = true; } } result.Verify(); } while (result.size() == 0); } -class IEJoinGlobalSourceState : public GlobalSourceState { -public: - IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) - : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), - right_outers(0), next_right(0) { +void IEJoinGlobalSourceState::Initialize() { + auto guard = Lock(); + if (stage != IEJoinSourceStage::INIT) { + return; } - void Initialize(ClientContext &client) { - auto guard = Lock(); - if (initialized) { - return; - } - - // Compute the starting row for each block - auto &left_table = *gsink.tables[0]; - const auto left_blocks = left_table.BlockCount(); - - auto &right_table = *gsink.tables[1]; - const auto right_blocks = right_table.BlockCount(); + // Compute the starting row for each block + auto &left_table = *gsink.tables[0]; + const auto left_blocks = left_table.BlockCount(); - // Outer join block counts - if (left_table.found_match) { - left_outers = left_blocks; - } + auto &right_table = *gsink.tables[1]; + const auto right_blocks = right_table.BlockCount(); - if (right_table.found_match) { - right_outers = right_blocks; - } + // Outer join block counts + if (left_table.found_match) { + left_outers = left_blocks; + } - // Ready for action - initialized = true; + if (right_table.found_match) { + right_outers = right_blocks; } -public: - idx_t MaxThreads() override { - // We can't leverage any more threads than block pairs. - const auto &sink_state = (op.sink_state->Cast()); - return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); + // Ready for action + stage = IEJoinSourceStage::INNER; +} +bool IEJoinGlobalSourceState::TryPrepareNextStage() { + // Inside lock + switch (stage.load()) { + case IEJoinSourceStage::INNER: + if (completed >= pair_count) { + stage = IEJoinSourceStage::OUTER; + return true; + } + break; + case IEJoinSourceStage::OUTER: + if (next_left >= left_outers && next_right >= right_outers) { + stage = IEJoinSourceStage::DONE; + return true; + } + break; + default: + break; } - void GetNextPair(ExecutionContext &context, IEJoinLocalSourceState &lstate) { - using ChunkRange = IEJoinUnion::ChunkRange; - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; + return false; +} - const auto left_blocks = left_table.BlockCount(); - const auto left_ranges = (left_blocks + left_per_thread - 1) / left_per_thread; +idx_t IEJoinGlobalSourceState::MaxThreads() { + // We can't leverage any more threads than block pairs. + const auto &sink_state = (op.sink_state->Cast()); + return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); +} - const auto right_blocks = right_table.BlockCount(); - const auto right_ranges = (right_blocks + right_per_thread - 1) / right_per_thread; +bool IEJoinGlobalSourceState::AssignTask(ExecutionContext &context, IEJoinLocalSourceState &lstate) { + auto guard = Lock(); - const auto pair_count = left_ranges * right_ranges; + using ChunkRange = IEJoinUnion::ChunkRange; + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; - // Regular block - const auto i = next_pair++; - if (i < pair_count) { + // Regular block + switch (stage.load()) { + case IEJoinSourceStage::INNER: + if (next_pair < pair_count) { + const auto i = next_pair++; const auto b1 = (i / right_ranges) * left_per_thread; const auto b2 = (i % right_ranges) * right_per_thread; @@ -1075,22 +1155,13 @@ class IEJoinGlobalSourceState : public GlobalSourceState { lstate.right_base = right_table.BlockStart(r_range.first); lstate.joiner = make_uniq(context, op, left_table, l_range, right_table, r_range); - return; - } - - // Outer joins - if (!left_outers && !right_outers) { - return; - } - - // Spin wait for regular blocks to finish(!) - while (completed < pair_count) { - TaskScheduler::GetScheduler(context.client).YieldThread(); + return true; } - + break; + case IEJoinSourceStage::OUTER: // Left outer blocks - const auto l = next_left++; - if (l < left_outers) { + if (next_left < left_outers) { + const auto l = next_left++; lstate.joiner = nullptr; lstate.left_block_index = l; lstate.left_base = left_table.BlockStart(l); @@ -1098,14 +1169,14 @@ class IEJoinGlobalSourceState : public GlobalSourceState { lstate.left_matches = left_table.found_match.get() + lstate.left_base; lstate.outer_idx = 0; lstate.outer_count = left_table.BlockSize(l); - return; + return true; } else { lstate.left_matches = nullptr; } - // Right outer block - const auto r = next_right++; - if (r < right_outers) { + // Right outer blocks + if (next_right < right_outers) { + const auto r = next_right++; lstate.joiner = nullptr; lstate.right_block_index = r; lstate.right_base = right_table.BlockStart(r); @@ -1113,61 +1184,34 @@ class IEJoinGlobalSourceState : public GlobalSourceState { lstate.right_matches = right_table.found_match.get() + lstate.right_base; lstate.outer_idx = 0; lstate.outer_count = right_table.BlockSize(r); - return; + return true; } else { lstate.right_matches = nullptr; } + break; + default: + break; } - void PairCompleted(ExecutionContext &context, IEJoinLocalSourceState &lstate) { - lstate.joiner.reset(); - ++completed; - GetNextPair(context, lstate); - } - - ProgressData GetProgress() const { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; - - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; + return false; +} - const auto count = pair_count + left_outers + right_outers; +ProgressData IEJoinGlobalSourceState::GetProgress() const { + const auto count = pair_count + left_outers + right_outers; - const auto l = MinValue(next_left.load(), left_outers.load()); - const auto r = MinValue(next_right.load(), right_outers.load()); - const auto returned = completed.load() + l + r; + const auto l = MinValue(next_left.load(), left_outers.load()); + const auto r = MinValue(next_right.load(), right_outers.load()); + const auto returned = completed.load() + l + r; - ProgressData res; - if (count) { - res.done = double(returned); - res.total = double(count); - } else { - res.SetInvalid(); - } - return res; + ProgressData res; + if (count) { + res.done = double(returned); + res.total = double(count); + } else { + res.SetInvalid(); } - - const PhysicalIEJoin &op; - IEJoinGlobalState &gsink; - - bool initialized = false; - - // Join queue state - const idx_t left_per_thread = 1024; - const idx_t right_per_thread = 1024; - atomic next_pair; - atomic completed; - - // Outer joins - atomic left_outers; - atomic next_left; - - atomic right_outers; - atomic next_right; -}; - + return res; +} unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { auto &gsink = sink_state->Cast(); return make_uniq(*this, gsink); @@ -1175,7 +1219,8 @@ unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(context.client, *this); + auto &gsource = gstate.Cast(); + return make_uniq(context.client, gsource); } ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { @@ -1185,104 +1230,97 @@ ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceSta SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, OperatorSourceInput &input) const { - auto &ie_sink = sink_state->Cast(); - auto &ie_gstate = input.global_state.Cast(); - auto &ie_lstate = input.local_state.Cast(); + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); - ie_gstate.Initialize(context.client); + gsource.Initialize(); - if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context, ie_lstate); + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.stage != IEJoinSourceStage::DONE && result.size() == 0) { + if (!lsource.TaskFinished() || gsource.AssignTask(context, lsource)) { + lsource.ExecuteTask(context, result); + } else { + auto guard = gsource.Lock(); + if (gsource.TryPrepareNextStage() || gsource.stage == IEJoinSourceStage::DONE) { + gsource.UnblockTasks(guard); + } else { + return gsource.BlockSource(guard, input.interrupt_state); + } + } } + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} - // Process INNER results - while (ie_lstate.joiner) { - ResolveComplexJoin(context, result, ie_lstate); +void IEJoinLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); - if (result.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } + const auto left_cols = op.children[0].get().GetTypes().size(); + auto &chunk = unprojected; - ie_gstate.PairCompleted(context, ie_lstate); + const idx_t count = SelectOuterRows(left_matches); + if (!count) { + left_matches = nullptr; + return; } - // Process LEFT OUTER results - const auto left_cols = children[0].get().GetTypes().size(); - auto &chunk = ie_lstate.unprojected; - while (ie_lstate.left_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); - if (!count) { - ie_gstate.GetNextPair(context, ie_lstate); - continue; - } - auto &left_table = *ie_sink.tables[0]; - auto &lpayload = ie_lstate.lpayload; - auto &left_iterator = *ie_lstate.left_iterator; - auto &left_chunk_state = ie_lstate.left_chunk_state; - auto &left_block_index = ie_lstate.left_block_index; - auto &left_scan_state = *ie_lstate.left_scan_state; + auto &left_table = *ie_sink.tables[0]; - left_table.Repin(left_iterator); - SliceSortedPayload(lpayload, left_table, left_iterator, left_chunk_state, left_block_index, ie_lstate.true_sel, - count, left_scan_state); + left_table.Repin(*left_iterator); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, true_sel, count, + *left_scan_state); - // Fill in NULLs to the right - chunk.Reset(); - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - if (col_idx < left_cols) { - chunk.data[col_idx].Reference(lpayload.data[col_idx]); - } else { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } + // Fill in NULLs to the right + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } - // Process RIGHT OUTER results - while (ie_lstate.right_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); - if (!count) { - ie_gstate.GetNextPair(context, ie_lstate); - continue; - } + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); +} - auto &right_table = *ie_sink.tables[1]; - auto &rsel = ie_lstate.true_sel; - auto &rpayload = ie_lstate.rpayload; - auto &right_iterator = *ie_lstate.right_iterator; - auto &right_chunk_state = ie_lstate.right_chunk_state; - auto &right_block_index = ie_lstate.right_block_index; - auto &right_scan_state = *ie_lstate.right_scan_state; - - right_table.Repin(right_iterator); - SliceSortedPayload(rpayload, right_table, right_iterator, right_chunk_state, right_block_index, rsel, count, - right_scan_state); - - // Fill in NULLs to the left - chunk.Reset(); - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - if (col_idx < left_cols) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } else { - chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); - } - } +void IEJoinLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto left_cols = op.children[0].get().GetTypes().size(); - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); + auto &chunk = unprojected; - break; + const idx_t count = SelectOuterRows(right_matches); + if (!count) { + right_matches = nullptr; + return; } - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + auto &right_table = *ie_sink.tables[1]; + auto &rsel = true_sel; + + right_table.Repin(*right_iterator); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, count, + *right_scan_state); + + // Fill in NULLs to the left + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); + } + } + + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); } //===--------------------------------------------------------------------===// diff --git a/src/function/aggregate/distributive/minmax.cpp b/src/function/aggregate/distributive/minmax.cpp index ce5ef12afaec..1d6ccfe7919a 100644 --- a/src/function/aggregate/distributive/minmax.cpp +++ b/src/function/aggregate/distributive/minmax.cpp @@ -368,7 +368,7 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f // Bind function like arg_min/arg_max. function.arguments[0] = arguments[0]->return_type; function.return_type = arguments[0]->return_type; - return nullptr; + return make_uniq(); } } @@ -431,7 +431,6 @@ class MinMaxNState { template void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { - auto &val_vector = inputs[0]; auto &n_vector = inputs[1]; @@ -441,7 +440,7 @@ void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_ auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, true); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); diff --git a/src/function/cast/variant/to_json.cpp b/src/function/cast/variant/to_json.cpp index 85d9578671fd..482fa90c26c6 100644 --- a/src/function/cast/variant/to_json.cpp +++ b/src/function/cast/variant/to_json.cpp @@ -10,6 +10,7 @@ #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/variant_visitor.hpp" using namespace duckdb_yyjson; // NOLINT @@ -17,262 +18,211 @@ namespace duckdb { //! ------------ Variant -> JSON ------------ -yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, - idx_t row, uint32_t values_idx) { - auto index = source.unified.sel->get_index(row); - if (!source.unified.validity.RowIsValid(index)) { - return yyjson_mut_null(doc); - } +namespace { + +struct JSONConverter { + using result_type = yyjson_mut_val *; - //! values - auto &values = UnifiedVariantVector::GetValues(source); - auto values_data = values.GetData(values); - - //! type_ids - auto &type_ids = UnifiedVariantVector::GetValuesTypeId(source); - auto type_ids_data = type_ids.GetData(type_ids); - - //! byte_offsets - auto &byte_offsets = UnifiedVariantVector::GetValuesByteOffset(source); - auto byte_offsets_data = byte_offsets.GetData(byte_offsets); - - //! children - auto &children = UnifiedVariantVector::GetChildren(source); - auto children_data = children.GetData(children); - - //! values_index - auto &values_index = UnifiedVariantVector::GetChildrenValuesIndex(source); - auto values_index_data = values_index.GetData(values_index); - - //! keys_index - auto &keys_index = UnifiedVariantVector::GetChildrenKeysIndex(source); - auto keys_index_data = keys_index.GetData(keys_index); - - //! keys - auto &keys = UnifiedVariantVector::GetKeys(source); - auto keys_data = keys.GetData(keys); - auto &keys_entry = UnifiedVariantVector::GetKeysEntry(source); - auto keys_entry_data = keys_entry.GetData(keys_entry); - - //! list entries - auto keys_list_entry = keys_data[keys.sel->get_index(row)]; - auto children_list_entry = children_data[children.sel->get_index(row)]; - auto values_list_entry = values_data[values.sel->get_index(row)]; - - //! The 'values' data of the value we're currently converting - values_idx += values_list_entry.offset; - auto type_id = static_cast(type_ids_data[type_ids.sel->get_index(values_idx)]); - auto byte_offset = byte_offsets_data[byte_offsets.sel->get_index(values_idx)]; - - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto &value = UnifiedVariantVector::GetData(source); - auto value_data = value.GetData(value); - auto &blob = value_data[value.sel->get_index(row)]; - auto blob_data = const_data_ptr_cast(blob.GetData()); - - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: + static yyjson_mut_val *VisitNull(yyjson_mut_doc *doc) { return yyjson_mut_null(doc); - case VariantLogicalType::BOOL_TRUE: - return yyjson_mut_true(doc); - case VariantLogicalType::BOOL_FALSE: - return yyjson_mut_false(doc); - case VariantLogicalType::INT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); - } - case VariantLogicalType::INT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); } - case VariantLogicalType::INT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitBoolean(bool val, yyjson_mut_doc *doc) { + return val ? yyjson_mut_true(doc) : yyjson_mut_false(doc); } - case VariantLogicalType::INT64: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + template + static yyjson_mut_val *VisitInteger(T val, yyjson_mut_doc *doc) { + throw InternalException("JSONConverter::VisitInteger not implemented!"); } - case VariantLogicalType::INT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTime(dtime_t val, yyjson_mut_doc *doc) { + auto val_str = Time::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeNanos(dtime_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIME_NS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeTZ(dtime_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMETZ(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimestampSec(timestamp_sec_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPSEC(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT64: { - auto val = Load(ptr); - return yyjson_mut_uint(doc, val); + + static yyjson_mut_val *VisitTimestampMs(timestamp_ms_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPMS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTimestamp(timestamp_t val, yyjson_mut_doc *doc) { + auto val_str = Timestamp::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UUID: { - auto val = Value::UUID(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampNanos(timestamp_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPNS(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::INTERVAL: { - auto val = Value::INTERVAL(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampTZ(timestamp_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPTZ(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::FLOAT: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitFloat(float val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DOUBLE: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitDouble(double val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DATE: { - auto val = Load(ptr); - auto val_str = Date::ToString(date_t(val)); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BLOB(const_data_ptr_cast(string_data), string_length).ToString(); + + static yyjson_mut_val *VisitUUID(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = Value::UUID(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::GEOMETRY: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::GEOMETRY(const_data_ptr_cast(string_data), string_length).ToString(); + + static yyjson_mut_val *VisitDate(date_t val, yyjson_mut_doc *doc) { + auto val_str = Date::ToString(val); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return yyjson_mut_strncpy(doc, string_data, static_cast(string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - string val_str; - if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else { - val_str = Decimal::ToString(Load(ptr), width, scale); - } - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::TIME_MICROS: { - auto val = Load(ptr); - auto val_str = Time::ToString(val); + static yyjson_mut_val *VisitInterval(interval_t val, yyjson_mut_doc *doc) { + auto val_str = Value::INTERVAL(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIME_MICROS_TZ: { - auto val = Value::TIMETZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitString(const string_t &str, yyjson_mut_doc *doc) { + return yyjson_mut_strncpy(doc, str.GetData(), str.GetSize()); } - case VariantLogicalType::TIMESTAMP_MICROS: { - auto val = Load(ptr); - auto val_str = Timestamp::ToString(val); + + static yyjson_mut_val *VisitBlob(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_SEC: { - auto val = Value::TIMESTAMPSEC(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitBignum(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_NANOS: { - auto val = Value::TIMESTAMPNS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitGeometry(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MILIS: { - auto val = Value::TIMESTAMPMS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitBitstring(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MICROS_TZ: { - auto val = Value::TIMESTAMPTZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + template + static yyjson_mut_val *VisitDecimal(T val, uint32_t width, uint32_t scale, yyjson_mut_doc *doc) { + string val_str; + if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else { + throw InternalException("Unhandled decimal type"); + } + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitArray(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto arr = yyjson_mut_arr(doc); - if (!count) { - return arr; - } - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_index = values_index_data[index]; -#ifdef DEBUG - auto key_id_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(!keys_index.validity.RowIsValid(key_id_index)); -#endif - auto val = ConvertVariantToJSON(doc, source, row, child_index); - if (!val) { - return nullptr; - } - yyjson_mut_arr_add_val(arr, val); + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data, doc); + for (auto &entry : array_items) { + yyjson_mut_arr_add_val(arr, entry); } return arr; } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto obj = yyjson_mut_obj(doc); - if (!count) { - return obj; - } - auto child_index_start = VarintDecode(ptr); - - for (idx_t i = 0; i < count; i++) { - auto children_index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_value_idx = values_index_data[children_index]; - auto val = ConvertVariantToJSON(doc, source, row, child_value_idx); - if (!val) { - return nullptr; - } - auto keys_index_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(keys_index.validity.RowIsValid(keys_index_index)); - auto child_key_id = keys_index_data[keys_index_index]; - auto &key = keys_entry_data[keys_entry.sel->get_index(keys_list_entry.offset + child_key_id)]; - yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, key.GetData(), key.GetSize()), val); + auto object_items = VariantVisitor::VisitObjectItems(variant, row, nested_data, doc); + for (auto &entry : object_items) { + yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, entry.first.c_str(), entry.first.size()), entry.second); } return obj; } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIT(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIGNUM(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); + + static yyjson_mut_val *VisitDefault(VariantLogicalType type_id, const_data_ptr_t, yyjson_mut_doc *) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } +}; + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} - return nullptr; +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_uint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uhugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +} // namespace + +yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, + idx_t row, uint32_t values_idx) { + UnifiedVariantVectorData variant(source); + return VariantVisitor::Visit(variant, row, values_idx, doc); } } // namespace duckdb diff --git a/src/function/cast/variant/to_variant.cpp b/src/function/cast/variant/to_variant.cpp index 813d2483598d..b724ade78a9a 100644 --- a/src/function/cast/variant/to_variant.cpp +++ b/src/function/cast/variant/to_variant.cpp @@ -84,39 +84,6 @@ static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVect selvec_size = keys_offset; } -static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, SelectionVector &sel, - idx_t sel_size) { - auto &keys = VariantVector::GetKeys(variant); - auto &keys_entry = ListVector::GetEntry(keys); - auto keys_entry_data = FlatVector::GetData(keys_entry); - - bool already_sorted = true; - - vector unsorted_to_sorted(dictionary.size()); - auto it = dictionary.begin(); - for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { - auto unsorted_idx = it->second; - if (unsorted_idx != sorted_idx) { - already_sorted = false; - } - unsorted_to_sorted[unsorted_idx] = sorted_idx; - D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); - keys_entry_data[sorted_idx] = it->first; - auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); - keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); - it++; - } - - if (!already_sorted) { - //! Adjust the selection vector to point to the right dictionary index - for (idx_t i = 0; i < sel_size; i++) { - auto &entry = sel[i]; - auto sorted_idx = unsorted_to_sorted[entry]; - entry = sorted_idx; - } - } -} - static bool GatherOffsetsAndSizes(ToVariantSourceData &source, ToVariantGlobalResultData &result, idx_t count) { InitializeOffsets(result.offsets, count); //! First pass - collect sizes/offsets @@ -166,7 +133,7 @@ static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParam } } - FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); //! Finalize the 'data' auto &blob = VariantVector::GetData(result); auto blob_data = FlatVector::GetData(blob); diff --git a/src/function/function_list.cpp b/src/function/function_list.cpp index d73467d3ac67..74dbb347b4d2 100644 --- a/src/function/function_list.cpp +++ b/src/function/function_list.cpp @@ -177,6 +177,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), DUCKDB_SCALAR_FUNCTION(UpperFun), DUCKDB_SCALAR_FUNCTION_SET(VariantExtractFun), + DUCKDB_SCALAR_FUNCTION(VariantNormalizeFun), DUCKDB_SCALAR_FUNCTION(VariantTypeofFun), DUCKDB_SCALAR_FUNCTION_SET(WriteLogFun), DUCKDB_SCALAR_FUNCTION(ConcatOperatorFun), diff --git a/src/function/macro_function.cpp b/src/function/macro_function.cpp index 2f407c0256a2..66e36181b5be 100644 --- a/src/function/macro_function.cpp +++ b/src/function/macro_function.cpp @@ -47,7 +47,7 @@ MacroBindResult MacroFunction::BindMacroFunction( InsertionOrderPreservingMap> &named_arguments, idx_t depth) { ExpressionBinder expr_binder(binder, binder.context); - + expr_binder.lambda_bindings = binder.lambda_bindings; // Find argument types and separate positional and default arguments vector positional_arg_types; InsertionOrderPreservingMap named_arg_types; diff --git a/src/function/scalar/create_sort_key.cpp b/src/function/scalar/create_sort_key.cpp index 2f5463e3f3ca..d9127d359b1c 100644 --- a/src/function/scalar/create_sort_key.cpp +++ b/src/function/scalar/create_sort_key.cpp @@ -696,13 +696,15 @@ void PrepareSortData(Vector &result, idx_t size, SortKeyLengthInfo &key_lengths, } } -void FinalizeSortData(Vector &result, idx_t size) { +void FinalizeSortData(Vector &result, idx_t size, const SortKeyLengthInfo &key_lengths, + const unsafe_vector &offsets) { switch (result.GetType().id()) { case LogicalTypeId::BLOB: { auto result_data = FlatVector::GetData(result); // call Finalize on the result for (idx_t r = 0; r < size; r++) { - result_data[r].Finalize(); + result_data[r].SetSizeAndFinalize(NumericCast(offsets[r]), + key_lengths.variable_lengths[r] + key_lengths.constant_length); } break; } @@ -739,7 +741,7 @@ void CreateSortKeyInternal(vector> &sort_key_data, SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); ConstructSortKey(*sort_key_data[c], info); } - FinalizeSortData(result, row_count); + FinalizeSortData(result, row_count, key_lengths, offsets); } } // namespace diff --git a/src/function/scalar/operator/arithmetic.cpp b/src/function/scalar/operator/arithmetic.cpp index 82cd9b5b7051..1dde43871c4e 100644 --- a/src/function/scalar/operator/arithmetic.cpp +++ b/src/function/scalar/operator/arithmetic.cpp @@ -1220,7 +1220,7 @@ hugeint_t InterpolateOperator::Operation(const hugeint_t &lo, const double d, co template <> uhugeint_t InterpolateOperator::Operation(const uhugeint_t &lo, const double d, const uhugeint_t &hi) { - return Hugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); + return Uhugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); } static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT diff --git a/src/function/scalar/variant/CMakeLists.txt b/src/function/scalar/variant/CMakeLists.txt index 76336e4faf70..a29c46d67dfb 100644 --- a/src/function/scalar/variant/CMakeLists.txt +++ b/src/function/scalar/variant/CMakeLists.txt @@ -1,5 +1,5 @@ add_library_unity(duckdb_func_variant_main OBJECT variant_utils.cpp - variant_extract.cpp variant_typeof.cpp) + variant_extract.cpp variant_typeof.cpp variant_normalize.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/src/function/scalar/variant/functions.json b/src/function/scalar/variant/functions.json index bcc46f17faf3..a92790571019 100644 --- a/src/function/scalar/variant/functions.json +++ b/src/function/scalar/variant/functions.json @@ -39,6 +39,16 @@ ], "type": "scalar_function_set" }, + { + "name": "variant_normalize", + "parameters": "input_variant", + "description": "Normalizes the `input_variant` to a canonical representation.", + "example": "variant_normalize({'b': [1,2,3], 'a': 42})::VARIANT)", + "categories": [ + "variant" + ], + "type": "scalar_function" + }, { "name": "variant_typeof", "parameters": "input_variant", @@ -49,4 +59,4 @@ ], "type": "scalar_function" } -] +] \ No newline at end of file diff --git a/src/function/scalar/variant/variant_extract.cpp b/src/function/scalar/variant/variant_extract.cpp index 2c6ff14cf2ee..76d5c84cd32b 100644 --- a/src/function/scalar/variant/variant_extract.cpp +++ b/src/function/scalar/variant/variant_extract.cpp @@ -12,6 +12,7 @@ struct BindData : public FunctionData { public: explicit BindData(const string &str); explicit BindData(uint32_t index); + BindData(const BindData &other) = default; public: unique_ptr Copy() const override; @@ -36,10 +37,7 @@ BindData::BindData(uint32_t index) : FunctionData() { } unique_ptr BindData::Copy() const { - if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { - return make_uniq(component.index); - } - return make_uniq(component.key); + return make_uniq(*this); } bool BindData::Equals(const FunctionData &other) const { diff --git a/src/function/scalar/variant/variant_normalize.cpp b/src/function/scalar/variant/variant_normalize.cpp new file mode 100644 index 000000000000..ef79e38e3708 --- /dev/null +++ b/src/function/scalar/variant/variant_normalize.cpp @@ -0,0 +1,311 @@ +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" + +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" + +namespace duckdb { + +namespace { + +struct VariantNormalizerState { +public: + VariantNormalizerState(idx_t result_row, VariantVectorData &source, OrderedOwningStringMap &dictionary, + SelectionVector &keys_selvec) + : source(source), dictionary(dictionary), keys_selvec(keys_selvec), + keys_index_validity(source.keys_index_validity) { + auto keys_list_entry = source.keys_data[result_row]; + auto values_list_entry = source.values_data[result_row]; + auto children_list_entry = source.children_data[result_row]; + + keys_offset = keys_list_entry.offset; + children_offset = children_list_entry.offset; + + blob_data = data_ptr_cast(source.blob_data[result_row].GetDataWriteable()); + type_ids = source.type_ids_data + values_list_entry.offset; + byte_offsets = source.byte_offset_data + values_list_entry.offset; + values_indexes = source.values_index_data + children_list_entry.offset; + keys_indexes = source.keys_index_data + children_list_entry.offset; + } + +public: + data_ptr_t GetDestination() { + return blob_data + blob_size; + } + uint32_t GetOrCreateIndex(const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; + } + +public: + uint32_t keys_size = 0; + uint32_t children_size = 0; + uint32_t values_size = 0; + uint32_t blob_size = 0; + + VariantVectorData &source; + OrderedOwningStringMap &dictionary; + SelectionVector &keys_selvec; + + uint64_t keys_offset; + uint64_t children_offset; + ValidityMask &keys_index_validity; + + data_ptr_t blob_data; + uint8_t *type_ids; + uint32_t *byte_offsets; + uint32_t *values_indexes; + uint32_t *keys_indexes; +}; + +struct VariantNormalizer { + using result_type = void; + + static void VisitNull(VariantNormalizerState &state) { + return; + } + static void VisitBoolean(bool val, VariantNormalizerState &state) { + return; + } + + static void VisitMetadata(VariantLogicalType type_id, VariantNormalizerState &state) { + state.type_ids[state.values_size] = static_cast(type_id); + state.byte_offsets[state.values_size] = state.blob_size; + state.values_size++; + } + + template + static void VisitInteger(T val, VariantNormalizerState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitDouble(double val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitUUID(hugeint_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitDate(date_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitInterval(interval_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTime(dtime_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimeNanos(dtime_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimeTZ(dtime_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampSec(timestamp_sec_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampMs(timestamp_ms_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestamp(timestamp_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampNanos(timestamp_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampTZ(timestamp_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + + static void WriteStringInternal(const string_t &str, VariantNormalizerState &state) { + } + + static void VisitString(const string_t &str, VariantNormalizerState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; + } + static void VisitBlob(const string_t &blob, VariantNormalizerState &state) { + return VisitString(blob, state); + } + static void VisitBignum(const string_t &bignum, VariantNormalizerState &state) { + return VisitString(bignum, state); + } + static void VisitGeometry(const string_t &geom, VariantNormalizerState &state) { + return VisitString(geom, state); + } + static void VisitBitstring(const string_t &bits, VariantNormalizerState &state) { + return VisitString(bits, state); + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, VariantNormalizerState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + idx_t result_children_idx = state.children_size; + state.blob_size += VarintEncode(result_children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto source_children_idx = nested_data.children_idx + i; + auto values_index = variant.GetValuesIndex(row, source_children_idx); + + //! Set the 'values_index' for the child, and set the 'keys_index' to NULL + state.values_indexes[result_children_idx] = state.values_size; + state.keys_index_validity.SetInvalid(state.children_offset + result_children_idx); + result_children_idx++; + + //! Visit the child value + VariantVisitor::Visit(variant, row, values_index, state); + } + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + uint32_t children_idx = state.children_size; + uint32_t keys_idx = state.keys_size; + state.blob_size += VarintEncode(children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + state.keys_size += nested_data.child_count; + + //! First iterate through all fields to populate the map of key -> field + map sorted_fields; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, nested_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + sorted_fields.emplace(key, i); + } + + //! Then visit the fields in sorted order + for (auto &entry : sorted_fields) { + auto source_children_idx = nested_data.children_idx + entry.second; + + //! Add the key of the field to the result + auto keys_index = variant.GetKeysIndex(row, source_children_idx); + auto &key = variant.GetKey(row, keys_index); + auto dict_index = state.GetOrCreateIndex(key); + state.keys_selvec.set_index(state.keys_offset + keys_idx, dict_index); + + //! Visit the child value + auto values_index = variant.GetValuesIndex(row, source_children_idx); + state.values_indexes[children_idx] = state.values_size; + state.keys_indexes[children_idx] = keys_idx; + children_idx++; + keys_idx++; + VariantVisitor::Visit(variant, row, values_index, state); + } + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantNormalizerState &state) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); + } +}; + +} // namespace + +static void VariantNormalizeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto count = input.size(); + + D_ASSERT(input.ColumnCount() == 1); + auto &variant_vec = input.data[0]; + D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); + + //! Set up the access helper for the source VARIANT + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, source_format); + UnifiedVariantVectorData variant(source_format); + + //! Take the original sizes of the lists, the result will be similar size, never bigger + auto original_keys_size = ListVector::GetListSize(VariantVector::GetKeys(variant_vec)); + auto original_children_size = ListVector::GetListSize(VariantVector::GetChildren(variant_vec)); + auto original_values_size = ListVector::GetListSize(VariantVector::GetValues(variant_vec)); + + auto &keys = VariantVector::GetKeys(result); + auto &children = VariantVector::GetChildren(result); + auto &values = VariantVector::GetValues(result); + auto &data = VariantVector::GetData(result); + + ListVector::Reserve(keys, original_keys_size); + ListVector::SetListSize(keys, 0); + ListVector::Reserve(children, original_children_size); + ListVector::SetListSize(children, 0); + ListVector::Reserve(values, original_values_size); + ListVector::SetListSize(values, 0); + + //! Initialize the dictionary + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + + VariantVectorData variant_data(result); + SelectionVector keys_selvec; + keys_selvec.Initialize(original_keys_size); + + for (idx_t i = 0; i < count; i++) { + if (!variant.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + //! Allocate for the new data, use the same size as source + auto &blob_data = variant_data.blob_data[i]; + auto original_data = variant.GetData(i); + blob_data = StringVector::EmptyString(data, original_data.GetSize()); + + auto &keys_list_entry = variant_data.keys_data[i]; + keys_list_entry.offset = ListVector::GetListSize(keys); + + auto &children_list_entry = variant_data.children_data[i]; + children_list_entry.offset = ListVector::GetListSize(children); + + auto &values_list_entry = variant_data.values_data[i]; + values_list_entry.offset = ListVector::GetListSize(values); + + //! Visit the source to populate the result + VariantNormalizerState visitor_state(i, variant_data, dictionary, keys_selvec); + VariantVisitor::Visit(variant, i, 0, visitor_state); + + blob_data.SetSizeAndFinalize(visitor_state.blob_size, original_data.GetSize()); + keys_list_entry.length = visitor_state.keys_size; + children_list_entry.length = visitor_state.children_size; + values_list_entry.length = visitor_state.values_size; + + ListVector::SetListSize(keys, ListVector::GetListSize(keys) + visitor_state.keys_size); + ListVector::SetListSize(children, ListVector::GetListSize(children) + visitor_state.children_size); + ListVector::SetListSize(values, ListVector::GetListSize(values) + visitor_state.values_size); + } + + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, ListVector::GetListSize(keys)); + keys_entry.Slice(keys_selvec, ListVector::GetListSize(keys)); + + if (input.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +ScalarFunction VariantNormalizeFun::GetFunction() { + auto variant_type = LogicalType::VARIANT(); + return ScalarFunction("variant_normalize", {variant_type}, variant_type, VariantNormalizeFunction); +} + +} // namespace duckdb diff --git a/src/function/scalar/variant/variant_utils.cpp b/src/function/scalar/variant/variant_utils.cpp index b9450188fcde..435ed40f8262 100644 --- a/src/function/scalar/variant/variant_utils.cpp +++ b/src/function/scalar/variant/variant_utils.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/types/variant_visitor.hpp" namespace duckdb { @@ -74,7 +75,6 @@ vector VariantUtils::GetObjectKeys(const UnifiedVariantVectorData &varia return object_keys; } -//! FIXME: this shouldn't return a "result", it should populate a validity mask instead. void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, optional_ptr sel, SelectionVector &res, ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count) { @@ -163,133 +163,204 @@ VariantUtils::CollectNestedData(const UnifiedVariantVectorData &variant, Variant return VariantNestedDataCollectionResult(); } -Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx) { - if (!variant.RowIsValid(row)) { - return Value(LogicalTypeId::SQLNULL); +namespace { + +struct ValueConverter { + using result_type = Value; + + static Value VisitNull() { + return Value(LogicalType::SQLNULL); } - //! The 'values' data of the value we're currently converting - auto type_id = variant.GetTypeId(row, values_idx); - auto byte_offset = variant.GetByteOffset(row, values_idx); + static Value VisitBoolean(bool val) { + return Value::BOOLEAN(val); + } - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); + template + static Value VisitInteger(T val) { + throw InternalException("ValueConverter::VisitInteger not implemented!"); + } - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: - return Value(LogicalType::SQLNULL); - case VariantLogicalType::BOOL_TRUE: - return Value::BOOLEAN(true); - case VariantLogicalType::BOOL_FALSE: - return Value::BOOLEAN(false); - case VariantLogicalType::INT8: - return Value::TINYINT(Load(ptr)); - case VariantLogicalType::INT16: - return Value::SMALLINT(Load(ptr)); - case VariantLogicalType::INT32: - return Value::INTEGER(Load(ptr)); - case VariantLogicalType::INT64: - return Value::BIGINT(Load(ptr)); - case VariantLogicalType::INT128: - return Value::HUGEINT(Load(ptr)); - case VariantLogicalType::UINT8: - return Value::UTINYINT(Load(ptr)); - case VariantLogicalType::UINT16: - return Value::USMALLINT(Load(ptr)); - case VariantLogicalType::UINT32: - return Value::UINTEGER(Load(ptr)); - case VariantLogicalType::UINT64: - return Value::UBIGINT(Load(ptr)); - case VariantLogicalType::UINT128: - return Value::UHUGEINT(Load(ptr)); - case VariantLogicalType::UUID: - return Value::UUID(Load(ptr)); - case VariantLogicalType::INTERVAL: - return Value::INTERVAL(Load(ptr)); - case VariantLogicalType::FLOAT: - return Value::FLOAT(Load(ptr)); - case VariantLogicalType::DOUBLE: - return Value::DOUBLE(Load(ptr)); - case VariantLogicalType::DATE: - return Value::DATE(date_t(Load(ptr))); - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value::BLOB(const_data_ptr_cast(string_data), string_length); - } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value(string_t(string_data, string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - - if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); + static Value VisitTime(dtime_t val) { + return Value::TIME(val); + } + + static Value VisitTimeNanos(dtime_ns_t val) { + return Value::TIME_NS(val); + } + + static Value VisitTimeTZ(dtime_tz_t val) { + return Value::TIMETZ(val); + } + + static Value VisitTimestampSec(timestamp_sec_t val) { + return Value::TIMESTAMPSEC(val); + } + + static Value VisitTimestampMs(timestamp_ms_t val) { + return Value::TIMESTAMPMS(val); + } + + static Value VisitTimestamp(timestamp_t val) { + return Value::TIMESTAMP(val); + } + + static Value VisitTimestampNanos(timestamp_ns_t val) { + return Value::TIMESTAMPNS(val); + } + + static Value VisitTimestampTZ(timestamp_tz_t val) { + return Value::TIMESTAMPTZ(val); + } + + static Value VisitFloat(float val) { + return Value::FLOAT(val); + } + static Value VisitDouble(double val) { + return Value::DOUBLE(val); + } + static Value VisitUUID(hugeint_t val) { + return Value::UUID(val); + } + static Value VisitDate(date_t val) { + return Value::DATE(val); + } + static Value VisitInterval(interval_t val) { + return Value::INTERVAL(val); + } + + static Value VisitString(const string_t &str) { + return Value(str); + } + static Value VisitBlob(const string_t &str) { + return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBignum(const string_t &str) { + return Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitGeometry(const string_t &str) { + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBitstring(const string_t &str) { + return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + + template + static Value VisitDecimal(T val, uint32_t width, uint32_t scale) { + if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); } else { - return Value::DECIMAL(Load(ptr), width, scale); + throw InternalException("Unhandled decimal type"); } } - case VariantLogicalType::TIME_MICROS: - return Value::TIME(Load(ptr)); - case VariantLogicalType::TIME_MICROS_TZ: - return Value::TIMETZ(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS: - return Value::TIMESTAMP(Load(ptr)); - case VariantLogicalType::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(Load(ptr)); - case VariantLogicalType::TIMESTAMP_NANOS: - return Value::TIMESTAMPNS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MILIS: - return Value::TIMESTAMPMS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - return Value::TIMESTAMPTZ(Load(ptr)); - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); - vector array_items; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_index = variant.GetValuesIndex(row, child_index_start + i); - array_items.emplace_back(ConvertVariantToValue(variant, row, child_index)); - } - } + + static Value VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data); return Value::LIST(LogicalType::VARIANT(), std::move(array_items)); } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); - child_list_t object_children; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_value_idx = variant.GetValuesIndex(row, child_index_start + i); - auto val = ConvertVariantToValue(variant, row, child_value_idx); - - auto child_key_id = variant.GetKeysIndex(row, child_index_start + i); - auto &key = variant.GetKey(row, child_key_id); - object_children.emplace_back(key.GetString(), std::move(val)); - } - } + static Value VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto object_children = VariantVisitor::VisitObjectItems(variant, row, nested_data); return Value::STRUCT(std::move(object_children)); } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - return Value::BIT(ptr, string_length); + + static Value VisitDefault(VariantLogicalType type_id, const_data_ptr_t) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - return Value::BIGNUM(ptr, string_length); +}; + +template <> +Value ValueConverter::VisitInteger(int8_t val) { + return Value::TINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int16_t val) { + return Value::SMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int32_t val) { + return Value::INTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(int64_t val) { + return Value::BIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(hugeint_t val) { + return Value::HUGEINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint8_t val) { + return Value::UTINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint16_t val) { + return Value::USMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint32_t val) { + return Value::UINTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(uint64_t val) { + return Value::UBIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uhugeint_t val) { + return Value::UHUGEINT(val); +} + +} // namespace + +Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx) { + return VariantVisitor::Visit(variant, row, values_idx); +} + +void VariantUtils::FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size) { + auto &keys = VariantVector::GetKeys(variant); + auto &keys_entry = ListVector::GetEntry(keys); + auto keys_entry_data = FlatVector::GetData(keys_entry); + + bool already_sorted = true; + + vector unsorted_to_sorted(dictionary.size()); + auto it = dictionary.begin(); + for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { + auto unsorted_idx = it->second; + if (unsorted_idx != sorted_idx) { + already_sorted = false; + } + unsorted_to_sorted[unsorted_idx] = sorted_idx; + D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); + keys_entry_data[sorted_idx] = it->first; + auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); + keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); + it++; } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); + + if (!already_sorted) { + //! Adjust the selection vector to point to the right dictionary index + for (idx_t i = 0; i < sel_size; i++) { + auto &entry = sel[i]; + auto sorted_idx = unsorted_to_sorted[entry]; + entry = sorted_idx; + } } } diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 600e50f49436..1ffd6e7ee4d5 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -280,7 +280,31 @@ struct GlobalWriteCSVData : public GlobalFunctionData { return writer.FileSize(); } + unique_ptr GetLocalState(ClientContext &context, const idx_t flush_size) { + { + lock_guard guard(local_state_lock); + if (!local_states.empty()) { + auto result = std::move(local_states.back()); + local_states.pop_back(); + return result; + } + } + auto result = make_uniq(context, flush_size); + result->require_manual_flush = true; + return result; + } + + void StoreLocalState(unique_ptr lstate) { + lock_guard guard(local_state_lock); + lstate->Reset(); + local_states.push_back(std::move(lstate)); + } + CSVWriter writer; + +private: + mutex local_state_lock; + vector> local_states; }; static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { @@ -371,9 +395,7 @@ CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, b // Prepare Batch //===--------------------------------------------------------------------===// struct WriteCSVBatchData : public PreparedBatchData { - explicit WriteCSVBatchData(ClientContext &context, const idx_t flush_size) - : writer_local_state(make_uniq(context, flush_size)) { - writer_local_state->require_manual_flush = true; + explicit WriteCSVBatchData(unique_ptr writer_state) : writer_local_state(std::move(writer_state)) { } //! The thread-local buffer to write data into @@ -397,7 +419,8 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct auto &global_state = gstate.Cast(); // write CSV chunks to the batch data - auto batch = make_uniq(context, NextPowerOfTwo(collection->SizeInBytes())); + auto local_writer_state = global_state.GetLocalState(context, NextPowerOfTwo(collection->SizeInBytes())); + auto batch = make_uniq(std::move(local_writer_state)); for (auto &chunk : collection->Chunks()) { WriteCSVChunkInternal(global_state.writer, *batch->writer_local_state, cast_chunk, chunk, executor); } @@ -412,6 +435,7 @@ void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalF auto &csv_batch = batch.Cast(); auto &global_state = gstate.Cast(); global_state.writer.Flush(*csv_batch.writer_local_state); + global_state.StoreLocalState(std::move(csv_batch.writer_local_state)); } //===--------------------------------------------------------------------===// diff --git a/src/function/table/system/duckdb_functions.cpp b/src/function/table/system/duckdb_functions.cpp index b0c7656fe0bb..044a9c8e665f 100644 --- a/src/function/table/system/duckdb_functions.cpp +++ b/src/function/table/system/duckdb_functions.cpp @@ -17,6 +17,7 @@ #include "duckdb/main/client_data.hpp" namespace duckdb { +constexpr const char *AggregateFunctionCatalogEntry::Name; struct DuckDBFunctionsData : public GlobalTableFunctionState { DuckDBFunctionsData() : offset(0), offset_in_entry(0) { diff --git a/src/include/duckdb.h b/src/include/duckdb.h index ccf5ad5ac553..954f017be1b1 100644 --- a/src/include/duckdb.h +++ b/src/include/duckdb.h @@ -1314,10 +1314,10 @@ DUCKDB_C_API duckdb_result_type duckdb_result_return_type(duckdb_result result); // Safe Fetch Functions //===--------------------------------------------------------------------===// -// These functions will perform conversions if necessary. -// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. -// Note that these functions are slow since they perform bounds checking and conversion -// For fast access of values prefer using `duckdb_result_get_chunk` +// This function group is deprecated. +// To access the values in a result, use `duckdb_fetch_chunk` repeatedly. +// For each chunk, use the `duckdb_data_chunk` interface to access any columns and their values. + #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1446,8 +1446,7 @@ DUCKDB_C_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_ DUCKDB_C_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null -bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The text value at the specified location as a null-terminated string, or nullptr if the value cannot be converted. The result must be freed with `duckdb_free`. @@ -1457,16 +1456,12 @@ DUCKDB_C_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. -No support for nested types, and for other complex types. -The resulting field "string.data" must be freed with `duckdb_free.` - * @return The string value at the specified location. Attempts to cast the result value to string. */ DUCKDB_C_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -1476,8 +1471,8 @@ The result must NOT be freed. DUCKDB_C_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. + * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -3336,23 +3331,26 @@ Returns the size of the child vector of the list. DUCKDB_C_API idx_t duckdb_list_vector_get_size(duckdb_vector vector); /*! -Sets the total size of the underlying child-vector of a list vector. +Sets the size of the underlying child-vector of a list vector. +Note that this does NOT reserve the memory in the child buffer, +and that it is possible to set a size exceeding the capacity. +To set the capacity, use `duckdb_list_vector_reserve`. * @param vector The list vector. * @param size The size of the child list. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size); /*! -Sets the total capacity of the underlying child-vector of a list. - -After calling this method, you must call `duckdb_vector_get_validity` and `duckdb_vector_get_data` to obtain current -data and validity pointers +Sets the capacity of the underlying child-vector of a list vector. +We increment to the next power of two, based on the required capacity. +Thus, the capacity might not match the size of the list (capacity >= size), +which is set via `duckdb_list_vector_set_size`. * @param vector The list vector. -* @param required_capacity the total capacity to reserve. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @param required_capacity The child buffer capacity to reserve. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity); @@ -4675,6 +4673,14 @@ Check if the column at 'index' index of the table has a DEFAULT expression. */ DUCKDB_C_API duckdb_state duckdb_column_has_default(duckdb_table_description table_description, idx_t index, bool *out); +/*! +Return the number of columns of the described table. + +* @param table_description The table_description to query. +* @return The column count. +*/ +DUCKDB_C_API idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description); + /*! Obtain the column name at 'index'. The out result must be destroyed with `duckdb_free`. @@ -4685,6 +4691,17 @@ The out result must be destroyed with `duckdb_free`. */ DUCKDB_C_API char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index); +/*! +Obtain the column type at 'index'. +The return value must be destroyed with `duckdb_destroy_logical_type`. + +* @param table_description The table_description to query. +* @param index The index of the column to query. +* @return The column type. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, + idx_t index); + //===--------------------------------------------------------------------===// // Arrow Interface //===--------------------------------------------------------------------===// diff --git a/src/include/duckdb/common/assert.hpp b/src/include/duckdb/common/assert.hpp index dbf0744e74c2..4bf4b90e54b7 100644 --- a/src/include/duckdb/common/assert.hpp +++ b/src/include/duckdb/common/assert.hpp @@ -38,3 +38,6 @@ DUCKDB_API void DuckDBAssertInternal(bool condition, const char *condition_name, #define D_ASSERT_IS_ENABLED #endif + +//! Force assertion implementation, which always asserts whatever build type is used. +#define ALWAYS_ASSERT(condition) duckdb::DuckDBAssertInternal(bool(condition), #condition, __FILE__, __LINE__) diff --git a/src/include/duckdb/common/csv_writer.hpp b/src/include/duckdb/common/csv_writer.hpp index b2d0e066e1eb..188d1de7f12a 100644 --- a/src/include/duckdb/common/csv_writer.hpp +++ b/src/include/duckdb/common/csv_writer.hpp @@ -90,9 +90,6 @@ class CSVWriter { //! Closes the writer, optionally writes a postfix void Close(); - unique_ptr InitializeLocalWriteState(ClientContext &context, idx_t flush_size); - unique_ptr InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size); - vector> string_casts; idx_t BytesWritten(); diff --git a/src/include/duckdb/common/enums/metric_type.hpp b/src/include/duckdb/common/enums/metric_type.hpp index d0d82b4c22eb..110c9d53f030 100644 --- a/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/include/duckdb/common/enums/metric_type.hpp @@ -20,32 +20,36 @@ namespace duckdb { enum class MetricsType : uint8_t { - QUERY_NAME, + ATTACH_LOAD_STORAGE_LATENCY, + ATTACH_REPLAY_WAL_LATENCY, BLOCKED_THREAD_TIME, + CHECKPOINT_LATENCY, CPU_TIME, - EXTRA_INFO, CUMULATIVE_CARDINALITY, - OPERATOR_TYPE, - OPERATOR_CARDINALITY, CUMULATIVE_ROWS_SCANNED, + EXTRA_INFO, + LATENCY, + OPERATOR_CARDINALITY, + OPERATOR_NAME, OPERATOR_ROWS_SCANNED, OPERATOR_TIMING, + OPERATOR_TYPE, + QUERY_NAME, RESULT_SET_SIZE, - LATENCY, ROWS_RETURNED, - OPERATOR_NAME, SYSTEM_PEAK_BUFFER_MEMORY, SYSTEM_PEAK_TEMP_DIR_SIZE, TOTAL_BYTES_READ, TOTAL_BYTES_WRITTEN, + WAITING_TO_ATTACH_LATENCY, ALL_OPTIMIZERS, CUMULATIVE_OPTIMIZER_TIMING, - PLANNER, - PLANNER_BINDING, PHYSICAL_PLANNER, PHYSICAL_PLANNER_COLUMN_BINDING, - PHYSICAL_PLANNER_RESOLVE_TYPES, PHYSICAL_PLANNER_CREATE_PLAN, + PHYSICAL_PLANNER_RESOLVE_TYPES, + PLANNER, + PLANNER_BINDING, OPTIMIZER_EXPRESSION_REWRITER, OPTIMIZER_FILTER_PULLUP, OPTIMIZER_FILTER_PUSHDOWN, diff --git a/src/include/duckdb/common/exception.hpp b/src/include/duckdb/common/exception.hpp index 480dd23853df..13fa9bc12bb6 100644 --- a/src/include/duckdb/common/exception.hpp +++ b/src/include/duckdb/common/exception.hpp @@ -94,15 +94,16 @@ enum class ExceptionType : uint8_t { class Exception : public std::runtime_error { public: DUCKDB_API Exception(ExceptionType exception_type, const string &message); - DUCKDB_API Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message); public: DUCKDB_API static string ExceptionTypeToString(ExceptionType type); DUCKDB_API static ExceptionType StringToExceptionType(const string &type); template - static string ConstructMessage(const string &msg, ARGS... params) { + static string ConstructMessage(const string &msg, ARGS const &...params) { const std::size_t num_args = sizeof...(ARGS); if (num_args == 0) { return msg; @@ -122,8 +123,9 @@ class Exception : public std::runtime_error { //! Whether this exception type can occur during execution of a query DUCKDB_API static bool IsExecutionError(ExceptionType type); DUCKDB_API static string ToJSON(ExceptionType type, const string &message); - DUCKDB_API static string ToJSON(ExceptionType type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API static string ToJSON(const unordered_map &extra_info, ExceptionType type, + const string &message); DUCKDB_API static bool InvalidatesTransaction(ExceptionType exception_type); DUCKDB_API static bool InvalidatesDatabase(ExceptionType exception_type); @@ -131,8 +133,8 @@ class Exception : public std::runtime_error { DUCKDB_API static string ConstructMessageRecursive(const string &msg, std::vector &values); template - static string ConstructMessageRecursive(const string &msg, std::vector &values, T param, - ARGS... params) { + static string ConstructMessageRecursive(const string &msg, std::vector &values, + const T ¶m, ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); return ConstructMessageRecursive(msg, values, params...); } @@ -155,8 +157,8 @@ class ConnectionException : public Exception { DUCKDB_API explicit ConnectionException(const string &msg); template - explicit ConnectionException(const string &msg, ARGS... params) - : ConnectionException(ConstructMessage(msg, params...)) { + explicit ConnectionException(const string &msg, ARGS &&...params) + : ConnectionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -165,8 +167,8 @@ class PermissionException : public Exception { DUCKDB_API explicit PermissionException(const string &msg); template - explicit PermissionException(const string &msg, ARGS... params) - : PermissionException(ConstructMessage(msg, params...)) { + explicit PermissionException(const string &msg, ARGS &&...params) + : PermissionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -175,8 +177,8 @@ class OutOfRangeException : public Exception { DUCKDB_API explicit OutOfRangeException(const string &msg); template - explicit OutOfRangeException(const string &msg, ARGS... params) - : OutOfRangeException(ConstructMessage(msg, params...)) { + explicit OutOfRangeException(const string &msg, ARGS &&...params) + : OutOfRangeException(ConstructMessage(msg, std::forward(params)...)) { } DUCKDB_API OutOfRangeException(const int64_t value, const PhysicalType orig_type, const PhysicalType new_type); DUCKDB_API OutOfRangeException(const hugeint_t value, const PhysicalType orig_type, const PhysicalType new_type); @@ -189,8 +191,8 @@ class OutOfMemoryException : public Exception { DUCKDB_API explicit OutOfMemoryException(const string &msg); template - explicit OutOfMemoryException(const string &msg, ARGS... params) - : OutOfMemoryException(ConstructMessage(msg, params...)) { + explicit OutOfMemoryException(const string &msg, ARGS &&...params) + : OutOfMemoryException(ConstructMessage(msg, std::forward(params)...)) { } private: @@ -202,7 +204,8 @@ class SyntaxException : public Exception { DUCKDB_API explicit SyntaxException(const string &msg); template - explicit SyntaxException(const string &msg, ARGS... params) : SyntaxException(ConstructMessage(msg, params...)) { + explicit SyntaxException(const string &msg, ARGS &&...params) + : SyntaxException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -211,8 +214,8 @@ class ConstraintException : public Exception { DUCKDB_API explicit ConstraintException(const string &msg); template - explicit ConstraintException(const string &msg, ARGS... params) - : ConstraintException(ConstructMessage(msg, params...)) { + explicit ConstraintException(const string &msg, ARGS &&...params) + : ConstraintException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -221,25 +224,27 @@ class DependencyException : public Exception { DUCKDB_API explicit DependencyException(const string &msg); template - explicit DependencyException(const string &msg, ARGS... params) - : DependencyException(ConstructMessage(msg, params...)) { + explicit DependencyException(const string &msg, ARGS &&...params) + : DependencyException(ConstructMessage(msg, std::forward(params)...)) { } }; class IOException : public Exception { public: DUCKDB_API explicit IOException(const string &msg); - DUCKDB_API explicit IOException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit IOException(const unordered_map &extra_info, const string &msg); explicit IOException(ExceptionType exception_type, const string &msg) : Exception(exception_type, msg) { } template - explicit IOException(const string &msg, ARGS... params) : IOException(ConstructMessage(msg, params...)) { + explicit IOException(const string &msg, ARGS &&...params) + : IOException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit IOException(const string &msg, const unordered_map &extra_info, ARGS... params) - : IOException(ConstructMessage(msg, params...), extra_info) { + explicit IOException(const unordered_map &extra_info, const string &msg, ARGS &&...params) + : IOException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -248,8 +253,8 @@ class MissingExtensionException : public Exception { DUCKDB_API explicit MissingExtensionException(const string &msg); template - explicit MissingExtensionException(const string &msg, ARGS... params) - : MissingExtensionException(ConstructMessage(msg, params...)) { + explicit MissingExtensionException(const string &msg, ARGS &&...params) + : MissingExtensionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -258,8 +263,8 @@ class NotImplementedException : public Exception { DUCKDB_API explicit NotImplementedException(const string &msg); template - explicit NotImplementedException(const string &msg, ARGS... params) - : NotImplementedException(ConstructMessage(msg, params...)) { + explicit NotImplementedException(const string &msg, ARGS &&...params) + : NotImplementedException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -273,8 +278,8 @@ class SerializationException : public Exception { DUCKDB_API explicit SerializationException(const string &msg); template - explicit SerializationException(const string &msg, ARGS... params) - : SerializationException(ConstructMessage(msg, params...)) { + explicit SerializationException(const string &msg, ARGS &&...params) + : SerializationException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -283,8 +288,8 @@ class SequenceException : public Exception { DUCKDB_API explicit SequenceException(const string &msg); template - explicit SequenceException(const string &msg, ARGS... params) - : SequenceException(ConstructMessage(msg, params...)) { + explicit SequenceException(const string &msg, ARGS &&...params) + : SequenceException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -298,14 +303,15 @@ class FatalException : public Exception { explicit FatalException(const string &msg) : FatalException(ExceptionType::FATAL, msg) { } template - explicit FatalException(const string &msg, ARGS... params) : FatalException(ConstructMessage(msg, params...)) { + explicit FatalException(const string &msg, ARGS &&...params) + : FatalException(ConstructMessage(msg, std::forward(params)...)) { } protected: DUCKDB_API explicit FatalException(ExceptionType type, const string &msg); template - explicit FatalException(ExceptionType type, const string &msg, ARGS... params) - : FatalException(type, ConstructMessage(msg, params...)) { + explicit FatalException(ExceptionType type, const string &msg, ARGS &&...params) + : FatalException(type, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -314,23 +320,25 @@ class InternalException : public Exception { DUCKDB_API explicit InternalException(const string &msg); template - explicit InternalException(const string &msg, ARGS... params) - : InternalException(ConstructMessage(msg, params...)) { + explicit InternalException(const string &msg, ARGS &&...params) + : InternalException(ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidInputException : public Exception { public: DUCKDB_API explicit InvalidInputException(const string &msg); - DUCKDB_API explicit InvalidInputException(const string &msg, const unordered_map &extra_info); + DUCKDB_API explicit InvalidInputException(const unordered_map &extra_info, const string &msg); template - explicit InvalidInputException(const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...)) { + explicit InvalidInputException(const string &msg, ARGS &&...params) + : InvalidInputException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit InvalidInputException(const Expression &expr, const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidInputException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidInputException(Exception::InitializeExtraInfo(expr), + ConstructMessage(msg, std::forward(params)...)) { } }; @@ -339,24 +347,26 @@ class ExecutorException : public Exception { DUCKDB_API explicit ExecutorException(const string &msg); template - explicit ExecutorException(const string &msg, ARGS... params) - : ExecutorException(ConstructMessage(msg, params...)) { + explicit ExecutorException(const string &msg, ARGS &&...params) + : ExecutorException(ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidConfigurationException : public Exception { public: DUCKDB_API explicit InvalidConfigurationException(const string &msg); - DUCKDB_API explicit InvalidConfigurationException(const string &msg, - const unordered_map &extra_info); + + DUCKDB_API explicit InvalidConfigurationException(const unordered_map &extra_info, + const string &msg); template - explicit InvalidConfigurationException(const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...)) { + explicit InvalidConfigurationException(const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...), + Exception::InitializeExtraInfo(expr)) { } }; @@ -381,8 +391,8 @@ class ParameterNotAllowedException : public Exception { DUCKDB_API explicit ParameterNotAllowedException(const string &msg); template - explicit ParameterNotAllowedException(const string &msg, ARGS... params) - : ParameterNotAllowedException(ConstructMessage(msg, params...)) { + explicit ParameterNotAllowedException(const string &msg, ARGS &&...params) + : ParameterNotAllowedException(ConstructMessage(msg, std::forward(params)...)) { } }; diff --git a/src/include/duckdb/common/exception/binder_exception.hpp b/src/include/duckdb/common/exception/binder_exception.hpp index 2590cb094a7f..fd7158f87336 100644 --- a/src/include/duckdb/common/exception/binder_exception.hpp +++ b/src/include/duckdb/common/exception/binder_exception.hpp @@ -15,31 +15,39 @@ namespace duckdb { class BinderException : public Exception { public: - DUCKDB_API explicit BinderException(const string &msg, const unordered_map &extra_info); DUCKDB_API explicit BinderException(const string &msg); + DUCKDB_API explicit BinderException(const unordered_map &extra_info, const string &msg); + template - explicit BinderException(const string &msg, ARGS... params) : BinderException(ConstructMessage(msg, params...)) { + explicit BinderException(const string &msg, ARGS &&...params) + : BinderException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const TableRef &ref, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(ref)) { + explicit BinderException(const TableRef &ref, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(ref), ConstructMessage(msg, std::forward(params)...)) { } template - explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const Expression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const Expression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(optional_idx error_location, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit BinderException(optional_idx error_location, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } static BinderException ColumnNotFound(const string &name, const vector &similar_bindings, diff --git a/src/include/duckdb/common/exception/catalog_exception.hpp b/src/include/duckdb/common/exception/catalog_exception.hpp index 498fafd19b73..1095531d00f6 100644 --- a/src/include/duckdb/common/exception/catalog_exception.hpp +++ b/src/include/duckdb/common/exception/catalog_exception.hpp @@ -19,14 +19,18 @@ struct EntryLookupInfo; class CatalogException : public Exception { public: DUCKDB_API explicit CatalogException(const string &msg); - DUCKDB_API explicit CatalogException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit CatalogException(const unordered_map &extra_info, const string &msg); template - explicit CatalogException(const string &msg, ARGS... params) : CatalogException(ConstructMessage(msg, params...)) { + explicit CatalogException(const string &msg, ARGS &&...params) + : CatalogException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS... params) - : CatalogException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : CatalogException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } static CatalogException MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion); diff --git a/src/include/duckdb/common/exception/conversion_exception.hpp b/src/include/duckdb/common/exception/conversion_exception.hpp index 5330f46e66b6..9252d07907c9 100644 --- a/src/include/duckdb/common/exception/conversion_exception.hpp +++ b/src/include/duckdb/common/exception/conversion_exception.hpp @@ -12,22 +12,24 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { - class ConversionException : public Exception { public: DUCKDB_API explicit ConversionException(const string &msg); + DUCKDB_API explicit ConversionException(optional_idx error_location, const string &msg); + DUCKDB_API ConversionException(const PhysicalType orig_type, const PhysicalType new_type); + DUCKDB_API ConversionException(const LogicalType &orig_type, const LogicalType &new_type); template - explicit ConversionException(const string &msg, ARGS... params) - : ConversionException(ConstructMessage(msg, params...)) { + explicit ConversionException(const string &msg, ARGS &&...params) + : ConversionException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit ConversionException(optional_idx error_location, const string &msg, ARGS... params) - : ConversionException(error_location, ConstructMessage(msg, params...)) { + explicit ConversionException(optional_idx error_location, const string &msg, ARGS &&...params) + : ConversionException(error_location, ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/include/duckdb/common/exception/http_exception.hpp b/src/include/duckdb/common/exception/http_exception.hpp index aff00d23de68..b0d0e9c2d70d 100644 --- a/src/include/duckdb/common/exception/http_exception.hpp +++ b/src/include/duckdb/common/exception/http_exception.hpp @@ -24,9 +24,9 @@ class HTTPException : public Exception { } template ::status = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.status), response.body, response.headers, response.reason, msg, - params...) { + std::forward(params)...) { } template @@ -35,16 +35,16 @@ class HTTPException : public Exception { }; template ::code = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.code), response.body, response.headers, response.error, msg, - params...) { + std::forward(params)...) { } template explicit HTTPException(int status_code, const string &response_body, const HEADERS &headers, const string &reason, - const string &msg, ARGS... params) - : Exception(ExceptionType::HTTP, ConstructMessage(msg, params...), - HTTPExtraInfo(status_code, response_body, headers, reason)) { + const string &msg, ARGS &&...params) + : Exception(HTTPExtraInfo(status_code, response_body, headers, reason), ExceptionType::HTTP, + ConstructMessage(msg, std::forward(params)...)) { } template diff --git a/src/include/duckdb/common/exception/parser_exception.hpp b/src/include/duckdb/common/exception/parser_exception.hpp index 363a34457901..26ce6c585cb6 100644 --- a/src/include/duckdb/common/exception/parser_exception.hpp +++ b/src/include/duckdb/common/exception/parser_exception.hpp @@ -17,18 +17,21 @@ namespace duckdb { class ParserException : public Exception { public: DUCKDB_API explicit ParserException(const string &msg); - DUCKDB_API explicit ParserException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit ParserException(const unordered_map &extra_info, const string &msg); template - explicit ParserException(const string &msg, ARGS... params) : ParserException(ConstructMessage(msg, params...)) { + explicit ParserException(const string &msg, ARGS &&...params) + : ParserException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(optional_idx error_location, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit ParserException(optional_idx error_location, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } static ParserException SyntaxError(const string &query, const string &error_message, optional_idx error_location); diff --git a/src/include/duckdb/common/exception/transaction_exception.hpp b/src/include/duckdb/common/exception/transaction_exception.hpp index f0164df696f2..5ca0be62bb97 100644 --- a/src/include/duckdb/common/exception/transaction_exception.hpp +++ b/src/include/duckdb/common/exception/transaction_exception.hpp @@ -11,15 +11,13 @@ #include "duckdb/common/exception.hpp" namespace duckdb { - class TransactionException : public Exception { public: DUCKDB_API explicit TransactionException(const string &msg); template - explicit TransactionException(const string &msg, ARGS... params) - : TransactionException(ConstructMessage(msg, params...)) { + explicit TransactionException(const string &msg, ARGS &&...params) + : TransactionException(ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/include/duckdb/common/exception_format_value.hpp b/src/include/duckdb/common/exception_format_value.hpp index 3693db54ceca..7beeead6e733 100644 --- a/src/include/duckdb/common/exception_format_value.hpp +++ b/src/include/duckdb/common/exception_format_value.hpp @@ -49,13 +49,13 @@ enum class ExceptionFormatValueType : uint8_t { }; struct ExceptionFormatValue { - DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT - DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT - DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT - DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(String str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT - DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT + DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT + DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT + DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(const String &str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT ExceptionFormatValueType type; @@ -65,37 +65,37 @@ struct ExceptionFormatValue { public: template - static ExceptionFormatValue CreateFormatValue(T value) { + static ExceptionFormatValue CreateFormatValue(const T &value) { return int64_t(value); } static string Format(const string &msg, std::vector &values); }; template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLString value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLIdentifier value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(LogicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value); } // namespace duckdb diff --git a/src/include/duckdb/common/hugeint.hpp b/src/include/duckdb/common/hugeint.hpp index c9b54bd95053..acdc4fb4b8a2 100644 --- a/src/include/duckdb/common/hugeint.hpp +++ b/src/include/duckdb/common/hugeint.hpp @@ -76,7 +76,7 @@ struct hugeint_t { // NOLINT: use numeric casing DUCKDB_API explicit operator int16_t() const; DUCKDB_API explicit operator int32_t() const; DUCKDB_API explicit operator int64_t() const; - DUCKDB_API operator uhugeint_t() const; // NOLINT: Allow implicit conversion from `hugeint_t` + DUCKDB_API explicit operator uhugeint_t() const; }; } // namespace duckdb diff --git a/src/include/duckdb/common/multi_file/multi_file_data.hpp b/src/include/duckdb/common/multi_file/multi_file_data.hpp index fd6380a7e31a..523084d6ee0a 100644 --- a/src/include/duckdb/common/multi_file/multi_file_data.hpp +++ b/src/include/duckdb/common/multi_file/multi_file_data.hpp @@ -139,7 +139,7 @@ struct MultiFileLocalColumnId { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return column_id; } idx_t GetId() const { @@ -170,7 +170,7 @@ struct MultiFileLocalIndex { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return index; } idx_t GetIndex() const { diff --git a/src/include/duckdb/common/operator/comparison_operators.hpp b/src/include/duckdb/common/operator/comparison_operators.hpp index f1a6f6eb33c5..a847e217bbfb 100644 --- a/src/include/duckdb/common/operator/comparison_operators.hpp +++ b/src/include/duckdb/common/operator/comparison_operators.hpp @@ -210,15 +210,4 @@ inline bool GreaterThan::Operation(const interval_t &left, const interval_t &rig return Interval::GreaterThan(left, right); } -//===--------------------------------------------------------------------===// -// Specialized Hugeint Comparison Operators -//===--------------------------------------------------------------------===// -template <> -inline bool Equals::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::Equals(left, right); -} -template <> -inline bool GreaterThan::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::GreaterThan(left, right); -} } // namespace duckdb diff --git a/src/include/duckdb/common/profiler.hpp b/src/include/duckdb/common/profiler.hpp index 5fb65337af1a..3a5cb402edb1 100644 --- a/src/include/duckdb/common/profiler.hpp +++ b/src/include/duckdb/common/profiler.hpp @@ -13,36 +13,52 @@ namespace duckdb { -//! The profiler can be used to measure elapsed time +//! Profiler class to measure the elapsed time. template class BaseProfiler { public: - //! Starts the timer + //! Start the timer. void Start() { finished = false; + ran = true; start = Tick(); } - //! Finishes timing + //! End the timer. void End() { end = Tick(); finished = true; } + //! Reset the timer. + void Reset() { + finished = false; + ran = false; + } - //! Returns the elapsed time in seconds. If End() has been called, returns - //! the total elapsed time. Otherwise returns how far along the timer is - //! right now. + //! Returns the elapsed time in seconds. + //! If ran is false, it returns 0. + //! If End() has been called, it returns the total elapsed time, + //! otherwise, returns how far along the timer is right now. double Elapsed() const { + if (!ran) { + return 0; + } auto measured_end = finished ? end : Tick(); return std::chrono::duration_cast>(measured_end - start).count(); } private: + //! Current time point. time_point Tick() const { return T::now(); } + //! Start time point. time_point start; + //! End time point. time_point end; + //! True, if end End() been called. bool finished = false; + //! True, if the timer was ran. + bool ran = false; }; using Profiler = BaseProfiler; diff --git a/src/include/duckdb/common/row_operations/row_operations.hpp b/src/include/duckdb/common/row_operations/row_operations.hpp index 557e9cd5b10e..ee1d11afbb32 100644 --- a/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/include/duckdb/common/row_operations/row_operations.hpp @@ -25,22 +25,6 @@ struct SelectionVector; class StringHeap; struct UnifiedVectorFormat; -// The NestedValidity class help to set/get the validity from inside nested vectors -class NestedValidity { - data_ptr_t list_validity_location; - data_ptr_t *struct_validity_locations; - idx_t entry_idx; - idx_t idx_in_entry; - idx_t list_validity_offset; - -public: - explicit NestedValidity(data_ptr_t validitymask_location); - NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index); - void SetInvalid(idx_t idx); - bool IsValid(idx_t idx); - void OffsetListBy(idx_t offset); -}; - struct RowOperationsState { explicit RowOperationsState(ArenaAllocator &allocator) : allocator(allocator) { } @@ -49,7 +33,7 @@ struct RowOperationsState { unique_ptr addresses; // Re-usable vector for row_aggregate.cpp }; -// RowOperations contains a set of operations that operate on data using a RowLayout +// RowOperations contains a set of operations that operate on data using a TupleDataLayout struct RowOperations { //===--------------------------------------------------------------------===// // Aggregation Operators @@ -70,66 +54,6 @@ struct RowOperations { //! finalize - unaligned addresses, updated static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, idx_t aggr_idx); - - //===--------------------------------------------------------------------===// - // Read/Write Operators - //===--------------------------------------------------------------------===// - //! Scatter group data to the rows. Initialises the ValidityMask. - static void Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count); - //! Gather a single column. - //! If heap_ptr is not null, then the data is assumed to contain swizzled pointers, - //! which will be unswizzled in memory. - static void Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size = 0, - data_ptr_t heap_ptr = nullptr); - - //===--------------------------------------------------------------------===// - // Heap Operators - //===--------------------------------------------------------------------===// - //! Compute the entry sizes of a vector with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset = 0); - //! Compute the entry sizes of vector data with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset = 0); - //! Scatter vector with variable size type to the heap. - static void HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset = 0); - //! Scatter vector data with variable size type to the heap. - static void HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset = 0); - //! Gather a single column with variable size type from the heap. - static void HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t key_locations[], - optional_ptr parent_validity); - - //===--------------------------------------------------------------------===// - // Sorting Operators - //===--------------------------------------------------------------------===// - //! Scatter vector data to the rows in radix-sortable format. - static void RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t key_locations[], bool desc, bool has_null, bool nulls_first, idx_t prefix_len, - idx_t width, idx_t offset = 0); - - //===--------------------------------------------------------------------===// - // Out-of-Core Operators - //===--------------------------------------------------------------------===// - //! Swizzles blob pointers to offset within heap row - static void SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count); - //! Swizzles the base pointer of each row to offset within heap block - static void SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset = 0); - //! Copies 'count' heap rows that are pointed to by the rows at 'row_ptr' to 'heap_ptr' and swizzles the pointers - static void CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count); - - //! Unswizzles the base offset within heap block the rows to pointers - static void UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); - //! Unswizzles all offsets back to pointers - static void UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); }; } // namespace duckdb diff --git a/src/include/duckdb/common/serializer/varint.hpp b/src/include/duckdb/common/serializer/varint.hpp index 8d0316a32e29..8cccd6f56c0b 100644 --- a/src/include/duckdb/common/serializer/varint.hpp +++ b/src/include/duckdb/common/serializer/varint.hpp @@ -35,7 +35,8 @@ uint8_t GetVarintSize(T val) { } template -void VarintEncode(T val, data_ptr_t ptr) { +idx_t VarintEncode(T val, data_ptr_t ptr) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -44,11 +45,14 @@ void VarintEncode(T val, data_ptr_t ptr) { } *ptr = byte; ptr++; + size++; } while (val != 0); + return size; } template -void VarintEncode(T val, MemoryStream &ser) { +idx_t VarintEncode(T val, MemoryStream &ser) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -56,7 +60,9 @@ void VarintEncode(T val, MemoryStream &ser) { byte |= 128; } ser.WriteData(&byte, sizeof(uint8_t)); + size++; } while (val != 0); + return size; } } // namespace duckdb diff --git a/src/include/duckdb/common/sort/comparators.hpp b/src/include/duckdb/common/sort/comparators.hpp deleted file mode 100644 index 5f3cd38071a3..000000000000 --- a/src/include/duckdb/common/sort/comparators.hpp +++ /dev/null @@ -1,65 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/comparators.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -struct SortLayout; -struct SBScanState; - -using ValidityBytes = RowLayout::ValidityBytes; - -struct Comparators { -public: - //! Whether a tie between two blobs can be broken - static bool TieIsBreakable(const idx_t &col_idx, const data_ptr_t &row_ptr, const SortLayout &sort_layout); - //! Compares the tuples that a being read from in the 'left' and 'right blocks during merge sort - //! (only in case we cannot simply 'memcmp' - if there are blob columns) - static int CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort); - //! Compare two blob values - static int CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type); - -private: - //! Compares two blob values that were initially tied by their prefix - static int BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external); - //! Compare two fixed-size values - template - static int TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr); - - //! Compare two values at the pointers (can be recursive if nested type) - static int CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid); - //! Compares two fixed-size values at the given pointers - template - static int TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr); - //! Compares two string values at the given pointers - static int CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid); - //! Compares two struct values at the given pointers (recursive) - static int CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid); - static int CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid, - idx_t array_size); - //! Compare two list values at the pointers (can be recursive if nested type) - static int CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid); - //! Compares a list of fixed-size values - template - static int TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const ValidityBytes &left_validity, - const ValidityBytes &right_validity, const idx_t &count); - - //! Unwizzles an offset into a pointer - static void UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); - //! Swizzles a pointer into an offset - static void SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); -}; - -} // namespace duckdb diff --git a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/include/duckdb/common/sort/duckdb_pdqsort.hpp deleted file mode 100644 index c935a713aa47..000000000000 --- a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ /dev/null @@ -1,710 +0,0 @@ -/* -pdqsort.h - Pattern-defeating quicksort. - -Copyright (c) 2021 Orson Peters - -This software is provided 'as-is', without any express or implied warranty. In no event will the -authors be held liable for any damages arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, including commercial -applications, and to alter it and redistribute it freely, subject to the following restrictions: - -1. The origin of this software must not be misrepresented; you must not claim that you wrote the - original software. If you use this software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - -2. Altered source versions must be plainly marked as such, and must not be misrepresented as - being the original software. - -3. This notice may not be removed or altered from any source distribution. -*/ - -#pragma once - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/unique_ptr.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb_pdqsort { - -using duckdb::data_ptr_t; -using duckdb::data_t; -using duckdb::FastMemcmp; -using duckdb::FastMemcpy; -using duckdb::idx_t; -using duckdb::make_unsafe_uniq_array_uninitialized; -using duckdb::unique_ptr; -using duckdb::unsafe_unique_array; - -// NOLINTBEGIN - -enum { - // Partitions below this size are sorted using insertion sort. - insertion_sort_threshold = 24, - - // Partitions above this size use Tukey's ninther to select the pivot. - ninther_threshold = 128, - - // When we detect an already sorted partition, attempt an insertion sort that allows this - // amount of element moves before giving up. - partial_insertion_sort_limit = 8, - - // Must be multiple of 8 due to loop unrolling, and < 256 to fit in unsigned char. - block_size = 64, - - // Cacheline size, assumes power of two. - cacheline_size = 64 - -}; - -// Returns floor(log2(n)), assumes n > 0. -template -inline int log2(T n) { - int log = 0; - while (n >>= 1) { - ++log; - } - return log; -} - -struct PDQConstants { - PDQConstants(idx_t entry_size, idx_t comp_offset, idx_t comp_size, data_ptr_t end) - : entry_size(entry_size), comp_offset(comp_offset), comp_size(comp_size), - tmp_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), tmp_buf(tmp_buf_ptr.get()), - iter_swap_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - iter_swap_buf(iter_swap_buf_ptr.get()), - swap_offsets_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - swap_offsets_buf(swap_offsets_buf_ptr.get()), end(end) { - } - - const duckdb::idx_t entry_size; - const idx_t comp_offset; - const idx_t comp_size; - - unsafe_unique_array tmp_buf_ptr; - const data_ptr_t tmp_buf; - - unsafe_unique_array iter_swap_buf_ptr; - const data_ptr_t iter_swap_buf; - - unsafe_unique_array swap_offsets_buf_ptr; - const data_ptr_t swap_offsets_buf; - - const data_ptr_t end; -}; - -struct PDQIterator { - PDQIterator(data_ptr_t ptr, const idx_t &entry_size) : ptr(ptr), entry_size(entry_size) { - } - - inline PDQIterator(const PDQIterator &other) : ptr(other.ptr), entry_size(other.entry_size) { - } - - inline const data_ptr_t &operator*() const { - return ptr; - } - - inline PDQIterator &operator++() { - ptr += entry_size; - return *this; - } - - inline PDQIterator &operator--() { - ptr -= entry_size; - return *this; - } - - inline PDQIterator operator++(int) { - auto tmp = *this; - ptr += entry_size; - return tmp; - } - - inline PDQIterator operator--(int) { - auto tmp = *this; - ptr -= entry_size; - return tmp; - } - - inline PDQIterator operator+(const idx_t &i) const { - auto result = *this; - result.ptr += i * entry_size; - return result; - } - - inline PDQIterator operator-(const idx_t &i) const { - PDQIterator result = *this; - result.ptr -= i * entry_size; - return result; - } - - inline PDQIterator &operator=(const PDQIterator &other) { - D_ASSERT(entry_size == other.entry_size); - ptr = other.ptr; - return *this; - } - - inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { - D_ASSERT(duckdb::NumericCast(*lhs - *rhs) % lhs.entry_size == 0); - D_ASSERT(*lhs - *rhs >= 0); - return duckdb::NumericCast(*lhs - *rhs) / lhs.entry_size; - } - - inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs < *rhs; - } - - inline friend bool operator>(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs > *rhs; - } - - inline friend bool operator>=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs >= *rhs; - } - - inline friend bool operator<=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs <= *rhs; - } - - inline friend bool operator==(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs == *rhs; - } - - inline friend bool operator!=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs != *rhs; - } - -private: - data_ptr_t ptr; - const idx_t &entry_size; -}; - -static inline bool comp(const data_ptr_t &l, const data_ptr_t &r, const PDQConstants &constants) { - D_ASSERT(l == constants.tmp_buf || l == constants.swap_offsets_buf || l < constants.end); - D_ASSERT(r == constants.tmp_buf || r == constants.swap_offsets_buf || r < constants.end); - return FastMemcmp(l + constants.comp_offset, r + constants.comp_offset, constants.comp_size) < 0; -} - -static inline const data_ptr_t &GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.tmp_buf, src, constants.entry_size); - return constants.tmp_buf; -} - -static inline const data_ptr_t &SWAP_OFFSETS_GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.swap_offsets_buf, src, constants.entry_size); - return constants.swap_offsets_buf; -} - -static inline void MOVE(const data_ptr_t &dest, const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(dest == constants.tmp_buf || dest == constants.swap_offsets_buf || dest < constants.end); - D_ASSERT(src == constants.tmp_buf || src == constants.swap_offsets_buf || src < constants.end); - FastMemcpy(dest, src, constants.entry_size); -} - -static inline void iter_swap(const PDQIterator &lhs, const PDQIterator &rhs, const PDQConstants &constants) { - D_ASSERT(*lhs < constants.end); - D_ASSERT(*rhs < constants.end); - FastMemcpy(constants.iter_swap_buf, *lhs, constants.entry_size); - FastMemcpy(*lhs, *rhs, constants.entry_size); - FastMemcpy(*rhs, constants.iter_swap_buf, constants.entry_size); -} - -// Sorts [begin, end) using insertion sort with the given comparison function. -inline void insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Sorts [begin, end) using insertion sort with the given comparison function. Assumes -// *(begin - 1) is an element smaller than or equal to any element in [begin, end). -inline void unguarded_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Attempts to use insertion sort on [begin, end). Will return false if more than -// partial_insertion_sort_limit elements were moved, and abort sorting. Otherwise it will -// successfully sort and return true. -inline bool partial_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return true; - } - - std::size_t limit = 0; - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - limit += cur - sift; - } - - if (limit > partial_insertion_sort_limit) { - return false; - } - } - - return true; -} - -inline void sort2(const PDQIterator &a, const PDQIterator &b, const PDQConstants &constants) { - if (comp(*b, *a, constants)) { - iter_swap(a, b, constants); - } -} - -// Sorts the elements *a, *b and *c using comparison function comp. -inline void sort3(const PDQIterator &a, const PDQIterator &b, const PDQIterator &c, const PDQConstants &constants) { - sort2(a, b, constants); - sort2(b, c, constants); - sort2(a, b, constants); -} - -template -inline T *align_cacheline(T *p) { -#if defined(UINTPTR_MAX) && __cplusplus >= 201103L - std::uintptr_t ip = reinterpret_cast(p); -#else - std::size_t ip = reinterpret_cast(p); -#endif - ip = (ip + cacheline_size - 1) & -duckdb::UnsafeNumericCast(cacheline_size); - return reinterpret_cast(ip); -} - -inline void swap_offsets(const PDQIterator &first, const PDQIterator &last, unsigned char *offsets_l, - unsigned char *offsets_r, size_t num, bool use_swaps, const PDQConstants &constants) { - if (use_swaps) { - // This case is needed for the descending distribution, where we need - // to have proper swapping for pdqsort to remain O(n). - for (size_t i = 0; i < num; ++i) { - iter_swap(first + offsets_l[i], last - offsets_r[i], constants); - } - } else if (num > 0) { - PDQIterator l = first + offsets_l[0]; - PDQIterator r = last - offsets_r[0]; - const auto &tmp = SWAP_OFFSETS_GET_TMP(*l, constants); - MOVE(*l, *r, constants); - for (size_t i = 1; i < num; ++i) { - l = first + offsets_l[i]; - MOVE(*r, *l, constants); - r = last - offsets_r[i]; - MOVE(*l, *r, constants); - } - MOVE(*r, tmp, constants); - } -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. Uses branchless partitioning. -inline std::pair partition_right_branchless(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - if (!already_partitioned) { - iter_swap(first, last, constants); - ++first; - - // The following branchless partitioning is derived from "BlockQuicksort: How Branch - // Mispredictions don’t affect Quicksort" by Stefan Edelkamp and Armin Weiss, but - // heavily micro-optimized. - unsigned char offsets_l_storage[block_size + cacheline_size]; - unsigned char offsets_r_storage[block_size + cacheline_size]; - unsigned char *offsets_l = align_cacheline(offsets_l_storage); - unsigned char *offsets_r = align_cacheline(offsets_r_storage); - - PDQIterator offsets_l_base = first; - PDQIterator offsets_r_base = last; - size_t num_l, num_r, start_l, start_r; - num_l = num_r = start_l = start_r = 0; - - while (first < last) { - // Fill up offset blocks with elements that are on the wrong side. - // First we determine how much elements are considered for each offset block. - size_t num_unknown = last - first; - size_t left_split = num_l == 0 ? (num_r == 0 ? num_unknown / 2 : num_unknown) : 0; - size_t right_split = num_r == 0 ? (num_unknown - left_split) : 0; - - // Fill the offset blocks. - if (left_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } else { - for (unsigned char i = 0; i < left_split;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } - - if (right_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } else { - for (unsigned char i = 0; i < right_split;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } - - // Swap elements and update block sizes and first/last boundaries. - size_t num = std::min(num_l, num_r); - swap_offsets(offsets_l_base, offsets_r_base, offsets_l + start_l, offsets_r + start_r, num, num_l == num_r, - constants); - num_l -= num; - num_r -= num; - start_l += num; - start_r += num; - - if (num_l == 0) { - start_l = 0; - offsets_l_base = first; - } - - if (num_r == 0) { - start_r = 0; - offsets_r_base = last; - } - } - - // We have now fully identified [first, last)'s proper position. Swap the last elements. - if (num_l) { - offsets_l += start_l; - while (num_l--) { - iter_swap(offsets_l_base + offsets_l[num_l], --last, constants); - } - first = last; - } - if (num_r) { - offsets_r += start_r; - while (num_r--) { - iter_swap(offsets_r_base - offsets_r[num_r], first, constants), ++first; - } - last = first; - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. -inline std::pair partition_right(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - - // Keep swapping pairs of elements that are on the wrong side of the pivot. Previously - // swapped pairs guard the searches, which is why the first iteration is special-cased - // above. - while (first < last) { - iter_swap(first, last, constants); - while (comp(*++first, pivot, constants)) { - } - while (!comp(*--last, pivot, constants)) { - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Similar function to the one above, except elements equal to the pivot are put to the left of -// the pivot and it doesn't check or return if the passed sequence already was partitioned. -// Since this is rarely used (the many equal case), and in that case pdqsort already has O(n) -// performance, no block quicksort is applied here for simplicity. -inline PDQIterator partition_left(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - while (comp(pivot, *--last, constants)) { - } - - if (last + 1 == end) { - while (first < last && !comp(pivot, *++first, constants)) { - } - } else { - while (!comp(pivot, *++first, constants)) { - } - } - - while (first < last) { - iter_swap(first, last, constants); - while (comp(pivot, *--last, constants)) { - } - while (!comp(pivot, *++first, constants)) { - } - } - - PDQIterator pivot_pos = last; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return pivot_pos; -} - -template -inline void pdqsort_loop(PDQIterator begin, const PDQIterator &end, const PDQConstants &constants, int bad_allowed, - bool leftmost = true) { - // Use a while loop for tail recursion elimination. - while (true) { - idx_t size = end - begin; - - // Insertion sort is faster for small arrays. - if (size < insertion_sort_threshold) { - if (leftmost) { - insertion_sort(begin, end, constants); - } else { - unguarded_insertion_sort(begin, end, constants); - } - return; - } - - // Choose pivot as median of 3 or pseudomedian of 9. - idx_t s2 = size / 2; - if (size > ninther_threshold) { - sort3(begin, begin + s2, end - 1, constants); - sort3(begin + 1, begin + (s2 - 1), end - 2, constants); - sort3(begin + 2, begin + (s2 + 1), end - 3, constants); - sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), constants); - iter_swap(begin, begin + s2, constants); - } else { - sort3(begin + s2, begin, end - 1, constants); - } - - // If *(begin - 1) is the end of the right partition of a previous partition operation - // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our - // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in - // the left partition, greater elements in the right partition. We do not have to - // recurse on the left partition, since it's sorted (all equal). - if (!leftmost && !comp(*(begin - 1), *begin, constants)) { - begin = partition_left(begin, end, constants) + 1; - continue; - } - - // Partition and get results. - std::pair part_result = - Branchless ? partition_right_branchless(begin, end, constants) : partition_right(begin, end, constants); - PDQIterator pivot_pos = part_result.first; - bool already_partitioned = part_result.second; - - // Check for a highly unbalanced partition. - idx_t l_size = pivot_pos - begin; - idx_t r_size = end - (pivot_pos + 1); - bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; - - // If we got a highly unbalanced partition we shuffle elements to break many patterns. - if (highly_unbalanced) { - // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). - // if (--bad_allowed == 0) { - // std::make_heap(begin, end, comp); - // std::sort_heap(begin, end, comp); - // return; - // } - - if (l_size >= insertion_sort_threshold) { - iter_swap(begin, begin + l_size / 4, constants); - iter_swap(pivot_pos - 1, pivot_pos - l_size / 4, constants); - - if (l_size > ninther_threshold) { - iter_swap(begin + 1, begin + (l_size / 4 + 1), constants); - iter_swap(begin + 2, begin + (l_size / 4 + 2), constants); - iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1), constants); - iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2), constants); - } - } - - if (r_size >= insertion_sort_threshold) { - iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4), constants); - iter_swap(end - 1, end - r_size / 4, constants); - - if (r_size > ninther_threshold) { - iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4), constants); - iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4), constants); - iter_swap(end - 2, end - (1 + r_size / 4), constants); - iter_swap(end - 3, end - (2 + r_size / 4), constants); - } - } - } else { - // If we were decently balanced and we tried to sort an already partitioned - // sequence try to use insertion sort. - if (already_partitioned && partial_insertion_sort(begin, pivot_pos, constants) && - partial_insertion_sort(pivot_pos + 1, end, constants)) { - return; - } - } - - // Sort the left partition first using recursion and do tail recursion elimination for - // the right-hand partition. - pdqsort_loop(begin, pivot_pos, constants, bad_allowed, leftmost); - begin = pivot_pos + 1; - leftmost = false; - } -} - -inline void pdqsort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} - -inline void pdqsort_branchless(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} -// NOLINTEND - -} // namespace duckdb_pdqsort diff --git a/src/include/duckdb/common/sort/sort.hpp b/src/include/duckdb/common/sort/sort.hpp deleted file mode 100644 index 188ea21273fa..000000000000 --- a/src/include/duckdb/common/sort/sort.hpp +++ /dev/null @@ -1,290 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sort.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class RowLayout; -struct LocalSortState; - -struct SortConstants { - static constexpr idx_t VALUES_PER_RADIX = 256; - static constexpr idx_t MSD_RADIX_LOCATIONS = VALUES_PER_RADIX + 1; - static constexpr idx_t INSERTION_SORT_THRESHOLD = 24; - static constexpr idx_t MSD_RADIX_SORT_SIZE_THRESHOLD = 4; -}; - -struct SortLayout { -public: - SortLayout() { - } - explicit SortLayout(const vector &orders); - SortLayout GetPrefixComparisonLayout(idx_t num_prefix_cols) const; - -public: - idx_t column_count; - vector order_types; - vector order_by_null_types; - vector logical_types; - - bool all_constant; - vector constant_size; - vector column_sizes; - vector prefix_lengths; - vector stats; - vector has_null; - - idx_t comparison_size; - idx_t entry_size; - - RowLayout blob_layout; - unordered_map sorting_to_blob_col; -}; - -struct GlobalSortState { -public: - GlobalSortState(ClientContext &context, const vector &orders, RowLayout &payload_layout); - - //! Add local state sorted data to this global state - void AddLocalState(LocalSortState &local_sort_state); - //! Prepares the GlobalSortState for the merge sort phase (after completing radix sort phase) - void PrepareMergePhase(); - //! Initializes the global sort state for another round of merging - void InitializeMergeRound(); - //! Completes the cascaded merge sort round. - //! Pass true if you wish to use the radix data for further comparisons. - void CompleteMergeRound(bool keep_radix_data = false); - //! Print the sorted data to the console. - void Print(); - -public: - //! The client context - ClientContext &context; - //! The lock for updating the order global state - mutex lock; - //! The buffer manager - BufferManager &buffer_manager; - - //! Sorting and payload layouts - const SortLayout sort_layout; - const RowLayout payload_layout; - - //! Sorted data - vector> sorted_blocks; - vector>> sorted_blocks_temp; - unique_ptr odd_one_out; - - //! Pinned heap data (if sorting in memory) - vector> heap_blocks; - vector pinned_blocks; - - //! Capacity (number of rows) used to initialize blocks - idx_t block_capacity; - //! Whether we are doing an external sort - bool external; - - //! Progress in merge path stage - idx_t pair_idx; - idx_t num_pairs; - idx_t l_start; - idx_t r_start; -}; - -struct LocalSortState { -public: - LocalSortState(); - - //! Initialize the layouts and RowDataCollections - void Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p); - //! Sink one DataChunk into the local sort state - void SinkChunk(DataChunk &sort, DataChunk &payload); - //! Size of accumulated data in bytes - idx_t SizeInBytes() const; - //! Sort the data accumulated so far - void Sort(GlobalSortState &global_sort_state, bool reorder_heap); - //! Concatenate the blocks held by a RowDataCollection into a single block - static unique_ptr ConcatenateBlocks(RowDataCollection &row_data); - -private: - //! Sorts the data in the newly created SortedBlock - void SortInMemory(); - //! Re-order the local state after sorting - void ReOrder(GlobalSortState &gstate, bool reorder_heap); - //! Re-order a SortedData object after sorting - void ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap); - -public: - //! Whether this local state has been initialized - bool initialized; - //! The buffer manager - BufferManager *buffer_manager; - //! The sorting and payload layouts - const SortLayout *sort_layout; - const RowLayout *payload_layout; - //! Radix/memcmp sortable data - unique_ptr radix_sorting_data; - //! Variable sized sorting data and accompanying heap - unique_ptr blob_sorting_data; - unique_ptr blob_sorting_heap; - //! Payload data and accompanying heap - unique_ptr payload_data; - unique_ptr payload_heap; - //! Sorted data - vector> sorted_blocks; - -private: - //! Selection vector and addresses for scattering the data to rows - const SelectionVector &sel_ptr = *FlatVector::IncrementalSelectionVector(); - Vector addresses = Vector(LogicalType::POINTER); -}; - -struct MergeSorter { -public: - MergeSorter(GlobalSortState &state, BufferManager &buffer_manager); - - //! Finds and merges partitions until the current cascaded merge round is finished - void PerformInMergeRound(); - -private: - //! The global sorting state - GlobalSortState &state; - //! The sorting and payload layouts - BufferManager &buffer_manager; - const SortLayout &sort_layout; - - //! The left and right reader - unique_ptr left; - unique_ptr right; - - //! Input and output blocks - unique_ptr left_input; - unique_ptr right_input; - SortedBlock *result; - -private: - //! Computes the left and right block that will be merged next (Merge Path partition) - void GetNextPartition(); - //! Finds the boundary of the next partition using binary search - void GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx); - //! Compare values within SortedBlocks using a global index - int CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx); - - //! Finds the next partition and merges it - void MergePartition(); - - //! Computes how the next 'count' tuples should be merged by setting the 'left_smaller' array - void ComputeMerge(const idx_t &count, bool left_smaller[]); - - //! Merges the radix sorting blocks according to the 'left_smaller' array - void MergeRadix(const idx_t &count, const bool left_smaller[]); - //! Merges SortedData according to the 'left_smaller' array - void MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices); - //! Merges constant size rows according to the 'left_smaller' array - void MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, idx_t &r_entry_idx, - const idx_t &r_count, RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, - const bool left_smaller[], idx_t &copied, const idx_t &count); - //! Flushes constant size rows into the result - void FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count); - //! Flushes blob rows and accompanying heap - void FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, BufferHandle &target_heap_handle, - data_ptr_t &target_heap_ptr, idx_t &copied, const idx_t &count); -}; - -struct SBIterator { - static int ComparisonValue(ExpressionType comparison); - - SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p = 0); - - inline idx_t GetIndex() const { - return entry_idx; - } - - inline void SetIndex(idx_t entry_idx_p) { - const auto new_block_idx = entry_idx_p / block_capacity; - if (new_block_idx != scan.block_idx) { - scan.SetIndices(new_block_idx, 0); - if (new_block_idx < block_count) { - scan.PinRadix(scan.block_idx); - block_ptr = scan.RadixPtr(); - if (!all_constant) { - scan.PinData(*scan.sb->blob_sorting_data); - } - } - } - - scan.entry_idx = entry_idx_p % block_capacity; - entry_ptr = block_ptr + scan.entry_idx * entry_size; - entry_idx = entry_idx_p; - } - - inline SBIterator &operator++() { - if (++scan.entry_idx < block_capacity) { - entry_ptr += entry_size; - ++entry_idx; - } else { - SetIndex(entry_idx + 1); - } - - return *this; - } - - inline SBIterator &operator--() { - if (scan.entry_idx) { - --scan.entry_idx; - --entry_idx; - entry_ptr -= entry_size; - } else { - SetIndex(entry_idx - 1); - } - - return *this; - } - - inline bool Compare(const SBIterator &other, const SortLayout &prefix) const { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(entry_ptr, other.entry_ptr, prefix.comparison_size); - } else { - comp_res = Comparators::CompareTuple(scan, other.scan, entry_ptr, other.entry_ptr, prefix, external); - } - - return comp_res <= cmp; - } - - inline bool Compare(const SBIterator &other) const { - return Compare(other, sort_layout); - } - - // Fixed comparison parameters - const SortLayout &sort_layout; - const idx_t block_count; - const idx_t block_capacity; - const size_t entry_size; - const bool all_constant; - const bool external; - const int cmp; - - // Iteration state - SBScanState scan; - idx_t entry_idx; - data_ptr_t block_ptr; - data_ptr_t entry_ptr; -}; - -} // namespace duckdb diff --git a/src/include/duckdb/common/sort/sorted_block.hpp b/src/include/duckdb/common/sort/sorted_block.hpp deleted file mode 100644 index b6941bda20d8..000000000000 --- a/src/include/duckdb/common/sort/sorted_block.hpp +++ /dev/null @@ -1,165 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sorted_block.hpp -// -// -//===----------------------------------------------------------------------===// -#pragma once - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" - -namespace duckdb { - -class BufferManager; -struct RowDataBlock; -struct SortLayout; -struct GlobalSortState; - -enum class SortedDataType { BLOB, PAYLOAD }; - -//! Object that holds sorted rows, and an accompanying heap if there are blobs -struct SortedData { -public: - SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, GlobalSortState &state); - //! Number of rows that this object holds - idx_t Count(); - //! Initialize new block to write to - void CreateBlock(); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index); - //! Unswizzles all - void Unswizzle(); - -public: - const SortedDataType type; - //! Layout of this data - const RowLayout layout; - //! Data and heap blocks - vector> data_blocks; - vector> heap_blocks; - //! Whether the pointers in this sorted data are swizzled - bool swizzled; - -private: - //! The buffer manager - BufferManager &buffer_manager; - //! The global state - GlobalSortState &state; -}; - -//! Block that holds sorted rows: radix, blob and payload data -struct SortedBlock { -public: - SortedBlock(BufferManager &buffer_manager, GlobalSortState &gstate); - //! Number of rows that this object holds - idx_t Count() const; - //! Initialize this block to write data to - void InitializeWrite(); - //! Init new block to write to - void CreateBlock(); - //! Fill this sorted block by appending the blocks held by a vector of sorted blocks - void AppendSortedBlocks(vector> &sorted_blocks); - //! Locate the block and entry index of a row in this block, - //! given an index between 0 and the total number of rows in this block - void GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx); - - //! Size (in bytes) of the heap of this block - idx_t HeapSize() const; - //! Total size (in bytes) of this block - idx_t SizeInBytes() const; - -public: - //! Radix/memcmp sortable data - vector> radix_sorting_data; - //! Variable sized sorting data - unique_ptr blob_sorting_data; - //! Payload data - unique_ptr payload_data; - -private: - //! Buffer manager, global state, and sorting layout constants - BufferManager &buffer_manager; - GlobalSortState &state; - const SortLayout &sort_layout; - const RowLayout &payload_layout; -}; - -//! State used to scan a SortedBlock e.g. during merge sort -struct SBScanState { -public: - SBScanState(BufferManager &buffer_manager, GlobalSortState &state); - - void PinRadix(idx_t block_idx_to); - void PinData(SortedData &sd); - - data_ptr_t RadixPtr() const; - data_ptr_t DataPtr(SortedData &sd) const; - data_ptr_t HeapPtr(SortedData &sd) const; - data_ptr_t BaseHeapPtr(SortedData &sd) const; - - idx_t Remaining() const; - - void SetIndices(idx_t block_idx_to, idx_t entry_idx_to); - -public: - BufferManager &buffer_manager; - const SortLayout &sort_layout; - GlobalSortState &state; - - SortedBlock *sb; - - idx_t block_idx; - idx_t entry_idx; - - BufferHandle radix_handle; - - BufferHandle blob_sorting_data_handle; - BufferHandle blob_sorting_heap_handle; - - BufferHandle payload_data_handle; - BufferHandle payload_heap_handle; -}; - -//! Used to scan the data into DataChunks after sorting -struct PayloadScanner { -public: - PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush = true); - explicit PayloadScanner(GlobalSortState &global_sort_state, bool flush = true); - - //! Scan a single block - PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush = false); - - //! The type layout of the payload - inline const vector &GetPayloadTypes() const { - return scanner->GetTypes(); - } - - //! The number of rows scanned so far - inline idx_t Scanned() const { - return scanner->Scanned(); - } - - //! The number of remaining rows - inline idx_t Remaining() const { - return scanner->Remaining(); - } - - //! Scans the next data chunk from the sorted data - void Scan(DataChunk &chunk); - -private: - //! The sorted data being scanned - unique_ptr rows; - unique_ptr heap; - //! The actual scanner - unique_ptr scanner; -}; - -} // namespace duckdb diff --git a/src/include/duckdb/common/types/geometry.hpp b/src/include/duckdb/common/types/geometry.hpp index 5b9bcd1f8406..228729a3ed63 100644 --- a/src/include/duckdb/common/types/geometry.hpp +++ b/src/include/duckdb/common/types/geometry.hpp @@ -186,7 +186,7 @@ class GeometryExtent { class Geometry { public: - static constexpr auto MAX_RECURSION_DEPTH = 16; + static constexpr idx_t MAX_RECURSION_DEPTH = 16; //! Convert from WKT DUCKDB_API static bool FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict); diff --git a/src/include/duckdb/common/types/hugeint.hpp b/src/include/duckdb/common/types/hugeint.hpp index 3720bf844ec2..9fa5d447ba89 100644 --- a/src/include/duckdb/common/types/hugeint.hpp +++ b/src/include/duckdb/common/types/hugeint.hpp @@ -129,38 +129,38 @@ class Hugeint { static int Sign(hugeint_t n); static hugeint_t Abs(hugeint_t n); // comparison operators - static bool Equals(hugeint_t lhs, hugeint_t rhs) { + static bool Equals(const hugeint_t &lhs, const hugeint_t &rhs) { bool lower_equals = lhs.lower == rhs.lower; bool upper_equals = lhs.upper == rhs.upper; return lower_equals && upper_equals; } - static bool NotEquals(hugeint_t lhs, hugeint_t rhs) { + static bool NotEquals(const hugeint_t &lhs, const hugeint_t &rhs) { return !Equals(lhs, rhs); } - static bool GreaterThan(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger = lhs.lower > rhs.lower; return upper_bigger || (upper_equal && lower_bigger); } - static bool GreaterThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger_equals = lhs.lower >= rhs.lower; return upper_bigger || (upper_equal && lower_bigger_equals); } - static bool LessThan(hugeint_t lhs, hugeint_t rhs) { + static bool LessThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller = lhs.lower < rhs.lower; return upper_smaller || (upper_equal && lower_smaller); } - static bool LessThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool LessThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller_equals = lhs.lower <= rhs.lower; diff --git a/src/include/duckdb/common/types/variant_visitor.hpp b/src/include/duckdb/common/types/variant_visitor.hpp new file mode 100644 index 000000000000..950980aefc85 --- /dev/null +++ b/src/include/duckdb/common/types/variant_visitor.hpp @@ -0,0 +1,232 @@ +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/enum_util.hpp" + +#include + +namespace duckdb { + +template +class VariantVisitor { + // Detects if T has a static VisitMetadata with signature + // void VisitMetadata(VariantLogicalType, Args...) + template + class has_visit_metadata { + private: + template + static auto test(int) -> decltype(U::VisitMetadata(std::declval(), std::declval()...), + std::true_type {}); + + template + static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; + }; + +public: + template + static ReturnType Visit(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, Args &&...args) { + if (!variant.RowIsValid(row)) { + return Visitor::VisitNull(std::forward(args)...); + } + + auto type_id = variant.GetTypeId(row, values_idx); + auto byte_offset = variant.GetByteOffset(row, values_idx); + auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = const_data_ptr_cast(blob_data + byte_offset); + + VisitMetadata(type_id, std::forward(args)...); + + switch (type_id) { + case VariantLogicalType::VARIANT_NULL: + return Visitor::VisitNull(std::forward(args)...); + case VariantLogicalType::BOOL_TRUE: + return Visitor::VisitBoolean(true, std::forward(args)...); + case VariantLogicalType::BOOL_FALSE: + return Visitor::VisitBoolean(false, std::forward(args)...); + case VariantLogicalType::INT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::FLOAT: + return Visitor::VisitFloat(Load(ptr), std::forward(args)...); + case VariantLogicalType::DOUBLE: + return Visitor::VisitDouble(Load(ptr), std::forward(args)...); + case VariantLogicalType::UUID: + return Visitor::VisitUUID(Load(ptr), std::forward(args)...); + case VariantLogicalType::DATE: + return Visitor::VisitDate(date_t(Load(ptr)), std::forward(args)...); + case VariantLogicalType::INTERVAL: + return Visitor::VisitInterval(Load(ptr), std::forward(args)...); + case VariantLogicalType::VARCHAR: + case VariantLogicalType::BLOB: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::GEOMETRY: + return VisitString(type_id, variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::DECIMAL: + return VisitDecimal(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::ARRAY: + return VisitArray(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::OBJECT: + return VisitObject(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::TIME_MICROS: + return Visitor::VisitTime(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_NANOS: + return Visitor::VisitTimeNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_MICROS_TZ: + return Visitor::VisitTimeTZ(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_SEC: + return Visitor::VisitTimestampSec(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MILIS: + return Visitor::VisitTimestampMs(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS: + return Visitor::VisitTimestamp(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_NANOS: + return Visitor::VisitTimestampNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + return Visitor::VisitTimestampTZ(Load(ptr), std::forward(args)...); + default: + return Visitor::VisitDefault(type_id, ptr, std::forward(args)...); + } + } + + // Non-void version + template + static typename std::enable_if::value, vector>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + vector array_items; + array_items.reserve(array_data.child_count); + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + array_items.emplace_back(Visit(variant, row, values_index, std::forward(args)...)); + } + return array_items; + } + + // Void version + template + static typename std::enable_if::value, void>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + Visit(variant, row, values_index, std::forward(args)...); + } + } + + template + static child_list_t VisitObjectItems(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &object_data, Args &&...args) { + child_list_t object_items; + for (idx_t i = 0; i < object_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, object_data.children_idx + i); + auto val = Visit(variant, row, values_index, std::forward(args)...); + + auto keys_index = variant.GetKeysIndex(row, object_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + + object_items.emplace_back(key.GetString(), std::move(val)); + } + return object_items; + } + +private: + template + static typename std::enable_if::value, void>::type + VisitMetadata(VariantLogicalType type_id, Args &&...args) { + Visitor::VisitMetadata(type_id, std::forward(args)...); + } + + // Fallback if the method does not exist + template + static typename std::enable_if::value, void>::type VisitMetadata(VariantLogicalType, + Args &&...) { + // do nothing + } + + template + static ReturnType VisitArray(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitArray(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitObject(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitObject(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitString(VariantLogicalType type_id, const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx, Args &&...args) { + auto decoded_string = VariantUtils::DecodeStringData(variant, row, values_idx); + if (type_id == VariantLogicalType::VARCHAR) { + return Visitor::VisitString(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BLOB) { + return Visitor::VisitBlob(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BIGNUM) { + return Visitor::VisitBignum(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::GEOMETRY) { + return Visitor::VisitGeometry(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BITSTRING) { + return Visitor::VisitBitstring(decoded_string, std::forward(args)...); + } + throw InternalException("String-backed variant type (%s) not handled", EnumUtil::ToString(type_id)); + } + + template + static ReturnType VisitDecimal(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_decimal = VariantUtils::DecodeDecimalData(variant, row, values_idx); + auto &width = decoded_decimal.width; + auto &scale = decoded_decimal.scale; + auto &ptr = decoded_decimal.value_ptr; + if (width > DecimalWidth::max) { + throw InternalException("Can't handle decimal of width: %d", width); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } + } +}; + +} // namespace duckdb diff --git a/src/include/duckdb/execution/index/art/base_leaf.hpp b/src/include/duckdb/execution/index/art/base_leaf.hpp index 209d022dc034..797c18469d81 100644 --- a/src/include/duckdb/execution/index/art/base_leaf.hpp +++ b/src/include/duckdb/execution/index/art/base_leaf.hpp @@ -31,13 +31,15 @@ class BaseLeaf { public: //! Get a new BaseLeaf and initialize it. - static BaseLeaf &New(ART &art, Node &node) { + static NodeHandle New(ART &art, Node &node) { node = Node::GetAllocator(art, TYPE).New(); node.SetMetadata(static_cast(TYPE)); - auto &n = Node::Ref(art, node, TYPE); + NodeHandle handle(art, node); + auto &n = handle.Get(); + n.count = 0; - return n; + return handle; } //! Returns true, if the byte exists, else false. @@ -70,7 +72,7 @@ class BaseLeaf { private: static void InsertByteInternal(BaseLeaf &n, const uint8_t byte); - static BaseLeaf &DeleteByteInternal(ART &art, Node &node, const uint8_t byte); + static NodeHandle DeleteByteInternal(ART &art, Node &node, const uint8_t byte); }; //! Node7Leaf holds up to seven sorted bytes. diff --git a/src/include/duckdb/execution/index/fixed_size_allocator.hpp b/src/include/duckdb/execution/index/fixed_size_allocator.hpp index 691a4aac65f2..65ffd167f32d 100644 --- a/src/include/duckdb/execution/index/fixed_size_allocator.hpp +++ b/src/include/duckdb/execution/index/fixed_size_allocator.hpp @@ -30,7 +30,8 @@ class FixedSizeAllocator { public: //! Construct a new fixed-size allocator - FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager); + FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, + MemoryTag memory_tag = MemoryTag::ART_INDEX); //! Block manager of the database instance BlockManager &block_manager; @@ -152,6 +153,8 @@ class FixedSizeAllocator { void VerifyBuffers(); private: + //! Memory tag of memory that is allocated through the allocator + MemoryTag memory_tag; //! Allocation size of one segment in a buffer //! We only need this value to calculate bitmask_count, bitmask_offset, and //! available_segments_per_buffer diff --git a/src/include/duckdb/execution/index/fixed_size_buffer.hpp b/src/include/duckdb/execution/index/fixed_size_buffer.hpp index e7c5b6aa9cd3..6ca7dc1aa35e 100644 --- a/src/include/duckdb/execution/index/fixed_size_buffer.hpp +++ b/src/include/duckdb/execution/index/fixed_size_buffer.hpp @@ -43,7 +43,7 @@ class FixedSizeBuffer { public: //! Constructor for a new in-memory buffer - explicit FixedSizeBuffer(BlockManager &block_manager); + explicit FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag); //! Constructor for deserializing buffer metadata from disk FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, const BlockPointer &block_pointer); diff --git a/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp b/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp index bacabfc4fca3..16b027ad6886 100644 --- a/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp +++ b/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp @@ -41,13 +41,15 @@ class FullLinePosition { return {}; } string result; - if (end.buffer_idx == begin.buffer_idx) { - if (buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { + if (end.buffer_idx == begin.buffer_idx || begin.buffer_pos == begin.buffer_size) { + idx_t buffer_idx = end.buffer_idx; + if (buffer_handles.find(buffer_idx) == buffer_handles.end()) { return {}; } - auto buffer = buffer_handles[begin.buffer_idx]->Ptr(); - first_char_nl = buffer[begin.buffer_pos] == '\n' || buffer[begin.buffer_pos] == '\r'; - for (idx_t i = begin.buffer_pos + first_char_nl; i < end.buffer_pos; i++) { + idx_t start_pos = begin.buffer_pos == begin.buffer_size ? 0 : begin.buffer_pos; + auto buffer = buffer_handles[buffer_idx]->Ptr(); + first_char_nl = buffer[start_pos] == '\n' || buffer[start_pos] == '\r'; + for (idx_t i = start_pos + first_char_nl; i < end.buffer_pos; i++) { result += buffer[i]; } } else { @@ -55,6 +57,9 @@ class FullLinePosition { buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { return {}; } + if (begin.buffer_pos >= begin.buffer_size) { + throw InternalException("CSV reader: buffer pos out of range for buffer"); + } auto first_buffer = buffer_handles[begin.buffer_idx]->Ptr(); auto first_buffer_size = buffer_handles[begin.buffer_idx]->actual_size; auto second_buffer = buffer_handles[end.buffer_idx]->Ptr(); diff --git a/src/include/duckdb/execution/operator/join/physical_iejoin.hpp b/src/include/duckdb/execution/operator/join/physical_iejoin.hpp index b57fe772d3a0..a93109a0ee3a 100644 --- a/src/include/duckdb/execution/operator/join/physical_iejoin.hpp +++ b/src/include/duckdb/execution/operator/join/physical_iejoin.hpp @@ -70,10 +70,6 @@ class PhysicalIEJoin : public PhysicalRangeJoin { public: void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; - -private: - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) - void ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state) const; }; } // namespace duckdb diff --git a/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp b/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp index a26772819c43..cc3d0c368be0 100644 --- a/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp +++ b/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp @@ -242,6 +242,30 @@ class BinaryAggregateHeap { idx_t size; }; +enum class ArgMinMaxNullHandling { IGNORE_ANY_NULL, HANDLE_ARG_NULL, HANDLE_ANY_NULL }; + +struct ArgMinMaxFunctionData : FunctionData { + explicit ArgMinMaxFunctionData(ArgMinMaxNullHandling null_handling_p = ArgMinMaxNullHandling::IGNORE_ANY_NULL, + bool nulls_last_p = true) + : null_handling(null_handling_p), nulls_last(nulls_last_p) { + } + + unique_ptr Copy() const override { + auto copy = make_uniq(); + copy->null_handling = null_handling; + copy->nulls_last = nulls_last; + return std::move(copy); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return other.null_handling == null_handling && other.nulls_last == nulls_last; + } + + ArgMinMaxNullHandling null_handling; + bool nulls_last; +}; + //------------------------------------------------------------------------------ // Specializations for fixed size types, strings, and anything else (using sortkey) //------------------------------------------------------------------------------ @@ -254,7 +278,7 @@ struct MinMaxFixedValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = value; } @@ -263,7 +287,8 @@ struct MinMaxFixedValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -276,7 +301,7 @@ struct MinMaxStringValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = StringVector::AddStringOrBlob(vector, value); } @@ -285,7 +310,8 @@ struct MinMaxStringValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -299,8 +325,9 @@ struct MinMaxFallbackValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { - OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::DecodeSortKey(value, vector, idx, modifiers); } @@ -308,14 +335,61 @@ struct MinMaxFallbackValue { return Vector(LogicalTypeId::BLOB); } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format) { - const OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + const OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::CreateSortKeyWithValidity(input, extra_state, modifiers, count); input.Flatten(count); extra_state.ToUnifiedFormat(count, format); } }; +template +struct ValueOrNull { + T value; + bool is_valid; + + bool operator==(const ValueOrNull &other) const { + return is_valid == other.is_valid && value == other.value; + } + + bool operator>(const ValueOrNull &other) const { + if (is_valid && other.is_valid) { + return value > other.value; + } + if (!is_valid && !other.is_valid) { + return false; + } + + return is_valid ^ !NULLS_LAST; + } +}; + +template +struct MinMaxFixedValueOrNull { + using TYPE = ValueOrNull; + using EXTRA_STATE = bool; + + static TYPE Create(const UnifiedVectorFormat &format, const idx_t idx) { + return TYPE {UnifiedVectorFormat::GetData(format)[idx], format.validity.RowIsValid(idx)}; + } + + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + FlatVector::Validity(vector).Set(idx, value.is_valid); + FlatVector::GetData(vector)[idx] = value.value; + } + + static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + return false; + } + + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + input.ToUnifiedFormat(count, format); + } +}; + //------------------------------------------------------------------------------ // MinMaxN Operation (common for both ArgMinMaxN and MinMaxN) //------------------------------------------------------------------------------ @@ -343,7 +417,11 @@ struct MinMaxNOperation { } template - static void Finalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + static void Finalize(Vector &state_vector, AggregateInputData &input_data, Vector &result, idx_t count, + idx_t offset) { + // We only expect bind data from arg_max, otherwise nulls last is the default + const bool nulls_last = + input_data.bind_data ? input_data.bind_data->Cast().nulls_last : true; UnifiedVectorFormat state_format; state_vector.ToUnifiedFormat(count, state_format); @@ -387,7 +465,7 @@ struct MinMaxNOperation { auto heap = state.heap.SortAndGetHeap(); for (idx_t slot = 0; slot < state.heap.Size(); slot++) { - STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot])); + STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot]), nulls_last); } } diff --git a/src/include/duckdb/function/cast/variant/variant_to_variant.hpp b/src/include/duckdb/function/cast/variant/variant_to_variant.hpp index cdbf698bdf48..d97954118c3e 100644 --- a/src/include/duckdb/function/cast/variant/variant_to_variant.hpp +++ b/src/include/duckdb/function/cast/variant/variant_to_variant.hpp @@ -1,93 +1,246 @@ #pragma once #include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" namespace duckdb { namespace variant { -static bool VariantIsTrivialPrimitive(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - case VariantLogicalType::INT16: - case VariantLogicalType::INT32: - case VariantLogicalType::INT64: - case VariantLogicalType::INT128: - case VariantLogicalType::UINT8: - case VariantLogicalType::UINT16: - case VariantLogicalType::UINT32: - case VariantLogicalType::UINT64: - case VariantLogicalType::UINT128: - case VariantLogicalType::FLOAT: - case VariantLogicalType::DOUBLE: - case VariantLogicalType::UUID: - case VariantLogicalType::DATE: - case VariantLogicalType::TIME_MICROS: - case VariantLogicalType::TIME_NANOS: - case VariantLogicalType::TIMESTAMP_SEC: - case VariantLogicalType::TIMESTAMP_MILIS: - case VariantLogicalType::TIMESTAMP_MICROS: - case VariantLogicalType::TIMESTAMP_NANOS: - case VariantLogicalType::TIME_MICROS_TZ: - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - case VariantLogicalType::INTERVAL: - return true; - default: - return false; +namespace { + +struct AnalyzeState { +public: + explicit AnalyzeState(uint32_t &children_offset) : children_offset(children_offset) { } -} -static uint32_t VariantTrivialPrimitiveSize(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - return sizeof(int8_t); - case VariantLogicalType::INT16: - return sizeof(int16_t); - case VariantLogicalType::INT32: - return sizeof(int32_t); - case VariantLogicalType::INT64: - return sizeof(int64_t); - case VariantLogicalType::INT128: - return sizeof(hugeint_t); - case VariantLogicalType::UINT8: - return sizeof(uint8_t); - case VariantLogicalType::UINT16: - return sizeof(uint16_t); - case VariantLogicalType::UINT32: - return sizeof(uint32_t); - case VariantLogicalType::UINT64: - return sizeof(uint64_t); - case VariantLogicalType::UINT128: - return sizeof(uhugeint_t); - case VariantLogicalType::FLOAT: +public: + uint32_t &children_offset; +}; + +struct WriteState { +public: + WriteState(uint32_t &keys_offset, uint32_t &children_offset, uint32_t &blob_offset, data_ptr_t blob_data, + uint32_t &blob_size) + : keys_offset(keys_offset), children_offset(children_offset), blob_offset(blob_offset), blob_data(blob_data), + blob_size(blob_size) { + } + +public: + inline data_ptr_t GetDestination() { + return blob_data + blob_offset + blob_size; + } + +public: + uint32_t &keys_offset; + uint32_t &children_offset; + uint32_t &blob_offset; + data_ptr_t blob_data; + uint32_t &blob_size; +}; + +struct VariantToVariantSizeAnalyzer { + using result_type = uint32_t; + + static uint32_t VisitNull(AnalyzeState &state) { + return 0; + } + static uint32_t VisitBoolean(bool, AnalyzeState &state) { + return 0; + } + + template + static uint32_t VisitInteger(T, AnalyzeState &state) { + return sizeof(T); + } + + static uint32_t VisitFloat(float, AnalyzeState &state) { return sizeof(float); - case VariantLogicalType::DOUBLE: + } + static uint32_t VisitDouble(double, AnalyzeState &state) { return sizeof(double); - case VariantLogicalType::UUID: + } + static uint32_t VisitUUID(hugeint_t, AnalyzeState &state) { return sizeof(hugeint_t); - case VariantLogicalType::DATE: + } + static uint32_t VisitDate(date_t, AnalyzeState &state) { return sizeof(int32_t); - case VariantLogicalType::TIME_MICROS: + } + static uint32_t VisitInterval(interval_t, AnalyzeState &state) { + return sizeof(interval_t); + } + + static uint32_t VisitTime(dtime_t, AnalyzeState &state) { return sizeof(dtime_t); - case VariantLogicalType::TIME_NANOS: + } + static uint32_t VisitTimeNanos(dtime_ns_t, AnalyzeState &state) { return sizeof(dtime_ns_t); - case VariantLogicalType::TIMESTAMP_SEC: + } + static uint32_t VisitTimeTZ(dtime_tz_t, AnalyzeState &state) { + return sizeof(dtime_tz_t); + } + static uint32_t VisitTimestampSec(timestamp_sec_t, AnalyzeState &state) { return sizeof(timestamp_sec_t); - case VariantLogicalType::TIMESTAMP_MILIS: + } + static uint32_t VisitTimestampMs(timestamp_ms_t, AnalyzeState &state) { return sizeof(timestamp_ms_t); - case VariantLogicalType::TIMESTAMP_MICROS: + } + static uint32_t VisitTimestamp(timestamp_t, AnalyzeState &state) { return sizeof(timestamp_t); - case VariantLogicalType::TIMESTAMP_NANOS: + } + static uint32_t VisitTimestampNanos(timestamp_ns_t, AnalyzeState &state) { return sizeof(timestamp_ns_t); - case VariantLogicalType::TIME_MICROS_TZ: - return sizeof(dtime_tz_t); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: + } + static uint32_t VisitTimestampTZ(timestamp_tz_t, AnalyzeState &state) { return sizeof(timestamp_tz_t); - case VariantLogicalType::INTERVAL: - return sizeof(interval_t); - default: - throw InternalException("VariantLogicalType '%s' is not a trivial primitive", EnumUtil::ToString(type)); } -} + + static uint32_t VisitString(const string_t &str, AnalyzeState &state) { + auto length = static_cast(str.GetSize()); + return GetVarintSize(length) + length; + } + + static uint32_t VisitBlob(const string_t &blob, AnalyzeState &state) { + return VisitString(blob, state); + } + static uint32_t VisitBignum(const string_t &bignum, AnalyzeState &state) { + return VisitString(bignum, state); + } + static uint32_t VisitGeometry(const string_t &geom, AnalyzeState &state) { + return VisitString(geom, state); + } + static uint32_t VisitBitstring(const string_t &bits, AnalyzeState &state) { + return VisitString(bits, state); + } + + template + static uint32_t VisitDecimal(T, uint32_t width, uint32_t scale, AnalyzeState &state) { + uint32_t size = GetVarintSize(width) + GetVarintSize(scale); + size += sizeof(T); + return size; + } + + static uint32_t VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + AnalyzeState &state) { + uint32_t size = GetVarintSize(nested_data.child_count); + if (nested_data.child_count) { + size += GetVarintSize(nested_data.children_idx + state.children_offset); + } + return size; + } + + static uint32_t VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, AnalyzeState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static uint32_t VisitDefault(VariantLogicalType type_id, const_data_ptr_t, AnalyzeState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +struct VariantToVariantDataWriter { + using result_type = void; + + static void VisitNull(WriteState &state) { + return; + } + static void VisitBoolean(bool, WriteState &state) { + return; + } + + template + static void VisitInteger(T val, WriteState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDouble(double val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitUUID(hugeint_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDate(date_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitInterval(interval_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTime(dtime_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeNanos(dtime_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeTZ(dtime_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampSec(timestamp_sec_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampMs(timestamp_ms_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestamp(timestamp_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampNanos(timestamp_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampTZ(timestamp_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + + static void VisitString(const string_t &str, WriteState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; + } + static void VisitBlob(const string_t &blob, WriteState &state) { + return VisitString(blob, state); + } + static void VisitBignum(const string_t &bignum, WriteState &state) { + return VisitString(bignum, state); + } + static void VisitGeometry(const string_t &geom, WriteState &state) { + return VisitString(geom, state); + } + static void VisitBitstring(const string_t &bits, WriteState &state) { + return VisitString(bits, state); + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, WriteState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (nested_data.child_count) { + //! NOTE: The 'child_index' stored in the OBJECT/ARRAY data could require more bits + //! That's the reason we have to rewrite the data in VARIANT->VARIANT cast + state.blob_size += VarintEncode(nested_data.children_idx + state.children_offset, state.GetDestination()); + } + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, WriteState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +} // namespace template bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalResultData &result_data, idx_t count, @@ -168,99 +321,26 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe } } - auto source_blob_data = const_data_ptr_cast(source.GetData(source_index).GetData()); - - //! Then write all values auto source_values_list_entry = source.GetValuesListEntry(source_index); - for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; - source_value_index++) { - auto source_type_id = source.GetTypeId(source_index, source_value_index); - auto source_byte_offset = source.GetByteOffset(source_index, source_value_index); - - //! NOTE: we have to deserialize these in both passes - //! because to figure out the size of the 'data' that is added by the VARIANT, we have to traverse the - //! VARIANT solely because the 'child_index' stored in the OBJECT/ARRAY data could require more bits - WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, - nullptr, 0, source_type_id); - - if (source_type_id == VariantLogicalType::ARRAY || source_type_id == VariantLogicalType::OBJECT) { - auto source_nested_data = VariantUtils::DecodeNestedData(source, source_index, source_value_index); - if (WRITE_DATA) { - VarintEncode(source_nested_data.child_count, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(source_nested_data.child_count); - if (source_nested_data.child_count) { - auto new_child_index = source_nested_data.children_idx + children_offset; - if (WRITE_DATA) { - VarintEncode(new_child_index, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(new_child_index); - } - } else if (source_type_id == VariantLogicalType::VARIANT_NULL || - source_type_id == VariantLogicalType::BOOL_FALSE || - source_type_id == VariantLogicalType::BOOL_TRUE) { - // no-op - } else if (source_type_id == VariantLogicalType::DECIMAL) { - auto decimal_blob_data = source_blob_data + source_byte_offset; - auto width = static_cast(VarintDecode(decimal_blob_data)); - auto width_varint_size = GetVarintSize(width); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - width_varint_size, - width_varint_size); - } - blob_size += width_varint_size; - auto scale = static_cast(VarintDecode(decimal_blob_data)); - auto scale_varint_size = GetVarintSize(scale); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - scale_varint_size, - scale_varint_size); - } - blob_size += scale_varint_size; - - if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(hugeint_t)); - } - blob_size += sizeof(hugeint_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int64_t)); - } - blob_size += sizeof(int64_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int32_t)); - } - blob_size += sizeof(int32_t); - } else { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int16_t)); - } - blob_size += sizeof(int16_t); - } - } else if (source_type_id == VariantLogicalType::BITSTRING || - source_type_id == VariantLogicalType::BIGNUM || source_type_id == VariantLogicalType::VARCHAR || - source_type_id == VariantLogicalType::BLOB || source_type_id == VariantLogicalType::GEOMETRY) { - auto str_blob_data = source_blob_data + source_byte_offset; - auto str_length = VarintDecode(str_blob_data); - auto str_length_varint_size = GetVarintSize(str_length); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data - str_length_varint_size, - str_length_varint_size); - } - blob_size += str_length_varint_size; - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data, str_length); - } - blob_size += str_length; - } else if (VariantIsTrivialPrimitive(source_type_id)) { - auto size = VariantTrivialPrimitiveSize(source_type_id); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, source_blob_data + source_byte_offset, size); - } - blob_size += size; - } else { - throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(source_type_id)); + + if (WRITE_DATA) { + WriteState write_state(keys_offset, children_offset, blob_offset, blob_data, blob_size); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + auto source_type_id = source.GetTypeId(source_index, source_value_index); + WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, + nullptr, 0, source_type_id); + + VariantVisitor::Visit(source, source_index, source_value_index, + write_state); + } + } else { + AnalyzeState analyze_state(children_offset); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + values_offset_data[result_index]++; + blob_size += VariantVisitor::Visit(source, source_index, + source_value_index, analyze_state); } } diff --git a/src/include/duckdb/function/scalar/variant_functions.hpp b/src/include/duckdb/function/scalar/variant_functions.hpp index 7c9ce455da40..c318a923697e 100644 --- a/src/include/duckdb/function/scalar/variant_functions.hpp +++ b/src/include/duckdb/function/scalar/variant_functions.hpp @@ -25,6 +25,16 @@ struct VariantExtractFun { static ScalarFunctionSet GetFunctions(); }; +struct VariantNormalizeFun { + static constexpr const char *Name = "variant_normalize"; + static constexpr const char *Parameters = "input_variant"; + static constexpr const char *Description = "Normalizes the `input_variant` to a canonical representation."; + static constexpr const char *Example = "variant_normalize({'b': [1,2,3], 'a': 42})::VARIANT)"; + static constexpr const char *Categories = "variant"; + + static ScalarFunction GetFunction(); +}; + struct VariantTypeofFun { static constexpr const char *Name = "variant_typeof"; static constexpr const char *Parameters = "input_variant"; diff --git a/src/include/duckdb/function/scalar/variant_utils.hpp b/src/include/duckdb/function/scalar/variant_utils.hpp index 3e90c4365e55..1c20b19c0e9f 100644 --- a/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/include/duckdb/function/scalar/variant_utils.hpp @@ -80,8 +80,11 @@ struct VariantUtils { VariantNestedData *child_data, ValidityMask &validity); DUCKDB_API static vector ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, idx_t count, optional_idx row); - DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx); + DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx); DUCKDB_API static bool Verify(Vector &variant, const SelectionVector &sel_p, idx_t count); + DUCKDB_API static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size); }; } // namespace duckdb diff --git a/src/include/duckdb/main/attached_database.hpp b/src/include/duckdb/main/attached_database.hpp index a9ec117e9572..cb539e6b0e95 100644 --- a/src/include/duckdb/main/attached_database.hpp +++ b/src/include/duckdb/main/attached_database.hpp @@ -37,9 +37,10 @@ enum class AttachVisibility { SHOWN, HIDDEN }; class DatabaseFilePathManager; struct StoredDatabasePath { - StoredDatabasePath(DatabaseFilePathManager &manager, string path, const string &name); + StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path, const string &name); ~StoredDatabasePath(); + DatabaseManager &db_manager; DatabaseFilePathManager &manager; string path; diff --git a/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp b/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp index c3ffadbe25b9..57ea17d6be10 100644 --- a/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp +++ b/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp @@ -32,7 +32,7 @@ class BatchedBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::BATCHED; public: - explicit BatchedBufferedData(weak_ptr context); + explicit BatchedBufferedData(ClientContext &context); public: void Append(const DataChunk &chunk, idx_t batch); diff --git a/src/include/duckdb/main/buffered_data/buffered_data.hpp b/src/include/duckdb/main/buffered_data/buffered_data.hpp index 0f32675ceac9..06a72b0f64dc 100644 --- a/src/include/duckdb/main/buffered_data/buffered_data.hpp +++ b/src/include/duckdb/main/buffered_data/buffered_data.hpp @@ -28,7 +28,7 @@ class BufferedData { enum class Type { SIMPLE, BATCHED }; public: - BufferedData(Type type, weak_ptr context_p); + BufferedData(Type type, ClientContext &context); virtual ~BufferedData(); public: diff --git a/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp b/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp index 967cc1ab7636..40a5a6edebac 100644 --- a/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp +++ b/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp @@ -24,7 +24,7 @@ class SimpleBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::SIMPLE; public: - explicit SimpleBufferedData(weak_ptr context); + explicit SimpleBufferedData(ClientContext &context); ~SimpleBufferedData() override; public: diff --git a/src/include/duckdb/main/capi/capi_internal.hpp b/src/include/duckdb/main/capi/capi_internal.hpp index 8307b70a35f8..3b736d88a208 100644 --- a/src/include/duckdb/main/capi/capi_internal.hpp +++ b/src/include/duckdb/main/capi/capi_internal.hpp @@ -51,6 +51,8 @@ struct PreparedStatementWrapper { //! Map of name -> values case_insensitive_map_t values; unique_ptr statement; + bool success = true; + ErrorData error_data; }; struct ExtractStatementsWrapper { diff --git a/src/include/duckdb/main/capi/extension_api.hpp b/src/include/duckdb/main/capi/extension_api.hpp index 2ce10061a1d1..899a331cb968 100644 --- a/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/include/duckdb/main/capi/extension_api.hpp @@ -554,6 +554,11 @@ typedef struct { // New string functions that are added char *(*duckdb_value_to_string)(duckdb_value value); + // New functions around the table description + + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); // New functions around table function binding void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1044,6 +1049,8 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_scalar_function_bind_get_argument = duckdb_scalar_function_bind_get_argument; result.duckdb_scalar_function_set_bind_data_copy = duckdb_scalar_function_set_bind_data_copy; result.duckdb_value_to_string = duckdb_value_to_string; + result.duckdb_table_description_get_column_count = duckdb_table_description_get_column_count; + result.duckdb_table_description_get_column_type = duckdb_table_description_get_column_type; result.duckdb_table_function_get_client_context = duckdb_table_function_get_client_context; result.duckdb_create_map_value = duckdb_create_map_value; result.duckdb_create_union_value = duckdb_create_union_value; diff --git a/src/include/duckdb/main/capi/header_generation/apis/v1/unstable/new_table_description_functions.json b/src/include/duckdb/main/capi/header_generation/apis/v1/unstable/new_table_description_functions.json new file mode 100644 index 000000000000..28cd47a50bfd --- /dev/null +++ b/src/include/duckdb/main/capi/header_generation/apis/v1/unstable/new_table_description_functions.json @@ -0,0 +1,8 @@ +{ + "version": "unstable_new_table_description_functions", + "description": "New functions around the table description", + "entries": [ + "duckdb_table_description_get_column_count", + "duckdb_table_description_get_column_type" + ] +} \ No newline at end of file diff --git a/src/include/duckdb/main/capi/header_generation/functions/safe_fetch_functions.json b/src/include/duckdb/main/capi/header_generation/functions/safe_fetch_functions.json index c05ba2c9fcbc..c4d24d209095 100644 --- a/src/include/duckdb/main/capi/header_generation/functions/safe_fetch_functions.json +++ b/src/include/duckdb/main/capi/header_generation/functions/safe_fetch_functions.json @@ -1,7 +1,7 @@ { "group": "safe_fetch_functions", "deprecated": true, - "description": "// These functions will perform conversions if necessary.\n// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned.\n// Note that these functions are slow since they perform bounds checking and conversion\n// For fast access of values prefer using `duckdb_result_get_chunk`", + "description": "// This function group is deprecated.\n// To access the values in a result, use `duckdb_fetch_chunk` repeatedly.\n// For each chunk, use the `duckdb_data_chunk` interface to access any columns and their values.\n\n", "entries": [ { "name": "duckdb_value_boolean", @@ -417,7 +417,7 @@ } ], "comment": { - "description": "**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null bytes.\n\n", + "description": "**DEPRECATION NOTICE**: This method is scheduled for removal in a future release.\n\n", "return_value": "The text value at the specified location as a null-terminated string, or nullptr if the value cannot be\nconverted. The result must be freed with `duckdb_free`." } }, @@ -439,7 +439,7 @@ } ], "comment": { - "description": "**DEPRECATION NOTICE**: This method is scheduled for removal in a future release.\n\nNo support for nested types, and for other complex types.\nThe resulting field \"string.data\" must be freed with `duckdb_free.`\n\n", + "description": "**DEPRECATION NOTICE**: This method is scheduled for removal in a future release.\n\n", "return_value": "The string value at the specified location. Attempts to cast the result value to string." } }, @@ -461,7 +461,7 @@ } ], "comment": { - "description": "**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains\nnull bytes.\n\n", + "description": "**DEPRECATION NOTICE**: This method is scheduled for removal in a future release.\n\n", "return_value": "The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast.\nIf the column is NOT a VARCHAR column this function will return NULL.\n\nThe result must NOT be freed." } }, @@ -483,7 +483,7 @@ } ], "comment": { - "description": "**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains\nnull bytes.\n", + "description": "**DEPRECATION NOTICE**: This method is scheduled for removal in a future release.\n\n", "return_value": "The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast.\nIf the column is NOT a VARCHAR column this function will return NULL.\n\nThe result must NOT be freed." } }, diff --git a/src/include/duckdb/main/capi/header_generation/functions/table_description.json b/src/include/duckdb/main/capi/header_generation/functions/table_description.json index 2a5ac5038b28..f8219458ea03 100644 --- a/src/include/duckdb/main/capi/header_generation/functions/table_description.json +++ b/src/include/duckdb/main/capi/header_generation/functions/table_description.json @@ -131,6 +131,23 @@ "return_value": "`DuckDBSuccess` on success or `DuckDBError` on failure." } }, + { + "name": "duckdb_table_description_get_column_count", + "return_type": "idx_t", + "params": [ + { + "type": "duckdb_table_description", + "name": "table_description" + } + ], + "comment": { + "description": "Return the number of columns of the described table.\n\n", + "param_comments": { + "table_description": "The table_description to query." + }, + "return_value": "The column count." + } + }, { "name": "duckdb_table_description_get_column_name", "return_type": "char *", @@ -152,6 +169,28 @@ }, "return_value": "The column name." } + }, + { + "name": "duckdb_table_description_get_column_type", + "return_type": "duckdb_logical_type", + "params": [ + { + "type": "duckdb_table_description", + "name": "table_description" + }, + { + "type": "idx_t", + "name": "index" + } + ], + "comment": { + "description": "Obtain the column type at 'index'.\nThe return value must be destroyed with `duckdb_destroy_logical_type`.\n\n", + "param_comments": { + "table_description": "The table_description to query.", + "index": "The index of the column to query." + }, + "return_value": "The column type." + } } ] } \ No newline at end of file diff --git a/src/include/duckdb/main/capi/header_generation/functions/vector_interface.json b/src/include/duckdb/main/capi/header_generation/functions/vector_interface.json index 29520006403b..57c39c3ff1c4 100644 --- a/src/include/duckdb/main/capi/header_generation/functions/vector_interface.json +++ b/src/include/duckdb/main/capi/header_generation/functions/vector_interface.json @@ -212,12 +212,12 @@ } ], "comment": { - "description": "Sets the total size of the underlying child-vector of a list vector.\n\n", + "description": "Sets the size of the underlying child-vector of a list vector.\nNote that this does NOT reserve the memory in the child buffer,\nand that it is possible to set a size exceeding the capacity.\nTo set the capacity, use `duckdb_list_vector_reserve`.\n\n", "param_comments": { "vector": "The list vector.", "size": "The size of the child list." }, - "return_value": "The duckdb state. Returns DuckDBError if the vector is nullptr." + "return_value": "The duckdb state. Returns DuckDBError, if the vector is nullptr." } }, { @@ -234,12 +234,12 @@ } ], "comment": { - "description": "Sets the total capacity of the underlying child-vector of a list.\n\nAfter calling this method, you must call `duckdb_vector_get_validity` and `duckdb_vector_get_data` to obtain current\ndata and validity pointers\n\n", + "description": "Sets the capacity of the underlying child-vector of a list vector.\nWe increment to the next power of two, based on the required capacity.\nThus, the capacity might not match the size of the list (capacity >= size),\nwhich is set via `duckdb_list_vector_set_size`.\n\n", "param_comments": { "vector": "The list vector.", - "required_capacity": "the total capacity to reserve." + "required_capacity": "The child buffer capacity to reserve." }, - "return_value": "The duckdb state. Returns DuckDBError if the vector is nullptr." + "return_value": "The duckdb state. Returns DuckDBError, if the vector is nullptr." } }, { diff --git a/src/include/duckdb/main/client_context.hpp b/src/include/duckdb/main/client_context.hpp index d0ad964ef33c..21f56ccf94b0 100644 --- a/src/include/duckdb/main/client_context.hpp +++ b/src/include/duckdb/main/client_context.hpp @@ -339,9 +339,6 @@ class QueryContext { } QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction } - QueryContext(weak_ptr context) // NOLINT: allow implicit construction - : owning_context(context.lock()), context(owning_context.get()) { - } public: bool Valid() const { @@ -352,7 +349,6 @@ class QueryContext { } private: - shared_ptr owning_context; optional_ptr context; }; diff --git a/src/include/duckdb/main/database_file_path_manager.hpp b/src/include/duckdb/main/database_file_path_manager.hpp index 90d028035b57..3af2f1873bff 100644 --- a/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/include/duckdb/main/database_file_path_manager.hpp @@ -12,31 +12,35 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/reference_map.hpp" namespace duckdb { struct AttachInfo; struct AttachOptions; +class DatabaseManager; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - explicit DatabasePathInfo(string name_p) : name(std::move(name_p)), is_attached(true) { - } + DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode); string name; - bool is_attached; + AccessMode access_mode; + reference_set_t attached_databases; + idx_t reference_count = 1; }; //! The DatabaseFilePathManager is used to ensure we only ever open a single database file once class DatabaseFilePathManager { public: idx_t ApproxDatabaseCount() const; - InsertDatabasePathResult InsertDatabasePath(const string &path, const string &name, OnCreateConflict on_conflict, - AttachOptions &options); + InsertDatabasePathResult InsertDatabasePath(DatabaseManager &manager, const string &path, const string &name, + OnCreateConflict on_conflict, AttachOptions &options); //! Erase a database path - indicating we are done with using it void EraseDatabasePath(const string &path); //! Called when a database is detached, but before it is fully finished being used - void DetachDatabase(const string &path); + void DetachDatabase(DatabaseManager &manager, const string &path); private: //! The lock to add entries to the db_paths map diff --git a/src/include/duckdb/main/error_manager.hpp b/src/include/duckdb/main/error_manager.hpp index aaedffd4b966..065f6399a9a3 100644 --- a/src/include/duckdb/main/error_manager.hpp +++ b/src/include/duckdb/main/error_manager.hpp @@ -34,38 +34,39 @@ enum class ErrorType : uint16_t { class ErrorManager { public: template - string FormatException(ErrorType error_type, ARGS... params) { + string FormatException(ErrorType error_type, ARGS &&...params) { vector values; - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } DUCKDB_API string FormatExceptionRecursive(ErrorType error_type, vector &values); template string FormatExceptionRecursive(ErrorType error_type, vector &values, T param, - ARGS... params) { + ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } template - static string FormatException(ClientContext &context, ErrorType error_type, ARGS... params) { - return Get(context).FormatException(error_type, params...); + static string FormatException(ClientContext &context, ErrorType error_type, ARGS &&...params) { + return Get(context).FormatException(error_type, std::forward(params)...); } DUCKDB_API static InvalidInputException InvalidUnicodeError(const String &input, const string &context); DUCKDB_API static FatalException InvalidatedDatabase(ClientContext &context, const string &invalidated_msg); + DUCKDB_API static TransactionException InvalidatedTransaction(ClientContext &context); //! Adds a custom error for a specific error type void AddCustomError(ErrorType type, string new_error); DUCKDB_API static ErrorManager &Get(ClientContext &context); + DUCKDB_API static ErrorManager &Get(DatabaseInstance &context); private: map custom_errors; }; - } // namespace duckdb diff --git a/src/include/duckdb/main/extension_entries.hpp b/src/include/duckdb/main/extension_entries.hpp index 4f119c4a0577..7688d575c46a 100644 --- a/src/include/duckdb/main/extension_entries.hpp +++ b/src/include/duckdb/main/extension_entries.hpp @@ -69,8 +69,10 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"approx_top_k", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_max_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_min_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmax", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmin", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"array_agg", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, diff --git a/src/include/duckdb/main/profiling_info.hpp b/src/include/duckdb/main/profiling_info.hpp index 709314375529..e54a13618fd5 100644 --- a/src/include/duckdb/main/profiling_info.hpp +++ b/src/include/duckdb/main/profiling_info.hpp @@ -41,8 +41,8 @@ class ProfilingInfo { public: static profiler_settings_t DefaultSettings(); - static profiler_settings_t DefaultRootSettings(); - static profiler_settings_t DefaultOperatorSettings(); + static profiler_settings_t RootScopeSettings(); + static profiler_settings_t OperatorScopeSettings(); public: void ResetMetrics(); diff --git a/src/include/duckdb/main/query_profiler.hpp b/src/include/duckdb/main/query_profiler.hpp index db27be1e435f..44df56025c24 100644 --- a/src/include/duckdb/main/query_profiler.hpp +++ b/src/include/duckdb/main/query_profiler.hpp @@ -94,7 +94,6 @@ class OperatorProfiler { DUCKDB_API void Flush(const PhysicalOperator &phys_op); DUCKDB_API OperatorInformation &GetOperatorInfo(const PhysicalOperator &phys_op); DUCKDB_API bool OperatorInfoIsInitialized(const PhysicalOperator &phys_op); - DUCKDB_API void AddExtraInfo(InsertionOrderPreservingMap extra_info); public: ClientContext &context; @@ -117,15 +116,35 @@ class OperatorProfiler { struct QueryMetrics { QueryMetrics() : total_bytes_read(0), total_bytes_written(0) {}; + //! Reset the query metrics. + void Reset() { + query = ""; + latency.Reset(); + waiting_to_attach_latency.Reset(); + attach_load_storage_latency.Reset(); + attach_replay_wal_latency.Reset(); + checkpoint_latency.Reset(); + total_bytes_read = 0; + total_bytes_written = 0; + } + ProfilingInfo query_global_info; - //! The SQL string of the query + //! The SQL string of the query. string query; - //! The timer used to time the excution time of the entire query + //! The timer of the execution of the entire query. Profiler latency; - //! The total bytes read by the file system + //! The timer of the delay when waiting to ATTACH a file. + Profiler waiting_to_attach_latency; + //! The timer for loading from storage. + Profiler attach_load_storage_latency; + //! The timer for replaying the WAL file. + Profiler attach_replay_wal_latency; + //! The timer for running checkpoints. + Profiler checkpoint_latency; + //! The total bytes read by the file system. atomic total_bytes_read; - //! The total bytes written by the file system + //! The total bytes written by the file system. atomic total_bytes_written; }; @@ -138,9 +157,6 @@ class QueryProfiler { DUCKDB_API explicit QueryProfiler(ClientContext &context); public: - //! Propagate save_location, enabled, detailed_enabled and automatic_print_format. - void Propagate(QueryProfiler &qp); - DUCKDB_API bool IsEnabled() const; DUCKDB_API bool IsDetailedEnabled() const; DUCKDB_API ProfilerPrintFormat GetPrintFormat(ExplainFormat format = ExplainFormat::DEFAULT) const; @@ -159,6 +175,10 @@ class QueryProfiler { //! Adds nr_bytes bytes to the total bytes written. DUCKDB_API void AddBytesWritten(const idx_t nr_bytes); + //! Start/End a timer for a specific metric type. + DUCKDB_API void StartTimer(MetricsType type); + DUCKDB_API void EndTimer(MetricsType type); + DUCKDB_API void StartExplainAnalyze(); //! Adds the timings gathered by an OperatorProfiler to this query profiler diff --git a/src/include/duckdb/main/relation.hpp b/src/include/duckdb/main/relation.hpp index 9d9e67686b1f..bc383ffe0467 100644 --- a/src/include/duckdb/main/relation.hpp +++ b/src/include/duckdb/main/relation.hpp @@ -78,7 +78,8 @@ class Relation : public enable_shared_from_this { public: DUCKDB_API virtual const vector &Columns() = 0; - DUCKDB_API virtual unique_ptr GetQueryNode(); + DUCKDB_API virtual unique_ptr GetQueryNode() = 0; + DUCKDB_API virtual string GetQuery(); DUCKDB_API virtual BoundStatement Bind(Binder &binder); DUCKDB_API virtual string GetAlias(); diff --git a/src/include/duckdb/main/relation/create_table_relation.hpp b/src/include/duckdb/main/relation/create_table_relation.hpp index 7d5462941bfd..8df59b8d251e 100644 --- a/src/include/duckdb/main/relation/create_table_relation.hpp +++ b/src/include/duckdb/main/relation/create_table_relation.hpp @@ -26,6 +26,8 @@ class CreateTableRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/create_view_relation.hpp b/src/include/duckdb/main/relation/create_view_relation.hpp index cb826a86cd33..aa09b0def99e 100644 --- a/src/include/duckdb/main/relation/create_view_relation.hpp +++ b/src/include/duckdb/main/relation/create_view_relation.hpp @@ -26,6 +26,8 @@ class CreateViewRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/delete_relation.hpp b/src/include/duckdb/main/relation/delete_relation.hpp index c07445ba4e2f..0c25c6576401 100644 --- a/src/include/duckdb/main/relation/delete_relation.hpp +++ b/src/include/duckdb/main/relation/delete_relation.hpp @@ -26,6 +26,8 @@ class DeleteRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/explain_relation.hpp b/src/include/duckdb/main/relation/explain_relation.hpp index 888583b2b6e4..96be08d8ff65 100644 --- a/src/include/duckdb/main/relation/explain_relation.hpp +++ b/src/include/duckdb/main/relation/explain_relation.hpp @@ -24,6 +24,8 @@ class ExplainRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/insert_relation.hpp b/src/include/duckdb/main/relation/insert_relation.hpp index 3695cde7b007..fccb0ae929b3 100644 --- a/src/include/duckdb/main/relation/insert_relation.hpp +++ b/src/include/duckdb/main/relation/insert_relation.hpp @@ -23,6 +23,8 @@ class InsertRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/query_relation.hpp b/src/include/duckdb/main/relation/query_relation.hpp index bdb035652ccf..b1be001b9fe6 100644 --- a/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/include/duckdb/main/relation/query_relation.hpp @@ -28,6 +28,7 @@ class QueryRelation : public Relation { public: static unique_ptr ParseStatement(ClientContext &context, const string &query, const string &error); unique_ptr GetQueryNode() override; + string GetQuery() override; unique_ptr GetTableRef() override; BoundStatement Bind(Binder &binder) override; diff --git a/src/include/duckdb/main/relation/update_relation.hpp b/src/include/duckdb/main/relation/update_relation.hpp index 58ad203b2ef7..91eac246e19d 100644 --- a/src/include/duckdb/main/relation/update_relation.hpp +++ b/src/include/duckdb/main/relation/update_relation.hpp @@ -29,6 +29,8 @@ class UpdateRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/write_csv_relation.hpp b/src/include/duckdb/main/relation/write_csv_relation.hpp index 99d2ebe8e5b0..cf0853ff3968 100644 --- a/src/include/duckdb/main/relation/write_csv_relation.hpp +++ b/src/include/duckdb/main/relation/write_csv_relation.hpp @@ -23,6 +23,8 @@ class WriteCSVRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/main/relation/write_parquet_relation.hpp b/src/include/duckdb/main/relation/write_parquet_relation.hpp index d32089212981..138eee7c7e4b 100644 --- a/src/include/duckdb/main/relation/write_parquet_relation.hpp +++ b/src/include/duckdb/main/relation/write_parquet_relation.hpp @@ -24,6 +24,8 @@ class WriteParquetRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/include/duckdb/optimizer/filter_pushdown.hpp index 5e920ca2de77..29c2f0ac4e6b 100644 --- a/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -104,8 +104,9 @@ class FilterPushdown { void ExtractFilterBindings(const Expression &expr, vector &bindings); //! Generate filters from the current set of filters stored in the FilterCombiner void GenerateFilters(); - //! if there are filters in this FilterPushdown node, push them into the combiner - void PushFilters(); + //! if there are filters in this FilterPushdown node, push them into the combiner. Returns + //! FilterResult::UNSATISFIABLE if the subtree should be stripped, or FilterResult::SUCCESS otherwise + FilterResult PushFilters(); }; } // namespace duckdb diff --git a/src/include/duckdb/optimizer/topn_window_elimination.hpp b/src/include/duckdb/optimizer/topn_window_elimination.hpp index fcb50bac51ea..c871bfc65246 100644 --- a/src/include/duckdb/optimizer/topn_window_elimination.hpp +++ b/src/include/duckdb/optimizer/topn_window_elimination.hpp @@ -25,6 +25,8 @@ struct TopNWindowEliminationParameters { TopNPayloadType payload_type; //! Whether to include row numbers bool include_row_number; + //! Whether the val or arg column contains null values + bool can_be_null = false; }; class TopNWindowElimination : public BaseColumnPruner { @@ -51,7 +53,7 @@ class TopNWindowElimination : public BaseColumnPruner { vector TraverseProjectionBindings(const std::vector &old_bindings, LogicalOperator *&op); unique_ptr CreateAggregateExpression(vector> aggregate_params, bool requires_arg, - OrderType order_type) const; + const TopNWindowEliminationParameters ¶ms) const; unique_ptr CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const; void AddStructExtractExprs(vector> &exprs, const LogicalType &struct_type, const unique_ptr &aggregate_column_ref) const; @@ -59,6 +61,9 @@ class TopNWindowElimination : public BaseColumnPruner { const map &group_idxs, const vector &topmost_bindings, vector &new_bindings, ColumnBindingReplacer &replacer); + TopNWindowEliminationParameters ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload); private: ClientContext &context; diff --git a/src/include/duckdb/parser/parser_extension.hpp b/src/include/duckdb/parser/parser_extension.hpp index 61c071307c6f..a3d2dcf64896 100644 --- a/src/include/duckdb/parser/parser_extension.hpp +++ b/src/include/duckdb/parser/parser_extension.hpp @@ -86,12 +86,12 @@ struct ParserOverrideResult { explicit ParserOverrideResult(vector> statements_p) : type(ParserExtensionResultType::PARSE_SUCCESSFUL), statements(std::move(statements_p)) {}; - explicit ParserOverrideResult(const string &error_p) + explicit ParserOverrideResult(std::exception &error_p) : type(ParserExtensionResultType::DISPLAY_EXTENSION_ERROR), error(error_p) {}; ParserExtensionResultType type; vector> statements; - string error; + ErrorData error; }; typedef ParserOverrideResult (*parser_override_function_t)(ParserExtensionInfo *info, const string &query); @@ -103,14 +103,14 @@ class ParserExtension { public: //! The parse function of the parser extension. //! Takes a query string as input and returns ParserExtensionParseData (on success) or an error - parse_function_t parse_function; + parse_function_t parse_function = nullptr; //! The plan function of the parser extension //! Takes as input the result of the parse_function, and outputs various properties of the resulting plan - plan_function_t plan_function; + plan_function_t plan_function = nullptr; //! Override the current parser with a new parser and return a vector of SQL statements - parser_override_function_t parser_override; + parser_override_function_t parser_override = nullptr; //! Additional parser info passed to the parse function shared_ptr parser_info; diff --git a/src/include/duckdb/parser/query_node.hpp b/src/include/duckdb/parser/query_node.hpp index 956bd63f7bc5..5c091b259ba6 100644 --- a/src/include/duckdb/parser/query_node.hpp +++ b/src/include/duckdb/parser/query_node.hpp @@ -60,8 +60,6 @@ class QueryNode { //! CTEs (used by SelectNode and SetOperationNode) CommonTableExpressionMap cte_map; - virtual const vector> &GetSelectList() const = 0; - public: //! Convert the query node to a string virtual string ToString() const = 0; diff --git a/src/include/duckdb/parser/query_node/cte_node.hpp b/src/include/duckdb/parser/query_node/cte_node.hpp index bc997a6c7740..fd2589fd2ab2 100644 --- a/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/include/duckdb/parser/query_node/cte_node.hpp @@ -14,6 +14,7 @@ namespace duckdb { +//! DEPRECATED - CTENode is only preserved for backwards compatibility when serializing older databases class CTENode : public QueryNode { public: static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; @@ -23,30 +24,18 @@ class CTENode : public QueryNode { } string ctename; - //! The query of the CTE unique_ptr query; - //! Child unique_ptr child; - //! Aliases of the CTE node vector aliases; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - const vector> &GetSelectList() const override { - return query->GetSelectList(); - } - public: - //! Convert the query node to a string string ToString() const override; bool Equals(const QueryNode *other) const override; - //! Create a copy of this SelectNode unique_ptr Copy() const override; - //! Serializes a QueryNode to a stand-alone binary blob - //! Deserializes a blob back into a QueryNode - void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &source); }; diff --git a/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/include/duckdb/parser/query_node/recursive_cte_node.hpp index 6d73fda4aca6..1f5f16ead15d 100644 --- a/src/include/duckdb/parser/query_node/recursive_cte_node.hpp +++ b/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -33,10 +33,6 @@ class RecursiveCTENode : public QueryNode { //! targets for key variants vector> key_targets; - const vector> &GetSelectList() const override { - return left->GetSelectList(); - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/include/duckdb/parser/query_node/select_node.hpp b/src/include/duckdb/parser/query_node/select_node.hpp index 62aa9c0b2e64..dfc474d14a5b 100644 --- a/src/include/duckdb/parser/query_node/select_node.hpp +++ b/src/include/duckdb/parser/query_node/select_node.hpp @@ -43,10 +43,6 @@ class SelectNode : public QueryNode { //! The SAMPLE clause unique_ptr sample; - const vector> &GetSelectList() const override { - return select_list; - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/include/duckdb/parser/query_node/set_operation_node.hpp b/src/include/duckdb/parser/query_node/set_operation_node.hpp index 960f6c2d678c..3070e224534a 100644 --- a/src/include/duckdb/parser/query_node/set_operation_node.hpp +++ b/src/include/duckdb/parser/query_node/set_operation_node.hpp @@ -29,8 +29,6 @@ class SetOperationNode : public QueryNode { //! The children of the set operation vector> children; - const vector> &GetSelectList() const override; - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/include/duckdb/parser/query_node/statement_node.hpp b/src/include/duckdb/parser/query_node/statement_node.hpp index 9e813335c6ed..26db46a58231 100644 --- a/src/include/duckdb/parser/query_node/statement_node.hpp +++ b/src/include/duckdb/parser/query_node/statement_node.hpp @@ -24,7 +24,6 @@ class StatementNode : public QueryNode { SQLStatement &stmt; public: - const vector> &GetSelectList() const override; //! Convert the query node to a string string ToString() const override; diff --git a/src/include/duckdb/parser/transformer.hpp b/src/include/duckdb/parser/transformer.hpp index 59e4f041988d..1945ebc5a921 100644 --- a/src/include/duckdb/parser/transformer.hpp +++ b/src/include/duckdb/parser/transformer.hpp @@ -80,7 +80,7 @@ class Transformer { //! The set of pivot entries to create vector> pivot_entries; //! Sets of stored CTEs, if any - vector stored_cte_map; + vector> stored_cte_map; //! Whether or not we are currently binding a window definition bool in_window_definition = false; @@ -304,7 +304,6 @@ class Transformer { string TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias); vector TransformStringList(duckdb_libpgquery::PGList *list); void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map); - static unique_ptr TransformMaterializedCTE(unique_ptr root); unique_ptr TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node, CommonTableExpressionInfo &info); diff --git a/src/include/duckdb/planner/bind_context.hpp b/src/include/duckdb/planner/bind_context.hpp index b17f2666a46b..db5b52c78fc0 100644 --- a/src/include/duckdb/planner/bind_context.hpp +++ b/src/include/duckdb/planner/bind_context.hpp @@ -54,7 +54,7 @@ class BindContext { //! matching ones vector GetSimilarBindings(const string &column_name); - optional_ptr GetCTEBinding(const string &ctename); + optional_ptr GetCTEBinding(const BindingAlias &ctename); //! Binds a column expression to the base table. Returns the bound expression //! or throws an exception if the column could not be bound. BindResult BindColumn(ColumnRefExpression &colref, idx_t depth); @@ -116,8 +116,9 @@ class BindContext { //! Adds a base table with the given alias to the CTE BindContext. //! We need this to correctly bind recursive CTEs with multiple references. - void AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types, - bool using_key = false); + void AddCTEBinding(idx_t index, BindingAlias alias, const vector &names, const vector &types, + CTEType cte_type = CTEType::CAN_BE_REFERENCED); + void AddCTEBinding(unique_ptr binding); //! Add an implicit join condition (e.g. USING (x)) void AddUsingBinding(const string &column_name, UsingColumnSet &set); @@ -173,6 +174,6 @@ class BindContext { //! The set of columns used in USING join conditions case_insensitive_map_t> using_columns; //! The set of CTE bindings - case_insensitive_map_t> cte_bindings; + vector> cte_bindings; }; } // namespace duckdb diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 5603ab36bec4..cf01e74d6585 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -69,6 +69,7 @@ struct UnpivotEntry; struct CopyInfo; struct CopyOption; struct BoundSetOpChild; +struct BoundCTEData; template class IndexVector; @@ -173,14 +174,14 @@ struct GlobalBinderState { case_insensitive_map_t> replacement_scans; //! Using column sets vector> using_column_sets; + //! The set of parameter expressions bound by this binder + optional_ptr parameters; }; // QueryBinderState is state shared WITHIN a query, a new query-binder state is created when binding inside e.g. a view struct QueryBinderState { //! The vector of active binders vector> active_binders; - //! The set of parameter expressions bound by this binder - optional_ptr parameters; }; //! Bind the parsed query tree to the actual columns present in the catalog. @@ -199,8 +200,6 @@ class Binder : public enable_shared_from_this { //! The client context ClientContext &context; - //! A mapping of names to common table expressions - case_insensitive_set_t CTE_bindings; // NOLINT //! The bind context BindContext bind_context; //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a @@ -260,12 +259,8 @@ class Binder : public enable_shared_from_this { optional_ptr GetCatalogEntry(const string &catalog, const string &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found); - //! Add a common table expression to the binder - void AddCTE(const string &name); //! Find all candidate common table expression by name; returns empty vector if none exists - optional_ptr GetCTEBinding(const string &name); - - bool CTEExists(const string &name); + optional_ptr GetCTEBinding(const BindingAlias &name); //! Add the view to the set of currently bound views - used for detecting recursive view definitions void AddBoundView(ViewCatalogEntry &view); @@ -409,24 +404,20 @@ class Binder : public enable_shared_from_this { unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - BoundStatement BindCTE(CTENode &statement); + BoundStatement BindCTE(const string &ctename, CommonTableExpressionInfo &info); BoundStatement BindNode(SelectNode &node); BoundStatement BindNode(SetOperationNode &node); BoundStatement BindNode(RecursiveCTENode &node); - BoundStatement BindNode(CTENode &node); BoundStatement BindNode(QueryNode &node); BoundStatement BindNode(StatementNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); - unique_ptr CreatePlan(BoundRecursiveCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node); unique_ptr CreatePlan(BoundSelectNode &statement); unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); - BoundSetOpChild BindSetOpChild(QueryNode &child); - unique_ptr BindSetOpNode(SetOperationNode &statement); + void BuildUnionByNameInfo(BoundSetOperationNode &result); BoundStatement BindJoin(Binder &parent, TableRef &ref); BoundStatement Bind(BaseTableRef &ref); @@ -517,8 +508,6 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); - unique_ptr BindSelectNodeInternal(SelectNode &statement); - unique_ptr BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); @@ -546,6 +535,9 @@ class Binder : public enable_shared_from_this { static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, const string &tname); + BoundCTEData PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement); + BoundStatement FinishCTE(BoundCTEData &bound_cte, BoundStatement child_data); + private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); }; diff --git a/src/include/duckdb/planner/bound_query_node.hpp b/src/include/duckdb/planner/bound_query_node.hpp index cd5a78b6aa94..76c461e78a4c 100644 --- a/src/include/duckdb/planner/bound_query_node.hpp +++ b/src/include/duckdb/planner/bound_query_node.hpp @@ -17,13 +17,8 @@ namespace duckdb { //! Bound equivalent of QueryNode class BoundQueryNode { public: - explicit BoundQueryNode(QueryNodeType type) : type(type) { - } - virtual ~BoundQueryNode() { - } + virtual ~BoundQueryNode() = default; - //! The type of the query node, either SetOperation or Select - QueryNodeType type; //! The result modifiers that should be applied to this query node vector> modifiers; @@ -34,23 +29,6 @@ class BoundQueryNode { public: virtual idx_t GetRootIndex() = 0; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } }; } // namespace duckdb diff --git a/src/include/duckdb/planner/bound_statement.hpp b/src/include/duckdb/planner/bound_statement.hpp index bb1f7bfec82b..23fae54d6f5c 100644 --- a/src/include/duckdb/planner/bound_statement.hpp +++ b/src/include/duckdb/planner/bound_statement.hpp @@ -9,17 +9,31 @@ #pragma once #include "duckdb/common/string.hpp" +#include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { class LogicalOperator; struct LogicalType; +struct BoundStatement; +class ParsedExpression; +class Binder; + +struct ExtraBoundInfo { + SetOperationType setop_type = SetOperationType::NONE; + vector> child_binders; + vector bound_children; + vector> original_expressions; +}; struct BoundStatement { unique_ptr plan; vector types; vector names; + ExtraBoundInfo extra_info; }; } // namespace duckdb diff --git a/src/include/duckdb/planner/bound_tokens.hpp b/src/include/duckdb/planner/bound_tokens.hpp index ac8aef099760..862ef5a114a7 100644 --- a/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/include/duckdb/planner/bound_tokens.hpp @@ -16,8 +16,6 @@ namespace duckdb { class BoundQueryNode; class BoundSelectNode; class BoundSetOperationNode; -class BoundRecursiveCTENode; -class BoundCTENode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/include/duckdb/planner/query_node/bound_cte_node.hpp deleted file mode 100644 index 67c076ab687e..000000000000 --- a/src/include/duckdb/planner/query_node/bound_cte_node.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class BoundCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; - -public: - BoundCTENode() : BoundQueryNode(QueryNodeType::CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - //! The cte node - BoundStatement query; - //! The child node - BoundStatement child; - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the query side of the CTE - shared_ptr query_binder; - //! The binder used by the child side of the CTE - shared_ptr child_binder; - - CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - -public: - idx_t GetRootIndex() override { - return child.plan->GetRootIndex(); - } -}; - -} // namespace duckdb diff --git a/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp deleted file mode 100644 index 6a18194649e6..000000000000 --- a/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_recursive_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -//! Bound equivalent of SetOperationNode -class BoundRecursiveCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; - -public: - BoundRecursiveCTENode() : BoundQueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - bool union_all; - //! The left side of the set operation - BoundStatement left; - //! The right side of the set operation - BoundStatement right; - //! Target columns for the recursive key variant - vector> key_targets; - - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the left side of the set operation - shared_ptr left_binder; - //! The binder used by the right side of the set operation - shared_ptr right_binder; - -public: - idx_t GetRootIndex() override { - return setop_index; - } -}; - -} // namespace duckdb diff --git a/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/include/duckdb/planner/query_node/bound_select_node.hpp index d941956987e5..3fdc186e9a1a 100644 --- a/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -35,12 +35,6 @@ struct BoundUnnestNode { //! Bound equivalent of SelectNode class BoundSelectNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; - -public: - BoundSelectNode() : BoundQueryNode(QueryNodeType::SELECT_NODE) { - } - //! Bind information SelectBindState bind_state; //! The projection list diff --git a/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 391ca26e6089..675007b50c1c 100644 --- a/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -13,24 +13,18 @@ #include "duckdb/planner/bound_query_node.hpp" namespace duckdb { -struct BoundSetOpChild; //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; - -public: - BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { - } - ~BoundSetOperationNode() override; - //! The type of set operation SetOperationType setop_type = SetOperationType::NONE; //! whether the ALL modifier was used or not bool setop_all = false; //! The bound children - vector bound_children; + vector bound_children; + //! Child binders + vector> child_binders; //! Index used by the set operation idx_t setop_index; @@ -41,18 +35,4 @@ class BoundSetOperationNode : public BoundQueryNode { } }; -struct BoundSetOpChild { - unique_ptr bound_node; - BoundStatement node; - shared_ptr binder; - //! Original select list (if this was a SELECT statement) - vector> select_list; - //! Exprs used by the UNION BY NAME operations to add a new projection - vector> reorder_expressions; - - const vector &GetNames(); - const vector &GetTypes(); - idx_t GetRootIndex(); -}; - } // namespace duckdb diff --git a/src/include/duckdb/planner/query_node/list.hpp b/src/include/duckdb/planner/query_node/list.hpp index 5c7dbda9492b..dcac81248934 100644 --- a/src/include/duckdb/planner/query_node/list.hpp +++ b/src/include/duckdb/planner/query_node/list.hpp @@ -1,4 +1,2 @@ -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" diff --git a/src/include/duckdb/planner/table_binding.hpp b/src/include/duckdb/planner/table_binding.hpp index 671c5abed526..836f52c417df 100644 --- a/src/include/duckdb/planner/table_binding.hpp +++ b/src/include/duckdb/planner/table_binding.hpp @@ -17,6 +17,7 @@ #include "duckdb/planner/binding_alias.hpp" #include "duckdb/common/column_index.hpp" #include "duckdb/common/table_column.hpp" +#include "duckdb/planner/bound_statement.hpp" namespace duckdb { class BindContext; @@ -36,19 +37,6 @@ struct Binding { Binding(BindingType binding_type, BindingAlias alias, vector types, vector names, idx_t index); virtual ~Binding() = default; - //! The type of Binding - BindingType binding_type; - //! The alias of the binding - BindingAlias alias; - //! The table index of the binding - idx_t index; - //! The types of the bound columns - vector types; - //! Column names of the subquery - vector names; - //! Name -> index for the names - case_insensitive_map_t name_map; - public: bool TryGetBindingIndex(const string &column_name, column_t &column_index); column_t GetBindingIndex(const string &column_name); @@ -58,6 +46,14 @@ struct Binding { virtual optional_ptr GetStandardEntry(); string GetAlias() const; + BindingType GetBindingType(); + const BindingAlias &GetBindingAlias(); + idx_t GetIndex(); + const vector &GetColumnTypes(); + const vector &GetColumnNames(); + idx_t GetColumnCount(); + void SetColumnType(idx_t col_idx, LogicalType type); + static BindingAlias GetAlias(const string &explicit_alias, const StandardEntry &entry); static BindingAlias GetAlias(const string &explicit_alias, optional_ptr entry); @@ -77,6 +73,23 @@ struct Binding { } return reinterpret_cast(*this); } + +protected: + void Initialize(); + +protected: + //! The type of Binding + BindingType binding_type; + //! The alias of the binding + BindingAlias alias; + //! The table index of the binding + idx_t index; + //! The types of the bound columns + vector types; + //! Column names of the subquery + vector names; + //! Name -> index for the names + case_insensitive_map_t name_map; }; struct EntryBinding : public Binding { @@ -148,14 +161,44 @@ struct DummyBinding : public Binding { unique_ptr ParamToArg(ColumnRefExpression &col_ref); }; +enum class CTEType { CAN_BE_REFERENCED, CANNOT_BE_REFERENCED }; +struct CTEBinding; + +struct CTEBindState { + CTEBindState(Binder &parent_binder, QueryNode &cte_def, const vector &aliases); + ~CTEBindState(); + + Binder &parent_binder; + QueryNode &cte_def; + const vector &aliases; + idx_t active_binder_count; + shared_ptr query_binder; + BoundStatement query; + vector names; + vector types; + +public: + bool IsBound() const; + void Bind(CTEBinding &binding); +}; + struct CTEBinding : public Binding { public: static constexpr const BindingType TYPE = BindingType::CTE; public: - CTEBinding(BindingAlias alias, vector types, vector names, idx_t index); + CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, CTEType type); + CTEBinding(BindingAlias alias, shared_ptr bind_state, idx_t index); + +public: + bool CanBeReferenced() const; + bool IsReferenced() const; + void Reference(); +private: + CTEType cte_type; idx_t reference_count; + shared_ptr bind_state; }; } // namespace duckdb diff --git a/src/include/duckdb/storage/block_manager.hpp b/src/include/duckdb/storage/block_manager.hpp index 0fd9df675667..b7907ad6d74a 100644 --- a/src/include/duckdb/storage/block_manager.hpp +++ b/src/include/duckdb/storage/block_manager.hpp @@ -37,6 +37,9 @@ class BlockManager { BufferManager &buffer_manager; public: + BufferManager &GetBufferManager() const { + return buffer_manager; + } //! Creates a new block inside the block manager virtual unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) = 0; virtual unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) = 0; diff --git a/src/include/duckdb/storage/buffer_manager.hpp b/src/include/duckdb/storage/buffer_manager.hpp index fd281e1419f4..3d4f5e5950ce 100644 --- a/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/include/duckdb/storage/buffer_manager.hpp @@ -101,6 +101,8 @@ class BufferManager { //! Set a new swap limit. virtual void SetSwapLimit(optional_idx limit = optional_idx()); + //! Get the block manager used for in-memory data + virtual BlockManager &GetTemporaryBlockManager() = 0; //! Get the temporary file information of each temporary file. virtual vector GetTemporaryFiles(); //! Get the path to the temporary file directory. diff --git a/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/include/duckdb/storage/metadata/metadata_manager.hpp index 06a451a5263d..88ace35be7aa 100644 --- a/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -62,6 +62,10 @@ class MetadataManager { MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager); ~MetadataManager(); + BufferManager &GetBufferManager() const { + return buffer_manager; + } + MetadataHandle AllocateHandle(); MetadataHandle Pin(const MetadataPointer &pointer); diff --git a/src/include/duckdb/storage/serialization/query_node.json b/src/include/duckdb/storage/serialization/query_node.json index 1fc7f9cf27a0..6242fa91d290 100644 --- a/src/include/duckdb/storage/serialization/query_node.json +++ b/src/include/duckdb/storage/serialization/query_node.json @@ -21,6 +21,11 @@ "name": "cte_map", "type": "CommonTableExpressionMap" } + ], + "finalize_deserialization": [ + "if (type == QueryNodeType::CTE_NODE) {", + "\tresult = std::move(result->Cast().child);", + "}" ] }, { diff --git a/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/include/duckdb/storage/standard_buffer_manager.hpp index f6b91ed1e147..ff4ce4684525 100644 --- a/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -84,6 +84,8 @@ class StandardBufferManager : public BufferManager { //! Returns information about memory usage vector GetMemoryUsageInfo() const override; + BlockManager &GetTemporaryBlockManager() final; + //! Returns a list of all temporary files vector GetTemporaryFiles() final; diff --git a/src/include/duckdb/storage/table/array_column_data.hpp b/src/include/duckdb/storage/table/array_column_data.hpp index f4c943a797ba..c246d68b6ddc 100644 --- a/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/include/duckdb/storage/table/array_column_data.hpp @@ -48,10 +48,10 @@ class ArrayColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/include/duckdb/storage/table/chunk_info.hpp b/src/include/duckdb/storage/table/chunk_info.hpp index 44b92dd7437e..9e33b42019a3 100644 --- a/src/include/duckdb/storage/table/chunk_info.hpp +++ b/src/include/duckdb/storage/table/chunk_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/vector_size.hpp" #include "duckdb/common/atomic.hpp" +#include "duckdb/execution/index/index_pointer.hpp" namespace duckdb { class RowGroup; @@ -20,6 +21,7 @@ struct TransactionData; struct DeleteInfo; class Serializer; class Deserializer; +class FixedSizeAllocator; enum class ChunkInfoType : uint8_t { CONSTANT_INFO, VECTOR_INFO, EMPTY_INFO }; @@ -38,19 +40,19 @@ class ChunkInfo { public: //! Gets up to max_count entries from the chunk info. If the ret is 0>ret>max_count, the selection vector is filled //! with the tuples - virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) = 0; + virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const = 0; virtual idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) = 0; //! Returns whether or not a single row in the ChunkInfo should be used or not for the given transaction virtual bool Fetch(TransactionData transaction, row_t row) = 0; virtual void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) = 0; - virtual idx_t GetCommittedDeletedCount(idx_t max_count) = 0; + virtual idx_t GetCommittedDeletedCount(idx_t max_count) const = 0; virtual bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const; virtual bool HasDeletes() const = 0; virtual void Write(WriteStream &writer) const; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); public: template @@ -81,12 +83,12 @@ class ChunkConstantInfo : public ChunkInfo { transaction_t delete_id; public: - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; bool HasDeletes() const override; @@ -105,27 +107,19 @@ class ChunkVectorInfo : public ChunkInfo { static constexpr const ChunkInfoType TYPE = ChunkInfoType::VECTOR_INFO; public: - explicit ChunkVectorInfo(idx_t start); - - //! The transaction ids of the transactions that inserted the tuples (if any) - transaction_t inserted[STANDARD_VECTOR_SIZE]; - transaction_t insert_id; - bool same_inserted_id; - - //! The transaction ids of the transactions that deleted the tuples (if any) - transaction_t deleted[STANDARD_VECTOR_SIZE]; - bool any_deleted; + explicit ChunkVectorInfo(FixedSizeAllocator &allocator, idx_t start, transaction_t insert_id = 0); + ~ChunkVectorInfo() override; public: idx_t GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; void Append(idx_t start, idx_t end, transaction_t commit_id); @@ -138,14 +132,32 @@ class ChunkVectorInfo : public ChunkInfo { void CommitDelete(transaction_t commit_id, const DeleteInfo &info); bool HasDeletes() const override; + bool AnyDeleted() const; + bool HasConstantInsertionId() const; + transaction_t ConstantInsertId() const; void Write(WriteStream &writer) const override; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); private: template idx_t TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; + + IndexPointer GetInsertedPointer() const; + IndexPointer GetDeletedPointer() const; + IndexPointer GetInitializedInsertedPointer(); + IndexPointer GetInitializedDeletedPointer(); + +private: + FixedSizeAllocator &allocator; + //! The transaction ids of the transactions that inserted the tuples (if any) + IndexPointer inserted_data; + //! The constant insert id (if there is only one) + transaction_t constant_insert_id; + + //! The transaction ids of the transactions that deleted the tuples (if any) + IndexPointer deleted_data; }; } // namespace duckdb diff --git a/src/include/duckdb/storage/table/column_data.hpp b/src/include/duckdb/storage/table/column_data.hpp index b688f8ed756c..ab8a5970e99d 100644 --- a/src/include/duckdb/storage/table/column_data.hpp +++ b/src/include/duckdb/storage/table/column_data.hpp @@ -156,10 +156,10 @@ class ColumnData { virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx); - virtual void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count); - virtual void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth); + virtual void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count); + virtual void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth); virtual unique_ptr GetUpdateStatistics(); virtual void CommitDropColumn(); @@ -220,8 +220,8 @@ class ColumnData { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result, idx_t scan_count, bool allow_updates, bool scan_committed); void FetchUpdateRow(TransactionData transaction, row_t row_id, Vector &result, idx_t result_idx); - void UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector); + void UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, Vector &base_vector); idx_t FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector); idx_t GetVectorCount(idx_t vector_index) const; diff --git a/src/include/duckdb/storage/table/list_column_data.hpp b/src/include/duckdb/storage/table/list_column_data.hpp index 98d8c662b1e9..621ece4517fe 100644 --- a/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/include/duckdb/storage/table/list_column_data.hpp @@ -46,10 +46,10 @@ class ListColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/include/duckdb/storage/table/row_group.hpp b/src/include/duckdb/storage/table/row_group.hpp index d003d1378336..62b54eeedd17 100644 --- a/src/include/duckdb/storage/table/row_group.hpp +++ b/src/include/duckdb/storage/table/row_group.hpp @@ -171,12 +171,12 @@ class RowGroup : public SegmentBase { void InitializeAppend(RowGroupAppendState &append_state); void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); - void Update(TransactionData transaction, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids); + void Update(TransactionData transaction, DataTable &data_table, DataChunk &updates, row_t *ids, idx_t offset, + idx_t count, const vector &column_ids); //! Update a single column; corresponds to DataTable::UpdateColumn //! This method should only be called from the WAL - void UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path); + void UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path); void MergeStatistics(idx_t column_idx, const BaseStatistics &other); void MergeIntoStatistics(idx_t column_idx, BaseStatistics &other); diff --git a/src/include/duckdb/storage/table/row_group_collection.hpp b/src/include/duckdb/storage/table/row_group_collection.hpp index e5c74582933a..80aa5266831e 100644 --- a/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/include/duckdb/storage/table/row_group_collection.hpp @@ -36,6 +36,7 @@ struct CollectionCheckpointState; struct PersistentCollectionData; class CheckpointTask; class TableIOManager; +class DataTable; class RowGroupCollection { public: @@ -101,9 +102,10 @@ class RowGroupCollection { void RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, Vector &row_identifiers, idx_t count); idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); - void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); - void UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates); + void Update(TransactionData transaction, DataTable &table, row_t *ids, const vector &column_ids, + DataChunk &updates); + void UpdateColumn(TransactionData transaction, DataTable &table, Vector &row_ids, + const vector &column_path, DataChunk &updates); void Checkpoint(TableDataWriter &writer, TableStatistics &global_stats); diff --git a/src/include/duckdb/storage/table/row_id_column_data.hpp b/src/include/duckdb/storage/table/row_id_column_data.hpp index 3bc6572ae2c6..f839c6b24619 100644 --- a/src/include/duckdb/storage/table/row_id_column_data.hpp +++ b/src/include/duckdb/storage/table/row_id_column_data.hpp @@ -48,10 +48,10 @@ class RowIdColumnData : public ColumnData { void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; void RevertAppend(row_t start_row) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; void CommitDropColumn() override; diff --git a/src/include/duckdb/storage/table/row_version_manager.hpp b/src/include/duckdb/storage/table/row_version_manager.hpp index bb0d0056b9f0..5ba674cc6f49 100644 --- a/src/include/duckdb/storage/table/row_version_manager.hpp +++ b/src/include/duckdb/storage/table/row_version_manager.hpp @@ -12,20 +12,25 @@ #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/common/mutex.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { struct DeleteInfo; class MetadataManager; +class BufferManager; struct MetaBlockPointer; class RowVersionManager { public: - explicit RowVersionManager(idx_t start) noexcept; + explicit RowVersionManager(BufferManager &buffer_manager, idx_t start) noexcept; - idx_t GetStart() { + idx_t GetStart() const { return start; } + FixedSizeAllocator &GetAllocator() { + return allocator; + } void SetStart(idx_t start); idx_t GetCommittedDeletedCount(idx_t count); @@ -48,6 +53,7 @@ class RowVersionManager { private: mutex version_lock; + FixedSizeAllocator allocator; idx_t start; vector> vector_info; bool has_changes; diff --git a/src/include/duckdb/storage/table/standard_column_data.hpp b/src/include/duckdb/storage/table/standard_column_data.hpp index 8d233139a3c5..ec06eb30a4cf 100644 --- a/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/include/duckdb/storage/table/standard_column_data.hpp @@ -47,10 +47,10 @@ class StandardColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/include/duckdb/storage/table/struct_column_data.hpp b/src/include/duckdb/storage/table/struct_column_data.hpp index 91c7f1e1919b..798a21326921 100644 --- a/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/include/duckdb/storage/table/struct_column_data.hpp @@ -46,10 +46,10 @@ class StructColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/include/duckdb/storage/table/update_segment.hpp b/src/include/duckdb/storage/table/update_segment.hpp index 75cf25ecfbb0..3f5b9d2119dd 100644 --- a/src/include/duckdb/storage/table/update_segment.hpp +++ b/src/include/duckdb/storage/table/update_segment.hpp @@ -38,8 +38,8 @@ class UpdateSegment { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result); void FetchCommitted(idx_t vector_index, Vector &result); void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); - void Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, - Vector &base_data); + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update, row_t *ids, + idx_t count, Vector &base_data); void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); void RollbackUpdate(UpdateInfo &info); diff --git a/src/include/duckdb/transaction/duck_transaction.hpp b/src/include/duckdb/transaction/duck_transaction.hpp index 12c4d180c25f..b9080f192f71 100644 --- a/src/include/duckdb/transaction/duck_transaction.hpp +++ b/src/include/duckdb/transaction/duck_transaction.hpp @@ -35,14 +35,12 @@ class DuckTransaction : public Transaction { transaction_t transaction_id; //! The commit id of this transaction, if it has successfully been committed transaction_t commit_id; - //! Highest active query when the transaction finished, used for cleaning up - transaction_t highest_active_query; atomic catalog_version; //! Transactions undergo Cleanup, after (1) removing them directly in RemoveTransaction, - //! or (2) after they exist old_transactions. - //! Some (after rollback) enter old_transactions, but do not require Cleanup. + //! or (2) after they enter cleanup_queue. + //! Some (after rollback) enter cleanup_queue, but do not require Cleanup. bool awaiting_cleanup; public: @@ -76,7 +74,7 @@ class DuckTransaction : public Transaction { idx_t base_row); void PushSequenceUsage(SequenceCatalogEntry &entry, const SequenceData &data); void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); - UndoBufferReference CreateUpdateInfo(idx_t type_size, idx_t entries); + UndoBufferReference CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries); bool IsDuckTransaction() const override { return true; @@ -90,6 +88,7 @@ class DuckTransaction : public Transaction { //! Get a shared lock on a table shared_ptr SharedLockTable(DataTableInfo &info); + //! Hold an owning reference of the table, needed to safely reference it inside the transaction commit/undo logic void ModifyTable(DataTable &tbl); private: diff --git a/src/include/duckdb/transaction/duck_transaction_manager.hpp b/src/include/duckdb/transaction/duck_transaction_manager.hpp index 63531ae7d7e2..a3bf3f47a627 100644 --- a/src/include/duckdb/transaction/duck_transaction_manager.hpp +++ b/src/include/duckdb/transaction/duck_transaction_manager.hpp @@ -110,8 +110,6 @@ class DuckTransactionManager : public TransactionManager { vector> active_transactions; //! Set of recently committed transactions vector> recently_committed_transactions; - //! Transactions awaiting GC - vector> old_transactions; //! The lock used for transaction operations mutex transaction_lock; //! The checkpoint lock diff --git a/src/include/duckdb/transaction/update_info.hpp b/src/include/duckdb/transaction/update_info.hpp index 7cccd923e74b..5eb1392611d2 100644 --- a/src/include/duckdb/transaction/update_info.hpp +++ b/src/include/duckdb/transaction/update_info.hpp @@ -17,6 +17,7 @@ namespace duckdb { class UpdateSegment; struct DataTableInfo; +class DataTable; //! UpdateInfo is a class that represents a set of updates applied to a single vector. //! The UpdateInfo struct contains metadata associated with the update. @@ -26,6 +27,8 @@ struct DataTableInfo; struct UpdateInfo { //! The update segment that this update info affects UpdateSegment *segment; + //! The table this was update was made on + DataTable *table; //! The column index of which column we are updating idx_t column_index; //! The version number @@ -87,7 +90,7 @@ struct UpdateInfo { //! Returns the total allocation size for an UpdateInfo entry, together with space for the tuple data static idx_t GetAllocSize(idx_t type_size); //! Initialize an UpdateInfo struct that has been allocated using GetAllocSize (i.e. has extra space after it) - static void Initialize(UpdateInfo &info, transaction_t transaction_id); + static void Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id); }; } // namespace duckdb diff --git a/src/include/duckdb/transaction/wal_write_state.hpp b/src/include/duckdb/transaction/wal_write_state.hpp index aad1a672cb71..4c68da4875a6 100644 --- a/src/include/duckdb/transaction/wal_write_state.hpp +++ b/src/include/duckdb/transaction/wal_write_state.hpp @@ -31,7 +31,7 @@ class WALWriteState { void CommitEntry(UndoFlags type, data_ptr_t data); private: - void SwitchTable(DataTableInfo *table, UndoFlags new_op); + void SwitchTable(DataTableInfo &table, UndoFlags new_op); void WriteCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data); void WriteDelete(DeleteInfo &info); diff --git a/src/include/duckdb/verification/statement_verifier.hpp b/src/include/duckdb/verification/statement_verifier.hpp index a60abf18705f..77fed981540d 100644 --- a/src/include/duckdb/verification/statement_verifier.hpp +++ b/src/include/duckdb/verification/statement_verifier.hpp @@ -85,6 +85,8 @@ class StatementVerifier { private: const vector> empty_select_list = {}; + + const vector> &GetSelectList(QueryNode &node); }; } // namespace duckdb diff --git a/src/include/duckdb_extension.h b/src/include/duckdb_extension.h index 7c51360590ae..f014be548bc8 100644 --- a/src/include/duckdb_extension.h +++ b/src/include/duckdb_extension.h @@ -643,6 +643,13 @@ typedef struct { char *(*duckdb_value_to_string)(duckdb_value value); #endif +// New functions around the table description +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); +#endif + // New functions around table function binding #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1164,6 +1171,10 @@ typedef struct { // Version unstable_new_string_functions #define duckdb_value_to_string duckdb_ext_api.duckdb_value_to_string +// Version unstable_new_table_description_functions +#define duckdb_table_description_get_column_count duckdb_ext_api.duckdb_table_description_get_column_count +#define duckdb_table_description_get_column_type duckdb_ext_api.duckdb_table_description_get_column_type + // Version unstable_new_table_function_functions #define duckdb_table_function_get_client_context duckdb_ext_api.duckdb_table_function_get_client_context diff --git a/src/main/attached_database.cpp b/src/main/attached_database.cpp index e879e17a178f..70e8be932cab 100644 --- a/src/main/attached_database.cpp +++ b/src/main/attached_database.cpp @@ -14,8 +14,9 @@ namespace duckdb { -StoredDatabasePath::StoredDatabasePath(DatabaseFilePathManager &manager, string path_p, const string &name) - : manager(manager), path(std::move(path_p)) { +StoredDatabasePath::StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path_p, + const string &name) + : db_manager(db_manager), manager(manager), path(std::move(path_p)) { } StoredDatabasePath::~StoredDatabasePath() { @@ -23,7 +24,7 @@ StoredDatabasePath::~StoredDatabasePath() { } void StoredDatabasePath::OnDetach() { - manager.DetachDatabase(path); + manager.DetachDatabase(db_manager, path); } //===--------------------------------------------------------------------===// diff --git a/src/main/buffered_data/batched_buffered_data.cpp b/src/main/buffered_data/batched_buffered_data.cpp index 3c593374cfa2..e9f949098d85 100644 --- a/src/main/buffered_data/batched_buffered_data.cpp +++ b/src/main/buffered_data/batched_buffered_data.cpp @@ -14,9 +14,8 @@ void BatchedBufferedData::BlockSink(const InterruptState &blocked_sink, idx_t ba blocked_sinks.emplace(batch, blocked_sink); } -BatchedBufferedData::BatchedBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::BATCHED, std::move(context)), buffer_byte_count(0), read_queue_byte_count(0), - min_batch(0) { +BatchedBufferedData::BatchedBufferedData(ClientContext &context) + : BufferedData(BufferedData::Type::BATCHED, context), buffer_byte_count(0), read_queue_byte_count(0), min_batch(0) { read_queue_capacity = (idx_t)(static_cast(total_buffer_size) * 0.6); buffer_capacity = (idx_t)(static_cast(total_buffer_size) * 0.4); } diff --git a/src/main/buffered_data/buffered_data.cpp b/src/main/buffered_data/buffered_data.cpp index 156539815e04..0e01df8dc692 100644 --- a/src/main/buffered_data/buffered_data.cpp +++ b/src/main/buffered_data/buffered_data.cpp @@ -4,9 +4,8 @@ namespace duckdb { -BufferedData::BufferedData(Type type, weak_ptr context_p) : type(type), context(std::move(context_p)) { - auto client_context = context.lock(); - auto &config = ClientConfig::GetConfig(*client_context); +BufferedData::BufferedData(Type type, ClientContext &context_p) : type(type), context(context_p.shared_from_this()) { + auto &config = ClientConfig::GetConfig(context_p); total_buffer_size = config.streaming_buffer_size; } diff --git a/src/main/buffered_data/simple_buffered_data.cpp b/src/main/buffered_data/simple_buffered_data.cpp index 4b6a3a534177..59cde1f43f10 100644 --- a/src/main/buffered_data/simple_buffered_data.cpp +++ b/src/main/buffered_data/simple_buffered_data.cpp @@ -6,8 +6,7 @@ namespace duckdb { -SimpleBufferedData::SimpleBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::SIMPLE, std::move(context)) { +SimpleBufferedData::SimpleBufferedData(ClientContext &context) : BufferedData(BufferedData::Type::SIMPLE, context) { buffered_count = 0; buffer_size = total_buffer_size; } diff --git a/src/main/capi/data_chunk-c.cpp b/src/main/capi/data_chunk-c.cpp index 7274852c43aa..77f6482ab024 100644 --- a/src/main/capi/data_chunk-c.cpp +++ b/src/main/capi/data_chunk-c.cpp @@ -167,20 +167,20 @@ idx_t duckdb_list_vector_get_size(duckdb_vector vector) { duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::SetListSize(*v, size); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::Reserve(*v, required_capacity); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index) { diff --git a/src/main/capi/prepared-c.cpp b/src/main/capi/prepared-c.cpp index 28b2f011fd7c..ac5b638f8311 100644 --- a/src/main/capi/prepared-c.cpp +++ b/src/main/capi/prepared-c.cpp @@ -88,7 +88,13 @@ duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement) { auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || !wrapper->statement->HasError()) { + if (!wrapper) { + return nullptr; + } + if (!wrapper->success) { + return wrapper->error_data.Message().c_str(); + } + if (!wrapper->statement || !wrapper->statement->HasError()) { return nullptr; } return wrapper->statement->error.Message().c_str(); @@ -191,7 +197,7 @@ const char *duckdb_prepared_statement_column_name(duckdb_prepared_statement prep } auto &names = wrapper->statement->GetNames(); - if (col_idx < 0 || col_idx >= names.size()) { + if (col_idx >= names.size()) { return nullptr; } return strdup(names[col_idx].c_str()); @@ -204,7 +210,7 @@ duckdb_logical_type duckdb_prepared_statement_column_logical_type(duckdb_prepare return nullptr; } auto types = wrapper->statement->GetTypes(); - if (col_idx < 0 || col_idx >= types.size()) { + if (col_idx >= types.size()) { return nullptr; } return reinterpret_cast(new LogicalType(types[col_idx])); @@ -229,9 +235,10 @@ duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx return DuckDBError; } if (param_idx <= 0 || param_idx > wrapper->statement->named_param_map.size()) { - wrapper->statement->error = + wrapper->error_data = duckdb::InvalidInputException("Can not bind to parameter number %d, statement only has %d parameter(s)", param_idx, wrapper->statement->named_param_map.size()); + wrapper->success = false; return DuckDBError; } auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); diff --git a/src/main/capi/table_description-c.cpp b/src/main/capi/table_description-c.cpp index 26624bbfc1d9..cfcd01c4350d 100644 --- a/src/main/capi/table_description-c.cpp +++ b/src/main/capi/table_description-c.cpp @@ -1,5 +1,5 @@ -#include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/main/capi/capi_internal.hpp" using duckdb::Connection; using duckdb::ErrorData; @@ -68,14 +68,14 @@ const char *duckdb_table_description_error(duckdb_table_description table) { return wrapper->error.c_str(); } -duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, idx_t index) { +duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, duckdb::optional_idx index) { if (!wrapper) { return DuckDBError; } auto &table = wrapper->description; - if (index >= table->columns.size()) { - wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", index, - table->columns.size()); + if (index.IsValid() && index.GetIndex() >= table->columns.size()) { + wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", + index.GetIndex(), table->columns.size()); return DuckDBError; } return DuckDBSuccess; @@ -97,6 +97,16 @@ duckdb_state duckdb_column_has_default(duckdb_table_description table_descriptio return DuckDBSuccess; } +idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, duckdb::optional_idx()) == DuckDBError) { + return 0; + } + + auto &table = wrapper->description; + return table->columns.size(); +} + char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index) { auto wrapper = reinterpret_cast(table_description); if (GetTableDescription(wrapper, index) == DuckDBError) { @@ -113,3 +123,16 @@ char *duckdb_table_description_get_column_name(duckdb_table_description table_de return result; } + +duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, idx_t index) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, index) == DuckDBError) { + return nullptr; + } + + auto &table = wrapper->description; + auto &column = table->columns[index]; + + auto logical_type = new duckdb::LogicalType(column.Type()); + return reinterpret_cast(logical_type); +} diff --git a/src/main/client_data.cpp b/src/main/client_data.cpp index 50df0556352b..0c63cca163b4 100644 --- a/src/main/client_data.cpp +++ b/src/main/client_data.cpp @@ -119,6 +119,9 @@ class ClientBufferManager : public BufferManager { return buffer_manager.SetSwapLimit(limit); } + BlockManager &GetTemporaryBlockManager() override { + return buffer_manager.GetTemporaryBlockManager(); + } vector GetTemporaryFiles() override { return buffer_manager.GetTemporaryFiles(); } diff --git a/src/main/config.cpp b/src/main/config.cpp index 48c1cacdfaf3..27456b649d27 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -518,8 +518,7 @@ void DBConfig::CheckLock(const String &name) { return; } // not allowed! - throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", - name.ToStdString()); + throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", name); } idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { diff --git a/src/main/database_file_path_manager.cpp b/src/main/database_file_path_manager.cpp index f1825780e557..2e107210a246 100644 --- a/src/main/database_file_path_manager.cpp +++ b/src/main/database_file_path_manager.cpp @@ -5,35 +5,57 @@ namespace duckdb { +DatabasePathInfo::DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode) + : name(std::move(name_p)), access_mode(access_mode) { + attached_databases.insert(manager); +} + idx_t DatabaseFilePathManager::ApproxDatabaseCount() const { lock_guard path_lock(db_paths_lock); return db_paths.size(); } -InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(const string &path, const string &name, - OnCreateConflict on_conflict, +InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(DatabaseManager &manager, const string &path, + const string &name, OnCreateConflict on_conflict, AttachOptions &options) { if (path.empty() || path == IN_MEMORY_PATH) { return InsertDatabasePathResult::SUCCESS; } lock_guard path_lock(db_paths_lock); - auto entry = db_paths.emplace(path, DatabasePathInfo(name)); + auto entry = db_paths.emplace(path, DatabasePathInfo(manager, name, options.access_mode)); if (!entry.second) { auto &existing = entry.first->second; + bool already_exists = false; + bool attached_in_this_system = false; if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT && existing.name == name) { - if (existing.is_attached) { + already_exists = true; + attached_in_this_system = existing.attached_databases.find(manager) != existing.attached_databases.end(); + } + if (options.access_mode == AccessMode::READ_ONLY && existing.access_mode == AccessMode::READ_ONLY) { + if (attached_in_this_system) { return InsertDatabasePathResult::ALREADY_EXISTS; } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " - "the process of being detached", - name, path); + // all attaches are in read-only mode - there is no conflict, just increase the reference count + existing.attached_databases.insert(manager); + existing.reference_count++; + } else { + if (already_exists) { + if (attached_in_this_system) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " + "the process of being detached", + name, path); + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " + "attached by database \"%s\"", + name, path, existing.name); } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " - "attached by database \"%s\"", - name, path, existing.name); } - options.stored_database_path = make_uniq(*this, path, name); + options.stored_database_path = make_uniq(manager, *this, path, name); return InsertDatabasePathResult::SUCCESS; } @@ -42,17 +64,24 @@ void DatabaseFilePathManager::EraseDatabasePath(const string &path) { return; } lock_guard path_lock(db_paths_lock); - db_paths.erase(path); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + if (entry->second.reference_count <= 1) { + db_paths.erase(entry); + } else { + entry->second.reference_count--; + } + } } -void DatabaseFilePathManager::DetachDatabase(const string &path) { +void DatabaseFilePathManager::DetachDatabase(DatabaseManager &manager, const string &path) { if (path.empty() || path == IN_MEMORY_PATH) { return; } lock_guard path_lock(db_paths_lock); auto entry = db_paths.find(path); if (entry != db_paths.end()) { - entry->second.is_attached = false; + entry->second.attached_databases.erase(manager); } } diff --git a/src/main/database_manager.cpp b/src/main/database_manager.cpp index f59cc2719ed3..1b6b070d43ca 100644 --- a/src/main/database_manager.cpp +++ b/src/main/database_manager.cpp @@ -85,28 +85,38 @@ shared_ptr DatabaseManager::GetDatabaseInternal(const lock_gua shared_ptr DatabaseManager::AttachDatabase(ClientContext &context, AttachInfo &info, AttachOptions &options) { if (options.db_type.empty() || StringUtil::CIEquals(options.db_type, "duckdb")) { + // Start timing the ATTACH-delay step. + auto profiler = context.client_data->profiler; + profiler->StartTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); + while (InsertDatabasePath(info, options) == InsertDatabasePathResult::ALREADY_EXISTS) { // database with this name and path already exists // first check if it exists within this transaction auto &meta_transaction = MetaTransaction::Get(context); auto existing_db = meta_transaction.GetReferencedDatabaseOwning(info.name); if (existing_db) { + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); // it does! return it return existing_db; } + // ... but it might not be done attaching yet! // verify the database has actually finished attaching prior to returning lock_guard guard(databases_lock); auto entry = databases.find(info.name); if (entry != databases.end()) { - // database ACTUALLY exists - return it + // The database ACTUALLY exists, so we return it. + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); return entry->second; } if (context.interrupted) { + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); throw InterruptException(); } } + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); } + auto &config = DBConfig::GetConfig(context); GetDatabaseType(context, info, config, options); if (!options.db_type.empty()) { @@ -270,7 +280,7 @@ idx_t DatabaseManager::ApproxDatabaseCount() { } InsertDatabasePathResult DatabaseManager::InsertDatabasePath(const AttachInfo &info, AttachOptions &options) { - return path_manager->InsertDatabasePath(info.path, info.name, info.on_conflict, options); + return path_manager->InsertDatabasePath(*this, info.path, info.name, info.on_conflict, options); } vector DatabaseManager::GetAttachedDatabasePaths() { diff --git a/src/main/extension.cpp b/src/main/extension.cpp index cd786d863ec7..c982a4bc3071 100644 --- a/src/main/extension.cpp +++ b/src/main/extension.cpp @@ -7,6 +7,8 @@ namespace duckdb { +constexpr const idx_t ParsedExtensionMetaData::FOOTER_SIZE; + Extension::~Extension() { } diff --git a/src/main/profiling_info.cpp b/src/main/profiling_info.cpp index 276e0c198148..b795f829fabf 100644 --- a/src/main/profiling_info.cpp +++ b/src/main/profiling_info.cpp @@ -23,12 +23,12 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t // Reduce. if (depth == 0) { - auto op_metrics = DefaultOperatorSettings(); + auto op_metrics = OperatorScopeSettings(); for (const auto metric : op_metrics) { settings.erase(metric); } } else { - auto root_metrics = DefaultRootSettings(); + auto root_metrics = RootScopeSettings(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -37,32 +37,40 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t } profiler_settings_t ProfilingInfo::DefaultSettings() { - return {MetricsType::QUERY_NAME, - MetricsType::BLOCKED_THREAD_TIME, - MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, + return {MetricsType::BLOCKED_THREAD_TIME, MetricsType::CPU_TIME, - MetricsType::EXTRA_INFO, MetricsType::CUMULATIVE_CARDINALITY, - MetricsType::OPERATOR_NAME, - MetricsType::OPERATOR_TYPE, - MetricsType::OPERATOR_CARDINALITY, MetricsType::CUMULATIVE_ROWS_SCANNED, + MetricsType::EXTRA_INFO, + MetricsType::LATENCY, + MetricsType::OPERATOR_CARDINALITY, + MetricsType::OPERATOR_NAME, MetricsType::OPERATOR_ROWS_SCANNED, MetricsType::OPERATOR_TIMING, + MetricsType::OPERATOR_TYPE, MetricsType::RESULT_SET_SIZE, - MetricsType::LATENCY, MetricsType::ROWS_RETURNED, + MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, + MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, MetricsType::TOTAL_BYTES_READ, - MetricsType::TOTAL_BYTES_WRITTEN}; + MetricsType::TOTAL_BYTES_WRITTEN, + MetricsType::QUERY_NAME}; } -profiler_settings_t ProfilingInfo::DefaultRootSettings() { - return {MetricsType::QUERY_NAME, MetricsType::BLOCKED_THREAD_TIME, MetricsType::LATENCY, - MetricsType::ROWS_RETURNED}; +profiler_settings_t ProfilingInfo::RootScopeSettings() { + return {MetricsType::ATTACH_LOAD_STORAGE_LATENCY, + MetricsType::ATTACH_REPLAY_WAL_LATENCY, + MetricsType::BLOCKED_THREAD_TIME, + MetricsType::CHECKPOINT_LATENCY, + MetricsType::LATENCY, + MetricsType::ROWS_RETURNED, + MetricsType::TOTAL_BYTES_READ, + MetricsType::TOTAL_BYTES_WRITTEN, + MetricsType::WAITING_TO_ATTACH_LATENCY, + MetricsType::QUERY_NAME}; } -profiler_settings_t ProfilingInfo::DefaultOperatorSettings() { +profiler_settings_t ProfilingInfo::OperatorScopeSettings() { return {MetricsType::OPERATOR_CARDINALITY, MetricsType::OPERATOR_ROWS_SCANNED, MetricsType::OPERATOR_TIMING, MetricsType::OPERATOR_NAME, MetricsType::OPERATOR_TYPE}; } @@ -83,6 +91,10 @@ void ProfilingInfo::ResetMetrics() { case MetricsType::BLOCKED_THREAD_TIME: case MetricsType::CPU_TIME: case MetricsType::OPERATOR_TIMING: + case MetricsType::WAITING_TO_ATTACH_LATENCY: + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + case MetricsType::CHECKPOINT_LATENCY: metrics[metric] = Value::CreateValue(0.0); break; case MetricsType::OPERATOR_NAME: @@ -209,7 +221,11 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest case MetricsType::LATENCY: case MetricsType::BLOCKED_THREAD_TIME: case MetricsType::CPU_TIME: - case MetricsType::OPERATOR_TIMING: { + case MetricsType::OPERATOR_TIMING: + case MetricsType::WAITING_TO_ATTACH_LATENCY: + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + case MetricsType::CHECKPOINT_LATENCY: { yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[metric].GetValue()); break; } diff --git a/src/main/query_profiler.cpp b/src/main/query_profiler.cpp index 2552248b4652..b4f15a8e68d3 100644 --- a/src/main/query_profiler.cpp +++ b/src/main/query_profiler.cpp @@ -103,9 +103,7 @@ void QueryProfiler::Reset() { phase_timings.clear(); phase_stack.clear(); running = false; - query_metrics.query = ""; - query_metrics.total_bytes_read = 0; - query_metrics.total_bytes_written = 0; + query_metrics.Reset(); } void QueryProfiler::StartQuery(const string &query, bool is_explain_analyze_p, bool start_at_optimizer) { @@ -282,6 +280,21 @@ void QueryProfiler::EndQuery() { if (info.Enabled(settings, MetricsType::RESULT_SET_SIZE)) { info.metrics[MetricsType::RESULT_SET_SIZE] = child_info.metrics[MetricsType::RESULT_SET_SIZE]; } + if (info.Enabled(settings, MetricsType::WAITING_TO_ATTACH_LATENCY)) { + info.metrics[MetricsType::WAITING_TO_ATTACH_LATENCY] = + query_metrics.waiting_to_attach_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::ATTACH_LOAD_STORAGE_LATENCY)) { + info.metrics[MetricsType::ATTACH_LOAD_STORAGE_LATENCY] = + query_metrics.attach_load_storage_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::ATTACH_REPLAY_WAL_LATENCY)) { + info.metrics[MetricsType::ATTACH_REPLAY_WAL_LATENCY] = + query_metrics.attach_replay_wal_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::CHECKPOINT_LATENCY)) { + info.metrics[MetricsType::CHECKPOINT_LATENCY] = query_metrics.checkpoint_latency.Elapsed(); + } MoveOptimizerPhasesToRoot(); if (info.Enabled(settings, MetricsType::CUMULATIVE_OPTIMIZER_TIMING)) { @@ -323,6 +336,52 @@ void QueryProfiler::AddBytesWritten(const idx_t nr_bytes) { } } +void QueryProfiler::StartTimer(MetricsType type) { + if (!IsEnabled()) { + return; + } + + switch (type) { + case MetricsType::WAITING_TO_ATTACH_LATENCY: + query_metrics.waiting_to_attach_latency.Start(); + return; + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + query_metrics.attach_load_storage_latency.Start(); + return; + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + query_metrics.attach_replay_wal_latency.Start(); + return; + case MetricsType::CHECKPOINT_LATENCY: + query_metrics.checkpoint_latency.Start(); + return; + default: + return; + } +} + +void QueryProfiler::EndTimer(MetricsType type) { + if (!IsEnabled()) { + return; + } + + switch (type) { + case MetricsType::WAITING_TO_ATTACH_LATENCY: + query_metrics.waiting_to_attach_latency.End(); + return; + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + query_metrics.attach_load_storage_latency.End(); + return; + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + query_metrics.attach_replay_wal_latency.End(); + return; + case MetricsType::CHECKPOINT_LATENCY: + query_metrics.checkpoint_latency.End(); + return; + default: + return; + } +} + string QueryProfiler::ToString(ExplainFormat explain_format) const { return ToString(GetPrintFormat(explain_format)); } @@ -405,7 +464,7 @@ OperatorProfiler::OperatorProfiler(ClientContext &context) : context(context) { } // Reduce. - auto root_metrics = ProfilingInfo::DefaultRootSettings(); + auto root_metrics = ProfilingInfo::RootScopeSettings(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -974,7 +1033,4 @@ void QueryProfiler::MoveOptimizerPhasesToRoot() { } } -void QueryProfiler::Propagate(QueryProfiler &) { -} - } // namespace duckdb diff --git a/src/main/relation.cpp b/src/main/relation.cpp index 9a28349e7059..b9e4d50ff49e 100644 --- a/src/main/relation.cpp +++ b/src/main/relation.cpp @@ -394,8 +394,8 @@ string Relation::ToString() { } // LCOV_EXCL_START -unique_ptr Relation::GetQueryNode() { - throw InternalException("Cannot create a query node from this node type"); +string Relation::GetQuery() { + return GetQueryNode()->ToString(); } void Relation::Head(idx_t limit) { diff --git a/src/main/relation/create_table_relation.cpp b/src/main/relation/create_table_relation.cpp index 2492f244b7bc..39aa65e3694a 100644 --- a/src/main/relation/create_table_relation.cpp +++ b/src/main/relation/create_table_relation.cpp @@ -29,6 +29,14 @@ BoundStatement CreateTableRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateTableRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a create table relation"); +} + +string CreateTableRelation::GetQuery() { + return string(); +} + const vector &CreateTableRelation::Columns() { return columns; } diff --git a/src/main/relation/create_view_relation.cpp b/src/main/relation/create_view_relation.cpp index c00deef381c5..6f77f013f8d7 100644 --- a/src/main/relation/create_view_relation.cpp +++ b/src/main/relation/create_view_relation.cpp @@ -35,6 +35,14 @@ BoundStatement CreateViewRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateViewRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string CreateViewRelation::GetQuery() { + return string(); +} + const vector &CreateViewRelation::Columns() { return columns; } diff --git a/src/main/relation/delete_relation.cpp b/src/main/relation/delete_relation.cpp index 64b3f231e820..2ec60f66458a 100644 --- a/src/main/relation/delete_relation.cpp +++ b/src/main/relation/delete_relation.cpp @@ -26,6 +26,14 @@ BoundStatement DeleteRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr DeleteRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a delete relation"); +} + +string DeleteRelation::GetQuery() { + return string(); +} + const vector &DeleteRelation::Columns() { return columns; } diff --git a/src/main/relation/explain_relation.cpp b/src/main/relation/explain_relation.cpp index f91e1d29f023..9f2976c9d081 100644 --- a/src/main/relation/explain_relation.cpp +++ b/src/main/relation/explain_relation.cpp @@ -20,6 +20,14 @@ BoundStatement ExplainRelation::Bind(Binder &binder) { return binder.Bind(explain.Cast()); } +unique_ptr ExplainRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an explain relation"); +} + +string ExplainRelation::GetQuery() { + return string(); +} + const vector &ExplainRelation::Columns() { return columns; } diff --git a/src/main/relation/insert_relation.cpp b/src/main/relation/insert_relation.cpp index 9728570a0598..84ef16ec6e47 100644 --- a/src/main/relation/insert_relation.cpp +++ b/src/main/relation/insert_relation.cpp @@ -24,6 +24,14 @@ BoundStatement InsertRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr InsertRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an insert relation"); +} + +string InsertRelation::GetQuery() { + return string(); +} + const vector &InsertRelation::Columns() { return columns; } diff --git a/src/main/relation/query_relation.cpp b/src/main/relation/query_relation.cpp index e0cf2e280e93..6ebebbc70f60 100644 --- a/src/main/relation/query_relation.cpp +++ b/src/main/relation/query_relation.cpp @@ -49,6 +49,10 @@ unique_ptr QueryRelation::GetQueryNode() { return std::move(select->node); } +string QueryRelation::GetQuery() { + return query; +} + unique_ptr QueryRelation::GetTableRef() { auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); return std::move(subquery_ref); @@ -63,7 +67,6 @@ BoundStatement QueryRelation::Bind(Binder &binder) { if (first_bind) { auto &query_node = *select_stmt->node; auto &cte_map = query_node.cte_map; - vector> materialized_ctes; for (auto &kv : replacements) { auto &name = kv.first; auto &tableref = kv.second; @@ -84,28 +87,7 @@ BoundStatement QueryRelation::Bind(Binder &binder) { cte_info->query = std::move(select); cte_map.map[name] = std::move(cte_info); - - // We can not rely on CTE inlining anymore, so we need to add a materialized CTE node - // to the query node to ensure that the CTE exists - auto &cte_entry = cte_map.map[name]; - auto mat_cte = make_uniq(); - mat_cte->ctename = name; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt->node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = root->cte_map.Copy(); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); } - select_stmt->node = std::move(root); } replacements.clear(); binder.SetBindingMode(saved_binding_mode); diff --git a/src/main/relation/update_relation.cpp b/src/main/relation/update_relation.cpp index 9176cf2f20b2..81d85ca89f0b 100644 --- a/src/main/relation/update_relation.cpp +++ b/src/main/relation/update_relation.cpp @@ -35,6 +35,14 @@ BoundStatement UpdateRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr UpdateRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string UpdateRelation::GetQuery() { + return string(); +} + const vector &UpdateRelation::Columns() { return columns; } diff --git a/src/main/relation/write_csv_relation.cpp b/src/main/relation/write_csv_relation.cpp index 4795c7a513ea..f77d6f1eeab0 100644 --- a/src/main/relation/write_csv_relation.cpp +++ b/src/main/relation/write_csv_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteCSVRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteCSVRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write CSV relation"); +} + +string WriteCSVRelation::GetQuery() { + return string(); +} + const vector &WriteCSVRelation::Columns() { return columns; } diff --git a/src/main/relation/write_parquet_relation.cpp b/src/main/relation/write_parquet_relation.cpp index d6e403618dc9..b1dfdb29f911 100644 --- a/src/main/relation/write_parquet_relation.cpp +++ b/src/main/relation/write_parquet_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteParquetRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteParquetRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write parquet relation"); +} + +string WriteParquetRelation::GetQuery() { + return string(); +} + const vector &WriteParquetRelation::Columns() { return columns; } diff --git a/src/main/settings/custom_settings.cpp b/src/main/settings/custom_settings.cpp index 819e1383db02..0f83968b5f89 100644 --- a/src/main/settings/custom_settings.cpp +++ b/src/main/settings/custom_settings.cpp @@ -34,6 +34,14 @@ namespace duckdb { +constexpr const char *LoggingMode::Name; +constexpr const char *LoggingLevel::Name; +constexpr const char *EnableLogging::Name; +constexpr const char *LoggingStorage::Name; +constexpr const char *EnabledLogTypes::Name; +constexpr const char *DisabledLogTypes::Name; +constexpr const char *DisabledFilesystemsSetting::Name; + const string GetDefaultUserAgent() { return StringUtil::Format("duckdb/%s(%s)", DuckDB::LibraryVersion(), DuckDB::Platform()); } diff --git a/src/optimizer/filter_combiner.cpp b/src/optimizer/filter_combiner.cpp index 8e4a295b4889..ddbe82ab0dfb 100644 --- a/src/optimizer/filter_combiner.cpp +++ b/src/optimizer/filter_combiner.cpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/expression.hpp" @@ -907,6 +908,12 @@ FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &com idx_t left_equivalence_set = GetEquivalenceSet(left_node); idx_t right_equivalence_set = GetEquivalenceSet(right_node); if (left_equivalence_set == right_equivalence_set) { + if (comparison.GetExpressionType() == ExpressionType::COMPARE_GREATERTHAN || + comparison.GetExpressionType() == ExpressionType::COMPARE_LESSTHAN) { + // non equal comparison has equal equivalence set, then it is unsatisfiable + // e.g., j > i AND i < j is unsatisfiable + return FilterResult::UNSATISFIABLE; + } // this equality filter already exists, prune it return FilterResult::SUCCESS; } diff --git a/src/optimizer/filter_pushdown.cpp b/src/optimizer/filter_pushdown.cpp index 4fa17f7d049e..7c13386d90b0 100644 --- a/src/optimizer/filter_pushdown.cpp +++ b/src/optimizer/filter_pushdown.cpp @@ -208,17 +208,23 @@ unique_ptr FilterPushdown::PushdownJoin(unique_ptrfilter)); D_ASSERT(result != FilterResult::UNSUPPORTED); - (void)result; + if (result == FilterResult::UNSATISFIABLE) { + // one of the filters is unsatisfiable - abort filter pushdown + return FilterResult::UNSATISFIABLE; + } } filters.clear(); + return FilterResult::SUCCESS; } FilterResult FilterPushdown::AddFilter(unique_ptr expr) { - PushFilters(); + if (PushFilters() == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } // split up the filters by AND predicate vector> expressions; expressions.push_back(std::move(expr)); diff --git a/src/optimizer/pushdown/pushdown_get.cpp b/src/optimizer/pushdown/pushdown_get.cpp index 90dbbb823951..ac4b6532ad94 100644 --- a/src/optimizer/pushdown/pushdown_get.cpp +++ b/src/optimizer/pushdown/pushdown_get.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" namespace duckdb { unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { @@ -48,7 +49,9 @@ unique_ptr FilterPushdown::PushdownGet(unique_ptr(std::move(op)); + } //! We generate the table filters that will be executed during the table scan vector pushdown_results; diff --git a/src/optimizer/rule/regex_optimizations.cpp b/src/optimizer/rule/regex_optimizations.cpp index 24786867b55a..3a0697e99457 100644 --- a/src/optimizer/rule/regex_optimizations.cpp +++ b/src/optimizer/rule/regex_optimizations.cpp @@ -184,6 +184,13 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< if (!escaped_like_string.exists) { return nullptr; } + + // if regexp had options, remove them so the new Contains Expression can be matched for other optimizers. + if (root.children.size() == 3) { + root.children.pop_back(); + D_ASSERT(root.children.size() == 2); + } + auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); auto contains = make_uniq(root.return_type, GetStringContains(), std::move(root.children), nullptr); diff --git a/src/optimizer/topn_window_elimination.cpp b/src/optimizer/topn_window_elimination.cpp index a06cdf8308a9..957cef3675bd 100644 --- a/src/optimizer/topn_window_elimination.cpp +++ b/src/optimizer/topn_window_elimination.cpp @@ -20,6 +20,7 @@ #include "duckdb/planner/expression/bound_unnest_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/function/function_binder.hpp" +#include "duckdb/main/database.hpp" namespace duckdb { @@ -74,20 +75,6 @@ bool BindingsReferenceRowNumber(const vector &bindings, const Log } return false; } -// Window, Filter, new_bindings, aggregate_payload -TopNWindowEliminationParameters ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, - const vector &bindings, - const vector> &aggregate_payload) { - TopNWindowEliminationParameters params; - - auto &limit_expr = filter.expressions[0]->Cast().right; - params.limit = limit_expr->Cast().value.GetValue(); - params.include_row_number = BindingsReferenceRowNumber(bindings, window); - params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; - params.order_type = window.expressions[0]->Cast().orders[0].type; - - return params; -} } // namespace @@ -97,6 +84,11 @@ TopNWindowElimination::TopNWindowElimination(ClientContext &context_p, Optimizer } unique_ptr TopNWindowElimination::Optimize(unique_ptr op) { + auto &extension_manager = context.db->GetExtensionManager(); + if (!extension_manager.ExtensionIsLoaded("core_functions")) { + return op; + } + ColumnBindingReplacer replacer; op = OptimizeInternal(std::move(op), replacer); if (!replacer.replacement_bindings.empty()) { @@ -163,15 +155,24 @@ unique_ptr TopNWindowElimination::OptimizeInternal(unique_ptr(std::move(op)); } -unique_ptr TopNWindowElimination::CreateAggregateExpression(vector> aggregate_params, - const bool requires_arg, - const OrderType order_type) const { +unique_ptr +TopNWindowElimination::CreateAggregateExpression(vector> aggregate_params, + const bool requires_arg, + const TopNWindowEliminationParameters ¶ms) const { auto &catalog = Catalog::GetSystemCatalog(context); FunctionBinder function_binder(context); - D_ASSERT(order_type == OrderType::ASCENDING || order_type == OrderType::DESCENDING); - string fun_name = requires_arg ? "arg_" : ""; - fun_name += order_type == OrderType::ASCENDING ? "min" : "max"; + // If the value column can be null, we must use the nulls_last function to follow null ordering semantics + const bool change_to_arg = !requires_arg && params.can_be_null && params.limit > 1; + if (change_to_arg) { + // Copy value as argument + aggregate_params.insert(aggregate_params.begin() + 1, aggregate_params[0]->Copy()); + } + + D_ASSERT(params.order_type == OrderType::ASCENDING || params.order_type == OrderType::DESCENDING); + string fun_name = requires_arg || change_to_arg ? "arg_" : ""; + fun_name += params.order_type == OrderType::ASCENDING ? "min" : "max"; + fun_name += params.can_be_null && (requires_arg || change_to_arg) ? "_nulls_last" : ""; auto &fun_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, fun_name); const auto fun = fun_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(aggregate_params)); @@ -206,7 +207,7 @@ TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector(Value::BIGINT(params.limit)))); } - auto aggregate_expr = CreateAggregateExpression(std::move(aggregate_params), use_arg, params.order_type); + auto aggregate_expr = CreateAggregateExpression(std::move(aggregate_params), use_arg, params); vector> select_list; select_list.push_back(std::move(aggregate_expr)); @@ -366,10 +367,6 @@ TopNWindowElimination::CreateProjectionOperator(unique_ptr op, } bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { - if (!stats) { - return false; - } - if (op.type != LogicalOperatorType::LOGICAL_FILTER) { return false; } @@ -448,15 +445,9 @@ bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { if (window_expr.orders[0].type != OrderType::DESCENDING && window_expr.orders[0].type != OrderType::ASCENDING) { return false; } - - VisitExpression(&window_expr.orders[0].expression); - for (const auto &column_ref : column_references) { - const auto &column_stats = stats->find(column_ref.first); - if (column_stats == stats->end() || column_stats->second->CanHaveNull()) { - return false; - } + if (window_expr.orders[0].null_order != OrderByNullType::NULLS_LAST) { + return false; } - column_references.clear(); // We have found a grouped top-n window construct! return true; @@ -589,4 +580,32 @@ void TopNWindowElimination::UpdateTopmostBindings(const idx_t window_idx, const } } +TopNWindowEliminationParameters +TopNWindowElimination::ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload) { + TopNWindowEliminationParameters params; + + auto &limit_expr = filter.expressions[0]->Cast().right; + params.limit = limit_expr->Cast().value.GetValue(); + params.include_row_number = BindingsReferenceRowNumber(bindings, window); + params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; + auto &window_expr = window.expressions[0]->Cast(); + params.order_type = window_expr.orders[0].type; + + VisitExpression(&window_expr.orders[0].expression); + if (params.payload_type == TopNPayloadType::SINGLE_COLUMN && !aggregate_payload.empty()) { + VisitExpression(&aggregate_payload[0]); + } + for (const auto &column_ref : column_references) { + const auto &column_stats = stats->find(column_ref.first); + if (column_stats == stats->end() || column_stats->second->CanHaveNull()) { + params.can_be_null = true; + } + } + column_references.clear(); + + return params; +} + } // namespace duckdb diff --git a/src/parallel/task_executor.cpp b/src/parallel/task_executor.cpp index fa2c0087c3eb..9487a1427624 100644 --- a/src/parallel/task_executor.cpp +++ b/src/parallel/task_executor.cpp @@ -69,8 +69,10 @@ TaskExecutionResult BaseExecutorTask::Execute(TaskExecutionMode mode) { return TaskExecutionResult::TASK_FINISHED; } try { - TaskNotifier task_notifier {executor.context}; - ExecuteTask(); + { + TaskNotifier task_notifier {executor.context}; + ExecuteTask(); + } executor.FinishTask(); return TaskExecutionResult::TASK_FINISHED; } catch (std::exception &ex) { diff --git a/src/parser/expression/lambdaref_expression.cpp b/src/parser/expression/lambdaref_expression.cpp index fed844feae15..e71debfc0766 100644 --- a/src/parser/expression/lambdaref_expression.cpp +++ b/src/parser/expression/lambdaref_expression.cpp @@ -47,7 +47,7 @@ LambdaRefExpression::FindMatchingBinding(optional_ptr> &lam if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { if ((*lambda_bindings)[i - 1].HasMatchingBinding(column_name)) { - D_ASSERT((*lambda_bindings)[i - 1].alias.IsSet()); + D_ASSERT((*lambda_bindings)[i - 1].GetBindingAlias().IsSet()); return make_uniq(i - 1, column_name); } } diff --git a/src/parser/parsed_expression_iterator.cpp b/src/parser/parsed_expression_iterator.cpp index 7ca38a10ee1c..47e39dc76239 100644 --- a/src/parser/parsed_expression_iterator.cpp +++ b/src/parser/parsed_expression_iterator.cpp @@ -271,12 +271,6 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( EnumerateQueryNodeChildren(*rcte_node.right, expr_callback, ref_callback); break; } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.query, expr_callback, ref_callback); - EnumerateQueryNodeChildren(*cte_node.child, expr_callback, ref_callback); - break; - } case QueryNodeType::SELECT_NODE: { auto &sel_node = node.Cast(); for (idx_t i = 0; i < sel_node.select_list.size(); i++) { diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 9ee2a675b4f2..8fff8df326e6 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -221,13 +221,26 @@ void Parser::ParseQuery(const string &query) { if (StringUtil::CIEquals(parser_override_option, "strict")) { if (result.type == ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR) { throw ParserException( - "Parser override failed to return a valid statement. Consider restarting the database and " + "Parser override failed to return a valid statement: %s\n\nConsider restarting the " + "database and " "using the setting \"set allow_parser_override_extension=fallback\" to fallback to the " - "default parser."); + "default parser.", + result.error.RawMessage()); } if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { - throw ParserException(result.error); + if (result.error.Type() == ExceptionType::NOT_IMPLEMENTED) { + throw NotImplementedException( + "Parser override has not yet implemented this transformer rule. (Original error: %s)", + result.error.RawMessage()); + } else if (result.error.Type() == ExceptionType::PARSER) { + throw ParserException("Parser override could not parse this query. (Original error: %s)", + result.error.RawMessage()); + } else { + result.error.Throw(); + } } + } else if (StringUtil::CIEquals(parser_override_option, "fallback")) { + continue; } } } @@ -299,7 +312,9 @@ void Parser::ParseQuery(const string &query) { bool parsed_single_statement = false; for (auto &ext : *options.extensions) { D_ASSERT(!parsed_single_statement); - D_ASSERT(ext.parse_function); + if (!ext.parse_function) { + continue; + } auto result = ext.parse_function(ext.parser_info.get(), query_statement); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { auto statement = make_uniq(ext, std::move(result.parse_data)); diff --git a/src/parser/query_node/cte_node.cpp b/src/parser/query_node/cte_node.cpp index 1e1f0e19909d..29c0599a52ed 100644 --- a/src/parser/query_node/cte_node.cpp +++ b/src/parser/query_node/cte_node.cpp @@ -1,42 +1,20 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { string CTENode::ToString() const { - string result; - result += child->ToString(); - return result; + throw InternalException("CTENode is a legacy type"); } bool CTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (!query->Equals(other.query.get())) { - return false; - } - if (!child->Equals(other.child.get())) { - return false; - } - return true; + throw InternalException("CTENode is a legacy type"); } unique_ptr CTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->query = query->Copy(); - result->child = child->Copy(); - result->aliases = aliases; - result->materialized = materialized; - this->CopyProperties(*result); - return std::move(result); + throw InternalException("CTENode is a legacy type"); } } // namespace duckdb diff --git a/src/parser/query_node/set_operation_node.cpp b/src/parser/query_node/set_operation_node.cpp index a8b624f21748..cdc188820dab 100644 --- a/src/parser/query_node/set_operation_node.cpp +++ b/src/parser/query_node/set_operation_node.cpp @@ -8,10 +8,6 @@ namespace duckdb { SetOperationNode::SetOperationNode() : QueryNode(QueryNodeType::SET_OPERATION_NODE) { } -const vector> &SetOperationNode::GetSelectList() const { - return children[0]->GetSelectList(); -} - string SetOperationNode::ToString() const { string result; result = cte_map.ToString(); diff --git a/src/parser/query_node/statement_node.cpp b/src/parser/query_node/statement_node.cpp index e27b2e6c03a6..66e7b8e5a602 100644 --- a/src/parser/query_node/statement_node.cpp +++ b/src/parser/query_node/statement_node.cpp @@ -5,9 +5,6 @@ namespace duckdb { StatementNode::StatementNode(SQLStatement &stmt_p) : QueryNode(QueryNodeType::STATEMENT_NODE), stmt(stmt_p) { } -const vector> &StatementNode::GetSelectList() const { - throw InternalException("StatementNode has no select list"); -} //! Convert the query node to a string string StatementNode::ToString() const { return stmt.ToString(); diff --git a/src/parser/statement/relation_statement.cpp b/src/parser/statement/relation_statement.cpp index 9b3801495447..023d3cac920c 100644 --- a/src/parser/statement/relation_statement.cpp +++ b/src/parser/statement/relation_statement.cpp @@ -5,10 +5,7 @@ namespace duckdb { RelationStatement::RelationStatement(shared_ptr relation_p) : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation_p)) { - if (relation->type == RelationType::QUERY_RELATION) { - auto &query_relation = relation->Cast(); - query = query_relation.query; - } + query = relation->GetQuery(); } unique_ptr RelationStatement::Copy() const { diff --git a/src/parser/transform/expression/transform_subquery.cpp b/src/parser/transform/expression/transform_subquery.cpp index bc8a9762dc51..986e46e25eaa 100644 --- a/src/parser/transform/expression/transform_subquery.cpp +++ b/src/parser/transform/expression/transform_subquery.cpp @@ -24,7 +24,6 @@ unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::P subquery_expr->subquery = TransformSelectStmt(*root.subselect); SetQueryLocation(*subquery_expr, root.location); D_ASSERT(subquery_expr->subquery); - D_ASSERT(!subquery_expr->subquery->node->GetSelectList().empty()); switch (root.subLinkType) { case duckdb_libpgquery::PG_EXISTS_SUBLINK: { diff --git a/src/parser/transform/helpers/transform_cte.cpp b/src/parser/transform/helpers/transform_cte.cpp index 2de5d8334c38..f53c6dbb81cb 100644 --- a/src/parser/transform/helpers/transform_cte.cpp +++ b/src/parser/transform/helpers/transform_cte.cpp @@ -25,7 +25,7 @@ CommonTableExpressionInfo::~CommonTableExpressionInfo() { void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { for (auto &cte_entry : stored_cte_map) { - for (auto &entry : cte_entry->map) { + for (auto &entry : cte_entry.get().map) { auto found_entry = cte_map.map.find(entry.first); if (found_entry != cte_map.map.end()) { // entry already present - use top-most entry @@ -40,7 +40,7 @@ void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { } void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map) { - stored_cte_map.push_back(&cte_map); + stored_cte_map.push_back(cte_map); // TODO: might need to update in case of future lawsuit D_ASSERT(de_with_clause.ctes); diff --git a/src/parser/transform/statement/transform_pivot_stmt.cpp b/src/parser/transform/statement/transform_pivot_stmt.cpp index 07dfb420d9dc..4572a3a36a16 100644 --- a/src/parser/transform/statement/transform_pivot_stmt.cpp +++ b/src/parser/transform/statement/transform_pivot_stmt.cpp @@ -95,7 +95,7 @@ unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr(); - select->node = TransformMaterializedCTE(std::move(subselect)); + select->node = std::move(subselect); info->query = std::move(select); info->type = LogicalType::INVALID; diff --git a/src/parser/transform/statement/transform_select.cpp b/src/parser/transform/statement/transform_select.cpp index 2e5135ef640d..16cd1a490144 100644 --- a/src/parser/transform/statement/transform_select.cpp +++ b/src/parser/transform/statement/transform_select.cpp @@ -26,13 +26,10 @@ unique_ptr Transformer::TransformSelectNodeInternal(duckdb_libpgquery throw ParserException("SELECT locking clause is not supported!"); } } - unique_ptr stmt = nullptr; if (select.pivot) { - stmt = TransformPivotStatement(select); - } else { - stmt = TransformSelectInternal(select); + return TransformPivotStatement(select); } - return TransformMaterializedCTE(std::move(stmt)); + return TransformSelectInternal(select); } unique_ptr Transformer::TransformSelectStmt(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { diff --git a/src/parser/transformer.cpp b/src/parser/transformer.cpp index 4ab39fca7079..32ddaa87a421 100644 --- a/src/parser/transformer.cpp +++ b/src/parser/transformer.cpp @@ -232,31 +232,6 @@ unique_ptr Transformer::TransformStatementInternal(duckdb_libpgque } } -unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - - for (auto &cte : root->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = TransformMaterializedCTE(cte_entry->query->node->Copy()); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return root; -} - void Transformer::SetQueryLocation(ParsedExpression &expr, int query_location) { if (query_location < 0) { return; diff --git a/src/planner/bind_context.cpp b/src/planner/bind_context.cpp index 038eddb4229c..cc1b3d25e459 100644 --- a/src/planner/bind_context.cpp +++ b/src/planner/bind_context.cpp @@ -38,16 +38,17 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) optional_ptr result; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; - auto is_using_binding = GetUsingBinding(column_name, binding.alias); + auto is_using_binding = GetUsingBinding(column_name, binding.GetBindingAlias()); if (is_using_binding) { continue; } if (binding.HasMatchingBinding(column_name)) { if (result || is_using_binding) { - throw BinderException("Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " - "or \"%s.%s\")", - column_name, MinimumUniqueAlias(result->alias, binding.alias), column_name, - MinimumUniqueAlias(binding.alias, result->alias), column_name); + throw BinderException( + "Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " + "or \"%s.%s\")", + column_name, MinimumUniqueAlias(result->GetBindingAlias(), binding.GetBindingAlias()), column_name, + MinimumUniqueAlias(binding.GetBindingAlias(), result->GetBindingAlias()), column_name); } result = &binding; } @@ -58,8 +59,8 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) vector BindContext::GetSimilarBindings(const string &column_name) { vector> scores; for (auto &binding_ptr : bindings_list) { - auto binding = *binding_ptr; - for (auto &name : binding.names) { + auto &binding = *binding_ptr; + for (auto &name : binding.GetColumnNames()) { double distance = StringUtil::SimilarityRating(name, column_name); // check if we need to qualify the column auto matching_bindings = GetMatchingBindings(name); @@ -157,7 +158,7 @@ string BindContext::GetActualColumnName(Binding &binding, const string &column_n throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding.GetAlias(), column_name); } // LCOV_EXCL_STOP - return binding.names[binding_index]; + return binding.GetColumnNames()[binding_index]; } string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const string &column_name) { @@ -200,7 +201,7 @@ unique_ptr BindContext::CreateColumnReference(const string &ta } static bool ColumnIsGenerated(Binding &binding, column_t index) { - if (binding.binding_type != BindingType::TABLE) { + if (binding.GetBindingType() != BindingType::TABLE) { return false; } auto &table_binding = binding.Cast(); @@ -239,10 +240,12 @@ unique_ptr BindContext::CreateColumnReference(const string &ca auto column_index = binding->GetBindingIndex(column_name); if (bind_type == ColumnBindType::EXPAND_GENERATED_COLUMNS && ColumnIsGenerated(*binding, column_index)) { return ExpandGeneratedColumn(binding->Cast(), column_name); - } else if (column_index < binding->names.size() && binding->names[column_index] != column_name) { + } + auto &column_names = binding->GetColumnNames(); + if (column_index < column_names.size() && column_names[column_index] != column_name) { // because of case insensitivity in the binder we rename the column to the original name // as it appears in the binding itself - result->SetAlias(binding->names[column_index]); + result->SetAlias(column_names[column_index]); } return std::move(result); } @@ -253,14 +256,6 @@ unique_ptr BindContext::CreateColumnReference(const string &sc return CreateColumnReference(catalog_name, schema_name, table_name, column_name, bind_type); } -optional_ptr BindContext::GetCTEBinding(const string &ctename) { - auto match = cte_bindings.find(ctename); - if (match == cte_bindings.end()) { - return nullptr; - } - return match->second.get(); -} - string GetCandidateAlias(const BindingAlias &main_alias, const BindingAlias &new_alias) { string candidate; if (!main_alias.GetCatalog().empty() && !new_alias.GetCatalog().empty()) { @@ -279,7 +274,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E } vector> matching_bindings; for (auto &binding : bindings_list) { - if (binding->alias.Matches(alias)) { + if (binding->GetBindingAlias().Matches(alias)) { matching_bindings.push_back(*binding); } } @@ -287,7 +282,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E // alias not found in this BindContext vector candidates; for (auto &binding : bindings_list) { - candidates.push_back(GetCandidateAlias(alias, binding->alias)); + candidates.push_back(GetCandidateAlias(alias, binding->GetBindingAlias())); } auto main_alias = GetCandidateAlias(alias, alias); string candidate_str = @@ -311,14 +306,14 @@ string BindContext::AmbiguityException(const BindingAlias &alias, const vector handled_using_columns; for (auto &entry : bindings_list) { auto &binding = *entry; - for (auto &column_name : binding.names) { - QualifiedColumnName qualified_column(binding.alias, column_name); + auto &column_names = binding.GetColumnNames(); + auto &binding_alias = binding.GetBindingAlias(); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_column(binding_alias, column_name); if (CheckExclusionList(expr, qualified_column, exclusion_info)) { continue; } // check if this column is a USING column - auto using_binding_ptr = GetUsingBinding(column_name, binding.alias); + auto using_binding_ptr = GetUsingBinding(column_name, binding_alias); if (using_binding_ptr) { auto &using_binding = *using_binding_ptr; // it is! @@ -526,7 +524,7 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, continue; } auto new_expr = - CreateColumnReference(binding.alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_column, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -544,17 +542,20 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, } is_struct_ref = true; } + auto &binding_alias = binding->GetBindingAlias(); + auto &column_names = binding->GetColumnNames(); + auto &column_types = binding->GetColumnTypes(); if (is_struct_ref) { auto col_idx = binding->GetBindingIndex(expr.relation_name); - auto col_type = binding->types[col_idx]; + auto col_type = column_types[col_idx]; if (col_type.id() != LogicalTypeId::STRUCT) { throw BinderException(StringUtil::Format( "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); } auto &struct_children = StructType::GetChildTypes(col_type); vector column_names(3); - column_names[0] = binding->alias.GetAlias(); + column_names[0] = binding->GetAlias(); column_names[1] = expr.relation_name; for (auto &child : struct_children) { QualifiedColumnName qualified_name(child.first); @@ -567,13 +568,13 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, new_select_list.push_back(std::move(new_expr)); } } else { - for (auto &column_name : binding->names) { - QualifiedColumnName qualified_name(binding->alias, column_name); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_name(binding_alias, column_name); if (CheckExclusionList(expr, qualified_name, exclusion_info)) { continue; } auto new_expr = - CreateColumnReference(binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_name, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -609,10 +610,12 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { for (auto &binding_entry : bindings_list) { auto &binding = *binding_entry; - D_ASSERT(binding.names.size() == binding.types.size()); - for (idx_t i = 0; i < binding.names.size(); i++) { - result_names.push_back(binding.names[i]); - result_types.push_back(binding.types[i]); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); + for (idx_t i = 0; i < column_names.size(); i++) { + result_names.push_back(column_names[i]); + result_types.push_back(column_types[i]); } } } @@ -708,19 +711,28 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect AddBinding(make_uniq(BindingType::BASE, BindingAlias(alias), types, names, index)); } -void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, - const vector &types, bool using_key) { - auto binding = make_uniq(BindingAlias(alias), types, names, index); - - if (cte_bindings.find(alias) != cte_bindings.end()) { - throw BinderException("Duplicate CTE binding \"%s\" in query!", alias); +void BindContext::AddCTEBinding(unique_ptr binding) { + for (auto &cte_binding : cte_bindings) { + if (cte_binding->GetBindingAlias() == binding->GetBindingAlias()) { + throw BinderException("Duplicate CTE binding \"%s\" in query!", binding->GetBindingAlias().ToString()); + } } - cte_bindings[alias] = std::move(binding); + cte_bindings.push_back(std::move(binding)); +} + +void BindContext::AddCTEBinding(idx_t index, BindingAlias alias_p, const vector &names, + const vector &types, CTEType cte_type) { + auto binding = make_uniq(std::move(alias_p), types, names, index, cte_type); + AddCTEBinding(std::move(binding)); +} - if (using_key) { - auto recurring_alias = "recurring." + alias; - cte_bindings[recurring_alias] = make_uniq(BindingAlias(recurring_alias), types, names, index); +optional_ptr BindContext::GetCTEBinding(const BindingAlias &ctename) { + for (auto &binding : cte_bindings) { + if (binding->GetBindingAlias().Matches(ctename)) { + return binding.get(); + } } + return nullptr; } void BindContext::AddContext(BindContext other) { @@ -737,7 +749,7 @@ void BindContext::AddContext(BindContext other) { vector BindContext::GetBindingAliases() { vector result; for (auto &binding : bindings_list) { - result.push_back(BindingAlias(binding->alias)); + result.push_back(binding->GetBindingAlias()); } return result; } @@ -764,7 +776,7 @@ void BindContext::RemoveContext(const vector &aliases) { // remove the binding from the list of bindings auto it = std::remove_if(bindings_list.begin(), bindings_list.end(), - [&](unique_ptr &x) { return x->alias == alias; }); + [&](unique_ptr &x) { return x->GetBindingAlias() == alias; }); bindings_list.erase(it, bindings_list.end()); } } diff --git a/src/planner/binder.cpp b/src/planner/binder.cpp index 07d225378488..9b3708be65ef 100644 --- a/src/planner/binder.cpp +++ b/src/planner/binder.cpp @@ -68,28 +68,9 @@ BoundStatement Binder::BindWithCTE(T &statement) { return Bind(statement); } - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - for (auto &cte : cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = std::move(cte_entry->query->node); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - unique_ptr cte_root = make_uniq(statement); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(cte_root); - cte_root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return Bind(*cte_root); + auto stmt_node = make_uniq(statement); + stmt_node->cte_map = cte_map.Copy(); + return Bind(*stmt_node); } BoundStatement Binder::Bind(SQLStatement &statement) { @@ -152,24 +133,6 @@ BoundStatement Binder::Bind(SQLStatement &statement) { } // LCOV_EXCL_STOP } -BoundStatement Binder::BindNode(QueryNode &node) { - // now we bind the node - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return BindNode(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return BindNode(node.Cast()); - case QueryNodeType::CTE_NODE: - return BindNode(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return BindNode(node.Cast()); - case QueryNodeType::STATEMENT_NODE: - return BindNode(node.Cast()); - default: - throw InternalException("Unsupported query node type"); - } -} - BoundStatement Binder::Bind(QueryNode &node) { return BindNode(node); } @@ -221,34 +184,27 @@ BoundStatement Binder::Bind(TableRef &ref) { return result; } -void Binder::AddCTE(const string &name) { - D_ASSERT(!name.empty()); - CTE_bindings.insert(name); -} - -optional_ptr Binder::GetCTEBinding(const string &name) { +optional_ptr Binder::GetCTEBinding(const BindingAlias &name) { reference current_binder(*this); + optional_ptr result; while (true) { auto ¤t = current_binder.get(); auto entry = current.bind_context.GetCTEBinding(name); if (entry) { - return entry; + // we only directly return the CTE if it can be referenced + // if it cannot be referenced (circular reference) we keep going up the stack + // to look for a CTE that can be referenced + if (entry->CanBeReferenced()) { + return entry; + } + result = entry; } if (!current.parent || current.binder_type != BinderType::REGULAR_BINDER) { - return nullptr; + break; } current_binder = *current.parent; } -} - -bool Binder::CTEExists(const string &name) { - if (CTE_bindings.find(name) != CTE_bindings.end()) { - return true; - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - return parent->CTEExists(name); - } - return false; + return result; } void Binder::AddBoundView(ViewCatalogEntry &view) { @@ -272,11 +228,11 @@ StatementProperties &Binder::GetStatementProperties() { } optional_ptr Binder::GetParameters() { - return query_binder_state->parameters; + return global_binder_state->parameters; } void Binder::SetParameters(BoundParameterMap ¶meters) { - query_binder_state->parameters = parameters; + global_binder_state->parameters = parameters; } void Binder::PushExpressionBinder(ExpressionBinder &binder) { @@ -343,7 +299,6 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con const string &table_name, const string &column_name, ErrorData &error) { optional_ptr binding; - D_ASSERT(!lambda_bindings); if (macro_binding && table_name == macro_binding->GetAlias()) { binding = optional_ptr(macro_binding.get()); } else { diff --git a/src/planner/binder/expression/bind_columnref_expression.cpp b/src/planner/binder/expression/bind_columnref_expression.cpp index 886a1ff42a63..e8b83e95cec4 100644 --- a/src/planner/binder/expression/bind_columnref_expression.cpp +++ b/src/planner/binder/expression/bind_columnref_expression.cpp @@ -94,12 +94,12 @@ unique_ptr ExpressionBinder::QualifyColumnName(const string &c // bind as a macro column if (is_macro_column) { - return binder.bind_context.CreateColumnReference(binder.macro_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(binder.macro_binding->GetBindingAlias(), column_name); } // bind as a regular column if (table_binding) { - return binder.bind_context.CreateColumnReference(table_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(table_binding->GetBindingAlias(), column_name); } // it's not, find candidates and error @@ -276,11 +276,12 @@ unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpress } // We found the table, now create the struct_pack expression + auto &column_names = binding->GetColumnNames(); vector> child_expressions; - child_expressions.reserve(binding->names.size()); - for (const auto &column_name : binding->names) { + child_expressions.reserve(column_names.size()); + for (const auto &column_name : column_names) { child_expressions.push_back(binder.bind_context.CreateColumnReference( - binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); + binding->GetBindingAlias(), column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); } return make_uniq("struct_pack", std::move(child_expressions)); } @@ -312,7 +313,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.schema.table.column" struct_extract_start = 4; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[3]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[3]); } } ErrorData catalog_table_error; @@ -321,7 +322,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.table.column" struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData schema_table_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2], @@ -330,7 +331,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // part1 is a schema - the column reference is "schema.table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData table_column_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], table_column_error); @@ -339,7 +340,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // the column reference is "table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 2; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[1]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[1]); } // part1 could be a column ErrorData unused_error; @@ -360,7 +361,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte optional_idx schema_pos; optional_idx table_pos; for (const auto &binding_entry : binder.bind_context.GetBindingsList()) { - auto &alias = binding_entry->alias; + auto &alias = binding_entry->GetBindingAlias(); string catalog = alias.GetCatalog(); string schema = alias.GetSchema(); string table = alias.GetAlias(); @@ -483,7 +484,7 @@ unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpres auto binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], error); if (binding) { // it is! return the column reference directly - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.GetColumnName()); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.GetColumnName()); } // otherwise check if we can turn this into a struct extract diff --git a/src/planner/binder/expression/bind_function_expression.cpp b/src/planner/binder/expression/bind_function_expression.cpp index e0d775db1b46..7274acf1662e 100644 --- a/src/planner/binder/expression/bind_function_expression.cpp +++ b/src/planner/binder/expression/bind_function_expression.cpp @@ -304,11 +304,13 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc for (idx_t i = lambda_bindings->size(); i > 0; i--) { auto &binding = (*lambda_bindings)[i - 1]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); - for (idx_t column_idx = binding.names.size(); column_idx > 0; column_idx--) { - auto bound_lambda_param = make_uniq(binding.names[column_idx - 1], - binding.types[column_idx - 1], offset); + for (idx_t column_idx = column_names.size(); column_idx > 0; column_idx--) { + auto bound_lambda_param = make_uniq(column_names[column_idx - 1], + column_types[column_idx - 1], offset); offset++; bound_function_expr.children.push_back(std::move(bound_lambda_param)); } diff --git a/src/planner/binder/expression/bind_lambda.cpp b/src/planner/binder/expression/bind_lambda.cpp index 592daa245659..ad8402824794 100644 --- a/src/planner/binder/expression/bind_lambda.cpp +++ b/src/planner/binder/expression/bind_lambda.cpp @@ -12,25 +12,25 @@ namespace duckdb { -idx_t GetLambdaParamCount(const vector &lambda_bindings) { +idx_t GetLambdaParamCount(vector &lambda_bindings) { idx_t count = 0; for (auto &binding : lambda_bindings) { - count += binding.names.size(); + count += binding.GetColumnCount(); } return count; } -idx_t GetLambdaParamIndex(const vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, +idx_t GetLambdaParamIndex(vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, const BoundLambdaRefExpression &bound_lambda_ref_expr) { D_ASSERT(bound_lambda_ref_expr.lambda_idx < lambda_bindings.size()); idx_t offset = 0; // count the remaining lambda parameters BEFORE the current lambda parameter, // as these will be in front of the current lambda parameter in the input chunk for (idx_t i = bound_lambda_ref_expr.lambda_idx + 1; i < lambda_bindings.size(); i++) { - offset += lambda_bindings[i].names.size(); + offset += lambda_bindings[i].GetColumnCount(); } - offset += - lambda_bindings[bound_lambda_ref_expr.lambda_idx].names.size() - bound_lambda_ref_expr.binding.column_index - 1; + offset += lambda_bindings[bound_lambda_ref_expr.lambda_idx].GetColumnCount() - + bound_lambda_ref_expr.binding.column_index - 1; offset += bound_lambda_expr.parameter_count; return offset; } @@ -148,16 +148,18 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori if (lambda_bindings && bound_lambda_ref.lambda_idx != lambda_bindings->size()) { auto &binding = (*lambda_bindings)[bound_lambda_ref.lambda_idx]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); // find the matching dummy column in the lambda binding - for (idx_t column_idx = 0; column_idx < binding.names.size(); column_idx++) { + for (idx_t column_idx = 0; column_idx < binding.GetColumnCount(); column_idx++) { if (column_idx == bound_lambda_ref.binding.column_index) { // now create the replacement auto index = GetLambdaParamIndex(*lambda_bindings, bound_lambda_expr, bound_lambda_ref); - replacement = make_uniq(binding.names[column_idx], - binding.types[column_idx], index); + replacement = + make_uniq(column_names[column_idx], column_types[column_idx], index); return; } } diff --git a/src/planner/binder/expression/bind_macro_expression.cpp b/src/planner/binder/expression/bind_macro_expression.cpp index cce06d712d05..5d2bce798653 100644 --- a/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/planner/binder/expression/bind_macro_expression.cpp @@ -98,6 +98,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala // validate the arguments and separate positional and default arguments vector> positional_arguments; InsertionOrderPreservingMap> named_arguments; + binder.lambda_bindings = lambda_bindings; auto bind_result = MacroFunction::BindMacroFunction(binder, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); if (!bind_result.error.empty()) { diff --git a/src/planner/binder/expression/bind_star_expression.cpp b/src/planner/binder/expression/bind_star_expression.cpp index f48fc14e6710..4cc0e3a234e5 100644 --- a/src/planner/binder/expression/bind_star_expression.cpp +++ b/src/planner/binder/expression/bind_star_expression.cpp @@ -152,10 +152,15 @@ string Binder::ReplaceColumnsAlias(const string &alias, const string &column_nam void TryTransformStarLike(unique_ptr &root) { // detect "* LIKE [literal]" and similar expressions - if (root->GetExpressionClass() != ExpressionClass::FUNCTION) { + bool inverse = root->GetExpressionType() == ExpressionType::OPERATOR_NOT; + auto &expr = inverse ? root->Cast().children[0] : root; + if (!expr) { return; } - auto &function = root->Cast(); + if (expr->GetExpressionClass() != ExpressionClass::FUNCTION) { + return; + } + auto &function = expr->Cast(); if (function.children.size() < 2 || function.children.size() > 3) { return; } @@ -197,7 +202,7 @@ void TryTransformStarLike(unique_ptr &root) { auto original_alias = root->GetAlias(); auto star_expr = std::move(left); unique_ptr child_expr; - if (function.function_name == "regexp_full_match" && star.exclude_list.empty()) { + if (!inverse && function.function_name == "regexp_full_match" && star.exclude_list.empty()) { // * SIMILAR TO '[regex]' is equivalent to COLUMNS('[regex]') so we can just move the expression directly child_expr = std::move(right); } else { @@ -207,13 +212,20 @@ void TryTransformStarLike(unique_ptr &root) { vector named_parameters; named_parameters.push_back("__lambda_col"); function.children[0] = make_uniq("__lambda_col"); + function.children[1] = std::move(right); + + unique_ptr lambda_body = std::move(expr); + if (inverse) { + vector> root_children; + root_children.push_back(std::move(lambda_body)); + lambda_body = make_uniq(ExpressionType::OPERATOR_NOT, std::move(root_children)); + } + auto lambda = make_uniq(std::move(named_parameters), std::move(lambda_body)); - auto lambda = make_uniq(std::move(named_parameters), std::move(root)); vector> filter_children; filter_children.push_back(std::move(star_expr)); filter_children.push_back(std::move(lambda)); - auto list_filter = make_uniq("list_filter", std::move(filter_children)); - child_expr = std::move(list_filter); + child_expr = make_uniq("list_filter", std::move(filter_children)); } auto columns_expr = make_uniq(); diff --git a/src/planner/binder/expression/bind_subquery_expression.cpp b/src/planner/binder/expression/bind_subquery_expression.cpp index 8e15f3b28adc..7f03f0e32d08 100644 --- a/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/planner/binder/expression/bind_subquery_expression.cpp @@ -23,10 +23,6 @@ class BoundSubqueryNode : public QueryNode { BoundStatement bound_node; unique_ptr subquery; - const vector> &GetSelectList() const override { - throw InternalException("Cannot get select list of bound subquery node"); - } - string ToString() const override { throw InternalException("Cannot ToString bound subquery node"); } diff --git a/src/planner/binder/query_node/CMakeLists.txt b/src/planner/binder/query_node/CMakeLists.txt index 39efaf89c275..709d23ce1749 100644 --- a/src/planner/binder/query_node/CMakeLists.txt +++ b/src/planner/binder/query_node/CMakeLists.txt @@ -8,8 +8,6 @@ add_library_unity( bind_statement_node.cpp bind_table_macro_node.cpp plan_query_node.cpp - plan_recursive_cte_node.cpp - plan_cte_node.cpp plan_select_node.cpp plan_setop.cpp plan_subquery.cpp) diff --git a/src/planner/binder/query_node/bind_cte_node.cpp b/src/planner/binder/query_node/bind_cte_node.cpp index 92ca383c7e10..8cea37ebc277 100644 --- a/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/planner/binder/query_node/bind_cte_node.cpp @@ -1,97 +1,170 @@ -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression_map.hpp" -#include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/parser/query_node/list.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { -BoundStatement Binder::BindNode(CTENode &statement) { - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +struct BoundCTEData { + string ctename; + CTEMaterialize materialized; + idx_t setop_index; + shared_ptr child_binder; + shared_ptr cte_bind_state; +}; + +BoundStatement Binder::BindNode(QueryNode &node) { + reference current_binder(*this); + vector bound_ctes; + for (auto &cte : node.cte_map.map) { + bound_ctes.push_back(current_binder.get().PrepareCTE(cte.first, *cte.second)); + current_binder = *bound_ctes.back().child_binder; + } + BoundStatement result; + // now we bind the node + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::STATEMENT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + default: + throw InternalException("Unsupported query node type"); + } + for (idx_t i = bound_ctes.size(); i > 0; i--) { + auto &finish_binder = i == 1 ? *this : *bound_ctes[i - 2].child_binder; + result = finish_binder.FinishCTE(bound_ctes[i - 1], std::move(result)); + } + return result; +} - return BindCTE(statement); +CTEBindState::CTEBindState(Binder &parent_binder_p, QueryNode &cte_def_p, const vector &aliases_p) + : parent_binder(parent_binder_p), cte_def(cte_def_p), aliases(aliases_p), + active_binder_count(parent_binder.GetActiveBinders().size()) { } -BoundStatement Binder::BindCTE(CTENode &statement) { - BoundCTENode result; +CTEBindState::~CTEBindState() { +} - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +bool CTEBindState::IsBound() const { + return query_binder.get() != nullptr; +} - result.ctename = statement.ctename; - result.materialized = statement.materialized; - result.setop_index = GenerateTableIndex(); +void CTEBindState::Bind(CTEBinding &binding) { + // we are lazily binding the CTE + // we need to bind it as if we were binding it during PrepareCTE + query_binder = Binder::CreateBinder(parent_binder.context, parent_binder); + + // we clear any expression binders that were added in the mean-time, to ensure we are not binding to any newly added + // correlated columns + auto &active_binders = parent_binder.GetActiveBinders(); + vector> stored_binders; + for (idx_t i = active_binder_count; i < active_binders.size(); i++) { + stored_binders.push_back(active_binders[i]); + } + active_binders.erase(active_binders.begin() + UnsafeNumericCast(active_binder_count), + active_binders.end()); - AddCTE(result.ctename); + // add this CTE to the query binder on the RHS with "CANNOT_BE_REFERENCED" to detect recursive references to + // ourselves + query_binder->bind_context.AddCTEBinding(binding.GetIndex(), binding.GetBindingAlias(), vector(), + vector(), CTEType::CANNOT_BE_REFERENCED); - result.query_binder = Binder::CreateBinder(context, this); - result.query = result.query_binder->BindNode(*statement.query); + // bind the actual CTE + query = query_binder->Bind(cte_def); + + // after binding - we add the active binders we removed back so we can leave the binder in its original state + for (auto &stored_binder : stored_binders) { + active_binders.push_back(stored_binder); + } // the result types of the CTE are the types of the LHS - result.types = result.query.types; + types = query.types; // names are picked from the LHS, unless aliases are explicitly specified - result.names = result.query.names; - for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { - result.names[i] = statement.aliases[i]; + names = query.names; + for (idx_t i = 0; i < aliases.size() && i < names.size(); i++) { + names[i] = aliases[i]; } // Rename columns if duplicate names are detected idx_t index = 1; - vector names; + vector new_names; // Use a case-insensitive set to track names case_insensitive_set_t ci_names; - for (auto &n : result.names) { + for (auto &n : names) { string name = n; while (ci_names.find(name) != ci_names.end()) { name = n + "_" + std::to_string(index++); } - names.push_back(name); + new_names.push_back(name); ci_names.insert(name); } + names = std::move(new_names); +} + +BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement) { + BoundCTEData result; + + // first recursively visit the materialized CTE operations + // the left side is visited first and is added to the BindContext of the right side + D_ASSERT(statement.query); + + result.ctename = ctename; + result.materialized = statement.materialized; + result.setop_index = GenerateTableIndex(); - // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result.setop_index, statement.ctename, names, result.types); + // instead of eagerly binding the CTE here we add the CTE bind state to the list of CTE bindings + // the CTE is bound lazily - when referenced for the first time we perform the binding + result.cte_bind_state = make_shared_ptr(*this, *statement.query->node, statement.aliases); result.child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result.child_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, names, result.types); - - if (statement.child) { - // Move all modifiers to the child node. - for (auto &modifier : statement.modifiers) { - statement.child->modifiers.push_back(std::move(modifier)); - } + auto cte_binding = make_uniq(BindingAlias(ctename), result.cte_bind_state, result.setop_index); + result.child_binder->bind_context.AddCTEBinding(std::move(cte_binding)); + return result; +} - statement.modifiers.clear(); +BoundStatement Binder::FinishCTE(BoundCTEData &bound_cte, BoundStatement child) { + if (!bound_cte.cte_bind_state->IsBound()) { + // CTE was not bound - just ignore it + return child; + } + auto &bind_state = *bound_cte.cte_bind_state; + for (auto &c : bind_state.query_binder->correlated_columns) { + bound_cte.child_binder->AddCorrelatedColumn(c); + } - result.child = result.child_binder->BindNode(*statement.child); - for (auto &c : result.query_binder->correlated_columns) { - result.child_binder->AddCorrelatedColumn(c); - } + BoundStatement result; + // the result types of the CTE are the types of the LHS + result.types = child.types; + result.names = child.names; - // the result types of the CTE are the types of the LHS - result.types = result.child.types; - result.names = result.child.names; + MoveCorrelatedExpressions(*bound_cte.child_binder); + MoveCorrelatedExpressions(*bind_state.query_binder); - MoveCorrelatedExpressions(*result.child_binder); - } + auto cte_query = std::move(bind_state.query.plan); + auto cte_child = std::move(child.plan); - MoveCorrelatedExpressions(*result.query_binder); + auto root = make_uniq(bound_cte.ctename, bound_cte.setop_index, result.types.size(), + std::move(cte_query), std::move(cte_child), bound_cte.materialized); - BoundStatement result_statement; - result_statement.types = result.types; - result_statement.names = result.names; - result_statement.plan = CreatePlan(result); - return result_statement; + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || + bound_cte.child_binder->has_unplanned_dependent_joins || + bind_state.query_binder->has_unplanned_dependent_joins; + result.plan = std::move(root); + return result; } } // namespace duckdb diff --git a/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 8cb62ab73bec..7341d7edebd2 100644 --- a/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -3,13 +3,12 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" namespace duckdb { BoundStatement Binder::BindNode(RecursiveCTENode &statement) { - BoundRecursiveCTENode result; // first recursively visit the recursive CTE operations // the left side is visited first and is added to the BindContext of the right side @@ -19,49 +18,55 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { throw BinderException("UNION ALL cannot be used with USING KEY in recursive CTE."); } - result.ctename = statement.ctename; - result.union_all = statement.union_all; - result.setop_index = GenerateTableIndex(); + auto ctename = statement.ctename; + auto union_all = statement.union_all; + auto setop_index = GenerateTableIndex(); - result.left_binder = Binder::CreateBinder(context, this); - result.left = result.left_binder->BindNode(*statement.left); + auto left_binder = Binder::CreateBinder(context, this); + auto left = left_binder->BindNode(*statement.left); + BoundStatement result; // the result types of the CTE are the types of the LHS - result.types = result.left.types; + result.types = left.types; // names are picked from the LHS, unless aliases are explicitly specified - result.names = result.left.names; + result.names = left.names; for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { result.names[i] = statement.aliases[i]; } // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result.setop_index, statement.ctename, result.names, result.types); + bind_context.AddGenericBinding(setop_index, statement.ctename, result.names, result.types); - result.right_binder = Binder::CreateBinder(context, this); + auto right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - result.right_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, result.names, result.types, - !statement.key_targets.empty()); + BindingAlias cte_alias(statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(cte_alias), result.names, result.types); + if (!statement.key_targets.empty()) { + BindingAlias recurring_alias("recurring", statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(recurring_alias), result.names, result.types); + } - result.right = result.right_binder->BindNode(*statement.right); - for (auto &c : result.left_binder->correlated_columns) { - result.right_binder->AddCorrelatedColumn(c); + auto right = right_binder->BindNode(*statement.right); + for (auto &c : left_binder->correlated_columns) { + right_binder->AddCorrelatedColumn(c); } // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result.left_binder); - MoveCorrelatedExpressions(*result.right_binder); + MoveCorrelatedExpressions(*left_binder); + MoveCorrelatedExpressions(*right_binder); + vector> key_targets; // bind specified keys to the referenced column auto expression_binder = ExpressionBinder(*this, context); - for (unique_ptr &expr : statement.key_targets) { + for (auto &expr : statement.key_targets) { auto bound_expr = expression_binder.Bind(expr); D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - result.key_targets.push_back(std::move(bound_expr)); + key_targets.push_back(std::move(bound_expr)); } // now both sides have been bound we can resolve types - if (result.left.types.size() != result.right.types.size()) { + if (left.types.size() != right.types.size()) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -70,11 +75,42 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); } - BoundStatement result_statement; - result_statement.types = result.types; - result_statement.names = result.names; - result_statement.plan = CreatePlan(result); - return result_statement; + // Generate the logical plan for the left and right sides of the set operation + left_binder->is_outside_flattened = is_outside_flattened; + right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = std::move(left.plan); + auto right_node = std::move(right.plan); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || left_binder->has_unplanned_dependent_joins || + right_binder->has_unplanned_dependent_joins; + + // for both the left and right sides, cast them to the same types + left_node = CastLogicalOperatorToTypes(left.types, result.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(right.types, result.types, std::move(right_node)); + + auto recurring_binding = right_binder->GetCTEBinding(BindingAlias("recurring", ctename)); + bool ref_recurring = recurring_binding && recurring_binding->IsReferenced(); + if (key_targets.empty() && ref_recurring) { + throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); + } + + // Check if there is a reference to the recursive or recurring table, if not create a set operator. + auto cte_binding = right_binder->GetCTEBinding(BindingAlias(ctename)); + bool ref_cte = cte_binding && cte_binding->IsReferenced(); + if (!ref_cte && !ref_recurring) { + auto root = + make_uniq(setop_index, result.types.size(), std::move(left_node), + std::move(right_node), LogicalOperatorType::LOGICAL_UNION, union_all); + result.plan = std::move(root); + } else { + auto root = make_uniq(ctename, setop_index, result.types.size(), union_all, + std::move(key_targets), std::move(left_node), std::move(right_node)); + root->ref_recurring = ref_recurring; + result.plan = std::move(root); + } + return result; } } // namespace duckdb diff --git a/src/planner/binder/query_node/bind_select_node.cpp b/src/planner/binder/query_node/bind_select_node.cpp index a02f878a9aa9..7bfc4b22ebf3 100644 --- a/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/planner/binder/query_node/bind_select_node.cpp @@ -372,15 +372,6 @@ BoundStatement Binder::BindNode(SelectNode &statement) { return BindSelectNode(statement, std::move(from_table)); } -unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement) { - D_ASSERT(statement.from_table); - - // first bind the FROM table statement - auto from = std::move(statement.from_table); - auto from_table = Bind(*from); - return BindSelectNodeInternal(statement, std::move(from_table)); -} - void Binder::BindWhereStarExpression(unique_ptr &expr) { // expand any expressions in the upper AND recursively if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { @@ -412,7 +403,7 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table) { +BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); auto result_ptr = make_uniq(); @@ -688,16 +679,12 @@ unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions BindModifiers(result, result.projection_index, result.names, internal_sql_types, bind_state); - return result_ptr; -} - -BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { - auto result = BindSelectNodeInternal(statement, std::move(from_table)); BoundStatement result_statement; - result_statement.types = result->types; - result_statement.names = result->names; - result_statement.plan = CreatePlan(*result); + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.original_expressions = std::move(result.bind_state.original_expressions); return result_statement; } diff --git a/src/planner/binder/query_node/bind_setop_node.cpp b/src/planner/binder/query_node/bind_setop_node.cpp index d70f6d2ccffc..91a501b2fb40 100644 --- a/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/planner/binder/query_node/bind_setop_node.cpp @@ -10,42 +10,32 @@ #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/common/enum_util.hpp" namespace duckdb { -BoundSetOperationNode::~BoundSetOperationNode() { -} - struct SetOpAliasGatherer { public: explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { } - void GatherAliases(BoundSetOpChild &node, const vector &reorder_idx); - void GatherAliases(BoundSetOperationNode &node, const vector &reorder_idx); + void GatherAliases(BoundStatement &stmt, const vector &reorder_idx); + void GatherSetOpAliases(SetOperationType setop_type, const vector &names, + vector &bound_children, const vector &reorder_idx); private: SelectBindState &bind_state; }; -const vector &BoundSetOpChild::GetNames() { - return bound_node ? bound_node->names : node.names; -} -const vector &BoundSetOpChild::GetTypes() { - return bound_node ? bound_node->types : node.types; -} -idx_t BoundSetOpChild::GetRootIndex() { - return bound_node ? bound_node->GetRootIndex() : node.plan->GetRootIndex(); -} -void SetOpAliasGatherer::GatherAliases(BoundSetOpChild &node, const vector &reorder_idx) { - if (node.bound_node) { - GatherAliases(*node.bound_node, reorder_idx); +void SetOpAliasGatherer::GatherAliases(BoundStatement &stmt, const vector &reorder_idx) { + if (stmt.extra_info.setop_type != SetOperationType::NONE) { + GatherSetOpAliases(stmt.extra_info.setop_type, stmt.names, stmt.extra_info.bound_children, reorder_idx); return; } // query node - auto &select_names = node.GetNames(); + auto &select_names = stmt.names; // fill the alias lists with the names D_ASSERT(reorder_idx.size() == select_names.size()); for (idx_t i = 0; i < select_names.size(); i++) { @@ -61,8 +51,9 @@ void SetOpAliasGatherer::GatherAliases(BoundSetOpChild &node, const vector &reorder_idx) { +void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, + vector &bound_children, const vector &reorder_idx) { // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { + if (setop_type == SetOperationType::UNION_BY_NAME) { + auto &setop_names = stmt_names; // for UNION BY NAME - create a new re-order index case_insensitive_map_t reorder_map; - for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { - reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; + for (idx_t col_idx = 0; col_idx < setop_names.size(); ++col_idx) { + reorder_map[setop_names[col_idx]] = reorder_idx[col_idx]; } // use new reorder index - for (auto &child : setop.bound_children) { + for (auto &child : bound_children) { vector new_reorder_idx; - auto &child_names = child.GetNames(); + auto &child_names = child.names; for (idx_t col_idx = 0; col_idx < child_names.size(); col_idx++) { auto &col_name = child_names[col_idx]; auto entry = reorder_map.find(col_name); @@ -103,22 +96,23 @@ void SetOpAliasGatherer::GatherAliases(BoundSetOperationNode &setop, const vecto GatherAliases(child, new_reorder_idx); } } else { - for (auto &child : setop.bound_children) { + for (auto &child : bound_children) { GatherAliases(child, reorder_idx); } } } -static void GatherAliases(BoundSetOperationNode &node, SelectBindState &bind_state) { +static void GatherAliases(BoundSetOperationNode &root, vector &child_statements, + SelectBindState &bind_state) { SetOpAliasGatherer gatherer(bind_state); vector reorder_idx; - for (idx_t i = 0; i < node.names.size(); i++) { + for (idx_t i = 0; i < root.names.size(); i++) { reorder_idx.push_back(i); } - gatherer.GatherAliases(node, reorder_idx); + gatherer.GatherSetOpAliases(root.setop_type, root.names, child_statements, reorder_idx); } -static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode &result, bool can_contain_nulls) { +void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); vector> node_name_maps; case_insensitive_set_t global_name_set; @@ -127,7 +121,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { - auto &child_names = child.GetNames(); + auto &child_names = child.names; case_insensitive_map_t node_name_map; for (idx_t i = 0; i < child_names.size(); ++i) { auto &col_name = child_names[i]; @@ -155,7 +149,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & auto &col_name = result.names[i]; LogicalType result_type(LogicalTypeId::INVALID); for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { - auto &child_types = result.bound_children[child_idx].GetTypes(); + auto &child_types = result.bound_children[child_idx].types; auto &child_name_map = node_name_maps[child_idx]; // check if the column exists in this child node auto entry = child_name_map.find(col_name); @@ -191,6 +185,8 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & return; } // If reorder is required, generate the expressions for each node + vector>> reorder_expressions; + reorder_expressions.resize(result.bound_children.size()); for (idx_t i = 0; i < new_size; ++i) { auto &col_name = result.names[i]; for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { @@ -205,52 +201,42 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & } else { // the column exists - reference it auto col_idx_in_child = entry->second; - auto &child_col_type = child.GetTypes()[col_idx_in_child]; - expr = make_uniq(child_col_type, - ColumnBinding(child.GetRootIndex(), col_idx_in_child)); + auto &child_col_type = child.types[col_idx_in_child]; + auto root_idx = child.plan->GetRootIndex(); + expr = make_uniq(child_col_type, ColumnBinding(root_idx, col_idx_in_child)); } - child.reorder_expressions.push_back(std::move(expr)); + reorder_expressions[child_idx].push_back(std::move(expr)); } } -} - -BoundSetOpChild Binder::BindSetOpChild(QueryNode &child) { - BoundSetOpChild bound_child; - if (child.type == QueryNodeType::SET_OPERATION_NODE) { - bound_child.bound_node = BindSetOpNode(child.Cast()); - } else { - bound_child.binder = Binder::CreateBinder(context, this); - bound_child.binder->can_contain_nulls = true; - if (child.type == QueryNodeType::SELECT_NODE) { - auto &select_node = child.Cast(); - auto bound_select_node = bound_child.binder->BindSelectNodeInternal(select_node); - for (auto &expr : bound_select_node->bind_state.original_expressions) { - bound_child.select_list.push_back(expr->Copy()); - } - bound_child.node.names = bound_select_node->names; - bound_child.node.types = bound_select_node->types; - bound_child.node.plan = bound_child.binder->CreatePlan(*bound_select_node); - } else { - bound_child.node = bound_child.binder->BindNode(child); + // now push projections for each node + for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { + auto &child = result.bound_children[child_idx]; + auto &child_reorder_expressions = reorder_expressions[child_idx]; + // if we have re-order expressions push a projection + vector child_types; + for (auto &expr : child_reorder_expressions) { + child_types.push_back(expr->return_type); } + auto child_projection = + make_uniq(GenerateTableIndex(), std::move(child_reorder_expressions)); + child_projection->children.push_back(std::move(child.plan)); + child.plan = std::move(child_projection); + child.types = std::move(child_types); } - return bound_child; } -static void GatherSetOpBinders(BoundSetOpChild &setop_child, vector> &binders) { - if (setop_child.binder) { - binders.push_back(*setop_child.binder); - return; +static void GatherSetOpBinders(vector &children, vector> &binders, + vector> &result) { + for (auto &child_binder : binders) { + result.push_back(*child_binder); } - auto &setop_node = *setop_child.bound_node; - for (auto &child : setop_node.bound_children) { - GatherSetOpBinders(child, binders); + for (auto &child_node : children) { + GatherSetOpBinders(child_node.extra_info.bound_children, child_node.extra_info.child_binders, result); } } -unique_ptr Binder::BindSetOpNode(SetOperationNode &statement) { - auto result_ptr = make_uniq(); - auto &result = *result_ptr; +BoundStatement Binder::BindNode(SetOperationNode &statement) { + BoundSetOperationNode result; result.setop_type = statement.setop_type; result.setop_all = statement.setop_all; @@ -265,27 +251,23 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem throw InternalException("Set Operation type must have exactly 2 children - except for UNION/UNION_BY_NAME"); } for (auto &child : statement.children) { - result.bound_children.push_back(BindSetOpChild(*child)); - } - - vector> binders; - for (auto &child : result.bound_children) { - GatherSetOpBinders(child, binders); - } - // move the correlated expressions from the child binders to this binder - for (auto &child_binder : binders) { - MoveCorrelatedExpressions(child_binder.get()); + auto child_binder = Binder::CreateBinder(context, this); + child_binder->can_contain_nulls = true; + auto child_node = child_binder->BindNode(*child); + MoveCorrelatedExpressions(*child_binder); + result.bound_children.push_back(std::move(child_node)); + result.child_binders.push_back(std::move(child_binder)); } if (result.setop_type == SetOperationType::UNION_BY_NAME) { // UNION BY NAME - merge the columns from all sides - BuildUnionByNameInfo(context, result, can_contain_nulls); + BuildUnionByNameInfo(result); } else { // UNION ALL BY POSITION - the columns of both sides must match exactly - result.names = result.bound_children[0].GetNames(); - auto result_columns = result.bound_children[0].GetTypes().size(); + result.names = result.bound_children[0].names; + auto result_columns = result.bound_children[0].types.size(); for (idx_t i = 1; i < result.bound_children.size(); ++i) { - if (result.bound_children[i].GetTypes().size() != result_columns) { + if (result.bound_children[i].types.size() != result_columns) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -293,9 +275,9 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem // figure out the types of the setop result by picking the max of both for (idx_t i = 0; i < result_columns; i++) { - auto result_type = result.bound_children[0].GetTypes()[i]; + auto result_type = result.bound_children[0].types[i]; for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { - auto &child_types = result.bound_children[child_idx].GetTypes(); + auto &child_types = result.bound_children[child_idx].types; result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); } if (!can_contain_nulls) { @@ -310,7 +292,9 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem SelectBindState bind_state; if (!statement.modifiers.empty()) { // handle the ORDER BY/DISTINCT clauses - GatherAliases(result, bind_state); + vector> binders; + GatherSetOpBinders(result.bound_children, result.child_binders, binders); + GatherAliases(result, result.bound_children, bind_state); // now we perform the actual resolution of the ORDER BY/DISTINCT expressions OrderBinder order_binder(binders, bind_state); @@ -319,16 +303,14 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem // finally bind the types of the ORDER/DISTINCT clause expressions BindModifiers(result, result.setop_index, result.names, result.types, bind_state); - return result_ptr; -} - -BoundStatement Binder::BindNode(SetOperationNode &statement) { - auto result = BindSetOpNode(statement); BoundStatement result_statement; - result_statement.types = result->types; - result_statement.names = result->names; - result_statement.plan = CreatePlan(*result); + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.setop_type = statement.setop_type; + result_statement.extra_info.bound_children = std::move(result.bound_children); + result_statement.extra_info.child_binders = std::move(result.child_binders); return result_statement; } diff --git a/src/planner/binder/query_node/plan_cte_node.cpp b/src/planner/binder/query_node/plan_cte_node.cpp deleted file mode 100644 index dc4cc8770659..000000000000 --- a/src/planner/binder/query_node/plan_cte_node.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/operator/logical_materialized_cte.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTENode &node) { - // Generate the logical plan for the cte_query and child. - auto cte_query = std::move(node.query.plan); - auto cte_child = std::move(node.child.plan); - - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), - std::move(cte_query), std::move(cte_child), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/planner/binder/query_node/plan_recursive_cte_node.cpp deleted file mode 100644 index f51a03c50367..000000000000 --- a/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_recursive_cte.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = std::move(node.left.plan); - auto right_node = std::move(node.right.plan); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.left_binder->has_unplanned_dependent_joins || - node.right_binder->has_unplanned_dependent_joins; - - // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left.types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right.types, node.types, std::move(right_node)); - - auto recurring_binding = node.right_binder->GetCTEBinding("recurring." + node.ctename); - bool ref_recurring = recurring_binding && recurring_binding->Cast().reference_count > 0; - if (node.key_targets.empty() && ref_recurring) { - throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); - } - - // Check if there is a reference to the recursive or recurring table, if not create a set operator. - auto cte_binding = node.right_binder->GetCTEBinding(node.ctename); - bool ref_cte = cte_binding && cte_binding->Cast().reference_count > 0; - if (!ref_cte && !ref_recurring) { - auto root = - make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), LogicalOperatorType::LOGICAL_UNION, node.union_all); - return VisitQueryNode(node, std::move(root)); - } - - auto root = - make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, - std::move(node.key_targets), std::move(left_node), std::move(right_node)); - root->ref_recurring = ref_recurring; - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/planner/binder/query_node/plan_setop.cpp b/src/planner/binder/query_node/plan_setop.cpp index fec93aa51e8e..a1a7f60b0653 100644 --- a/src/planner/binder/query_node/plan_setop.cpp +++ b/src/planner/binder/query_node/plan_setop.cpp @@ -113,34 +113,16 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { D_ASSERT(node.bound_children.size() >= 2); vector> children; - for (auto &child : node.bound_children) { - unique_ptr child_node; - if (child.bound_node) { - child_node = CreatePlan(*child.bound_node); - } else { - child.binder->is_outside_flattened = is_outside_flattened; + for (idx_t child_idx = 0; child_idx < node.bound_children.size(); child_idx++) { + auto &child = node.bound_children[child_idx]; + auto &child_binder = *node.child_binders[child_idx]; - // construct the logical plan for the child node - child_node = std::move(child.node.plan); - } - if (!child.reorder_expressions.empty()) { - // if we have re-order expressions push a projection - vector child_types; - for (auto &expr : child.reorder_expressions) { - child_types.push_back(expr->return_type); - } - auto child_projection = - make_uniq(GenerateTableIndex(), std::move(child.reorder_expressions)); - child_projection->children.push_back(std::move(child_node)); - child_node = std::move(child_projection); - - child_node = CastLogicalOperatorToTypes(child_types, node.types, std::move(child_node)); - } else { - // otherwise push only casts - child_node = CastLogicalOperatorToTypes(child.GetTypes(), node.types, std::move(child_node)); - } + // construct the logical plan for the child node + auto child_node = std::move(child.plan); + // push casts for the target types + child_node = CastLogicalOperatorToTypes(child.types, node.types, std::move(child_node)); // check if there are any unplanned subqueries left in any child - if (child.binder && child.binder->has_unplanned_dependent_joins) { + if (child_binder.has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } children.push_back(std::move(child_node)); diff --git a/src/planner/binder/statement/bind_copy.cpp b/src/planner/binder/statement/bind_copy.cpp index b7881a0a1d67..1757a2110a0b 100644 --- a/src/planner/binder/statement/bind_copy.cpp +++ b/src/planner/binder/statement/bind_copy.cpp @@ -36,7 +36,7 @@ void IsFormatExtensionKnown(const string &format) { // It's a match, we must throw throw CatalogException( "Copy Function with name \"%s\" is not in the catalog, but it exists in the %s extension.", format, - file_postfixes.extension); + std::string(file_postfixes.extension)); } } } @@ -551,8 +551,8 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { // check if this matches the mode if (copy_option.mode != CopyOptionMode::READ_WRITE && copy_option.mode != copy_mode) { throw InvalidInputException("Option \"%s\" is not supported for %s - only for %s", provided_option, - stmt.info->is_from ? "reading" : "writing", - stmt.info->is_from ? "writing" : "reading"); + std::string(stmt.info->is_from ? "reading" : "writing"), + std::string(stmt.info->is_from ? "writing" : "reading")); } if (copy_option.type.id() != LogicalTypeId::ANY) { if (provided_entry.second.empty()) { diff --git a/src/planner/binder/statement/bind_create.cpp b/src/planner/binder/statement/bind_create.cpp index c63e9bf107ce..91a03aadb49a 100644 --- a/src/planner/binder/statement/bind_create.cpp +++ b/src/planner/binder/statement/bind_create.cpp @@ -119,11 +119,11 @@ void Binder::SearchSchema(CreateInfo &info) { if (!info.temporary) { // non-temporary create: not read only if (info.catalog == TEMP_CATALOG) { - throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } else { if (info.catalog != TEMP_CATALOG) { - throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } } diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index ad70fe14abc0..f27ee73c6f77 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -289,7 +289,7 @@ void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { col.SetType(bound_expression->return_type); // Update the type in the binding, for future expansions - table_binding->types[i.index] = col.Type(); + table_binding->SetColumnType(i.index, col.Type()); } bound_indices.insert(i); } diff --git a/src/planner/binder/statement/bind_merge_into.cpp b/src/planner/binder/statement/bind_merge_into.cpp index 1dd59c480039..62867f3525fb 100644 --- a/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/planner/binder/statement/bind_merge_into.cpp @@ -200,9 +200,10 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { vector source_names; for (auto &binding_entry : source_binder->bind_context.GetBindingsList()) { auto &binding = *binding_entry; - for (idx_t c = 0; c < binding.names.size(); c++) { - source_aliases.push_back(binding.alias); - source_names.push_back(binding.names[c]); + auto &column_names = binding.GetColumnNames(); + for (idx_t c = 0; c < column_names.size(); c++) { + source_aliases.push_back(binding.GetBindingAlias()); + source_names.push_back(column_names[c]); } } diff --git a/src/planner/binder/statement/bind_pragma.cpp b/src/planner/binder/statement/bind_pragma.cpp index 3955cf89753e..b5fc04677b1c 100644 --- a/src/planner/binder/statement/bind_pragma.cpp +++ b/src/planner/binder/statement/bind_pragma.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/planner/operator/logical_pragma.hpp" #include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" @@ -28,16 +29,32 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte } // bind the pragma function - auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); + auto entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + // try to find whether a table extry might exist + auto table_entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, + info.name, OnEntryNotFound::RETURN_NULL); + if (table_entry) { + // there is a table entry with the same name, now throw more explicit error message + throw CatalogException("Pragma Function with name %s does not exist, but a table function with the same " + "name exists, try `CALL %s(...)`", + info.name, info.name); + } + // rebind to throw exception + entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::THROW_EXCEPTION); + } + FunctionBinder function_binder(*this); ErrorData error; - auto bound_idx = function_binder.BindFunction(entry.name, entry.functions, params, error); + auto bound_idx = function_binder.BindFunction(entry->name, entry->functions, params, error); if (!bound_idx.IsValid()) { D_ASSERT(error.HasError()); error.AddQueryLocation(error_context); error.Throw(); } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx.GetIndex()); + auto bound_function = entry->functions.GetFunctionByOffset(bound_idx.GetIndex()); // bind and check named params BindNamedParameters(bound_function.named_parameters, named_parameters, error_context, bound_function.name); return make_uniq(std::move(bound_function), std::move(params), std::move(named_parameters)); diff --git a/src/planner/binder/tableref/bind_basetableref.cpp b/src/planner/binder/tableref/bind_basetableref.cpp index de775a1980bb..25d0b7b2a082 100644 --- a/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/planner/binder/tableref/bind_basetableref.cpp @@ -121,53 +121,29 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { // CTE name should never be qualified (i.e. schema_name should be empty) // unless we want to refer to the recurring table of "using key". - auto ctebinding = GetCTEBinding(ref.table_name); - if (ctebinding) { + BindingAlias binding_alias(ref.schema_name, ref.table_name); + auto ctebinding = GetCTEBinding(binding_alias); + if (ctebinding && ctebinding->CanBeReferenced()) { + ctebinding->Reference(); + // There is a CTE binding in the BindContext. // This can only be the case if there is a recursive CTE, // or a materialized CTE present. auto index = GenerateTableIndex(); - if (ref.schema_name == "recurring") { - auto recurring_bindings = GetCTEBinding("recurring." + ref.table_name); - if (!recurring_bindings) { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query." - " Hint: RECURRING can only be used with USING KEY in recursive CTE.", - ref.table_name); - } - } - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - auto names = BindContext::AliasColumnNames(alias, ctebinding->names, ref.column_name_alias); - - bind_context.AddGenericBinding(index, alias, names, ctebinding->types); - - auto cte_ref = reference(ctebinding->Cast()); - if (!ref.schema_name.empty()) { - auto cte_reference = ref.schema_name + "." + ref.table_name; - auto recurring_ref = GetCTEBinding(cte_reference); - if (!recurring_ref) { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query.", - ref.table_name); - } - cte_ref = reference(recurring_ref->Cast()); - } + auto names = BindContext::AliasColumnNames(alias, ctebinding->GetColumnNames(), ref.column_name_alias); + + bind_context.AddGenericBinding(index, alias, names, ctebinding->GetColumnTypes()); - // Update references to CTE - cte_ref.get().reference_count++; bool is_recurring = ref.schema_name == "recurring"; BoundStatement result; - result.types = ctebinding->types; + result.types = ctebinding->GetColumnTypes(); result.names = names; result.plan = - make_uniq(index, ctebinding->index, ctebinding->types, std::move(names), is_recurring); + make_uniq(index, ctebinding->GetIndex(), result.types, std::move(names), is_recurring); return result; - ; } // not a CTE @@ -231,17 +207,13 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { } } - // remember that we did not find a CTE, but there is a CTE with the same name - // this means that there is a circular reference - // Otherwise, re-throw the original exception - if (!ctebinding && ref.schema_name.empty() && CTEExists(ref.table_name)) { - throw BinderException( - error_context, - "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " - "use recursive CTEs. \n2. If " - "you want to use the TABLE name \"%s\" the same as the CTE name, please explicitly add " - "\"SCHEMA\" before table name. You can try \"main.%s\" (main is the duckdb default schema)", - ref.table_name, ref.table_name, ref.table_name); + // if we found a CTE that cannot be referenced that means that there is a circular reference + if (ctebinding) { + D_ASSERT(!ctebinding->CanBeReferenced()); + throw BinderException(error_context, + "Circular reference to CTE \"%s\", use WITH RECURSIVE to " + "use recursive CTEs.", + ref.table_name); } // could not find an alternative: bind again to get the error // note: this will always throw when using DuckDB as a catalog, but a second look-up might succeed @@ -312,29 +284,6 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { // The view may contain CTEs, but maybe only in the cte_map, so we need create CTE nodes for them auto query = view_catalog_entry.GetQuery().Copy(); - auto &select_stmt = query->Cast(); - - vector> materialized_ctes; - for (auto &cte : select_stmt.node->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt.node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - select_stmt.node = std::move(root); - SubqueryRef subquery(unique_ptr_cast(std::move(query))); subquery.alias = ref.alias; diff --git a/src/planner/binder/tableref/bind_joinref.cpp b/src/planner/binder/tableref/bind_joinref.cpp index 258bd3331117..0a6420bfddef 100644 --- a/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/planner/binder/tableref/bind_joinref.cpp @@ -55,7 +55,7 @@ bool Binder::TryFindBinding(const string &using_column, const string &join_side, } throw BinderException(error); } else { - result = binding.get().alias; + result = binding.get().GetBindingAlias(); } } return true; @@ -188,7 +188,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { case_insensitive_set_t lhs_columns; auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); for (auto &binding : lhs_binding_list) { - for (auto &column_name : binding->names) { + for (auto &column_name : binding->GetColumnNames()) { lhs_columns.insert(column_name); } } @@ -215,7 +215,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { auto &rhs_binding_list = right_binder.bind_context.GetBindingsList(); for (auto &binding_ref : lhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!left_candidates.empty()) { left_candidates += ", "; } @@ -224,7 +224,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { } for (auto &binding_ref : rhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!right_candidates.empty()) { right_candidates += ", "; } diff --git a/src/planner/collation_binding.cpp b/src/planner/collation_binding.cpp index 1ddefb9a894f..dd371bbc4284 100644 --- a/src/planner/collation_binding.cpp +++ b/src/planner/collation_binding.cpp @@ -8,6 +8,7 @@ #include "duckdb/function/function_binder.hpp" namespace duckdb { +constexpr const char *CollateCatalogEntry::Name; bool PushVarcharCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType type) { @@ -109,11 +110,34 @@ bool PushIntervalCollation(ClientContext &context, unique_ptr &sourc return true; } +bool PushVariantCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, + CollationType) { + if (sql_type.id() != LogicalTypeId::VARIANT) { + return false; + } + auto &catalog = Catalog::GetSystemCatalog(context); + auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "variant_normalize"); + if (function_entry.functions.Size() != 1) { + throw InternalException("variant_normalize should only have a single overload"); + } + auto source_alias = source->GetAlias(); + auto &scalar_function = function_entry.functions.GetFunctionReferenceByOffset(0); + vector> children; + children.push_back(std::move(source)); + + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); + function->SetAlias(source_alias); + source = std::move(function); + return true; +} + // timetz_byte_comparable CollationBinding::CollationBinding() { RegisterCollation(CollationCallback(PushVarcharCollation)); RegisterCollation(CollationCallback(PushTimeTZCollation)); RegisterCollation(CollationCallback(PushIntervalCollation)); + RegisterCollation(CollationCallback(PushVariantCollation)); } void CollationBinding::RegisterCollation(CollationCallback callback) { diff --git a/src/planner/expression_iterator.cpp b/src/planner/expression_iterator.cpp index 9f67f915c0c4..3d14079001de 100644 --- a/src/planner/expression_iterator.cpp +++ b/src/planner/expression_iterator.cpp @@ -4,8 +4,6 @@ #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/tableref/list.hpp" #include "duckdb/common/enum_util.hpp" diff --git a/src/planner/table_binding.cpp b/src/planner/table_binding.cpp index c60513cc94a3..c55d0be822b4 100644 --- a/src/planner/table_binding.cpp +++ b/src/planner/table_binding.cpp @@ -19,6 +19,10 @@ Binding::Binding(BindingType binding_type, BindingAlias alias_p, vector &Binding::GetColumnTypes() { + return types; +} + +const vector &Binding::GetColumnNames() { + return names; +} + +idx_t Binding::GetColumnCount() { + return GetColumnNames().size(); +} + +void Binding::SetColumnType(idx_t col_idx, LogicalType type_p) { + types[col_idx] = std::move(type_p); +} + string Binding::GetAlias() const { return alias.GetAlias(); } @@ -304,8 +336,42 @@ unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colre return arg; } -CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, idx_t index) - : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), reference_count(0) { +CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, + CTEType cte_type) + : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), cte_type(cte_type), + reference_count(0) { +} + +CTEBinding::CTEBinding(BindingAlias alias_p, shared_ptr bind_state_p, idx_t index) + : Binding(BindingType::CTE, std::move(alias_p), vector(), vector(), index), + cte_type(CTEType::CAN_BE_REFERENCED), reference_count(0), bind_state(std::move(bind_state_p)) { +} + +bool CTEBinding::CanBeReferenced() const { + return cte_type == CTEType::CAN_BE_REFERENCED; +} + +bool CTEBinding::IsReferenced() const { + return reference_count > 0; +} + +void CTEBinding::Reference() { + if (!CanBeReferenced()) { + throw InternalException("CTE cannot be referenced!"); + } + if (bind_state) { + // we have not bound the CTE yet - bind it + bind_state->Bind(*this); + + // copy over the names / types and initialize the binding + this->names = bind_state->names; + this->types = bind_state->types; + Initialize(); + + // finalize binding + bind_state.reset(); + } + reference_count++; } } // namespace duckdb diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index ae3550c5eb7b..f3188662ccfe 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -19,6 +19,7 @@ namespace duckdb { +constexpr const idx_t BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; static constexpr const idx_t BITPACKING_METADATA_GROUP_SIZE = STANDARD_VECTOR_SIZE > 512 ? STANDARD_VECTOR_SIZE : 2048; BitpackingMode BitpackingModeFromString(const string &str) { diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 75f8dd694bc3..7cc6bf3e4dca 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -253,7 +253,7 @@ void DataTable::InitializeScanWithOffset(DuckTransaction &transaction, TableScan const vector &column_ids, idx_t start_row, idx_t end_row) { state.checkpoint_lock = transaction.SharedLockTable(*info); state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(transaction.context, state.table_state, column_ids, start_row, end_row); + row_groups->InitializeScanWithOffset(QueryContext(), state.table_state, column_ids, start_row, end_row); } idx_t DataTable::GetRowGroupSize() const { @@ -1544,7 +1544,7 @@ void DataTable::Update(TableUpdateState &state, ClientContext &context, Vector & row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); row_ids_slice.Flatten(n_global_update); - row_groups->Update(transaction, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); + row_groups->Update(transaction, *this, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); } } @@ -1568,7 +1568,7 @@ void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, V updates.Flatten(); row_ids.Flatten(updates.size()); - row_groups->UpdateColumn(transaction, row_ids, column_path, updates); + row_groups->UpdateColumn(transaction, *this, row_ids, column_path, updates); } //===--------------------------------------------------------------------===// diff --git a/src/storage/local_storage.cpp b/src/storage/local_storage.cpp index 4c58c2d5d179..39612b09b85c 100644 --- a/src/storage/local_storage.cpp +++ b/src/storage/local_storage.cpp @@ -580,7 +580,7 @@ void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector(row_ids); - storage->GetCollection().Update(TransactionData(0, 0), ids, column_ids, updates); + storage->GetCollection().Update(TransactionData(0, 0), table, ids, column_ids, updates); } void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ptr commit_state) { diff --git a/src/storage/serialization/serialize_query_node.cpp b/src/storage/serialization/serialize_query_node.cpp index 50ab535d252a..25b167558c57 100644 --- a/src/storage/serialization/serialize_query_node.cpp +++ b/src/storage/serialization/serialize_query_node.cpp @@ -38,6 +38,9 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { } result->modifiers = std::move(modifiers); result->cte_map = std::move(cte_map); + if (type == QueryNodeType::CTE_NODE) { + result = std::move(result->Cast().child); + } return result; } diff --git a/src/storage/standard_buffer_manager.cpp b/src/storage/standard_buffer_manager.cpp index 73705c43afa5..04260029b04d 100644 --- a/src/storage/standard_buffer_manager.cpp +++ b/src/storage/standard_buffer_manager.cpp @@ -642,6 +642,10 @@ bool StandardBufferManager::HasFilesInTemporaryDirectory() const { return found; } +BlockManager &StandardBufferManager::GetTemporaryBlockManager() { + return *temp_block_manager; +} + vector StandardBufferManager::GetTemporaryFiles() { vector result; if (temporary_directory.path.empty()) { diff --git a/src/storage/storage_info.cpp b/src/storage/storage_info.cpp index 616aa3039672..df847fc9d58e 100644 --- a/src/storage/storage_info.cpp +++ b/src/storage/storage_info.cpp @@ -4,6 +4,10 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { +constexpr idx_t Storage::MAX_ROW_GROUP_SIZE; +constexpr idx_t Storage::MAX_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::MIN_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::DEFAULT_BLOCK_HEADER_SIZE; const uint64_t VERSION_NUMBER = 64; const uint64_t VERSION_NUMBER_LOWER = 64; diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index d82905452ad9..ebf0472e05b5 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/checkpoint_manager.hpp" #include "duckdb/storage/in_memory_block_manager.hpp" @@ -322,13 +323,40 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { } } - // load the db from storage + // Start timing the storage load step. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::ATTACH_LOAD_STORAGE_LATENCY); + } + + // Load the checkpoint from storage. auto checkpoint_reader = SingleFileCheckpointReader(*this); checkpoint_reader.LoadFromStorage(); + // End timing the storage load step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::ATTACH_LOAD_STORAGE_LATENCY); + } + + // Start timing the WAL replay step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::ATTACH_REPLAY_WAL_LATENCY); + } + + // Replay the WAL. auto wal_path = GetWALPath(); wal = WriteAheadLog::Replay(fs, db, wal_path); + + // End timing the WAL replay step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::ATTACH_REPLAY_WAL_LATENCY); + } } + if (row_group_size > 122880ULL && GetStorageVersion() < 4) { throw InvalidInputException("Unsupported row group size %llu - row group sizes >= 122_880 are only supported " "with STORAGE_VERSION '1.2.0' or above.\nExplicitly specify a newer storage " @@ -476,17 +504,35 @@ void SingleFileStorageManager::CreateCheckpoint(QueryContext context, Checkpoint if (db.GetStorageExtension()) { db.GetStorageExtension()->OnCheckpointStart(db, options); } + auto &config = DBConfig::Get(db); + // We only need to checkpoint if there is anything in the WAL. if (GetWALSize() > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { - // we only need to checkpoint if there is anything in the WAL try { + + // Start timing the checkpoint. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::CHECKPOINT_LATENCY); + } + + // Write the checkpoint. auto checkpointer = CreateCheckpointWriter(context, options); checkpointer->CreateCheckpoint(); + + // End timing the checkpoint. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::CHECKPOINT_LATENCY); + } + } catch (std::exception &ex) { ErrorData error(ex); throw FatalException("Failed to create checkpoint because of error: %s", error.RawMessage()); } } + if (!InMemory() && options.wal_action == CheckpointWALAction::DELETE_WAL) { ResetWAL(); } diff --git a/src/storage/table/array_column_data.cpp b/src/storage/table/array_column_data.cpp index d92562a9406c..849e1dec86d7 100644 --- a/src/storage/table/array_column_data.cpp +++ b/src/storage/table/array_column_data.cpp @@ -224,13 +224,14 @@ idx_t ArrayColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resul throw NotImplementedException("Array Fetch"); } -void ArrayColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ArrayColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("Array Update is not supported."); } -void ArrayColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ArrayColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("Array Update Column is not supported"); } diff --git a/src/storage/table/chunk_info.cpp b/src/storage/table/chunk_info.cpp index 3b7b11d7be84..702b4beb6847 100644 --- a/src/storage/table/chunk_info.cpp +++ b/src/storage/table/chunk_info.cpp @@ -1,10 +1,12 @@ #include "duckdb/storage/table/chunk_info.hpp" + #include "duckdb/transaction/transaction.hpp" #include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/transaction/delete_info.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { @@ -40,7 +42,7 @@ void ChunkInfo::Write(WriteStream &writer) const { writer.Write(type); } -unique_ptr ChunkInfo::Read(ReadStream &reader) { +unique_ptr ChunkInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto type = reader.Read(); switch (type) { case ChunkInfoType::EMPTY_INFO: @@ -48,7 +50,7 @@ unique_ptr ChunkInfo::Read(ReadStream &reader) { case ChunkInfoType::CONSTANT_INFO: return ChunkConstantInfo::Read(reader); case ChunkInfoType::VECTOR_INFO: - return ChunkVectorInfo::Read(reader); + return ChunkVectorInfo::Read(allocator, reader); default: throw SerializationException("Could not deserialize Chunk Info Type: unrecognized type"); } @@ -71,7 +73,7 @@ idx_t ChunkConstantInfo::TemplatedGetSelVector(transaction_t start_time, transac return 0; } -idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return TemplatedGetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } @@ -95,7 +97,7 @@ bool ChunkConstantInfo::HasDeletes() const { return is_deleted; } -idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) { +idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) const { return delete_id < TRANSACTION_ID_START ? max_count : 0; } @@ -128,49 +130,70 @@ unique_ptr ChunkConstantInfo::Read(ReadStream &reader) { //===--------------------------------------------------------------------===// // Vector info //===--------------------------------------------------------------------===// -ChunkVectorInfo::ChunkVectorInfo(idx_t start) - : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), insert_id(0), same_inserted_id(true), any_deleted(false) { - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - inserted[i] = 0; - deleted[i] = NOT_DELETED_ID; +ChunkVectorInfo::ChunkVectorInfo(FixedSizeAllocator &allocator_p, idx_t start, transaction_t insert_id_p) + : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), allocator(allocator_p), constant_insert_id(insert_id_p) { +} + +ChunkVectorInfo::~ChunkVectorInfo() { + if (AnyDeleted()) { + allocator.Free(deleted_data); + } + if (!HasConstantInsertionId()) { + allocator.Free(inserted_data); } } template idx_t ChunkVectorInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const { - idx_t count = 0; - if (same_inserted_id && !any_deleted) { - // all tuples have the same inserted id: and no tuples were deleted - if (OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { - return max_count; - } else { - return 0; + if (HasConstantInsertionId()) { + if (!AnyDeleted()) { + // all tuples have the same inserted id: and no tuples were deleted + if (OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { + return max_count; + } else { + return 0; + } } - } else if (same_inserted_id) { - if (!OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { + if (!OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { return 0; } // have to check deleted flag + idx_t count = 0; + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < max_count; i++) { if (OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { sel_vector.set_index(count++, i); } } - } else if (!any_deleted) { + return count; + } + if (!AnyDeleted()) { // have to check inserted flag + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + idx_t count = 0; for (idx_t i = 0; i < max_count; i++) { if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i])) { sel_vector.set_index(count++, i); } } - } else { - // have to check both flags - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && - OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { - sel_vector.set_index(count++, i); - } + return count; + } + + idx_t count = 0; + // have to check both flags + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && + OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { + sel_vector.set_index(count++, i); } } return count; @@ -186,16 +209,76 @@ idx_t ChunkVectorInfo::GetCommittedSelVector(transaction_t min_start_id, transac return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); } -idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return GetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } bool ChunkVectorInfo::Fetch(TransactionData transaction, row_t row) { - return UseVersion(transaction, inserted[row]) && !UseVersion(transaction, deleted[row]); + transaction_t fetch_insert_id; + transaction_t fetch_deleted_id; + if (HasConstantInsertionId()) { + fetch_insert_id = ConstantInsertId(); + } else { + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + fetch_insert_id = inserted[row]; + } + if (!AnyDeleted()) { + fetch_deleted_id = NOT_DELETED_ID; + } else { + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + fetch_deleted_id = deleted[row]; + } + + return UseVersion(transaction, fetch_insert_id) && !UseVersion(transaction, fetch_deleted_id); +} + +IndexPointer ChunkVectorInfo::GetInsertedPointer() const { + if (HasConstantInsertionId()) { + throw InternalException("ChunkVectorInfo: insert id requested but insertions were not initialized"); + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetDeletedPointer() const { + if (!AnyDeleted()) { + throw InternalException("ChunkVectorInfo: deleted id requested but deletions were not initialized"); + } + return deleted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedInsertedPointer() { + if (HasConstantInsertionId()) { + transaction_t constant_id = ConstantInsertId(); + + inserted_data = allocator.New(); + inserted_data.SetMetadata(1); + auto segment = allocator.GetHandle(inserted_data); + auto inserted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + inserted[i] = constant_id; + } + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedDeletedPointer() { + if (!AnyDeleted()) { + deleted_data = allocator.New(); + deleted_data.SetMetadata(1); + auto segment = allocator.GetHandle(deleted_data); + auto deleted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + deleted[i] = NOT_DELETED_ID; + } + } + return deleted_data; } idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t count) { - any_deleted = true; + auto segment = allocator.GetHandle(GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); idx_t deleted_tuples = 0; for (idx_t i = 0; i < count; i++) { @@ -220,6 +303,9 @@ idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t } void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &info) { + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); + if (info.is_consecutive) { for (idx_t i = 0; i < info.count; i++) { deleted[i] = commit_id; @@ -234,32 +320,45 @@ void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &in void ChunkVectorInfo::Append(idx_t start, idx_t end, transaction_t commit_id) { if (start == 0) { - insert_id = commit_id; - } else if (insert_id != commit_id) { - same_inserted_id = false; - insert_id = NOT_DELETED_ID; + // first insert to this vector - just assign the commit id + constant_insert_id = commit_id; + return; + } + if (HasConstantInsertionId() && ConstantInsertId() == commit_id) { + // we are inserting again, but we have the same id as before - still the same insert id + return; } + + auto segment = allocator.GetHandle(GetInitializedInsertedPointer()); + auto inserted = segment.GetPtr(); for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } void ChunkVectorInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { - if (same_inserted_id) { - insert_id = commit_id; + if (HasConstantInsertionId()) { + constant_insert_id = commit_id; + return; } + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr &result) const { - if (any_deleted) { + if (AnyDeleted()) { // if any rows are deleted we can't clean-up return false; } // check if the insertion markers have to be used by all transactions going forward - if (!same_inserted_id) { + if (!HasConstantInsertionId()) { + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t idx = 1; idx < STANDARD_VECTOR_SIZE; idx++) { if (inserted[idx] > lowest_transaction) { // transaction was inserted after the lowest transaction start @@ -267,7 +366,7 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr lowest_transaction) { + } else if (ConstantInsertId() > lowest_transaction) { // transaction was inserted after the lowest transaction start // we still need to use an older version - cannot compress return false; @@ -276,13 +375,31 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr(); + idx_t delete_count = 0; for (idx_t i = 0; i < max_count; i++) { if (deleted[i] < TRANSACTION_ID_START) { @@ -319,15 +436,17 @@ void ChunkVectorInfo::Write(WriteStream &writer) const { mask.Write(writer, STANDARD_VECTOR_SIZE); } -unique_ptr ChunkVectorInfo::Read(ReadStream &reader) { +unique_ptr ChunkVectorInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto start = reader.Read(); - auto result = make_uniq(start); - result->any_deleted = true; + auto result = make_uniq(allocator, start); ValidityMask mask; mask.Read(reader, STANDARD_VECTOR_SIZE); + + auto segment = allocator.GetHandle(result->GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { if (mask.RowIsValid(i)) { - result->deleted[i] = 0; + deleted[i] = 0; } } return std::move(result); diff --git a/src/storage/table/column_data.cpp b/src/storage/table/column_data.cpp index c38fff709b1e..2b48d90cff11 100644 --- a/src/storage/table/column_data.cpp +++ b/src/storage/table/column_data.cpp @@ -293,13 +293,13 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); } -void ColumnData::UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector) { +void ColumnData::UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, Vector &base_vector) { lock_guard update_guard(update_lock); if (!updates) { updates = make_uniq(*this); } - updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); + updates->Update(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, @@ -578,20 +578,20 @@ idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector return fetch_count; } -void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) { Vector base_vector(type); ColumnScanState state; FetchUpdateData(state, row_ids, base_vector); - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) { +void ColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { // this method should only be called at the end of the path in the base column case D_ASSERT(depth >= column_path.size()); - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { diff --git a/src/storage/table/list_column_data.cpp b/src/storage/table/list_column_data.cpp index 986b32dc7c1c..5672482c73a0 100644 --- a/src/storage/table/list_column_data.cpp +++ b/src/storage/table/list_column_data.cpp @@ -263,13 +263,14 @@ idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result throw NotImplementedException("List Fetch"); } -void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ListColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("List Update is not supported."); } -void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ListColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("List Update Column is not supported"); } diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 9425a09584bf..169d5c24ed51 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -723,7 +723,8 @@ shared_ptr RowGroup::GetOrCreateVersionInfoInternal() { // version info does not exist - need to create it lock_guard lock(row_group_lock); if (!owned_version_info) { - auto new_info = make_shared_ptr(start); + auto &buffer_manager = GetBlockManager().GetBufferManager(); + auto new_info = make_shared_ptr(buffer_manager, start); SetVersionInfo(std::move(new_info)); } return owned_version_info; @@ -854,8 +855,8 @@ void RowGroup::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_ vinfo.CleanupAppend(lowest_transaction, start, count); } -void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids) { +void RowGroup::Update(TransactionData transaction, DataTable &data_table, DataChunk &update_chunk, row_t *ids, + idx_t offset, idx_t count, const vector &column_ids) { #ifdef DEBUG for (size_t i = offset; i < offset + count; i++) { D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); @@ -868,16 +869,16 @@ void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_ if (offset > 0) { Vector sliced_vector(update_chunk.data[i], offset, offset + count); sliced_vector.Flatten(count); - col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); + col_data.Update(transaction, data_table, column.index, sliced_vector, ids + offset, count); } else { - col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); + col_data.Update(transaction, data_table, column.index, update_chunk.data[i], ids, count); } MergeStatistics(column.index, *col_data.GetUpdateStatistics()); } } -void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path) { +void RowGroup::UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path) { D_ASSERT(updates.ColumnCount() == 1); auto ids = FlatVector::GetData(row_ids); @@ -887,9 +888,9 @@ void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vec if (offset > 0) { Vector sliced_vector(updates.data[0], offset, offset + count); sliced_vector.Flatten(count); - col_data.UpdateColumn(transaction, column_path, sliced_vector, ids + offset, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, sliced_vector, ids + offset, count, 1); } else { - col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, updates.data[0], ids, count, 1); } MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); } diff --git a/src/storage/table/row_group_collection.cpp b/src/storage/table/row_group_collection.cpp index 42c453ea0699..c0dc479ad842 100644 --- a/src/storage/table/row_group_collection.cpp +++ b/src/storage/table/row_group_collection.cpp @@ -271,7 +271,7 @@ bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t return row_group; } -void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, - DataChunk &updates) { +void RowGroupCollection::Update(TransactionData transaction, DataTable &data_table, row_t *ids, + const vector &column_ids, DataChunk &updates) { D_ASSERT(updates.size() >= 1); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->Update(transaction, updates, ids, start, pos - start, column_ids); + row_group->Update(transaction, data_table, updates, ids, start, pos - start, column_ids); auto l = stats.GetLock(); for (idx_t i = 0; i < column_ids.size(); i++) { @@ -770,15 +770,15 @@ void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableInd } } -void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates) { +void RowGroupCollection::UpdateColumn(TransactionData transaction, DataTable &data_table, Vector &row_ids, + const vector &column_path, DataChunk &updates) { D_ASSERT(updates.size() >= 1); auto ids = FlatVector::GetData(row_ids); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->UpdateColumn(transaction, updates, row_ids, start, pos - start, column_path); + row_group->UpdateColumn(transaction, data_table, updates, row_ids, start, pos - start, column_path); auto lock = stats.GetLock(); auto primary_column_idx = column_path[0]; diff --git a/src/storage/table/row_id_column_data.cpp b/src/storage/table/row_id_column_data.cpp index d869913bf8da..4bc3c4148ded 100644 --- a/src/storage/table/row_id_column_data.cpp +++ b/src/storage/table/row_id_column_data.cpp @@ -138,13 +138,14 @@ void RowIdColumnData::RevertAppend(row_t start_row) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void RowIdColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw InternalException("RowIdColumnData cannot be updated"); } -void RowIdColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void RowIdColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw InternalException("RowIdColumnData cannot be updated"); } diff --git a/src/storage/table/row_version_manager.cpp b/src/storage/table/row_version_manager.cpp index df4e463da2d1..67ebbfb8653b 100644 --- a/src/storage/table/row_version_manager.cpp +++ b/src/storage/table/row_version_manager.cpp @@ -7,7 +7,10 @@ namespace duckdb { -RowVersionManager::RowVersionManager(idx_t start) noexcept : start(start), has_changes(false) { +RowVersionManager::RowVersionManager(BufferManager &buffer_manager_p, idx_t start) noexcept + : allocator(STANDARD_VECTOR_SIZE * sizeof(transaction_t), buffer_manager_p.GetTemporaryBlockManager(), + MemoryTag::BASE_TABLE), + start(start), has_changes(false) { } void RowVersionManager::SetStart(idx_t new_start) { @@ -112,7 +115,7 @@ void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t cou optional_ptr new_info; if (!vector_info[vector_idx]) { // first time appending to this vector: create new info - auto insert_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + auto insert_info = make_uniq(allocator, start + vector_idx * STANDARD_VECTOR_SIZE); new_info = insert_info.get(); vector_info[vector_idx] = std::move(insert_info); } else if (vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO) { @@ -188,15 +191,12 @@ ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { if (!vector_info[vector_idx]) { // no info yet: create it - vector_info[vector_idx] = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + vector_info[vector_idx] = make_uniq(allocator, start + vector_idx * STANDARD_VECTOR_SIZE); } else if (vector_info[vector_idx]->type == ChunkInfoType::CONSTANT_INFO) { auto &constant = vector_info[vector_idx]->Cast(); // info exists but it's a constant info: convert to a vector info - auto new_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - new_info->insert_id = constant.insert_id; - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - new_info->inserted[i] = constant.insert_id; - } + auto new_info = + make_uniq(allocator, start + vector_idx * STANDARD_VECTOR_SIZE, constant.insert_id); vector_info[vector_idx] = std::move(new_info); } D_ASSERT(vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO); @@ -262,7 +262,7 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de if (!delete_pointer.IsValid()) { return nullptr; } - auto version_info = make_shared_ptr(start); + auto version_info = make_shared_ptr(manager.GetBufferManager(), start); MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); auto chunk_count = source.Read(); D_ASSERT(chunk_count > 0); @@ -275,7 +275,7 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de } version_info->FillVectorInfo(vector_index); - version_info->vector_info[vector_index] = ChunkInfo::Read(source); + version_info->vector_info[vector_index] = ChunkInfo::Read(version_info->GetAllocator(), source); } version_info->has_changes = false; return version_info; diff --git a/src/storage/table/standard_column_data.cpp b/src/storage/table/standard_column_data.cpp index fde7d2463eb8..ad8814ab4732 100644 --- a/src/storage/table/standard_column_data.cpp +++ b/src/storage/table/standard_column_data.cpp @@ -152,8 +152,8 @@ idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &re return scan_count; } -void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void StandardColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { ColumnScanState standard_state, validity_state; Vector base_vector(type); auto standard_fetch = FetchUpdateData(standard_state, row_ids, base_vector); @@ -162,18 +162,19 @@ void StandardColumnData::Update(TransactionData transaction, idx_t column_index, throw InternalException("Unaligned fetch in validity and main column data for update"); } - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); - validity.UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); + validity.UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StandardColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { if (depth >= column_path.size()) { // update this column - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } else { // update the child column (i.e. the validity column) - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + validity.UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, update_count, depth + 1); } } diff --git a/src/storage/table/struct_column_data.cpp b/src/storage/table/struct_column_data.cpp index 65f322e7950b..b1de02b2d984 100644 --- a/src/storage/table/struct_column_data.cpp +++ b/src/storage/table/struct_column_data.cpp @@ -207,17 +207,18 @@ idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resu return scan_count; } -void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - validity.Update(transaction, column_index, update_vector, row_ids, update_count); +void StructColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { + validity.Update(transaction, data_table, column_index, update_vector, row_ids, update_count); auto &child_entries = StructVector::GetEntries(update_vector); for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); + sub_columns[i]->Update(transaction, data_table, column_index, *child_entries[i], row_ids, update_count); } } -void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StructColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { // we can never DIRECTLY update a struct column if (depth >= column_path.size()) { throw InternalException("Attempting to directly update a struct column - this should not be possible"); @@ -225,13 +226,13 @@ void StructColumnData::UpdateColumn(TransactionData transaction, const vector sub_columns.size()) { throw InternalException("Update column_path out of range"); } - sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, - depth + 1); + sub_columns[update_column - 1]->UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, + update_count, depth + 1); } } diff --git a/src/storage/table/update_segment.cpp b/src/storage/table/update_segment.cpp index 8056907bc489..c47851ead687 100644 --- a/src/storage/table/update_segment.cpp +++ b/src/storage/table/update_segment.cpp @@ -7,6 +7,7 @@ #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/update_info.hpp" #include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/storage/data_table.hpp" #include @@ -104,9 +105,10 @@ idx_t UpdateInfo::GetAllocSize(idx_t type_size) { return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); } -void UpdateInfo::Initialize(UpdateInfo &info, transaction_t transaction_id) { +void UpdateInfo::Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id) { info.max = STANDARD_VECTOR_SIZE; info.version_number = transaction_id; + info.table = &data_table; info.segment = nullptr; info.prev.entry = nullptr; info.next.entry = nullptr; @@ -1236,11 +1238,11 @@ static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) return pos; } -UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, +UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, DataTable &data_table, idx_t type_size, idx_t count, unsafe_unique_array &data) { data = make_unsafe_uniq_array_uninitialized(UpdateInfo::GetAllocSize(type_size)); auto update_info = reinterpret_cast(data.get()); - UpdateInfo::Initialize(*update_info, transaction.transaction_id); + UpdateInfo::Initialize(*update_info, data_table, transaction.transaction_id); return update_info; } @@ -1258,8 +1260,8 @@ void UpdateSegment::InitializeUpdateInfo(idx_t vector_idx) { } } -void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update_p, row_t *ids, idx_t count, - Vector &base_data) { +void UpdateSegment::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_p, + row_t *ids, idx_t count, Vector &base_data) { // obtain an exclusive lock auto write_lock = lock.GetExclusiveLock(); @@ -1322,10 +1324,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // no updates made yet by this transaction: initially the update info to empty if (transaction.transaction) { auto &dtransaction = transaction.transaction->Cast(); - node_ref = dtransaction.CreateUpdateInfo(type_size, count); + node_ref = dtransaction.CreateUpdateInfo(type_size, data_table, count); node = &UpdateInfo::Get(node_ref); } else { - node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } node->segment = this; node->vector_index = vector_index; @@ -1360,7 +1362,7 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto handle = root->allocator.Allocate(alloc_size); auto &update_info = UpdateInfo::Get(handle); - UpdateInfo::Initialize(update_info, TRANSACTION_ID_START - 1); + UpdateInfo::Initialize(update_info, data_table, TRANSACTION_ID_START - 1); update_info.column_index = column_index; InitializeUpdateInfo(update_info, ids, sel, count, vector_index, vector_offset); @@ -1370,10 +1372,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect UndoBufferReference node_ref; optional_ptr transaction_node; if (transaction.transaction) { - node_ref = transaction.transaction->CreateUpdateInfo(type_size, count); + node_ref = transaction.transaction->CreateUpdateInfo(type_size, data_table, count); transaction_node = &UpdateInfo::Get(node_ref); } else { - transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + transaction_node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); diff --git a/src/transaction/commit_state.cpp b/src/transaction/commit_state.cpp index 0f5d75bd231c..6eba8ab10988 100644 --- a/src/transaction/commit_state.cpp +++ b/src/transaction/commit_state.cpp @@ -165,6 +165,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::INSERT_TUPLE: { // append: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->table->CommitAppend(commit_id, info->start_row, info->count); break; @@ -172,6 +178,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::DELETE_TUPLE: { // deletion: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->version_info->CommitDelete(info->vector_idx, commit_id, *info); break; @@ -179,6 +191,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::UPDATE_TUPLE: { // update: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } info->version_number = commit_id; break; } diff --git a/src/transaction/duck_transaction.cpp b/src/transaction/duck_transaction.cpp index dc6afccb7674..562faa614d09 100644 --- a/src/transaction/duck_transaction.cpp +++ b/src/transaction/duck_transaction.cpp @@ -32,8 +32,8 @@ TransactionData::TransactionData(transaction_t transaction_id_p, transaction_t s DuckTransaction::DuckTransaction(DuckTransactionManager &manager, ClientContext &context_p, transaction_t start_time, transaction_t transaction_id, idx_t catalog_version_p) : Transaction(manager, context_p), start_time(start_time), transaction_id(transaction_id), commit_id(0), - highest_active_query(0), catalog_version(catalog_version_p), awaiting_cleanup(false), - transaction_manager(manager), undo_buffer(*this, context_p), storage(make_uniq(context_p, *this)) { + catalog_version(catalog_version_p), awaiting_cleanup(false), transaction_manager(manager), + undo_buffer(*this, context_p), storage(make_uniq(context_p, *this)) { } DuckTransaction::~DuckTransaction() { @@ -126,11 +126,11 @@ void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_co append_info->count = row_count; } -UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { +UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries) { idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto undo_entry = undo_buffer.CreateEntry(UndoFlags::UPDATE_TUPLE, alloc_size); auto &update_info = UpdateInfo::Get(undo_entry); - UpdateInfo::Initialize(update_info, transaction_id); + UpdateInfo::Initialize(update_info, data_table, transaction_id); return undo_entry; } @@ -246,14 +246,6 @@ ErrorData DuckTransaction::Commit(AttachedDatabase &db, transaction_t new_commit // no need to flush anything if we made no changes return ErrorData(); } - for (auto &entry : modified_tables) { - auto &tbl = entry.first.get(); - if (!tbl.IsMainTable()) { - return ErrorData( - TransactionException("Attempting to modify table %s but another transaction has %s this table", - tbl.GetTableName(), tbl.TableModification())); - } - } D_ASSERT(db.IsSystem() || db.IsTemporary() || !IsReadOnly()); UndoBuffer::IteratorState iterator_state; diff --git a/src/transaction/duck_transaction_manager.cpp b/src/transaction/duck_transaction_manager.cpp index 06d17189b939..2dfe7cd6c3f6 100644 --- a/src/transaction/duck_transaction_manager.cpp +++ b/src/transaction/duck_transaction_manager.cpp @@ -324,7 +324,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran } // We do not need to hold the transaction lock during cleanup of transactions, - // as they (1) have been removed, or (2) exited old_transactions. + // as they (1) have been removed, or (2) enter cleanup_info. t_lock.unlock(); { @@ -412,7 +412,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa idx_t t_index = active_transactions.size(); auto lowest_start_time = TRANSACTION_ID_START; auto lowest_transaction_id = MAX_TRANSACTION_ID; - auto lowest_active_query = MAXIMUM_QUERY_ID; for (idx_t i = 0; i < active_transactions.size(); i++) { if (active_transactions[i].get() == &transaction) { t_index = i; @@ -420,8 +419,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa } lowest_start_time = MinValue(lowest_start_time, active_transactions[i]->start_time); lowest_transaction_id = MinValue(lowest_transaction_id, active_transactions[i]->transaction_id); - transaction_t active_query = active_transactions[i]->active_query; - lowest_active_query = MinValue(lowest_active_query, active_query); } lowest_active_start = lowest_start_time; lowest_active_id = lowest_transaction_id; @@ -429,7 +426,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa // Decide if we need to store the transaction, or if we can schedule it for cleanup. auto current_transaction = std::move(active_transactions[t_index]); - auto current_query = DatabaseManager::Get(db).ActiveQueryNumber(); if (store_transaction) { // If the transaction made any changes, we need to keep it around. if (transaction.commit_id != 0) { @@ -438,9 +434,7 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.push_back(std::move(current_transaction)); } else { // The transaction was aborted. - // We might still need its information; add it to the set of transactions awaiting GC. - current_transaction->highest_active_query = current_query; - old_transactions.push_back(std::move(current_transaction)); + cleanup_info->transactions.push_back(std::move(current_transaction)); } } else if (transaction.ChangesMade()) { // We do not need to store the transaction, directly schedule it for cleanup. @@ -464,18 +458,8 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa break; } - // Changes made BEFORE this transaction are no longer relevant. - // We can schedule the transaction and its undo buffer for cleanup. recently_committed_transactions[i]->awaiting_cleanup = true; - - // HOWEVER: Any currently running QUERY can still be using - // the version information of the transaction. - // If we remove the UndoBuffer immediately, we have a race condition. - - // Store the current highest active query. - recently_committed_transactions[i]->highest_active_query = current_query; - // Move it to the list of transactions awaiting GC. - old_transactions.push_back(std::move(recently_committed_transactions[i])); + cleanup_info->transactions.push_back(std::move(recently_committed_transactions[i])); } if (i > 0) { @@ -485,34 +469,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.erase(start, end); } - // Check if we can clean up and free the memory of any old transactions. - i = active_transactions.empty() ? old_transactions.size() : 0; - for (; i < old_transactions.size(); i++) { - D_ASSERT(old_transactions[i]); - D_ASSERT(old_transactions[i]->highest_active_query > 0); - if (old_transactions[i]->highest_active_query >= lowest_active_query) { - // There is still a query running that could be using - // this transactions' data. - break; - } - } - - if (i > 0) { - // We garbage-collected old transactions: - // - Remove them from the list and schedule them for cleanup. - - // We can only safely do the actual memory cleanup when all the - // currently active queries have finished running! (actually, - // when all the currently active scans have finished running...). - - // Because we clean up asynchronously, we only clean up once we - // no longer need the transaction for anything (i.e., we can move it). - for (idx_t t_idx = 0; t_idx < i; t_idx++) { - cleanup_info->transactions.push_back(std::move(old_transactions[t_idx])); - } - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + static_cast(i)); - } - return cleanup_info; } diff --git a/src/transaction/undo_buffer.cpp b/src/transaction/undo_buffer.cpp index b584f5401c2e..beec05ec9367 100644 --- a/src/transaction/undo_buffer.cpp +++ b/src/transaction/undo_buffer.cpp @@ -177,7 +177,7 @@ void UndoBuffer::Cleanup(transaction_t lowest_active_transaction) { // the chunks) // (2) there is no active transaction with start_id < commit_id of this // transaction - CleanupState state(transaction.context, lowest_active_transaction); + CleanupState state(QueryContext(), lowest_active_transaction); UndoBuffer::IteratorState iterator_state; IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); diff --git a/src/transaction/wal_write_state.cpp b/src/transaction/wal_write_state.cpp index 5fe17e0502f9..0036ad0c69b8 100644 --- a/src/transaction/wal_write_state.cpp +++ b/src/transaction/wal_write_state.cpp @@ -27,10 +27,10 @@ WALWriteState::WALWriteState(DuckTransaction &transaction_p, WriteAheadLog &log, : transaction(transaction_p), log(log), commit_state(commit_state), current_table_info(nullptr) { } -void WALWriteState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { - if (current_table_info != table_info) { +void WALWriteState::SwitchTable(DataTableInfo &table_info, UndoFlags new_op) { + if (current_table_info != &table_info) { // write the current table to the log - log.WriteSetTable(table_info->GetSchemaName(), table_info->GetTableName()); + log.WriteSetTable(table_info.GetSchemaName(), table_info.GetTableName()); current_table_info = table_info; } } @@ -171,7 +171,7 @@ void WALWriteState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { void WALWriteState::WriteDelete(DeleteInfo &info) { // switch to the current table, if necessary - SwitchTable(info.table->GetDataTableInfo().get(), UndoFlags::DELETE_TUPLE); + SwitchTable(*info.table->GetDataTableInfo(), UndoFlags::DELETE_TUPLE); if (!delete_chunk) { delete_chunk = make_uniq(); @@ -198,7 +198,7 @@ void WALWriteState::WriteUpdate(UpdateInfo &info) { auto &column_data = info.segment->column_data; auto &table_info = column_data.GetTableInfo(); - SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); + SwitchTable(table_info, UndoFlags::UPDATE_TUPLE); // initialize the update chunk vector update_types; diff --git a/src/verification/statement_verifier.cpp b/src/verification/statement_verifier.cpp index 81f4c4aba28f..14e4c0491ed8 100644 --- a/src/verification/statement_verifier.cpp +++ b/src/verification/statement_verifier.cpp @@ -1,5 +1,9 @@ #include "duckdb/verification/statement_verifier.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + #include "duckdb/common/error_data.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/parser/parser.hpp" @@ -15,13 +19,24 @@ namespace duckdb { +const vector> &StatementVerifier::GetSelectList(QueryNode &node) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: + return node.Cast().select_list; + case QueryNodeType::SET_OPERATION_NODE: + return GetSelectList(*node.Cast().children[0]); + default: + return empty_select_list; + } +} + StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p, optional_ptr> parameters_p) : type(type), name(std::move(name)), statement(std::move(statement_p)), select_statement(statement->type == StatementType::SELECT_STATEMENT ? &statement->Cast() : nullptr), parameters(parameters_p), - select_list(select_statement ? select_statement->node->GetSelectList() : empty_select_list) { + select_list(select_statement ? GetSelectList(*select_statement->node) : empty_select_list) { } StatementVerifier::StatementVerifier(unique_ptr statement_p, diff --git a/test/api/capi/test_capi_complex_types.cpp b/test/api/capi/test_capi_complex_types.cpp index f540fb63a898..1f2fa3cfa204 100644 --- a/test/api/capi/test_capi_complex_types.cpp +++ b/test/api/capi/test_capi_complex_types.cpp @@ -541,6 +541,15 @@ TEST_CASE("Binding values", "[capi]") { auto struct_value = duckdb_create_struct_value(struct_type, &value); duckdb_destroy_logical_type(&struct_type); + // Fail with out-of-bounds. + duckdb_prepared_statement prepared_fail; + REQUIRE(duckdb_prepare(tester.connection, "SELECT ?, ?", &prepared_fail) == DuckDBSuccess); + auto state = duckdb_bind_value(prepared_fail, 3, struct_value); + REQUIRE(state == DuckDBError); + auto error_msg = duckdb_prepare_error(prepared_fail); + REQUIRE(StringUtil::Contains(string(error_msg), "Can not bind to parameter number")); + duckdb_destroy_prepare(&prepared_fail); + duckdb::vector list_values {value}; auto list_value = duckdb_create_list_value(member_type, list_values.data(), member_count); diff --git a/test/api/capi/test_capi_profiling.cpp b/test/api/capi/test_capi_profiling.cpp index 4e442b71fb80..936feca75250 100644 --- a/test/api/capi/test_capi_profiling.cpp +++ b/test/api/capi/test_capi_profiling.cpp @@ -320,3 +320,97 @@ TEST_CASE("Test profiling with Extra Info enabled", "[capi]") { duckdb_destroy_value(&map); tester.Cleanup(); } + +TEST_CASE("Test profiling with the appender", "[capi]") { + CAPITester tester; + duckdb::unique_ptr result; + REQUIRE(tester.OpenDatabase(nullptr)); + + tester.Query("CREATE TABLE tbl (i INT PRIMARY KEY, value VARCHAR)"); + REQUIRE_NO_FAIL(tester.Query("PRAGMA enable_profiling = 'no_output'")); + REQUIRE_NO_FAIL(tester.Query("SET profiling_coverage='ALL'")); + duckdb_appender appender; + + string query = "INSERT INTO tbl FROM my_appended_data"; + duckdb_logical_type types[2]; + types[0] = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER); + types[1] = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR); + + auto status = duckdb_appender_create_query(tester.connection, query.c_str(), 2, types, "my_appended_data", nullptr, + &appender); + duckdb_destroy_logical_type(&types[0]); + duckdb_destroy_logical_type(&types[1]); + REQUIRE(status == DuckDBSuccess); + REQUIRE(duckdb_appender_error(appender) == nullptr); + + REQUIRE(duckdb_appender_begin_row(appender) == DuckDBSuccess); + REQUIRE(duckdb_append_int32(appender, 1) == DuckDBSuccess); + REQUIRE(duckdb_append_varchar(appender, "hello world") == DuckDBSuccess); + REQUIRE(duckdb_appender_end_row(appender) == DuckDBSuccess); + + REQUIRE(duckdb_appender_flush(appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_close(appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_destroy(&appender) == DuckDBSuccess); + + auto info = duckdb_get_profiling_info(tester.connection); + REQUIRE(info); + + // Check that the query name matches the appender query. + auto query_name = duckdb_profiling_info_get_value(info, "QUERY_NAME"); + REQUIRE(query_name); + auto query_name_c_str = duckdb_get_varchar(query_name); + auto query_name_str = duckdb::string(query_name_c_str); + REQUIRE(query_name_str == query); + duckdb_destroy_value(&query_name); + duckdb_free(query_name_c_str); + + duckdb::map cumulative_counter; + duckdb::map cumulative_result; + TraverseTree(info, cumulative_counter, cumulative_result, 0); + tester.Cleanup(); +} + +TEST_CASE("Test profiling with the non-query appender", "[capi]") { + CAPITester tester; + duckdb::unique_ptr result; + duckdb_state status; + + REQUIRE(tester.OpenDatabase(nullptr)); + tester.Query("CREATE TABLE test (i INTEGER)"); + REQUIRE_NO_FAIL(tester.Query("PRAGMA enable_profiling = 'no_output'")); + REQUIRE_NO_FAIL(tester.Query("SET profiling_coverage='ALL'")); + + duckdb_appender appender; + REQUIRE(duckdb_appender_create(tester.connection, nullptr, "test", &appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_error(appender) == nullptr); + + // Appending a row. + REQUIRE(duckdb_appender_begin_row(appender) == DuckDBSuccess); + REQUIRE(duckdb_append_int32(appender, 42) == DuckDBSuccess); + // Finish and flush. + REQUIRE(duckdb_appender_end_row(appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_flush(appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_close(appender) == DuckDBSuccess); + REQUIRE(duckdb_appender_destroy(&appender) == DuckDBSuccess); + + auto info = duckdb_get_profiling_info(tester.connection); + REQUIRE(info); + + // Check that the query name matches the appender query. + auto query_name = duckdb_profiling_info_get_value(info, "QUERY_NAME"); + REQUIRE(query_name); + + auto query_name_c_str = duckdb_get_varchar(query_name); + auto query_name_str = duckdb::string(query_name_c_str); + + auto query = "INSERT INTO main.test FROM __duckdb_internal_appended_data"; + REQUIRE(query_name_str == query); + + duckdb_destroy_value(&query_name); + duckdb_free(query_name_c_str); + + duckdb::map cumulative_counter; + duckdb::map cumulative_result; + TraverseTree(info, cumulative_counter, cumulative_result, 0); + tester.Cleanup(); +} diff --git a/test/api/capi/test_capi_table_description.cpp b/test/api/capi/test_capi_table_description.cpp index 5064a422e24a..1273d1750b76 100644 --- a/test/api/capi/test_capi_table_description.cpp +++ b/test/api/capi/test_capi_table_description.cpp @@ -64,18 +64,37 @@ TEST_CASE("Test the table description in the C API", "[capi]") { REQUIRE(status == DuckDBSuccess); REQUIRE(duckdb_table_description_error(table_description) == nullptr); + SECTION("Passing nullptr to get_column_count") { + REQUIRE(duckdb_table_description_get_column_count(nullptr) == 0); + } SECTION("Passing nullptr to get_name") { REQUIRE(duckdb_table_description_get_column_name(nullptr, 0) == nullptr); } + SECTION("Passing nullptr to get_type") { + REQUIRE(duckdb_table_description_get_column_type(nullptr, 0) == nullptr); + } SECTION("Out of range column for get_name") { REQUIRE(duckdb_table_description_get_column_name(table_description, 1) == nullptr); } + SECTION("Out of range column for get_type") { + REQUIRE(duckdb_table_description_get_column_type(table_description, 1) == nullptr); + } + SECTION("get the column count") { + auto column_count = duckdb_table_description_get_column_count(table_description); + REQUIRE(column_count == 1); + } SECTION("In range column - get the name") { auto column_name = duckdb_table_description_get_column_name(table_description, 0); string expected = "my_column"; REQUIRE(!expected.compare(column_name)); duckdb_free(column_name); } + SECTION("In range column - get the type") { + auto column_type = duckdb_table_description_get_column_type(table_description, 0); + auto type_id = duckdb_get_type_id(column_type); + REQUIRE(type_id == DUCKDB_TYPE_INTEGER); + duckdb_destroy_logical_type(&column_type); + } duckdb_table_description_destroy(&table_description); } diff --git a/test/api/test_instance_cache.cpp b/test/api/test_instance_cache.cpp index 0105b298f9aa..7abe20e55406 100644 --- a/test/api/test_instance_cache.cpp +++ b/test/api/test_instance_cache.cpp @@ -110,20 +110,98 @@ TEST_CASE("Test attaching the same database path from different databases", "[ap auto db1 = instance_cache.GetOrCreateInstance(":memory:", config, false); auto db2 = instance_cache.GetOrCreateInstance(":memory:", config, false); - string attach_query = "ATTACH '" + test_path + "' AS db_ref"; - Connection con1(*db1); - REQUIRE_NO_FAIL(con1.Query(attach_query)); + Connection con2(*db2); + SECTION("Regular ATTACH conflict") { + string attach_query = "ATTACH '" + test_path + "' AS db_ref"; + + REQUIRE_NO_FAIL(con1.Query(attach_query)); + + // fails - already attached in db1 + REQUIRE_FAIL(con2.Query(attach_query)); + + // if we detach from con1, we can now attach in con2 + REQUIRE_NO_FAIL(con1.Query("DETACH db_ref")); + + REQUIRE_NO_FAIL(con2.Query(attach_query)); + + // .. but not in con1 anymore! + REQUIRE_FAIL(con1.Query(attach_query)); + } + SECTION("ATTACH IF NOT EXISTS") { + string attach_query = "ATTACH IF NOT EXISTS '" + test_path + "' AS db_ref"; + + REQUIRE_NO_FAIL(con1.Query(attach_query)); - // fails - already attached in db1 + // fails - already attached in db1 + REQUIRE_FAIL(con2.Query(attach_query)); + } +} + +TEST_CASE("Test attaching the same database path from different databases in read-only mode", "[api][.]") { + DBInstanceCache instance_cache; + auto test_path = TestCreatePath("instance_cache_reuse_readonly.db"); + + // create an empty database + { + DuckDB db(test_path); + Connection con(db); + REQUIRE_NO_FAIL(con.Query("CREATE TABLE IF NOT EXISTS integers AS FROM (VALUES (1), (2), (3)) t(i)")); + } + + DBConfig config; + auto db1 = instance_cache.GetOrCreateInstance(":memory:", config, false); + auto db2 = instance_cache.GetOrCreateInstance(":memory:", config, false); + auto db3 = instance_cache.GetOrCreateInstance(":memory:", config, false); + + Connection con1(*db1); Connection con2(*db2); - REQUIRE_FAIL(con2.Query(attach_query)); + Connection con3(*db3); + + SECTION("Regular ATTACH conflict") { + string attach_query = "ATTACH '" + test_path + "' AS db_ref"; + string read_only_attach = attach_query + " (READ_ONLY)"; + + REQUIRE_NO_FAIL(con1.Query(read_only_attach)); + + // succeeds - we can attach the same database multiple times in read-only mode + REQUIRE_NO_FAIL(con2.Query(read_only_attach)); + + // fails - we cannot attach in read-write + REQUIRE_FAIL(con3.Query(attach_query)); - // if we detach from con1, we can now attach in con2 - REQUIRE_NO_FAIL(con1.Query("DETACH db_ref")); + // if we detach from con1, we still cannot attach in read-write in con3 + REQUIRE_NO_FAIL(con1.Query("DETACH db_ref")); + REQUIRE_FAIL(con3.Query(attach_query)); - REQUIRE_NO_FAIL(con2.Query(attach_query)); + // but if we detach in con2, we can attach in read-write mode now + REQUIRE_NO_FAIL(con2.Query("DETACH db_ref")); + REQUIRE_NO_FAIL(con3.Query(attach_query)); - // .. but not in con1 anymore! - REQUIRE_FAIL(con1.Query(attach_query)); + // and now we can no longer attach in read-only mode + REQUIRE_FAIL(con1.Query(read_only_attach)); + } + SECTION("ATTACH IF EXISTS") { + string attach_query = "ATTACH IF NOT EXISTS '" + test_path + "' AS db_ref"; + string read_only_attach = attach_query + " (READ_ONLY)"; + + REQUIRE_NO_FAIL(con1.Query(read_only_attach)); + + // succeeds - we can attach the same database multiple times in read-only mode + REQUIRE_NO_FAIL(con2.Query(read_only_attach)); + + // fails - we cannot attach in read-write + REQUIRE_FAIL(con3.Query(attach_query)); + + // if we detach from con1, we still cannot attach in read-write in con3 + REQUIRE_NO_FAIL(con1.Query("DETACH db_ref")); + REQUIRE_FAIL(con3.Query(attach_query)); + + // but if we detach in con2, we can attach in read-write mode now + REQUIRE_NO_FAIL(con2.Query("DETACH db_ref")); + REQUIRE_NO_FAIL(con3.Query(attach_query)); + + // and now we can no longer attach in read-only mode + REQUIRE_FAIL(con1.Query(read_only_attach)); + } } diff --git a/test/api/test_relation_api.cpp b/test/api/test_relation_api.cpp index ceadf049a7fc..d04357fc40d1 100644 --- a/test/api/test_relation_api.cpp +++ b/test/api/test_relation_api.cpp @@ -1063,6 +1063,15 @@ TEST_CASE("Test Relation Pending Query API", "[relation_api]") { } } +TEST_CASE("Test Relation Query setting query", "[relation_api]") { + DuckDB db; + Connection con(db); + + auto query = con.RelationFromQuery("SELECT current_query()"); + auto result = query->Limit(1)->Execute(); + REQUIRE(!result->Fetch()->GetValue(0, 0).ToString().empty()); +} + TEST_CASE("Construct ValueRelation with RelationContextWrapper and operate on it", "[relation_api][txn][wrapper]") { DuckDB db; Connection con(db); diff --git a/test/configs/enable_verification_for_debug.json b/test/configs/enable_verification_for_debug.json index eddcf41626af..9306cfdbad98 100644 --- a/test/configs/enable_verification_for_debug.json +++ b/test/configs/enable_verification_for_debug.json @@ -724,7 +724,8 @@ "test/sql/logging/logging_file_bind_replace.test", "test/sql/optimizer/test_rowid_pushdown_plan.test", "test/sql/pg_catalog/system_functions.test", - "test/sql/storage/compression/test_using_compression.test" + "test/sql/storage/compression/test_using_compression.test", + "test/sql/error/error_position.test" ] }, { diff --git a/test/configs/latest_storage.json b/test/configs/latest_storage.json index 8c749e31d555..c6b36e6b010b 100644 --- a/test/configs/latest_storage.json +++ b/test/configs/latest_storage.json @@ -3,6 +3,9 @@ "on_init": "ATTACH '__TEST_DIR__/{BASE_TEST_NAME}__test__config__latest__storage.db' AS __test__config__latest__storage (STORAGE_VERSION 'latest'); SET storage_compatibility_version='latest';", "on_new_connection": "USE __test__config__latest__storage;", "on_load": "skip", + "settings": [ + {"name": "storage_compatibility_version", "value": "latest"} + ], "skip_compiled": "true", "skip_tests": [ { diff --git a/test/configs/latest_storage_block_size_16kB.json b/test/configs/latest_storage_block_size_16kB.json index 2e5b70146cd7..5030b7faabcd 100644 --- a/test/configs/latest_storage_block_size_16kB.json +++ b/test/configs/latest_storage_block_size_16kB.json @@ -3,6 +3,9 @@ "on_init": "ATTACH '__TEST_DIR__/{BASE_TEST_NAME}__test__config__latest_storage_block_size_16kB.db' AS __test__config__latest_storage_block_size_16kB (STORAGE_VERSION 'latest'); SET storage_compatibility_version='latest';", "on_new_connection": "USE __test__config__latest_storage_block_size_16kB;", "on_load": "skip", + "settings": [ + {"name": "storage_compatibility_version", "value": "latest"} + ], "skip_compiled": "true", "block_size": "16384", "skip_tests": [ diff --git a/test/configs/peg_parser.json b/test/configs/peg_parser.json new file mode 100644 index 000000000000..e869a4e80a95 --- /dev/null +++ b/test/configs/peg_parser.json @@ -0,0 +1,51 @@ +{ + "description": "Test PEG Parser + transformer", + "skip_compiled": "true", + "on_init": "set allow_parser_override_extension=fallback;", + "statically_loaded_extensions": [ + "core_functions", + "autocomplete" + ], + "skip_tests": [ + { + "reason": "SIGBUS", + "paths": [ + "test/sql/catalog/sequence/sequence_offset_increment.test", + "test/sql/upsert/insert_or_replace/returning_nothing.test", + "test/sql/copy_database/copy_database_different_types.test", + "test/sql/copy_database/copy_table_with_sequence.test", + "test/sql/join/external/external_join_many_duplicates.test_slow" + ] + }, + { + "reason": "Arithmetic expression", + "paths": [ + "test/sql/function/operator/test_arithmetic_sqllogic.test", + "test/sql/projection/test_row_id_expression.test", + "test/sql/function/numeric/test_trigo.test", + "test/sql/function/generic/test_between.test" + ] + }, + { + "reason": "Expression Depth", + "paths": [ + "test/sql/overflow/expression_tree_depth.test" + ] + }, + { + "reason": "Timeout", + "paths": [ + "test/sql/copy/csv/afl/test_fuzz_3981.test_slow", + "test/sql/join/external/tpch_all_tables.test_slow", + "test/sql/storage/encryption/temp_files/encrypted_offloading_block_files.test_slow", + "test/sql/storage/temp_directory/offloading_block_files.test_slow" + ] + }, + { + "reason": "Setting option to 'fallback' changes behavior of first tests", + "paths": [ + "test/extension/loadable_parser_override.test" + ] + } + ] +} \ No newline at end of file diff --git a/test/extension/loadable_extension_demo.cpp b/test/extension/loadable_extension_demo.cpp index 5207f6a93933..20ca60987cf3 100644 --- a/test/extension/loadable_extension_demo.cpp +++ b/test/extension/loadable_extension_demo.cpp @@ -251,7 +251,8 @@ class QuackExtension : public ParserExtension { statements.push_back(std::move(select_statement)); } if (StringUtil::CIEquals(query_input, "over")) { - return ParserOverrideResult("Parser overridden, query equaled \"over\" but not \"override\""); + auto exception = ParserException("Parser overridden, query equaled \"over\" but not \"override\""); + return ParserOverrideResult(exception); } } if (statements.empty()) { diff --git a/test/extension/loadable_parser_override.test b/test/extension/loadable_parser_override.test index d6b235c8ff41..b6589e7bc128 100644 --- a/test/extension/loadable_parser_override.test +++ b/test/extension/loadable_parser_override.test @@ -56,9 +56,9 @@ The DuckDB parser has been overridden statement error over ---- -Parser Error: Parser overridden, query equaled "over" but not "override" +Parser Error: Parser override could not parse this query. (Original error: Parser overridden, query equaled "over" but not "override") statement error SELECT 1; ---- -Parser Error: Parser override failed to return a valid statement. Consider restarting the database and using the setting "set allow_parser_override_extension=fallback" to fallback to the default parser. +:.*Parser Error: Parser override failed.* \ No newline at end of file diff --git a/test/geoparquet/disabled.test b/test/geoparquet/disabled.test index 9a08bb5fb67f..c4ba9b51acd9 100644 --- a/test/geoparquet/disabled.test +++ b/test/geoparquet/disabled.test @@ -51,7 +51,7 @@ query I SELECT (decode(value)) as col FROM parquet_kv_metadata('__TEST_DIR__/data-point-out-enabled.parquet'); ---- -{"version":"1.1.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[30.0,10.0,40.0,40.0]}}} +{"version":"1.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[30.0,10.0,40.0,40.0]}}} # Now disable conversion diff --git a/test/geoparquet/mixed.test b/test/geoparquet/mixed.test index 68b7b1a46c50..d6679c6a5f50 100644 --- a/test/geoparquet/mixed.test +++ b/test/geoparquet/mixed.test @@ -49,7 +49,7 @@ query I SELECT (decode(value)) as col FROM parquet_kv_metadata('__TEST_DIR__/t1.parquet'); ---- -{"version":"1.1.0","primary_column":"geom","columns":{"geom":{"encoding":"WKB","geometry_types":["Point","LineString","Polygon","MultiPoint","MultiLineString","MultiPolygon","GeometryCollection","Point Z"],"bbox":[0.0,0.0,1.0,3.0,3.0,1.0]}}} +{"version":"1.0.0","primary_column":"geom","columns":{"geom":{"encoding":"WKB","geometry_types":["Point","LineString","Polygon","MultiPoint","MultiLineString","MultiPolygon","GeometryCollection","Point Z"],"bbox":[0.0,0.0,1.0,3.0,3.0,1.0]}}} #------------------------------------------------------------------------------ diff --git a/test/geoparquet/roundtrip.test b/test/geoparquet/roundtrip.test index 8c313267b5a0..fa4c6e5de2c8 100644 --- a/test/geoparquet/roundtrip.test +++ b/test/geoparquet/roundtrip.test @@ -138,4 +138,4 @@ query I SELECT decode(value) as col FROM parquet_kv_metadata('__TEST_DIR__/data-multipolygon-out.parquet') WHERE key = 'geo'; ---- -{"version":"1.1.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["MultiPolygon"],"bbox":[5.0,5.0,45.0,45.0]}}} \ No newline at end of file +{"version":"1.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["MultiPolygon"],"bbox":[5.0,5.0,45.0,45.0]}}} \ No newline at end of file diff --git a/test/geoparquet/unsupported.test b/test/geoparquet/unsupported.test index 34ba1825559b..9c01633e2f6d 100644 --- a/test/geoparquet/unsupported.test +++ b/test/geoparquet/unsupported.test @@ -21,7 +21,7 @@ FROM parquet_kv_metadata('__TEST_DIR__/t1.parquet'); # But still a normal parquet file query I -SELECT st_astext(geometry) FROM '__TEST_DIR__/t1.parquet'; +SELECT st_astext(st_geomfromwkb(geometry)) FROM '__TEST_DIR__/t1.parquet'; ---- POINT ZM (0 1 2 3) @@ -35,6 +35,6 @@ FROM parquet_kv_metadata('__TEST_DIR__/t1.parquet'); # But still a normal parquet file query I -SELECT st_astext(geometry) FROM '__TEST_DIR__/t1.parquet'; +SELECT st_astext(st_geomfromwkb(geometry)) FROM '__TEST_DIR__/t1.parquet'; ---- POINT M (0 1 2) \ No newline at end of file diff --git a/test/geoparquet/versions.test b/test/geoparquet/versions.test new file mode 100644 index 000000000000..b724b06026b1 --- /dev/null +++ b/test/geoparquet/versions.test @@ -0,0 +1,90 @@ +# name: test/geoparquet/versions.test +# group: [geoparquet] + +require spatial + +require parquet + +# DEFAULT (V1) + +statement ok +COPY (SELECT st_point(1,2) as geometry) +TO '__TEST_DIR__/test_default.parquet' (FORMAT PARQUET); + +query I +SELECT (decode(value)) as col +FROM parquet_kv_metadata('__TEST_DIR__/test_default.parquet'); +---- +{"version":"1.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[1.0,2.0,1.0,2.0]}}} + +query I +SELECT geo_types from parquet_metadata('__TEST_DIR__/test_default.parquet'); +---- +NULL + +# V1 + +statement ok +COPY (SELECT st_point(1,2) as geometry) +TO '__TEST_DIR__/test_v1.parquet' (FORMAT PARQUET, GEOPARQUET_VERSION 'V1'); + +query I +SELECT (decode(value)) as col +FROM parquet_kv_metadata('__TEST_DIR__/test_v1.parquet'); +---- +{"version":"1.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[1.0,2.0,1.0,2.0]}}} + +query I +SELECT geo_types from parquet_metadata('__TEST_DIR__/test_v1.parquet'); +---- +NULL + +# NONE + +statement ok +COPY (SELECT st_point(1,2) as geometry) +TO '__TEST_DIR__/test_none.parquet' (FORMAT PARQUET, GEOPARQUET_VERSION 'NONE'); + +query I +SELECT (decode(value)) as col +FROM parquet_kv_metadata('__TEST_DIR__/test_none.parquet'); +---- + +query I +SELECT geo_types from parquet_metadata('__TEST_DIR__/test_none.parquet'); +---- +[point] + +# BOTH + +statement ok +COPY (SELECT st_point(1,2) as geometry) +TO '__TEST_DIR__/test_both.parquet' (FORMAT PARQUET, GEOPARQUET_VERSION 'BOTH'); + +query I +SELECT (decode(value)) as col +FROM parquet_kv_metadata('__TEST_DIR__/test_both.parquet'); +---- +{"version":"1.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[1.0,2.0,1.0,2.0]}}} + +query I +SELECT geo_types from parquet_metadata('__TEST_DIR__/test_both.parquet'); +---- +[point] + + +# V2 +statement ok +COPY (SELECT st_point(1,2) as geometry) +TO '__TEST_DIR__/test_v2.parquet' (FORMAT PARQUET, GEOPARQUET_VERSION 'V2'); + +query I +SELECT (decode(value)) as col +FROM parquet_kv_metadata('__TEST_DIR__/test_v2.parquet'); +---- +{"version":"2.0.0","primary_column":"geometry","columns":{"geometry":{"encoding":"WKB","geometry_types":["Point"],"bbox":[1.0,2.0,1.0,2.0]}}} + +query I +SELECT geo_types from parquet_metadata('__TEST_DIR__/test_v2.parquet'); +---- +[point] diff --git a/test/helpers/test_config.cpp b/test/helpers/test_config.cpp index 135930a689c0..11f4bb1b412a 100644 --- a/test/helpers/test_config.cpp +++ b/test/helpers/test_config.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/types/uuid.hpp" #include #include +#include namespace duckdb { @@ -57,6 +58,17 @@ static const TestConfigOption test_config_options[] = { {"storage_version", "Database storage version to use by default", LogicalType::VARCHAR, nullptr}, {"data_location", "Directory where static test files are read (defaults to `data/`)", LogicalType::VARCHAR, nullptr}, + {"select_tag", "Select tests which match named tag (as singleton set; multiple sets are OR'd)", + LogicalType::VARCHAR, TestConfiguration::AppendSelectTagSet}, + {"select_tag_set", "Select tests which match _all_ named tags (multiple sets are OR'd)", + LogicalType::LIST(LogicalType::VARCHAR), TestConfiguration::AppendSelectTagSet}, + {"skip_tag", "Skip tests which match named tag (as singleton set; multiple sets are OR'd)", LogicalType::VARCHAR, + TestConfiguration::AppendSkipTagSet}, + {"skip_tag_set", "Skip tests which match _all_ named tags (multiple sets are OR'd)", + LogicalType::LIST(LogicalType::VARCHAR), TestConfiguration::AppendSkipTagSet}, + {"settings", "Configuration settings to apply", + LogicalType::LIST(LogicalType::STRUCT({{"name", LogicalType::VARCHAR}, {"value", LogicalType::VARCHAR}})), + nullptr}, {nullptr, nullptr, LogicalType::INVALID, nullptr}, }; @@ -407,6 +419,22 @@ string TestConfiguration::GetStorageVersion() { return GetOptionOrDefault("storage_version", string()); } +vector TestConfiguration::GetConfigSettings() { + vector result; + if (options.find("settings") != options.end()) { + auto entry = options["settings"]; + auto list_children = ListValue::GetChildren(entry); + for (const auto &value : list_children) { + auto &struct_children = StructValue::GetChildren(value); + ConfigSetting config_setting; + config_setting.name = StringValue::Get(struct_children[0]); + config_setting.value = StringValue::Get(struct_children[1]); + result.push_back(std::move(config_setting)); + } + } + return result; +} + string TestConfiguration::GetTestEnv(const string &key, const string &default_value) { if (test_env.empty() && options.find("test_env") != options.end()) { auto entry = options["test_env"]; @@ -424,6 +452,10 @@ string TestConfiguration::GetTestEnv(const string &key, const string &default_va return test_env[key]; } +const unordered_map &TestConfiguration::GetTestEnvMap() { + return test_env; +} + DebugVectorVerification TestConfiguration::GetVectorVerification() { return EnumUtil::FromString(GetOptionOrDefault("verify_vector", "NONE")); } @@ -432,6 +464,68 @@ DebugInitialize TestConfiguration::GetDebugInitialize() { return EnumUtil::FromString(GetOptionOrDefault("debug_initialize", "NO_INITIALIZE")); } +vector> TestConfiguration::GetSelectTagSets() { + return select_tag_sets; +} + +vector> TestConfiguration::GetSkipTagSets() { + return skip_tag_sets; +} + +std::unordered_set make_tag_set(const Value &src_val) { + // handle both cases -- singleton VARCHAR/string, and set of strings + auto dst_set = std::unordered_set(); + if (src_val.type() == LogicalType::VARCHAR) { + dst_set.insert(src_val.GetValue()); + } else /* LIST(VARCHAR) */ { + for (auto &tag : ListValue::GetChildren(src_val)) { + dst_set.insert(tag.GetValue()); + } + } + return dst_set; +} + +void TestConfiguration::AppendSelectTagSet(const Value &tag_set) { + TestConfiguration::Get().select_tag_sets.push_back(make_tag_set(tag_set)); +} + +void TestConfiguration::AppendSkipTagSet(const Value &tag_set) { + TestConfiguration::Get().skip_tag_sets.push_back(make_tag_set(tag_set)); +} + +bool is_subset(const unordered_set &sub, const vector &super) { + for (const auto &elt : sub) { + if (std::find(super.begin(), super.end(), elt) == super.end()) { + return false; + } + } + return true; +} + +// NOTE: this model of policy assumes simply that all selects are applied to the All set, then +// all skips are applied to that result. (Typical alternative: CLI ordering where each +// select/skip operation is applied in sequence.) +TestConfiguration::SelectPolicy TestConfiguration::GetPolicyForTagSet(const vector &subject_tag_set) { + // Apply select_tag_set first then skip_tag_set; if both empty always NONE + auto policy = TestConfiguration::SelectPolicy::NONE; + // select: if >= 1 select_tag_set is subset of subject_tag_set + // if count(select_tag_sets) > 0 && no matches, SKIP + for (const auto &select_tag_set : select_tag_sets) { + policy = TestConfiguration::SelectPolicy::SKIP; // >=1 sets => SKIP || SELECT + if (is_subset(select_tag_set, subject_tag_set)) { + policy = TestConfiguration::SelectPolicy::SELECT; + break; + } + } + // skip: if >=1 skip_tag_set is subset of subject_tag_set, else passthrough + for (const auto &skip_tag_set : skip_tag_sets) { + if (is_subset(skip_tag_set, subject_tag_set)) { + return TestConfiguration::SelectPolicy::SKIP; + } + } + return policy; +} + bool TestConfiguration::TestForceStorage() { auto &test_config = TestConfiguration::Get(); return !test_config.GetInitialDBPath().empty(); diff --git a/test/include/test_config.hpp b/test/include/test_config.hpp index bbea1b34faf3..68ed3b67a084 100644 --- a/test/include/test_config.hpp +++ b/test/include/test_config.hpp @@ -16,15 +16,28 @@ #include "duckdb/common/atomic.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/enums/debug_initialize.hpp" +#include +#include namespace duckdb { enum class SortStyle : uint8_t { NO_SORT, ROW_SORT, VALUE_SORT }; +struct ConfigSetting { + string name; + Value value; +}; + class TestConfiguration { public: enum class ExtensionAutoLoadingMode { NONE = 0, AVAILABLE = 1, ALL = 2 }; + enum class SelectPolicy : uint8_t { + NONE, // does not match any explicit policy (default: policy=SELECT) + SELECT, // matches explicit select + SKIP // matches explicit skip + }; + static TestConfiguration &Get(); void Initialize(); @@ -60,6 +73,11 @@ class TestConfiguration { vector ErrorMessagesToBeSkipped(); string GetStorageVersion(); string GetTestEnv(const string &key, const string &default_value); + const unordered_map &GetTestEnvMap(); + vector> GetSelectTagSets(); + vector> GetSkipTagSets(); + SelectPolicy GetPolicyForTagSet(const vector &tag_set); + vector GetConfigSettings(); static bool TestForceStorage(); static bool TestForceReload(); @@ -70,12 +88,17 @@ class TestConfiguration { static void ParseConnectScript(const Value &input); static void CheckSortStyle(const Value &input); static bool TryParseSortStyle(const string &sort_style, SortStyle &result); + static void AppendSelectTagSet(const Value &tag_set); + static void AppendSkipTagSet(const Value &tag_set); private: case_insensitive_map_t options; unordered_set tests_to_be_skipped; unordered_map test_env; + vector> select_tag_sets; + vector> skip_tag_sets; + private: template T GetOptionOrDefault(const string &name, T default_val); diff --git a/test/optimizer/pushdown/issue_18603.test b/test/optimizer/pushdown/issue_18603.test new file mode 100644 index 000000000000..823be9e10e27 --- /dev/null +++ b/test/optimizer/pushdown/issue_18603.test @@ -0,0 +1,46 @@ +# name: test/optimizer/pushdown/issue_18603.test +# description: Test filter pushdown with conflict comparison filters +# group: [pushdown] + +statement ok +pragma enable_verification + +statement ok +CREATE TABLE t0(c0 INT, c1 BOOLEAN); + +statement ok +CREATE TABLE t1(c0 INT); + +statement ok +INSERT INTO t0(c0, c1) VALUES (0, 0); + +statement ok +INSERT INTO t1(c0) VALUES (1); + +# test different order of filters +query III +SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 > t0.c0) AND (t1.c0 > t0.c0) AND (t0.c0 < 7) AND (t0.c1 = t0.c0); +---- + +query II +EXPLAIN SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 > t0.c0) AND (t1.c0 > t0.c0) AND (t0.c0 < 7) AND (t0.c1 = t0.c0); +---- +physical_plan :.*EMPTY_RESULT.* + +query II +EXPLAIN SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 > t0.c0) AND (t1.c0 > t0.c0) AND (t0.c0 < 7) AND (t0.c1 = t0.c0); +---- +physical_plan :.*c0 > c0.* + +query II +EXPLAIN SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 > t0.c0) AND (t1.c0 > t0.c0) AND (t0.c0 < 7) AND (t0.c1 = t0.c0); +---- +physical_plan :.*c0 < 7.* + +query III +SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 = t0.c0) AND (t0.c1 > t0.c0) AND (t1.c0 > t0.c0) AND (t0.c0 < 7); +---- + +query III +SELECT * FROM t0 INNER JOIN t1 ON (t0.c1 = t0.c0) AND (t0.c0 < 7) AND (t0.c1 > t0.c0) AND (t1.c0 > t0.c0); +---- diff --git a/test/optimizer/regex_optimizer.test b/test/optimizer/regex_optimizer.test index 8d8b077df4a4..5491caa70050 100644 --- a/test/optimizer/regex_optimizer.test +++ b/test/optimizer/regex_optimizer.test @@ -228,3 +228,28 @@ statement error select count(s) from test where regexp_matches('aaa'); ---- Binder Error + +# Test regexp_matches with flags (like 'm') is properly optimized to contains +statement ok +DELETE FROM test; + +statement ok +INSERT INTO test VALUES ('hello world'), ('test'), ('hello again'); + +query II +explain analyze SELECT s FROM test WHERE regexp_matches(s, 'hello', 'm'); +---- +analyzed_plan :.*contains\(s, 'hello'\).* + +query I nosort +SELECT s FROM test WHERE regexp_matches(s, 'hello', 'm'); +---- +hello world +hello again + +# This used to trigger a debug assertion +query I +WITH fetch_schema AS ( + SELECT s FROM test WHERE regexp_matches(s, 'hello', 'm') +) SELECT * FROM fetch_schema LIMIT 0; +---- diff --git a/test/optimizer/topn_window_elimination.test b/test/optimizer/topn_window_elimination.test index ea9a3ba559a6..e050c2f41c69 100644 --- a/test/optimizer/topn_window_elimination.test +++ b/test/optimizer/topn_window_elimination.test @@ -50,6 +50,24 @@ EXPLAIN SELECT * FROM window_fun('tbl', ['grp', 'a', 'b'], '${sort_order}', ${to ---- logical_opt :.*FILTER.*WINDOW.* +# min/max and nulls +query II +EXPLAIN SELECT * FROM window_fun('tbl_with_null', ['grp'], '${sort_order}', ${topn}) +---- +logical_opt :.*FILTER.*WINDOW.* + +# arg_min/max, no struct_pack and nulls +query II +EXPLAIN SELECT * FROM window_fun('tbl_with_null', ['grp', 'a'], '${sort_order}', ${topn}) +---- +logical_opt :.*FILTER.*WINDOW.* + +# arg_min/max with struct_pack and nulls +query II +EXPLAIN SELECT * FROM window_fun('tbl_with_null', ['grp', 'a', 'b'], '${sort_order}', ${topn}) +---- +logical_opt :.*FILTER.*WINDOW.* + # test lateral join query II EXPLAIN SELECT * FROM lateral_join('${sort_order}', ${topn}) diff --git a/test/parquet/parquet_fuzzer_issues.test b/test/parquet/parquet_fuzzer_issues.test new file mode 100644 index 000000000000..0296ef8e1998 --- /dev/null +++ b/test/parquet/parquet_fuzzer_issues.test @@ -0,0 +1,17 @@ +# name: test/parquet/parquet_fuzzer_issues.test +# description: Test Parquet fuzzer issues +# group: [parquet] + +require parquet + +# internal issue 6129 +statement error +from 'data/parquet-testing/broken/internal_6129.parquet' +---- +invalid number of miniblocks per block + +# internal issue 6165 +statement error +from 'data/parquet-testing/broken/internal_6165.parquet'; +---- +row group does not have enough columns diff --git a/test/parquet/variant/variant_basic_writing.test b/test/parquet/variant/variant_basic_writing.test index d115f0823342..dc6ca289d40d 100644 --- a/test/parquet/variant/variant_basic_writing.test +++ b/test/parquet/variant/variant_basic_writing.test @@ -69,11 +69,10 @@ NULL "this is a long string" "this is big enough to not be classified as a \"short string\" by parquet VARIANT" -# VARIANT is only supported at the root for now statement error COPY (select [123::VARIANT]) TO '__TEST_DIR__/list_of_variant.parquet' ---- -Not implemented Error: Unimplemented type for Parquet "VARIANT" +Not implemented Error: ColumnWriter of type 'VARIANT' requires a transform, but is not a root column, this isn't supported currently statement ok create macro data() as table ( diff --git a/test/sql/aggregate/aggregates/arg_min_max_nulls_last.test b/test/sql/aggregate/aggregates/arg_min_max_nulls_last.test new file mode 100644 index 000000000000..264ef5a4221b --- /dev/null +++ b/test/sql/aggregate/aggregates/arg_min_max_nulls_last.test @@ -0,0 +1,64 @@ +# name: test/sql/aggregate/aggregates/arg_min_max_nulls_last.test +# description: Test arg_min_nulls_last and arg_max_nulls_last +# group: [aggregates] + +statement ok +CREATE TABLE tbl AS SELECT * FROM VALUES (1, 5, 1), (1, NULL, 2), (1, 3, NULL), (2, NULL, NULL), (3, 1, NULL) t(grp, arg, val) + +query I +SELECT arg_max_nulls_last(arg, val) FROM tbl +---- +NULL + +query I +SELECT arg_max_nulls_last(arg, val, 1) FROM tbl +---- +[NULL] + +query I +SELECT arg_max_nulls_last(val, val, 4) FROM tbl +---- +[2, 1, NULL, NULL] + +query II +SELECT grp, arg_max_nulls_last(arg, val) FROM tbl GROUP BY grp ORDER BY grp +---- +1 NULL +2 NULL +3 1 + +query II +SELECT grp, arg_max_nulls_last(arg, val, 2) FROM tbl GROUP BY grp ORDER BY grp +---- +1 [NULL, 5] +2 [NULL] +3 [1] + +query I +SELECT arg_min_nulls_last(arg, val) FROM tbl +---- +5 + +query I +SELECT arg_min_nulls_last(arg, val, 1) FROM tbl +---- +[5] + +query I +SELECT arg_min_nulls_last(val, val, 4) FROM tbl +---- +[1, 2, NULL, NULL] + +query II +SELECT grp, arg_min_nulls_last(arg, val) FROM tbl GROUP BY grp ORDER BY grp +---- +1 5 +2 NULL +3 1 + +query II +SELECT grp, arg_min_nulls_last(arg, val, 2) FROM tbl GROUP BY grp ORDER BY grp +---- +1 [5, NULL] +2 [NULL] +3 [1] diff --git a/test/sql/aggregate/aggregates/arg_min_max_nulls_last_all_types.test_slow b/test/sql/aggregate/aggregates/arg_min_max_nulls_last_all_types.test_slow new file mode 100644 index 000000000000..d6688545e942 --- /dev/null +++ b/test/sql/aggregate/aggregates/arg_min_max_nulls_last_all_types.test_slow @@ -0,0 +1,43 @@ +# name: test/sql/aggregate/aggregates/arg_min_max_nulls_last_all_types.test_slow +# description: Test the ARG_MIN_NULLS_LAST and ARG_MAX_NULLS_LAST overloads with all types +# group: [aggregates] + +statement ok +PRAGMA enable_verification + +statement ok +create table all_types as from test_all_types() + +foreach col bool tinyint smallint int bigint hugeint uhugeint utinyint usmallint uint ubigint date time timestamp timestamp_s timestamp_ms timestamp_ns time_tz timestamp_tz float double dec_4_1 dec_9_4 dec_18_6 dec38_10 uuid interval varchar blob bit small_enum medium_enum large_enum int_array double_array date_array timestamp_array timestamptz_array varchar_array nested_int_array struct struct_of_arrays array_of_structs map union fixed_int_array fixed_varchar_array fixed_nested_int_array fixed_nested_varchar_array fixed_struct_array struct_of_fixed_array fixed_array_of_int_list list_of_fixed_int_array + +statement ok +CREATE OR REPLACE TABLE asc_ordered AS SELECT "${col}" FROM all_types ORDER BY "${col}" ASC NULLS LAST + +statement ok +CREATE OR REPLACE TABLE desc_ordered AS SELECT "${col}" FROM all_types ORDER BY "${col}" DESC NULLS LAST + +statement ok +CREATE OR REPLACE TABLE arg_min_result AS SELECT unnest(arg_min_nulls_last("${col}", "${col}", 3)) FROM all_types + +statement ok +CREATE OR REPLACE TABLE arg_max_result AS SELECT unnest(arg_max_nulls_last("${col}", "${col}", 3)) FROM all_types + +query I +SELECT * FROM (SELECT * FROM asc_ordered ORDER BY rowid) EXCEPT SELECT * FROM (SELECT * FROM arg_min_result ORDER BY rowid); + +query I +SELECT * FROM (SELECT * FROM desc_ordered ORDER BY rowid) EXCEPT SELECT * FROM (SELECT * FROM arg_max_result ORDER BY rowid); + +statement ok +CREATE OR REPLACE TABLE arg_min_result AS SELECT arg_min_nulls_last("${col}", "${col}") FROM all_types + +statement ok +CREATE OR REPLACE TABLE arg_max_result AS SELECT arg_max_nulls_last("${col}", "${col}") FROM all_types + +query I +SELECT * FROM (SELECT * FROM asc_ordered ORDER BY rowid LIMIT 1) EXCEPT SELECT * FROM (SELECT * FROM arg_min_result ORDER BY rowid); + +query I +SELECT * FROM (SELECT * FROM desc_ordered ORDER BY rowid LIMIT 1) EXCEPT SELECT * FROM (SELECT * FROM arg_max_result ORDER BY rowid); + +endloop diff --git a/test/sql/binder/function_chaining_19035.test b/test/sql/binder/function_chaining_19035.test new file mode 100644 index 000000000000..e0d2242b4299 --- /dev/null +++ b/test/sql/binder/function_chaining_19035.test @@ -0,0 +1,43 @@ +# name: test/sql/binder/function_chaining_19035.test +# description: Lambda expression with macro function chaining +# group: [binder] + +statement ok +PRAGMA enable_verification + +statement ok +CREATE MACRO list_contains_macro(x, y) AS (SELECT list_contains(x, y)) + +statement error +SELECT list_filter([[1, 2, 1], [1, 2, 3], [1, 1, 1]], lambda x: list_contains_macro(x, 3)) +---- +:Binder Error.*subqueries in lambda expressions are not supported.* + +statement ok +CREATE TABLE tbl(a int[]); + +statement ok +INSERT INTO tbl VALUES ([5, 4, 3]), ([1, 2, 3]), (NULL), ([NULL, 101, 12]); + +query I +SELECT list_transform(a, lambda x, i: x + i + list_any_value(a)) FROM tbl; +---- +[11, 11, 11] +[3, 5, 7] +NULL +[NULL, 204, 116] + +query I +select ['a a ', ' b ', ' cc'].list_transform(lambda x: nullif(trim(x), '')) as trimmed_and_nulled +---- +[a a, b, cc] + +query I +select ['a a ', ' b ', ' cc'].list_transform(lambda x: x.trim().nullif('')) as trimmed_and_nulled; +---- +[a a, b, cc] + +query I +select ['a a ', ' b ', ' cd'].list_transform(lambda x: x.trim().nullif('').reverse()) as trimmed_and_nulled; +---- +[a a, b, dc] diff --git a/test/sql/binder/not_similar_to.test b/test/sql/binder/not_similar_to.test new file mode 100644 index 000000000000..0ddc9a417b31 --- /dev/null +++ b/test/sql/binder/not_similar_to.test @@ -0,0 +1,65 @@ +# name: test/sql/binder/not_similar_to.test +# description: Correctly return all columns NOT similar to pattern +# group: [binder] + +statement ok +PRAGMA enable_verification + +statement ok +create or replace table foo as select 'd' as a, 'e' as b, 'f' as c; + +# should select the column a and c +query II +select * similar to '(a|c)' from foo; +---- +d f + +# should select the column b +query I +select * not similar to '(a|c)' from foo; +---- +e + +# should select the column b +query I +SELECT * similar to 'b' FROM (select * not similar to '(a|c)' from foo); +---- +e + +statement error +SELECT * similar to 'a' FROM (select * not similar to '(a|c)' from foo); +---- +Binder Error: No matching columns found that match regex "a" + + +statement ok +CREATE TABLE t0(c0 VARCHAR); + +statement ok +INSERT INTO t0(c0) VALUES (0.1); + +query T +SELECT * FROM t0 WHERE REGEXP_MATCHES(t0.c0, '1'); +---- +0.1 + +query T +SELECT * FROM t0 WHERE NOT REGEXP_MATCHES(t0.c0, '1'); +---- + + +query I +SELECT 'aaa' NOT SIMILAR TO '[b-z]{3}'; +---- +1 + +statement ok +CREATE TABLE integers(col1 INTEGER, col2 INTEGER, k INTEGER) + +statement ok +INSERT INTO integers VALUES (1, 2, 3) + +query II +SELECT * LIKE 'col%' FROM integers +---- +1 2 \ No newline at end of file diff --git a/test/sql/copy/csv/afl/test_fuzz_3981.test_slow b/test/sql/copy/csv/afl/test_fuzz_3981.test_slow deleted file mode 100644 index 19ad0bbcd0fe..000000000000 --- a/test/sql/copy/csv/afl/test_fuzz_3981.test_slow +++ /dev/null @@ -1,40 +0,0 @@ -# name: test/sql/copy/csv/afl/test_fuzz_3981.test_slow -# description: fuzzer generated csv files - should not raise internal exception (by failed assertion). -# group: [afl] - -statement ok -PRAGMA enable_verification - -query I -select count(file) from glob('data/csv/afl/3981/*'); ----- -7 - -statement maybe -FROM read_csv('data/csv/afl/3981/case_0.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_1.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_2.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_3.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_4.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_5.csv', compression='gzip'); ----- - -statement maybe -FROM read_csv('data/csv/afl/3981/case_6.csv', compression='gzip'); ----- - diff --git a/test/sql/copy/csv/unquoted_escape/human_eval.test b/test/sql/copy/csv/unquoted_escape/human_eval.test new file mode 100644 index 000000000000..099dd3acf86b --- /dev/null +++ b/test/sql/copy/csv/unquoted_escape/human_eval.test @@ -0,0 +1,86 @@ +# name: test/sql/copy/csv/unquoted_escape/human_eval.test +# description: Test the parsing of unquoted escape characters +# group: [unquoted_escape] + +# +# The data file is generated by the following workflow: +# +# duckdb -c "COPY (SELECT REPLACE(COLUMNS(*), ' ', E'\t') FROM read_ndjson_auto('https://raw.githubusercontent.com/openai/human-eval/refs/heads/master/data/HumanEval.jsonl.gz')) to 'HumanEval.csv'" +# +# docker run --rm -d --name tmp-gen-csv \ +# -e MYSQL_ROOT_PASSWORD=root \ +# -p 13316:3306 \ +# mysql:latest \ +# mysqld --secure-file-priv=/tmp +# +# mysql -h127.0.0.1 -uroot -proot -P13316 --local-infile <= 10; + +statement ok +CREATE TABLE human_eval_csv(task_id TEXT, prompt TEXT, entry_point TEXT, canonical_solution TEXT, test TEXT); + +statement ok +CREATE TABLE human_eval_tsv(task_id TEXT, prompt TEXT, entry_point TEXT, canonical_solution TEXT, test TEXT); + +loop buffer_size 10 25 + +statement ok +TRUNCATE human_eval_csv; + +statement ok +TRUNCATE human_eval_tsv; + +# replace the CRLF with LF to pass the test on Windows +statement ok +INSERT INTO human_eval_csv +SELECT replace(COLUMNS(*), E'\r\n', E'\n') +FROM read_csv('data/csv/unquoted_escape/human_eval.csv', quote = '', escape = '\', sep = ',', header = false, strict_mode = false); + +statement ok +INSERT INTO human_eval_tsv +SELECT replace(COLUMNS(*), E'\r\n', E'\n') +FROM read_csv('data/csv/unquoted_escape/human_eval.tsv', quote = '', escape = '\', sep = '\t', header = false, strict_mode = false); + +# Verify that the three copies are the same +query II +SELECT count(*), bool_and( + j.task_id = c.task_id AND j.task_id = t.task_id AND + j.prompt = c.prompt AND j.prompt = t.prompt AND + j.entry_point = c.entry_point AND j.entry_point = t.entry_point AND + j.canonical_solution = c.canonical_solution AND j.canonical_solution = t.canonical_solution AND + j.test = c.test AND j.test = t.test +)::int +FROM human_eval_jsonl j, human_eval_csv c, human_eval_tsv t +WHERE j.task_id = c.task_id AND j.task_id = t.task_id +---- +10 1 + +endloop \ No newline at end of file diff --git a/test/sql/copy/parquet/parquet_encrypted_tpch_httpfs.test_slow b/test/sql/copy/parquet/parquet_encrypted_tpch_httpfs.test_slow new file mode 100644 index 000000000000..46821f545657 --- /dev/null +++ b/test/sql/copy/parquet/parquet_encrypted_tpch_httpfs.test_slow @@ -0,0 +1,96 @@ +# name: test/sql/copy/parquet/parquet_encrypted_tpch_httpfs.test_slow +# description: Test Parquet encryption with OpenSSL for TPC-H +# group: [parquet] + +require parquet + +require httpfs + +require tpch + +require-env DUCKDB_DATA_DIR + +statement ok +CALL dbgen(sf=1) + +statement ok +PRAGMA add_parquet_key('key128', '0123456789112345') + +statement ok +EXPORT DATABASE '__TEST_DIR__/tpch_encrypted' (FORMAT 'parquet', ENCRYPTION_CONFIG {footer_key: 'key128'}) + +load :memory: + +# re-add key upon loading the DB again +statement ok +PRAGMA add_parquet_key('key128', '0123456789112345') + +statement ok +IMPORT DATABASE '__TEST_DIR__/tpch_encrypted' + +loop i 1 9 + +query I +PRAGMA tpch(${i}) +---- +:${DUCKDB_DATA_DIR}/extension/tpch/dbgen/answers/sf1/q0${i}.csv + +endloop + +loop i 10 23 + +query I +PRAGMA tpch(${i}) +---- +:${DUCKDB_DATA_DIR}/extension/tpch/dbgen/answers/sf1/q${i}.csv + +endloop + +# now again without importing the DB, just with views, so we can test projection/filter pushdown +load :memory: + +# re-add key upon loading the DB again +statement ok +PRAGMA add_parquet_key('key128', '0123456789112345') + +statement ok +CREATE VIEW lineitem AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/lineitem.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW orders AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/orders.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW partsupp AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/partsupp.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW part AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/part.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW customer AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/customer.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW supplier AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/supplier.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW nation AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/nation.parquet', encryption_config={footer_key: 'key128'}); + +statement ok +CREATE VIEW region AS SELECT * FROM read_parquet('__TEST_DIR__/tpch_encrypted/region.parquet', encryption_config={footer_key: 'key128'}); + +loop i 1 9 + +query I +PRAGMA tpch(${i}) +---- +:${DUCKDB_DATA_DIR}/extension/tpch/dbgen/answers/sf1/q0${i}.csv + +endloop + +loop i 10 23 + +query I +PRAGMA tpch(${i}) +---- +:${DUCKDB_DATA_DIR}/extension/tpch/dbgen/answers/sf1/q${i}.csv + +endloop \ No newline at end of file diff --git a/test/sql/copy/parquet/parquet_metadata_glob.test_slow b/test/sql/copy/parquet/parquet_metadata_glob.test_slow index 4d3847be27d5..1571bc5ff591 100644 --- a/test/sql/copy/parquet/parquet_metadata_glob.test_slow +++ b/test/sql/copy/parquet/parquet_metadata_glob.test_slow @@ -5,10 +5,17 @@ require parquet statement ok -SELECT * FROM parquet_metadata('data/parquet-testing/**.parquet'); +set variable parquet_files = ( + select list(file) + from glob('data/parquet-testing/**.parquet') + where 'broken' not in file +) + +statement ok +SELECT * FROM parquet_metadata(getvariable('parquet_files')); query II -select row_group_bytes, row_group_compressed_bytes from parquet_metadata('data/parquet-testing/**.parquet') +select row_group_bytes, row_group_compressed_bytes from parquet_metadata(getvariable('parquet_files')) where file_name = 'data/parquet-testing/varchar_stats.parquet' ---- 200 208 diff --git a/test/sql/cte/cte_bc.test b/test/sql/cte/cte_bc.test new file mode 100644 index 000000000000..2f97c238764b --- /dev/null +++ b/test/sql/cte/cte_bc.test @@ -0,0 +1,35 @@ +# name: test/sql/cte/cte_bc.test +# description: Test BC of reading CTEs +# group: [cte] + +# The database is written with a vector size of 2048. +require vector_size 2048 + +unzip data/storage/cte_v1.db.gz __TEST_DIR__/cte_v1.db + +unzip data/storage/cte_v1_4.db.gz __TEST_DIR__/cte_v1_4.db + +statement ok +ATTACH '__TEST_DIR__/cte_v1.db' (READ_ONLY) + +statement ok +ATTACH '__TEST_DIR__/cte_v1_4.db' (READ_ONLY) + +foreach cte_db cte_v1 cte_v1_4 + +query I +FROM ${cte_db}.v1 +---- +42 + +query I +FROM ${cte_db}.v2 +---- +42 + +query I +FROM ${cte_db}.v3 +---- +42 + +endloop diff --git a/test/sql/cte/cte_schema.test b/test/sql/cte/cte_schema.test new file mode 100644 index 000000000000..cc7a18093030 --- /dev/null +++ b/test/sql/cte/cte_schema.test @@ -0,0 +1,18 @@ +# name: test/sql/cte/cte_schema.test +# description: Test conflict between CTE and table in different schema +# group: [cte] + +statement ok +create schema s1; + +statement ok +create table s1.tbl(a varchar); + +statement ok +insert into s1.tbl values ('hello'); + +query II +with tbl as (select 'world' b) +select * from s1.tbl, tbl; +---- +hello world diff --git a/test/sql/cte/lazy_cte_bind.test b/test/sql/cte/lazy_cte_bind.test new file mode 100644 index 000000000000..abb87262aa00 --- /dev/null +++ b/test/sql/cte/lazy_cte_bind.test @@ -0,0 +1,12 @@ +# name: test/sql/cte/lazy_cte_bind.test +# description: Test that CTE binding is lazy +# group: [cte] + +statement ok +PRAGMA enable_verification + +query I +with cte as (select * from read_parquet('does/not/exist/file.parquet')) +select 42 +---- +42 diff --git a/test/sql/cte/recursive_cte_key_variant.test b/test/sql/cte/recursive_cte_key_variant.test index 6bae6a49ea1d..717ed505ccba 100644 --- a/test/sql/cte/recursive_cte_key_variant.test +++ b/test/sql/cte/recursive_cte_key_variant.test @@ -112,45 +112,45 @@ SELECT * FROM recurring.tbl; ---- Catalog Error: Table with name tbl does not exist! +# no using key statement error WITH RECURSIVE tbl2(a) AS (SELECT 1 UNION SELECT a.a + 1 FROM tbl2 AS a, recurring.tbl2 AS b WHERE a.a < 2) SELECT * FROM tbl2; ---- -:.*Binder Error.*cannot be referenced.* +does not exist statement error WITH RECURSIVE tbl2(a,b) AS (SELECT 1, NULL UNION SELECT a.a+1, a.b FROM tbl2 AS a, (SELECT * FROM recurring.tbl2) AS b WHERE a.a < 2) SELECT * FROM tbl2; ---- -:.*Binder Error.*cannot be referenced.* +does not exist # second cte references recurring table of first cte while first cte does not statement error WITH RECURSIVE tbl(a, b) USING KEY (a) AS (SELECT 5, 1 UNION SELECT a, b + 1 FROM tbl WHERE b < a), tbl1(a,b) AS (SELECT * FROM recurring.tbl) SELECT * FROM tbl1; ---- -:.*Binder Error.*cannot be referenced.* +does not exist # second cte references recurring table of first like the first cte statement error WITH RECURSIVE tbl(a, b) USING KEY (a) AS (SELECT * FROM ((VALUES (5, 1), (6,1)) UNION SELECT a, b + 1 FROM recurring.tbl WHERE b < a) WHERE a = 5), tbl1(a,b) AS (SELECT * FROM recurring.tbl) SELECT * FROM tbl1; ---- -:.*Catalog Error.*does not exist.* +does not exist statement error WITH RECURSIVE tbl2(a) AS (SELECT 1 UNION SELECT a.a+1 FROM tbl2 AS a, (SELECT * FROM recurring.tbl2) AS b WHERE a.a < 2) SELECT * FROM tbl2; ---- -:.*Binder Error.*cannot be referenced.* +does not exist statement error WITH RECURSIVE tbl(a, b) USING KEY (a) AS (SELECT 5, 1 UNION SELECT a, b + 1 FROM tbl WHERE b < a),tbl1(a,b) AS (SELECT * FROM recurring.tbl) SELECT * FROM tbl1; ---- -:.*Binder Error.*cannot be referenced.* +does not exist statement error WITH RECURSIVE tbl(a, b) USING KEY (a) AS MATERIALIZED (SELECT 5, 1 UNION SELECT a, b + 1 FROM tbl WHERE b < a),tbl1(a,b) AS (SELECT * FROM recurring.tbl) SELECT * FROM tbl1; ---- -:.*Binder Error.*cannot be referenced.* - +does not exist ####################### # Connected components diff --git a/test/sql/cte/test_cte.test b/test/sql/cte/test_cte.test index e870e5f80da3..044df3572b11 100644 --- a/test/sql/cte/test_cte.test +++ b/test/sql/cte/test_cte.test @@ -97,6 +97,14 @@ SELECT 1 UNION ALL (WITH cte AS (SELECT 42) SELECT * FROM cte); 1 42 +# cte in nested set operation node +query I +SELECT 1 UNION ALL (WITH cte AS (SELECT 42) SELECT * FROM cte UNION ALL SELECT * FROM cte); +---- +1 +42 +42 + # cte in recursive cte query I WITH RECURSIVE cte(d) AS ( diff --git a/test/sql/filter/test_variant_filter.test b/test/sql/filter/test_variant_filter.test new file mode 100644 index 000000000000..999a6311dd95 --- /dev/null +++ b/test/sql/filter/test_variant_filter.test @@ -0,0 +1,89 @@ +# name: test/sql/filter/test_variant_filter.test +# description: test comparison logic for VARIANT columns +# group: [filter] + +query I +WITH cte as ( + select '12'::VARIANT a +) +select IF(a == 12, 1, 0) from cte +---- +0 + +query I +WITH cte as ( + select '12'::VARIANT a +) +select IF(a == [1,2,3], 1, 0) from cte +---- +0 + +query I +WITH cte as ( + select '[1,2,3]'::VARIANT a +) +select IF(a == [1,2,3], 1, 0) from cte +---- +0 + +# Different lists, not equal +query I +WITH cte as ( + select [1,2,3]::VARIANT a +) +select IF(a == [3, 2, 1], 1, 0) from cte +---- +0 + +# Exact equality of list +query I +WITH cte as ( + select [1,2,3]::VARIANT a +) +select IF(a == [1,2,3], 1, 0) from cte +---- +1 + +query I +WITH cte as ( + select [1::VARIANT,'2',3]::VARIANT a +) +select IF(a == [1,2,3], 1, 0) from cte +---- +0 + +# Not even remotely equal +query I +WITH cte as ( + select {a: 21, b: [1,2,3]}::VARIANT a +) +select IF(a == [1,2,3], 1, 0) from cte +---- +0 + +# Compare equal because objects are compared by key, not position +query I +WITH cte as ( + select {a: 21, b: [1,2,3]}::VARIANT a +) +select IF(a == {b: [1,2,3], a: 21}, 1, 0) from cte +---- +1 + +# GreaterThan / LessThan + +query I +WITH cte as ( + select {a: 21, b: [1,2,3]}::VARIANT a +) +select IF(a < [1,2,3], 1, 0) from cte +---- +0 + +query I +WITH cte as ( + select {a: 21, b: [1,2,3]}::VARIANT a +) +select IF(a > [1,2,3], 1, 0) from cte +---- +1 diff --git a/test/sql/function/variant/variant_extract.test b/test/sql/function/variant/variant_extract.test index 96c69d1c69c6..98fc35bbe66e 100644 --- a/test/sql/function/variant/variant_extract.test +++ b/test/sql/function/variant/variant_extract.test @@ -1,6 +1,9 @@ # name: test/sql/function/variant/variant_extract.test # group: [variant] +statement ok +pragma enable_verification; + require json query I diff --git a/test/sql/function/variant/variant_typeof.test b/test/sql/function/variant/variant_typeof.test index 799026c893d2..f2e9ba67947c 100644 --- a/test/sql/function/variant/variant_typeof.test +++ b/test/sql/function/variant/variant_typeof.test @@ -1,6 +1,9 @@ # name: test/sql/function/variant/variant_typeof.test # group: [variant] +statement ok +pragma enable_verification; + query I select variant_typeof({'a': 42}::VARIANT); ---- diff --git a/test/sql/index/art/issues/test_art_fuzzer_persisted.test b/test/sql/index/art/issues/test_art_fuzzer_persisted.test index 3d4085eba592..1ec3380c0e08 100644 --- a/test/sql/index/art/issues/test_art_fuzzer_persisted.test +++ b/test/sql/index/art/issues/test_art_fuzzer_persisted.test @@ -14,7 +14,7 @@ statement ok CREATE INDEX i1 ON t1 (c1); statement ok -PRAGMA MEMORY_LIMIT='2MB'; +PRAGMA MEMORY_LIMIT='4MB'; statement ok CHECKPOINT; diff --git a/test/sql/json/issues/issue19357.test b/test/sql/json/issues/issue19357.test new file mode 100644 index 000000000000..295498ddc0ed --- /dev/null +++ b/test/sql/json/issues/issue19357.test @@ -0,0 +1,20 @@ +# name: test/sql/json/issues/issue19357.test +# description: Test issue 19357 - Expected unified vector format of type VARCHAR, but found type INT32 +# group: [issues] + +require json + +query I +SELECT TO_JSON({'key_1': 'one'}) AS WITHOUT_KEEP_NULL +---- +{"key_1":"one"} + +query I +SELECT JSON_OBJECT('key_1', 'one', 'key_2', NULL) AS KEEP_NULL_1 +---- +{"key_1":"one","key_2":null} + +statement error +SELECT JSON_OBJECT('key_1', 'one', NULL, 'two') AS KEEP_NULL_2 +---- +json_object() keys must be VARCHAR diff --git a/test/sql/json/scalar/test_json_create.test b/test/sql/json/scalar/test_json_create.test index 0816bdb00017..ab0c969528c4 100644 --- a/test/sql/json/scalar/test_json_create.test +++ b/test/sql/json/scalar/test_json_create.test @@ -27,7 +27,7 @@ select to_json({n: 42}) statement error select to_json({n: 42}, {extra: 'argument'}) ---- -Invalid Input Error: to_json() takes exactly one argument +to_json() takes exactly one argument query T select to_json(union_value(n := 42)) @@ -141,7 +141,7 @@ select json_array(a, b, c, d, e) from test [-777,4.2,"goose",[4,2],null] query T -select json_object(a, a, b, b, c, c, d, d, e, e) from test +select json_object(a::varchar, a, b::varchar, b, c, c, d::varchar, d, e::varchar, e) from test ---- {"0":0,"0.5":0.5,"short":"short","[0, 1, 2, 3, 4, 5, 6, 7, 9]":[0,1,2,3,4,5,6,7,9],"33":33} {"42":42,"1.0":1.0,"looooooooooooooong":"looooooooooooooong","[]":[],"42":42} diff --git a/test/sql/logging/test_logging_function.test b/test/sql/logging/test_logging_function.test index 6fb6a6f34295..c8059d76f668 100644 --- a/test/sql/logging/test_logging_function.test +++ b/test/sql/logging/test_logging_function.test @@ -8,6 +8,11 @@ query IIIIIIIIII from duckdb_logs ---- +statement error +PRAGMA enable_logging; +---- +Pragma Function with name enable_logging does not exist, but a table function with the same name exists, try + statement ok CALL enable_logging(); diff --git a/test/sql/peg_parser/transformer/peg_transformer.test b/test/sql/peg_parser/transformer/peg_transformer.test new file mode 100644 index 000000000000..14c068fd5b98 --- /dev/null +++ b/test/sql/peg_parser/transformer/peg_transformer.test @@ -0,0 +1,17 @@ +# name: test/sql/peg_parser/transformer/peg_transformer.test +# description: Test analyze and vacuum statements in peg parser +# group: [transformer] + +require autocomplete + +require skip_reload + +require no_extension_autoloading "FIXME: to be reviewed whether this can be lifted" + +statement ok +set allow_parser_override_extension=strict; + +statement error +select 1; +---- +Not implemented Error: Parser override has not yet implemented this transformer rule. (Original error: No transformer function found for rule 'SelectStatement') diff --git a/test/sql/peg_parser/transformer/use_statement.test b/test/sql/peg_parser/transformer/use_statement.test new file mode 100644 index 000000000000..f99115edda3e --- /dev/null +++ b/test/sql/peg_parser/transformer/use_statement.test @@ -0,0 +1,24 @@ +# name: test/sql/peg_parser/transformer/use_statement.test +# description: Test use statement with new transformer +# group: [transformer] + +require autocomplete + +require skip_reload + +require no_extension_autoloading "FIXME: to be reviewed whether this can be lifted" + +statement ok +set allow_parser_override_extension=fallback; + +statement ok +ATTACH ':memory:' as "my""db"; + +statement ok +CREATE TABLE "my""db".tbl(i int); + +statement ok +INSERT INTO "my""db".tbl VALUES (42) + +statement ok +USE "my""db"; diff --git a/test/sql/pragma/profiling/test_attach_and_checkpoint_latency.test b/test/sql/pragma/profiling/test_attach_and_checkpoint_latency.test new file mode 100644 index 000000000000..97257c196669 --- /dev/null +++ b/test/sql/pragma/profiling/test_attach_and_checkpoint_latency.test @@ -0,0 +1,102 @@ +# name: test/sql/pragma/profiling/test_attach_and_checkpoint_latency.test +# group: [profiling] + +require json + +require noforcestorage + +require skip_reload + +# Setup. + +statement ok +SET threads = 1; + +statement ok +SET wal_autocheckpoint = '1TB'; + +statement ok +PRAGMA disable_checkpoint_on_shutdown; + +statement ok +PRAGMA profiling_output = '__TEST_DIR__/profile_fs.json'; + +statement ok +PRAGMA custom_profiling_settings='{"WAITING_TO_ATTACH_LATENCY": "true", "ATTACH_LOAD_STORAGE_LATENCY": "true", "ATTACH_REPLAY_WAL_LATENCY": "true", "CHECKPOINT_LATENCY": "true"}'; + +statement ok +SET profiling_coverage='ALL'; + +# Finished setup. + +# CHECKPOINT_LATENCY. + +statement ok +ATTACH '__TEST_DIR__/profile_fs.db'; + +statement ok +CREATE TABLE profile_fs.tbl AS SELECT range AS id FROM range(100_000); + +statement ok +PRAGMA enable_profiling = 'json'; + +statement ok +CHECKPOINT profile_fs; + +statement ok +PRAGMA disable_profiling; + +statement ok +CREATE OR REPLACE TABLE metrics_output AS SELECT * FROM '__TEST_DIR__/profile_fs.json'; + +query I +SELECT + CASE WHEN checkpoint_latency > 0 THEN 'true' + ELSE 'false' END +FROM metrics_output; +---- +true + +# WAITING_TO_ATTACH_LATENCY, ATTACH_LOAD_STORAGE_LATENCY and ATTACH_REPLAY_WAL_LATENCY. + +statement ok +CREATE TABLE profile_fs.other_tbl AS SELECT range AS id FROM range(100_000); + +statement ok +DETACH profile_fs; + +statement ok +PRAGMA enable_profiling = 'json'; + +statement ok +ATTACH '__TEST_DIR__/profile_fs.db'; + +statement ok +PRAGMA disable_profiling; + +statement ok +CREATE OR REPLACE TABLE metrics_output AS SELECT * FROM '__TEST_DIR__/profile_fs.json'; + +query I +SELECT + CASE WHEN waiting_to_attach_latency > 0 THEN 'true' + ELSE 'false' END +FROM metrics_output; +---- +true + +query I +SELECT + CASE WHEN attach_load_storage_latency > 0 THEN 'true' + ELSE 'false' END +FROM metrics_output; +---- +true + +query I +SELECT + CASE WHEN attach_replay_wal_latency > 0 THEN 'true' + ELSE 'false' END +FROM metrics_output; +---- +true diff --git a/test/sql/transactions/delete_and_drop_in_same_transaction.test b/test/sql/transactions/delete_and_drop_in_same_transaction.test new file mode 100644 index 000000000000..2090192edcf7 --- /dev/null +++ b/test/sql/transactions/delete_and_drop_in_same_transaction.test @@ -0,0 +1,23 @@ +# name: test/sql/transactions/delete_and_drop_in_same_transaction.test +# group: [transactions] + +statement ok +CREATE OR REPLACE TABLE SampleTable AS +SELECT DISTINCT id +FROM (VALUES + ('one'), + ('two'), + ('three') +) AS t(id); + +statement ok +BEGIN TRANSACTION; + +statement ok +DELETE FROM sampletable; + +statement ok +DROP TABLE SampleTable; + +statement ok +COMMIT TRANSACTION; diff --git a/test/sql/types/variant/variant_distinct.test b/test/sql/types/variant/variant_distinct.test new file mode 100644 index 000000000000..19964ce3dfcb --- /dev/null +++ b/test/sql/types/variant/variant_distinct.test @@ -0,0 +1,800 @@ +# name: test/sql/types/variant/variant_distinct.test +# description: Test VARIANT distinctions +# group: [variant] + +statement ok +PRAGMA enable_verification + +# Constant single integer column distinctions +query T +SELECT [1]::VARIANT IS NOT DISTINCT FROM [2] +---- +false + +query T +SELECT [1]::VARIANT IS NOT DISTINCT FROM [1] +---- +true + +query T +SELECT NULL IS NOT DISTINCT FROM [1]::VARIANT +---- +false + +query T +SELECT [1] IS NOT DISTINCT FROM NULL::VARIANT +---- +false + +query T +SELECT [1]::VARIANT IS DISTINCT FROM [2] +---- +true + +query T +SELECT [1]::VARIANT IS DISTINCT FROM [1] +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM [1] +---- +true + +query T +SELECT [1]::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW list_int1 AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ([1], [1]), + ([1], [2]), + ([2], [1]), + (NULL, [1]), + ([2], NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_int1 +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_int1 +---- +false +true +true +true +true +false + +# Constant multiple integer column distinctions + +query T +SELECT [1]::VARIANT IS NOT DISTINCT FROM [1, 2] +---- +false + +query T +SELECT [1]::VARIANT IS NOT DISTINCT FROM [1] +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM [1] +---- +false + +query T +SELECT [1] IS NOT DISTINCT FROM NULL::VARIANT +---- +false + +query T +SELECT [1]::VARIANT IS DISTINCT FROM [1, 2] +---- +true + +query T +SELECT [1]::VARIANT IS DISTINCT FROM [1] +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM [1] +---- +true + +query T +SELECT [1] IS DISTINCT FROM NULL::VARIANT +---- +true + +statement ok +CREATE VIEW list_int AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ([1], [1]), + ([1], [1, 2]), + ([1, 2], [1]), + (NULL, [1]), + ([1, 2], NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_int +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_int +---- +false +true +true +true +true +false + +# Constant empty integer column distinctions + +query T +SELECT []::VARIANT IS NOT DISTINCT FROM [1, 2] +---- +false + +query T +SELECT []::VARIANT IS NOT DISTINCT FROM [] +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM [] +---- +false + +query T +SELECT []::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT []::VARIANT IS DISTINCT FROM [1, 2] +---- +true + +query T +SELECT []::VARIANT IS DISTINCT FROM [] +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM [] +---- +true + +query T +SELECT []::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW list_int_empty AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ([], []), + ([], [1, 2]), + ([1, 2], []), + (NULL, []), + ([1, 2], NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_int_empty +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_int_empty +---- +false +true +true +true +true +false + +# List of strings +query T +SELECT ['duck']::VARIANT IS NOT DISTINCT FROM ['duck', 'goose'] +---- +false + +query T +SELECT ['duck']::VARIANT IS NOT DISTINCT FROM ['duck'] +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM ['duck'] +---- +false + +query T +SELECT ['duck']::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT ['duck']::VARIANT IS DISTINCT FROM ['duck', 'goose'] +---- +true + +query T +SELECT ['duck']::VARIANT IS DISTINCT FROM ['duck'] +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM ['duck'] +---- +true + +query T +SELECT ['duck']::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW list_str AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + (['duck'], ['duck']), + (['duck'], ['duck', 'goose']), + (['duck', 'goose'], ['duck']), + (NULL, ['duck']), + (['duck', 'goose'], NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_str +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_str +---- +false +true +true +true +true +false + +# List of structs + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS NOT DISTINCT FROM [{'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}] +---- +false + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS NOT DISTINCT FROM [{'x': 'duck', 'y': 1}] +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM [{'x': 'duck', 'y': 1}] +---- +false + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS DISTINCT FROM [{'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}] +---- +true + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS DISTINCT FROM [{'x': 'duck', 'y': 1}] +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM [{'x': 'duck', 'y': 1}] +---- +true + +query T +SELECT [{'x': 'duck', 'y': 1}]::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW list_of_struct AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ([{'x': 'duck', 'y': 1}], [{'x': 'duck', 'y': 1}]), + ([{'x': 'duck', 'y': 1}], [{'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}]), + ([{'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}], [{'x': 'duck', 'y': 1}]), + (NULL, [{'x': 'duck', 'y': 1}]), + ([{'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}], NULL), + (NULL, NULL) + ) tbl(l, r); + + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_of_struct +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_of_struct +---- +false +true +true +true +true +false + +# Filter by constant +query T +select CASE WHEN a::INT32::VARIANT < 4 THEN [a,a+1,a+2] ELSE NULL END IS NOT DISTINCT FROM [1::INT32,2,3] from range(5) tbl(a); +---- +false +true +false +false +false + +query T +select CASE WHEN a::INT32::VARIANT < 4 THEN [a,a+1,a+2] ELSE NULL END IS DISTINCT FROM [1::INT32,2,3] from range(5) tbl(a); +---- +true +false +true +true +true + +foreach type + +# Constant single integer column distinct +query T +SELECT {'x': 1::${type}}::VARIANT IS NOT DISTINCT FROM {'x': 2::${type}} +---- +false + +query T +SELECT {'x': 1::${type}}::VARIANT IS NOT DISTINCT FROM {'x': 1::${type}} +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM {'x': 1::${type}} +---- +false + +query T +SELECT {'x': 1::${type}}::VARIANT IS DISTINCT FROM {'x': 2::${type}} +---- +true + +query T +SELECT {'x': 1::${type}}::VARIANT IS DISTINCT FROM {'x': 1::${type}} +---- +false + +query T +SELECT {'x': 1::${type}}::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE OR REPLACE VIEW struct_int AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ({'x': 1::${type}}, {'x': 1::${type}}), + ({'x': 1::${type}}, {'x': 2::${type}}), + ({'x': 2::${type}}, {'x': 1::${type}}), + (NULL, {'x': 1::${type}}), + ({'x': 2::${type}}, NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM struct_int +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM struct_int +---- +false +true +true +true +true +false + +endloop + +# Constant single string column distinct +query T +SELECT {'x': 'duck'}::VARIANT IS NOT DISTINCT FROM {'x': 'goose'} +---- +false + +query T +SELECT {'x': 'duck'}::VARIANT IS NOT DISTINCT FROM {'x': 'duck'} +---- +true + + +query T +SELECT {'x': 'duck'}::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM {'x': 'duck'} +---- +false + +query T +SELECT {'x': 'duck'}::VARIANT IS DISTINCT FROM {'x': 'goose'} +---- +true + +query T +SELECT {'x': 'duck'}::VARIANT IS DISTINCT FROM {'x': 'duck'} +---- +false + +query T +SELECT {'x': 'duck'}::VARIANT IS DISTINCT FROM NULL +---- +true + +query T +SELECT NULL::VARIANT IS DISTINCT FROM {'x': 'duck'} +---- +true + +statement ok +CREATE VIEW struct_str AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ({'x': 'duck'}, {'x': 'duck'}), + ({'x': 'duck'}, {'x': 'goose'}), + ({'x': 'goose'}, {'x': 'duck'}), + (NULL, {'x': 'duck'}), + ({'x': 'goose'}, NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM struct_str +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM struct_str +---- +false +true +true +true +true +false + +# Constant string, integer column distinct +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS NOT DISTINCT FROM {'x': 'goose', 'y': 2} +---- +false + +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS NOT DISTINCT FROM {'x': 'duck', 'y': 1} +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM {'x': 'duck', 'y': 1} +---- +false + +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS DISTINCT FROM {'x': 'goose', 'y': 2} +---- +true + +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS DISTINCT FROM {'x': 'duck', 'y': 1} +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM {'x': 'duck', 'y': 1} +---- +true + +query T +SELECT {'x': 'duck', 'y': 1}::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW struct_str_int AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ({'x': 'duck', 'y': 1}, {'x': 'duck', 'y': 1}), + ({'x': 'duck', 'y': 1}, {'x': 'goose', 'y': 2}), + ({'x': 'goose', 'y': 2}, {'x': 'duck', 'y': 1}), + (NULL, {'x': 'duck', 'y': 1}), + ({'x': 'goose', 'y': 2}, NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM struct_str_int +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM struct_str_int +---- +false +true +true +true +true +false + +# Nested structs + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}::VARIANT IS NOT DISTINCT FROM {'x': 2, 'y': {'a': 'goose', 'b': 2.5}} +---- +false + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}::VARIANT IS NOT DISTINCT FROM {'x': 1, 'y': {'a': 'duck', 'b': 1.5}} +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM {'x': 1, 'y': {'a': 'duck', 'b': 1.5}} +---- +false + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}::VARIANT IS DISTINCT FROM {'x': 2, 'y': {'a': 'goose', 'b': 2.5}} +---- +true + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}::VARIANT IS DISTINCT FROM {'x': 1, 'y': {'a': 'duck', 'b': 1.5}} +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM {'x': 1, 'y': {'a': 'duck', 'b': 1.5}} +---- +true + +query T +SELECT {'x': 1, 'y': {'a': 'duck', 'b': 1.5}} IS DISTINCT FROM NULL::VARIANT +---- +true + +statement ok +CREATE VIEW struct_nested AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ({'x': 1, 'y': {'a': 'duck', 'b': 1.5}}, {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}), + ({'x': 1, 'y': {'a': 'duck', 'b': 1.5}}, {'x': 2, 'y': {'a': 'goose', 'b': 2.5}}), + ({'x': 2, 'y': {'a': 'goose', 'b': 2.5}}, {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}), + (NULL, {'x': 1, 'y': {'a': 'duck', 'b': 1.5}}), + ({'x': 2, 'y': {'a': 'goose', 'b': 2.5}}, NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM struct_nested +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM struct_nested +---- +false +true +true +true +true +false + +# List nested inside struct +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS NOT DISTINCT FROM {'x': 2, 'y': ['goose']} +---- +false + +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS NOT DISTINCT FROM {'x': 1, 'y': ['duck', 'somateria']} +---- +true + +query T +SELECT NULL::VARIANT IS NOT DISTINCT FROM {'x': 1, 'y': ['duck', 'somateria']} +---- +false + +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS NOT DISTINCT FROM NULL +---- +false + +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS DISTINCT FROM {'x': 2, 'y': ['goose']} +---- +true + +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS DISTINCT FROM {'x': 1, 'y': ['duck', 'somateria']} +---- +false + +query T +SELECT NULL::VARIANT IS DISTINCT FROM {'x': 1, 'y': ['duck', 'somateria']} +---- +true + +query T +SELECT {'x': 1, 'y': ['duck', 'somateria']}::VARIANT IS DISTINCT FROM NULL +---- +true + +statement ok +CREATE VIEW list_in_struct AS SELECT COLUMNS(*)::VARIANT FROM (VALUES + ({'x': 1, 'y': ['duck', 'somateria']}, {'x': 1, 'y': ['duck', 'somateria']}), + ({'x': 1, 'y': ['duck', 'somateria']}, {'x': 2, 'y': ['goose']}), + ({'x': 2, 'y': ['goose']}, {'x': 1, 'y': ['duck', 'somateria']}), + (NULL, {'x': 1, 'y': ['duck', 'somateria']}), + ({'x': 2, 'y': ['goose']}, NULL), + (NULL, NULL) + ) tbl(l, r); + +query T +SELECT l IS NOT DISTINCT FROM r FROM list_in_struct +---- +true +false +false +false +false +true + +query T +SELECT l IS DISTINCT FROM r FROM list_in_struct +---- +false +true +true +true +true +false + +# Filter by constant +query T +WITH cte as ( + select a::INT32 a from range(5) tbl(a) +) +select + CASE + WHEN a::VARIANT < 4 + THEN { + 'x': a, + 'y': a+1, + 'z': a+2 + }::VARIANT + ELSE NULL + END IS NOT DISTINCT FROM { + 'x': 1::INT32, + 'y': 2::INT32, + 'z': 3::INT32 + } +from cte; +---- +false +true +false +false +false + +query T +WITH cte as ( + select a::INT32 a from range(5) tbl(a) +) +select CASE WHEN a < 4 THEN {'x': a, 'z': a+2, 'y': a+1}::VARIANT ELSE NULL END IS DISTINCT FROM {'x': 1, 'y': 2, 'z': 3} +from cte; +---- +true +false +true +true +true + +# Without the collation that pushes 'variant_normalize', +# the 'variant_extract' (performed by the ['nested'] call) will break the comparison +query T +WITH cte as ( + select a::INT32 a from range(5) tbl(a) +) +select CASE WHEN a < 4 THEN ({ + 'nested': { + 'x': a, + 'z': a+2, + 'y': a+1 + } +}::VARIANT)['nested'] ELSE NULL END IS DISTINCT FROM { + 'x': 1, + 'y': 2, + 'z': 3 +} +from cte; +---- +true +false +true +true +true diff --git a/test/sqlite/README.md b/test/sqlite/README.md new file mode 100644 index 000000000000..d1a5cd814a85 --- /dev/null +++ b/test/sqlite/README.md @@ -0,0 +1,42 @@ +# SQLLogic Test Runner + +## Origin + +Here you'll find source code originating from +[SQLite's SQLLogicTest](https://sqlite.org/sqllogictest/doc/trunk/about.wiki). +DuckDB has extended functionality in several ways, including several new expressions +(test_env, set/reset, tags). + +## Usage Notes + +### Environment: test_env and require-env + +Environment variables can be managed in 2 ways: `test_env` which allows variables to have defaults set, and `require-env` which is a select/skip predicate for a test file. + +For examples of `test_env` usage see the `duckdb/ducklake` extension tests. + +When a file `require-env FOO`, or `require-env FOO=bar` a test will only execute if FOO is set, or in the latter case, set to `bar`. + +### Tags: explicit and implicit + +SQL test files also support a `tags` attribute of the form: + +```text +tags optimization memory>=64GB +``` + +The tags are free-form, and can be used when executing tests for both selection and skipping, a la: + +```bash +build/release/test/unittest --skip-tag 'slow' --select-tag-set "['memory>=64GB', 'env[TEST_DATA]']" +``` + +Tags can be specified individually, or as a set (which is treated as an `AND` predicate). +Each specification is an `OR`, and selects are processed before skips. + +Additionally some implicit tags are computed when an SQL test file is parsed. +All `require-env` and `test_env` expressions will be added as tags of the form `env[VAR]`, and +`env[VAR]=VALUE` (when specified). + +For an extensive example of tag matching expectations, see the file +`test/sqlite/validate_tags_usage.sh` which unit tests these behaviors. diff --git a/test/sqlite/sqllogic_parser.cpp b/test/sqlite/sqllogic_parser.cpp index d2a72db682a4..df59aedb3770 100644 --- a/test/sqlite/sqllogic_parser.cpp +++ b/test/sqlite/sqllogic_parser.cpp @@ -169,6 +169,7 @@ bool SQLLogicParser::IsSingleLineStatement(SQLLogicToken &token) { case SQLLogicTokenType::SQLLOGIC_RECONNECT: case SQLLogicTokenType::SQLLOGIC_SLEEP: case SQLLogicTokenType::SQLLOGIC_UNZIP: + case SQLLogicTokenType::SQLLOGIC_TAGS: return true; case SQLLogicTokenType::SQLLOGIC_SKIP_IF: @@ -183,6 +184,42 @@ bool SQLLogicParser::IsSingleLineStatement(SQLLogicToken &token) { } } +// (All) Context statements must precede all non-header statements +bool SQLLogicParser::IsTestCommand(SQLLogicTokenType &type) { + switch (type) { + case SQLLogicTokenType::SQLLOGIC_QUERY: + case SQLLogicTokenType::SQLLOGIC_STATEMENT: + return true; + + case SQLLogicTokenType::SQLLOGIC_CONCURRENT_FOREACH: + case SQLLogicTokenType::SQLLOGIC_CONCURRENT_LOOP: + case SQLLogicTokenType::SQLLOGIC_ENDLOOP: + case SQLLogicTokenType::SQLLOGIC_FOREACH: + case SQLLogicTokenType::SQLLOGIC_HALT: + case SQLLogicTokenType::SQLLOGIC_HASH_THRESHOLD: + case SQLLogicTokenType::SQLLOGIC_INVALID: + case SQLLogicTokenType::SQLLOGIC_LOAD: + case SQLLogicTokenType::SQLLOGIC_LOOP: + case SQLLogicTokenType::SQLLOGIC_MODE: + case SQLLogicTokenType::SQLLOGIC_ONLY_IF: + case SQLLogicTokenType::SQLLOGIC_RECONNECT: + case SQLLogicTokenType::SQLLOGIC_REQUIRE: + case SQLLogicTokenType::SQLLOGIC_REQUIRE_ENV: + case SQLLogicTokenType::SQLLOGIC_RESET: + case SQLLogicTokenType::SQLLOGIC_RESTART: + case SQLLogicTokenType::SQLLOGIC_SET: + case SQLLogicTokenType::SQLLOGIC_SKIP_IF: + case SQLLogicTokenType::SQLLOGIC_SLEEP: + case SQLLogicTokenType::SQLLOGIC_TAGS: + case SQLLogicTokenType::SQLLOGIC_TEST_ENV: + case SQLLogicTokenType::SQLLOGIC_UNZIP: + return false; + + default: + throw std::runtime_error("Unknown SQLLogic token found!"); + } +} + SQLLogicTokenType SQLLogicParser::CommandToToken(const string &token) { if (token == "skipif") { return SQLLogicTokenType::SQLLOGIC_SKIP_IF; @@ -228,6 +265,8 @@ SQLLogicTokenType SQLLogicParser::CommandToToken(const string &token) { return SQLLogicTokenType::SQLLOGIC_SLEEP; } else if (token == "unzip") { return SQLLogicTokenType::SQLLOGIC_UNZIP; + } else if (token == "tags") { + return SQLLogicTokenType::SQLLOGIC_TAGS; } Fail("Unrecognized parameter %s", token); return SQLLogicTokenType::SQLLOGIC_INVALID; diff --git a/test/sqlite/sqllogic_parser.hpp b/test/sqlite/sqllogic_parser.hpp index e194b2f10ae4..eff512ad229c 100644 --- a/test/sqlite/sqllogic_parser.hpp +++ b/test/sqlite/sqllogic_parser.hpp @@ -37,7 +37,8 @@ enum class SQLLogicTokenType { SQLLOGIC_RESTART, SQLLOGIC_RECONNECT, SQLLOGIC_SLEEP, - SQLLOGIC_UNZIP + SQLLOGIC_UNZIP, + SQLLOGIC_TAGS }; class SQLLogicToken { @@ -61,6 +62,7 @@ class SQLLogicParser { public: static bool EmptyOrComment(const string &line); static bool IsSingleLineStatement(SQLLogicToken &token); + static bool IsTestCommand(SQLLogicTokenType &type); //! Does the next line contain a comment, empty line, or is the end of the file bool NextLineEmptyOrComment(); diff --git a/test/sqlite/sqllogic_test_runner.cpp b/test/sqlite/sqllogic_test_runner.cpp index 257c2761e4f7..72559950dfa3 100644 --- a/test/sqlite/sqllogic_test_runner.cpp +++ b/test/sqlite/sqllogic_test_runner.cpp @@ -58,6 +58,9 @@ SQLLogicTestRunner::SQLLogicTestRunner(string dbpath) : dbpath(std::move(dbpath) } else if (config->options.autoload_known_extensions) { local_extension_repo = string(DUCKDB_BUILD_DIRECTORY) + "/repository"; } + for (auto &entry : test_config.GetConfigSettings()) { + config->SetOptionByName(entry.name, entry.value); + } } SQLLogicTestRunner::~SQLLogicTestRunner() { @@ -708,6 +711,14 @@ bool TryParseConditions(SQLLogicParser &parser, const string &condition_text, ve return true; } +// add implicit tags from environment variables, with value if available +void add_env_tag(vector &tags, const string &name, const string *value = nullptr) { + tags.emplace_back(StringUtil::Format("env[%s]", name)); + if (value != nullptr) { + tags.emplace_back(StringUtil::Format("env[%s]=%s", name, *value)); + } +} + void SQLLogicTestRunner::ExecuteFile(string script) { auto &test_config = TestConfiguration::Get(); if (test_config.ShouldSkipTest(script)) { @@ -718,6 +729,9 @@ void SQLLogicTestRunner::ExecuteFile(string script) { file_name = script; SQLLogicParser parser; idx_t skip_level = 0; + bool test_expr_executed = false; + bool file_tags_expr_seen = false; + vector file_tags; // gets both implicit and file-spec'd // for the original SQLite tests we convert floating point numbers to integers // for our own tests this is undesirable since it hides certain errors @@ -747,6 +761,13 @@ void SQLLogicTestRunner::ExecuteFile(string script) { FAIL("Could not find test script '" + script + "'. Perhaps run `make sqlite`. "); } + if (StringUtil::EndsWith(script, ".test_slow")) { + file_tags.emplace_back("slow"); + } + if (StringUtil::EndsWith(script, ".test_coverage")) { + file_tags.emplace_back("coverage"); + } + /* Loop over all records in the file */ while (parser.NextStatement()) { // tokenize the current line @@ -757,6 +778,15 @@ void SQLLogicTestRunner::ExecuteFile(string script) { parser.Fail("all test statements need to be separated by an empty line"); } + // Check tags first time we hit test statements, since all explicit & implicit tags now present + if (parser.IsTestCommand(token.type) && !test_expr_executed) { + if (test_config.GetPolicyForTagSet(file_tags) == TestConfiguration::SelectPolicy::SKIP) { + SKIP_TEST("select tag-set"); + return; + } + test_expr_executed = true; + } + vector conditions; bool skip_statement = false; while (token.type == SQLLogicTokenType::SQLLOGIC_SKIP_IF || token.type == SQLLogicTokenType::SQLLOGIC_ONLY_IF) { @@ -1052,7 +1082,9 @@ void SQLLogicTestRunner::ExecuteFile(string script) { if (environment_variables.count(env_var)) { parser.Fail(StringUtil::Format("Environment/Test variable '%s' has already been defined", env_var)); } + environment_variables[env_var] = env_actual; + add_env_tag(file_tags, env_var, &env_actual); } else if (token.type == SQLLogicTokenType::SQLLOGIC_REQUIRE_ENV) { if (InLoop()) { @@ -1087,12 +1119,15 @@ void SQLLogicTestRunner::ExecuteFile(string script) { SKIP_TEST("require-env " + token.parameters[0] + " " + token.parameters[1]); return; } + + file_tags.emplace_back(StringUtil::Format("env[%s]=%s", token.parameters[0], token.parameters[1])); } if (environment_variables.count(env_var)) { parser.Fail(StringUtil::Format("Environment variable '%s' has already been defined", env_var)); } environment_variables[env_var] = env_actual; + add_env_tag(file_tags, token.parameters[0], token.parameters.size() == 2 ? &token.parameters[1] : nullptr); } else if (token.type == SQLLogicTokenType::SQLLOGIC_LOAD) { auto &test_config = TestConfiguration::Get(); @@ -1172,6 +1207,26 @@ void SQLLogicTestRunner::ExecuteFile(string script) { auto command = make_uniq(*this, input_path, extraction_path); ExecuteCommand(std::move(command)); + } else if (token.type == SQLLogicTokenType::SQLLOGIC_TAGS) { + // NOTE: tags-before-test-commands is the low bar right now + // 1 better: all non-command lines precede command lines + // Mo better: parse first, build entire context before execution; allows e.g. + // - implicit tag scans of e.g. strings, vars, etc., like '${ENVVAR}', '__TEST_DIR__', 'ATTACH' + // - faster subset runs + // - tag match runs to generate lists + if (test_expr_executed) { + parser.Fail("tags expression must precede test commands"); + } + if (file_tags_expr_seen) { + parser.Fail("tags may be only specified once"); + } + file_tags_expr_seen = true; + if (token.parameters.empty()) { + parser.Fail("tags requires >= 1 argument, e.g.: [tag2 .. tagN]"); + } + + // extend file_tags for jit eval + file_tags.insert(file_tags.begin(), token.parameters.begin(), token.parameters.end()); } } if (InLoop()) { diff --git a/test/sqlite/tags/tags-1-2-3.test_slow b/test/sqlite/tags/tags-1-2-3.test_slow new file mode 100644 index 000000000000..14acdf7e1e65 --- /dev/null +++ b/test/sqlite/tags/tags-1-2-3.test_slow @@ -0,0 +1,9 @@ +# name: test/sqlite/tags/tags-1-2-3.test_slow +# group: [tags] + +tags 1 2 3 + +require-env VALIDATE_TAGS + +statement ok +SELECT 'tagged 1 2 3'; diff --git a/test/sqlite/tags/tags-1-2.test b/test/sqlite/tags/tags-1-2.test new file mode 100644 index 000000000000..768f5eff319d --- /dev/null +++ b/test/sqlite/tags/tags-1-2.test @@ -0,0 +1,9 @@ +# name: test/sqlite/tags/tags-1-2.test +# group: [tags] + +tags 1 2 + +require-env VALIDATE_TAGS + +statement ok +SELECT 'tagged 1 2'; diff --git a/test/sqlite/tags/tags-1.test b/test/sqlite/tags/tags-1.test new file mode 100644 index 000000000000..1775052735f0 --- /dev/null +++ b/test/sqlite/tags/tags-1.test @@ -0,0 +1,9 @@ +# name: test/sqlite/tags/tags-1.test +# group: [tags] + +tags 1 + +require-env VALIDATE_TAGS + +statement ok +SELECT 'tagged 1'; diff --git a/test/sqlite/tags/tags-a.test b/test/sqlite/tags/tags-a.test new file mode 100644 index 000000000000..07098abd9370 --- /dev/null +++ b/test/sqlite/tags/tags-a.test @@ -0,0 +1,9 @@ +# name: test/sqlite/tags/tags-a.test +# group: [tags] + +tags a + +require-env VALIDATE_TAGS + +statement ok +SELECT 'tagged a'; diff --git a/test/sqlite/test_sqllogictest.cpp b/test/sqlite/test_sqllogictest.cpp index 3c292311b9d2..8466fa11df7e 100644 --- a/test/sqlite/test_sqllogictest.cpp +++ b/test/sqlite/test_sqllogictest.cpp @@ -61,6 +61,11 @@ static void testRunner() { runner.output_sql = Catch::getCurrentContext().getConfig()->outputSQL(); runner.enable_verification = VERIFICATION; + // Copy configured env vars + for (auto &kv : test_config.GetTestEnvMap()) { + runner.environment_variables[kv.first] = kv.second; + } + string prev_directory; // We assume the test working dir for extensions to be one dir above the test/sql. Note that this is very hacky. diff --git a/test/sqlite/validate_tags_usage.sh b/test/sqlite/validate_tags_usage.sh new file mode 100755 index 000000000000..d44039baeea9 --- /dev/null +++ b/test/sqlite/validate_tags_usage.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash + +## +# assumes $SCRIPT_DIR/../../build/debug/test/unittest to be ready to run +# + +ROOT=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd) +cd "$ROOT" + +: ${UNITTEST:="build/debug/test/unittest --output-sql=true"} +: ${TESTS_SPEC:='test/sqlite/tags/*'} +export VALIDATE_TAGS=1 + +run() { + $UNITTEST "$@" "$TESTS_SPEC" 2>&1 | grep "SELECT 'tagged" +} + +expect() { + output=$(cat) + local errs=0 + if [[ "$#" -eq 0 ]]; then + [[ $(echo -n "$output" | wc -l) -eq 0 ]] && { + echo -n "✅ - ok" + } || { + printf "\n ❌ - error - matches found but none expected:\n%s" "$output" + } + else + for elt in "$@"; do + echo -n "$output" | grep -q "'tagged $elt'" || { + printf "\n ❌ - error - missing %s" "$elt" + errs=$(($errs + 1)) + } + done + [[ $errs -eq 0 ]] && { + echo -n " ✅ - ok" + } + fi +} + +test() { + local args="$1" + shift + echo -n "test $args -- " + run $args | expect "$@" + echo +} + +test_select() { + # select tags + test "--select-tag 1" "1" "1 2" "1 2 3" + test "--select-tag-set ['1']" "1" "1 2" "1 2 3" + + test "--select-tag 2" "1 2" "1 2 3" + test "--select-tag-set ['2']" "1 2" "1 2 3" + + test "--select-tag 3" "1 2 3" + test "--select-tag-set ['3']" "1 2 3" + + test "--select-tag-set ['1','3']" "1 2 3" + test "--select-tag-set ['2','3']" "1 2 3" + test "--select-tag-set ['1','2']" "1 2" "1 2 3" + test "--select-tag-set ['1','2','3']" "1 2 3" + test "--select-tag-set ['1','2','3','4']" +} + +test_skip() { + # skip tags + test "--skip-tag 1" + test "--skip-tag 1" + + test "--skip-tag 1" + test "--skip-tag-set ['1']" + + test "--skip-tag 2" "1" + test "--skip-tag-set ['2']" "1" + + test "--skip-tag 3" "1" "1 2" + test "--skip-tag-set ['3']" "1" "1 2" + + test "--skip-tag-set ['1','3']" "1" "1 2" + test "--skip-tag-set ['2','3']" "1" "1 2" + test "--skip-tag-set ['1','2']" "1" + test "--skip-tag-set ['1','2','3']" "1" "1 2" + test "--skip-tag-set ['1','2','3','a']" "1" "1 2" "1 2 3" "a" +} + +test_combo() { + # crossover + test "--select-tag 1 --skip-tag 2" "1" + test "--skip-tag-set ['1','2','3'] --select-tag 2" "1 2" + test "--select-tag 1 --skip-tag 1" + test "--select-tag noexist --skip-tag 1" + test "--select-tag 3 --skip-tag noexist" "1 2 3" + + # confirm BNF behavior + test "--select-tag 3 --select-tag a" "1 2 3" "a" + test "--skip-tag 3 --skip-tag a" "1" "1 2" +} + +test_implicit_env() { + test "--select-tag env[VALIDATE_TAGS]" "1" "1" "1 2" "1 2 3" "a" + test "--select-tag env[VALIDATE_TAGS]=0" + # NOTE: =1 not set because it's a require, not a test-env + test "--select-tag env[VALIDATE_TAGS]=1" +} + +test "" "1" "1 2" "1 2 3" +test_select +test_skip +test_combo +test_implicit_env diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index a7e32c8e9b89..5294f426512f 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -17,10 +17,10 @@ if(NOT AMALGAMATION_BUILD) endif() if(NOT WIN32 - AND NOT SUN - AND ${BUILD_UNITTESTS}) - add_subdirectory(imdb) - if(${BUILD_TPCE}) + AND ${BUILD_UNITTESTS} AND ${BUILD_TPCE}) add_subdirectory(tpce-tool) - endif() endif() + +if (${BUILD_BENCHMARKS}) + add_subdirectory(imdb) +endif() \ No newline at end of file diff --git a/tools/juliapkg/Project.toml b/tools/juliapkg/Project.toml index b7c0519587b0..4256dc97084f 100644 --- a/tools/juliapkg/Project.toml +++ b/tools/juliapkg/Project.toml @@ -1,7 +1,7 @@ name = "DuckDB" uuid = "d2f5444f-75bc-4fdf-ac35-56f514c445e1" authors = ["Mark Raasveldt "] -version = "1.3.2" +version = "1.4.1" [deps] DBInterface = "a10d1c49-ce27-4219-8d33-6db1a4562965" @@ -14,7 +14,7 @@ WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [compat] DBInterface = "2.5" -DuckDB_jll = "1.3.2" +DuckDB_jll = "1.4.1" FixedPointDecimals = "0.4, 0.5, 0.6" Tables = "1.7" WeakRefStrings = "1.4"