diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index b405e88f1fbe..09d65bfe5f47 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,7 @@ name: Bug report description: Create a report to help us improve +labels: + - needs triage body: - type: markdown attributes: @@ -72,21 +74,23 @@ body: # Before Submitting - - type: checkboxes + - type: dropdown attributes: label: Have you tried this on the latest `master` branch? description: | * **Python**: `pip install duckdb --upgrade --pre` * **R**: `install.packages('duckdb', repos=c('https://duckdb.r-universe.dev', 'https://cloud.r-project.org'))` * **Other Platforms**: You can find links to binaries [here](https://duckdb.org/docs/installation/) or compile from source. - options: - - label: I agree - required: true + - I have tested with a master build + - I have tested with a release build (and could not test with a master build) + - I have not tested with any build + validations: + required: true - type: checkboxes attributes: label: Have you tried the steps to reproduce? Do they include all relevant data and configuration? Does the issue you report still appear there? options: - - label: I agree + - label: Yes, I have required: true diff --git a/.github/config/extensions.csv b/.github/config/extensions.csv index 36f4f0405eec..6933b462da33 100644 --- a/.github/config/extensions.csv +++ b/.github/config/extensions.csv @@ -12,4 +12,4 @@ sqlite_scanner,https://github.com/duckdblabs/sqlite_scanner,05dd50b3fc91ee497496 postgres_scanner,https://github.com/duckdblabs/postgres_scanner,cd043b49cdc9e0d3752535b8333c9433e1007a48, substrait,https://github.com/duckdblabs/substrait,53da781310c9c680efb97576d33a5fde89a58870,no-windows arrow,https://github.com/duckdblabs/arrow,1a43a5513b96e4c6ffd92026775ffeb648e71dac,no-windows -spatial,https://github.com/duckdblabs/duckdb_spatial.git,f577b9441793f9170403e489f5d3587e023a945f, +spatial,https://github.com/duckdblabs/duckdb_spatial.git,dc66594776fbe2f0a8a3af30af7f9f8626e6e215, diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index 1d8a3fe5f6bd..2966b3714c1c 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -15,13 +15,7 @@ duckdb_extension_load(postgres_scanner GIT_URL https://github.com/duckdblabs/postgres_scanner GIT_TAG cd043b49cdc9e0d3752535b8333c9433e1007a48 ) -duckdb_extension_load(spatial - DONT_LINK - GIT_URL https://github.com/duckdblabs/duckdb_spatial.git - GIT_TAG f577b9441793f9170403e489f5d3587e023a945f - INCLUDE_DIR spatial/include - TEST_DIR spatial/test/sql -) + if (NOT WIN32) duckdb_extension_load(arrow DONT_LINK diff --git a/.github/config/uncovered_files.csv b/.github/config/uncovered_files.csv index 8ed74fd2a146..6c5ab11e3491 100644 --- a/.github/config/uncovered_files.csv +++ b/.github/config/uncovered_files.csv @@ -52,7 +52,7 @@ common/row_operations/row_match.cpp 4 common/serializer/binary_deserializer.cpp 12 common/serializer/binary_serializer.cpp 21 common/serializer/buffered_deserializer.cpp 3 -common/serializer/buffered_file_reader.cpp 9 +common/serializer/buffered_file_reader.cpp 11 common/serializer/buffered_file_writer.cpp 2 common/sort/comparators.cpp 72 common/sort/merge_sorter.cpp 100 @@ -258,7 +258,7 @@ function/scalar/string/regexp.cpp 5 function/scalar/string/regexp/regexp_extract_all.cpp 5 function/scalar/string/substring.cpp 8 function/scalar/struct/struct_extract.cpp 3 -function/scalar/system/aggregate_export.cpp 8 +function/scalar/system/aggregate_export.cpp 10 function/scalar_function.cpp 7 function/table/arrow.cpp 61 function/table/arrow_conversion.cpp 229 @@ -331,7 +331,7 @@ include/duckdb/common/bitpacking.hpp 11 include/duckdb/common/dl.hpp 3 include/duckdb/common/enum_util.hpp 3 include/duckdb/common/exception.hpp 22 -include/duckdb/common/field_writer.hpp 4 +include/duckdb/common/field_writer.hpp 5 include/duckdb/common/gzip_file_system.hpp 3 include/duckdb/common/hive_partitioning.hpp 3 include/duckdb/common/local_file_system.hpp 5 @@ -397,6 +397,7 @@ include/duckdb/main/capi/cast/from_decimal.hpp 2 include/duckdb/main/capi/cast/to_decimal.hpp 38 include/duckdb/main/client_context.hpp 2 include/duckdb/main/client_context_file_opener.hpp 3 +include/duckdb/main/chunk_scan_state.hpp 4 include/duckdb/main/connection.hpp 32 include/duckdb/main/connection_manager.hpp 3 include/duckdb/main/database.hpp 2 @@ -449,10 +450,12 @@ include/duckdb/verification/no_operator_caching_verifier.hpp 3 include/duckdb/verification/parsed_statement_verifier.hpp 3 main/appender.cpp 23 main/attached_database.cpp 14 +main/chunk_scan_state.cpp 21 main/capi/appender-c.cpp 3 main/capi/arrow-c.cpp 8 main/capi/cast/from_decimal-c.cpp 5 main/capi/cast/utils-c.cpp 3 +main/chunk_scan_state/query_result.cpp 28 main/capi/data_chunk-c.cpp 30 main/capi/duckdb-c.cpp 3 main/capi/duckdb_value-c.cpp 14 @@ -693,7 +696,7 @@ planner/joinside.cpp 12 planner/logical_operator.cpp 30 planner/operator/logical_aggregate.cpp 3 planner/operator/logical_column_data_get.cpp 2 -planner/operator/logical_copy_to_file.cpp 36 +planner/operator/logical_copy_to_file.cpp 37 planner/operator/logical_create_index.cpp 2 planner/operator/logical_cross_product.cpp 2 planner/operator/logical_cteref.cpp 2 @@ -705,7 +708,7 @@ planner/operator/logical_dummy_scan.cpp 2 planner/operator/logical_execute.cpp 3 planner/operator/logical_export.cpp 3 planner/operator/logical_expression_get.cpp 2 -planner/operator/logical_extension_operator.cpp 7 +planner/operator/logical_extension_operator.cpp 18 planner/operator/logical_get.cpp 10 planner/operator/logical_insert.cpp 4 planner/operator/logical_pivot.cpp 3 @@ -748,6 +751,7 @@ storage/data_table.cpp 46 storage/index.cpp 3 storage/local_storage.cpp 10 storage/magic_bytes.cpp 2 +storage/meta_block_reader.cpp 3 storage/optimistic_data_writer.cpp 4 storage/partial_block_manager.cpp 26 storage/single_file_block_manager.cpp 15 diff --git a/.github/config/vcpkg_extensions.cmake b/.github/config/vcpkg_extensions.cmake index b9639d2c1d9e..ccdb328b2888 100644 --- a/.github/config/vcpkg_extensions.cmake +++ b/.github/config/vcpkg_extensions.cmake @@ -17,6 +17,14 @@ duckdb_extension_load(aws GIT_TAG 617a4b1456eec1dee3d668f9ce005a1de9ef21c8 ) +duckdb_extension_load(spatial + DONT_LINK + GIT_URL https://github.com/duckdblabs/duckdb_spatial.git + GIT_TAG dc66594776fbe2f0a8a3af30af7f9f8626e6e215 + INCLUDE_DIR spatial/include + TEST_DIR test/sql +) + # Windows tests for iceberg currently not working if (NOT WIN32) set(LOAD_ICEBERG_TESTS "LOAD_TESTS") @@ -28,4 +36,5 @@ duckdb_extension_load(iceberg ${LOAD_ICEBERG_TESTS} GIT_URL https://github.com/duckdblabs/duckdb_iceberg GIT_TAG 6481aa4dd0ab9d724a8df28a1db66800561dd5f9 + APPLY_PATCHES ) \ No newline at end of file diff --git a/.github/patches/duckdb-wasm/README.md b/.github/patches/duckdb-wasm/README.md index a59ad04ec5ee..bf8cd8f6587b 100644 --- a/.github/patches/duckdb-wasm/README.md +++ b/.github/patches/duckdb-wasm/README.md @@ -1,2 +1,19 @@ -Welcome! -This is a placeholder +# WASM patches +Patches in this directory are used to smoothen the process of introducing changes to DuckDB that break compatibility with +the current duckdb-wasm version as is pinned in the `.github/workflows/NightlyTests.yml` workflow. + +# Workflow +Imagine a change to DuckDB is introduced that breaks compatibility with current WASM. The +workflow for this is as follows: + +### PR #1: breaking change to DuckDB +- Commit breaking change to DuckDB +- Fix breakage in duckdb-wasm, producing a patch with fix (be wary of already existing patches) +- Commit patch in `.github/patches/duckdb-wasm/*.patch` using a descriptive name + +### PR #2: patch to duckdb-wasm +- Apply (all) the patch(es) in `.github/patches/duckdb-wasm/*.patch` to duckdb-wasm. + +### PR #3: update extension X in DuckDB +- Remove patches in `.github/patches/duckdb-wasm/*.patch` +- Update hash of duckdb-wasm in `.github/workflows/NightlyTests.yml` \ No newline at end of file diff --git a/.github/patches/extensions/README.md b/.github/patches/extensions/README.md new file mode 100644 index 000000000000..367eacdb43cf --- /dev/null +++ b/.github/patches/extensions/README.md @@ -0,0 +1,42 @@ +# Extension patches +Patches in this directory are used to smoothen the process of introducing changes to DuckDB that break compatibility with an +out-of-tree extension. Extensions installed from git urls can automatically apply patches found in this directory. The APPLY_PATCHES flag +should be used to explicitly enable this feature. For example, +lets say our extension config looks like this: + +```shell +duckdb_extension_load(spatial + DONT_LINK + GIT_URL https://github.com/duckdblabs/duckdb_spatial.git + GIT_TAG f577b9441793f9170403e489f5d3587e023a945f + APPLY_PATCHES +) +``` +In this example, upon downloading the spatial extension, all patches in the `.github/patches/extensions/spatial/*.patch` +will be automatically applied. + +Note that the reason for having the APPLY_PATCHES flag explicitly enabled is to make it easier for developers reading +the extension config to detect a patch is present. For this reason, the patching mechanism will actually fail if `APPLY_PATCHES` +is set with no patches in `.github/patches/extensions//*.patch`. + +# Workflow +Imagine a change to DuckDB is introduced that breaks compatibility with extension X. The +workflow for this is as follows: + +### PR #1: breaking change to DuckDB +- Commit breaking change to DuckDB +- Fix breakage in extension X, producing a patch with fix (be wary of already existing patches) +- Commit patch in `.github/patches/extensions/x/*.patch` using a descriptive name +- enable APPLY_PATCHES for extension X in `.github/config/out_of_tree_extensions.cmake` (if not already enabled) + +### PR #2: patch to extension X +- Apply (all) the patch(es) in `.github/patches/extensions/x/*.patch` to extension X. + +### PR #3: update extension X in DuckDB +- Remove patches in `.github/patches/extensions/x/*.patch` +- Remove `APPLY_PATCHES` flag from config +- Update hash of extension in config + + + + diff --git a/.github/patches/extensions/iceberg/filesystem_api_change.patch b/.github/patches/extensions/iceberg/filesystem_api_change.patch new file mode 100644 index 000000000000..25cc90e66d0c --- /dev/null +++ b/.github/patches/extensions/iceberg/filesystem_api_change.patch @@ -0,0 +1,20 @@ +diff --git a/duckdb b/duckdb +index 35fde53..2f131ae 160000 +--- a/duckdb ++++ b/duckdb +@@ -1 +1 @@ +-Subproject commit 35fde53437d93697136953ef4e23d89bf35018b2 ++Subproject commit 2f131ae80ad7b1c10cdab2e132d4d7ac87e4d75f +diff --git a/src/common/iceberg.cpp b/src/common/iceberg.cpp +index 9f43afd..ebea592 100644 +--- a/src/common/iceberg.cpp ++++ b/src/common/iceberg.cpp +@@ -138,7 +138,7 @@ IcebergSnapshot IcebergSnapshot::ParseSnapShot(yyjson_val *snapshot) { + + idx_t IcebergSnapshot::GetTableVersion(string &path, FileSystem &fs) { + auto meta_path = fs.JoinPath(path, "metadata"); +- auto version_file_path = FileSystem::JoinPath(meta_path, "version-hint.text"); ++ auto version_file_path = fs.JoinPath(meta_path, "version-hint.text"); + auto version_file_content = IcebergUtils::FileToString(version_file_path, fs); + + try { diff --git a/.github/workflows/InternalIssuesCreateMirror.yml b/.github/workflows/InternalIssuesCreateMirror.yml new file mode 100644 index 000000000000..9c56ff546a67 --- /dev/null +++ b/.github/workflows/InternalIssuesCreateMirror.yml @@ -0,0 +1,56 @@ +name: Create or Label Mirror Issue +on: + issues: + types: + - labeled + +env: + GH_TOKEN: ${{ secrets.DUCKDBLABS_BOT_TOKEN }} + TITLE_PREFIX: "[duckdb/#${{ github.event.issue.number }}]" + PUBLIC_ISSUE_TITLE: ${{ github.event.issue.title }} + +jobs: + create_or_label_issue: + if: github.event.label.name == 'reproduced' || github.event.label.name == 'under review' + runs-on: ubuntu-latest + steps: + - name: Remove needs triage / under review if reproduced + if: github.event.label.name == 'reproduced' + run: | + gh issue edit --repo duckdb/duckdb ${{ github.event.issue.number }} --remove-label "needs triage" --remove-label "under review" + + - name: Remove needs triage / reproduced if under review + if: github.event.label.name == 'under review' + run: | + gh issue edit --repo duckdb/duckdb ${{ github.event.issue.number }} --remove-label "needs triage" --remove-label "reproduced" + + - name: Get mirror issue number + run: | + gh issue list --repo duckdblabs/duckdb-internal --json title,number --jq ".[] | select(.title | startswith(\"$TITLE_PREFIX\")).number" > mirror_issue_number.txt + echo "MIRROR_ISSUE_NUMBER=$(cat mirror_issue_number.txt)" >> $GITHUB_ENV + + - name: Print whether mirror issue exists + run: | + if [ "$MIRROR_ISSUE_NUMBER" == "" ]; then + echo "Mirror issue with title prefix '$TITLE_PREFIX' does not exist yet" + else + echo "Mirror issue with title prefix '$TITLE_PREFIX' exists with number $MIRROR_ISSUE_NUMBER" + fi + + - name: Set label environment variable + run: | + if ${{ github.event.label.name == 'reproduced' }}; then + echo "LABEL=needs label" >> $GITHUB_ENV + echo "UNLABEL=needs triage" >> $GITHUB_ENV + else + echo "LABEL=needs triage" >> $GITHUB_ENV + echo "UNLABEL=needs label" >> $GITHUB_ENV + fi + + - name: Create or label issue + run: | + if [ "$MIRROR_ISSUE_NUMBER" == "" ]; then + gh issue create --repo duckdblabs/duckdb-internal --label "$LABEL" --title "$TITLE_PREFIX - $PUBLIC_ISSUE_TITLE" --body "See https://github.com/duckdb/duckdb/issues/${{ github.event.issue.number }}" + else + gh issue edit --repo duckdblabs/duckdb-internal $MIRROR_ISSUE_NUMBER --remove-label "$UNLABEL" --add-label "$LABEL" + fi diff --git a/.github/workflows/InternalIssuesUpdateMirror.yml b/.github/workflows/InternalIssuesUpdateMirror.yml new file mode 100644 index 000000000000..ba50e59c8d9f --- /dev/null +++ b/.github/workflows/InternalIssuesUpdateMirror.yml @@ -0,0 +1,48 @@ +name: Update Mirror Issue +on: + issues: + types: + - closed + - reopened + +env: + GH_TOKEN: ${{ secrets.DUCKDBLABS_BOT_TOKEN }} + TITLE_PREFIX: "[duckdb/#${{ github.event.issue.number }}]" + +jobs: + update_mirror_issue: + runs-on: ubuntu-latest + steps: + - name: Get mirror issue number + run: | + gh issue list --repo duckdblabs/duckdb-internal --json title,number --jq ".[] | select(.title | startswith(\"$TITLE_PREFIX\")).number" > mirror_issue_number.txt + echo "MIRROR_ISSUE_NUMBER=$(cat mirror_issue_number.txt)" >> $GITHUB_ENV + + - name: Print whether mirror issue exists + run: | + if [ "$MIRROR_ISSUE_NUMBER" == "" ]; then + echo "Mirror issue with title prefix '$TITLE_PREFIX' does not exist yet" + else + echo "Mirror issue with title prefix '$TITLE_PREFIX' exists with number $MIRROR_ISSUE_NUMBER" + fi + + - name: Add comment with status to mirror issue + run: | + if [ "$MIRROR_ISSUE_NUMBER" != "" ]; then + gh issue comment --repo duckdblabs/duckdb-internal $MIRROR_ISSUE_NUMBER --body "The issue has been ${{ github.event.action }}." + fi + + - name: Add closed label to mirror issue + if: github.event.action == 'closed' + run: | + if [ "$MIRROR_ISSUE_NUMBER" != "" ]; then + gh issue edit --repo duckdblabs/duckdb-internal $MIRROR_ISSUE_NUMBER --add-label "public closed" --remove-label "public reopened" + fi + + - name: Reopen mirror issue and add reopened label + if: github.event.action == 'reopened' + run: | + if [ "$MIRROR_ISSUE_NUMBER" != "" ]; then + gh issue reopen --repo duckdblabs/duckdb-internal $MIRROR_ISSUE_NUMBER + gh issue edit --repo duckdblabs/duckdb-internal $MIRROR_ISSUE_NUMBER --add-label "public reopened" --remove-label "public closed" + fi diff --git a/.github/workflows/IssuesCloseStale.yml b/.github/workflows/IssuesCloseStale.yml new file mode 100644 index 000000000000..c94b4e3f5a2e --- /dev/null +++ b/.github/workflows/IssuesCloseStale.yml @@ -0,0 +1,18 @@ +name: Close Stale Issues +on: + repository_dispatch: + workflow_dispatch: + +jobs: + close_stale_issues: + runs-on: ubuntu-latest + steps: + - name: Close stale issues + uses: actions/stale@v8 + with: + stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 30 days.' + close-issue-message: 'This issue was closed because it has been stale for 30 days with no activity.' + days-before-stale: 90 + days-before-close: 30 + operations-per-run: 500 + stale-issue-label: stale diff --git a/.github/workflows/lcov_exclude b/.github/workflows/lcov_exclude index 88832188f2f0..3c0919320605 100644 --- a/.github/workflows/lcov_exclude +++ b/.github/workflows/lcov_exclude @@ -25,3 +25,4 @@ */enum_util.cpp */enums/expression_type.cpp */serialization/* +*/json_enums.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d96f4abd5970..1d0a4f5ff1dd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -553,7 +553,10 @@ endfunction() function(add_extension_definitions) include_directories(${PROJECT_SOURCE_DIR}/extension) - + if(NOT("${TEST_WITH_LOADABLE_EXTENSION}" STREQUAL "")) + add_definitions(-DDUCKDB_EXTENSIONS_TEST_WITH_LOADABLE="${TEST_WITH_LOADABLE_EXTENSION}") + add_definitions(-DDUCKDB_EXTENSIONS_BUILD_PATH="${CMAKE_BINARY_DIR}/extension") + endif() if(NOT("${TEST_REMOTE_INSTALL}" STREQUAL "OFF")) add_definitions(-DDUCKDB_TEST_REMOTE_INSTALL="${TEST_REMOTE_INSTALL}") endif() @@ -746,13 +749,17 @@ function(register_extension NAME DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PA endfunction() # Downloads the external extension repo at the specified commit and calls register_extension -macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH) +macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH APPLY_PATCHES) include(FetchContent) + if (${APPLY_PATCHES}) + set(PATCH_COMMAND python3 ${CMAKE_SOURCE_DIR}/scripts/apply_extension_patches.py ${CMAKE_SOURCE_DIR}/.github/patches/extensions/${NAME}/) + endif() FETCHCONTENT_DECLARE( ${NAME}_extension_fc GIT_REPOSITORY ${URL} GIT_TAG ${COMMIT} GIT_SUBMODULES "" + PATCH_COMMAND ${PATCH_COMMAND} ) message(STATUS "Load extension '${NAME}' from ${URL} @ ${COMMIT}") FETCHCONTENT_POPULATE(${NAME}_EXTENSION_FC) @@ -774,7 +781,7 @@ endmacro() function(duckdb_extension_load NAME) # Parameter parsing - set(options DONT_LINK DONT_BUILD LOAD_TESTS) + set(options DONT_LINK DONT_BUILD LOAD_TESTS APPLY_PATCHES) set(oneValueArgs SOURCE_DIR INCLUDE_DIR TEST_DIR GIT_URL GIT_TAG) cmake_parse_arguments(duckdb_extension_load "${options}" "${oneValueArgs}" "" ${ARGN}) @@ -794,7 +801,7 @@ function(duckdb_extension_load NAME) if (NOT "${duckdb_extension_load_GIT_COMMIT}" STREQUAL "") error("Git URL specified but no valid git commit was found for ${NAME} extension") endif() - register_external_extension(${NAME} "${duckdb_extension_load_GIT_URL}" "${duckdb_extension_load_GIT_TAG}" "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${duckdb_extension_load_INCLUDE_DIR}" "${duckdb_extension_load_TEST_DIR}") + register_external_extension(${NAME} "${duckdb_extension_load_GIT_URL}" "${duckdb_extension_load_GIT_TAG}" "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${duckdb_extension_load_INCLUDE_DIR}" "${duckdb_extension_load_TEST_DIR}" "${duckdb_extension_load_APPLY_PATCHES}") elseif (NOT "${duckdb_extension_load_SOURCE_DIR}" STREQUAL "") # Local extension, custom path message(STATUS "Load extension '${NAME}' from '${duckdb_extension_load_SOURCE_DIR}'") @@ -839,6 +846,14 @@ if(${EXPORT_DLL_SYMBOLS}) add_definitions(-DDUCKDB_BUILD_LIBRARY) endif() +# Log extensions that are built by directly passing cmake variables +foreach(EXT IN LISTS DUCKDB_EXTENSION_NAMES) + if (NOT "${EXT}" STREQUAL "") + string(TOUPPER ${EXT} EXTENSION_NAME_UPPERCASE) + message(STATUS "Load extension '${EXT}' from '${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_PATH}'") + endif() +endforeach() + # Load extensions passed through cmake config var foreach(EXT IN LISTS SKIP_EXTENSIONS) if (NOT "${EXT}" STREQUAL "") @@ -882,6 +897,15 @@ endif() # Load base extension config include(${CMAKE_CURRENT_SOURCE_DIR}/extension/extension_config.cmake) +# For extensions whose tests were loaded, but not linked into duckdb, we need to ensure they are registered to have +# the sqllogictest "require" statement load the loadable extensions instead of the baked in static one +foreach(EXT_NAME IN LISTS DUCKDB_EXTENSION_NAMES) + string(TOUPPER ${EXT_NAME} EXT_NAME_UPPERCASE) + if (NOT "${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_SHOULD_LINK}" AND "${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_LOAD_TESTS}") + list(APPEND TEST_WITH_LOADABLE_EXTENSION ${EXT_NAME}) + endif() +endforeach() + if (BUILD_MAIN_DUCKDB_LIBRARY) add_subdirectory(src) add_subdirectory(tools) @@ -901,7 +925,7 @@ foreach(EXT_NAME IN LISTS DUCKDB_EXTENSION_NAMES) endif() # Warning for trying to load vcpkg extensions without having VCPKG_BUILD SET - if (EXISTS "${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_PATH}/vcpkg.json" AND NOT "${VCPKG_BUILD}") + if (EXISTS "${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_PATH}/vcpkg.json" AND NOT DEFINED VCPKG_BUILD) message(WARNING "Extension '${EXT_NAME}' has a vcpkg.json, but build was not run with VCPKG. If build fails, check out VCPKG build instructions in 'duckdb/extension/README.md' or try manually installing the dependencies in ${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_PATH}vcpkg.json") endif() @@ -934,19 +958,19 @@ endforeach() if(NOT "${LINKED_EXTENSIONS}" STREQUAL "") string(REPLACE ";" ", " EXT_LIST_DEBUG_MESSAGE "${LINKED_EXTENSIONS}") - message(STATUS "Extensions linked into DuckDB: ${EXT_LIST_DEBUG_MESSAGE}") + message(STATUS "Extensions linked into DuckDB: [${EXT_LIST_DEBUG_MESSAGE}]") endif() if(NOT "${NONLINKED_EXTENSIONS}" STREQUAL "") string(REPLACE ";" ", " EXT_LIST_DEBUG_MESSAGE "${NONLINKED_EXTENSIONS}") - message(STATUS "Extensions built but not linked: ${EXT_LIST_DEBUG_MESSAGE}") + message(STATUS "Extensions built but not linked: [${EXT_LIST_DEBUG_MESSAGE}]") endif() if(NOT "${SKIPPED_EXTENSIONS}" STREQUAL "") string(REPLACE ";" ", " EXT_LIST_DEBUG_MESSAGE "${SKIPPED_EXTENSIONS}") - message(STATUS "Extensions explicitly skipped: ${EXT_LIST_DEBUG_MESSAGE}") + message(STATUS "Extensions explicitly skipped: [${EXT_LIST_DEBUG_MESSAGE}]") endif() if(NOT "${TEST_LOADED_EXTENSIONS}" STREQUAL "") string(REPLACE ";" ", " EXT_LIST_DEBUG_MESSAGE "${TEST_LOADED_EXTENSIONS}") - message(STATUS "Tests loaded from extensions: ${EXT_LIST_DEBUG_MESSAGE}") + message(STATUS "Tests loaded for extensions: [${EXT_LIST_DEBUG_MESSAGE}]") endif() # Special build where instead of building duckdb, we produce several artifact that require parsing the diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 250b4615d4bf..1fb2278e6ded 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -70,8 +70,8 @@ This project and everyone participating in it is governed by a [Code of Conduct] * Use tabs for indentation, spaces for alignment. * Lines should not exceed 120 columns. -* To make sure the formatting is consistent, please use version 11.0.1, installable through `python3 -m pip install clang-format==11.0.1` -* `clang_format` enforces these rules automatically, use `make format-fix` to run the formatter. +* To make sure the formatting is consistent, please use version 10.0.1, installable through `python3 -m pip install clang-format==10.0.1.1` +* `clang_format` and `black` enforce these rules automatically, use `make format-fix` to run the formatter. * The project also comes with an [`.editorconfig` file](https://editorconfig.org/) that corresponds to these rules. ## C++ Guidelines diff --git a/benchmark/benchmark_runner.cpp b/benchmark/benchmark_runner.cpp index e912ded54479..9710b7ebe99f 100644 --- a/benchmark/benchmark_runner.cpp +++ b/benchmark/benchmark_runner.cpp @@ -173,17 +173,16 @@ void print_help() { enum ConfigurationError { None, BenchmarkNotFound, InfoWithoutBenchmarkName }; -void LoadInterpretedBenchmarks() { +void LoadInterpretedBenchmarks(FileSystem &fs) { // load interpreted benchmarks - duckdb::unique_ptr fs = FileSystem::CreateLocal(); - listFiles(*fs, "benchmark", [](const string &path) { + listFiles(fs, "benchmark", [](const string &path) { if (endsWith(path, ".benchmark")) { new InterpretedBenchmark(path); } }); } -string parse_root_dir_or_default(const int arg_counter, char const *const *arg_values) { +string parse_root_dir_or_default(const int arg_counter, char const *const *arg_values, FileSystem &fs) { // check if the user specified a different root directory for (int arg_index = 1; arg_index < arg_counter; ++arg_index) { string arg = arg_values[arg_index]; @@ -194,10 +193,10 @@ string parse_root_dir_or_default(const int arg_counter, char const *const *arg_v exit(1); } auto path = arg_values[arg_index + 1]; - if (FileSystem::IsPathAbsolute(path)) { + if (fs.IsPathAbsolute(path)) { return path; } else { - return FileSystem::JoinPath(FileSystem::GetWorkingDirectory(), path); + return fs.JoinPath(FileSystem::GetWorkingDirectory(), path); } } } @@ -335,11 +334,12 @@ void print_error_message(const ConfigurationError &error) { } int main(int argc, char **argv) { + duckdb::unique_ptr fs = FileSystem::CreateLocal(); // Set the working directory. We need to scan this before loading the benchmarks or parsing the other arguments - string root_dir = parse_root_dir_or_default(argc, argv); + string root_dir = parse_root_dir_or_default(argc, argv, *fs); FileSystem::SetWorkingDirectory(root_dir); // load interpreted benchmarks before doing anything else - LoadInterpretedBenchmarks(); + LoadInterpretedBenchmarks(*fs); parse_arguments(argc, argv); const auto configuration_error = run_benchmarks(); if (configuration_error != ConfigurationError::None) { diff --git a/data/parquet-testing/encryption/encrypted_column.parquet b/data/parquet-testing/encryption/encrypted_column.parquet new file mode 100644 index 000000000000..40511acbf100 Binary files /dev/null and b/data/parquet-testing/encryption/encrypted_column.parquet differ diff --git a/data/parquet-testing/encryption/encrypted_footer.parquet b/data/parquet-testing/encryption/encrypted_footer.parquet new file mode 100644 index 000000000000..9e695c6c4dd6 Binary files /dev/null and b/data/parquet-testing/encryption/encrypted_footer.parquet differ diff --git a/extension/autocomplete/autocomplete_extension.cpp b/extension/autocomplete/autocomplete_extension.cpp index 96e2000522c5..de7ec6b39dd4 100644 --- a/extension/autocomplete/autocomplete_extension.cpp +++ b/extension/autocomplete/autocomplete_extension.cpp @@ -196,7 +196,7 @@ static vector SuggestFileName(ClientContext &context, str fs.ListFiles(search_dir, [&](const string &fname, bool is_dir) { string suggestion; if (is_dir) { - suggestion = fname + fs.PathSeparator(); + suggestion = fname + fs.PathSeparator(fname); } else { suggestion = fname + "'"; } diff --git a/extension/httpfs/include/httpfs.hpp b/extension/httpfs/include/httpfs.hpp index 0f4a3f1d83a3..3745ce1ab74e 100644 --- a/extension/httpfs/include/httpfs.hpp +++ b/extension/httpfs/include/httpfs.hpp @@ -135,7 +135,9 @@ class HTTPFileSystem : public FileSystem { string GetName() const override { return "HTTPFileSystem"; } - + string PathSeparator(const string &path) override { + return "/"; + } static void Verify(); // Global cache diff --git a/extension/httpfs/s3fs.cpp b/extension/httpfs/s3fs.cpp index 3daad7454f90..8179e4f5bfac 100644 --- a/extension/httpfs/s3fs.cpp +++ b/extension/httpfs/s3fs.cpp @@ -975,7 +975,7 @@ string S3FileSystem::GetName() const { bool S3FileSystem::ListFiles(const string &directory, const std::function &callback, FileOpener *opener) { string trimmed_dir = directory; - StringUtil::RTrim(trimmed_dir, PathSeparator()); + StringUtil::RTrim(trimmed_dir, PathSeparator(trimmed_dir)); auto glob_res = Glob(JoinPath(trimmed_dir, "**"), opener); if (glob_res.empty()) { diff --git a/extension/json/CMakeLists.txt b/extension/json/CMakeLists.txt index d5349f309104..b2626205d1ee 100644 --- a/extension/json/CMakeLists.txt +++ b/extension/json/CMakeLists.txt @@ -11,10 +11,12 @@ set(JSON_EXTENSION_FILES buffered_json_reader.cpp json_extension.cpp json_common.cpp + json_enums.cpp json_functions.cpp json_scan.cpp json_serializer.cpp json_deserializer.cpp + serialize_json.cpp json_functions/copy_json.cpp json_functions/json_array_length.cpp json_functions/json_contains.cpp diff --git a/extension/json/buffered_json_reader.cpp b/extension/json/buffered_json_reader.cpp index 5a408b9444d3..a21ba941ac64 100644 --- a/extension/json/buffered_json_reader.cpp +++ b/extension/json/buffered_json_reader.cpp @@ -3,6 +3,8 @@ #include "duckdb/common/field_writer.hpp" #include "duckdb/common/file_opener.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { diff --git a/extension/json/include/buffered_json_reader.hpp b/extension/json/include/buffered_json_reader.hpp index 4918b29a056e..0ad3350a3ef5 100644 --- a/extension/json/include/buffered_json_reader.hpp +++ b/extension/json/include/buffered_json_reader.hpp @@ -14,28 +14,11 @@ #include "duckdb/common/multi_file_reader.hpp" #include "duckdb/common/mutex.hpp" #include "json_common.hpp" +#include "json_enums.hpp" +#include "duckdb/common/enum_util.hpp" namespace duckdb { -enum class JSONFormat : uint8_t { - //! Auto-detect format (UNSTRUCTURED / NEWLINE_DELIMITED) - AUTO_DETECT = 0, - //! One unit after another, newlines can be anywhere - UNSTRUCTURED = 1, - //! Units are separated by newlines, newlines do not occur within Units (NDJSON) - NEWLINE_DELIMITED = 2, - //! File is one big array of units - ARRAY = 3, -}; - -enum class JSONRecordType : uint8_t { - AUTO_DETECT = 0, - //! Sequential objects that are unpacked - RECORDS = 1, - //! Any other JSON type, e.g., ARRAY - VALUES = 2, -}; - struct BufferedJSONReaderOptions { public: //! The format of the JSON @@ -50,6 +33,9 @@ struct BufferedJSONReaderOptions { public: void Serialize(FieldWriter &writer) const; void Deserialize(FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const; + static BufferedJSONReaderOptions FormatDeserialize(FormatDeserializer &deserializer); }; struct JSONBufferHandle { diff --git a/extension/json/include/json.json b/extension/json/include/json.json new file mode 100644 index 000000000000..aa09cdf31f21 --- /dev/null +++ b/extension/json/include/json.json @@ -0,0 +1,120 @@ +[ + { + "class": "BufferedJSONReaderOptions", + "includes": [ + "buffered_json_reader.hpp" + ], + "members": [ + { + "name": "format", + "type": "JSONFormat" + }, + { + "name": "record_type", + "type": "JSONRecordType" + }, + { + "name": "compression", + "type": "FileCompressionType" + }, + { + "name": "file_options", + "type": "MultiFileReaderOptions" + } + ], + "pointer_type": "none" + }, + { + "class": "JSONTransformOptions", + "includes": [ + "json_transform.hpp" + ], + "members": [ + { + "name": "strict_cast", + "type": "bool" + }, + { + "name": "error_duplicate_key", + "type": "bool" + }, + { + "name": "error_missing_key", + "type": "bool" + }, + { + "name": "error_unknown_key", + "type": "bool" + }, + { + "name": "delay_error", + "type": "bool" + } + ], + "pointer_type": "none" + }, + { + "class": "JSONScanData", + "includes": [ + "json_scan.hpp" + ], + "members": [ + { + "name": "json_type", + "type": "JSONScanType", + "property": "type" + }, + { + "name": "options", + "type": "BufferedJSONReaderOptions" + }, + { + "name": "reader_bind", + "type": "MultiFileReaderBindData" + }, + { + "name": "files", + "type": "vector" + }, + { + "name": "ignore_errors", + "type": "bool" + }, + { + "name": "maximum_object_size", + "type": "idx_t" + }, + { + "name": "auto_detect", + "type": "bool" + }, + { + "name": "sample_size", + "type": "idx_t" + }, + { + "name": "max_depth", + "type": "idx_t" + }, + { + "name": "transform_options", + "type": "JSONTransformOptions" + }, + { + "name": "names", + "type": "vector" + }, + { + "name": "date_format", + "type": "string", + "serialize_property": "GetDateFormat()" + }, + { + "name": "timestamp_format", + "type": "string", + "serialize_property": "GetTimestampFormat()" + } + ], + "constructor": ["$ClientContext", "files", "date_format", "timestamp_format"] + } +] diff --git a/extension/json/include/json_enums.hpp b/extension/json/include/json_enums.hpp new file mode 100644 index 000000000000..86e8b3093edf --- /dev/null +++ b/extension/json/include/json_enums.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_enums.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +enum class JSONScanType : uint8_t { + INVALID = 0, + //! Read JSON straight to columnar data + READ_JSON = 1, + //! Read JSON values as strings + READ_JSON_OBJECTS = 2, + //! Sample run for schema detection + SAMPLE = 3, +}; + +enum class JSONRecordType : uint8_t { + AUTO_DETECT = 0, + //! Sequential objects that are unpacked + RECORDS = 1, + //! Any other JSON type, e.g., ARRAY + VALUES = 2, +}; + +enum class JSONFormat : uint8_t { + //! Auto-detect format (UNSTRUCTURED / NEWLINE_DELIMITED) + AUTO_DETECT = 0, + //! One unit after another, newlines can be anywhere + UNSTRUCTURED = 1, + //! Units are separated by newlines, newlines do not occur within Units (NDJSON) + NEWLINE_DELIMITED = 2, + //! File is one big array of units + ARRAY = 3, +}; + +template<> +const char* EnumUtil::ToChars(JSONScanType value); + +template<> +JSONScanType EnumUtil::FromString(const char *value); + +template<> +const char* EnumUtil::ToChars(JSONRecordType value); + +template<> +JSONRecordType EnumUtil::FromString(const char *value); + +template<> +const char* EnumUtil::ToChars(JSONFormat value); + +template<> +JSONFormat EnumUtil::FromString(const char *value); + +} // namespace duckdb diff --git a/extension/json/include/json_enums.json b/extension/json/include/json_enums.json new file mode 100644 index 000000000000..dc69cde082d7 --- /dev/null +++ b/extension/json/include/json_enums.json @@ -0,0 +1,55 @@ +[ + { + "name": "JSONScanType", + "values": [ + "INVALID", + { + "name": "READ_JSON", + "comment": "Read JSON straight to columnar data" + }, + { + "name": "READ_JSON_OBJECTS", + "comment": "Read JSON values as strings" + }, + { + "name": "SAMPLE", + "comment": "Sample run for schema detection" + } + ] + }, + { + "name": "JSONRecordType", + "values": [ + "AUTO_DETECT", + { + "name": "RECORDS", + "comment": "Sequential objects that are unpacked" + }, + { + "name": "VALUES", + "comment": "Any other JSON type, e.g., ARRAY" + } + ] + }, + { + "name": "JSONFormat", + "values": [ + { + "name": "AUTO_DETECT", + "comment": "Auto-detect format (UNSTRUCTURED / NEWLINE_DELIMITED)" + }, + { + "name": "UNSTRUCTURED", + "comment": "One unit after another, newlines can be anywhere" + }, + { + "name": "NEWLINE_DELIMITED", + "comment": "Units are separated by newlines, newlines do not occur within Units (NDJSON)" + }, + { + "name": "ARRAY", + "comment": "File is one big array of units" + } + ] + } +] diff --git a/extension/json/include/json_scan.hpp b/extension/json/include/json_scan.hpp index de2423c53d02..a92ac76b93a7 100644 --- a/extension/json/include/json_scan.hpp +++ b/extension/json/include/json_scan.hpp @@ -9,6 +9,7 @@ #pragma once #include "buffered_json_reader.hpp" +#include "json_enums.hpp" #include "duckdb/common/multi_file_reader.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/pair.hpp" @@ -19,16 +20,6 @@ namespace duckdb { -enum class JSONScanType : uint8_t { - INVALID = 0, - //! Read JSON straight to columnar data - READ_JSON = 1, - //! Read JSON values as strings - READ_JSON_OBJECTS = 2, - //! Sample run for schema detection - SAMPLE = 3, -}; - struct JSONString { public: JSONString() { @@ -104,6 +95,9 @@ struct JSONScanData : public TableFunctionData { void Serialize(FieldWriter &writer) const; void Deserialize(ClientContext &context, FieldReader &reader); + void FormatSerialize(FormatSerializer &serializer) const; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); + public: //! Scan type JSONScanType type; @@ -144,6 +138,12 @@ struct JSONScanData : public TableFunctionData { //! The inferred avg tuple size idx_t avg_tuple_size = 420; + +private: + JSONScanData(ClientContext &context, vector files, string date_format, string timestamp_format); + + string GetDateFormat() const; + string GetTimestampFormat() const; }; struct JSONScanInfo : public TableFunctionInfo { @@ -295,6 +295,10 @@ struct JSONScan { static unique_ptr Deserialize(PlanDeserializationState &state, FieldReader &reader, TableFunction &function); + static void FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data, + const TableFunction &function); + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer, TableFunction &function); + static void TableFunctionDefaults(TableFunction &table_function); }; diff --git a/extension/json/include/json_transform.hpp b/extension/json/include/json_transform.hpp index 68ada36e1683..e45d8f7cd679 100644 --- a/extension/json/include/json_transform.hpp +++ b/extension/json/include/json_transform.hpp @@ -44,6 +44,9 @@ struct JSONTransformOptions { public: void Serialize(FieldWriter &writer) const; void Deserialize(FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const; + static JSONTransformOptions FormatDeserialize(FormatDeserializer &deserializer); }; struct TryParseDate { diff --git a/extension/json/json_config.py b/extension/json/json_config.py index d3003e20c7ac..543d4f23f59d 100644 --- a/extension/json/json_config.py +++ b/extension/json/json_config.py @@ -9,6 +9,7 @@ os.path.sep.join(x.split('/')) for x in [ 'extension/json/buffered_json_reader.cpp', + 'extension/json/json_enums.cpp', 'extension/json/json_extension.cpp', 'extension/json/json_common.cpp', 'extension/json/json_functions.cpp', @@ -30,5 +31,6 @@ 'extension/json/json_functions/json_serialize_sql.cpp', 'extension/json/json_serializer.cpp', 'extension/json/json_deserializer.cpp', + 'extension/json/serialize_json.cpp', ] ] diff --git a/extension/json/json_enums.cpp b/extension/json/json_enums.cpp new file mode 100644 index 000000000000..06e03f85e3de --- /dev/null +++ b/extension/json/json_enums.cpp @@ -0,0 +1,105 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_enums.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "json_enums.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +template<> +const char* EnumUtil::ToChars(JSONScanType value) { + switch(value) { + case JSONScanType::INVALID: + return "INVALID"; + case JSONScanType::READ_JSON: + return "READ_JSON"; + case JSONScanType::READ_JSON_OBJECTS: + return "READ_JSON_OBJECTS"; + case JSONScanType::SAMPLE: + return "SAMPLE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value of type JSONScanType: '%d' not implemented", value)); + } +} + +template<> +JSONScanType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "INVALID")) { + return JSONScanType::INVALID; + } + if (StringUtil::Equals(value, "READ_JSON")) { + return JSONScanType::READ_JSON; + } + if (StringUtil::Equals(value, "READ_JSON_OBJECTS")) { + return JSONScanType::READ_JSON_OBJECTS; + } + if (StringUtil::Equals(value, "SAMPLE")) { + return JSONScanType::SAMPLE; + } + throw NotImplementedException(StringUtil::Format("Enum value of type JSONScanType: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(JSONRecordType value) { + switch(value) { + case JSONRecordType::AUTO_DETECT: + return "AUTO_DETECT"; + case JSONRecordType::RECORDS: + return "RECORDS"; + case JSONRecordType::VALUES: + return "VALUES"; + default: + throw NotImplementedException(StringUtil::Format("Enum value of type JSONRecordType: '%d' not implemented", value)); + } +} + +template<> +JSONRecordType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTO_DETECT")) { + return JSONRecordType::AUTO_DETECT; + } + if (StringUtil::Equals(value, "RECORDS")) { + return JSONRecordType::RECORDS; + } + if (StringUtil::Equals(value, "VALUES")) { + return JSONRecordType::VALUES; + } + throw NotImplementedException(StringUtil::Format("Enum value of type JSONRecordType: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(JSONFormat value) { + switch(value) { + case JSONFormat::AUTO_DETECT: + return "AUTO_DETECT"; + case JSONFormat::UNSTRUCTURED: + return "UNSTRUCTURED"; + case JSONFormat::NEWLINE_DELIMITED: + return "NEWLINE_DELIMITED"; + case JSONFormat::ARRAY: + return "ARRAY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value of type JSONFormat: '%d' not implemented", value)); + } +} + +template<> +JSONFormat EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTO_DETECT")) { + return JSONFormat::AUTO_DETECT; + } + if (StringUtil::Equals(value, "UNSTRUCTURED")) { + return JSONFormat::UNSTRUCTURED; + } + if (StringUtil::Equals(value, "NEWLINE_DELIMITED")) { + return JSONFormat::NEWLINE_DELIMITED; + } + if (StringUtil::Equals(value, "ARRAY")) { + return JSONFormat::ARRAY; + } + throw NotImplementedException(StringUtil::Format("Enum value of type JSONFormat: '%s' not implemented", value)); +} + +} // namespace duckdb diff --git a/extension/json/json_functions.cpp b/extension/json/json_functions.cpp index 1303fd63004f..5de50f1d69ed 100644 --- a/extension/json/json_functions.cpp +++ b/extension/json/json_functions.cpp @@ -209,7 +209,8 @@ unique_ptr JSONFunctions::ReadJSONReplacement(ClientContext &context, table_function->function = make_uniq("read_json_auto", std::move(children)); if (!FileSystem::HasGlob(table_name)) { - table_function->alias = FileSystem::ExtractBaseName(table_name); + auto &fs = FileSystem::GetFileSystem(context); + table_function->alias = fs.ExtractBaseName(table_name); } return std::move(table_function); diff --git a/extension/json/json_functions/json_transform.cpp b/extension/json/json_functions/json_transform.cpp index e9a2fa16f850..0138fc180d7c 100644 --- a/extension/json/json_functions/json_transform.cpp +++ b/extension/json/json_functions/json_transform.cpp @@ -8,6 +8,8 @@ #include "duckdb/function/scalar/nested_functions.hpp" #include "json_functions.hpp" #include "json_scan.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { diff --git a/extension/json/json_scan.cpp b/extension/json/json_scan.cpp index e7d7e950d65e..de0c75f09013 100644 --- a/extension/json/json_scan.cpp +++ b/extension/json/json_scan.cpp @@ -5,12 +5,22 @@ #include "duckdb/main/extension_helper.hpp" #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { JSONScanData::JSONScanData() { } +JSONScanData::JSONScanData(ClientContext &context, vector files_p, string date_format_p, + string timestamp_format_p) + : files(std::move(files_p)), date_format(std::move(date_format_p)), + timestamp_format(std::move(timestamp_format_p)) { + InitializeReaders(context); + InitializeFormats(); +} + void JSONScanData::Bind(ClientContext &context, TableFunctionBindInput &input) { auto &info = input.info->Cast(); type = info.type; @@ -164,6 +174,26 @@ void JSONScanData::Deserialize(ClientContext &context, FieldReader &reader) { transform_options.date_format_map = &date_format_map; } +string JSONScanData::GetDateFormat() const { + if (!date_format.empty()) { + return date_format; + } else if (date_format_map.HasFormats(LogicalTypeId::DATE)) { + return date_format_map.GetFormat(LogicalTypeId::DATE).format_specifier; + } else { + return string(); + } +} + +string JSONScanData::GetTimestampFormat() const { + if (!timestamp_format.empty()) { + return timestamp_format; + } else if (date_format_map.HasFormats(LogicalTypeId::TIMESTAMP)) { + return date_format_map.GetFormat(LogicalTypeId::TIMESTAMP).format_specifier; + } else { + return string(); + } +} + JSONScanGlobalState::JSONScanGlobalState(ClientContext &context, const JSONScanData &bind_data_p) : bind_data(bind_data_p), transform_options(bind_data.transform_options), allocator(BufferManager::GetBufferManager(context).GetBufferAllocator()), @@ -966,6 +996,18 @@ unique_ptr JSONScan::Deserialize(PlanDeserializationState &state, return std::move(result); } +void JSONScan::FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("scan_data", bind_data); +} + +unique_ptr JSONScan::FormatDeserialize(FormatDeserializer &deserializer, TableFunction &function) { + unique_ptr result; + deserializer.ReadProperty("scan_data", result); + return std::move(result); +} + void JSONScan::TableFunctionDefaults(TableFunction &table_function) { MultiFileReader::AddParameters(table_function); @@ -980,6 +1022,8 @@ void JSONScan::TableFunctionDefaults(TableFunction &table_function) { table_function.serialize = Serialize; table_function.deserialize = Deserialize; + table_function.format_serialize = FormatSerialize; + table_function.format_deserialize = FormatDeserialize; table_function.projection_pushdown = true; table_function.filter_pushdown = false; diff --git a/extension/json/serialize_json.cpp b/extension/json/serialize_json.cpp new file mode 100644 index 000000000000..02064d21dda9 --- /dev/null +++ b/extension/json/serialize_json.cpp @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" +#include "buffered_json_reader.hpp" +#include "json_transform.hpp" +#include "json_scan.hpp" + +namespace duckdb { + +void BufferedJSONReaderOptions::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("format", format); + serializer.WriteProperty("record_type", record_type); + serializer.WriteProperty("compression", compression); + serializer.WriteProperty("file_options", file_options); +} + +BufferedJSONReaderOptions BufferedJSONReaderOptions::FormatDeserialize(FormatDeserializer &deserializer) { + BufferedJSONReaderOptions result; + deserializer.ReadProperty("format", result.format); + deserializer.ReadProperty("record_type", result.record_type); + deserializer.ReadProperty("compression", result.compression); + deserializer.ReadProperty("file_options", result.file_options); + return result; +} + +void JSONScanData::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("json_type", type); + serializer.WriteProperty("options", options); + serializer.WriteProperty("reader_bind", reader_bind); + serializer.WriteProperty("files", files); + serializer.WriteProperty("ignore_errors", ignore_errors); + serializer.WriteProperty("maximum_object_size", maximum_object_size); + serializer.WriteProperty("auto_detect", auto_detect); + serializer.WriteProperty("sample_size", sample_size); + serializer.WriteProperty("max_depth", max_depth); + serializer.WriteProperty("transform_options", transform_options); + serializer.WriteProperty("names", names); + serializer.WriteProperty("date_format", GetDateFormat()); + serializer.WriteProperty("timestamp_format", GetTimestampFormat()); +} + +unique_ptr JSONScanData::FormatDeserialize(FormatDeserializer &deserializer) { + auto type = deserializer.ReadProperty("json_type"); + auto options = deserializer.ReadProperty("options"); + auto reader_bind = deserializer.ReadProperty("reader_bind"); + auto files = deserializer.ReadProperty>("files"); + auto ignore_errors = deserializer.ReadProperty("ignore_errors"); + auto maximum_object_size = deserializer.ReadProperty("maximum_object_size"); + auto auto_detect = deserializer.ReadProperty("auto_detect"); + auto sample_size = deserializer.ReadProperty("sample_size"); + auto max_depth = deserializer.ReadProperty("max_depth"); + auto transform_options = deserializer.ReadProperty("transform_options"); + auto names = deserializer.ReadProperty>("names"); + auto date_format = deserializer.ReadProperty("date_format"); + auto timestamp_format = deserializer.ReadProperty("timestamp_format"); + auto result = duckdb::unique_ptr(new JSONScanData(deserializer.Get(), std::move(files), std::move(date_format), std::move(timestamp_format))); + result->type = type; + result->options = options; + result->reader_bind = reader_bind; + result->ignore_errors = ignore_errors; + result->maximum_object_size = maximum_object_size; + result->auto_detect = auto_detect; + result->sample_size = sample_size; + result->max_depth = max_depth; + result->transform_options = transform_options; + result->names = std::move(names); + return result; +} + +void JSONTransformOptions::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("strict_cast", strict_cast); + serializer.WriteProperty("error_duplicate_key", error_duplicate_key); + serializer.WriteProperty("error_missing_key", error_missing_key); + serializer.WriteProperty("error_unknown_key", error_unknown_key); + serializer.WriteProperty("delay_error", delay_error); +} + +JSONTransformOptions JSONTransformOptions::FormatDeserialize(FormatDeserializer &deserializer) { + JSONTransformOptions result; + deserializer.ReadProperty("strict_cast", result.strict_cast); + deserializer.ReadProperty("error_duplicate_key", result.error_duplicate_key); + deserializer.ReadProperty("error_missing_key", result.error_missing_key); + deserializer.ReadProperty("error_unknown_key", result.error_unknown_key); + deserializer.ReadProperty("delay_error", result.delay_error); + return result; +} + +} // namespace duckdb diff --git a/extension/parquet/CMakeLists.txt b/extension/parquet/CMakeLists.txt index 80d7405c8525..877ddf8689e5 100644 --- a/extension/parquet/CMakeLists.txt +++ b/extension/parquet/CMakeLists.txt @@ -15,6 +15,7 @@ set(PARQUET_EXTENSION_FILES parquet_timestamp.cpp parquet_writer.cpp parquet_statistics.cpp + serialize_parquet.cpp zstd_file_system.cpp column_reader.cpp) diff --git a/extension/parquet/include/parquet.json b/extension/parquet/include/parquet.json new file mode 100644 index 000000000000..cf17bd5b3303 --- /dev/null +++ b/extension/parquet/include/parquet.json @@ -0,0 +1,23 @@ +[ + { + "class": "ParquetOptions", + "includes": [ + "parquet_reader.hpp" + ], + "members": [ + { + "name": "binary_as_string", + "type": "bool" + }, + { + "name": "file_row_number", + "type": "bool" + }, + { + "name": "file_options", + "type": "MultiFileReaderOptions" + } + ], + "pointer_type": "none" + } +] diff --git a/extension/parquet/include/parquet_reader.hpp b/extension/parquet/include/parquet_reader.hpp index 6fa8f7758184..6a1d5000bcdd 100644 --- a/extension/parquet/include/parquet_reader.hpp +++ b/extension/parquet/include/parquet_reader.hpp @@ -76,6 +76,9 @@ struct ParquetOptions { public: void Serialize(FieldWriter &writer) const; void Deserialize(FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const; + static ParquetOptions FormatDeserialize(FormatDeserializer &deserializer); }; class ParquetReader { diff --git a/extension/parquet/parquet_config.py b/extension/parquet/parquet_config.py index 0848ff7f06ff..a21a1ff1baaf 100644 --- a/extension/parquet/parquet_config.py +++ b/extension/parquet/parquet_config.py @@ -17,6 +17,7 @@ for x in [ 'extension/parquet/parquet_extension.cpp', 'extension/parquet/column_writer.cpp', + 'extension/parquet/serialize_parquet.cpp', 'third_party/parquet/parquet_constants.cpp', 'third_party/parquet/parquet_types.cpp', 'third_party/thrift/thrift/protocol/TProtocol.cpp', diff --git a/extension/parquet/parquet_extension.cpp b/extension/parquet/parquet_extension.cpp index 95b8110029d7..9dd41b00b7e1 100644 --- a/extension/parquet/parquet_extension.cpp +++ b/extension/parquet/parquet_extension.cpp @@ -35,6 +35,8 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" #endif namespace duckdb { @@ -181,6 +183,8 @@ class ParquetScanFunction { table_function.get_batch_index = ParquetScanGetBatchIndex; table_function.serialize = ParquetScanSerialize; table_function.deserialize = ParquetScanDeserialize; + table_function.format_serialize = ParquetScanFormatSerialize; + table_function.format_deserialize = ParquetScanFormatDeserialize; table_function.get_batch_info = ParquetGetBatchInfo; table_function.projection_pushdown = true; table_function.filter_pushdown = true; @@ -430,6 +434,25 @@ class ParquetScanFunction { return ParquetScanBindInternal(context, files, types, names, options); } + static void ParquetScanFormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("files", bind_data.files); + serializer.WriteProperty("types", bind_data.types); + serializer.WriteProperty("names", bind_data.names); + serializer.WriteProperty("parquet_options", bind_data.parquet_options); + } + + static unique_ptr ParquetScanFormatDeserialize(FormatDeserializer &deserializer, + TableFunction &function) { + auto &context = deserializer.Get(); + auto files = deserializer.ReadProperty>("files"); + auto types = deserializer.ReadProperty>("types"); + auto names = deserializer.ReadProperty>("names"); + auto parquet_options = deserializer.ReadProperty("parquet_options"); + return ParquetScanBindInternal(context, files, types, names, parquet_options); + } + static void ParquetScanImplementation(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { if (!data_p.local_state) { return; @@ -938,7 +961,8 @@ unique_ptr ParquetScanReplacement(ClientContext &context, const string table_function->function = make_uniq("parquet_scan", std::move(children)); if (!FileSystem::HasGlob(table_name)) { - table_function->alias = FileSystem::ExtractBaseName(table_name); + auto &fs = FileSystem::GetFileSystem(context); + table_function->alias = fs.ExtractBaseName(table_name); } return std::move(table_function); diff --git a/extension/parquet/parquet_reader.cpp b/extension/parquet/parquet_reader.cpp index 61066c017e41..f6f5723ae18d 100644 --- a/extension/parquet/parquet_reader.cpp +++ b/extension/parquet/parquet_reader.cpp @@ -72,6 +72,9 @@ static shared_ptr LoadMetadata(Allocator &allocator, F transport.read((uint8_t *)buf.ptr, 8); if (memcmp(buf.ptr + 4, "PAR1", 4) != 0) { + if (memcmp(buf.ptr + 4, "PARE", 4) == 0) { + throw InvalidInputException("Encrypted Parquet files are not supported for file '%s'", file_handle.path); + } throw InvalidInputException("No magic bytes found at end of file '%s'", file_handle.path); } // read four-byte footer length from just before the end magic bytes diff --git a/extension/parquet/serialize_parquet.cpp b/extension/parquet/serialize_parquet.cpp new file mode 100644 index 000000000000..66b25f99bd89 --- /dev/null +++ b/extension/parquet/serialize_parquet.cpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" +#include "parquet_reader.hpp" + +namespace duckdb { + +void ParquetOptions::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("binary_as_string", binary_as_string); + serializer.WriteProperty("file_row_number", file_row_number); + serializer.WriteProperty("file_options", file_options); +} + +ParquetOptions ParquetOptions::FormatDeserialize(FormatDeserializer &deserializer) { + ParquetOptions result; + deserializer.ReadProperty("binary_as_string", result.binary_as_string); + deserializer.ReadProperty("file_row_number", result.file_row_number); + deserializer.ReadProperty("file_options", result.file_options); + return result; +} + +} // namespace duckdb diff --git a/scripts/apply_extension_patches.py b/scripts/apply_extension_patches.py new file mode 100644 index 000000000000..65730f1e5306 --- /dev/null +++ b/scripts/apply_extension_patches.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +import sys +import glob +import subprocess + +# Get the directory and construct the patch file pattern +directory = sys.argv[1] +patch_pattern = f"{directory}*.patch" + +# Find patch files matching the pattern +patches = glob.glob(patch_pattern) + +# Exit if no patches are found +if not patches: + error_message = ( + f"\nERROR: Extension patching enabled, but no patches found in '{directory}'. " + "Please make sure APPLY_PATCHES is only enabled when there are actually patches present. " + "See .github/patches/extensions/README.md for more details.\n" + ) + sys.stderr.write(error_message) + sys.exit(1) + +# Apply each patch file using git apply +for patch in patches: + print(f"Applying patch: {patch}") + subprocess.run(["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch]) diff --git a/scripts/format.py b/scripts/format.py index c6435a4da700..0b7863962e85 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -10,6 +10,19 @@ import difflib import re from python_helpers import open_utf8 +from importlib import import_module +from importlib.metadata import version + +try: + import_module('black') +except ImportError as e: + print('you need to run `pip install black`', e) + exit(-1) + +ver = subprocess.check_output(('clang-format', '--version'), text=True) +if '10.' not in ver: + print('you need to run `pip install clang_format==10.0.1.1 - `', ver) + exit(-1) cpp_format_command = 'clang-format --sort-includes=0 -style=file' cmake_format_command = 'cmake-format' diff --git a/scripts/generate_enums.py b/scripts/generate_enums.py new file mode 100644 index 000000000000..bd7718c0aa51 --- /dev/null +++ b/scripts/generate_enums.py @@ -0,0 +1,161 @@ +import os +import json +import re + +targets = [{'source': 'extension/json/include/', 'target': 'extension/json'}] + +file_list = [] +for target in targets: + source_base = os.path.sep.join(target['source'].split('/')) + target_base = os.path.sep.join(target['target'].split('/')) + for fname in os.listdir(source_base): + if '_enums.json' not in fname: + continue + file_list.append( + { + 'source': os.path.join(source_base, fname), + 'include_path': fname.replace('.json', '.hpp'), + 'target_hpp': os.path.join(source_base, fname.replace('.json', '.hpp')), + 'target_cpp': os.path.join(target_base, fname.replace('.json', '.cpp')), + } + ) + +header = '''//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_enums.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +${INCLUDE_LIST} +namespace duckdb { +''' + +footer = ''' +} // namespace duckdb +''' + +include_base = '#include "${FILENAME}"\n' + +enum_header = '\nenum class ${ENUM_NAME} : ${ENUM_TYPE} {\n' + +enum_footer = '};' + +enum_value = '\t${ENUM_MEMBER} = ${ENUM_VALUE},\n' + +enum_util_header = ''' +template<> +const char* EnumUtil::ToChars<${ENUM_NAME}>(${ENUM_NAME} value); + +template<> +${ENUM_NAME} EnumUtil::FromString<${ENUM_NAME}>(const char *value); +''' + +enum_util_conversion_begin = ''' +template<> +const char* EnumUtil::ToChars<${ENUM_NAME}>(${ENUM_NAME} value) { + switch(value) { +''' + +enum_util_switch = '\tcase ${ENUM_NAME}::${ENUM_MEMBER}:\n\t\treturn "${ENUM_MEMBER}";\n' + +enum_util_conversion_end = ''' default: + throw NotImplementedException(StringUtil::Format("Enum value of type ${ENUM_NAME}: '%d' not implemented", value)); + } +} +''' + +from_string_begin = ''' +template<> +${ENUM_NAME} EnumUtil::FromString<${ENUM_NAME}>(const char *value) { +''' + +from_string_comparison = ''' if (StringUtil::Equals(value, "${ENUM_MEMBER}")) { + return ${ENUM_NAME}::${ENUM_MEMBER}; + } +''' + +from_string_end = ''' throw NotImplementedException(StringUtil::Format("Enum value of type ${ENUM_NAME}: '%s' not implemented", value)); +} +''' + + +class EnumMember: + def __init__(self, entry, index): + self.comment = None + self.index = index + if type(entry) == str: + self.name = entry + else: + self.name = entry['name'] + if 'comment' in entry: + self.comment = entry['comment'] + if 'index' in entry: + self.index = int(entry['index']) + + +class EnumClass: + def __init__(self, entry): + self.name = entry['name'] + self.type = 'uint8_t' + self.values = [] + index = 0 + for value_entry in entry['values']: + self.values.append(EnumMember(value_entry, index)) + index += 1 + + +for entry in file_list: + source_path = entry['source'] + target_header = entry['target_hpp'] + target_source = entry['target_cpp'] + include_path = entry['include_path'] + with open(source_path, 'r') as f: + json_data = json.load(f) + + include_list = ['duckdb/common/constants.hpp', 'duckdb/common/enum_util.hpp'] + enums = [] + + for entry in json_data: + if 'includes' in entry: + include_list += entry['includes'] + enums.append(EnumClass(entry)) + + with open(target_header, 'w+') as f: + include_text = '#pragma once\n\n' + include_text += ''.join([include_base.replace('${FILENAME}', x) for x in include_list]) + f.write(header.replace('${INCLUDE_LIST}', include_text)) + + for enum in enums: + f.write(enum_header.replace('${ENUM_NAME}', enum.name).replace('${ENUM_TYPE}', enum.type)) + for value in enum.values: + if value.comment is not None: + f.write('\t//! ' + value.comment + '\n') + f.write(enum_value.replace('${ENUM_MEMBER}', value.name).replace('${ENUM_VALUE}', str(value.index))) + + f.write(enum_footer) + f.write('\n') + + for enum in enums: + f.write(enum_util_header.replace('${ENUM_NAME}', enum.name)) + + f.write(footer) + + with open(target_source, 'w+') as f: + source_include_list = [include_path, 'duckdb/common/string_util.hpp'] + f.write( + header.replace( + '${INCLUDE_LIST}', ''.join([include_base.replace('${FILENAME}', x) for x in source_include_list]) + ) + ) + + for enum in enums: + f.write(enum_util_conversion_begin.replace('${ENUM_NAME}', enum.name)) + for value in enum.values: + f.write(enum_util_switch.replace('${ENUM_MEMBER}', value.name).replace('${ENUM_NAME}', enum.name)) + + f.write(enum_util_conversion_end.replace('${ENUM_NAME}', enum.name)) + f.write(from_string_begin.replace('${ENUM_NAME}', enum.name)) + for value in enum.values: + f.write(from_string_comparison.replace('${ENUM_MEMBER}', value.name).replace('${ENUM_NAME}', enum.name)) + + f.write(from_string_end.replace('${ENUM_NAME}', enum.name)) + f.write(footer) diff --git a/scripts/generate_serialization.py b/scripts/generate_serialization.py index 56cab8a07315..8d7732362397 100644 --- a/scripts/generate_serialization.py +++ b/scripts/generate_serialization.py @@ -2,19 +2,27 @@ import json import re -source_base = os.path.sep.join('src/include/duckdb/storage/serialization'.split('/')) -target_base = os.path.sep.join('src/storage/serialization'.split('/')) +targets = [ + {'source': 'src/include/duckdb/storage/serialization', 'target': 'src/storage/serialization'}, + {'source': 'extension/parquet/include/', 'target': 'extension/parquet'}, + {'source': 'extension/json/include/', 'target': 'extension/json'}, +] file_list = [] -for fname in os.listdir(source_base): - if '.json' not in fname: - continue - file_list.append( - { - 'source': os.path.join(source_base, fname), - 'target': os.path.join(target_base, 'serialize_' + fname.replace('.json', '.cpp')), - } - ) +for target in targets: + source_base = os.path.sep.join(target['source'].split('/')) + target_base = os.path.sep.join(target['target'].split('/')) + for fname in os.listdir(source_base): + if '.json' not in fname: + continue + if '_enums.json' in fname: + continue + file_list.append( + { + 'source': os.path.join(source_base, fname), + 'target': os.path.join(target_base, 'serialize_' + fname.replace('.json', '.cpp')), + } + ) include_base = '#include "${FILENAME}"\n' @@ -397,9 +405,9 @@ def generate_class_code(class_entry): constructor_parameters += ", " type_name = replace_pointer(entry.type) if requires_move(type_name) and not is_reference: - constructor_parameters += 'std::move(' + entry.name + ')' + constructor_parameters += 'std::move(' + entry.deserialize_property + ')' else: - constructor_parameters += entry.name + constructor_parameters += entry.deserialize_property found = True break if constructor_entry.startswith('$'): @@ -426,7 +434,9 @@ def generate_class_code(class_entry): for entry_idx in range(last_constructor_index + 1): entry = class_entry.members[entry_idx] type_name = replace_pointer(entry.type) - class_deserialize += get_deserialize_element(entry.name, entry.name, type_name, entry.optional, 'unique_ptr') + class_deserialize += get_deserialize_element( + entry.deserialize_property, entry.name, type_name, entry.optional, 'unique_ptr' + ) class_deserialize += generate_constructor( class_entry.pointer_type, class_entry.return_class, constructor_parameters @@ -465,7 +475,9 @@ def generate_class_code(class_entry): class_entry.pointer_type, ) elif entry.name not in constructor_entries: - class_deserialize += get_deserialize_assignment(entry.name, entry.type, class_entry.pointer_type) + class_deserialize += get_deserialize_assignment( + entry.deserialize_property, entry.type, class_entry.pointer_type + ) class_deserialize += generate_return(class_entry) deserialize_return = get_return_value(class_entry.pointer_type, class_entry.return_type) diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 645dc7495e67..34f07c0a8feb 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -655,17 +655,13 @@ LogicalType Catalog::GetType(ClientContext &context, const string &schema, const if (!type_entry) { return LogicalType::INVALID; } - auto result_type = type_entry->user_type; - EnumType::SetCatalog(result_type, type_entry.get()); - return result_type; + return type_entry->user_type; } LogicalType Catalog::GetType(ClientContext &context, const string &catalog_name, const string &schema, const string &name) { auto &type_entry = Catalog::GetEntry(context, catalog_name, schema, name); - auto result_type = type_entry.user_type; - EnumType::SetCatalog(result_type, &type_entry); - return result_type; + return type_entry.user_type; } vector> Catalog::GetSchemas(ClientContext &context) { diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index a228a836f7b5..3a0296fa66d6 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -527,10 +527,7 @@ unique_ptr DuckTableEntry::DropNotNull(ClientContext &context, Dro } unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info) { - if (info.target_type.id() == LogicalTypeId::USER) { - info.target_type = - Catalog::GetType(context, catalog.GetName(), schema.name, UserType::GetTypeName(info.target_type)); - } + Binder::BindLogicalType(context, info.target_type, &catalog, schema.name); auto change_idx = GetColumnIndex(info.column_name); auto create_info = make_uniq(schema, name); create_info->temporary = temporary; diff --git a/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/catalog/catalog_entry/table_catalog_entry.cpp index f88b4506d1a3..b01bfb288a60 100644 --- a/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -172,10 +172,6 @@ const ColumnList &TableCatalogEntry::GetColumns() const { return columns; } -ColumnList &TableCatalogEntry::GetColumnsMutable() { - return columns; -} - const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) { return columns.GetColumn(idx); } diff --git a/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/catalog/catalog_entry/type_catalog_entry.cpp index 125cf6f53bc0..f16f86af5a99 100644 --- a/src/catalog/catalog_entry/type_catalog_entry.cpp +++ b/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -5,7 +5,6 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/field_writer.hpp" #include "duckdb/parser/keyword_helper.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" #include "duckdb/common/types/vector.hpp" #include #include @@ -18,31 +17,13 @@ TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, this->internal = info.internal; } -void TypeCatalogEntry::Serialize(Serializer &serializer) const { - D_ASSERT(!internal); - FieldWriter writer(serializer); - writer.WriteString(schema.name); - writer.WriteString(name); - if (user_type.id() == LogicalTypeId::ENUM) { - // We have to serialize Enum Values - writer.AddField(); - user_type.SerializeEnumType(writer.GetSerializer()); - } else { - writer.WriteSerializable(user_type); - } - writer.Finalize(); -} - -unique_ptr TypeCatalogEntry::Deserialize(Deserializer &source) { - auto info = make_uniq(); - - FieldReader reader(source); - info->schema = reader.ReadRequired(); - info->name = reader.ReadRequired(); - info->type = reader.ReadRequiredSerializable(); - reader.Finalize(); - - return info; +unique_ptr TypeCatalogEntry::GetInfo() const { + auto result = make_uniq(); + result->catalog = catalog.GetName(); + result->schema = schema.name; + result->name = name; + result->type = user_type; + return std::move(result); } string TypeCatalogEntry::ToSQL() const { diff --git a/src/catalog/catalog_set.cpp b/src/catalog/catalog_set.cpp index ed2415379e80..b44814b81190 100644 --- a/src/catalog/catalog_set.cpp +++ b/src/catalog/catalog_set.cpp @@ -543,67 +543,6 @@ void CatalogSet::UpdateTimestamp(CatalogEntry &entry, transaction_t timestamp) { mapping[entry.name]->timestamp = timestamp; } -void CatalogSet::AdjustUserDependency(CatalogEntry &entry, ColumnDefinition &column, bool remove) { - auto user_type_catalog_p = EnumType::GetCatalog(column.Type()); - if (!user_type_catalog_p) { - return; - } - auto &user_type_catalog = user_type_catalog_p->Cast(); - auto &dependency_manager = catalog.GetDependencyManager(); - if (remove) { - dependency_manager.dependents_map[user_type_catalog].erase(*entry.parent); - dependency_manager.dependencies_map[*entry.parent].erase(user_type_catalog); - } else { - dependency_manager.dependents_map[user_type_catalog].insert(entry); - dependency_manager.dependencies_map[entry].insert(user_type_catalog); - } -} - -void CatalogSet::AdjustDependency(CatalogEntry &entry, TableCatalogEntry &table, ColumnDefinition &column, - bool remove) { - bool found = false; - if (column.Type().id() == LogicalTypeId::ENUM) { - for (auto &old_column : table.GetColumns().Logical()) { - if (old_column.Name() == column.Name() && old_column.Type().id() != LogicalTypeId::ENUM) { - AdjustUserDependency(entry, column, remove); - found = true; - } - } - if (!found) { - AdjustUserDependency(entry, column, remove); - } - } else if (!(column.Type().GetAlias().empty())) { - auto alias = column.Type().GetAlias(); - for (auto &old_column : table.GetColumns().Logical()) { - auto old_alias = old_column.Type().GetAlias(); - if (old_column.Name() == column.Name() && old_alias != alias) { - AdjustUserDependency(entry, column, remove); - found = true; - } - } - if (!found) { - AdjustUserDependency(entry, column, remove); - } - } -} - -void CatalogSet::AdjustTableDependencies(CatalogEntry &entry) { - if (entry.type == CatalogType::TABLE_ENTRY && entry.parent->type == CatalogType::TABLE_ENTRY) { - // If it's a table entry we have to check for possibly removing or adding user type dependencies - auto &old_table = entry.parent->Cast(); - auto &new_table = entry.Cast(); - - for (idx_t i = 0; i < new_table.GetColumns().LogicalColumnCount(); i++) { - auto &new_column = new_table.GetColumnsMutable().GetColumnMutable(LogicalIndex(i)); - AdjustDependency(entry, old_table, new_column, false); - } - for (idx_t i = 0; i < old_table.GetColumns().LogicalColumnCount(); i++) { - auto &old_column = old_table.GetColumnsMutable().GetColumnMutable(LogicalIndex(i)); - AdjustDependency(entry, new_table, old_column, true); - } - } -} - void CatalogSet::Undo(CatalogEntry &entry) { lock_guard write_lock(catalog.GetWriteLock()); lock_guard lock(catalog_lock); @@ -614,8 +553,6 @@ void CatalogSet::Undo(CatalogEntry &entry) { // i.e. we have to place (entry) as (entry->parent) again auto &to_be_removed_node = *entry.parent; - AdjustTableDependencies(entry); - if (!to_be_removed_node.deleted) { // delete the entry from the dependency manager as well auto &dependency_manager = catalog.GetDependencyManager(); diff --git a/src/catalog/dependency_manager.cpp b/src/catalog/dependency_manager.cpp index e04a227c0f6b..901604a5fd71 100644 --- a/src/catalog/dependency_manager.cpp +++ b/src/catalog/dependency_manager.cpp @@ -106,52 +106,16 @@ void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry } // add the new object to the dependents_map of each object that it depends on auto &old_dependencies = dependencies_map[old_obj]; - catalog_entry_vector_t to_delete; for (auto &dep : old_dependencies) { auto &dependency = dep.get(); - if (dependency.type == CatalogType::TYPE_ENTRY) { - auto &user_type = dependency.Cast(); - auto &table = new_obj.Cast(); - bool deleted_dependency = true; - for (auto &column : table.GetColumns().Logical()) { - if (column.Type() == user_type.user_type) { - deleted_dependency = false; - break; - } - } - if (deleted_dependency) { - to_delete.push_back(dependency); - continue; - } - } dependents_map[dependency].insert(new_obj); } - for (auto &dep : to_delete) { - auto &dependency = dep.get(); - old_dependencies.erase(dependency); - dependents_map[dependency].erase(old_obj); - } // We might have to add a type dependency - catalog_entry_vector_t to_add; - if (new_obj.type == CatalogType::TABLE_ENTRY) { - auto &table = new_obj.Cast(); - for (auto &column : table.GetColumns().Logical()) { - auto user_type_catalog = EnumType::GetCatalog(column.Type()); - if (user_type_catalog) { - to_add.push_back(*user_type_catalog); - } - } - } // add the new object to the dependency manager dependents_map[new_obj] = dependency_set_t(); dependencies_map[new_obj] = old_dependencies; - for (auto &dependency : to_add) { - dependencies_map[new_obj].insert(dependency); - dependents_map[dependency].insert(new_obj); - } - for (auto &dependency : owned_objects_to_add) { dependents_map[new_obj].insert(Dependency(dependency, DependencyType::DEPENDENCY_OWNS)); dependents_map[dependency].insert(Dependency(new_obj, DependencyType::DEPENDENCY_OWNED_BY)); diff --git a/src/common/adbc/adbc.cpp b/src/common/adbc/adbc.cpp index c63e0eb0e257..252a833bb242 100644 --- a/src/common/adbc/adbc.cpp +++ b/src/common/adbc/adbc.cpp @@ -13,10 +13,15 @@ #endif #include "duckdb/common/adbc/single_batch_array_stream.hpp" +#include "duckdb/common/arrow/arrow_cpp.hpp" #include #include +using duckdb::ArrowArrayCPP; +using duckdb::ArrowArrayStreamCPP; +using duckdb::ArrowSchemaCPP; + // We must leak the symbols of the init function duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::AdbcDriver *driver, struct duckdb_adbc::AdbcError *error) { @@ -64,7 +69,7 @@ struct DuckDBAdbcStatementWrapper { ::duckdb_arrow result; ::duckdb_prepared_statement statement; char *ingestion_table_name; - ArrowArrayStream ingestion_stream; + ArrowArrayStreamCPP ingestion_stream; IngestionMode ingestion_mode = IngestionMode::CREATE; }; @@ -599,7 +604,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, stru auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), duckdb::Value::POINTER((uintptr_t)stream_produce), - duckdb::Value::POINTER((uintptr_t)get_schema)}); + duckdb::Value::POINTER((uintptr_t)input->get_schema)}); try { if (ingestion_mode == IngestionMode::CREATE) { // We create the table based on an Arrow Scanner @@ -645,7 +650,7 @@ AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatem statement->private_data = nullptr; - auto statement_wrapper = (DuckDBAdbcStatementWrapper *)malloc(sizeof(DuckDBAdbcStatementWrapper)); + auto statement_wrapper = new DuckDBAdbcStatementWrapper; status = SetErrorMaybe(statement_wrapper, error, "Allocation error"); if (status != ADBC_STATUS_OK) { return status; @@ -655,7 +660,6 @@ AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatem statement_wrapper->connection = (duckdb_connection)connection->private_data; statement_wrapper->statement = nullptr; statement_wrapper->result = nullptr; - statement_wrapper->ingestion_stream.release = nullptr; statement_wrapper->ingestion_table_name = nullptr; statement_wrapper->ingestion_mode = IngestionMode::CREATE; return ADBC_STATUS_OK; @@ -675,15 +679,11 @@ AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcErro duckdb_destroy_arrow(&wrapper->result); wrapper->result = nullptr; } - if (wrapper->ingestion_stream.release) { - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - wrapper->ingestion_stream.release = nullptr; - } if (wrapper->ingestion_table_name) { free(wrapper->ingestion_table_name); wrapper->ingestion_table_name = nullptr; } - free(statement->private_data); + delete wrapper; statement->private_data = nullptr; } return ADBC_STATUS_OK; @@ -742,15 +742,14 @@ AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::uniqu static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, AdbcError *error) { // See ADBC_INGEST_OPTION_TARGET_TABLE - D_ASSERT(statement->ingestion_stream.release); + D_ASSERT(statement->ingestion_stream.Valid()); D_ASSERT(statement->ingestion_table_name); // Take the input stream from the statement - auto stream = statement->ingestion_stream; - statement->ingestion_stream.release = nullptr; + auto stream = std::move(statement->ingestion_stream); // Ingest into a table from the bound stream - return Ingest(statement->connection, statement->ingestion_table_name, &stream, error, statement->ingestion_mode); + return Ingest(statement->connection, statement->ingestion_table_name, stream, error, statement->ingestion_mode); } AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, @@ -775,7 +774,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr *rows_affected = 0; } - const auto has_stream = wrapper->ingestion_stream.release != nullptr; + const auto has_stream = wrapper->ingestion_stream.Valid(); const auto to_table = wrapper->ingestion_table_name != nullptr; if (has_stream && to_table) { @@ -785,9 +784,8 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr if (has_stream) { // A stream was bound to the statement, use that to bind parameters duckdb::unique_ptr result; - ArrowArrayStream stream = wrapper->ingestion_stream; - wrapper->ingestion_stream.release = nullptr; - auto adbc_res = GetPreparedParameters(wrapper->connection, result, &stream, error); + auto stream = std::move(wrapper->ingestion_stream); + auto adbc_res = GetPreparedParameters(wrapper->connection, result, stream, error); if (adbc_res != ADBC_STATUS_OK) { return adbc_res; } @@ -907,11 +905,8 @@ AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray return status; } auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - if (wrapper->ingestion_stream.release) { - // Free the stream that was previously bound - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } - status = BatchToArrayStream(values, schemas, &wrapper->ingestion_stream, error); + auto existing_stream = std::move(wrapper->ingestion_stream); + status = BatchToArrayStream(values, schemas, wrapper->ingestion_stream, error); if (status != ADBC_STATUS_OK) { return status; } @@ -933,10 +928,6 @@ AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct Arrow return status; } auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; - if (wrapper->ingestion_stream.release) { - // Release any resources currently held by the ingestion stream before we overwrite it - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } wrapper->ingestion_stream = *values; values->release = nullptr; return ADBC_STATUS_OK; diff --git a/src/common/arrow/arrow_wrapper.cpp b/src/common/arrow/arrow_wrapper.cpp index ddc6855d8a42..34049bbed0bf 100644 --- a/src/common/arrow/arrow_wrapper.cpp +++ b/src/common/arrow/arrow_wrapper.cpp @@ -9,6 +9,7 @@ #include "duckdb/common/arrow/result_arrow_wrapper.hpp" #include "duckdb/common/arrow/arrow_appender.hpp" #include "duckdb/main/query_result.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" namespace duckdb { @@ -99,6 +100,7 @@ int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stre } auto my_stream = reinterpret_cast(stream->private_data); auto &result = *my_stream->result; + auto &scan_state = *my_stream->scan_state; if (result.HasError()) { my_stream->last_error = result.GetErrorObject(); return -1; @@ -117,7 +119,8 @@ int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stre } idx_t result_count; PreservedError error; - if (!ArrowUtil::TryFetchChunk(&result, my_stream->batch_size, out, result_count, error)) { + if (!ArrowUtil::TryFetchChunk(scan_state, result.GetArrowOptions(result), my_stream->batch_size, out, result_count, + error)) { D_ASSERT(error); my_stream->last_error = error; return -1; @@ -147,7 +150,7 @@ const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArra } ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr result_p, idx_t batch_size_p) - : result(std::move(result_p)) { + : result(std::move(result_p)), scan_state(make_uniq(*result)) { //! We first initialize the private data of the stream stream.private_data = this; //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE @@ -162,52 +165,43 @@ ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr &chunk, PreservedError &error) { - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - return true; - } - } - return result.TryFetch(chunk, error); -} - -bool ArrowUtil::TryFetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out, idx_t &count, - PreservedError &error) { +bool ArrowUtil::TryFetchChunk(ChunkScanState &scan_state, ArrowOptions options, idx_t batch_size, ArrowArray *out, + idx_t &count, PreservedError &error) { count = 0; - ArrowAppender appender(result->types, chunk_size, QueryResult::GetArrowOptions(*result)); - auto ¤t_chunk = result->current_chunk; - if (current_chunk.Valid()) { + ArrowAppender appender(scan_state.Types(), batch_size, std::move(options)); + auto remaining_tuples_in_chunk = scan_state.RemainingInChunk(); + if (remaining_tuples_in_chunk) { // We start by scanning the non-finished current chunk - // Limit the amount we're fetching to the chunk_size - idx_t cur_consumption = MinValue(current_chunk.RemainingSize(), chunk_size); + idx_t cur_consumption = MinValue(remaining_tuples_in_chunk, batch_size); count += cur_consumption; - appender.Append(*current_chunk.data_chunk, current_chunk.position, current_chunk.position + cur_consumption, - current_chunk.data_chunk->size()); - current_chunk.position += cur_consumption; - } - while (count < chunk_size) { - unique_ptr data_chunk; - if (!TryFetchNext(*result, data_chunk, error)) { - if (result->HasError()) { - error = result->GetErrorObject(); + auto ¤t_chunk = scan_state.CurrentChunk(); + appender.Append(current_chunk, scan_state.CurrentOffset(), scan_state.CurrentOffset() + cur_consumption, + current_chunk.size()); + scan_state.IncreaseOffset(cur_consumption); + } + while (count < batch_size) { + if (!scan_state.LoadNextChunk(error)) { + if (scan_state.HasError()) { + error = scan_state.GetError(); } return false; } - if (!data_chunk || data_chunk->size() == 0) { + if (scan_state.ChunkIsEmpty()) { + // The scan was successful, but an empty chunk was returned break; } - if (count + data_chunk->size() > chunk_size) { - // We have to split the chunk between this and the next batch - idx_t available_space = chunk_size - count; - appender.Append(*data_chunk, 0, available_space, data_chunk->size()); - count += available_space; - current_chunk.data_chunk = std::move(data_chunk); - current_chunk.position = available_space; - } else { - count += data_chunk->size(); - appender.Append(*data_chunk, 0, data_chunk->size(), data_chunk->size()); + auto ¤t_chunk = scan_state.CurrentChunk(); + if (scan_state.Finished() || current_chunk.size() == 0) { + break; } + // The amount we still need to append into this chunk + auto remaining = batch_size - count; + + // The amount remaining, capped by the amount left in the current chunk + auto to_append_to_batch = MinValue(remaining, scan_state.RemainingInChunk()); + appender.Append(current_chunk, 0, to_append_to_batch, current_chunk.size()); + count += to_append_to_batch; + scan_state.IncreaseOffset(to_append_to_batch); } if (count > 0) { *out = appender.Finalize(); @@ -215,10 +209,10 @@ bool ArrowUtil::TryFetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray return true; } -idx_t ArrowUtil::FetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out) { +idx_t ArrowUtil::FetchChunk(ChunkScanState &scan_state, ArrowOptions options, idx_t chunk_size, ArrowArray *out) { PreservedError error; idx_t result_count; - if (!TryFetchChunk(result, chunk_size, out, result_count, error)) { + if (!TryFetchChunk(scan_state, std::move(options), chunk_size, out, result_count, error)) { error.Throw(); } return result_count; diff --git a/src/common/extra_type_info.cpp b/src/common/extra_type_info.cpp index 3bb80638f55f..f18ffcc18145 100644 --- a/src/common/extra_type_info.cpp +++ b/src/common/extra_type_info.cpp @@ -299,8 +299,8 @@ PhysicalType EnumTypeInfo::DictType(idx_t size) { template struct EnumTypeInfoTemplated : public EnumTypeInfo { - explicit EnumTypeInfoTemplated(const string &enum_name_p, Vector &values_insert_order_p, idx_t size_p) - : EnumTypeInfo(enum_name_p, values_insert_order_p, size_p) { + explicit EnumTypeInfoTemplated(Vector &values_insert_order_p, idx_t size_p) + : EnumTypeInfo(values_insert_order_p, size_p) { D_ASSERT(values_insert_order_p.GetType().InternalType() == PhysicalType::VARCHAR); UnifiedVectorFormat vdata; @@ -320,17 +320,16 @@ struct EnumTypeInfoTemplated : public EnumTypeInfo { } } - static shared_ptr Deserialize(FieldReader &reader, uint32_t size, string enum_name) { + static shared_ptr Deserialize(FieldReader &reader, uint32_t size) { Vector values_insert_order(LogicalType::VARCHAR, size); values_insert_order.Deserialize(size, reader.GetSource()); - return make_shared(std::move(enum_name), values_insert_order, size); + return make_shared(values_insert_order, size); } static shared_ptr FormatDeserialize(FormatDeserializer &source, uint32_t size) { - auto enum_name = source.ReadProperty("enum_name"); Vector values_insert_order(LogicalType::VARCHAR, size); values_insert_order.FormatDeserialize(source, size); - return make_shared(std::move(enum_name), values_insert_order, size); + return make_shared(values_insert_order, size); } const string_map_t &GetValues() const { @@ -344,23 +343,15 @@ struct EnumTypeInfoTemplated : public EnumTypeInfo { string_map_t values; }; -EnumTypeInfo::EnumTypeInfo(string enum_name_p, Vector &values_insert_order_p, idx_t dict_size_p) +EnumTypeInfo::EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p) : ExtraTypeInfo(ExtraTypeInfoType::ENUM_TYPE_INFO), values_insert_order(values_insert_order_p), - dict_type(EnumDictType::VECTOR_DICT), enum_name(std::move(enum_name_p)), dict_size(dict_size_p) { + dict_type(EnumDictType::VECTOR_DICT), dict_size(dict_size_p) { } const EnumDictType &EnumTypeInfo::GetEnumDictType() const { return dict_type; } -const string &EnumTypeInfo::GetEnumName() const { - return enum_name; -} - -const string EnumTypeInfo::GetSchemaName() const { - return catalog_entry ? catalog_entry->schema.name : ""; -} - const Vector &EnumTypeInfo::GetValuesInsertOrder() const { return values_insert_order; } @@ -369,19 +360,19 @@ const idx_t &EnumTypeInfo::GetDictSize() const { return dict_size; } -LogicalType EnumTypeInfo::CreateType(const string &enum_name, Vector &ordered_data, idx_t size) { +LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { // Generate EnumTypeInfo shared_ptr info; auto enum_internal_type = EnumTypeInfo::DictType(size); switch (enum_internal_type) { case PhysicalType::UINT8: - info = make_shared>(enum_name, ordered_data, size); + info = make_shared>(ordered_data, size); break; case PhysicalType::UINT16: - info = make_shared>(enum_name, ordered_data, size); + info = make_shared>(ordered_data, size); break; case PhysicalType::UINT32: - info = make_shared>(enum_name, ordered_data, size); + info = make_shared>(ordered_data, size); break; default: throw InternalException("Invalid Physical Type for ENUMs"); @@ -413,37 +404,22 @@ int64_t EnumType::GetPos(const LogicalType &type, const string_t &key) { } } +string_t EnumType::GetString(const LogicalType &type, idx_t pos) { + D_ASSERT(pos < EnumType::GetSize(type)); + return FlatVector::GetData(EnumType::GetValuesInsertOrder(type))[pos]; +} + shared_ptr EnumTypeInfo::Deserialize(FieldReader &reader) { - auto schema_name = reader.ReadRequired(); - auto enum_name = reader.ReadRequired(); - auto deserialize_internals = reader.ReadRequired(); - if (!deserialize_internals) { - // this means the enum should already be in the catalog. - auto &client_context = reader.GetSource().GetContext(); - // See if the serializer has a catalog - auto catalog = reader.GetSource().GetCatalog(); - shared_ptr extra_info; - if (catalog) { - auto enum_type = catalog->GetType(client_context, schema_name, enum_name, OnEntryNotFound::RETURN_NULL); - if (enum_type != LogicalType::INVALID) { - extra_info = enum_type.GetAuxInfoShrPtr(); - } - } - if (!extra_info) { - throw InternalException("Could not find ENUM in the Catalog to deserialize"); - } - return extra_info; - } // deserialize the enum data auto enum_size = reader.ReadRequired(); auto enum_internal_type = EnumTypeInfo::DictType(enum_size); switch (enum_internal_type) { case PhysicalType::UINT8: - return EnumTypeInfoTemplated::Deserialize(reader, enum_size, enum_name); + return EnumTypeInfoTemplated::Deserialize(reader, enum_size); case PhysicalType::UINT16: - return EnumTypeInfoTemplated::Deserialize(reader, enum_size, enum_name); + return EnumTypeInfoTemplated::Deserialize(reader, enum_size); case PhysicalType::UINT32: - return EnumTypeInfoTemplated::Deserialize(reader, enum_size, enum_name); + return EnumTypeInfoTemplated::Deserialize(reader, enum_size); default: throw InternalException("Invalid Physical Type for ENUMs"); } @@ -491,14 +467,16 @@ void EnumTypeInfo::Serialize(FieldWriter &writer) const { if (dict_type != EnumDictType::VECTOR_DICT) { throw InternalException("Cannot serialize non-vector dictionary ENUM types"); } - bool serialize_internals = GetSchemaName().empty() || writer.GetSerializer().is_query_plan; - EnumType::Serialize(writer, *this, serialize_internals); + auto dict_size = GetDictSize(); + // Store Dictionary Size + writer.WriteField(dict_size); + // Store Vector Order By Insertion + ((Vector &)GetValuesInsertOrder()).Serialize(dict_size, writer.GetSerializer()); // NOLINT - FIXME } void EnumTypeInfo::FormatSerialize(FormatSerializer &serializer) const { ExtraTypeInfo::FormatSerialize(serializer); serializer.WriteProperty("dict_size", dict_size); - serializer.WriteProperty("enum_name", enum_name); ((Vector &)values_insert_order).FormatSerialize(serializer, dict_size); // NOLINT - FIXME } diff --git a/src/common/field_writer.cpp b/src/common/field_writer.cpp index 0de7b8a78e12..4da5a2c8626e 100644 --- a/src/common/field_writer.cpp +++ b/src/common/field_writer.cpp @@ -8,7 +8,6 @@ namespace duckdb { FieldWriter::FieldWriter(Serializer &serializer_p) : serializer(serializer_p), buffer(make_uniq()), field_count(0), finalized(false) { buffer->SetVersion(serializer.GetVersion()); - buffer->is_query_plan = serializer.is_query_plan; } FieldWriter::~FieldWriter() { diff --git a/src/common/file_system.cpp b/src/common/file_system.cpp index ad7970284300..5d0dc4e91de2 100644 --- a/src/common/file_system.cpp +++ b/src/common/file_system.cpp @@ -71,11 +71,11 @@ string FileSystem::GetEnvVariable(const string &name) { } bool FileSystem::IsPathAbsolute(const string &path) { - auto path_separator = FileSystem::PathSeparator(); + auto path_separator = PathSeparator(path); return PathMatched(path, path_separator); } -string FileSystem::PathSeparator() { +string FileSystem::PathSeparator(const string &path) { return "/"; } @@ -167,7 +167,7 @@ string FileSystem::NormalizeAbsolutePath(const string &path) { return result; } -string FileSystem::PathSeparator() { +string FileSystem::PathSeparator(const string &path) { return "\\"; } @@ -210,11 +210,11 @@ string FileSystem::GetWorkingDirectory() { string FileSystem::JoinPath(const string &a, const string &b) { // FIXME: sanitize paths - return a + PathSeparator() + b; + return a + PathSeparator(a) + b; } string FileSystem::ConvertSeparators(const string &path) { - auto separator_str = PathSeparator(); + auto separator_str = PathSeparator(path); char separator = separator_str[0]; if (separator == '/') { // on unix-based systems we only accept / as a separator @@ -229,7 +229,7 @@ string FileSystem::ExtractName(const string &path) { return string(); } auto normalized_path = ConvertSeparators(path); - auto sep = PathSeparator(); + auto sep = PathSeparator(path); auto splits = StringUtil::Split(normalized_path, sep); D_ASSERT(!splits.empty()); return splits.back(); diff --git a/src/common/filename_pattern.cpp b/src/common/filename_pattern.cpp index 69c58c8bd1d2..e52b9a61859d 100644 --- a/src/common/filename_pattern.cpp +++ b/src/common/filename_pattern.cpp @@ -24,7 +24,7 @@ void FilenamePattern::SetFilenamePattern(const string &pattern) { _pos = std::min(_pos, (idx_t)_base.length()); } -string FilenamePattern::CreateFilename(const FileSystem &fs, const string &path, const string &extension, +string FilenamePattern::CreateFilename(FileSystem &fs, const string &path, const string &extension, idx_t offset) const { string result(_base); string replacement; diff --git a/src/common/gzip_file_system.cpp b/src/common/gzip_file_system.cpp index a1c351166dcf..51774bd8d61c 100644 --- a/src/common/gzip_file_system.cpp +++ b/src/common/gzip_file_system.cpp @@ -145,6 +145,13 @@ void MiniZStreamWrapper::Initialize(CompressedFile &file, bool write) { bool MiniZStreamWrapper::Read(StreamData &sd) { // Handling for the concatenated files if (sd.refresh) { + auto available = (uint32_t)(sd.in_buff_end - sd.in_buff_start); + if (available <= GZIP_FOOTER_SIZE) { + // Only footer is available so we just close and return finished + Close(); + return true; + } + sd.refresh = false; auto body_ptr = sd.in_buff_start + GZIP_FOOTER_SIZE; uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; @@ -200,18 +207,6 @@ bool MiniZStreamWrapper::Read(StreamData &sd) { // if stream ended, deallocate inflator if (ret == duckdb_miniz::MZ_STREAM_END) { - // Last read from file done and remaining bytes only for footer or less - if ((sd.in_buff_end < sd.in_buff.get() + sd.in_buf_size) && mz_stream_ptr->avail_in <= GZIP_FOOTER_SIZE) { - Close(); - return true; - } - if (mz_stream_ptr->avail_in > GZIP_FOOTER_SIZE) { - // Definitely not concatenated gzip - if (*(sd.in_buff_start + GZIP_FOOTER_SIZE) != 0x1F) { - Close(); - return true; - } - } // Concatenated GZIP potentially coming up - refresh input buffer sd.refresh = true; } diff --git a/src/common/multi_file_reader.cpp b/src/common/multi_file_reader.cpp index 5d1f17e0cf2f..6a2cbf701725 100644 --- a/src/common/multi_file_reader.cpp +++ b/src/common/multi_file_reader.cpp @@ -440,10 +440,11 @@ void UnionByName::CombineUnionTypes(const vector &col_names, const vecto } } -bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(const vector &files) { +bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(const vector &files, ClientContext &context) { std::unordered_set partitions; + auto &fs = FileSystem::GetFileSystem(context); - auto splits_first_file = StringUtil::Split(files.front(), FileSystem::PathSeparator()); + auto splits_first_file = StringUtil::Split(files.front(), fs.PathSeparator(files.front())); if (splits_first_file.size() < 2) { return false; } @@ -457,7 +458,7 @@ bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(const vector partitions; - auto splits = StringUtil::Split(file, FileSystem::PathSeparator()); + auto splits = StringUtil::Split(file, fs.PathSeparator(file)); if (splits.size() < 2) { return; } @@ -518,7 +521,7 @@ void MultiFileReaderOptions::AutoDetectHivePartitioning(const vector &fi auto_detect_hive_partitioning = false; } if (auto_detect_hive_partitioning) { - hive_partitioning = AutoDetectHivePartitioningInternal(files); + hive_partitioning = AutoDetectHivePartitioningInternal(files, context); } if (hive_partitioning && hive_types_autocast) { AutoDetectHiveTypesInternal(files.front(), context); diff --git a/src/common/serializer/CMakeLists.txt b/src/common/serializer/CMakeLists.txt index 8110ba358e40..ea8de74fccad 100644 --- a/src/common/serializer/CMakeLists.txt +++ b/src/common/serializer/CMakeLists.txt @@ -6,7 +6,8 @@ add_library_unity( buffered_deserializer.cpp buffered_file_reader.cpp buffered_file_writer.cpp - buffered_serializer.cpp) + buffered_serializer.cpp + format_serializer.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/src/common/serializer/buffered_file_reader.cpp b/src/common/serializer/buffered_file_reader.cpp index f4fd1fe00920..485b10d12188 100644 --- a/src/common/serializer/buffered_file_reader.cpp +++ b/src/common/serializer/buffered_file_reader.cpp @@ -63,13 +63,4 @@ ClientContext &BufferedFileReader::GetContext() { return *context; } -optional_ptr BufferedFileReader::GetCatalog() { - return catalog; -} - -void BufferedFileReader::SetCatalog(Catalog &catalog_p) { - D_ASSERT(!catalog); - this->catalog = &catalog_p; -} - } // namespace duckdb diff --git a/src/common/serializer/format_serializer.cpp b/src/common/serializer/format_serializer.cpp new file mode 100644 index 000000000000..76415a81eee7 --- /dev/null +++ b/src/common/serializer/format_serializer.cpp @@ -0,0 +1,15 @@ +#include "duckdb/common/serializer/format_serializer.hpp" + +namespace duckdb { + +template <> +void FormatSerializer::WriteValue(const vector &vec) { + auto count = vec.size(); + OnListBegin(count); + for (auto item : vec) { + WriteValue(item); + } + OnListEnd(count); +} + +} // namespace duckdb diff --git a/src/common/types.cpp b/src/common/types.cpp index bcfd7bae2fe8..51cafd5e6c63 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -395,7 +395,15 @@ string LogicalType::ToString() const { return StringUtil::Format("DECIMAL(%d,%d)", width, scale); } case LogicalTypeId::ENUM: { - return KeywordHelper::WriteOptionallyQuoted(EnumType::GetTypeName(*this)); + string ret = "ENUM("; + for (idx_t i = 0; i < EnumType::GetSize(*this); i++) { + if (i > 0) { + ret += ", "; + } + ret += "'" + KeywordHelper::WriteOptionallyQuoted(EnumType::GetString(*this, i).GetString(), '\'') + "'"; + } + ret += ")"; + return ret; } case LogicalTypeId::USER: { return KeywordHelper::WriteOptionallyQuoted(UserType::GetTypeName(*this)); @@ -995,34 +1003,12 @@ LogicalType LogicalType::USER(const string &user_type_name) { //===--------------------------------------------------------------------===// // Enum Type //===--------------------------------------------------------------------===// -void EnumType::Serialize(FieldWriter &writer, const ExtraTypeInfo &type_info, bool serialize_internals) { - D_ASSERT(type_info.type == ExtraTypeInfoType::ENUM_TYPE_INFO); - auto &enum_info = type_info.Cast(); - // Store Schema Name - writer.WriteString(enum_info.GetSchemaName()); - // Store Enum Name - writer.WriteString(enum_info.GetEnumName()); - // Store If we are serializing the internals - writer.WriteField(serialize_internals); - if (serialize_internals) { - // We must serialize the internals - auto dict_size = enum_info.GetDictSize(); - // Store Dictionary Size - writer.WriteField(dict_size); - // Store Vector Order By Insertion - ((Vector &)enum_info.GetValuesInsertOrder()).Serialize(dict_size, writer.GetSerializer()); // NOLINT - FIXME - } -} - -const string &EnumType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().GetEnumName(); +LogicalType LogicalType::ENUM(Vector &ordered_data, idx_t size) { + return EnumTypeInfo::CreateType(ordered_data, size); } LogicalType LogicalType::ENUM(const string &enum_name, Vector &ordered_data, idx_t size) { - return EnumTypeInfo::CreateType(enum_name, ordered_data, size); + return LogicalType::ENUM(ordered_data, size); } const string EnumType::GetValue(const Value &val) { @@ -1045,27 +1031,6 @@ idx_t EnumType::GetSize(const LogicalType &type) { return info->Cast().GetDictSize(); } -void EnumType::SetCatalog(LogicalType &type, optional_ptr catalog_entry) { - auto info = type.AuxInfo(); - if (!info) { - return; - } - ((ExtraTypeInfo &)*info).catalog_entry = catalog_entry; -} - -optional_ptr EnumType::GetCatalog(const LogicalType &type) { - auto info = type.AuxInfo(); - if (!info) { - return nullptr; - } - return info->catalog_entry; -} - -string EnumType::GetSchemaName(const LogicalType &type) { - auto catalog_entry = EnumType::GetCatalog(type); - return catalog_entry ? catalog_entry->schema.name : ""; -} - PhysicalType EnumType::GetPhysicalType(const LogicalType &type) { D_ASSERT(type.id() == LogicalTypeId::ENUM); auto aux_info = type.AuxInfo(); @@ -1090,15 +1055,6 @@ void LogicalType::Serialize(Serializer &serializer) const { writer.Finalize(); } -void LogicalType::SerializeEnumType(Serializer &serializer) const { - FieldWriter writer(serializer); - writer.WriteField(id_); - writer.WriteField(type_info_->type); - EnumType::Serialize(writer, *type_info_, true); - writer.WriteString(type_info_->alias); - writer.Finalize(); -} - LogicalType LogicalType::Deserialize(Deserializer &source) { FieldReader reader(source); auto id = reader.ReadRequired(); diff --git a/src/common/virtual_file_system.cpp b/src/common/virtual_file_system.cpp index f4af67304bd9..0aaff1423b96 100644 --- a/src/common/virtual_file_system.cpp +++ b/src/common/virtual_file_system.cpp @@ -108,6 +108,10 @@ void VirtualFileSystem::RemoveFile(const string &filename) { FindFileSystem(filename).RemoveFile(filename); } +string VirtualFileSystem::PathSeparator(const string &path) { + return FindFileSystem(path).PathSeparator(path); +} + vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { return FindFileSystem(path).Glob(path, opener); } diff --git a/src/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/core_functions/aggregate/holistic/approximate_quantile.cpp index f17263704979..10eda208999a 100644 --- a/src/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -4,6 +4,8 @@ #include "duckdb/planner/expression.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/field_writer.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" #include #include @@ -17,6 +19,8 @@ struct ApproxQuantileState { }; struct ApproximateQuantileBindData : public FunctionData { + ApproximateQuantileBindData() { + } explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) { } @@ -48,6 +52,18 @@ struct ApproximateQuantileBindData : public FunctionData { return make_uniq(std::move(quantiles)); } + static void FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("quantiles", bind_data.quantiles); + } + + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty("quantiles", result->quantiles); + return std::move(result); + } + vector quantiles; }; @@ -192,6 +208,8 @@ unique_ptr BindApproxQuantileDecimal(ClientContext &context, Aggre function.name = "approx_quantile"; function.serialize = ApproximateQuantileBindData::Serialize; function.deserialize = ApproximateQuantileBindData::Deserialize; + function.format_serialize = ApproximateQuantileBindData::FormatSerialize; + function.format_deserialize = ApproximateQuantileBindData::FormatDeserialize; return bind_data; } @@ -200,6 +218,8 @@ AggregateFunction GetApproximateQuantileAggregate(PhysicalType type) { fun.bind = BindApproxQuantile; fun.serialize = ApproximateQuantileBindData::Serialize; fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.format_serialize = ApproximateQuantileBindData::FormatSerialize; + fun.format_deserialize = ApproximateQuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::FLOAT); return fun; @@ -255,6 +275,8 @@ AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType auto fun = ApproxQuantileListAggregate(type, type); fun.serialize = ApproximateQuantileBindData::Serialize; fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.format_serialize = ApproximateQuantileBindData::FormatSerialize; + fun.format_deserialize = ApproximateQuantileBindData::FormatDeserialize; return fun; } @@ -300,6 +322,8 @@ unique_ptr BindApproxQuantileDecimalList(ClientContext &context, A function.name = "approx_quantile"; function.serialize = ApproximateQuantileBindData::Serialize; function.deserialize = ApproximateQuantileBindData::Deserialize; + function.format_serialize = ApproximateQuantileBindData::FormatSerialize; + function.format_deserialize = ApproximateQuantileBindData::FormatDeserialize; return bind_data; } @@ -308,6 +332,8 @@ AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { fun.bind = BindApproxQuantile; fun.serialize = ApproximateQuantileBindData::Serialize; fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.format_serialize = ApproximateQuantileBindData::FormatSerialize; + fun.format_deserialize = ApproximateQuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); fun.arguments.push_back(list_of_float); diff --git a/src/core_functions/aggregate/holistic/quantile.cpp b/src/core_functions/aggregate/holistic/quantile.cpp index cc76449c15df..88f58cfe8daa 100644 --- a/src/core_functions/aggregate/holistic/quantile.cpp +++ b/src/core_functions/aggregate/holistic/quantile.cpp @@ -8,6 +8,8 @@ #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/queue.hpp" #include "duckdb/common/field_writer.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" #include #include @@ -417,6 +419,8 @@ inline Value QuantileAbs(const Value &v) { } struct QuantileBindData : public FunctionData { + QuantileBindData() { + } explicit QuantileBindData(const Value &quantile_p) : quantiles(1, QuantileAbs(quantile_p)), order(1, 0), desc(quantile_p < 0) { @@ -456,6 +460,27 @@ struct QuantileBindData : public FunctionData { return desc == other.desc && quantiles == other.quantiles && order == other.order; } + static void FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("quantiles", bind_data.quantiles); + serializer.WriteProperty("order", bind_data.order); + serializer.WriteProperty("desc", bind_data.desc); + } + + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty("quantiles", result->quantiles); + deserializer.ReadProperty("order", result->order); + deserializer.ReadProperty("desc", result->desc); + return std::move(result); + } + + static void FormatSerializeDecimal(FormatSerializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + throw SerializationException("FIXME: quantile serialize for decimal"); + } + vector quantiles; vector order; bool desc; @@ -1189,6 +1214,8 @@ unique_ptr BindMedianDecimal(ClientContext &context, AggregateFunc function.name = "median"; function.serialize = QuantileDecimalSerialize; function.deserialize = QuantileDeserialize; + function.format_serialize = QuantileBindData::FormatSerializeDecimal; + function.format_deserialize = QuantileBindData::FormatDeserialize; function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return bind_data; } @@ -1245,6 +1272,8 @@ unique_ptr BindDiscreteQuantileDecimal(ClientContext &context, Agg function.name = "quantile_disc"; function.serialize = QuantileDecimalSerialize; function.deserialize = QuantileDeserialize; + function.format_serialize = QuantileBindData::FormatSerializeDecimal; + function.format_deserialize = QuantileBindData::FormatDeserialize; function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return bind_data; } @@ -1256,6 +1285,8 @@ unique_ptr BindDiscreteQuantileDecimalList(ClientContext &context, function.name = "quantile_disc"; function.serialize = QuantileDecimalSerialize; function.deserialize = QuantileDeserialize; + function.format_serialize = QuantileBindData::FormatSerializeDecimal; + function.format_deserialize = QuantileBindData::FormatDeserialize; function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return bind_data; } @@ -1267,6 +1298,8 @@ unique_ptr BindContinuousQuantileDecimal(ClientContext &context, A function.name = "quantile_cont"; function.serialize = QuantileDecimalSerialize; function.deserialize = QuantileDeserialize; + function.format_serialize = QuantileBindData::FormatSerializeDecimal; + function.format_deserialize = QuantileBindData::FormatDeserialize; function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return bind_data; } @@ -1278,6 +1311,8 @@ unique_ptr BindContinuousQuantileDecimalList(ClientContext &contex function.name = "quantile_cont"; function.serialize = QuantileDecimalSerialize; function.deserialize = QuantileDeserialize; + function.format_serialize = QuantileBindData::FormatSerializeDecimal; + function.format_deserialize = QuantileBindData::FormatDeserialize; function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return bind_data; } @@ -1298,6 +1333,8 @@ AggregateFunction GetMedianAggregate(const LogicalType &type) { fun.bind = BindMedian; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; return fun; } @@ -1306,6 +1343,8 @@ AggregateFunction GetDiscreteQuantileAggregate(const LogicalType &type) { fun.bind = BindQuantile; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; @@ -1317,6 +1356,8 @@ AggregateFunction GetDiscreteQuantileListAggregate(const LogicalType &type) { fun.bind = BindQuantile; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); fun.arguments.push_back(list_of_double); @@ -1329,6 +1370,8 @@ AggregateFunction GetContinuousQuantileAggregate(const LogicalType &type) { fun.bind = BindQuantile; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; @@ -1340,6 +1383,8 @@ AggregateFunction GetContinuousQuantileListAggregate(const LogicalType &type) { fun.bind = BindQuantile; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); fun.arguments.push_back(list_of_double); @@ -1353,6 +1398,8 @@ AggregateFunction GetQuantileDecimalAggregate(const vector &argumen fun.bind = bind; fun.serialize = QuantileSerialize; fun.deserialize = QuantileDeserialize; + fun.format_serialize = QuantileBindData::FormatSerialize; + fun.format_deserialize = QuantileBindData::FormatDeserialize; fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return fun; } diff --git a/src/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/core_functions/aggregate/holistic/reservoir_quantile.cpp index 3f5a5c108700..c82f69a2eadf 100644 --- a/src/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ b/src/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -4,6 +4,8 @@ #include "duckdb/planner/expression.hpp" #include "duckdb/common/queue.hpp" #include "duckdb/common/field_writer.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" #include #include @@ -49,6 +51,8 @@ struct ReservoirQuantileState { }; struct ReservoirQuantileBindData : public FunctionData { + ReservoirQuantileBindData() { + } ReservoirQuantileBindData(double quantile_p, int32_t sample_size_p) : quantiles(1, quantile_p), sample_size(sample_size_p) { } @@ -80,6 +84,20 @@ struct ReservoirQuantileBindData : public FunctionData { return make_uniq(std::move(quantiles), sample_size); } + static void FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("quantiles", bind_data.quantiles); + serializer.WriteProperty("sample_size", bind_data.sample_size); + } + + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty("quantiles", result->quantiles); + deserializer.ReadProperty("sample_size", result->sample_size); + return std::move(result); + } + vector quantiles; int32_t sample_size; }; @@ -355,6 +373,8 @@ unique_ptr BindReservoirQuantileDecimal(ClientContext &context, Ag function.name = "reservoir_quantile"; function.serialize = ReservoirQuantileBindData::Serialize; function.deserialize = ReservoirQuantileBindData::Deserialize; + function.format_serialize = ReservoirQuantileBindData::FormatSerialize; + function.format_deserialize = ReservoirQuantileBindData::FormatDeserialize; return bind_data; } @@ -363,6 +383,8 @@ AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { fun.bind = BindReservoirQuantile; fun.serialize = ReservoirQuantileBindData::Serialize; fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.format_serialize = ReservoirQuantileBindData::FormatSerialize; + fun.format_deserialize = ReservoirQuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); return fun; @@ -374,6 +396,8 @@ unique_ptr BindReservoirQuantileDecimalList(ClientContext &context auto bind_data = BindReservoirQuantile(context, function, arguments); function.serialize = ReservoirQuantileBindData::Serialize; function.deserialize = ReservoirQuantileBindData::Deserialize; + function.format_serialize = ReservoirQuantileBindData::FormatSerialize; + function.format_deserialize = ReservoirQuantileBindData::FormatDeserialize; function.name = "reservoir_quantile"; return bind_data; } @@ -383,6 +407,8 @@ AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { fun.bind = BindReservoirQuantile; fun.serialize = ReservoirQuantileBindData::Serialize; fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.format_serialize = ReservoirQuantileBindData::FormatSerialize; + fun.format_deserialize = ReservoirQuantileBindData::FormatDeserialize; // temporarily push an argument so we can bind the actual quantile auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); fun.arguments.push_back(list_of_double); @@ -411,6 +437,8 @@ static void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const BindReservoirQuantileDecimal); fun.serialize = ReservoirQuantileBindData::Serialize; fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.format_serialize = ReservoirQuantileBindData::FormatSerialize; + fun.format_deserialize = ReservoirQuantileBindData::FormatDeserialize; set.AddFunction(fun); fun.arguments.emplace_back(LogicalType::INTEGER); diff --git a/src/core_functions/scalar/date/strftime.cpp b/src/core_functions/scalar/date/strftime.cpp index 2a586b47e97d..69e549a7c92f 100644 --- a/src/core_functions/scalar/date/strftime.cpp +++ b/src/core_functions/scalar/date/strftime.cpp @@ -94,6 +94,16 @@ ScalarFunctionSet StrfTimeFun::GetFunctions() { return strftime; } +StrpTimeFormat::StrpTimeFormat() { +} + +StrpTimeFormat::StrpTimeFormat(const string &format_string) { + if (format_string.empty()) { + return; + } + StrTimeFormat::ParseFormatSpecifier(format_string, *this); +} + struct StrpTimeBindData : public FunctionData { StrpTimeBindData(const StrpTimeFormat &format, const string &format_string) : formats(1, format), format_strings(1, format_string) { diff --git a/src/core_functions/scalar/list/list_lambdas.cpp b/src/core_functions/scalar/list/list_lambdas.cpp index fa883b503ea5..48725efac96c 100644 --- a/src/core_functions/scalar/list/list_lambdas.cpp +++ b/src/core_functions/scalar/list/list_lambdas.cpp @@ -8,6 +8,8 @@ #include "duckdb/planner/expression/bound_lambda_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { @@ -28,6 +30,19 @@ struct ListLambdaBindData : public FunctionData { ScalarFunction &bound_function) { throw NotImplementedException("FIXME: list lambda deserialize"); } + + static void FormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("stype", bind_data.stype); + serializer.WriteOptionalProperty("lambda_expr", bind_data.lambda_expr); + } + + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer, ScalarFunction &function) { + auto stype = deserializer.ReadProperty("stype"); + auto lambda_expr = deserializer.ReadOptionalProperty>("lambda_expr"); + return make_uniq(stype, std::move(lambda_expr)); + } }; ListLambdaBindData::ListLambdaBindData(const LogicalType &stype_p, unique_ptr lambda_expr_p) @@ -35,12 +50,12 @@ ListLambdaBindData::ListLambdaBindData(const LogicalType &stype_p, unique_ptr ListLambdaBindData::Copy() const { - return make_uniq(stype, lambda_expr->Copy()); + return make_uniq(stype, lambda_expr ? lambda_expr->Copy() : nullptr); } bool ListLambdaBindData::Equals(const FunctionData &other_p) const { auto &other = other_p.Cast(); - return lambda_expr->Equals(*other.lambda_expr) && stype == other.stype; + return Expression::Equals(lambda_expr, other.lambda_expr) && stype == other.stype; } ListLambdaBindData::~ListLambdaBindData() { @@ -330,7 +345,7 @@ static unique_ptr ListLambdaBind(ClientContext &context, ScalarFun if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); + return make_uniq(bound_function.return_type, nullptr); } if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { @@ -385,6 +400,8 @@ ScalarFunction ListTransformFun::GetFunction() { fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; fun.serialize = ListLambdaBindData::Serialize; fun.deserialize = ListLambdaBindData::Deserialize; + fun.format_serialize = ListLambdaBindData::FormatSerialize; + fun.format_deserialize = ListLambdaBindData::FormatDeserialize; return fun; } @@ -394,6 +411,8 @@ ScalarFunction ListFilterFun::GetFunction() { fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; fun.serialize = ListLambdaBindData::Serialize; fun.deserialize = ListLambdaBindData::Deserialize; + fun.format_serialize = ListLambdaBindData::FormatSerialize; + fun.format_deserialize = ListLambdaBindData::FormatDeserialize; return fun; } diff --git a/src/core_functions/scalar/union/union_tag.cpp b/src/core_functions/scalar/union/union_tag.cpp index 9b5f5d3ec994..431df0ad24fd 100644 --- a/src/core_functions/scalar/union/union_tag.cpp +++ b/src/core_functions/scalar/union/union_tag.cpp @@ -39,7 +39,7 @@ static unique_ptr UnionTagBind(ClientContext &context, ScalarFunct FlatVector::GetData(varchar_vector)[i] = str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); } - auto enum_type = LogicalType::ENUM("", varchar_vector, member_count); + auto enum_type = LogicalType::ENUM(varchar_vector, member_count); bound_function.return_type = enum_type; return nullptr; diff --git a/src/execution/column_binding_resolver.cpp b/src/execution/column_binding_resolver.cpp index 7d24d8e7bc11..aee05a2e4332 100644 --- a/src/execution/column_binding_resolver.cpp +++ b/src/execution/column_binding_resolver.cpp @@ -3,7 +3,6 @@ #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_any_join.hpp" #include "duckdb/planner/operator/logical_create_index.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/operator/logical_insert.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" @@ -29,12 +28,9 @@ void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { for (auto &cond : comp_join.conditions) { VisitExpression(&cond.left); } - if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - // visit the duplicate eliminated columns on the LHS, if any - auto &delim_join = op.Cast(); - for (auto &expr : delim_join.duplicate_eliminated_columns) { - VisitExpression(&expr); - } + // visit the duplicate eliminated columns on the LHS, if any + for (auto &expr : comp_join.duplicate_eliminated_columns) { + VisitExpression(&expr); } // then get the bindings of the RHS and resolve the RHS expressions VisitOperator(*comp_join.children[1]); diff --git a/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/execution/operator/persistent/physical_copy_to_file.cpp index 5473c5b1a9e8..c192a4d942dc 100644 --- a/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -108,7 +108,7 @@ SinkCombineResultType PhysicalCopyToFile::Combine(ExecutionContext &context, Ope auto partition_key_map = l.part_buffer->GetReverseMap(); string trimmed_path = file_path; - StringUtil::RTrim(trimmed_path, fs.PathSeparator()); + StringUtil::RTrim(trimmed_path, fs.PathSeparator(trimmed_path)); for (idx_t i = 0; i < partitions.size(); i++) { string hive_path = diff --git a/src/execution/operator/scan/physical_table_scan.cpp b/src/execution/operator/scan/physical_table_scan.cpp index dfab2ec4832f..d007348868aa 100644 --- a/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/execution/operator/scan/physical_table_scan.cpp @@ -9,16 +9,6 @@ namespace duckdb { -PhysicalTableScan::PhysicalTableScan(vector types, TableFunction function_p, - unique_ptr bind_data_p, vector column_ids_p, - vector names_p, unique_ptr table_filters_p, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)), column_ids(std::move(column_ids_p)), - names(std::move(names_p)), table_filters(std::move(table_filters_p)) { - extra_info.file_filters = ""; -} - PhysicalTableScan::PhysicalTableScan(vector types, TableFunction function_p, unique_ptr bind_data_p, vector returned_types_p, vector column_ids_p, vector projection_ids_p, diff --git a/src/execution/operator/schema/physical_attach.cpp b/src/execution/operator/schema/physical_attach.cpp index 1d7d3ec9e3ad..d2aa6608d1a6 100644 --- a/src/execution/operator/schema/physical_attach.cpp +++ b/src/execution/operator/schema/physical_attach.cpp @@ -66,7 +66,8 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c const auto &path = info->path; if (name.empty()) { - name = AttachedDatabase::ExtractDatabaseName(path); + auto &fs = FileSystem::GetFileSystem(context.client); + name = AttachedDatabase::ExtractDatabaseName(path, fs); } auto &db_manager = DatabaseManager::Get(context.client); auto existing_db = db_manager.GetDatabaseFromPath(context.client, path); diff --git a/src/execution/operator/schema/physical_create_type.cpp b/src/execution/operator/schema/physical_create_type.cpp index 62ef0295ebe1..68bc258b36d6 100644 --- a/src/execution/operator/schema/physical_create_type.cpp +++ b/src/execution/operator/schema/physical_create_type.cpp @@ -74,15 +74,11 @@ SourceResultType PhysicalCreateType::GetData(ExecutionContext &context, DataChun if (IsSink()) { D_ASSERT(info->type == LogicalType::INVALID); auto &g_sink_state = sink_state->Cast(); - info->type = LogicalType::ENUM(info->name, g_sink_state.result, g_sink_state.size); + info->type = LogicalType::ENUM(g_sink_state.result, g_sink_state.size); } auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - auto catalog_entry = catalog.CreateType(context.client, *info); - D_ASSERT(catalog_entry->type == CatalogType::TYPE_ENTRY); - auto &catalog_type = catalog_entry->Cast(); - EnumType::SetCatalog(info->type, &catalog_type); - + catalog.CreateType(context.client, *info); return SourceResultType::FINISHED; } diff --git a/src/execution/physical_plan/plan_asof_join.cpp b/src/execution/physical_plan/plan_asof_join.cpp index 13cbcdd3b916..927defa4ff27 100644 --- a/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/execution/physical_plan/plan_asof_join.cpp @@ -7,11 +7,10 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" -#include "duckdb/planner/operator/logical_asof_join.hpp" namespace duckdb { -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAsOfJoin &op) { +unique_ptr PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) { // now visit the children D_ASSERT(op.children.size() == 2); idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); diff --git a/src/execution/physical_plan/plan_comparison_join.cpp b/src/execution/physical_plan/plan_comparison_join.cpp index 8a9cf064e8d9..48df353b1c28 100644 --- a/src/execution/physical_plan/plan_comparison_join.cpp +++ b/src/execution/physical_plan/plan_comparison_join.cpp @@ -237,7 +237,7 @@ static void RewriteJoinCondition(Expression &expr, idx_t offset) { ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { RewriteJoinCondition(child, offset); }); } -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparisonJoin &op) { +unique_ptr PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoin &op) { // now visit the children D_ASSERT(op.children.size() == 2); idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); @@ -338,4 +338,17 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparison return plan; } +unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparisonJoin &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + return PlanAsOfJoin(op); + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + return PlanComparisonJoin(op); + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + return PlanDelimJoin(op); + default: + throw InternalException("Unrecognized operator type for LogicalComparisonJoin"); + } +} + } // namespace duckdb diff --git a/src/execution/physical_plan/plan_create_index.cpp b/src/execution/physical_plan/plan_create_index.cpp index 53333c2bc594..c346a7b9754c 100644 --- a/src/execution/physical_plan/plan_create_index.cpp +++ b/src/execution/physical_plan/plan_create_index.cpp @@ -1,12 +1,11 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/execution/operator/filter/physical_filter.hpp" -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" #include "duckdb/execution/operator/schema/physical_create_index.hpp" #include "duckdb/execution/operator/order/physical_order.hpp" #include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/function/table/table_scan.hpp" #include "duckdb/planner/operator/logical_create_index.hpp" +#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/table_filter.hpp" @@ -14,11 +13,10 @@ namespace duckdb { unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateIndex &op) { - // generate a physical plan for the parallel index creation which consists of the following operators // table scan - projection (for expression execution) - filter (NOT NULL) - order - create index - - D_ASSERT(op.children.empty()); + D_ASSERT(op.children.size() == 1); + auto table_scan = CreatePlan(*op.children[0]); // validate that all expressions contain valid scalar functions // e.g. get_current_timestamp(), random(), and sequence values are not allowed as ART keys @@ -32,19 +30,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateInde } // table scan operator for index key columns and row IDs - - unique_ptr table_filters; - op.info->column_ids.emplace_back(COLUMN_IDENTIFIER_ROW_ID); - - auto &bind_data = op.bind_data->Cast(); - bind_data.is_create_index = true; - - auto table_scan = - make_uniq(op.info->scan_types, op.function, std::move(op.bind_data), op.info->column_ids, - op.info->names, std::move(table_filters), op.estimated_cardinality); - dependencies.AddDependency(op.table); - op.info->column_ids.pop_back(); D_ASSERT(op.info->scan_types.size() - 1 <= op.info->names.size()); D_ASSERT(op.info->scan_types.size() - 1 <= op.info->column_ids.size()); diff --git a/src/execution/physical_plan/plan_delim_join.cpp b/src/execution/physical_plan/plan_delim_join.cpp index 2f7072f557b5..f30cb2591f59 100644 --- a/src/execution/physical_plan/plan_delim_join.cpp +++ b/src/execution/physical_plan/plan_delim_join.cpp @@ -1,10 +1,7 @@ -#include "duckdb/execution/aggregate_hashtable.hpp" #include "duckdb/execution/operator/join/physical_delim_join.hpp" #include "duckdb/execution/operator/join/physical_hash_join.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" @@ -19,9 +16,9 @@ static void GatherDelimScans(const PhysicalOperator &op, vector PhysicalPlanGenerator::CreatePlan(LogicalDelimJoin &op) { +unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalComparisonJoin &op) { // first create the underlying join - auto plan = CreatePlan(op.Cast()); + auto plan = PlanComparisonJoin(op); // this should create a join, not a cross product D_ASSERT(plan && plan->type != PhysicalOperatorType::CROSS_PRODUCT); // duplicate eliminated join diff --git a/src/execution/physical_plan_generator.cpp b/src/execution/physical_plan_generator.cpp index 1178dbdbed20..b47c284c9c36 100644 --- a/src/execution/physical_plan_generator.cpp +++ b/src/execution/physical_plan_generator.cpp @@ -113,12 +113,8 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOperator & case LogicalOperatorType::LOGICAL_ANY_JOIN: plan = CreatePlan(op.Cast()); break; - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - plan = CreatePlan(op.Cast()); - break; case LogicalOperatorType::LOGICAL_ASOF_JOIN: - plan = CreatePlan(op.Cast()); - break; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: plan = CreatePlan(op.Cast()); break; diff --git a/src/function/aggregate/distributive/count.cpp b/src/function/aggregate/distributive/count.cpp index 8b0b89573e18..1f1c482835e8 100644 --- a/src/function/aggregate/distributive/count.cpp +++ b/src/function/aggregate/distributive/count.cpp @@ -218,22 +218,11 @@ AggregateFunction CountFun::GetFunction() { return fun; } -static void CountStarSerialize(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function) { -} - -static unique_ptr CountStarDeserialize(PlanDeserializationState &state, FieldReader &reader, - AggregateFunction &function) { - return nullptr; -} - AggregateFunction CountStarFun::GetFunction() { auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); fun.name = "count_star"; fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; fun.window = CountStarFunction::Window; - // TODO is there a better way to set those? - fun.serialize = CountStarSerialize; - fun.deserialize = CountStarDeserialize; return fun; } diff --git a/src/function/aggregate/sorted_aggregate_function.cpp b/src/function/aggregate/sorted_aggregate_function.cpp index 28dadf90f8b1..3c343ff8cd2f 100644 --- a/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/function/aggregate/sorted_aggregate_function.cpp @@ -516,14 +516,6 @@ struct SortedAggregateFunction { result.Verify(count); } - - static void Serialize(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function) { - throw NotImplementedException("FIXME: serialize sorted aggregate not supported"); - } - static unique_ptr Deserialize(PlanDeserializationState &state, FieldReader &reader, - AggregateFunction &function) { - throw NotImplementedException("FIXME: deserialize sorted aggregate not supported"); - } }; void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, @@ -582,7 +574,7 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE AggregateFunction::StateCombine, SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, AggregateFunction::StateDestroy, nullptr, - SortedAggregateFunction::Window, SortedAggregateFunction::Serialize, SortedAggregateFunction::Deserialize); + SortedAggregateFunction::Window); expr.function = std::move(ordered_aggregate); expr.bind_info = std::move(sorted_bind); diff --git a/src/function/cast/string_cast.cpp b/src/function/cast/string_cast.cpp index 716676d43aaa..709e42fae3d5 100644 --- a/src/function/cast/string_cast.cpp +++ b/src/function/cast/string_cast.cpp @@ -36,7 +36,6 @@ bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, template bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); - auto enum_name = EnumType::GetTypeName(result.GetType()); switch (source.GetVectorType()) { case VectorType::CONSTANT_VECTOR: { result.SetVectorType(VectorType::CONSTANT_VECTOR); diff --git a/src/function/scalar/system/aggregate_export.cpp b/src/function/scalar/system/aggregate_export.cpp index 6262a04111c7..605f225debd9 100644 --- a/src/function/scalar/system/aggregate_export.cpp +++ b/src/function/scalar/system/aggregate_export.cpp @@ -286,6 +286,17 @@ static unique_ptr ExportStateAggregateDeserialize(PlanDeserializat throw NotImplementedException("FIXME: export state deserialize"); } +static void ExportStateAggregateFormatSerialize(FormatSerializer &serializer, + const optional_ptr bind_data_p, + const AggregateFunction &function) { + throw SerializationException("FIXME: export state serialize"); +} + +static unique_ptr ExportStateAggregateFormatDeserialize(FormatDeserializer &deserializer, + AggregateFunction &function) { + throw SerializationException("FIXME: export state deserialize"); +} + static void ExportStateScalarSerialize(FieldWriter &writer, const FunctionData *bind_data_p, const ScalarFunction &function) { throw NotImplementedException("FIXME: export state serialize"); @@ -295,6 +306,16 @@ static unique_ptr ExportStateScalarDeserialize(PlanDeserialization throw NotImplementedException("FIXME: export state deserialize"); } +static void ExportStateScalarFormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + throw SerializationException("FIXME: export state serialize"); +} + +static unique_ptr ExportStateScalarFormatDeserialize(FormatDeserializer &deserializer, + ScalarFunction &function) { + throw SerializationException("FIXME: export state deserialize"); +} + unique_ptr ExportAggregateFunction::Bind(unique_ptr child_aggregate) { auto &bound_function = child_aggregate->function; @@ -331,6 +352,8 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; export_function.serialize = ExportStateAggregateSerialize; export_function.deserialize = ExportStateAggregateDeserialize; + export_function.format_serialize = ExportStateAggregateFormatSerialize; + export_function.format_deserialize = ExportStateAggregateFormatDeserialize; return make_uniq(export_function, std::move(child_aggregate->children), std::move(child_aggregate->filter), std::move(export_bind_data), @@ -343,6 +366,8 @@ ScalarFunction ExportAggregateFunction::GetFinalize() { result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; result.serialize = ExportStateScalarSerialize; result.deserialize = ExportStateScalarDeserialize; + result.format_serialize = ExportStateScalarFormatSerialize; + result.format_deserialize = ExportStateScalarFormatDeserialize; return result; } @@ -353,6 +378,8 @@ ScalarFunction ExportAggregateFunction::GetCombine() { result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; result.serialize = ExportStateScalarSerialize; result.deserialize = ExportStateScalarDeserialize; + result.format_serialize = ExportStateScalarFormatSerialize; + result.format_deserialize = ExportStateScalarFormatDeserialize; return result; } diff --git a/src/function/scalar_function.cpp b/src/function/scalar_function.cpp index 0bfeba7d81e5..6d46a25b6263 100644 --- a/src/function/scalar_function.cpp +++ b/src/function/scalar_function.cpp @@ -13,7 +13,8 @@ ScalarFunction::ScalarFunction(string name, vector arguments, Logic : BaseScalarFunction(std::move(name), std::move(arguments), std::move(return_type), side_effects, std::move(varargs), null_handling), function(std::move(function)), bind(bind), init_local_state(init_local_state), dependency(dependency), - statistics(statistics), serialize(nullptr), deserialize(nullptr) { + statistics(statistics), serialize(nullptr), deserialize(nullptr), format_serialize(nullptr), + format_deserialize(nullptr) { } ScalarFunction::ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, @@ -26,9 +27,10 @@ ScalarFunction::ScalarFunction(vector arguments, LogicalType return } bool ScalarFunction::operator==(const ScalarFunction &rhs) const { - return CompareScalarFunctionT(rhs.function) && bind == rhs.bind && dependency == rhs.dependency && - statistics == rhs.statistics; + return name == rhs.name && arguments == rhs.arguments && return_type == rhs.return_type && varargs == rhs.varargs && + bind == rhs.bind && dependency == rhs.dependency && statistics == rhs.statistics; } + bool ScalarFunction::operator!=(const ScalarFunction &rhs) const { return !(*this == rhs); } @@ -56,23 +58,6 @@ bool ScalarFunction::Equal(const ScalarFunction &rhs) const { return true; // they are equal } -bool ScalarFunction::CompareScalarFunctionT(const scalar_function_t &other) const { - typedef void(scalar_function_ptr_t)(DataChunk &, ExpressionState &, Vector &); - - auto func_ptr = (scalar_function_ptr_t **)function.template target(); // NOLINT - auto other_ptr = (scalar_function_ptr_t **)other.template target(); // NOLINT - - // Case the functions were created from lambdas the target will return a nullptr - if (!func_ptr && !other_ptr) { - return true; - } - if (func_ptr == nullptr || other_ptr == nullptr) { - // scalar_function_t (std::functions) from lambdas cannot be compared - return false; - } - return CastPointerToValue(*func_ptr) == CastPointerToValue(*other_ptr); -} - void ScalarFunction::NopFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() >= 1); result.Reference(input.data[0]); diff --git a/src/function/table/arrow.cpp b/src/function/table/arrow.cpp index 04dbf0798300..de27b9f4d74c 100644 --- a/src/function/table/arrow.cpp +++ b/src/function/table/arrow.cpp @@ -123,7 +123,7 @@ LogicalType ArrowTableFunction::GetArrowLogicalType( child_list_t child_types; for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { auto child_type = GetArrowLogicalType(*schema.children[type_idx], arrow_convert_data, col_idx); - child_types.push_back({schema.children[type_idx]->name, child_type}); + child_types.emplace_back(schema.children[type_idx]->name, child_type); } return LogicalType::STRUCT(child_types); diff --git a/src/function/table/arrow_conversion.cpp b/src/function/table/arrow_conversion.cpp index 3e32acbdedcd..ec893828a1eb 100644 --- a/src/function/table/arrow_conversion.cpp +++ b/src/function/table/arrow_conversion.cpp @@ -93,7 +93,7 @@ static void SetValidityMask(Vector &vector, ArrowArray &array, ArrowScanLocalSta static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, std::unordered_map> &arrow_convert_data, idx_t col_idx, ArrowConvertDataIndices &arrow_convert_idx, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr); + ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, std::unordered_map> &arrow_convert_data, @@ -265,12 +265,13 @@ static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) } } -static void DirectConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, - int64_t nested_offset) { +static void DirectConversion(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, int64_t nested_offset, + uint64_t parent_offset) { auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); - auto data_ptr = ArrowBufferData(array, 1) + internal_type * (scan_state.chunk_offset + array.offset); + auto data_ptr = + ArrowBufferData(array, 1) + internal_type * (scan_state.chunk_offset + array.offset + parent_offset); if (nested_offset != -1) { - data_ptr = ArrowBufferData(array, 1) + internal_type * (array.offset + nested_offset); + data_ptr = ArrowBufferData(array, 1) + internal_type * (array.offset + nested_offset + parent_offset); } FlatVector::SetData(vector, data_ptr); } @@ -359,7 +360,7 @@ static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, A static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLocalState &scan_state, idx_t size, std::unordered_map> &arrow_convert_data, idx_t col_idx, ArrowConvertDataIndices &arrow_convert_idx, int64_t nested_offset, - ValidityMask *parent_mask) { + ValidityMask *parent_mask, uint64_t parent_offset) { switch (vector.GetType().id()) { case LogicalTypeId::SQLNULL: vector.Reference(Value()); @@ -407,7 +408,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLoca case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_NS: { - DirectConversion(vector, array, scan_state, nested_offset); + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); break; } case LogicalTypeId::VARCHAR: { @@ -432,7 +433,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLoca auto precision = arrow_convert_data[col_idx]->date_time_precision[arrow_convert_idx.datetime_precision_index++]; switch (precision) { case ArrowDateTimeType::DAYS: { - DirectConversion(vector, array, scan_state, nested_offset); + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); break; } case ArrowDateTimeType::MILLISECONDS: { @@ -495,7 +496,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLoca break; } case ArrowDateTimeType::MICROSECONDS: { - DirectConversion(vector, array, scan_state, nested_offset); + DirectConversion(vector, array, scan_state, nested_offset, parent_offset); break; } case ArrowDateTimeType::NANOSECONDS: { @@ -640,7 +641,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowScanLoca } } ColumnArrowToDuckDB(*child_entries[type_idx], *array.children[type_idx], scan_state, size, - arrow_convert_data, col_idx, arrow_convert_idx, nested_offset, &struct_validity_mask); + arrow_convert_data, col_idx, arrow_convert_idx, nested_offset, &struct_validity_mask, + array.offset); } break; } diff --git a/src/function/table/read_csv.cpp b/src/function/table/read_csv.cpp index 9d343e73cf75..d9066a315aa4 100644 --- a/src/function/table/read_csv.cpp +++ b/src/function/table/read_csv.cpp @@ -15,6 +15,8 @@ #include "duckdb/main/client_data.hpp" #include "duckdb/execution/operator/persistent/csv_line_info.hpp" #include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" #include @@ -1235,6 +1237,20 @@ static unique_ptr CSVReaderDeserialize(PlanDeserializationState &s return std::move(result_data); } +static void CSVReaderFormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("extra_info", function.extra_info); + serializer.WriteProperty("csv_data", bind_data); +} + +static unique_ptr CSVReaderFormatDeserialize(FormatDeserializer &deserializer, TableFunction &function) { + unique_ptr result; + deserializer.ReadProperty("extra_info", function.extra_info); + deserializer.ReadProperty("csv_data", result); + return std::move(result); +} + TableFunction ReadCSVTableFunction::GetFunction() { TableFunction read_csv("read_csv", {LogicalType::VARCHAR}, ReadCSVFunction, ReadCSVBind, ReadCSVInitGlobal, ReadCSVInitLocal); @@ -1242,6 +1258,8 @@ TableFunction ReadCSVTableFunction::GetFunction() { read_csv.pushdown_complex_filter = CSVComplexFilterPushdown; read_csv.serialize = CSVReaderSerialize; read_csv.deserialize = CSVReaderDeserialize; + read_csv.format_serialize = CSVReaderFormatSerialize; + read_csv.format_deserialize = CSVReaderFormatDeserialize; read_csv.get_batch_index = CSVReaderGetBatchIndex; read_csv.cardinality = CSVReaderCardinality; read_csv.projection_pushdown = true; @@ -1279,7 +1297,8 @@ unique_ptr ReadCSVReplacement(ClientContext &context, const string &ta table_function->function = make_uniq("read_csv_auto", std::move(children)); if (!FileSystem::HasGlob(table_name)) { - table_function->alias = FileSystem::ExtractBaseName(table_name); + auto &fs = FileSystem::GetFileSystem(context); + table_function->alias = fs.ExtractBaseName(table_name); } return std::move(table_function); diff --git a/src/function/table/system/test_all_types.cpp b/src/function/table/system/test_all_types.cpp index 1652c3eccda2..a0b856266c1b 100644 --- a/src/function/table/system/test_all_types.cpp +++ b/src/function/table/system/test_all_types.cpp @@ -70,14 +70,14 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { auto small_enum_ptr = FlatVector::GetData(small_enum); small_enum_ptr[0] = StringVector::AddStringOrBlob(small_enum, "DUCK_DUCK_ENUM"); small_enum_ptr[1] = StringVector::AddStringOrBlob(small_enum, "GOOSE"); - result.emplace_back(LogicalType::ENUM("small_enum", small_enum, 2), "small_enum"); + result.emplace_back(LogicalType::ENUM(small_enum, 2), "small_enum"); Vector medium_enum(LogicalType::VARCHAR, 300); auto medium_enum_ptr = FlatVector::GetData(medium_enum); for (idx_t i = 0; i < 300; i++) { medium_enum_ptr[i] = StringVector::AddStringOrBlob(medium_enum, string("enum_") + to_string(i)); } - result.emplace_back(LogicalType::ENUM("medium_enum", medium_enum, 300), "medium_enum"); + result.emplace_back(LogicalType::ENUM(medium_enum, 300), "medium_enum"); if (use_large_enum) { // this is a big one... not sure if we should push this one here, but it's required for completeness @@ -86,13 +86,13 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { for (idx_t i = 0; i < 70000; i++) { large_enum_ptr[i] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(i)); } - result.emplace_back(LogicalType::ENUM("large_enum", large_enum, 70000), "large_enum"); + result.emplace_back(LogicalType::ENUM(large_enum, 70000), "large_enum"); } else { Vector large_enum(LogicalType::VARCHAR, 2); auto large_enum_ptr = FlatVector::GetData(large_enum); large_enum_ptr[0] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(0)); large_enum_ptr[1] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(69999)); - result.emplace_back(LogicalType::ENUM("large_enum", large_enum, 2), "large_enum"); + result.emplace_back(LogicalType::ENUM(large_enum, 2), "large_enum"); } // arrays diff --git a/src/function/table/table_scan.cpp b/src/function/table/table_scan.cpp index e2fec74e2fef..5aca910b1a4f 100644 --- a/src/function/table/table_scan.cpp +++ b/src/function/table/table_scan.cpp @@ -15,6 +15,8 @@ #include "duckdb/catalog/dependency_list.hpp" #include "duckdb/function/function_set.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { @@ -447,6 +449,35 @@ static unique_ptr TableScanDeserialize(PlanDeserializationState &s return std::move(result); } +static void TableScanFormatSerialize(FormatSerializer &serializer, const optional_ptr bind_data_p, + const TableFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty("catalog", bind_data.table.schema.catalog.GetName()); + serializer.WriteProperty("schema", bind_data.table.schema.name); + serializer.WriteProperty("table", bind_data.table.name); + serializer.WriteProperty("is_index_scan", bind_data.is_index_scan); + serializer.WriteProperty("is_create_index", bind_data.is_create_index); + serializer.WriteProperty("result_ids", bind_data.result_ids); + serializer.WriteProperty("result_ids", bind_data.result_ids); +} + +static unique_ptr TableScanFormatDeserialize(FormatDeserializer &deserializer, TableFunction &function) { + auto catalog = deserializer.ReadProperty("catalog"); + auto schema = deserializer.ReadProperty("schema"); + auto table = deserializer.ReadProperty("table"); + auto &catalog_entry = + Catalog::GetEntry(deserializer.Get(), catalog, schema, table); + if (catalog_entry.type != CatalogType::TABLE_ENTRY) { + throw SerializationException("Cant find table for %s.%s", schema, table); + } + auto result = make_uniq(catalog_entry.Cast()); + deserializer.ReadProperty("is_index_scan", result->is_index_scan); + deserializer.ReadProperty("is_create_index", result->is_create_index); + deserializer.ReadProperty("result_ids", result->result_ids); + deserializer.ReadProperty("result_ids", result->result_ids); + return std::move(result); +} + TableFunction TableScanFunction::GetIndexScanFunction() { TableFunction scan_function("index_scan", {}, IndexScanFunction); scan_function.init_local = nullptr; @@ -462,6 +493,8 @@ TableFunction TableScanFunction::GetIndexScanFunction() { scan_function.filter_pushdown = false; scan_function.serialize = TableScanSerialize; scan_function.deserialize = TableScanDeserialize; + scan_function.format_serialize = TableScanFormatSerialize; + scan_function.format_deserialize = TableScanFormatDeserialize; return scan_function; } @@ -482,6 +515,8 @@ TableFunction TableScanFunction::GetFunction() { scan_function.filter_prune = true; scan_function.serialize = TableScanSerialize; scan_function.deserialize = TableScanDeserialize; + scan_function.format_serialize = TableScanFormatSerialize; + scan_function.format_deserialize = TableScanFormatDeserialize; return scan_function; } diff --git a/src/function/table_function.cpp b/src/function/table_function.cpp index 4fcf8d82f91b..61a6d1744b8f 100644 --- a/src/function/table_function.cpp +++ b/src/function/table_function.cpp @@ -18,8 +18,8 @@ TableFunction::TableFunction(string name, vector arguments, table_f init_global(init_global), init_local(init_local), function(function), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), - get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), projection_pushdown(false), - filter_pushdown(false), filter_prune(false) { + get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), format_serialize(nullptr), + format_deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false) { } TableFunction::TableFunction(const vector &arguments, table_function_t function, @@ -32,7 +32,8 @@ TableFunction::TableFunction() init_local(nullptr), function(nullptr), in_out_function(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), get_batch_info(nullptr), serialize(nullptr), deserialize(nullptr), - projection_pushdown(false), filter_pushdown(false), filter_prune(false) { + format_serialize(nullptr), format_deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), + filter_prune(false) { } bool TableFunction::Equal(const TableFunction &rhs) const { diff --git a/src/include/duckdb.h b/src/include/duckdb.h index 8c82044370ef..a54ed019ee39 100644 --- a/src/include/duckdb.h +++ b/src/include/duckdb.h @@ -360,6 +360,21 @@ The instantiated connection should be closed using 'duckdb_disconnect' */ DUCKDB_API duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out_connection); +/*! +Interrupt running query + +* connection: The connection to interruot +*/ +DUCKDB_API void duckdb_interrupt(duckdb_connection connection); + +/*! +Get progress of the running query + +* connection: The working connection +* returns: -1 if no progress or a percentage of the progress +*/ +DUCKDB_API double duckdb_query_progress(duckdb_connection connection); + /*! Closes the specified connection and de-allocates all memory allocated for that connection. diff --git a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index ccd00eb8650b..207e52e7381a 100644 --- a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -71,8 +71,6 @@ class TableCatalogEntry : public StandardEntry { DUCKDB_API vector GetTypes(); //! Returns a list of the columns of the table DUCKDB_API const ColumnList &GetColumns() const; - //! Returns a mutable list of the columns of the table - DUCKDB_API ColumnList &GetColumnsMutable(); //! Returns the underlying storage of the table virtual DataTable &GetStorage(); //! Returns a list of the bound constraints of the table diff --git a/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp b/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp index 0852d4b12496..7bfb18202efa 100644 --- a/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/type_catalog_entry.hpp @@ -29,10 +29,7 @@ class TypeCatalogEntry : public StandardEntry { LogicalType user_type; public: - //! Serialize the meta information of the TypeCatalogEntry a serializer - virtual void Serialize(Serializer &serializer) const; - //! Deserializes to a TypeCatalogEntry - static unique_ptr Deserialize(Deserializer &source); + unique_ptr GetInfo() const override; string ToSQL() const override; }; diff --git a/src/include/duckdb/catalog/catalog_set.hpp b/src/include/duckdb/catalog/catalog_set.hpp index d57cd8448d4c..dba2d661930f 100644 --- a/src/include/duckdb/catalog/catalog_set.hpp +++ b/src/include/duckdb/catalog/catalog_set.hpp @@ -125,12 +125,6 @@ class CatalogSet { void Verify(Catalog &catalog); private: - //! Adjusts table dependencies on the event of an UNDO - void AdjustTableDependencies(CatalogEntry &entry); - //! Adjust one dependency - void AdjustDependency(CatalogEntry &entry, TableCatalogEntry &table, ColumnDefinition &column, bool remove); - //! Adjust User dependency - void AdjustUserDependency(CatalogEntry &entry, ColumnDefinition &column, bool remove); //! Given a root entry, gets the entry valid for this transaction CatalogEntry &GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t); CatalogEntry &GetCommittedEntry(CatalogEntry ¤t); diff --git a/src/include/duckdb/common/arrow/arrow_cpp.hpp b/src/include/duckdb/common/arrow/arrow_cpp.hpp new file mode 100644 index 000000000000..d9b62200be57 --- /dev/null +++ b/src/include/duckdb/common/arrow/arrow_cpp.hpp @@ -0,0 +1,184 @@ +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/arrow/arrow.hpp" + +// These classes are wrappers around Arrow objects, making sure the conventions around ownership are respected. + +namespace duckdb { + +template +class ArrowObjectBase { +public: + explicit ArrowObjectBase() { + object.release = nullptr; + } + ArrowObjectBase(ARROW_OBJECT interface) { + object.release = nullptr; + *this = interface; + } + operator ARROW_OBJECT *() { + return &object; + } + ArrowObjectBase(ARROW_OBJECT *interface) : ArrowObjectBase(*interface) { + if (OWNING) { + // We have taken ownership over this schema object, communicate this to the caller + interface->release = nullptr; + } + } + ArrowObjectBase(ArrowObjectBase &&other) : ArrowObjectBase(&other.object) { + } + ArrowObjectBase(ArrowObjectBase &other) = delete; + ~ArrowObjectBase() { + if (OWNING && object.release) { + object.release(&object); + object.release = nullptr; + } + object.release = nullptr; + } + +public: + ArrowObjectBase &operator=(ARROW_OBJECT other) { + if (Valid()) { + object.release(&object); + } + memcpy(&object, &other, sizeof(object)); + return *this; + } + ArrowObjectBase &operator=(ArrowObjectBase &&other) { + *this = other.object; + other.object.release = nullptr; + } + +public: + bool Valid() const { + return object.release != nullptr; + } + +protected: + void AssertOwnership() { + // Only when the release pointer is not null can we be sure the data of the array is valid and accessible + D_ASSERT(object.release); + } + +protected: + ARROW_OBJECT object; +}; + +template +class ArrowSchemaCPP : public ArrowObjectBase { +private: + using base = ArrowObjectBase; + +public: + using base::base; + using base::operator=; + using base::AssertOwnership; + using base::object; + +public: + const string Format() const { + AssertOwnership(); + return object.format; + } + const string Name() const { + AssertOwnership(); + return object.name; + } + const string MetaData() const { + AssertOwnership(); + return object.metadata; + } + int64_t Flags() const { + AssertOwnership(); + return object.flags; + } + int64_t ChildrenCount() const { + AssertOwnership(); + return object.n_children; + } + struct ArrowSchema **Children() const { + AssertOwnership(); + return object.children; + } + struct ArrowSchema *Dictionary() const { + AssertOwnership(); + return object.dictionary; + } +}; + +template +class ArrowArrayCPP : public ArrowObjectBase { +private: + using base = ArrowObjectBase; + +public: + using base::base; + using base::operator=; + using base::AssertOwnership; + using base::object; + +public: + int64_t Length() const { + AssertOwnership(); + return object.length; + } + int64_t NullCount() const { + AssertOwnership(); + return object.null_count; + } + int64_t Offset() const { + AssertOwnership(); + return object.offset; + } + int64_t BufferCount() const { + AssertOwnership(); + return object.n_buffers; + } + int64_t ChildrenCount() const { + AssertOwnership(); + return object.n_children; + } + const void **Buffers() const { + AssertOwnership(); + return object.buffers; + } + struct ArrowArray **Children() const { + AssertOwnership(); + return object.children; + } + struct ArrowArray *Dictionary() const { + AssertOwnership(); + return object.dictionary; + } +}; + +template +class ArrowArrayStreamCPP : public ArrowObjectBase { +private: + using base = ArrowObjectBase; + +public: + using base::base; + using base::operator=; + using base::AssertOwnership; + using base::object; + +public: + int GetSchema(ArrowSchemaCPP &out) { + ArrowSchema result; + object.get_schema(&object, &result); + out = result; + } + int GetNext(ArrowArrayCPP &out) { + ArrowArray result; + object.get_next(&object, &result); + out = result; + } + string GetLastError() { + auto result = object.get_last_error(&object); + return result; + } +}; + +} // namespace duckdb diff --git a/src/include/duckdb/common/arrow/arrow_wrapper.hpp b/src/include/duckdb/common/arrow/arrow_wrapper.hpp index 2400fdfa57bf..39e811e6705e 100644 --- a/src/include/duckdb/common/arrow/arrow_wrapper.hpp +++ b/src/include/duckdb/common/arrow/arrow_wrapper.hpp @@ -10,6 +10,8 @@ #include "duckdb/common/arrow/arrow.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/preserved_error.hpp" +#include "duckdb/main/chunk_scan_state.hpp" +#include "duckdb/common/arrow/arrow_options.hpp" //! Here we have the internal duckdb classes that interact with Arrow's Internal Header (i.e., duckdb/commons/arrow.hpp) namespace duckdb { @@ -56,9 +58,9 @@ class ArrowArrayStreamWrapper { class ArrowUtil { public: - static bool TryFetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out, idx_t &result_count, - PreservedError &error); - static idx_t FetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out); + static bool TryFetchChunk(ChunkScanState &scan_state, ArrowOptions options, idx_t chunk_size, ArrowArray *out, + idx_t &result_count, PreservedError &error); + static idx_t FetchChunk(ChunkScanState &scan_state, ArrowOptions options, idx_t chunk_size, ArrowArray *out); private: static bool TryFetchNext(QueryResult &result, unique_ptr &out, PreservedError &error); diff --git a/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp b/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp index 10fabcd1760f..629316de9f3a 100644 --- a/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp +++ b/src/include/duckdb/common/arrow/result_arrow_wrapper.hpp @@ -10,17 +10,21 @@ #include "duckdb/main/query_result.hpp" #include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/main/chunk_scan_state.hpp" namespace duckdb { class ResultArrowArrayStreamWrapper { public: explicit ResultArrowArrayStreamWrapper(unique_ptr result, idx_t batch_size); + +public: ArrowArrayStream stream; unique_ptr result; PreservedError last_error; idx_t batch_size; vector column_types; vector column_names; + unique_ptr scan_state; private: static int MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out); diff --git a/src/include/duckdb/common/extra_type_info.hpp b/src/include/duckdb/common/extra_type_info.hpp index ec53f6c19411..0a8c300d136a 100644 --- a/src/include/duckdb/common/extra_type_info.hpp +++ b/src/include/duckdb/common/extra_type_info.hpp @@ -35,7 +35,6 @@ struct ExtraTypeInfo { ExtraTypeInfoType type; string alias; - optional_ptr catalog_entry; public: bool Equals(ExtraTypeInfo *other_p) const; @@ -184,19 +183,17 @@ struct UserTypeInfo : public ExtraTypeInfo { enum EnumDictType : uint8_t { INVALID = 0, VECTOR_DICT = 1 }; struct EnumTypeInfo : public ExtraTypeInfo { - explicit EnumTypeInfo(string enum_name_p, Vector &values_insert_order_p, idx_t dict_size_p); + explicit EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p); EnumTypeInfo(const EnumTypeInfo &) = delete; EnumTypeInfo &operator=(const EnumTypeInfo &) = delete; public: const EnumDictType &GetEnumDictType() const; - const string &GetEnumName() const; - const string GetSchemaName() const; const Vector &GetValuesInsertOrder() const; const idx_t &GetDictSize() const; static PhysicalType DictType(idx_t size); - static LogicalType CreateType(const string &enum_name, Vector &ordered_data, idx_t size); + static LogicalType CreateType(Vector &ordered_data, idx_t size); void Serialize(FieldWriter &writer) const override; static shared_ptr Deserialize(FieldReader &reader); @@ -212,7 +209,6 @@ struct EnumTypeInfo : public ExtraTypeInfo { private: EnumDictType dict_type; - string enum_name; idx_t dict_size; }; diff --git a/src/include/duckdb/common/field_writer.hpp b/src/include/duckdb/common/field_writer.hpp index 8f3cebfeaba1..7b066b4ef0b8 100644 --- a/src/include/duckdb/common/field_writer.hpp +++ b/src/include/duckdb/common/field_writer.hpp @@ -167,10 +167,6 @@ class FieldDeserializer : public Deserializer { return root.GetContext(); } - optional_ptr GetCatalog() override { - return root.GetCatalog(); - } - private: Deserializer &root; idx_t remaining_data; diff --git a/src/include/duckdb/common/file_system.hpp b/src/include/duckdb/common/file_system.hpp index cd48efe07a3f..1f06b3ed664c 100644 --- a/src/include/duckdb/common/file_system.hpp +++ b/src/include/duckdb/common/file_system.hpp @@ -185,21 +185,21 @@ class FileSystem { DUCKDB_API virtual string ExpandPath(const string &path); //! Returns the system-available memory in bytes. Returns DConstants::INVALID_INDEX if the system function fails. DUCKDB_API static idx_t GetAvailableMemory(); - //! Path separator for the current file system - DUCKDB_API static string PathSeparator(); + //! Path separator for path + DUCKDB_API virtual string PathSeparator(const string &path); //! Checks if path is starts with separator (i.e., '/' on UNIX '\\' on Windows) - DUCKDB_API static bool IsPathAbsolute(const string &path); + DUCKDB_API bool IsPathAbsolute(const string &path); //! Normalize an absolute path - the goal of normalizing is converting "\test.db" and "C:/test.db" into "C:\test.db" //! so that the database system cache can correctly - DUCKDB_API static string NormalizeAbsolutePath(const string &path); + DUCKDB_API string NormalizeAbsolutePath(const string &path); //! Join two paths together - DUCKDB_API static string JoinPath(const string &a, const string &path); + DUCKDB_API string JoinPath(const string &a, const string &path); //! Convert separators in a path to the local separators (e.g. convert "/" into \\ on windows) - DUCKDB_API static string ConvertSeparators(const string &path); + DUCKDB_API string ConvertSeparators(const string &path); //! Extract the base name of a file (e.g. if the input is lib/example.dll the base name is 'example') - DUCKDB_API static string ExtractBaseName(const string &path); + DUCKDB_API string ExtractBaseName(const string &path); //! Extract the name of a file (e.g if the input is lib/example.dll the name is 'example.dll') - DUCKDB_API static string ExtractName(const string &path); + DUCKDB_API string ExtractName(const string &path); //! Returns the value of an environment variable - or the empty string if it is not set DUCKDB_API static string GetEnvVariable(const string &name); diff --git a/src/include/duckdb/common/filename_pattern.hpp b/src/include/duckdb/common/filename_pattern.hpp index 152722ed7d8d..3795fc364857 100644 --- a/src/include/duckdb/common/filename_pattern.hpp +++ b/src/include/duckdb/common/filename_pattern.hpp @@ -23,7 +23,7 @@ class FilenamePattern { public: void SetFilenamePattern(const string &pattern); - string CreateFilename(const FileSystem &fs, const string &path, const string &extension, idx_t offset) const; + string CreateFilename(FileSystem &fs, const string &path, const string &extension, idx_t offset) const; private: string _base; diff --git a/src/include/duckdb/common/multi_file_reader.hpp b/src/include/duckdb/common/multi_file_reader.hpp index 3ff87af11172..44c25179f43d 100644 --- a/src/include/duckdb/common/multi_file_reader.hpp +++ b/src/include/duckdb/common/multi_file_reader.hpp @@ -8,7 +8,7 @@ #pragma once -#include "duckdb/common/types.hpp" +#include "duckdb/common/common.hpp" #include "duckdb/common/multi_file_reader_options.hpp" #include "duckdb/common/enums/file_glob_options.hpp" #include "duckdb/common/union_by_name.hpp" @@ -32,6 +32,8 @@ struct HivePartitioningIndex { DUCKDB_API void Serialize(Serializer &serializer) const; DUCKDB_API static HivePartitioningIndex Deserialize(Deserializer &source); + DUCKDB_API void FormatSerialize(FormatSerializer &serializer) const; + DUCKDB_API static HivePartitioningIndex FormatDeserialize(FormatDeserializer &deserializer); }; //! The bind data for the multi-file reader, obtained through MultiFileReader::BindReader @@ -43,6 +45,8 @@ struct MultiFileReaderBindData { DUCKDB_API void Serialize(Serializer &serializer) const; DUCKDB_API static MultiFileReaderBindData Deserialize(Deserializer &source); + DUCKDB_API void FormatSerialize(FormatSerializer &serializer) const; + DUCKDB_API static MultiFileReaderBindData FormatDeserialize(FormatDeserializer &deserializer); }; struct MultiFileFilterEntry { diff --git a/src/include/duckdb/common/multi_file_reader_options.hpp b/src/include/duckdb/common/multi_file_reader_options.hpp index 98ce2b73b83a..d09047b5f187 100644 --- a/src/include/duckdb/common/multi_file_reader_options.hpp +++ b/src/include/duckdb/common/multi_file_reader_options.hpp @@ -28,9 +28,11 @@ struct MultiFileReaderOptions { DUCKDB_API void Serialize(Serializer &serializer) const; DUCKDB_API static MultiFileReaderOptions Deserialize(Deserializer &source); + DUCKDB_API void FormatSerialize(FormatSerializer &serializer) const; + DUCKDB_API static MultiFileReaderOptions FormatDeserialize(FormatDeserializer &source); DUCKDB_API void AddBatchInfo(BindInfo &bind_info) const; DUCKDB_API void AutoDetectHivePartitioning(const vector &files, ClientContext &context); - DUCKDB_API static bool AutoDetectHivePartitioningInternal(const vector &files); + DUCKDB_API static bool AutoDetectHivePartitioningInternal(const vector &files, ClientContext &context); DUCKDB_API void AutoDetectHiveTypesInternal(const string &file, ClientContext &context); DUCKDB_API void VerifyHiveTypesArePartitions(const std::map &partitions) const; DUCKDB_API LogicalType GetHiveLogicalType(const string &hive_partition_column) const; diff --git a/src/include/duckdb/common/opener_file_system.hpp b/src/include/duckdb/common/opener_file_system.hpp index 57c41a2b7d27..28b352452ef8 100644 --- a/src/include/duckdb/common/opener_file_system.hpp +++ b/src/include/duckdb/common/opener_file_system.hpp @@ -103,6 +103,10 @@ class OpenerFileSystem : public FileSystem { GetFileSystem().RemoveFile(filename); } + string PathSeparator(const string &path) override { + return GetFileSystem().PathSeparator(path); + } + vector Glob(const string &path, FileOpener *opener = nullptr) override { if (opener) { throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); diff --git a/src/include/duckdb/common/serializer.hpp b/src/include/duckdb/common/serializer.hpp index c7ba197c35d4..9c7eda329e11 100644 --- a/src/include/duckdb/common/serializer.hpp +++ b/src/include/duckdb/common/serializer.hpp @@ -22,8 +22,6 @@ class Serializer { uint64_t version = 0L; public: - bool is_query_plan = false; - virtual ~Serializer() { } @@ -119,11 +117,6 @@ class Deserializer { throw InternalException("This deserializer does not have a client-context"); }; - //! Gets the catalog for the deserializer - virtual optional_ptr GetCatalog() { - return nullptr; - }; - template T Read() { T value; diff --git a/src/include/duckdb/common/serializer/buffered_file_reader.hpp b/src/include/duckdb/common/serializer/buffered_file_reader.hpp index aef17a27346d..711afd65a6ee 100644 --- a/src/include/duckdb/common/serializer/buffered_file_reader.hpp +++ b/src/include/duckdb/common/serializer/buffered_file_reader.hpp @@ -37,14 +37,10 @@ class BufferedFileReader : public Deserializer { ClientContext &GetContext() override; - optional_ptr GetCatalog() override; - void SetCatalog(Catalog &catalog); - private: idx_t file_size; idx_t total_read; optional_ptr context; - optional_ptr catalog; }; } // namespace duckdb diff --git a/src/include/duckdb/common/serializer/format_deserializer.hpp b/src/include/duckdb/common/serializer/format_deserializer.hpp index 8e4abc85d526..34c5a8ddd435 100644 --- a/src/include/duckdb/common/serializer/format_deserializer.hpp +++ b/src/include/duckdb/common/serializer/format_deserializer.hpp @@ -127,6 +127,16 @@ class FormatDeserializer { return data.Unset(); } + // Manually begin an object - should be followed by EndObject + void BeginObject(const char *tag) { + SetTag(tag); + OnObjectBegin(); + } + + void EndObject() { + OnObjectEnd(); + } + private: // Deserialize anything implementing a FormatDeserialize method template @@ -208,6 +218,28 @@ class FormatDeserializer { return map; } + template + inline typename std::enable_if::value, T>::type Read() { + using KEY_TYPE = typename is_map::KEY_TYPE; + using VALUE_TYPE = typename is_map::VALUE_TYPE; + + T map; + auto size = OnMapBegin(); + for (idx_t i = 0; i < size; i++) { + OnMapEntryBegin(); + OnMapKeyBegin(); + auto key = Read(); + OnMapKeyEnd(); + OnMapValueBegin(); + auto value = Read(); + OnMapValueEnd(); + OnMapEntryEnd(); + map[std::move(key)] = std::move(value); + } + OnMapEnd(); + return map; + } + // Deserialize an unordered set template inline typename std::enable_if::value, T>::type Read() { diff --git a/src/include/duckdb/common/serializer/format_serializer.hpp b/src/include/duckdb/common/serializer/format_serializer.hpp index e9985426acaa..fe6f1b641a0e 100644 --- a/src/include/duckdb/common/serializer/format_serializer.hpp +++ b/src/include/duckdb/common/serializer/format_serializer.hpp @@ -28,25 +28,11 @@ class FormatSerializer { public: // Serialize a value template - typename std::enable_if::value, void>::type WriteProperty(const char *tag, const T &value) { + void WriteProperty(const char *tag, const T &value) { SetTag(tag); WriteValue(value); } - // Serialize an enum - template - typename std::enable_if::value, void>::type WriteProperty(const char *tag, T value) { - SetTag(tag); - if (serialize_enum_as_string) { - // Use the enum serializer to lookup tostring function - auto str = EnumUtil::ToChars(value); - WriteValue(str); - } else { - // Use the underlying type - WriteValue(static_cast::type>(value)); - } - } - // Optional pointer template void WriteOptionalProperty(const char *tag, POINTER &&ptr) { @@ -67,7 +53,29 @@ class FormatSerializer { WriteDataPtr(ptr, count); } + // Manually begin an object - should be followed by EndObject + void BeginObject(const char *tag) { + SetTag(tag); + OnObjectBegin(); + } + + void EndObject() { + OnObjectEnd(); + } + protected: + template + typename std::enable_if::value, void>::type WriteValue(const T value) { + if (serialize_enum_as_string) { + // Use the enum serializer to lookup tostring function + auto str = EnumUtil::ToChars(value); + WriteValue(str); + } else { + // Use the underlying type + WriteValue(static_cast::type>(value)); + } + } + // Unique Pointer Ref template void WriteValue(const unique_ptr &ptr) { @@ -160,6 +168,24 @@ class FormatSerializer { OnMapEnd(count); } + // Map + template + void WriteValue(const duckdb::map &map) { + auto count = map.size(); + OnMapBegin(count); + for (auto &item : map) { + OnMapEntryBegin(); + OnMapKeyBegin(); + WriteValue(item.first); + OnMapKeyEnd(); + OnMapValueBegin(); + WriteValue(item.second); + OnMapValueEnd(); + OnMapEntryEnd(); + } + OnMapEnd(count); + } + // class or struct implementing `FormatSerialize(FormatSerializer& FormatSerializer)`; template typename std::enable_if::value>::type WriteValue(const T &value) { @@ -247,4 +273,8 @@ class FormatSerializer { } }; +// We need to special case vector because elements of vector cannot be referenced +template <> +void FormatSerializer::WriteValue(const vector &vec); + } // namespace duckdb diff --git a/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/include/duckdb/common/serializer/serialization_traits.hpp index a97575b12865..c12fc30d9744 100644 --- a/src/include/duckdb/common/serializer/serialization_traits.hpp +++ b/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -70,6 +70,16 @@ struct is_unordered_map> : std::true_typ typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; }; +template +struct is_map : std::false_type {}; +template +struct is_map> : std::true_type { + typedef typename std::tuple_element<0, std::tuple>::type KEY_TYPE; + typedef typename std::tuple_element<1, std::tuple>::type VALUE_TYPE; + typedef typename std::tuple_element<2, std::tuple>::type HASH_TYPE; + typedef typename std::tuple_element<3, std::tuple>::type EQUAL_TYPE; +}; + template struct is_unique_ptr : std::false_type {}; diff --git a/src/include/duckdb/common/stack_checker.hpp b/src/include/duckdb/common/stack_checker.hpp new file mode 100644 index 000000000000..a2375e8ef966 --- /dev/null +++ b/src/include/duckdb/common/stack_checker.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/stack_checker.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +template +class StackChecker { +public: + StackChecker(RECURSIVE_CLASS &recursive_class_p, idx_t stack_usage_p) + : recursive_class(recursive_class_p), stack_usage(stack_usage_p) { + recursive_class.stack_depth += stack_usage; + } + ~StackChecker() { + recursive_class.stack_depth -= stack_usage; + } + StackChecker(StackChecker &&other) noexcept + : recursive_class(other.recursive_class), stack_usage(other.stack_usage) { + other.stack_usage = 0; + } + StackChecker(const StackChecker &) = delete; + +private: + RECURSIVE_CLASS &recursive_class; + idx_t stack_usage; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/common/types.hpp b/src/include/duckdb/common/types.hpp index c33dffcd484d..b1830a6fda77 100644 --- a/src/include/duckdb/common/types.hpp +++ b/src/include/duckdb/common/types.hpp @@ -286,8 +286,6 @@ struct LogicalType { //! Serializes a LogicalType to a stand-alone binary blob DUCKDB_API void Serialize(Serializer &serializer) const; - DUCKDB_API void SerializeEnumType(Serializer &serializer) const; - //! Deserializes a blob back into an LogicalType DUCKDB_API static LogicalType Deserialize(Deserializer &source); @@ -373,9 +371,10 @@ struct LogicalType { DUCKDB_API static LogicalType STRUCT(child_list_t children); // NOLINT DUCKDB_API static LogicalType AGGREGATE_STATE(aggregate_state_t state_type); // NOLINT DUCKDB_API static LogicalType MAP(const LogicalType &child); // NOLINT - DUCKDB_API static LogicalType MAP( child_list_t children); // NOLINT DUCKDB_API static LogicalType MAP(LogicalType key, LogicalType value); // NOLINT DUCKDB_API static LogicalType UNION( child_list_t members); // NOLINT + DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT + // DEPRECATED - provided for backwards compatibility DUCKDB_API static LogicalType ENUM(const string &enum_name, Vector &ordered_data, idx_t size); // NOLINT DUCKDB_API static LogicalType USER(const string &user_type_name); // NOLINT //! A list of all NUMERIC types (integral and floating point types) @@ -400,21 +399,17 @@ struct ListType { DUCKDB_API static const LogicalType &GetChildType(const LogicalType &type); }; -struct UserType{ +struct UserType { DUCKDB_API static const string &GetTypeName(const LogicalType &type); }; -struct EnumType{ - DUCKDB_API static const string &GetTypeName(const LogicalType &type); +struct EnumType { DUCKDB_API static int64_t GetPos(const LogicalType &type, const string_t& key); DUCKDB_API static const Vector &GetValuesInsertOrder(const LogicalType &type); DUCKDB_API static idx_t GetSize(const LogicalType &type); DUCKDB_API static const string GetValue(const Value &val); - DUCKDB_API static void SetCatalog(LogicalType &type, optional_ptr catalog_entry); - DUCKDB_API static optional_ptr GetCatalog(const LogicalType &type); - DUCKDB_API static string GetSchemaName(const LogicalType &type); DUCKDB_API static PhysicalType GetPhysicalType(const LogicalType &type); - DUCKDB_API static void Serialize(FieldWriter& writer, const ExtraTypeInfo& type_info, bool serialize_internals); + DUCKDB_API static string_t GetString(const LogicalType &type, idx_t pos); }; struct StructType { diff --git a/src/include/duckdb/common/types/timestamp.hpp b/src/include/duckdb/common/types/timestamp.hpp index 1d5f5d962e6a..f6f427d4a2fe 100644 --- a/src/include/duckdb/common/types/timestamp.hpp +++ b/src/include/duckdb/common/types/timestamp.hpp @@ -26,7 +26,7 @@ struct timestamp_t { // NOLINT int64_t value; timestamp_t() = default; - explicit inline timestamp_t(int64_t value_p) : value(value_p) { + explicit inline constexpr timestamp_t(int64_t value_p) : value(value_p) { } inline timestamp_t &operator=(int64_t value_p) { value = value_p; @@ -67,21 +67,25 @@ struct timestamp_t { // NOLINT timestamp_t &operator-=(const int64_t &delta); // special values - static timestamp_t infinity() { // NOLINT + static constexpr timestamp_t infinity() { // NOLINT return timestamp_t(NumericLimits::Maximum()); - } // NOLINT - static timestamp_t ninfinity() { // NOLINT + } // NOLINT + static constexpr timestamp_t ninfinity() { // NOLINT return timestamp_t(-NumericLimits::Maximum()); - } // NOLINT - static inline timestamp_t epoch() { // NOLINT + } // NOLINT + static constexpr inline timestamp_t epoch() { // NOLINT return timestamp_t(0); } // NOLINT }; -struct timestamp_tz_t : public timestamp_t {}; // NOLINT -struct timestamp_ns_t : public timestamp_t {}; // NOLINT -struct timestamp_ms_t : public timestamp_t {}; // NOLINT -struct timestamp_sec_t : public timestamp_t {}; // NOLINT +struct timestamp_tz_t : public timestamp_t { // NOLINT +}; +struct timestamp_ns_t : public timestamp_t { // NOLINT +}; +struct timestamp_ms_t : public timestamp_t { // NOLINT +}; +struct timestamp_sec_t : public timestamp_t { // NOLINT +}; enum class TimestampCastResult : uint8_t { SUCCESS, ERROR_INCORRECT_FORMAT, ERROR_NON_UTC_TIMEZONE }; diff --git a/src/include/duckdb/common/virtual_file_system.hpp b/src/include/duckdb/common/virtual_file_system.hpp index 4b5285abed1a..69990bcbd1b9 100644 --- a/src/include/duckdb/common/virtual_file_system.hpp +++ b/src/include/duckdb/common/virtual_file_system.hpp @@ -68,6 +68,8 @@ class VirtualFileSystem : public FileSystem { void SetDisabledFileSystems(const vector &names) override; + string PathSeparator(const string &path) override; + private: FileSystem &FindFileSystem(const string &path); FileSystem &FindFileSystemInternal(const string &path); diff --git a/src/include/duckdb/execution/operator/persistent/csv_reader_options.hpp b/src/include/duckdb/execution/operator/persistent/csv_reader_options.hpp index 864da460d16a..e5f56809c72e 100644 --- a/src/include/duckdb/execution/operator/persistent/csv_reader_options.hpp +++ b/src/include/duckdb/execution/operator/persistent/csv_reader_options.hpp @@ -147,6 +147,8 @@ struct BufferedCSVReaderOptions { void Serialize(FieldWriter &writer) const; void Deserialize(FieldReader &reader); + void FormatSerialize(FormatSerializer &serializer) const; + static BufferedCSVReaderOptions FormatDeserialize(FormatDeserializer &deserializer); void SetCompression(const string &compression); void SetHeader(bool has_header); diff --git a/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp b/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp index 733d067174a3..e00a18ec51fd 100644 --- a/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp +++ b/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp @@ -22,10 +22,6 @@ class PhysicalTableScan : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::TABLE_SCAN; public: - //! Regular Table Scan - PhysicalTableScan(vector types, TableFunction function, unique_ptr bind_data, - vector column_ids, vector names, unique_ptr table_filters, - idx_t estimated_cardinality); //! Table scan that immediately projects out filter columns that are unused in the remainder of the query plan PhysicalTableScan(vector types, TableFunction function, unique_ptr bind_data, vector returned_types, vector column_ids, vector projection_ids, diff --git a/src/include/duckdb/execution/physical_plan_generator.hpp b/src/include/duckdb/execution/physical_plan_generator.hpp index a795014c92f4..f17dd1b82051 100644 --- a/src/include/duckdb/execution/physical_plan_generator.hpp +++ b/src/include/duckdb/execution/physical_plan_generator.hpp @@ -50,7 +50,6 @@ class PhysicalPlanGenerator { unique_ptr CreatePlan(LogicalAggregate &op); unique_ptr CreatePlan(LogicalAnyJoin &op); - unique_ptr CreatePlan(LogicalAsOfJoin &op); unique_ptr CreatePlan(LogicalColumnDataGet &op); unique_ptr CreatePlan(LogicalComparisonJoin &op); unique_ptr CreatePlan(LogicalCreate &op); @@ -59,7 +58,6 @@ class PhysicalPlanGenerator { unique_ptr CreatePlan(LogicalCrossProduct &op); unique_ptr CreatePlan(LogicalDelete &op); unique_ptr CreatePlan(LogicalDelimGet &op); - unique_ptr CreatePlan(LogicalDelimJoin &op); unique_ptr CreatePlan(LogicalDistinct &op); unique_ptr CreatePlan(LogicalDummyScan &expr); unique_ptr CreatePlan(LogicalEmptyResult &op); @@ -93,6 +91,9 @@ class PhysicalPlanGenerator { unique_ptr CreatePlan(LogicalCTERef &op); unique_ptr CreatePlan(LogicalPivot &op); + unique_ptr PlanAsOfJoin(LogicalComparisonJoin &op); + unique_ptr PlanComparisonJoin(LogicalComparisonJoin &op); + unique_ptr PlanDelimJoin(LogicalComparisonJoin &op); unique_ptr ExtractAggregateExpressions(unique_ptr child, vector> &expressions, vector> &groups); diff --git a/src/include/duckdb/function/aggregate_function.hpp b/src/include/duckdb/function/aggregate_function.hpp index ef93e491d949..31b2b8969d73 100644 --- a/src/include/duckdb/function/aggregate_function.hpp +++ b/src/include/duckdb/function/aggregate_function.hpp @@ -52,6 +52,11 @@ typedef void (*aggregate_serialize_t)(FieldWriter &writer, const FunctionData *b typedef unique_ptr (*aggregate_deserialize_t)(PlanDeserializationState &context, FieldReader &reader, AggregateFunction &function); +typedef void (*aggregate_format_serialize_t)(FormatSerializer &serializer, const optional_ptr bind_data, + const AggregateFunction &function); +typedef unique_ptr (*aggregate_format_deserialize_t)(FormatDeserializer &deserializer, + AggregateFunction &function); + class AggregateFunction : public BaseScalarFunction { public: AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, @@ -66,7 +71,8 @@ class AggregateFunction : public BaseScalarFunction { LogicalType(LogicalTypeId::INVALID), null_handling), state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), - serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + serialize(serialize), deserialize(deserialize), format_serialize(nullptr), format_deserialize(nullptr), + order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, @@ -80,7 +86,8 @@ class AggregateFunction : public BaseScalarFunction { LogicalType(LogicalTypeId::INVALID)), state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), - serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + serialize(serialize), deserialize(deserialize), format_serialize(nullptr), format_deserialize(nullptr), + order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, @@ -131,6 +138,8 @@ class AggregateFunction : public BaseScalarFunction { aggregate_serialize_t serialize; aggregate_deserialize_t deserialize; + aggregate_format_serialize_t format_serialize; + aggregate_format_deserialize_t format_deserialize; //! Whether or not the aggregate is order dependent AggregateOrderDependent order_dependent; diff --git a/src/include/duckdb/function/function_serialization.hpp b/src/include/duckdb/function/function_serialization.hpp index 17497114b7e6..cb15bada607a 100644 --- a/src/include/duckdb/function/function_serialization.hpp +++ b/src/include/duckdb/function/function_serialization.hpp @@ -11,6 +11,8 @@ #include "duckdb/common/field_writer.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { @@ -97,6 +99,85 @@ class FunctionSerializer { function.return_type = return_type; return function; } + + template + static void FormatSerialize(FormatSerializer &serializer, const FUNC &function, + optional_ptr bind_info) { + D_ASSERT(!function.name.empty()); + serializer.WriteProperty("name", function.name); + serializer.WriteProperty("arguments", function.arguments); + serializer.WriteProperty("original_arguments", function.original_arguments); + bool has_serialize = function.format_serialize; + serializer.WriteProperty("has_serialize", has_serialize); + if (has_serialize) { + serializer.BeginObject("function_data"); + function.format_serialize(serializer, bind_info, function); + serializer.EndObject(); + D_ASSERT(function.format_deserialize); + } + } + + template + static FUNC DeserializeFunction(ClientContext &context, CatalogType catalog_type, const string &name, + vector arguments, vector original_arguments) { + auto &func_catalog = Catalog::GetEntry(context, catalog_type, SYSTEM_CATALOG, DEFAULT_SCHEMA, name); + if (func_catalog.type != catalog_type) { + throw InternalException("DeserializeFunction - cant find catalog entry for function %s", name); + } + auto &functions = func_catalog.Cast(); + auto function = functions.functions.GetFunctionByArguments( + context, original_arguments.empty() ? arguments : original_arguments); + function.arguments = std::move(arguments); + function.original_arguments = std::move(original_arguments); + return function; + } + + template + static pair FormatDeserializeBase(FormatDeserializer &deserializer, CatalogType catalog_type) { + auto &context = deserializer.Get(); + auto name = deserializer.ReadProperty("name"); + auto arguments = deserializer.ReadProperty>("arguments"); + auto original_arguments = deserializer.ReadProperty>("original_arguments"); + auto function = DeserializeFunction(context, catalog_type, name, std::move(arguments), + std::move(original_arguments)); + auto has_serialize = deserializer.ReadProperty("has_serialize"); + return make_pair(std::move(function), has_serialize); + } + + template + static unique_ptr FunctionDeserialize(FormatDeserializer &deserializer, FUNC &function) { + if (!function.format_deserialize) { + throw SerializationException("Function requires deserialization but no deserialization function for %s", + function.name); + } + deserializer.BeginObject("function_data"); + auto result = function.format_deserialize(deserializer, function); + deserializer.EndObject(); + return result; + } + + template + static pair> FormatDeserialize(FormatDeserializer &deserializer, + CatalogType catalog_type, + vector> &children) { + auto &context = deserializer.Get(); + auto entry = FormatDeserializeBase(deserializer, catalog_type); + auto &function = entry.first; + auto has_serialize = entry.second; + + unique_ptr bind_data; + if (has_serialize) { + bind_data = FunctionDeserialize(deserializer, function); + } else if (function.bind) { + try { + bind_data = function.bind(context, function, children); + } catch (Exception &ex) { + // FIXME + throw SerializationException("Error during bind of function in deserialization: %s", ex.what()); + } + } + return make_pair(std::move(function), std::move(bind_data)); + } }; } // namespace duckdb diff --git a/src/include/duckdb/function/scalar/strftime_format.hpp b/src/include/duckdb/function/scalar/strftime_format.hpp index 1c803090e670..ba4597c1dcfa 100644 --- a/src/include/duckdb/function/scalar/strftime_format.hpp +++ b/src/include/duckdb/function/scalar/strftime_format.hpp @@ -121,6 +121,8 @@ struct StrfTimeFormat : public StrTimeFormat { struct StrpTimeFormat : public StrTimeFormat { public: + StrpTimeFormat(); + //! Type-safe parsing argument struct ParseResult { int32_t data[8]; // year, month, day, hour, min, sec, µs, offset @@ -148,12 +150,18 @@ struct StrpTimeFormat : public StrTimeFormat { date_t ParseDate(string_t str); timestamp_t ParseTimestamp(string_t str); + void FormatSerialize(FormatSerializer &serializer) const; + static StrpTimeFormat FormatDeserialize(FormatDeserializer &deserializer); + protected: static string FormatStrpTimeError(const string &input, idx_t position); DUCKDB_API void AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) override; int NumericSpecifierWidth(StrTimeSpecifier specifier); int32_t TryParseCollection(const char *data, idx_t &pos, idx_t size, const string_t collection[], idx_t collection_count); + +private: + explicit StrpTimeFormat(const string &format_string); }; } // namespace duckdb diff --git a/src/include/duckdb/function/scalar_function.hpp b/src/include/duckdb/function/scalar_function.hpp index 0f1bf29776ff..106b8e74b225 100644 --- a/src/include/duckdb/function/scalar_function.hpp +++ b/src/include/duckdb/function/scalar_function.hpp @@ -69,6 +69,11 @@ typedef void (*function_serialize_t)(FieldWriter &writer, const FunctionData *bi typedef unique_ptr (*function_deserialize_t)(PlanDeserializationState &state, FieldReader &reader, ScalarFunction &function); +typedef void (*function_format_serialize_t)(FormatSerializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function); +typedef unique_ptr (*function_format_deserialize_t)(FormatDeserializer &deserializer, + ScalarFunction &function); + class ScalarFunction : public BaseScalarFunction { public: DUCKDB_API ScalarFunction(string name, vector arguments, LogicalType return_type, @@ -100,14 +105,14 @@ class ScalarFunction : public BaseScalarFunction { function_serialize_t serialize; function_deserialize_t deserialize; + function_format_serialize_t format_serialize; + function_format_deserialize_t format_deserialize; + DUCKDB_API bool operator==(const ScalarFunction &rhs) const; DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; DUCKDB_API bool Equal(const ScalarFunction &rhs) const; -private: - bool CompareScalarFunctionT(const scalar_function_t &other) const; - public: DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result); diff --git a/src/include/duckdb/function/table/read_csv.hpp b/src/include/duckdb/function/table/read_csv.hpp index ef40e15d2c53..380da758ed10 100644 --- a/src/include/duckdb/function/table/read_csv.hpp +++ b/src/include/duckdb/function/table/read_csv.hpp @@ -76,6 +76,10 @@ struct ColumnInfo { info.types = reader.ReadRequiredSerializableList(); return info; } + + void FormatSerialize(FormatSerializer &serializer) const; + static ColumnInfo FormatDeserialize(FormatDeserializer &deserializer); + vector names; vector types; }; @@ -105,6 +109,9 @@ struct ReadCSVData : public BaseCSVData { this->initial_reader = std::move(reader); } void FinalizeRead(ClientContext &context); + + void FormatSerialize(FormatSerializer &serializer) const; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; struct CSVCopyFunction { diff --git a/src/include/duckdb/function/table_function.hpp b/src/include/duckdb/function/table_function.hpp index daff56b18934..13d79619adfe 100644 --- a/src/include/duckdb/function/table_function.hpp +++ b/src/include/duckdb/function/table_function.hpp @@ -207,6 +207,12 @@ typedef void (*table_function_serialize_t)(FieldWriter &writer, const FunctionDa typedef unique_ptr (*table_function_deserialize_t)(PlanDeserializationState &context, FieldReader &reader, TableFunction &function); +typedef void (*table_function_format_serialize_t)(FormatSerializer &serializer, + const optional_ptr bind_data, + const TableFunction &function); +typedef unique_ptr (*table_function_format_deserialize_t)(FormatDeserializer &deserializer, + TableFunction &function); + class TableFunction : public SimpleNamedParameterFunction { public: DUCKDB_API @@ -265,6 +271,8 @@ class TableFunction : public SimpleNamedParameterFunction { table_function_serialize_t serialize; table_function_deserialize_t deserialize; + table_function_format_serialize_t format_serialize; + table_function_format_deserialize_t format_deserialize; bool verify_serialization = true; //! Whether or not the table function supports projection pushdown. If not supported a projection will be added diff --git a/src/include/duckdb/main/attached_database.hpp b/src/include/duckdb/main/attached_database.hpp index 0508afaffd5b..14b6f6e053ef 100644 --- a/src/include/duckdb/main/attached_database.hpp +++ b/src/include/duckdb/main/attached_database.hpp @@ -60,7 +60,7 @@ class AttachedDatabase : public CatalogEntry { bool IsInitialDatabase() const; void SetInitialDatabase(); - static string ExtractDatabaseName(const string &dbpath); + static string ExtractDatabaseName(const string &dbpath, FileSystem &fs); private: DatabaseInstance &db; diff --git a/src/include/duckdb/main/chunk_scan_state.hpp b/src/include/duckdb/main/chunk_scan_state.hpp new file mode 100644 index 000000000000..cccff2fe5fb1 --- /dev/null +++ b/src/include/duckdb/main/chunk_scan_state.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "duckdb/common/vector.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/preserved_error.hpp" + +namespace duckdb { + +class DataChunk; + +//! Abstract chunk fetcher +class ChunkScanState { +public: + explicit ChunkScanState() { + } + virtual ~ChunkScanState() { + } + +public: + ChunkScanState(const ChunkScanState &other) = delete; + ChunkScanState(ChunkScanState &&other) = default; + ChunkScanState &operator=(const ChunkScanState &other) = delete; + ChunkScanState &operator=(ChunkScanState &&other) = default; + +public: + virtual bool LoadNextChunk(PreservedError &error) = 0; + virtual bool HasError() const = 0; + virtual PreservedError &GetError() = 0; + virtual const vector &Types() const = 0; + virtual const vector &Names() const = 0; + idx_t CurrentOffset() const; + idx_t RemainingInChunk() const; + DataChunk &CurrentChunk(); + bool ChunkIsEmpty() const; + bool Finished() const; + bool ScanStarted() const; + void IncreaseOffset(idx_t increment, bool unsafe = false); + +protected: + idx_t offset = 0; + bool finished = false; + unique_ptr current_chunk; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/main/chunk_scan_state/query_result.hpp b/src/include/duckdb/main/chunk_scan_state/query_result.hpp new file mode 100644 index 000000000000..d6f21a5076c6 --- /dev/null +++ b/src/include/duckdb/main/chunk_scan_state/query_result.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "duckdb/main/chunk_scan_state.hpp" +#include "duckdb/common/preserved_error.hpp" + +namespace duckdb { + +class QueryResult; + +class QueryResultChunkScanState : public ChunkScanState { +public: + QueryResultChunkScanState(QueryResult &result); + ~QueryResultChunkScanState(); + +public: + bool LoadNextChunk(PreservedError &error) override; + bool HasError() const override; + PreservedError &GetError() override; + const vector &Types() const override; + const vector &Names() const override; + +private: + bool InternalLoad(PreservedError &error); + +private: + QueryResult &result; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/main/extension_entries.hpp b/src/include/duckdb/main/extension_entries.hpp index ca026d2d0a17..08f406242f94 100644 --- a/src/include/duckdb/main/extension_entries.hpp +++ b/src/include/duckdb/main/extension_entries.hpp @@ -162,6 +162,9 @@ static constexpr ExtensionEntry EXTENSION_FUNCTIONS[] = {{"->>", "json"}, {"st_removerepeatedpoints", "spatial"}, {"st_geomfromgeojson", "spatial"}, {"st_readosm", "spatial"}, + {"st_reduceprecision", "spatial"}, + {"st_geomfromhexwkb", "spatial"}, + {"st_geomfromhexewkb", "spatial"}, {"st_numpoints", "spatial"}}; static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { diff --git a/src/include/duckdb/main/query_result.hpp b/src/include/duckdb/main/query_result.hpp index 68fb07275028..f7e97ca38df7 100644 --- a/src/include/duckdb/main/query_result.hpp +++ b/src/include/duckdb/main/query_result.hpp @@ -63,16 +63,7 @@ class BaseQueryResult { //! The error (in case execution was not successful) PreservedError error; }; -struct CurrentChunk { - //! The current data chunk - unique_ptr data_chunk; - //! The current position in the data chunk - idx_t position; - //! If we have a current chunk we must scan for result production - bool Valid(); - //! The remaining size of the current chunk - idx_t RemainingSize(); -}; + //! The QueryResult object holds the result of a query. It can either be a MaterializedQueryResult, in which case the //! result contains the entire result set, or a StreamQueryResult in which case the Fetch method can be called to //! incrementally fetch data from the database. @@ -89,10 +80,6 @@ class QueryResult : public BaseQueryResult { ClientProperties client_properties; //! The next result (if any) unique_ptr next; - //! In case we are converting the result from Native DuckDB to a different library (e.g., Arrow, Polars) - //! We might be producing chunks of a pre-determined size. - //! To comply, we use the following variable to store the current chunk, and it's position. - CurrentChunk current_chunk; public: template diff --git a/src/include/duckdb/optimizer/deliminator.hpp b/src/include/duckdb/optimizer/deliminator.hpp index 0d110563b0a2..482c4064adce 100644 --- a/src/include/duckdb/optimizer/deliminator.hpp +++ b/src/include/duckdb/optimizer/deliminator.hpp @@ -27,9 +27,9 @@ class Deliminator { void FindCandidates(unique_ptr &op, vector &candidates); void FindJoinWithDelimGet(unique_ptr &op, DelimCandidate &candidate); //! Remove joins with a DelimGet - bool RemoveJoinWithDelimGet(LogicalDelimJoin &delim_join, const idx_t delim_get_count, + bool RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, unique_ptr &join, bool &all_equality_conditions); - bool RemoveInequalityJoinWithDelimGet(LogicalDelimJoin &delim_join, const idx_t delim_get_count, + bool RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, unique_ptr &join, const vector &replacement_bindings); diff --git a/src/include/duckdb/parser/transformer.hpp b/src/include/duckdb/parser/transformer.hpp index d9db17138533..849e8740683a 100644 --- a/src/include/duckdb/parser/transformer.hpp +++ b/src/include/duckdb/parser/transformer.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/constants.hpp" #include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/stack_checker.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/group_by_node.hpp" @@ -26,7 +27,6 @@ namespace duckdb { class ColumnDefinition; -class StackChecker; struct OrderByNode; struct CopyInfo; struct CommonTableExpressionInfo; @@ -39,7 +39,7 @@ struct PivotColumn; //! The transformer class is responsible for transforming the internal Postgres //! parser representation into the DuckDB representation class Transformer { - friend class StackChecker; + friend class StackChecker; struct CreatePivotEntry { string enum_name; @@ -343,7 +343,7 @@ class Transformer { idx_t stack_depth; void InitializeStackCheck(); - StackChecker StackCheck(idx_t extra_stack = 1); + StackChecker StackCheck(idx_t extra_stack = 1); public: template @@ -356,18 +356,6 @@ class Transformer { } }; -class StackChecker { -public: - StackChecker(Transformer &transformer, idx_t stack_usage); - ~StackChecker(); - StackChecker(StackChecker &&) noexcept; - StackChecker(const StackChecker &) = delete; - -private: - Transformer &transformer; - idx_t stack_usage; -}; - vector ReadPgListToString(duckdb_libpgquery::PGList *column_list); } // namespace duckdb diff --git a/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp b/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp index fbee39204467..e7c21fa6d2af 100644 --- a/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp +++ b/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp @@ -57,5 +57,8 @@ class BoundAggregateExpression : public Expression { unique_ptr Copy() override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(ExpressionDeserializationState &state, FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/expression/bound_function_expression.hpp b/src/include/duckdb/planner/expression/bound_function_expression.hpp index 50491e828f5c..8cc949265f6e 100644 --- a/src/include/duckdb/planner/expression/bound_function_expression.hpp +++ b/src/include/duckdb/planner/expression/bound_function_expression.hpp @@ -46,5 +46,9 @@ class BoundFunctionExpression : public Expression { void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(ExpressionDeserializationState &state, FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; + } // namespace duckdb diff --git a/src/include/duckdb/planner/expression/bound_window_expression.hpp b/src/include/duckdb/planner/expression/bound_window_expression.hpp index e0cc1f5befbb..ba5b1745bec6 100644 --- a/src/include/duckdb/planner/expression/bound_window_expression.hpp +++ b/src/include/duckdb/planner/expression/bound_window_expression.hpp @@ -67,5 +67,8 @@ class BoundWindowExpression : public Expression { void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(ExpressionDeserializationState &state, FieldReader &reader); + + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/expression_binder.hpp b/src/include/duckdb/planner/expression_binder.hpp index 3efc79964ee9..11e5a882c937 100644 --- a/src/include/duckdb/planner/expression_binder.hpp +++ b/src/include/duckdb/planner/expression_binder.hpp @@ -9,11 +9,12 @@ #pragma once #include "duckdb/common/exception.hpp" +#include "duckdb/common/stack_checker.hpp" +#include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/expression/bound_expression.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/tokens.hpp" #include "duckdb/planner/expression.hpp" -#include "duckdb/common/unordered_map.hpp" namespace duckdb { @@ -51,6 +52,8 @@ struct BindResult { }; class ExpressionBinder { + friend class StackChecker; + public: ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false); virtual ~ExpressionBinder(); @@ -110,6 +113,15 @@ class ExpressionBinder { void ReplaceMacroParametersRecursive(unique_ptr &expr); +private: + //! Maximum stack depth + static constexpr const idx_t MAXIMUM_STACK_DEPTH = 128; + //! Current stack depth + idx_t stack_depth = DConstants::INVALID_INDEX; + + void InitializeStackCheck(); + StackChecker StackCheck(const ParsedExpression &expr, idx_t extra_stack = 1); + protected: BindResult BindExpression(BetweenExpression &expr, idx_t depth); BindResult BindExpression(CaseExpression &expr, idx_t depth); diff --git a/src/include/duckdb/planner/filter/conjunction_filter.hpp b/src/include/duckdb/planner/filter/conjunction_filter.hpp index 4d58adc483d6..7371ca474311 100644 --- a/src/include/duckdb/planner/filter/conjunction_filter.hpp +++ b/src/include/duckdb/planner/filter/conjunction_filter.hpp @@ -45,6 +45,8 @@ class ConjunctionOrFilter : public ConjunctionFilter { bool Equals(const TableFilter &other) const override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(FieldReader &source); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; class ConjunctionAndFilter : public ConjunctionFilter { @@ -60,6 +62,8 @@ class ConjunctionAndFilter : public ConjunctionFilter { bool Equals(const TableFilter &other) const override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(FieldReader &source); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/filter/constant_filter.hpp b/src/include/duckdb/planner/filter/constant_filter.hpp index 7da12c9fc0c6..2ce3b968d1a1 100644 --- a/src/include/duckdb/planner/filter/constant_filter.hpp +++ b/src/include/duckdb/planner/filter/constant_filter.hpp @@ -32,6 +32,8 @@ class ConstantFilter : public TableFilter { bool Equals(const TableFilter &other) const override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(FieldReader &source); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/filter/null_filter.hpp b/src/include/duckdb/planner/filter/null_filter.hpp index 695dc378629b..af8a61644256 100644 --- a/src/include/duckdb/planner/filter/null_filter.hpp +++ b/src/include/duckdb/planner/filter/null_filter.hpp @@ -24,6 +24,8 @@ class IsNullFilter : public TableFilter { string ToString(const string &column_name) override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(FieldReader &source); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; class IsNotNullFilter : public TableFilter { @@ -38,6 +40,8 @@ class IsNotNullFilter : public TableFilter { string ToString(const string &column_name) override; void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(FieldReader &source); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/logical_tokens.hpp b/src/include/duckdb/planner/logical_tokens.hpp index fc40b5cb7fbb..0d033ef6adfc 100644 --- a/src/include/duckdb/planner/logical_tokens.hpp +++ b/src/include/duckdb/planner/logical_tokens.hpp @@ -14,7 +14,6 @@ class LogicalOperator; class LogicalAggregate; class LogicalAnyJoin; -class LogicalAsOfJoin; class LogicalColumnDataGet; class LogicalComparisonJoin; class LogicalCopyToFile; @@ -26,7 +25,6 @@ class LogicalCrossProduct; class LogicalCTERef; class LogicalDelete; class LogicalDelimGet; -class LogicalDelimJoin; class LogicalDistinct; class LogicalDummyScan; class LogicalEmptyResult; diff --git a/src/include/duckdb/planner/operator/list.hpp b/src/include/duckdb/planner/operator/list.hpp index 36eb6b092844..e782f8ed9c1d 100644 --- a/src/include/duckdb/planner/operator/list.hpp +++ b/src/include/duckdb/planner/operator/list.hpp @@ -1,6 +1,5 @@ #include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/planner/operator/logical_any_join.hpp" -#include "duckdb/planner/operator/logical_asof_join.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_copy_to_file.hpp" @@ -11,7 +10,6 @@ #include "duckdb/planner/operator/logical_cteref.hpp" #include "duckdb/planner/operator/logical_delete.hpp" #include "duckdb/planner/operator/logical_delim_get.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/operator/logical_distinct.hpp" #include "duckdb/planner/operator/logical_dummy_scan.hpp" #include "duckdb/planner/operator/logical_empty_result.hpp" @@ -19,6 +17,7 @@ #include "duckdb/planner/operator/logical_explain.hpp" #include "duckdb/planner/operator/logical_export.hpp" #include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_extension_operator.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_insert.hpp" diff --git a/src/include/duckdb/planner/operator/logical_asof_join.hpp b/src/include/duckdb/planner/operator/logical_asof_join.hpp deleted file mode 100644 index 5289d67837c1..000000000000 --- a/src/include/duckdb/planner/operator/logical_asof_join.hpp +++ /dev/null @@ -1,27 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/operator/logical_asof_join.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/operator/logical_comparison_join.hpp" - -namespace duckdb { - -//! LogicalAsOfJoin represents a temporal-style join with one less-than inequality. -//! This inequality matches the greatest value on the right that satisfies the condition. -class LogicalAsOfJoin : public LogicalComparisonJoin { -public: - static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_ASOF_JOIN; - -public: - explicit LogicalAsOfJoin(JoinType type); - - static unique_ptr Deserialize(LogicalDeserializationState &state, FieldReader &reader); -}; - -} // namespace duckdb diff --git a/src/include/duckdb/planner/operator/logical_comparison_join.hpp b/src/include/duckdb/planner/operator/logical_comparison_join.hpp index d2faad5f5da1..23e57700ecd7 100644 --- a/src/include/duckdb/planner/operator/logical_comparison_join.hpp +++ b/src/include/duckdb/planner/operator/logical_comparison_join.hpp @@ -15,7 +15,6 @@ #include "duckdb/planner/operator/logical_join.hpp" namespace duckdb { -class LogicalDelimJoin; //! LogicalComparisonJoin represents a join that involves comparisons between the LHS and RHS class LogicalComparisonJoin : public LogicalJoin { @@ -30,6 +29,8 @@ class LogicalComparisonJoin : public LogicalJoin { vector conditions; //! Used for duplicate-eliminated MARK joins vector mark_types; + //! The set of columns that will be duplicate eliminated from the LHS and pushed into the RHS + vector> duplicate_eliminated_columns; public: string ParamsToString() const override; @@ -41,9 +42,6 @@ class LogicalComparisonJoin : public LogicalJoin { void FormatSerialize(FormatSerializer &serializer) const override; static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); - //! Turn a delim join into a regular comparison join (after all required delim scans have been pruned) - static unique_ptr FromDelimJoin(LogicalDelimJoin &join); - public: static unique_ptr CreateJoin(ClientContext &context, JoinType type, JoinRefType ref_type, unique_ptr left_child, diff --git a/src/include/duckdb/planner/operator/logical_copy_to_file.hpp b/src/include/duckdb/planner/operator/logical_copy_to_file.hpp index 95a827b1b12d..e0cd91575cf5 100644 --- a/src/include/duckdb/planner/operator/logical_copy_to_file.hpp +++ b/src/include/duckdb/planner/operator/logical_copy_to_file.hpp @@ -45,6 +45,8 @@ class LogicalCopyToFile : public LogicalOperator { bool SupportSerialization() const override { return false; } + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); protected: void ResolveTypes() override { diff --git a/src/include/duckdb/planner/operator/logical_create_index.hpp b/src/include/duckdb/planner/operator/logical_create_index.hpp index 42f6ee897b9c..42dd0dc0b914 100644 --- a/src/include/duckdb/planner/operator/logical_create_index.hpp +++ b/src/include/duckdb/planner/operator/logical_create_index.hpp @@ -19,19 +19,14 @@ class LogicalCreateIndex : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CREATE_INDEX; public: - LogicalCreateIndex(unique_ptr bind_data_p, unique_ptr info_p, - vector> expressions_p, TableCatalogEntry &table_p, - TableFunction function_p); + LogicalCreateIndex(unique_ptr info_p, vector> expressions_p, + TableCatalogEntry &table_p); - //! The bind data of the function - unique_ptr bind_data; // Info for index creation unique_ptr info; //! The table to create the index for TableCatalogEntry &table; - //! The function that is called - TableFunction function; //! Unbound expressions to be used in the optimizer vector> unbound_expressions; @@ -40,7 +35,15 @@ class LogicalCreateIndex : public LogicalOperator { void Serialize(FieldWriter &writer) const override; static unique_ptr Deserialize(LogicalDeserializationState &state, FieldReader &reader); + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); + protected: void ResolveTypes() override; + +private: + LogicalCreateIndex(ClientContext &context, unique_ptr info, vector> expressions); + + TableCatalogEntry &BindTable(ClientContext &context, CreateIndexInfo &info); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/operator/logical_delim_join.hpp b/src/include/duckdb/planner/operator/logical_delim_join.hpp deleted file mode 100644 index 62422d96964f..000000000000 --- a/src/include/duckdb/planner/operator/logical_delim_join.hpp +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/operator/logical_delim_join.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/operator/logical_comparison_join.hpp" - -namespace duckdb { - -//! LogicalDelimJoin represents a special "duplicate eliminated" join. This join type is only used for subquery -//! flattening, and involves performing duplicate elimination on the LEFT side which is then pushed into the RIGHT side. -class LogicalDelimJoin : public LogicalComparisonJoin { -public: - static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_DELIM_JOIN; - -public: - explicit LogicalDelimJoin(JoinType type); - - //! The set of columns that will be duplicate eliminated from the LHS and pushed into the RHS - vector> duplicate_eliminated_columns; - -public: - void Serialize(FieldWriter &writer) const override; - static unique_ptr Deserialize(LogicalDeserializationState &state, FieldReader &reader); -}; - -} // namespace duckdb diff --git a/src/include/duckdb/planner/operator/logical_extension_operator.hpp b/src/include/duckdb/planner/operator/logical_extension_operator.hpp index c41d46133ea5..0dd082ffe6ae 100644 --- a/src/include/duckdb/planner/operator/logical_extension_operator.hpp +++ b/src/include/duckdb/planner/operator/logical_extension_operator.hpp @@ -26,6 +26,11 @@ struct LogicalExtensionOperator : public LogicalOperator { static unique_ptr Deserialize(LogicalDeserializationState &state, FieldReader &reader); + virtual void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); + virtual unique_ptr CreatePlan(ClientContext &context, PhysicalPlanGenerator &generator) = 0; + + virtual string GetExtensionName() const; }; } // namespace duckdb diff --git a/src/include/duckdb/planner/operator/logical_get.hpp b/src/include/duckdb/planner/operator/logical_get.hpp index e7c2e79d62b3..bdeb33484ee6 100644 --- a/src/include/duckdb/planner/operator/logical_get.hpp +++ b/src/include/duckdb/planner/operator/logical_get.hpp @@ -69,9 +69,15 @@ class LogicalGet : public LogicalOperator { //! Skips the serialization check in VerifyPlan bool SupportSerialization() const override { return function.verify_serialization; - }; + } + + void FormatSerialize(FormatSerializer &serializer) const override; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); protected: void ResolveTypes() override; + +private: + LogicalGet(); }; } // namespace duckdb diff --git a/src/include/duckdb/planner/operator_extension.hpp b/src/include/duckdb/planner/operator_extension.hpp index fa0fcb3c095a..7bf2551863f8 100644 --- a/src/include/duckdb/planner/operator_extension.hpp +++ b/src/include/duckdb/planner/operator_extension.hpp @@ -36,6 +36,7 @@ class OperatorExtension { virtual std::string GetName() = 0; virtual unique_ptr Deserialize(LogicalDeserializationState &state, FieldReader &reader) = 0; + virtual unique_ptr FormatDeserialize(FormatDeserializer &deserializer) = 0; virtual ~OperatorExtension() { } diff --git a/src/include/duckdb/planner/table_filter.hpp b/src/include/duckdb/planner/table_filter.hpp index dd3f7d7010a4..3c8f69bed4af 100644 --- a/src/include/duckdb/planner/table_filter.hpp +++ b/src/include/duckdb/planner/table_filter.hpp @@ -29,7 +29,7 @@ enum class TableFilterType : uint8_t { //! TableFilter represents a filter pushed down into the table scan. class TableFilter { public: - TableFilter(TableFilterType filter_type_p) : filter_type(filter_type_p) { + explicit TableFilter(TableFilterType filter_type_p) : filter_type(filter_type_p) { } virtual ~TableFilter() { } @@ -48,6 +48,9 @@ class TableFilter { virtual void Serialize(FieldWriter &writer) const = 0; static unique_ptr Deserialize(Deserializer &source); + virtual void FormatSerialize(FormatSerializer &serializer) const; + static unique_ptr FormatDeserialize(FormatDeserializer &deserializer); + public: template TARGET &Cast() { @@ -100,6 +103,9 @@ class TableFilterSet { void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &source); + + void FormatSerialize(FormatSerializer &serializer) const; + static TableFilterSet FormatDeserialize(FormatDeserializer &deserializer); }; } // namespace duckdb diff --git a/src/include/duckdb/storage/meta_block_reader.hpp b/src/include/duckdb/storage/meta_block_reader.hpp index df675fd1443f..e9dfc14162f4 100644 --- a/src/include/duckdb/storage/meta_block_reader.hpp +++ b/src/include/duckdb/storage/meta_block_reader.hpp @@ -37,13 +37,10 @@ class MetaBlockReader : public Deserializer { void ReadData(data_ptr_t buffer, idx_t read_size) override; ClientContext &GetContext() override; - optional_ptr GetCatalog() override; - void SetCatalog(Catalog &catalog_p); void SetContext(ClientContext &context_p); private: void ReadNewBlock(block_id_t id); optional_ptr context; - optional_ptr catalog; }; } // namespace duckdb diff --git a/src/include/duckdb/storage/serialization/create_info.json b/src/include/duckdb/storage/serialization/create_info.json index 16e405d975e3..6dd26665cf39 100644 --- a/src/include/duckdb/storage/serialization/create_info.json +++ b/src/include/duckdb/storage/serialization/create_info.json @@ -49,6 +49,10 @@ "type": "string", "property": "index_name" }, + { + "name": "table", + "type": "string" + }, { "name": "index_type", "type": "IndexType" diff --git a/src/include/duckdb/storage/serialization/expression.json b/src/include/duckdb/storage/serialization/expression.json index 4ea3ed16aa45..e903c9b79935 100644 --- a/src/include/duckdb/storage/serialization/expression.json +++ b/src/include/duckdb/storage/serialization/expression.json @@ -277,5 +277,23 @@ } ], "constructor": ["return_type"] + }, + { + "class": "BoundFunctionExpression", + "base": "Expression", + "enum": "BOUND_FUNCTION", + "custom_implementation": true + }, + { + "class": "BoundAggregateExpression", + "base": "Expression", + "enum": "BOUND_AGGREGATE", + "custom_implementation": true + }, + { + "class": "BoundWindowExpression", + "base": "Expression", + "enum": "BOUND_WINDOW", + "custom_implementation": true } ] diff --git a/src/include/duckdb/storage/serialization/logical_operator.json b/src/include/duckdb/storage/serialization/logical_operator.json index af9ccfa7cd53..1e5a595e3ccc 100644 --- a/src/include/duckdb/storage/serialization/logical_operator.json +++ b/src/include/duckdb/storage/serialization/logical_operator.json @@ -352,7 +352,7 @@ { "class": "LogicalComparisonJoin", "base": "LogicalOperator", - "enum": "LOGICAL_COMPARISON_JOIN", + "enum": ["LOGICAL_ASOF_JOIN", "LOGICAL_COMPARISON_JOIN", "LOGICAL_DELIM_JOIN"], "members": [ { "name": "join_type", @@ -377,6 +377,10 @@ { "name": "mark_types", "type": "vector" + }, + { + "name": "duplicate_eliminated_columns", + "type": "vector" } ], "constructor": ["join_type", "type"] @@ -792,5 +796,39 @@ "type": "BoundPivotInfo" } ] + }, + { + "class": "LogicalGet", + "base": "LogicalOperator", + "enum": "LOGICAL_GET", + "custom_implementation": true + }, + { + "class": "LogicalCopyToFile", + "base": "LogicalOperator", + "enum": "LOGICAL_COPY_TO_FILE", + "custom_implementation": true + }, + { + "class": "LogicalCreateIndex", + "base": "LogicalOperator", + "enum": "LOGICAL_CREATE_INDEX", + "members": [ + { + "name": "info", + "type": "CreateInfo*" + }, + { + "name": "unbound_expressions", + "type": "vector" + } + ], + "constructor": ["$ClientContext", "info", "unbound_expressions"] + }, + { + "class": "LogicalExtensionOperator", + "base": "LogicalOperator", + "enum": "LOGICAL_EXTENSION_OPERATOR", + "custom_implementation": true } ] diff --git a/src/include/duckdb/storage/serialization/nodes.json b/src/include/duckdb/storage/serialization/nodes.json index c7cb5d11719a..7ba078560080 100644 --- a/src/include/duckdb/storage/serialization/nodes.json +++ b/src/include/duckdb/storage/serialization/nodes.json @@ -9,12 +9,12 @@ { "name": "id", "type": "LogicalTypeId", - "property": "id_" + "serialize_property": "id_" }, { "name": "type_info", "type": "shared_ptr", - "property": "type_info_", + "serialize_property": "type_info_", "optional": true } ], @@ -323,5 +323,312 @@ } ], "pointer_type": "none" + }, + { + "class": "TableFilterSet", + "includes": [ + "duckdb/planner/table_filter.hpp" + ], + "members": [ + { + "name": "filters", + "type": "unordered_map" + } + ], + "pointer_type": "none" + }, + { + "class": "MultiFileReaderOptions", + "includes": [ + "duckdb/common/multi_file_reader_options.hpp" + ], + "members": [ + { + "name": "filename", + "type": "bool" + }, + { + "name": "hive_partitioning", + "type": "bool" + }, + { + "name": "auto_detect_hive_partitioning", + "type": "bool" + }, + { + "name": "union_by_name", + "type": "bool" + }, + { + "name": "hive_types_autocast", + "type": "bool" + }, + { + "name": "hive_types_schema", + "type": "unordered_map" + } + ], + "pointer_type": "none" + }, + { + "class": "MultiFileReaderBindData", + "includes": [ + "duckdb/common/multi_file_reader.hpp" + ], + "members": [ + { + "name": "filename_idx", + "type": "idx_t" + }, + { + "name": "hive_partitioning_indexes", + "type": "vector" + } + ], + "pointer_type": "none" + }, + { + "class": "HivePartitioningIndex", + "members": [ + { + "name": "value", + "type": "string" + }, + { + "name": "index", + "type": "idx_t" + } + ], + "pointer_type": "none", + "constructor": ["value", "index"] + }, + { + "class": "BufferedCSVReaderOptions", + "includes": [ + "duckdb/execution/operator/persistent/csv_reader_options.hpp" + ], + "members": [ + { + "name": "has_delimiter", + "type": "bool" + }, + { + "name": "delimiter", + "type": "string" + }, + { + "name": "has_quote", + "type": "bool" + }, + { + "name": "quote", + "type": "string" + }, + { + "name": "has_escape", + "type": "bool" + }, + { + "name": "escape", + "type": "string" + }, + { + "name": "has_header", + "type": "bool" + }, + { + "name": "header", + "type": "bool" + }, + { + "name": "ignore_errors", + "type": "bool" + }, + { + "name": "num_cols", + "type": "idx_t" + }, + { + "name": "buffer_sample_size", + "type": "idx_t" + }, + { + "name": "null_str", + "type": "string" + }, + { + "name": "compression", + "type": "FileCompressionType" + }, + { + "name": "new_line", + "type": "NewLineIdentifier" + }, + { + "name": "allow_quoted_nulls", + "type": "bool" + }, + { + "name": "skip_rows", + "type": "idx_t" + }, + { + "name": "skip_rows_set", + "type": "bool" + }, + { + "name": "maximum_line_size", + "type": "idx_t" + }, + { + "name": "normalize_names", + "type": "bool" + }, + { + "name": "force_not_null", + "type": "vector" + }, + { + "name": "all_varchar", + "type": "bool" + }, + { + "name": "sample_chunk_size", + "type": "idx_t" + }, + { + "name": "sample_chunks", + "type": "idx_t" + }, + { + "name": "auto_detect", + "type": "bool" + }, + { + "name": "file_path", + "type": "string" + }, + { + "name": "decimal_separator", + "type": "string" + }, + { + "name": "null_padding", + "type": "bool" + }, + { + "name": "buffer_size", + "type": "idx_t" + }, + { + "name": "file_options", + "type": "MultiFileReaderOptions" + }, + { + "name": "force_quote", + "type": "vector" + }, + { + "name": "date_format", + "type": "unordered_map date_format" + }, + { + "name": "has_format", + "type": "unordered_map date_format" + }, + { + "name": "rejects_table_name", + "type": "string" + }, + { + "name": "rejects_limit", + "type": "idx_t" + }, + { + "name": "rejects_recovery_columns", + "type": "vector" + }, + { + "name": "rejects_recovery_column_ids", + "type": "vector" + } + ], + "pointer_type": "none" + }, + { + "class": "StrpTimeFormat", + "includes": [ + "duckdb/function/scalar/strftime_format.hpp" + ], + "members": [ + { + "name": "format_specifier", + "type": "string" + } + ], + "constructor": ["format_specifier&"], + "pointer_type": "none" + }, + { + "class": "ReadCSVData", + "includes": [ + "duckdb/function/table/read_csv.hpp" + ], + "members": [ + { + "name": "files", + "type": "vector" + }, + { + "name": "csv_types", + "type": "vector" + }, + { + "name": "csv_names", + "type": "vector" + }, + { + "name": "return_types", + "type": "vector" + }, + { + "name": "return_names", + "type": "vector" + }, + { + "name": "filename_col_idx", + "type": "idx_t" + }, + { + "name": "options", + "type": "BufferedCSVReaderOptions" + }, + { + "name": "single_threaded", + "type": "bool" + }, + { + "name": "reader_bind", + "type": "MultiFileReaderBindData" + }, + { + "name": "column_info", + "type": "vector" + } + ] + }, + { + "class": "ColumnInfo", + "members": [ + { + "name": "names", + "type": "vector" + }, + { + "name": "types", + "type": "vector" + } + ], + "pointer_type": "none" } ] diff --git a/src/include/duckdb/storage/serialization/table_filter.json b/src/include/duckdb/storage/serialization/table_filter.json new file mode 100644 index 000000000000..63c1a2c3eabf --- /dev/null +++ b/src/include/duckdb/storage/serialization/table_filter.json @@ -0,0 +1,76 @@ +[ + { + "class": "TableFilter", + "class_type": "filter_type", + "includes": [ + "duckdb/planner/table_filter.hpp" + ], + "members": [ + { + "name": "filter_type", + "type": "TableFilterType" + } + ] + }, + { + "class": "IsNullFilter", + "base": "TableFilter", + "includes": [ + "duckdb/planner/filter/null_filter.hpp" + ], + "enum": "IS_NULL", + "members": [ + ] + }, + { + "class": "IsNotNullFilter", + "base": "TableFilter", + "enum": "IS_NOT_NULL", + "members": [ + ] + }, + { + "class": "ConstantFilter", + "base": "TableFilter", + "includes": [ + "duckdb/planner/filter/constant_filter.hpp" + ], + "enum": "CONSTANT_COMPARISON", + "members": [ + { + "name": "comparison_type", + "type": "ExpressionType" + }, + { + "name": "constant", + "type": "Value" + } + ], + "constructor": ["comparison_type", "constant"] + }, + { + "class": "ConjunctionOrFilter", + "base": "TableFilter", + "includes": [ + "duckdb/planner/filter/conjunction_filter.hpp" + ], + "enum": "CONJUNCTION_OR", + "members": [ + { + "name": "child_filters", + "type": "vector" + } + ] + }, + { + "class": "ConjunctionAndFilter", + "base": "TableFilter", + "enum": "CONJUNCTION_AND", + "members": [ + { + "name": "child_filters", + "type": "vector" + } + ] + } +] diff --git a/src/main/CMakeLists.txt b/src/main/CMakeLists.txt index f19ec798933b..0df3eaf3ca20 100644 --- a/src/main/CMakeLists.txt +++ b/src/main/CMakeLists.txt @@ -4,6 +4,7 @@ endif() add_subdirectory(extension) add_subdirectory(relation) add_subdirectory(settings) +add_subdirectory(chunk_scan_state) if(FORCE_QUERY_LOG) add_definitions(-DDUCKDB_FORCE_QUERY_LOG="\""${FORCE_QUERY_LOG}"\"") @@ -18,6 +19,7 @@ add_library_unity( client_context.cpp client_data.cpp client_verify.cpp + chunk_scan_state.cpp config.cpp connection.cpp database.cpp diff --git a/src/main/attached_database.cpp b/src/main/attached_database.cpp index 36c1ad5b982c..24605378ce99 100644 --- a/src/main/attached_database.cpp +++ b/src/main/attached_database.cpp @@ -87,11 +87,11 @@ bool AttachedDatabase::IsReadOnly() const { return type == AttachedDatabaseType::READ_ONLY_DATABASE; } -string AttachedDatabase::ExtractDatabaseName(const string &dbpath) { +string AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { if (dbpath.empty() || dbpath == ":memory:") { return "memory"; } - return FileSystem::ExtractBaseName(dbpath); + return fs.ExtractBaseName(dbpath); } void AttachedDatabase::Initialize() { diff --git a/src/main/capi/duckdb-c.cpp b/src/main/capi/duckdb-c.cpp index e82621118084..5377df466c98 100644 --- a/src/main/capi/duckdb-c.cpp +++ b/src/main/capi/duckdb-c.cpp @@ -54,6 +54,22 @@ duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out) { return DuckDBSuccess; } +void duckdb_interrupt(duckdb_connection connection) { + if (!connection) { + return; + } + Connection *conn = reinterpret_cast(connection); + conn->Interrupt(); +} + +double duckdb_query_progress(duckdb_connection connection) { + if (!connection) { + return -1; + } + Connection *conn = reinterpret_cast(connection); + return conn->context->GetProgress(); +} + void duckdb_disconnect(duckdb_connection *connection) { if (connection && *connection) { Connection *conn = reinterpret_cast(*connection); diff --git a/src/main/chunk_scan_state.cpp b/src/main/chunk_scan_state.cpp new file mode 100644 index 000000000000..458c9ffc4817 --- /dev/null +++ b/src/main/chunk_scan_state.cpp @@ -0,0 +1,42 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/main/chunk_scan_state.hpp" + +namespace duckdb { + +idx_t ChunkScanState::CurrentOffset() const { + return offset; +} + +void ChunkScanState::IncreaseOffset(idx_t increment, bool unsafe) { + D_ASSERT(unsafe || increment <= RemainingInChunk()); + offset += increment; +} + +bool ChunkScanState::ChunkIsEmpty() const { + return !current_chunk || current_chunk->size() == 0; +} + +bool ChunkScanState::Finished() const { + return finished; +} + +bool ChunkScanState::ScanStarted() const { + return !ChunkIsEmpty(); +} + +DataChunk &ChunkScanState::CurrentChunk() { + // Scan must already be started + D_ASSERT(current_chunk); + return *current_chunk; +} + +idx_t ChunkScanState::RemainingInChunk() const { + if (ChunkIsEmpty()) { + return 0; + } + D_ASSERT(current_chunk); + D_ASSERT(offset <= current_chunk->size()); + return current_chunk->size() - offset; +} + +} // namespace duckdb diff --git a/src/main/chunk_scan_state/CMakeLists.txt b/src/main/chunk_scan_state/CMakeLists.txt new file mode 100644 index 000000000000..4a897ef9d0cc --- /dev/null +++ b/src/main/chunk_scan_state/CMakeLists.txt @@ -0,0 +1,4 @@ +add_library_unity(duckdb_main_chunk_scan_state OBJECT query_result.cpp) +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/src/main/chunk_scan_state/query_result.cpp b/src/main/chunk_scan_state/query_result.cpp new file mode 100644 index 000000000000..84e45e3646f3 --- /dev/null +++ b/src/main/chunk_scan_state/query_result.cpp @@ -0,0 +1,53 @@ +#include "duckdb/main/query_result.hpp" +#include "duckdb/main/stream_query_result.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" + +namespace duckdb { + +QueryResultChunkScanState::QueryResultChunkScanState(QueryResult &result) : ChunkScanState(), result(result) { +} + +QueryResultChunkScanState::~QueryResultChunkScanState() { +} + +bool QueryResultChunkScanState::InternalLoad(PreservedError &error) { + D_ASSERT(!finished); + if (result.type == QueryResultType::STREAM_RESULT) { + auto &stream_result = result.Cast(); + if (!stream_result.IsOpen()) { + return true; + } + } + return result.TryFetch(current_chunk, error); +} + +bool QueryResultChunkScanState::HasError() const { + return result.HasError(); +} + +PreservedError &QueryResultChunkScanState::GetError() { + D_ASSERT(result.HasError()); + return result.GetErrorObject(); +} + +const vector &QueryResultChunkScanState::Types() const { + return result.types; +} + +const vector &QueryResultChunkScanState::Names() const { + return result.names; +} + +bool QueryResultChunkScanState::LoadNextChunk(PreservedError &error) { + if (finished) { + return !finished; + } + auto load_result = InternalLoad(error); + if (!load_result) { + finished = true; + } + offset = 0; + return !finished; +} + +} // namespace duckdb diff --git a/src/main/database.cpp b/src/main/database.cpp index d3798425d57d..633ecccd21d8 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -158,7 +158,7 @@ duckdb::unique_ptr DatabaseInstance::CreateAttachedDatabase(At void DatabaseInstance::CreateMainDatabase() { AttachInfo info; - info.name = AttachedDatabase::ExtractDatabaseName(config.options.database_path); + info.name = AttachedDatabase::ExtractDatabaseName(config.options.database_path, GetFileSystem()); info.path = config.options.database_path; auto attached_database = CreateAttachedDatabase(info, config.options.database_type, config.options.access_mode); diff --git a/src/main/db_instance_cache.cpp b/src/main/db_instance_cache.cpp index 0e7550008f46..41012f559ab0 100644 --- a/src/main/db_instance_cache.cpp +++ b/src/main/db_instance_cache.cpp @@ -3,7 +3,7 @@ namespace duckdb { -string GetDBAbsolutePath(const string &database_p) { +string GetDBAbsolutePath(const string &database_p, FileSystem &fs) { auto database = FileSystem::ExpandPath(database_p, nullptr); if (database.empty()) { return ":memory:"; @@ -16,15 +16,17 @@ string GetDBAbsolutePath(const string &database_p) { // this database path is handled by a replacement open and is not a file path return database; } - if (FileSystem::IsPathAbsolute(database)) { - return FileSystem::NormalizeAbsolutePath(database); + if (fs.IsPathAbsolute(database)) { + return fs.NormalizeAbsolutePath(database); } - return FileSystem::NormalizeAbsolutePath(FileSystem::JoinPath(FileSystem::GetWorkingDirectory(), database)); + return fs.NormalizeAbsolutePath(fs.JoinPath(FileSystem::GetWorkingDirectory(), database)); } shared_ptr DBInstanceCache::GetInstanceInternal(const string &database, const DBConfig &config) { shared_ptr db_instance; - auto abs_database_path = GetDBAbsolutePath(database); + + auto local_fs = FileSystem::CreateLocal(); + auto abs_database_path = GetDBAbsolutePath(database, *local_fs); if (db_instances.find(abs_database_path) != db_instances.end()) { db_instance = db_instances[abs_database_path].lock(); if (db_instance) { @@ -48,7 +50,13 @@ shared_ptr DBInstanceCache::GetInstance(const string &database, const DB shared_ptr DBInstanceCache::CreateInstanceInternal(const string &database, DBConfig &config, bool cache_instance) { - auto abs_database_path = GetDBAbsolutePath(database); + string abs_database_path; + if (config.file_system) { + abs_database_path = GetDBAbsolutePath(database, *config.file_system); + } else { + auto tmp_fs = FileSystem::CreateLocal(); + abs_database_path = GetDBAbsolutePath(database, *tmp_fs); + } if (db_instances.find(abs_database_path) != db_instances.end()) { throw duckdb::Exception(ExceptionType::CONNECTION, "Instance with path: " + abs_database_path + " already exists."); diff --git a/src/main/extension/CMakeLists.txt b/src/main/extension/CMakeLists.txt index 117e9d5ee60f..2865e8120eeb 100644 --- a/src/main/extension/CMakeLists.txt +++ b/src/main/extension/CMakeLists.txt @@ -28,18 +28,23 @@ if(NOT ${DISABLE_BUILTIN_EXTENSIONS}) foreach(EXT_NAME IN LISTS DUCKDB_EXTENSION_NAMES) string(TOUPPER ${EXT_NAME} EXT_NAME_UPPERCASE) if(${DUCKDB_EXTENSION_${EXT_NAME_UPPERCASE}_SHOULD_LINK}) + # Assumes lowercase input! - set(EXTENSION_CLASS ${EXT_NAME}Extension) - string(SUBSTRING ${EXT_NAME} 0 1 FIRST_LETTER) - string(TOUPPER ${FIRST_LETTER} FIRST_LETTER) - string(REGEX REPLACE "^.(.*)" "${FIRST_LETTER}\\1" EXTENSION_CLASS - "${EXT_NAME}") + string(REPLACE "_" ";" EXT_NAME_SPLIT ${EXT_NAME}) + set(EXT_NAME_CAMELCASE "") + foreach(EXT_NAME_PART IN LISTS EXT_NAME_SPLIT) + string(SUBSTRING ${EXT_NAME_PART} 0 1 FIRST_LETTER) + string(SUBSTRING ${EXT_NAME_PART} 1 -1 REMAINDER) + string(TOUPPER ${FIRST_LETTER} FIRST_LETTER) + set(EXT_NAME_CAMELCASE + "${EXT_NAME_CAMELCASE}${FIRST_LETTER}${REMAINDER}") + endforeach() set(EXT_LOADER_NAME_LIST "${EXT_LOADER_NAME_LIST},\n\t\"${EXT_NAME}\"") set(EXT_LOADER_BODY "${EXT_LOADER_BODY}\ if (extension==\"${EXT_NAME}\") { - db.LoadExtension<${EXTENSION_CLASS}Extension>(); + db.LoadExtension<${EXT_NAME_CAMELCASE}Extension>(); return true; } ") diff --git a/src/main/extension/extension_helper.cpp b/src/main/extension/extension_helper.cpp index 0e0461cf9c74..dacf0da1a8f5 100644 --- a/src/main/extension/extension_helper.cpp +++ b/src/main/extension/extension_helper.cpp @@ -179,6 +179,19 @@ ExtensionLoadResult ExtensionHelper::LoadExtensionInternal(DuckDB &db, const std } #endif +#ifdef DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE + if (!initial_load && StringUtil::Contains(DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE, extension)) { + Connection con(db); + auto result = con.Query((string) "LOAD '" + DUCKDB_EXTENSIONS_BUILD_PATH + "/" + extension + "/" + extension + + ".duckdb_extension'"); + if (result->HasError()) { + result->Print(); + return ExtensionLoadResult::EXTENSION_UNKNOWN; + } + return ExtensionLoadResult::LOADED_EXTENSION; + } +#endif + // This is the main extension loading mechanism that loads the extension that are statically linked. #if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS if (TryLoadLinkedExtension(db, extension)) { diff --git a/src/main/extension/extension_install.cpp b/src/main/extension/extension_install.cpp index 6976a1cf829e..a6e6f42a9185 100644 --- a/src/main/extension/extension_install.cpp +++ b/src/main/extension/extension_install.cpp @@ -56,7 +56,7 @@ string ExtensionHelper::ExtensionDirectory(DBConfig &config, FileSystem &fs) { // expand ~ in extension directory extension_directory = fs.ExpandPath(extension_directory); if (!fs.DirectoryExists(extension_directory)) { - auto sep = fs.PathSeparator(); + auto sep = fs.PathSeparator(extension_directory); auto splits = StringUtil::Split(extension_directory, sep); D_ASSERT(!splits.empty()); string extension_directory_prefix; diff --git a/src/main/query_result.cpp b/src/main/query_result.cpp index 9679dc183523..09b00b65d3c7 100644 --- a/src/main/query_result.cpp +++ b/src/main/query_result.cpp @@ -57,19 +57,6 @@ QueryResult::QueryResult(QueryResultType type, StatementType statement_type, Sta client_properties(std::move(client_properties_p)) { } -bool CurrentChunk::Valid() { - if (data_chunk) { - if (position < data_chunk->size()) { - return true; - } - } - return false; -} - -idx_t CurrentChunk::RemainingSize() { - return data_chunk->size() - position; -} - QueryResult::QueryResult(QueryResultType type, PreservedError error) : BaseQueryResult(type, std::move(error)), client_properties("UTC", ArrowOffsetSize::REGULAR) { } diff --git a/src/optimizer/column_lifetime_analyzer.cpp b/src/optimizer/column_lifetime_analyzer.cpp index 787b7619eb79..6f0a7ed82699 100644 --- a/src/optimizer/column_lifetime_analyzer.cpp +++ b/src/optimizer/column_lifetime_analyzer.cpp @@ -1,7 +1,6 @@ #include "duckdb/optimizer/column_lifetime_optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/operator/logical_filter.hpp" namespace duckdb { @@ -38,7 +37,7 @@ void ColumnLifetimeAnalyzer::StandardVisitOperator(LogicalOperator &op) { LogicalOperatorVisitor::VisitOperatorExpressions(op); if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { // visit the duplicate eliminated columns on the LHS, if any - auto &delim_join = op.Cast(); + auto &delim_join = op.Cast(); for (auto &expr : delim_join.duplicate_eliminated_columns) { VisitExpression(&expr); } diff --git a/src/optimizer/compressed_materialization.cpp b/src/optimizer/compressed_materialization.cpp index 49eccd95174f..a1fc9912b924 100644 --- a/src/optimizer/compressed_materialization.cpp +++ b/src/optimizer/compressed_materialization.cpp @@ -9,7 +9,6 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/operator/logical_projection.hpp" namespace duckdb { diff --git a/src/optimizer/deliminator.cpp b/src/optimizer/deliminator.cpp index 807a559f1e9f..0e45c4d86ee6 100644 --- a/src/optimizer/deliminator.cpp +++ b/src/optimizer/deliminator.cpp @@ -7,21 +7,21 @@ #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_delim_get.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/operator/logical_filter.hpp" namespace duckdb { struct DelimCandidate { public: - explicit DelimCandidate(unique_ptr &op, LogicalDelimJoin &delim_join) + explicit DelimCandidate(unique_ptr &op, LogicalComparisonJoin &delim_join) : op(op), delim_join(delim_join), delim_get_count(0) { } public: unique_ptr &op; - LogicalDelimJoin &delim_join; + LogicalComparisonJoin &delim_join; vector>> joins; idx_t delim_get_count; }; @@ -55,6 +55,7 @@ unique_ptr Deliminator::Optimize(unique_ptr op // Change type if there are no more duplicate-eliminated columns if (candidate.joins.size() == candidate.delim_get_count && all_removed) { + delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; delim_join.duplicate_eliminated_columns.clear(); if (all_equality_conditions) { for (auto &cond : delim_join.conditions) { @@ -63,7 +64,6 @@ unique_ptr Deliminator::Optimize(unique_ptr op } } } - candidate.op = LogicalComparisonJoin::FromDelimJoin(delim_join); } } @@ -80,7 +80,7 @@ void Deliminator::FindCandidates(unique_ptr &op, vectorCast()); + candidates.emplace_back(op, op->Cast()); auto &candidate = candidates.back(); // DelimGets are in the RHS @@ -125,7 +125,7 @@ static bool ChildJoinTypeCanBeDeliminated(JoinType &join_type) { } } -bool Deliminator::RemoveJoinWithDelimGet(LogicalDelimJoin &delim_join, const idx_t delim_get_count, +bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, unique_ptr &join, bool &all_equality_conditions) { auto &comparison_join = join->Cast(); if (!ChildJoinTypeCanBeDeliminated(comparison_join.join_type)) { @@ -218,7 +218,7 @@ bool FindAndReplaceBindings(vector &traced_bindings, const vector return true; } -bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalDelimJoin &delim_join, const idx_t delim_get_count, +bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_join, const idx_t delim_get_count, unique_ptr &join, const vector &replacement_bindings) { auto &comparison_join = join->Cast(); diff --git a/src/optimizer/unnest_rewriter.cpp b/src/optimizer/unnest_rewriter.cpp index 34836e664d25..4c93505891e6 100644 --- a/src/optimizer/unnest_rewriter.cpp +++ b/src/optimizer/unnest_rewriter.cpp @@ -2,7 +2,7 @@ #include "duckdb/common/pair.hpp" #include "duckdb/planner/operator/logical_delim_get.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_unnest.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_window.hpp" @@ -17,11 +17,9 @@ void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { } void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expression) { - auto &expr = *expression; if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { - auto &bound_column_ref = expr->Cast(); for (idx_t i = 0; i < replace_bindings.size(); i++) { if (bound_column_ref.binding == replace_bindings[i].old_binding) { @@ -76,7 +74,7 @@ void UnnestRewriter::FindCandidates(unique_ptr *op_ptr, } // found a delim join - auto &delim_join = op->children[0]->Cast(); + auto &delim_join = op->children[0]->Cast(); // only support INNER delim joins if (delim_join.join_type != JoinType::INNER) { return; @@ -295,7 +293,7 @@ void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &update void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); - auto &delim_join = op.Cast(); + auto &delim_join = op.Cast(); for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { auto &expr = *delim_join.duplicate_eliminated_columns[i]; D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF); diff --git a/src/parser/transform/helpers/transform_typename.cpp b/src/parser/transform/helpers/transform_typename.cpp index 66be3f382f44..2527b55f1ac8 100644 --- a/src/parser/transform/helpers/transform_typename.cpp +++ b/src/parser/transform/helpers/transform_typename.cpp @@ -4,6 +4,7 @@ #include "duckdb/parser/transformer.hpp" #include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/vector.hpp" namespace duckdb { @@ -21,7 +22,21 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n if (base_type == LogicalTypeId::LIST) { throw ParserException("LIST is not valid as a stand-alone type"); } else if (base_type == LogicalTypeId::ENUM) { - throw ParserException("ENUM is not valid as a stand-alone type"); + if (!type_name.typmods || type_name.typmods->length == 0) { + throw ParserException("Enum needs a set of entries"); + } + Vector enum_vector(LogicalType::VARCHAR, type_name.typmods->length); + auto string_data = FlatVector::GetData(enum_vector); + idx_t pos = 0; + for (auto node = type_name.typmods->head; node; node = node->next) { + auto constant_value = PGPointerCast(node->data.ptr_value); + if (constant_value->type != duckdb_libpgquery::T_PGAConst || + constant_value->val.type != duckdb_libpgquery::T_PGString) { + throw ParserException("Enum type requires a set of strings as type modifiers"); + } + string_data[pos++] = StringVector::AddString(enum_vector, constant_value->val.val.str); + } + return LogicalType::ENUM(enum_vector, type_name.typmods->length); } else if (base_type == LogicalTypeId::STRUCT) { if (!type_name.typmods || type_name.typmods->length == 0) { throw ParserException("Struct needs a name and entries"); diff --git a/src/parser/transform/statement/transform_create_type.cpp b/src/parser/transform/statement/transform_create_type.cpp index 95f26fa6205a..3235ed3b3917 100644 --- a/src/parser/transform/statement/transform_create_type.cpp +++ b/src/parser/transform/statement/transform_create_type.cpp @@ -56,7 +56,7 @@ unique_ptr Transformer::TransformCreateType(duckdb_libpgquery:: D_ASSERT(stmt.query == nullptr); idx_t size = 0; auto ordered_array = PGListToVector(stmt.vals, size); - info->type = LogicalType::ENUM(info->name, ordered_array, size); + info->type = LogicalType::ENUM(ordered_array, size); } } break; diff --git a/src/parser/transformer.cpp b/src/parser/transformer.cpp index dafb83e9e30e..c6af4d92c53a 100644 --- a/src/parser/transformer.cpp +++ b/src/parser/transformer.cpp @@ -9,20 +9,6 @@ namespace duckdb { -StackChecker::StackChecker(Transformer &transformer_p, idx_t stack_usage_p) - : transformer(transformer_p), stack_usage(stack_usage_p) { - transformer.stack_depth += stack_usage; -} - -StackChecker::~StackChecker() { - transformer.stack_depth -= stack_usage; -} - -StackChecker::StackChecker(StackChecker &&other) noexcept - : transformer(other.transformer), stack_usage(other.stack_usage) { - other.stack_usage = 0; -} - Transformer::Transformer(ParserOptions &options) : parent(nullptr), options(options), stack_depth(DConstants::INVALID_INDEX) { } @@ -59,7 +45,7 @@ void Transformer::InitializeStackCheck() { stack_depth = 0; } -StackChecker Transformer::StackCheck(idx_t extra_stack) { +StackChecker Transformer::StackCheck(idx_t extra_stack) { auto &root = RootTransformer(); D_ASSERT(root.stack_depth != DConstants::INVALID_INDEX); if (root.stack_depth + extra_stack >= options.max_expression_depth) { @@ -67,7 +53,7 @@ StackChecker Transformer::StackCheck(idx_t extra_stack) { "increase the maximum expression depth.", options.max_expression_depth); } - return StackChecker(root, extra_stack); + return StackChecker(root, extra_stack); } unique_ptr Transformer::TransformStatement(duckdb_libpgquery::PGNode &stmt) { diff --git a/src/planner/binder/expression/bind_macro_expression.cpp b/src/planner/binder/expression/bind_macro_expression.cpp index 02962c0fdeeb..cce36d49d3ac 100644 --- a/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/planner/binder/expression/bind_macro_expression.cpp @@ -1,12 +1,12 @@ #include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/common/reference_map.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar_macro_function.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/subquery_expression.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/scalar_macro_function.hpp" - namespace duckdb { void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr &expr) { @@ -79,8 +79,10 @@ BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacro new_macro_binding->arguments = &positionals; macro_binding = new_macro_binding.get(); - // replace current expression with stored macro expression, and replace params + // replace current expression with stored macro expression expr = macro_def.expression->Copy(); + + // now replace the parameters ReplaceMacroParametersRecursive(expr); // bind the unfolded macro diff --git a/src/planner/binder/query_node/plan_subquery.cpp b/src/planner/binder/query_node/plan_subquery.cpp index 7674f5c026ed..8d019e1422bb 100644 --- a/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/planner/binder/query_node/plan_subquery.cpp @@ -138,10 +138,10 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq } } -static unique_ptr +static unique_ptr CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, unique_ptr original_plan, bool perform_delim) { - auto delim_join = make_uniq(join_type); + auto delim_join = make_uniq(join_type, LogicalOperatorType::LOGICAL_DELIM_JOIN); if (!perform_delim) { // if we are not performing a delim join, we push a row_number() OVER() window operator on the LHS // and perform all duplicate elimination on that row number instead @@ -165,7 +165,7 @@ CreateDuplicateEliminatedJoin(const vector &correlated_col return delim_join; } -static void CreateDelimJoinConditions(LogicalDelimJoin &delim_join, +static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, const vector &correlated_columns, vector bindings, idx_t base_offset, bool perform_delim) { auto col_count = perform_delim ? correlated_columns.size() : 1; diff --git a/src/planner/binder/statement/bind_create.cpp b/src/planner/binder/statement/bind_create.cpp index 4def15eff7f5..0916da300df6 100644 --- a/src/planner/binder/statement/bind_create.cpp +++ b/src/planner/binder/statement/bind_create.cpp @@ -24,7 +24,6 @@ #include "duckdb/planner/operator/logical_create_index.hpp" #include "duckdb/planner/operator/logical_create_table.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/operator/logical_distinct.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" @@ -39,6 +38,7 @@ #include "duckdb/main/database_manager.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/function/table/table_scan.hpp" namespace duckdb { @@ -134,33 +134,6 @@ void Binder::BindCreateViewInfo(CreateViewInfo &base) { base.types = query_node.types; } -static void QualifyFunctionNames(ClientContext &context, unique_ptr &expr) { - switch (expr->GetExpressionClass()) { - case ExpressionClass::FUNCTION: { - auto &func = expr->Cast(); - auto function = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, func.catalog, func.schema, - func.function_name, OnEntryNotFound::RETURN_NULL); - if (function) { - func.catalog = function->ParentCatalog().GetName(); - func.schema = function->ParentSchema().name; - } - break; - } - case ExpressionClass::SUBQUERY: { - // replacing parameters within a subquery is slightly different - auto &sq = (expr->Cast()).subquery; - ParsedExpressionIterator::EnumerateQueryNodeChildren( - *sq->node, [&](unique_ptr &child) { QualifyFunctionNames(context, child); }); - break; - } - default: // fall through - break; - } - // unfold child expressions - ParsedExpressionIterator::EnumerateChildren( - *expr, [&](unique_ptr &child) { QualifyFunctionNames(context, child); }); -} - SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { auto &base = info.Cast(); auto &scalar_function = base.function->Cast(); @@ -190,7 +163,6 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { auto this_macro_binding = make_uniq(dummy_types, dummy_names, base.name); macro_binding = this_macro_binding.get(); ExpressionBinder::QualifyColumnNames(*this, scalar_function.expression); - QualifyFunctionNames(context, scalar_function.expression); // create a copy of the expression because we do not want to alter the original auto expression = scalar_function.expression->Copy(); @@ -260,23 +232,7 @@ void Binder::BindLogicalType(ClientContext &context, LogicalType &type, optional } else { type = Catalog::GetType(context, INVALID_CATALOG, schema, user_type_name); } - } else if (type.id() == LogicalTypeId::ENUM) { - auto enum_type_name = EnumType::GetTypeName(type); - optional_ptr enum_type_catalog; - if (catalog) { - enum_type_catalog = - catalog->GetEntry(context, schema, enum_type_name, OnEntryNotFound::RETURN_NULL); - if (!enum_type_catalog) { - // look in the system catalog if the type was not found - enum_type_catalog = Catalog::GetEntry(context, SYSTEM_CATALOG, schema, enum_type_name, - OnEntryNotFound::RETURN_NULL); - } - } else { - enum_type_catalog = Catalog::GetEntry(context, INVALID_CATALOG, schema, enum_type_name, - OnEntryNotFound::RETURN_NULL); - } - - EnumType::SetCatalog(type, enum_type_catalog.get()); + BindLogicalType(context, type, catalog, schema); } } @@ -483,10 +439,14 @@ unique_ptr DuckCatalog::BindCreateIndex(Binder &binder, CreateS create_index_info->scan_types.emplace_back(LogicalType::ROW_TYPE); create_index_info->names = get.names; create_index_info->column_ids = get.column_ids; + auto &bind_data = get.bind_data->Cast(); + bind_data.is_create_index = true; + get.column_ids.push_back(COLUMN_IDENTIFIER_ROW_ID); // the logical CREATE INDEX also needs all fields to scan the referenced table - return make_uniq(std::move(get.bind_data), std::move(create_index_info), std::move(expressions), - table, std::move(get.function)); + auto result = make_uniq(std::move(create_index_info), std::move(expressions), table); + result->children.push_back(std::move(plan)); + return std::move(result); } BoundStatement Binder::Bind(CreateStatement &stmt) { @@ -548,7 +508,6 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { if (plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("Cannot create index on a view!"); } - result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); break; } @@ -653,8 +612,6 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { // We set b to be an alias for the underlying type of a auto inner_type = Catalog::GetType(context, schema.catalog.GetName(), schema.name, UserType::GetTypeName(create_type_info.type)); - // clear to nullptr, we don't need this - EnumType::SetCatalog(inner_type, nullptr); inner_type.SetAlias(create_type_info.name); create_type_info.type = inner_type; } diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index f85fd915b1b2..eff11c73e3bd 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -292,30 +292,6 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptrschema.catalog); - // We add a catalog dependency - auto type_dependency = EnumType::GetCatalog(column.Type()); - - if (type_dependency) { - // Only if the USER comes from a create type - if (schema.catalog.IsTemporaryCatalog() && column.Type().id() == LogicalTypeId::ENUM) { - // for enum types that are used in tables in the temp catalog, we need to - // make a copy of the enum type definition that is accessible there - auto enum_type = type_dependency->user_type; - auto &enum_entries = EnumType::GetValuesInsertOrder(enum_type); - auto enum_size = EnumType::GetSize(enum_type); - Vector copy_enum_entries_vec(LogicalType::VARCHAR, enum_size); - auto copy_enum_entries_ptr = FlatVector::GetData(copy_enum_entries_vec); - auto enum_entries_ptr = FlatVector::GetData(enum_entries); - for (idx_t enum_idx = 0; enum_idx < enum_size; enum_idx++) { - copy_enum_entries_ptr[enum_idx] = - StringVector::AddStringOrBlob(copy_enum_entries_vec, enum_entries_ptr[enum_idx]); - } - auto copy_type = LogicalType::ENUM(EnumType::GetTypeName(enum_type), copy_enum_entries_vec, enum_size); - column.SetType(copy_type); - } else { - result->dependencies.AddDependency(*type_dependency); - } - } } result->dependencies.VerifyDependencies(schema.catalog, result->Base().table); properties.allow_stream_result = false; diff --git a/src/planner/binder/tableref/plan_joinref.cpp b/src/planner/binder/tableref/plan_joinref.cpp index 6ba97090f790..53804ef9d434 100644 --- a/src/planner/binder/tableref/plan_joinref.cpp +++ b/src/planner/binder/tableref/plan_joinref.cpp @@ -7,7 +7,6 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_any_join.hpp" -#include "duckdb/planner/operator/logical_asof_join.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/planner/operator/logical_dependent_join.hpp" @@ -193,12 +192,11 @@ unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &con } else { // we successfully converted expressions into JoinConditions // create a LogicalComparisonJoin - unique_ptr comp_join; + auto logical_type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; if (reftype == JoinRefType::ASOF) { - comp_join = make_uniq(type); - } else { - comp_join = make_uniq(type); + logical_type = LogicalOperatorType::LOGICAL_ASOF_JOIN; } + auto comp_join = make_uniq(type, logical_type); comp_join->conditions = std::move(conditions); comp_join->children.push_back(std::move(left_child)); comp_join->children.push_back(std::move(right_child)); diff --git a/src/planner/expression/bound_aggregate_expression.cpp b/src/planner/expression/bound_aggregate_expression.cpp index de63c708266d..c0885c1c4d9d 100644 --- a/src/planner/expression/bound_aggregate_expression.cpp +++ b/src/planner/expression/bound_aggregate_expression.cpp @@ -104,4 +104,27 @@ unique_ptr BoundAggregateExpression::Deserialize(ExpressionDeseriali return std::move(x); } +void BoundAggregateExpression::FormatSerialize(FormatSerializer &serializer) const { + Expression::FormatSerialize(serializer); + serializer.WriteProperty("return_type", return_type); + serializer.WriteProperty("children", children); + FunctionSerializer::FormatSerialize(serializer, function, bind_info.get()); + serializer.WriteProperty("aggregate_type", aggr_type); + serializer.WriteOptionalProperty("filter", filter); + serializer.WriteOptionalProperty("order_bys", order_bys); +} + +unique_ptr BoundAggregateExpression::FormatDeserialize(FormatDeserializer &deserializer) { + auto return_type = deserializer.ReadProperty("return_type"); + auto children = deserializer.ReadProperty>>("children"); + auto entry = FunctionSerializer::FormatDeserialize( + deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children); + auto aggregate_type = deserializer.ReadProperty("aggregate_type"); + auto filter = deserializer.ReadOptionalProperty>("filter"); + auto result = make_uniq(std::move(entry.first), std::move(children), std::move(filter), + std::move(entry.second), aggregate_type); + deserializer.ReadOptionalProperty("order_bys", result->order_bys); + return std::move(result); +} + } // namespace duckdb diff --git a/src/planner/expression/bound_function_expression.cpp b/src/planner/expression/bound_function_expression.cpp index fe4f985f781c..d2b5e8f676f5 100644 --- a/src/planner/expression/bound_function_expression.cpp +++ b/src/planner/expression/bound_function_expression.cpp @@ -3,6 +3,8 @@ #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/common/types/hash.hpp" #include "duckdb/function/function_serialization.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { @@ -92,4 +94,24 @@ unique_ptr BoundFunctionExpression::Deserialize(ExpressionDeserializ return make_uniq(std::move(return_type), std::move(function), std::move(children), std::move(bind_info), is_operator); } + +void BoundFunctionExpression::FormatSerialize(FormatSerializer &serializer) const { + Expression::FormatSerialize(serializer); + serializer.WriteProperty("return_type", return_type); + serializer.WriteProperty("children", children); + FunctionSerializer::FormatSerialize(serializer, function, bind_info.get()); + serializer.WriteProperty("is_operator", is_operator); +} + +unique_ptr BoundFunctionExpression::FormatDeserialize(FormatDeserializer &deserializer) { + auto return_type = deserializer.ReadProperty("return_type"); + auto children = deserializer.ReadProperty>>("children"); + auto entry = FunctionSerializer::FormatDeserialize( + deserializer, CatalogType::SCALAR_FUNCTION_ENTRY, children); + auto result = make_uniq(std::move(return_type), std::move(entry.first), + std::move(children), std::move(entry.second)); + deserializer.ReadProperty("is_operator", result->is_operator); + return std::move(result); +} + } // namespace duckdb diff --git a/src/planner/expression/bound_window_expression.cpp b/src/planner/expression/bound_window_expression.cpp index cb72340cb8bd..6e4afb434f44 100644 --- a/src/planner/expression/bound_window_expression.cpp +++ b/src/planner/expression/bound_window_expression.cpp @@ -162,4 +162,51 @@ unique_ptr BoundWindowExpression::Deserialize(ExpressionDeserializat return std::move(result); } +void BoundWindowExpression::FormatSerialize(FormatSerializer &serializer) const { + Expression::FormatSerialize(serializer); + serializer.WriteProperty("return_type", return_type); + serializer.WriteProperty("children", children); + if (type == ExpressionType::WINDOW_AGGREGATE) { + D_ASSERT(aggregate); + FunctionSerializer::FormatSerialize(serializer, *aggregate, bind_info.get()); + } + serializer.WriteProperty("partitions", partitions); + serializer.WriteProperty("orders", orders); + serializer.WriteOptionalProperty("filters", filter_expr); + serializer.WriteProperty("ignore_nulls", ignore_nulls); + serializer.WriteProperty("start", start); + serializer.WriteProperty("end", end); + serializer.WriteOptionalProperty("start_expr", start_expr); + serializer.WriteOptionalProperty("end_expr", end_expr); + serializer.WriteOptionalProperty("offset_expr", offset_expr); + serializer.WriteOptionalProperty("default_expr", default_expr); +} + +unique_ptr BoundWindowExpression::FormatDeserialize(FormatDeserializer &deserializer) { + auto expression_type = deserializer.Get(); + auto return_type = deserializer.ReadProperty("return_type"); + auto children = deserializer.ReadProperty>>("children"); + unique_ptr aggregate; + unique_ptr bind_info; + if (expression_type == ExpressionType::WINDOW_AGGREGATE) { + auto entry = FunctionSerializer::FormatDeserialize( + deserializer, CatalogType::AGGREGATE_FUNCTION_ENTRY, children); + aggregate = make_uniq(std::move(entry.first)); + bind_info = std::move(entry.second); + } + auto result = + make_uniq(expression_type, return_type, std::move(aggregate), std::move(bind_info)); + deserializer.ReadProperty("partitions", result->partitions); + deserializer.ReadProperty("orders", result->orders); + deserializer.ReadOptionalProperty("filters", result->filter_expr); + deserializer.ReadProperty("ignore_nulls", result->ignore_nulls); + deserializer.ReadProperty("start", result->start); + deserializer.ReadProperty("end", result->end); + deserializer.ReadOptionalProperty("start_expr", result->start_expr); + deserializer.ReadOptionalProperty("end_expr", result->end_expr); + deserializer.ReadOptionalProperty("offset_expr", result->offset_expr); + deserializer.ReadOptionalProperty("default_expr", result->default_expr); + return std::move(result); +} + } // namespace duckdb diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 1f1dcbb5c726..266b4102c5c0 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -10,6 +10,7 @@ namespace duckdb { ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) : binder(binder), context(context) { + InitializeStackCheck(); if (replace_binder) { stored_binder = &binder.GetActiveBinder(); binder.SetActiveBinder(*this); @@ -28,7 +29,26 @@ ExpressionBinder::~ExpressionBinder() { } } +void ExpressionBinder::InitializeStackCheck() { + if (binder.HasActiveBinder()) { + stack_depth = binder.GetActiveBinder().stack_depth; + } else { + stack_depth = 0; + } +} + +StackChecker ExpressionBinder::StackCheck(const ParsedExpression &expr, idx_t extra_stack) { + D_ASSERT(stack_depth != DConstants::INVALID_INDEX); + if (stack_depth + extra_stack >= MAXIMUM_STACK_DEPTH) { + throw BinderException("Maximum recursion depth exceeded (Maximum: %llu) while binding \"%s\"", + MAXIMUM_STACK_DEPTH, expr.ToString()); + } + return StackChecker(*this, extra_stack); +} + BindResult ExpressionBinder::BindExpression(unique_ptr &expr, idx_t depth, bool root_expression) { + auto stack_checker = StackCheck(*expr); + auto &expr_ref = *expr; switch (expr_ref.expression_class) { case ExpressionClass::BETWEEN: diff --git a/src/planner/logical_operator.cpp b/src/planner/logical_operator.cpp index 40d01886d837..2e96f2a30864 100644 --- a/src/planner/logical_operator.cpp +++ b/src/planner/logical_operator.cpp @@ -131,7 +131,6 @@ void LogicalOperator::Verify(ClientContext &context) { } BufferedSerializer serializer; // We are serializing a query plan - serializer.is_query_plan = true; try { expressions[expr_idx]->Serialize(serializer); } catch (NotImplementedException &ex) { @@ -266,12 +265,8 @@ unique_ptr LogicalOperator::Deserialize(Deserializer &deseriali break; case LogicalOperatorType::LOGICAL_JOIN: throw InternalException("LogicalJoin deserialize not supported"); - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - result = LogicalDelimJoin::Deserialize(state, reader); - break; case LogicalOperatorType::LOGICAL_ASOF_JOIN: - result = LogicalAsOfJoin::Deserialize(state, reader); - break; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: result = LogicalComparisonJoin::Deserialize(state, reader); break; diff --git a/src/planner/logical_operator_visitor.cpp b/src/planner/logical_operator_visitor.cpp index 32283e7d6b46..e692473f1fac 100644 --- a/src/planner/logical_operator_visitor.cpp +++ b/src/planner/logical_operator_visitor.cpp @@ -70,13 +70,10 @@ void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, case LogicalOperatorType::LOGICAL_DELIM_JOIN: case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - auto &delim_join = op.Cast(); - for (auto &expr : delim_join.duplicate_eliminated_columns) { - callback(&expr); - } - } auto &join = op.Cast(); + for (auto &expr : join.duplicate_eliminated_columns) { + callback(&expr); + } for (auto &cond : join.conditions) { callback(&cond.left); callback(&cond.right); diff --git a/src/planner/operator/CMakeLists.txt b/src/planner/operator/CMakeLists.txt index 1528eeed0540..976b6908aed4 100644 --- a/src/planner/operator/CMakeLists.txt +++ b/src/planner/operator/CMakeLists.txt @@ -3,7 +3,6 @@ add_library_unity( OBJECT logical_aggregate.cpp logical_any_join.cpp - logical_asof_join.cpp logical_column_data_get.cpp logical_comparison_join.cpp logical_copy_to_file.cpp @@ -14,7 +13,6 @@ add_library_unity( logical_cteref.cpp logical_delete.cpp logical_delim_get.cpp - logical_delim_join.cpp logical_dependent_join.cpp logical_distinct.cpp logical_dummy_scan.cpp diff --git a/src/planner/operator/logical_asof_join.cpp b/src/planner/operator/logical_asof_join.cpp deleted file mode 100644 index 95cc415b3b51..000000000000 --- a/src/planner/operator/logical_asof_join.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "duckdb/planner/operator/logical_asof_join.hpp" - -namespace duckdb { - -LogicalAsOfJoin::LogicalAsOfJoin(JoinType type) : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_ASOF_JOIN) { -} - -unique_ptr LogicalAsOfJoin::Deserialize(LogicalDeserializationState &state, FieldReader &reader) { - auto result = make_uniq(JoinType::INVALID); - LogicalComparisonJoin::Deserialize(*result, state, reader); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/planner/operator/logical_comparison_join.cpp b/src/planner/operator/logical_comparison_join.cpp index befdb9736ed9..507c0ee240ac 100644 --- a/src/planner/operator/logical_comparison_join.cpp +++ b/src/planner/operator/logical_comparison_join.cpp @@ -1,7 +1,6 @@ #include "duckdb/common/field_writer.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/common/enum_util.hpp" namespace duckdb { @@ -26,6 +25,7 @@ void LogicalComparisonJoin::Serialize(FieldWriter &writer) const { LogicalJoin::Serialize(writer); writer.WriteRegularSerializableList(mark_types); writer.WriteRegularSerializableList(conditions); + writer.WriteSerializableList(duplicate_eliminated_columns); } void LogicalComparisonJoin::Deserialize(LogicalComparisonJoin &comparison_join, LogicalDeserializationState &state, @@ -33,6 +33,7 @@ void LogicalComparisonJoin::Deserialize(LogicalComparisonJoin &comparison_join, LogicalJoin::Deserialize(comparison_join, state, reader); comparison_join.mark_types = reader.ReadRequiredSerializableList(); comparison_join.conditions = reader.ReadRequiredSerializableList(state.gstate); + comparison_join.duplicate_eliminated_columns = reader.ReadRequiredSerializableList(state.gstate); } unique_ptr LogicalComparisonJoin::Deserialize(LogicalDeserializationState &state, @@ -42,17 +43,4 @@ unique_ptr LogicalComparisonJoin::Deserialize(LogicalDeserializ return std::move(result); } -unique_ptr LogicalComparisonJoin::FromDelimJoin(LogicalDelimJoin &delim_join) { - auto new_join = make_uniq(delim_join.join_type); - new_join->children = std::move(delim_join.children); - new_join->conditions = std::move(delim_join.conditions); - new_join->types = std::move(delim_join.types); - new_join->mark_types = std::move(delim_join.mark_types); - new_join->mark_index = delim_join.mark_index; - new_join->left_projection_map = std::move(delim_join.left_projection_map); - new_join->right_projection_map = std::move(delim_join.right_projection_map); - new_join->join_stats = std::move(delim_join.join_stats); - return std::move(new_join); -} - } // namespace duckdb diff --git a/src/planner/operator/logical_copy_to_file.cpp b/src/planner/operator/logical_copy_to_file.cpp index ea32a05b571d..cc6744a2f8bb 100644 --- a/src/planner/operator/logical_copy_to_file.cpp +++ b/src/planner/operator/logical_copy_to_file.cpp @@ -59,6 +59,14 @@ unique_ptr LogicalCopyToFile::Deserialize(LogicalDeserializatio return std::move(result); } +void LogicalCopyToFile::FormatSerialize(FormatSerializer &serializer) const { + throw SerializationException("LogicalCopyToFile not implemented yet"); +} + +unique_ptr LogicalCopyToFile::FormatDeserialize(FormatDeserializer &deserializer) { + throw SerializationException("LogicalCopyToFile not implemented yet"); +} + idx_t LogicalCopyToFile::EstimateCardinality(ClientContext &context) { return 1; } diff --git a/src/planner/operator/logical_create_index.cpp b/src/planner/operator/logical_create_index.cpp index f1c0210ce2c9..9be5ad799e1d 100644 --- a/src/planner/operator/logical_create_index.cpp +++ b/src/planner/operator/logical_create_index.cpp @@ -6,11 +6,9 @@ namespace duckdb { -LogicalCreateIndex::LogicalCreateIndex(unique_ptr bind_data_p, unique_ptr info_p, - vector> expressions_p, TableCatalogEntry &table_p, - TableFunction function_p) - : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), bind_data(std::move(bind_data_p)), - info(std::move(info_p)), table(table_p), function(std::move(function_p)) { +LogicalCreateIndex::LogicalCreateIndex(unique_ptr info_p, vector> expressions_p, + TableCatalogEntry &table_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), info(std::move(info_p)), table(table_p) { for (auto &expr : expressions_p) { this->unbound_expressions.push_back(expr->Copy()); @@ -22,12 +20,21 @@ LogicalCreateIndex::LogicalCreateIndex(unique_ptr bind_data_p, uni } } +LogicalCreateIndex::LogicalCreateIndex(ClientContext &context, unique_ptr info_p, + vector> expressions_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), + info(unique_ptr_cast(std::move(info_p))), table(BindTable(context, *info)) { + for (auto &expr : expressions_p) { + this->unbound_expressions.push_back(expr->Copy()); + } + this->expressions = std::move(expressions_p); +} + void LogicalCreateIndex::Serialize(FieldWriter &writer) const { writer.WriteOptional(info); writer.WriteString(table.catalog.GetName()); writer.WriteString(table.schema.name); writer.WriteString(table.name); - FunctionSerializer::SerializeBase(writer, function, bind_data.get()); writer.WriteSerializableList(unbound_expressions); } @@ -38,10 +45,6 @@ unique_ptr LogicalCreateIndex::Deserialize(LogicalDeserializati auto catalog = reader.ReadRequired(); auto schema = reader.ReadRequired(); auto table_name = reader.ReadRequired(); - unique_ptr bind_data; - bool has_deserialize; - auto function = FunctionSerializer::DeserializeBaseInternal( - reader, state.gstate, CatalogType::TABLE_FUNCTION_ENTRY, bind_data, has_deserialize); auto unbound_expressions = reader.ReadRequiredSerializableList(state.gstate); if (info->type != CatalogType::INDEX_ENTRY) { throw InternalException("Unexpected type: '%s', expected '%s'", CatalogTypeToString(info->type), @@ -49,12 +52,18 @@ unique_ptr LogicalCreateIndex::Deserialize(LogicalDeserializati } auto index_info = unique_ptr_cast(std::move(info)); auto &table = Catalog::GetEntry(context, catalog, schema, table_name); - return make_uniq(std::move(bind_data), std::move(index_info), std::move(unbound_expressions), - table, std::move(function)); + return make_uniq(std::move(index_info), std::move(unbound_expressions), table); } void LogicalCreateIndex::ResolveTypes() { types.emplace_back(LogicalType::BIGINT); } +TableCatalogEntry &LogicalCreateIndex::BindTable(ClientContext &context, CreateIndexInfo &info) { + auto &catalog = info.catalog; + auto &schema = info.schema; + auto &table_name = info.table; + return Catalog::GetEntry(context, catalog, schema, table_name); +} + } // namespace duckdb diff --git a/src/planner/operator/logical_delim_join.cpp b/src/planner/operator/logical_delim_join.cpp deleted file mode 100644 index 678884437451..000000000000 --- a/src/planner/operator/logical_delim_join.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "duckdb/common/field_writer.hpp" -#include "duckdb/planner/operator/logical_delim_join.hpp" - -namespace duckdb { - -LogicalDelimJoin::LogicalDelimJoin(JoinType type) - : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DELIM_JOIN) { -} - -void LogicalDelimJoin::Serialize(FieldWriter &writer) const { - LogicalComparisonJoin::Serialize(writer); - if (type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { - throw InternalException("LogicalDelimJoin needs to have type LOGICAL_DELIM_JOIN"); - } - writer.WriteSerializableList(duplicate_eliminated_columns); -} - -unique_ptr LogicalDelimJoin::Deserialize(LogicalDeserializationState &state, FieldReader &reader) { - auto result = make_uniq(JoinType::INVALID); - LogicalComparisonJoin::Deserialize(*result, state, reader); - result->duplicate_eliminated_columns = reader.ReadRequiredSerializableList(state.gstate); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/planner/operator/logical_extension_operator.cpp b/src/planner/operator/logical_extension_operator.cpp index e513c841c496..630b15f21629 100644 --- a/src/planner/operator/logical_extension_operator.cpp +++ b/src/planner/operator/logical_extension_operator.cpp @@ -1,5 +1,7 @@ #include "duckdb/planner/operator/logical_extension_operator.hpp" #include "duckdb/main/config.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { unique_ptr LogicalExtensionOperator::Deserialize(LogicalDeserializationState &state, @@ -15,4 +17,26 @@ unique_ptr LogicalExtensionOperator::Deserialize(Logic throw SerializationException("No serialization method exists for extension: " + extension_name); } + +void LogicalExtensionOperator::FormatSerialize(FormatSerializer &serializer) const { + LogicalOperator::FormatSerialize(serializer); + serializer.WriteProperty("extension_name", GetExtensionName()); +} + +unique_ptr LogicalExtensionOperator::FormatDeserialize(FormatDeserializer &deserializer) { + auto &config = DBConfig::GetConfig(deserializer.Get()); + auto extension_name = deserializer.ReadProperty("extension_name"); + for (auto &extension : config.operator_extensions) { + if (extension->GetName() == extension_name) { + return extension->FormatDeserialize(deserializer); + } + } + throw SerializationException("No deserialization method exists for extension: " + extension_name); +} + +string LogicalExtensionOperator::GetExtensionName() const { + throw SerializationException("LogicalExtensionOperator::GetExtensionName not implemented which is required for " + "serializing extension operators"); +} + } // namespace duckdb diff --git a/src/planner/operator/logical_get.cpp b/src/planner/operator/logical_get.cpp index 4968f3d08858..c5ed6cec0e0d 100644 --- a/src/planner/operator/logical_get.cpp +++ b/src/planner/operator/logical_get.cpp @@ -8,9 +8,14 @@ #include "duckdb/function/table/table_scan.hpp" #include "duckdb/main/config.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" namespace duckdb { +LogicalGet::LogicalGet() : LogicalOperator(LogicalOperatorType::LOGICAL_GET) { +} + LogicalGet::LogicalGet(idx_t table_index, TableFunction function, unique_ptr bind_data, vector returned_types, vector returned_names) : LogicalOperator(LogicalOperatorType::LOGICAL_GET), table_index(table_index), function(std::move(function)), @@ -199,6 +204,70 @@ unique_ptr LogicalGet::Deserialize(LogicalDeserializationState return std::move(result); } +void LogicalGet::FormatSerialize(FormatSerializer &serializer) const { + LogicalOperator::FormatSerialize(serializer); + serializer.WriteProperty("table_index", table_index); + serializer.WriteProperty("returned_types", returned_types); + serializer.WriteProperty("names", names); + serializer.WriteProperty("column_ids", column_ids); + serializer.WriteProperty("projection_ids", projection_ids); + serializer.WriteProperty("table_filters", table_filters); + FunctionSerializer::FormatSerialize(serializer, function, bind_data.get()); + if (!function.format_serialize) { + D_ASSERT(!function.format_deserialize); + // no serialize method: serialize input values and named_parameters for rebinding purposes + serializer.WriteProperty("parameters", parameters); + serializer.WriteProperty("named_parameters", named_parameters); + serializer.WriteProperty("input_table_types", input_table_types); + serializer.WriteProperty("input_table_names", input_table_names); + } + serializer.WriteProperty("projected_input", projected_input); +} + +unique_ptr LogicalGet::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = unique_ptr(new LogicalGet()); + deserializer.ReadProperty("table_index", result->table_index); + deserializer.ReadProperty("returned_types", result->returned_types); + deserializer.ReadProperty("names", result->names); + deserializer.ReadProperty("column_ids", result->column_ids); + deserializer.ReadProperty("projection_ids", result->projection_ids); + deserializer.ReadProperty("table_filters", result->table_filters); + auto entry = FunctionSerializer::FormatDeserializeBase( + deserializer, CatalogType::TABLE_FUNCTION_ENTRY); + auto &function = entry.first; + auto has_serialize = entry.second; + + unique_ptr bind_data; + if (!has_serialize) { + deserializer.ReadProperty("parameters", result->parameters); + deserializer.ReadProperty("named_parameters", result->named_parameters); + deserializer.ReadProperty("input_table_types", result->input_table_types); + deserializer.ReadProperty("input_table_names", result->input_table_names); + TableFunctionBindInput input(result->parameters, result->named_parameters, result->input_table_types, + result->input_table_names, function.function_info.get()); + + vector bind_return_types; + vector bind_names; + if (!function.bind) { + throw InternalException("Table function \"%s\" has neither bind nor (de)serialize", function.name); + } + bind_data = function.bind(deserializer.Get(), input, bind_return_types, bind_names); + if (result->returned_types != bind_return_types) { + throw SerializationException( + "Table function deserialization failure - bind returned different return types than were serialized"); + } + // names can actually be different because of aliases - only the sizes cannot be different + if (result->names.size() != bind_names.size()) { + throw SerializationException( + "Table function deserialization failure - bind returned different returned names than were serialized"); + } + } else { + bind_data = FunctionSerializer::FunctionDeserialize(deserializer, function); + } + deserializer.ReadProperty("projected_input", result->projected_input); + return std::move(result); +} + vector LogicalGet::GetTableIndex() const { return vector {table_index}; } diff --git a/src/planner/planner.cpp b/src/planner/planner.cpp index 1ced9d38c9e4..26464ba2f615 100644 --- a/src/planner/planner.cpp +++ b/src/planner/planner.cpp @@ -176,7 +176,6 @@ void Planner::VerifyPlan(ClientContext &context, unique_ptr &op } BufferedSerializer serializer; - serializer.is_query_plan = true; try { op->Serialize(serializer); } catch (NotImplementedException &ex) { diff --git a/src/storage/checkpoint_manager.cpp b/src/storage/checkpoint_manager.cpp index 1691a08fd0bd..c633fb00cf0c 100644 --- a/src/storage/checkpoint_manager.cpp +++ b/src/storage/checkpoint_manager.cpp @@ -131,7 +131,6 @@ void SingleFileCheckpointReader::LoadFromStorage() { con.BeginTransaction(); // create the MetaBlockReader to read from the storage MetaBlockReader reader(block_manager, meta_block); - reader.SetCatalog(catalog.GetAttached().GetCatalog()); reader.SetContext(*con.context); LoadCheckpoint(*con.context, reader); con.Commit(); @@ -414,10 +413,7 @@ void CheckpointWriter::WriteType(TypeCatalogEntry &type) { void CheckpointReader::ReadType(ClientContext &context, MetaBlockReader &reader) { auto info = TypeCatalogEntry::Deserialize(reader); - auto &catalog_entry = catalog.CreateType(context, *info)->Cast(); - if (info->type.id() == LogicalTypeId::ENUM) { - EnumType::SetCatalog(info->type, &catalog_entry); - } + catalog.CreateType(context, info->Cast()); } //===--------------------------------------------------------------------===// diff --git a/src/storage/meta_block_reader.cpp b/src/storage/meta_block_reader.cpp index d31f5f1c15e5..c3b00e701d39 100644 --- a/src/storage/meta_block_reader.cpp +++ b/src/storage/meta_block_reader.cpp @@ -43,10 +43,6 @@ ClientContext &MetaBlockReader::GetContext() { return *context; } -optional_ptr MetaBlockReader::GetCatalog() { - return catalog; -} - void MetaBlockReader::ReadNewBlock(block_id_t id) { auto &buffer_manager = block_manager.buffer_manager; @@ -65,11 +61,6 @@ void MetaBlockReader::ReadNewBlock(block_id_t id) { offset = sizeof(block_id_t); } -void MetaBlockReader::SetCatalog(Catalog &catalog_p) { - D_ASSERT(!catalog); - catalog = &catalog_p; -} - void MetaBlockReader::SetContext(ClientContext &context_p) { D_ASSERT(!context); context = &context_p; diff --git a/src/storage/serialization/CMakeLists.txt b/src/storage/serialization/CMakeLists.txt index 535293e9f181..543c033353f0 100644 --- a/src/storage/serialization/CMakeLists.txt +++ b/src/storage/serialization/CMakeLists.txt @@ -12,6 +12,7 @@ add_library_unity( serialize_query_node.cpp serialize_result_modifier.cpp serialize_statement.cpp + serialize_table_filter.cpp serialize_tableref.cpp serialize_types.cpp) set(ALL_OBJECT_FILES diff --git a/src/storage/serialization/serialize_create_info.cpp b/src/storage/serialization/serialize_create_info.cpp index 81e00def219c..8faeae6a12fe 100644 --- a/src/storage/serialization/serialize_create_info.cpp +++ b/src/storage/serialization/serialize_create_info.cpp @@ -75,6 +75,7 @@ unique_ptr CreateInfo::FormatDeserialize(FormatDeserializer &deseria void CreateIndexInfo::FormatSerialize(FormatSerializer &serializer) const { CreateInfo::FormatSerialize(serializer); serializer.WriteProperty("name", index_name); + serializer.WriteProperty("table", table); serializer.WriteProperty("index_type", index_type); serializer.WriteProperty("constraint_type", constraint_type); serializer.WriteProperty("parsed_expressions", parsed_expressions); @@ -85,6 +86,7 @@ void CreateIndexInfo::FormatSerialize(FormatSerializer &serializer) const { unique_ptr CreateIndexInfo::FormatDeserialize(FormatDeserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateIndexInfo()); deserializer.ReadProperty("name", result->index_name); + deserializer.ReadProperty("table", result->table); deserializer.ReadProperty("index_type", result->index_type); deserializer.ReadProperty("constraint_type", result->constraint_type); deserializer.ReadProperty("parsed_expressions", result->parsed_expressions); diff --git a/src/storage/serialization/serialize_expression.cpp b/src/storage/serialization/serialize_expression.cpp index 47df2a531ca7..10a1e9dc5d26 100644 --- a/src/storage/serialization/serialize_expression.cpp +++ b/src/storage/serialization/serialize_expression.cpp @@ -22,6 +22,9 @@ unique_ptr Expression::FormatDeserialize(FormatDeserializer &deseria deserializer.Set(type); unique_ptr result; switch (expression_class) { + case ExpressionClass::BOUND_AGGREGATE: + result = BoundAggregateExpression::FormatDeserialize(deserializer); + break; case ExpressionClass::BOUND_BETWEEN: result = BoundBetweenExpression::FormatDeserialize(deserializer); break; @@ -46,6 +49,9 @@ unique_ptr Expression::FormatDeserialize(FormatDeserializer &deseria case ExpressionClass::BOUND_DEFAULT: result = BoundDefaultExpression::FormatDeserialize(deserializer); break; + case ExpressionClass::BOUND_FUNCTION: + result = BoundFunctionExpression::FormatDeserialize(deserializer); + break; case ExpressionClass::BOUND_LAMBDA: result = BoundLambdaExpression::FormatDeserialize(deserializer); break; @@ -64,6 +70,9 @@ unique_ptr Expression::FormatDeserialize(FormatDeserializer &deseria case ExpressionClass::BOUND_UNNEST: result = BoundUnnestExpression::FormatDeserialize(deserializer); break; + case ExpressionClass::BOUND_WINDOW: + result = BoundWindowExpression::FormatDeserialize(deserializer); + break; default: throw SerializationException("Unsupported type for deserialization of Expression!"); } diff --git a/src/storage/serialization/serialize_logical_operator.cpp b/src/storage/serialization/serialize_logical_operator.cpp index 9e4ccaa1e3d2..21e186d23e44 100644 --- a/src/storage/serialization/serialize_logical_operator.cpp +++ b/src/storage/serialization/serialize_logical_operator.cpp @@ -31,6 +31,9 @@ unique_ptr LogicalOperator::FormatDeserialize(FormatDeserialize case LogicalOperatorType::LOGICAL_ANY_JOIN: result = LogicalAnyJoin::FormatDeserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + result = LogicalComparisonJoin::FormatDeserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_ATTACH: result = LogicalSimple::FormatDeserialize(deserializer); break; @@ -40,6 +43,12 @@ unique_ptr LogicalOperator::FormatDeserialize(FormatDeserialize case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: result = LogicalComparisonJoin::FormatDeserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + result = LogicalCopyToFile::FormatDeserialize(deserializer); + break; + case LogicalOperatorType::LOGICAL_CREATE_INDEX: + result = LogicalCreateIndex::FormatDeserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_CREATE_MACRO: result = LogicalCreate::FormatDeserialize(deserializer); break; @@ -70,6 +79,9 @@ unique_ptr LogicalOperator::FormatDeserialize(FormatDeserialize case LogicalOperatorType::LOGICAL_DELIM_GET: result = LogicalDelimGet::FormatDeserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + result = LogicalComparisonJoin::FormatDeserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_DETACH: result = LogicalSimple::FormatDeserialize(deserializer); break; @@ -94,9 +106,15 @@ unique_ptr LogicalOperator::FormatDeserialize(FormatDeserialize case LogicalOperatorType::LOGICAL_EXPRESSION_GET: result = LogicalExpressionGet::FormatDeserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: + result = LogicalExtensionOperator::FormatDeserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_FILTER: result = LogicalFilter::FormatDeserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_GET: + result = LogicalGet::FormatDeserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_INSERT: result = LogicalInsert::FormatDeserialize(deserializer); break; @@ -255,6 +273,7 @@ void LogicalComparisonJoin::FormatSerialize(FormatSerializer &serializer) const serializer.WriteProperty("right_projection_map", right_projection_map); serializer.WriteProperty("conditions", conditions); serializer.WriteProperty("mark_types", mark_types); + serializer.WriteProperty("duplicate_eliminated_columns", duplicate_eliminated_columns); } unique_ptr LogicalComparisonJoin::FormatDeserialize(FormatDeserializer &deserializer) { @@ -265,6 +284,7 @@ unique_ptr LogicalComparisonJoin::FormatDeserialize(FormatDeser deserializer.ReadProperty("right_projection_map", result->right_projection_map); deserializer.ReadProperty("conditions", result->conditions); deserializer.ReadProperty("mark_types", result->mark_types); + deserializer.ReadProperty("duplicate_eliminated_columns", result->duplicate_eliminated_columns); return std::move(result); } @@ -279,6 +299,19 @@ unique_ptr LogicalCreate::FormatDeserialize(FormatDeserializer return std::move(result); } +void LogicalCreateIndex::FormatSerialize(FormatSerializer &serializer) const { + LogicalOperator::FormatSerialize(serializer); + serializer.WriteProperty("info", *info); + serializer.WriteProperty("unbound_expressions", unbound_expressions); +} + +unique_ptr LogicalCreateIndex::FormatDeserialize(FormatDeserializer &deserializer) { + auto info = deserializer.ReadProperty>("info"); + auto unbound_expressions = deserializer.ReadProperty>>("unbound_expressions"); + auto result = duckdb::unique_ptr(new LogicalCreateIndex(deserializer.Get(), std::move(info), std::move(unbound_expressions))); + return std::move(result); +} + void LogicalCreateTable::FormatSerialize(FormatSerializer &serializer) const { LogicalOperator::FormatSerialize(serializer); serializer.WriteProperty("catalog", schema.ParentCatalog().GetName()); diff --git a/src/storage/serialization/serialize_nodes.cpp b/src/storage/serialization/serialize_nodes.cpp index e8f20cd44a6d..6c531368edb9 100644 --- a/src/storage/serialization/serialize_nodes.cpp +++ b/src/storage/serialization/serialize_nodes.cpp @@ -22,6 +22,12 @@ #include "duckdb/planner/expression/bound_parameter_data.hpp" #include "duckdb/planner/joinside.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/multi_file_reader_options.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/execution/operator/persistent/csv_reader_options.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/function/table/read_csv.hpp" namespace duckdb { @@ -79,6 +85,86 @@ BoundPivotInfo BoundPivotInfo::FormatDeserialize(FormatDeserializer &deserialize return result; } +void BufferedCSVReaderOptions::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("has_delimiter", has_delimiter); + serializer.WriteProperty("delimiter", delimiter); + serializer.WriteProperty("has_quote", has_quote); + serializer.WriteProperty("quote", quote); + serializer.WriteProperty("has_escape", has_escape); + serializer.WriteProperty("escape", escape); + serializer.WriteProperty("has_header", has_header); + serializer.WriteProperty("header", header); + serializer.WriteProperty("ignore_errors", ignore_errors); + serializer.WriteProperty("num_cols", num_cols); + serializer.WriteProperty("buffer_sample_size", buffer_sample_size); + serializer.WriteProperty("null_str", null_str); + serializer.WriteProperty("compression", compression); + serializer.WriteProperty("new_line", new_line); + serializer.WriteProperty("allow_quoted_nulls", allow_quoted_nulls); + serializer.WriteProperty("skip_rows", skip_rows); + serializer.WriteProperty("skip_rows_set", skip_rows_set); + serializer.WriteProperty("maximum_line_size", maximum_line_size); + serializer.WriteProperty("normalize_names", normalize_names); + serializer.WriteProperty("force_not_null", force_not_null); + serializer.WriteProperty("all_varchar", all_varchar); + serializer.WriteProperty("sample_chunk_size", sample_chunk_size); + serializer.WriteProperty("sample_chunks", sample_chunks); + serializer.WriteProperty("auto_detect", auto_detect); + serializer.WriteProperty("file_path", file_path); + serializer.WriteProperty("decimal_separator", decimal_separator); + serializer.WriteProperty("null_padding", null_padding); + serializer.WriteProperty("buffer_size", buffer_size); + serializer.WriteProperty("file_options", file_options); + serializer.WriteProperty("force_quote", force_quote); + serializer.WriteProperty("date_format", date_format); + serializer.WriteProperty("has_format", has_format); + serializer.WriteProperty("rejects_table_name", rejects_table_name); + serializer.WriteProperty("rejects_limit", rejects_limit); + serializer.WriteProperty("rejects_recovery_columns", rejects_recovery_columns); + serializer.WriteProperty("rejects_recovery_column_ids", rejects_recovery_column_ids); +} + +BufferedCSVReaderOptions BufferedCSVReaderOptions::FormatDeserialize(FormatDeserializer &deserializer) { + BufferedCSVReaderOptions result; + deserializer.ReadProperty("has_delimiter", result.has_delimiter); + deserializer.ReadProperty("delimiter", result.delimiter); + deserializer.ReadProperty("has_quote", result.has_quote); + deserializer.ReadProperty("quote", result.quote); + deserializer.ReadProperty("has_escape", result.has_escape); + deserializer.ReadProperty("escape", result.escape); + deserializer.ReadProperty("has_header", result.has_header); + deserializer.ReadProperty("header", result.header); + deserializer.ReadProperty("ignore_errors", result.ignore_errors); + deserializer.ReadProperty("num_cols", result.num_cols); + deserializer.ReadProperty("buffer_sample_size", result.buffer_sample_size); + deserializer.ReadProperty("null_str", result.null_str); + deserializer.ReadProperty("compression", result.compression); + deserializer.ReadProperty("new_line", result.new_line); + deserializer.ReadProperty("allow_quoted_nulls", result.allow_quoted_nulls); + deserializer.ReadProperty("skip_rows", result.skip_rows); + deserializer.ReadProperty("skip_rows_set", result.skip_rows_set); + deserializer.ReadProperty("maximum_line_size", result.maximum_line_size); + deserializer.ReadProperty("normalize_names", result.normalize_names); + deserializer.ReadProperty("force_not_null", result.force_not_null); + deserializer.ReadProperty("all_varchar", result.all_varchar); + deserializer.ReadProperty("sample_chunk_size", result.sample_chunk_size); + deserializer.ReadProperty("sample_chunks", result.sample_chunks); + deserializer.ReadProperty("auto_detect", result.auto_detect); + deserializer.ReadProperty("file_path", result.file_path); + deserializer.ReadProperty("decimal_separator", result.decimal_separator); + deserializer.ReadProperty("null_padding", result.null_padding); + deserializer.ReadProperty("buffer_size", result.buffer_size); + deserializer.ReadProperty("file_options", result.file_options); + deserializer.ReadProperty("force_quote", result.force_quote); + deserializer.ReadProperty("date_format", result.date_format); + deserializer.ReadProperty("has_format", result.has_format); + deserializer.ReadProperty("rejects_table_name", result.rejects_table_name); + deserializer.ReadProperty("rejects_limit", result.rejects_limit); + deserializer.ReadProperty("rejects_recovery_columns", result.rejects_recovery_columns); + deserializer.ReadProperty("rejects_recovery_column_ids", result.rejects_recovery_column_ids); + return result; +} + void CaseCheck::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("when_expr", *when_expr); serializer.WriteProperty("then_expr", *then_expr); @@ -121,6 +207,18 @@ ColumnDefinition ColumnDefinition::FormatDeserialize(FormatDeserializer &deseria return result; } +void ColumnInfo::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("names", names); + serializer.WriteProperty("types", types); +} + +ColumnInfo ColumnInfo::FormatDeserialize(FormatDeserializer &deserializer) { + ColumnInfo result; + deserializer.ReadProperty("names", result.names); + deserializer.ReadProperty("types", result.types); + return result; +} + void ColumnList::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("columns", columns); } @@ -155,6 +253,18 @@ CommonTableExpressionMap CommonTableExpressionMap::FormatDeserialize(FormatDeser return result; } +void HivePartitioningIndex::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("value", value); + serializer.WriteProperty("index", index); +} + +HivePartitioningIndex HivePartitioningIndex::FormatDeserialize(FormatDeserializer &deserializer) { + auto value = deserializer.ReadProperty("value"); + auto index = deserializer.ReadProperty("index"); + HivePartitioningIndex result(std::move(value), index); + return result; +} + void JoinCondition::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("left", *left); serializer.WriteProperty("right", *right); @@ -181,6 +291,38 @@ LogicalType LogicalType::FormatDeserialize(FormatDeserializer &deserializer) { return result; } +void MultiFileReaderBindData::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("filename_idx", filename_idx); + serializer.WriteProperty("hive_partitioning_indexes", hive_partitioning_indexes); +} + +MultiFileReaderBindData MultiFileReaderBindData::FormatDeserialize(FormatDeserializer &deserializer) { + MultiFileReaderBindData result; + deserializer.ReadProperty("filename_idx", result.filename_idx); + deserializer.ReadProperty("hive_partitioning_indexes", result.hive_partitioning_indexes); + return result; +} + +void MultiFileReaderOptions::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("filename", filename); + serializer.WriteProperty("hive_partitioning", hive_partitioning); + serializer.WriteProperty("auto_detect_hive_partitioning", auto_detect_hive_partitioning); + serializer.WriteProperty("union_by_name", union_by_name); + serializer.WriteProperty("hive_types_autocast", hive_types_autocast); + serializer.WriteProperty("hive_types_schema", hive_types_schema); +} + +MultiFileReaderOptions MultiFileReaderOptions::FormatDeserialize(FormatDeserializer &deserializer) { + MultiFileReaderOptions result; + deserializer.ReadProperty("filename", result.filename); + deserializer.ReadProperty("hive_partitioning", result.hive_partitioning); + deserializer.ReadProperty("auto_detect_hive_partitioning", result.auto_detect_hive_partitioning); + deserializer.ReadProperty("union_by_name", result.union_by_name); + deserializer.ReadProperty("hive_types_autocast", result.hive_types_autocast); + deserializer.ReadProperty("hive_types_schema", result.hive_types_schema); + return result; +} + void OrderByNode::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("type", type); serializer.WriteProperty("null_order", null_order); @@ -211,6 +353,34 @@ PivotColumn PivotColumn::FormatDeserialize(FormatDeserializer &deserializer) { return result; } +void ReadCSVData::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("files", files); + serializer.WriteProperty("csv_types", csv_types); + serializer.WriteProperty("csv_names", csv_names); + serializer.WriteProperty("return_types", return_types); + serializer.WriteProperty("return_names", return_names); + serializer.WriteProperty("filename_col_idx", filename_col_idx); + serializer.WriteProperty("options", options); + serializer.WriteProperty("single_threaded", single_threaded); + serializer.WriteProperty("reader_bind", reader_bind); + serializer.WriteProperty("column_info", column_info); +} + +unique_ptr ReadCSVData::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = duckdb::unique_ptr(new ReadCSVData()); + deserializer.ReadProperty("files", result->files); + deserializer.ReadProperty("csv_types", result->csv_types); + deserializer.ReadProperty("csv_names", result->csv_names); + deserializer.ReadProperty("return_types", result->return_types); + deserializer.ReadProperty("return_names", result->return_names); + deserializer.ReadProperty("filename_col_idx", result->filename_col_idx); + deserializer.ReadProperty("options", result->options); + deserializer.ReadProperty("single_threaded", result->single_threaded); + deserializer.ReadProperty("reader_bind", result->reader_bind); + deserializer.ReadProperty("column_info", result->column_info); + return result; +} + void SampleOptions::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("sample_size", sample_size); serializer.WriteProperty("is_percentage", is_percentage); @@ -227,6 +397,26 @@ unique_ptr SampleOptions::FormatDeserialize(FormatDeserializer &d return result; } +void StrpTimeFormat::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("format_specifier", format_specifier); +} + +StrpTimeFormat StrpTimeFormat::FormatDeserialize(FormatDeserializer &deserializer) { + auto format_specifier = deserializer.ReadProperty("format_specifier"); + StrpTimeFormat result(format_specifier); + return result; +} + +void TableFilterSet::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("filters", filters); +} + +TableFilterSet TableFilterSet::FormatDeserialize(FormatDeserializer &deserializer) { + TableFilterSet result; + deserializer.ReadProperty("filters", result.filters); + return result; +} + void VacuumOptions::FormatSerialize(FormatSerializer &serializer) const { serializer.WriteProperty("vacuum", vacuum); serializer.WriteProperty("analyze", analyze); diff --git a/src/storage/serialization/serialize_table_filter.cpp b/src/storage/serialization/serialize_table_filter.cpp new file mode 100644 index 000000000000..175cba91fa1d --- /dev/null +++ b/src/storage/serialization/serialize_table_filter.cpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_serialization.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/common/serializer/format_serializer.hpp" +#include "duckdb/common/serializer/format_deserializer.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" + +namespace duckdb { + +void TableFilter::FormatSerialize(FormatSerializer &serializer) const { + serializer.WriteProperty("filter_type", filter_type); +} + +unique_ptr TableFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto filter_type = deserializer.ReadProperty("filter_type"); + unique_ptr result; + switch (filter_type) { + case TableFilterType::CONJUNCTION_AND: + result = ConjunctionAndFilter::FormatDeserialize(deserializer); + break; + case TableFilterType::CONJUNCTION_OR: + result = ConjunctionOrFilter::FormatDeserialize(deserializer); + break; + case TableFilterType::CONSTANT_COMPARISON: + result = ConstantFilter::FormatDeserialize(deserializer); + break; + case TableFilterType::IS_NOT_NULL: + result = IsNotNullFilter::FormatDeserialize(deserializer); + break; + case TableFilterType::IS_NULL: + result = IsNullFilter::FormatDeserialize(deserializer); + break; + default: + throw SerializationException("Unsupported type for deserialization of TableFilter!"); + } + return result; +} + +void ConjunctionAndFilter::FormatSerialize(FormatSerializer &serializer) const { + TableFilter::FormatSerialize(serializer); + serializer.WriteProperty("child_filters", child_filters); +} + +unique_ptr ConjunctionAndFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConjunctionAndFilter()); + deserializer.ReadProperty("child_filters", result->child_filters); + return std::move(result); +} + +void ConjunctionOrFilter::FormatSerialize(FormatSerializer &serializer) const { + TableFilter::FormatSerialize(serializer); + serializer.WriteProperty("child_filters", child_filters); +} + +unique_ptr ConjunctionOrFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConjunctionOrFilter()); + deserializer.ReadProperty("child_filters", result->child_filters); + return std::move(result); +} + +void ConstantFilter::FormatSerialize(FormatSerializer &serializer) const { + TableFilter::FormatSerialize(serializer); + serializer.WriteProperty("comparison_type", comparison_type); + serializer.WriteProperty("constant", constant); +} + +unique_ptr ConstantFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto comparison_type = deserializer.ReadProperty("comparison_type"); + auto constant = deserializer.ReadProperty("constant"); + auto result = duckdb::unique_ptr(new ConstantFilter(comparison_type, constant)); + return std::move(result); +} + +void IsNotNullFilter::FormatSerialize(FormatSerializer &serializer) const { + TableFilter::FormatSerialize(serializer); +} + +unique_ptr IsNotNullFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = duckdb::unique_ptr(new IsNotNullFilter()); + return std::move(result); +} + +void IsNullFilter::FormatSerialize(FormatSerializer &serializer) const { + TableFilter::FormatSerialize(serializer); +} + +unique_ptr IsNullFilter::FormatDeserialize(FormatDeserializer &deserializer) { + auto result = duckdb::unique_ptr(new IsNullFilter()); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/storage/storage_info.cpp b/src/storage/storage_info.cpp index 3b198128d1a1..f22ca733d957 100644 --- a/src/storage/storage_info.cpp +++ b/src/storage/storage_info.cpp @@ -2,7 +2,7 @@ namespace duckdb { -const uint64_t VERSION_NUMBER = 53; +const uint64_t VERSION_NUMBER = 54; struct StorageVersionInfo { const char *version_name; diff --git a/src/storage/table/list_column_data.cpp b/src/storage/table/list_column_data.cpp index 9af4b92decc0..c009cd8f497c 100644 --- a/src/storage/table/list_column_data.cpp +++ b/src/storage/table/list_column_data.cpp @@ -89,16 +89,19 @@ idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t co D_ASSERT(scan_count > 0); validity.ScanCount(state.child_states[0], result, count); - auto data = FlatVector::GetData(offset_vector); - auto last_entry = data[scan_count - 1]; + UnifiedVectorFormat offsets; + offset_vector.ToUnifiedFormat(scan_count, offsets); + auto data = UnifiedVectorFormat::GetData(offsets); + auto last_entry = data[offsets.sel->get_index(scan_count - 1)]; // shift all offsets so they are 0 at the first entry auto result_data = FlatVector::GetData(result); auto base_offset = state.last_offset; idx_t current_offset = 0; for (idx_t i = 0; i < scan_count; i++) { + auto offset_index = offsets.sel->get_index(i); result_data[i].offset = current_offset; - result_data[i].length = data[i] - current_offset - base_offset; + result_data[i].length = data[offset_index] - current_offset - base_offset; current_offset += result_data[i].length; } diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index d42222be7895..f7524d424eae 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -36,7 +36,6 @@ bool WriteAheadLog::Replay(AttachedDatabase &database, string &path) { // first deserialize the WAL to look for a checkpoint flag // if there is a checkpoint flag, we might have already flushed the contents of the WAL to disk ReplayState checkpoint_state(database, *con.context, *initial_reader); - initial_reader->SetCatalog(checkpoint_state.catalog); checkpoint_state.deserialize_only = true; try { while (true) { @@ -73,7 +72,6 @@ bool WriteAheadLog::Replay(AttachedDatabase &database, string &path) { // we need to recover from the WAL: actually set up the replay state BufferedFileReader reader(FileSystem::Get(database), path.c_str(), con.context.get()); - reader.SetCatalog(checkpoint_state.catalog); ReplayState state(database, *con.context, reader); // replay the WAL @@ -284,7 +282,7 @@ void ReplayState::ReplayDropSchema() { void ReplayState::ReplayCreateType() { auto info = TypeCatalogEntry::Deserialize(source); info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - catalog.CreateType(context, *info); + catalog.CreateType(context, info->Cast()); } void ReplayState::ReplayDropType() { diff --git a/test/api/capi/test_capi_streaming.cpp b/test/api/capi/test_capi_streaming.cpp index 820d2516c088..0b308885ed76 100644 --- a/test/api/capi/test_capi_streaming.cpp +++ b/test/api/capi/test_capi_streaming.cpp @@ -97,3 +97,45 @@ TEST_CASE("Test other methods on streaming results in C API", "[capi]") { auto is_null = result->IsNull(0, 0); REQUIRE(is_null == false); } + +TEST_CASE("Test query progress and interrupt in C API", "[capi]") { + CAPITester tester; + CAPIPrepared prepared; + CAPIPending pending; + duckdb::unique_ptr result; + + // test null handling + REQUIRE(duckdb_query_progress(nullptr) == -1.0); + duckdb_interrupt(nullptr); + + // open the database in in-memory mode + REQUIRE(tester.OpenDatabase(nullptr)); + REQUIRE_NO_FAIL(tester.Query("SET threads=1")); + REQUIRE_NO_FAIL(tester.Query("create table tbl as select range a, mod(range,10) b from range(10000);")); + REQUIRE_NO_FAIL(tester.Query("create table tbl_2 as select range a from range(10000);")); + REQUIRE_NO_FAIL(tester.Query("set enable_progress_bar=true;")); + REQUIRE_NO_FAIL(tester.Query("set enable_progress_bar_print=false;")); + // test no progress before query + REQUIRE(duckdb_query_progress(tester.connection) == -1.0); + // test zero progress with query + REQUIRE(prepared.Prepare(tester, "select count(*) from tbl where a = (select min(a) from tbl_2)")); + REQUIRE(pending.PendingStreaming(prepared)); + REQUIRE(duckdb_query_progress(tester.connection) == 0.0); + + // test progress + while (duckdb_query_progress(tester.connection) == 0.0) { + auto state = pending.ExecuteTask(); + REQUIRE(state == DUCKDB_PENDING_RESULT_NOT_READY); + } + REQUIRE(duckdb_query_progress(tester.connection) >= 0.0); + + // test interrupt + duckdb_interrupt(tester.connection); + while (true) { + auto state = pending.ExecuteTask(); + REQUIRE(state != DUCKDB_PENDING_RESULT_READY); + if (state == DUCKDB_PENDING_ERROR) { + break; + } + } +} diff --git a/test/api/serialized_plans/serialized_plans.binary b/test/api/serialized_plans/serialized_plans.binary index a1c408a928f7..841773beff43 100644 Binary files a/test/api/serialized_plans/serialized_plans.binary and b/test/api/serialized_plans/serialized_plans.binary differ diff --git a/test/extension/loadable_extension_demo.cpp b/test/extension/loadable_extension_demo.cpp index bc37ae09332b..c587a868589e 100644 --- a/test/extension/loadable_extension_demo.cpp +++ b/test/extension/loadable_extension_demo.cpp @@ -242,8 +242,7 @@ DUCKDB_EXTENSION_API void loadable_extension_demo_init(duckdb::DatabaseInstance target_type.SetAlias(alias_name); alias_info->type = target_type; - auto &entry = catalog.CreateType(client_context, *alias_info)->Cast(); - EnumType::SetCatalog(target_type, &entry); + catalog.CreateType(client_context, *alias_info); // Function add point ScalarFunction add_point_func("add_point", {target_type, target_type}, target_type, AddPointFunction); diff --git a/test/parquet/encrypted_parquet.test b/test/parquet/encrypted_parquet.test new file mode 100644 index 000000000000..81c939902a51 --- /dev/null +++ b/test/parquet/encrypted_parquet.test @@ -0,0 +1,18 @@ +# name: test/parquet/encrypted_parquet.test +# description: Test Parquet reader on data/parquet-testing/encryption +# group: [parquet] + +require parquet + +statement ok +PRAGMA enable_verification + +statement error +SELECT * FROM parquet_scan('data/parquet-testing/encryption/encrypted_footer.parquet') limit 50; +---- +Invalid Input Error: Encrypted Parquet files are not supported for file 'data/parquet-testing/encryption/encrypted_footer.parquet' + +statement error +SELECT * FROM parquet_scan('data/parquet-testing/encryption/encrypted_column.parquet') limit 50; +---- +Invalid Error: Failed to read Parquet file "data/parquet-testing/encryption/encrypted_column.parquet": Encrypted Parquet files are not supported \ No newline at end of file diff --git a/test/sql/catalog/function/test_recursive_macro.test b/test/sql/catalog/function/test_recursive_macro.test index ea877bca26d6..be9739561679 100644 --- a/test/sql/catalog/function/test_recursive_macro.test +++ b/test/sql/catalog/function/test_recursive_macro.test @@ -5,15 +5,15 @@ statement ok CREATE MACRO "sum"(x) AS (CASE WHEN sum(x) IS NULL THEN 0 ELSE sum(x) END); -query I +statement error SELECT sum(1); ---- -1 +Binder Error: Maximum recursion depth exceeded -query I +statement error SELECT sum(1) WHERE 42=0 ---- -0 +Binder Error: Maximum recursion depth exceeded statement ok DROP MACRO sum @@ -31,3 +31,33 @@ query I SELECT sum(1) WHERE 42=0 ---- 0 + +# evil test case by Mark +statement ok +create macro m1(a) as a+1; + +statement ok +create macro m2(a) as m1(a)+1; + +statement ok +create or replace macro m1(a) as m2(a)+1; + +statement error +select m2(42); +---- +Binder Error: Maximum recursion depth exceeded + +# also table macros +statement ok +create macro m3(a) as a+1; + +statement ok +create macro m4(a) as table select m3(a); + +statement ok +create or replace macro m3(a) as (from m4(42)); + +statement error +select m3(42); +---- +Binder Error: Maximum recursion depth exceeded diff --git a/test/sql/copy/csv/read_csv_subquery.test b/test/sql/copy/csv/read_csv_subquery.test index fada62f69345..0eb5c1576fd5 100644 --- a/test/sql/copy/csv/read_csv_subquery.test +++ b/test/sql/copy/csv/read_csv_subquery.test @@ -10,7 +10,7 @@ WITH urls AS ( SELECT 'a.csv' AS url UNION ALL SELECT 'b.csv' ) SELECT * -FROM read_csv_auto((SELECT url FROM urls LIMIT 3), delimiter=',') +FROM read_csv_auto((SELECT url FROM urls LIMIT 3), delim=',') WHERE properties.height > -1.0 LIMIT 10 ---- @@ -30,6 +30,6 @@ Table function cannot contain aggregates statement error SELECT * -FROM read_csv_auto('a.csv', delimiter=',', 42) +FROM read_csv_auto('a.csv', delim=',', 42) ---- Unnamed parameters cannot come after named parameters diff --git a/test/sql/fts/test_fts_attach.test b/test/sql/fts/test_fts_attach.test index 7c4c5c2820ac..530108e0ce82 100644 --- a/test/sql/fts/test_fts_attach.test +++ b/test/sql/fts/test_fts_attach.test @@ -19,3 +19,47 @@ PRAGMA create_fts_index(search_con.main.my_table, 'CustomerId', 'CustomerName') statement ok SELECT search_con.fts_main_my_table.match_bm25(1, 'han') + +statement ok +DETACH search_con + +# test reopened #8141 +load __TEST_DIR__/index.db + +statement ok +CREATE TABLE data AS SELECT 0 __index, 0 id, 'lorem ipsum' nl, NULL code; + +statement ok +PRAGMA create_fts_index('data', '__index', '*', overwrite=1); + +# test that it works before doing the problematic stuff +query IIII +SELECT * FROM data WHERE fts_main_data.match_bm25(__index, 'lorem') IS NOT NULL; +---- +0 0 lorem ipsum NULL + +statement ok +ATTACH ':memory:' AS memory; + +statement ok +USE memory; + +statement ok +DETACH "index"; + +# now attach again +statement ok +ATTACH '__TEST_DIR__/index.db' AS db; + +statement ok +USE db; + +query T +SELECT COUNT(*) FROM data; +---- +1 + +query IIII +SELECT * FROM data WHERE fts_main_data.match_bm25(__index, 'lorem') IS NOT NULL; +---- +0 0 lorem ipsum NULL diff --git a/test/sql/join/asof/test_asof_join_merge.test b/test/sql/join/asof/test_asof_join_merge.test_slow similarity index 82% rename from test/sql/join/asof/test_asof_join_merge.test rename to test/sql/join/asof/test_asof_join_merge.test_slow index ae92cf02e676..544deaad4cd5 100644 --- a/test/sql/join/asof/test_asof_join_merge.test +++ b/test/sql/join/asof/test_asof_join_merge.test_slow @@ -1,4 +1,4 @@ -# name: test/sql/join/asof/test_asof_join_merge.test +# name: test/sql/join/asof/test_asof_join_merge.test_slow # description: Test merge queue and repartitioning # group: [asof] @@ -8,6 +8,9 @@ PRAGMA memory_limit='400M' statement ok PRAGMA threads=4 +statement ok +SET temp_directory='__TEST_DIR__/temp.tmp' + query II WITH build AS ( SELECT k, ('2021-01-01'::TIMESTAMP + INTERVAL (i) SECOND) AS t, i % 37 AS v diff --git a/test/sql/storage/compression/rle/rle_constant.test b/test/sql/storage/compression/rle/rle_constant.test index f3bd302c969f..74cf52e29605 100644 --- a/test/sql/storage/compression/rle/rle_constant.test +++ b/test/sql/storage/compression/rle/rle_constant.test @@ -5,6 +5,8 @@ # load the DB from disk load __TEST_DIR__/test_rle.db +require vector_size 2048 + statement ok PRAGMA force_compression = 'rle' diff --git a/test/sql/storage_version/storage_version.db b/test/sql/storage_version/storage_version.db index 6835427ede2e..98dafcef180e 100644 Binary files a/test/sql/storage_version/storage_version.db and b/test/sql/storage_version/storage_version.db differ diff --git a/test/sql/types/alias/test_alias_struct_nested_alias.test b/test/sql/types/alias/test_alias_struct_nested_alias.test new file mode 100644 index 000000000000..4ef55a8e3353 --- /dev/null +++ b/test/sql/types/alias/test_alias_struct_nested_alias.test @@ -0,0 +1,37 @@ +# name: test/sql/types/alias/test_alias_struct_nested_alias.test +# description: Test creates alias for struct type +# group: [alias] + +require skip_reload + +statement ok +PRAGMA enable_verification + +# Create a USER type +statement ok +CREATE TYPE foobar AS ENUM( + 'Foo', + 'Bar' +); + +# Create a USER type which is a STRUCT that contains another USER type +statement ok +CREATE TYPE top_nest AS STRUCT( + foobar FOOBAR +); + +# Create a table out of this type +statement ok +CREATE TABLE failing ( + top_nest TOP_NEST +); + +statement ok +insert into failing VALUES ( + {'foobar': 'Foo'} +) + +query I +SELECT top_nest FROM failing; +---- +{'foobar': Foo} diff --git a/test/sql/types/enum/standalone_enum.test b/test/sql/types/enum/standalone_enum.test new file mode 100644 index 000000000000..e10db16132a3 --- /dev/null +++ b/test/sql/types/enum/standalone_enum.test @@ -0,0 +1,34 @@ +# name: test/sql/types/enum/standalone_enum.test +# description: Test stand-alone enums +# group: [enum] + +statement ok +PRAGMA enable_verification + +query I +SELECT 'hello'::ENUM('world', 'hello'); +---- +hello + +statement ok +CREATE TABLE test AS SELECT 'hello'::ENUM('world', 'hello') AS h; + +query I +SELECT * FROM test +---- +hello + +statement error +SELECT 'hello'::ENUM; +---- +Enum needs a set of entries + +statement error +SELECT 'hello'::ENUM(42); +---- +Enum type requires a set of strings as type modifiers + +statement error +SELECT 'hello'::ENUM('zzz', 42); +---- +Enum type requires a set of strings as type modifiers diff --git a/test/sql/types/enum/test_alter_enum.test b/test/sql/types/enum/test_alter_enum.test index 99daca6a8dff..868daf64d26c 100644 --- a/test/sql/types/enum/test_alter_enum.test +++ b/test/sql/types/enum/test_alter_enum.test @@ -42,7 +42,10 @@ Mark query I select typeof(name) from person limit 1 ---- -name_enum +ENUM('Pedro', 'Mark', 'Hannes') + +# FIXME: dependencies between enums and tables are currently disabled +mode skip # This should not be possible statement error diff --git a/test/sql/types/enum/test_enum_constraints.test b/test/sql/types/enum/test_enum_constraints.test index 9d8b732fcf0d..061838079593 100644 --- a/test/sql/types/enum/test_enum_constraints.test +++ b/test/sql/types/enum/test_enum_constraints.test @@ -5,6 +5,9 @@ statement ok PRAGMA enable_verification +# FIXME: dependencies between enums and tables are currently disabled +mode skip + statement ok CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); diff --git a/test/sql/types/enum/test_enum_storage.test b/test/sql/types/enum/test_enum_storage.test index add278fb8227..d4d8fc04241b 100644 --- a/test/sql/types/enum/test_enum_storage.test +++ b/test/sql/types/enum/test_enum_storage.test @@ -71,13 +71,20 @@ loop i 0 2 restart +# FIXME: dependencies between enums and tables are currently disabled +mode skip + # We cant drop mood and we verify the constraint is still there after reloading the database statement error DROP TYPE mood; +---- +Dependency query TT select * from person ---- Moe happy -endloop \ No newline at end of file +mode unskip + +endloop diff --git a/test/sql/types/enum/test_enum_structs.test b/test/sql/types/enum/test_enum_structs.test new file mode 100644 index 000000000000..499717d21924 --- /dev/null +++ b/test/sql/types/enum/test_enum_structs.test @@ -0,0 +1,121 @@ +# name: test/sql/types/enum/test_enum_structs.test +# description: ENUM types used inside structs +# group: [enum] + +statement ok +PRAGMA enable_verification + +# load the DB from disk +load __TEST_DIR__/enum_types.db + +statement ok +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + +statement ok +CREATE TABLE person ( + id INTEGER, + c STRUCT( + name text, + current_mood mood + ) +); + +statement ok +INSERT INTO person VALUES (1, ROW('Mark', 'happy')); + +query II +FROM person +---- +1 {'name': Mark, 'current_mood': happy} + +# FIXME: dependencies between enums and tables are currently disabled +mode skip + +statement error +DROP TYPE mood +---- +Dependency + +mode unskip + +restart + +query II +FROM person +---- +1 {'name': Mark, 'current_mood': happy} + +# after dropping the column we can drop the type +statement ok +ALTER TABLE person DROP COLUMN c + +statement ok +DROP TYPE mood + +statement ok +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + +statement ok +ALTER TABLE person ADD COLUMN c STRUCT( + name text, + current_mood mood + ) + +# FIXME: dependencies between enums and tables are currently disabled +mode skip + +# we cannot drop the type after adding it back +statement error +DROP TYPE mood +---- +Dependency + +mode unskip + +statement ok +ALTER TABLE person DROP COLUMN c + +statement ok +DROP TYPE mood + +statement ok +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + +statement ok +ALTER TABLE person ADD COLUMN c INT + +statement ok +ALTER TABLE person ALTER c SET DATA TYPE STRUCT( + name text, + current_mood mood + ) + +query II +FROM person +---- +1 NULL + +# FIXME: dependencies between enums and tables are currently disabled +mode skip + +# we cannot drop the type after adding it using an alter type +statement error +DROP TYPE mood +---- +Dependency + +mode unskip + +statement ok +UPDATE person SET c=ROW('Mark', 'happy') + +query II +FROM person +---- +1 {'name': Mark, 'current_mood': happy} + +statement ok +DROP TABLE person + +statement ok +DROP TYPE mood diff --git a/test/sqlite/sqllogic_test_runner.cpp b/test/sqlite/sqllogic_test_runner.cpp index ea4ffc2a20c0..c528ee927e02 100644 --- a/test/sqlite/sqllogic_test_runner.cpp +++ b/test/sqlite/sqllogic_test_runner.cpp @@ -201,7 +201,10 @@ void SQLLogicTestRunner::ExecuteFile(string script) { // for the original SQLite tests we convert floating point numbers to integers // for our own tests this is undesirable since it hides certain errors - if (script.find("sqlite") != string::npos || script.find("sqllogictest") != string::npos) { + if (script.find("test/sqlite/select") != string::npos) { + original_sqlite_test = true; + } + if (script.find("third_party/sqllogictest") != string::npos) { original_sqlite_test = true; } diff --git a/tools/nodejs/src/statement.cpp b/tools/nodejs/src/statement.cpp index 57e6eb2b8094..0e1299967321 100644 --- a/tools/nodejs/src/statement.cpp +++ b/tools/nodejs/src/statement.cpp @@ -548,7 +548,6 @@ static Napi::Value TypeToObject(Napi::Env &env, const duckdb::LogicalType &type) obj.Set("value", TypeToObject(env, value_type)); } break; case duckdb::LogicalTypeId::ENUM: { - auto name = duckdb::EnumType::GetTypeName(type); auto &values_vec = duckdb::EnumType::GetValuesInsertOrder(type); auto enum_size = duckdb::EnumType::GetSize(type); auto arr = Napi::Array::New(env, enum_size); @@ -556,7 +555,6 @@ static Napi::Value TypeToObject(Napi::Env &env, const duckdb::LogicalType &type) auto child_name = values_vec.GetValue(i).GetValue(); arr.Set(i, child_name); } - obj.Set("name", name); obj.Set("values", arr); } break; case duckdb::LogicalTypeId::UNION: { diff --git a/tools/nodejs/test/columns.test.ts b/tools/nodejs/test/columns.test.ts index 8275d0a52171..55dabce29daa 100644 --- a/tools/nodejs/test/columns.test.ts +++ b/tools/nodejs/test/columns.test.ts @@ -48,8 +48,7 @@ describe('Column Types', function() { name: 'small_enum', type: { id: 'ENUM', - sql_type: 'small_enum', - name: 'small_enum', + sql_type: "ENUM('DUCK_DUCK_ENUM', 'GOOSE')", values: [ "DUCK_DUCK_ENUM", "GOOSE" diff --git a/tools/odbc/connection.cpp b/tools/odbc/connection.cpp index 1b235637c1d0..3b806c37ff67 100644 --- a/tools/odbc/connection.cpp +++ b/tools/odbc/connection.cpp @@ -101,6 +101,17 @@ SQLRETURN SQL_API SQLGetConnectAttr(SQLHDBC connection_handle, SQLINTEGER attrib } } +/** + * @brief Sets attribute for connection + * @param connection_handle + * @param attribute Attribute to set, for full list see: + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlsetconnectattr-function?view=sql-server-ver15#comments + * @param value_ptr Value to set, depending on the attribute, could be either an unsigned integer or a pointer to a null + * terminated string. + * @param string_length Length of the string, if the attribute is a string, in bytes. If the attribute is an integer, + * this value is ignored. + * @return SQL return code + */ SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connection_handle, SQLINTEGER attribute, SQLPOINTER value_ptr, SQLINTEGER string_length) { // attributes before connection @@ -1004,6 +1015,18 @@ SQLRETURN SQL_API SQLGetInfo(SQLHDBC connection_handle, SQLUSMALLINT info_type, } } // end SQLGetInfo +/** + * @brief Requests a commit or rollback operation for all active operations on all statements associated with a + * connection. https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlendtran-function?view=sql-server-ver15 + * @param handle_type Can either be SQL_HANDLE_ENV or SQL_HANDLE_DBC + * @param handle The input handle + * @param completion_type Can either be SQL_COMMIT or SQL_ROLLBACK + * + * For more about committing and rolling back transactions, see: + * https://learn.microsoft.com/en-us/sql/odbc/reference/develop-app/committing-and-rolling-back-transactions?view=sql-server-ver15 + * + * @return + */ SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT completion_type) { if (handle_type != SQL_HANDLE_DBC) { // theoretically this can also be done on env but no no no return SQL_ERROR; diff --git a/tools/odbc/driver.cpp b/tools/odbc/driver.cpp index 808dd2ffe2a2..061b1d6f9b13 100644 --- a/tools/odbc/driver.cpp +++ b/tools/odbc/driver.cpp @@ -49,10 +49,25 @@ SQLRETURN duckdb::FreeHandle(SQLSMALLINT handle_type, SQLHANDLE handle) { } } +/** + * @brief Frees a handle + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlfreehandle-function?view=sql-server-ver15 + * @param handle_type + * @param handle + * @return SQL return code + */ SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT handle_type, SQLHANDLE handle) { return duckdb::FreeHandle(handle_type, handle); } +/** + * @brief Allocates a handle + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlallochandle-function?view=sql-server-ver15 + * @param handle_type Can be SQL_HANDLE_ENV, SQL_HANDLE_DBC, SQL_HANDLE_STMT, SQL_HANDLE_DESC + * @param input_handle Handle to associate with the new handle, if applicable + * @param output_handle_ptr The new handle + * @return + */ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handle_type, SQLHANDLE input_handle, SQLHANDLE *output_handle_ptr) { switch (handle_type) { case SQL_HANDLE_DBC: { @@ -84,6 +99,18 @@ SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handle_type, SQLHANDLE input_handle } } +static SQLUINTEGER ExtractMajorVersion(SQLPOINTER value_ptr) { + // Values like 380 represent version 3.8, here we extract the major version (3 in this case) + auto full_version = (SQLUINTEGER)(uintptr_t)value_ptr; + if (full_version > 100) { + return full_version / 100; + } + if (full_version > 10) { + return full_version / 10; + } + return full_version; +} + SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV environment_handle, SQLINTEGER attribute, SQLPOINTER value_ptr, SQLINTEGER string_length) { duckdb::OdbcHandleEnv *env = nullptr; @@ -93,7 +120,8 @@ SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV environment_handle, SQLINTEGER attribute switch (attribute) { case SQL_ATTR_ODBC_VERSION: { - switch ((SQLUINTEGER)(intptr_t)value_ptr) { + auto major_version = ExtractMajorVersion(value_ptr); + switch (major_version) { case SQL_OV_ODBC3: case SQL_OV_ODBC2: // TODO actually do something with this? diff --git a/tools/odbc/include/parameter_descriptor.hpp b/tools/odbc/include/parameter_descriptor.hpp index fe03ca14572b..50f1283ab546 100644 --- a/tools/odbc/include/parameter_descriptor.hpp +++ b/tools/odbc/include/parameter_descriptor.hpp @@ -51,11 +51,9 @@ class ParameterDescriptor { void SetSQLDescDataPtr(DescRecord &apd_record, SQLPOINTER data_ptr); SQLLEN *GetSQLDescIndicatorPtr(DescRecord &apd_record, idx_t set_idx = 0); - void SetSQLDescIndicatorPtr(DescRecord &apd_record, SQLLEN *ind_ptr); void SetSQLDescIndicatorPtr(DescRecord &apd_record, SQLLEN value); SQLLEN *GetSQLDescOctetLengthPtr(DescRecord &apd_record, idx_t set_idx = 0); - void SetSQLDescOctetLengthPtr(DescRecord &apd_record, SQLLEN *ind_ptr); private: OdbcHandleStmt *stmt; diff --git a/tools/odbc/linux_setup/unixodbc_setup.sh b/tools/odbc/linux_setup/unixodbc_setup.sh index 42d9d4e34491..af917348c168 100755 --- a/tools/odbc/linux_setup/unixodbc_setup.sh +++ b/tools/odbc/linux_setup/unixodbc_setup.sh @@ -26,8 +26,8 @@ function ReadArgs() { "-D") shift DRIVER_PATH=$1 - if grep -qv "libduckdb_odbc.so" <<< $DRIVER_PATH; then - printf "\n****Driver path doesn't contain 'libduckdb_odbc.so'****\n\n" + if grep -qv "libduckdb_odbc" <<< $DRIVER_PATH; then + printf "\n****Driver path doesn't contain 'libduckdb_odbc'****\n\n" Usage fi shift @@ -60,13 +60,8 @@ EOF } function ConfigUserInstFile() { - INST_SETUP=$1 - if test -f ~/.odbcinst.ini; then - #file already exist - sed -i "/DuckDB Driver/{n;s#.*libduckdb_odbc.so#Driver=${DRIVER_PATH}#}" ~/.odbcinst.ini - else - cp $INST_SETUP ~/.odbcinst.ini - fi + SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + DRIVER_PATH="$DRIVER_PATH" python3 "${SCRIPT_DIR}/update_odbc_path.py" } # Exit immediately if a command exits with a non-zero status. diff --git a/tools/odbc/linux_setup/update_odbc_path.py b/tools/odbc/linux_setup/update_odbc_path.py new file mode 100644 index 000000000000..abe9f7994457 --- /dev/null +++ b/tools/odbc/linux_setup/update_odbc_path.py @@ -0,0 +1,26 @@ +import os +import configparser + +# Fetch DRIVER_PATH from the environment +DRIVER_PATH = os.environ.get("DRIVER_PATH") + +if not DRIVER_PATH: + raise ValueError("Environment variable DRIVER_PATH is not set.") + +# The path to the .odbcinst.ini file +config_file_path = os.path.expanduser("~/.odbcinst.ini") + +# Create a ConfigParser object and read the existing .odbcinst.ini file +config = configparser.ConfigParser() +config.read(config_file_path) + +# Update the 'DuckDB Driver' section with the new DRIVER_PATH +if "DuckDB Driver" in config: + config["DuckDB Driver"]["Driver"] = DRIVER_PATH +else: + config.add_section("DuckDB Driver") + config["DuckDB Driver"]["Driver"] = DRIVER_PATH + +# Write the modified configuration back to the .odbcinst.ini file +with open(config_file_path, "w") as configfile: + config.write(configfile) diff --git a/tools/odbc/parameter_descriptor.cpp b/tools/odbc/parameter_descriptor.cpp index 6b5d6be3a269..4de793d4e055 100644 --- a/tools/odbc/parameter_descriptor.cpp +++ b/tools/odbc/parameter_descriptor.cpp @@ -417,14 +417,6 @@ SQLLEN *ParameterDescriptor::GetSQLDescIndicatorPtr(DescRecord &apd_record, idx_ return apd_record.sql_desc_indicator_ptr + set_idx; } -void ParameterDescriptor::SetSQLDescIndicatorPtr(DescRecord &apd_record, SQLLEN *ind_ptr) { - auto sql_ind_ptr = apd_record.sql_desc_indicator_ptr; - if (cur_apd->header.sql_desc_bind_offset_ptr) { - sql_ind_ptr += *cur_apd->header.sql_desc_bind_offset_ptr; - } - sql_ind_ptr = ind_ptr; -} - void ParameterDescriptor::SetSQLDescIndicatorPtr(DescRecord &apd_record, SQLLEN value) { auto sql_ind_ptr = apd_record.sql_desc_indicator_ptr; if (cur_apd->header.sql_desc_bind_offset_ptr) { @@ -439,11 +431,3 @@ SQLLEN *ParameterDescriptor::GetSQLDescOctetLengthPtr(DescRecord &apd_record, id } return apd_record.sql_desc_octet_length_ptr + set_idx; } - -void ParameterDescriptor::SetSQLDescOctetLengthPtr(DescRecord &apd_record, SQLLEN *len_ptr) { - auto sql_len_ptr = apd_record.sql_desc_octet_length_ptr; - if (cur_apd->header.sql_desc_bind_offset_ptr) { - sql_len_ptr += *cur_apd->header.sql_desc_bind_offset_ptr; - } - sql_len_ptr = len_ptr; -} diff --git a/tools/odbc/prepared.cpp b/tools/odbc/prepared.cpp index 64ad80a13333..3bbdd263c6d3 100644 --- a/tools/odbc/prepared.cpp +++ b/tools/odbc/prepared.cpp @@ -6,8 +6,6 @@ #include "duckdb/main/prepared_statement_data.hpp" -#include - using duckdb::hugeint_t; using duckdb::idx_t; using duckdb::Load; @@ -220,6 +218,13 @@ SQLRETURN SQL_API SQLDescribeCol(SQLHSTMT statement_handle, SQLUSMALLINT column_ return SQL_SUCCESS; } +/** + * @brief Used together with SQLPutData and SQLGetData to support binding to a column or parameter data. + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlparamdata-function?view=sql-server-ver15 + * @param statement_handle + * @param value_ptr_ptr Pointer to a buffer in which to return the address of the bound data buffer. + * @return + */ SQLRETURN SQL_API SQLParamData(SQLHSTMT statement_handle, SQLPOINTER *value_ptr_ptr) { duckdb::OdbcHandleStmt *hstmt = nullptr; if (ConvertHSTMTPrepared(statement_handle, hstmt) != SQL_SUCCESS) { @@ -234,6 +239,15 @@ SQLRETURN SQL_API SQLParamData(SQLHSTMT statement_handle, SQLPOINTER *value_ptr_ return SQL_SUCCESS; } +/** + * @brief Allows the application to set data for a parameter or column at execution time. + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlputdata-function?view=sql-server-ver15 + * @param statement_handle + * @param data_ptr Pointer to a buffer containing the data to be sent to the data source. Must be in the C data type + * specified in SQLBindParameter or SQLBindCol. + * @param str_len_or_ind_ptr Length of data_ptr. + * @return + */ SQLRETURN SQL_API SQLPutData(SQLHSTMT statement_handle, SQLPOINTER data_ptr, SQLLEN str_len_or_ind_ptr) { duckdb::OdbcHandleStmt *hstmt = nullptr; if (ConvertHSTMTPrepared(statement_handle, hstmt) != SQL_SUCCESS) { diff --git a/tools/odbc/statement.cpp b/tools/odbc/statement.cpp index 20f8dae0e607..8ea153c3b78e 100644 --- a/tools/odbc/statement.cpp +++ b/tools/odbc/statement.cpp @@ -331,7 +331,7 @@ SQLRETURN SQL_API SQLCancel(SQLHSTMT statement_handle) { /** *@brief Executes a prepared statement - * + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlexecdirect-function?view=sql-server-ver15 * @param statement_handle A handle to a statement object. Stores information about the statement and the results of a *query. * @param statement_text The text of the query to execute. @@ -679,6 +679,12 @@ SQLRETURN SQL_API SQLFreeStmt(SQLHSTMT statement_handle, SQLUSMALLINT option) { return SQL_ERROR; } +/** + * @brief Determines whether more results are available on a statement, and, if so, prepares the next result set. + * https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlmoreresults-function?view=sql-server-ver15 + * @param statement_handle + * @return + */ SQLRETURN SQL_API SQLMoreResults(SQLHSTMT statement_handle) { duckdb::OdbcHandleStmt *hstmt = nullptr; if (ConvertHSTMT(statement_handle, hstmt) != SQL_SUCCESS) { diff --git a/tools/odbc/test/CMakeLists.txt b/tools/odbc/test/CMakeLists.txt index e418c1058187..096bd2b998f2 100644 --- a/tools/odbc/test/CMakeLists.txt +++ b/tools/odbc/test/CMakeLists.txt @@ -15,6 +15,14 @@ add_executable( tests/diagnostics.cpp tests/select.cpp tests/row_wise_fetching.cpp - tests/set_attr.cpp) + tests/set_attr.cpp + tests/cte.cpp + tests/cursor_commit.cpp + tests/declare_fetch_block.cpp + tests/data_execution.cpp + tests/multicolumn_param_bind.cpp + tests/numeric.cpp + tests/quotes.cpp + tests/result_conversion.cpp) target_link_libraries(test_odbc duckdb_odbc ODBC::ODBC) diff --git a/tools/odbc/test/common.cpp b/tools/odbc/test/common.cpp index 603cdc62fa1e..c81b58aca9d7 100644 --- a/tools/odbc/test/common.cpp +++ b/tools/odbc/test/common.cpp @@ -55,7 +55,7 @@ void ACCESS_DIAGNOSTIC(std::string &state, std::string &message, SQLHANDLE handl } } -void DATA_CHECK(HSTMT hstmt, SQLSMALLINT col_num, const char *expected_content) { +void DATA_CHECK(HSTMT &hstmt, SQLSMALLINT col_num, const std::string expected_content) { SQLCHAR content[256]; SQLLEN content_len; @@ -63,13 +63,13 @@ void DATA_CHECK(HSTMT hstmt, SQLSMALLINT col_num, const char *expected_content) SQLRETURN ret = SQLGetData(hstmt, col_num, SQL_C_CHAR, content, sizeof(content), &content_len); ODBC_CHECK(ret, "SQLGetData"); if (content_len == SQL_NULL_DATA) { - REQUIRE(expected_content == nullptr); + REQUIRE(expected_content.empty()); return; } - REQUIRE(STR_EQUAL(ConvertToCString(content), expected_content)); + REQUIRE(ConvertToString(content) == expected_content); } -void METADATA_CHECK(HSTMT hstmt, SQLUSMALLINT col_num, const std::string &expected_col_name, +void METADATA_CHECK(HSTMT &hstmt, SQLUSMALLINT col_num, const std::string &expected_col_name, SQLSMALLINT expected_col_name_len, SQLSMALLINT expected_col_data_type, SQLULEN expected_col_size, SQLSMALLINT expected_col_decimal_digits, SQLSMALLINT expected_col_nullable) { SQLCHAR col_name[256]; @@ -168,7 +168,7 @@ void EXEC_SQL(HSTMT hstmt, const std::string &query) { EXECUTE_AND_CHECK("SQLExecDirect (" + query + ")", SQLExecDirect, hstmt, ConvertToSQLCHAR(query.c_str()), SQL_NTS); } -void InitializeDatabase(HSTMT hstmt) { +void InitializeDatabase(HSTMT &hstmt) { EXEC_SQL(hstmt, "DROP TABLE IF EXISTS test_table_1;"); EXEC_SQL(hstmt, "CREATE TABLE test_table_1 (id integer PRIMARY KEY, t varchar(20));"); EXEC_SQL(hstmt, "INSERT INTO test_table_1 VALUES (1, 'foo');"); @@ -183,13 +183,13 @@ void InitializeDatabase(HSTMT hstmt) { EXEC_SQL(hstmt, "INSERT INTO bool_table VALUES (4, 'false', false);"); EXEC_SQL(hstmt, "INSERT INTO bool_table VALUES (5, 'not', false);"); - EXEC_SQL(hstmt, "DROP TABLE IF EXISTS byte_table;"); - EXEC_SQL(hstmt, "CREATE TABLE byte_table (id integer, t blob);"); - EXEC_SQL(hstmt, "INSERT INTO byte_table VALUES (1, '\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x10'::blob);"); - EXEC_SQL(hstmt, "INSERT INTO byte_table VALUES (2, 'bar');"); - EXEC_SQL(hstmt, "INSERT INTO byte_table VALUES (3, 'foobar');"); - EXEC_SQL(hstmt, "INSERT INTO byte_table VALUES (4, 'foo');"); - EXEC_SQL(hstmt, "INSERT INTO byte_table VALUES (5, 'barf');"); + EXEC_SQL(hstmt, "DROP TABLE IF EXISTS bytea_table;"); + EXEC_SQL(hstmt, "CREATE TABLE bytea_table (id integer, t blob);"); + EXEC_SQL(hstmt, "INSERT INTO bytea_table VALUES (1, '\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x10'::blob);"); + EXEC_SQL(hstmt, "INSERT INTO bytea_table VALUES (2, 'bar');"); + EXEC_SQL(hstmt, "INSERT INTO bytea_table VALUES (3, 'foobar');"); + EXEC_SQL(hstmt, "INSERT INTO bytea_table VALUES (4, 'foo');"); + EXEC_SQL(hstmt, "INSERT INTO bytea_table VALUES (5, 'barf');"); EXEC_SQL(hstmt, "DROP TABLE IF EXISTS interval_table;"); EXEC_SQL(hstmt, "CREATE TABLE interval_table(id integer, iv interval, d varchar(100));"); @@ -217,6 +217,10 @@ SQLCHAR *ConvertToSQLCHAR(const char *str) { return reinterpret_cast(const_cast(str)); } +SQLCHAR *ConvertToSQLCHAR(const std::string &str) { + return reinterpret_cast(const_cast(str.c_str())); +} + std::string ConvertToString(SQLCHAR *str) { return std::string(reinterpret_cast(str)); } @@ -229,4 +233,17 @@ SQLPOINTER ConvertToSQLPOINTER(uint64_t ptr) { return reinterpret_cast(static_cast(ptr)); } +SQLPOINTER ConvertToSQLPOINTER(const char *str) { + return reinterpret_cast(const_cast(str)); +} + +std::string ConvertHexToString(SQLCHAR val[16], int precision) { + std::stringstream ss; + ss << std::hex << std::uppercase << std::setfill('0'); + for (int i = 0; i < precision; i++) { + ss << std::setw(2) << static_cast(val[i]); + } + return ss.str().substr(0, precision); +} + } // namespace odbc_test diff --git a/tools/odbc/test/common.h b/tools/odbc/test/common.h index 543d880a2d19..ed67484b7525 100644 --- a/tools/odbc/test/common.h +++ b/tools/odbc/test/common.h @@ -55,7 +55,7 @@ void ACCESS_DIAGNOSTIC(std::string &state, std::string &message, SQLHANDLE handl * @param col_num The number of the column in the result set * @param expected_content The expected content of the column */ -void DATA_CHECK(HSTMT hstmt, SQLSMALLINT col_num, const char *expected_content); +void DATA_CHECK(HSTMT &hstmt, SQLSMALLINT col_num, const std::string expected_content); /** * @brief @@ -71,7 +71,7 @@ void DATA_CHECK(HSTMT hstmt, SQLSMALLINT col_num, const char *expected_content); * @param expected_col_decimal_digits * @param expected_col_nullable */ -void METADATA_CHECK(HSTMT hstmt, SQLUSMALLINT col_num, const std::string &expected_col_name, +void METADATA_CHECK(HSTMT &hstmt, SQLUSMALLINT col_num, const std::string &expected_col_name, SQLSMALLINT expected_col_name_len, SQLSMALLINT expected_col_data_type, SQLULEN expected_col_size, SQLSMALLINT expected_col_decimal_digits, SQLSMALLINT expected_col_nullable); @@ -111,15 +111,18 @@ void DISCONNECT_FROM_DATABASE(SQLHANDLE &env, SQLHANDLE &dbc); */ void EXEC_SQL(HSTMT hstmt, const std::string &query); -void InitializeDatabase(HSTMT hstmt); +void InitializeDatabase(HSTMT &hstmt); std::map InitializeTypesMap(); // Converters SQLCHAR *ConvertToSQLCHAR(const char *str); +SQLCHAR *ConvertToSQLCHAR(const std::string &str); std::string ConvertToString(SQLCHAR *str); const char *ConvertToCString(SQLCHAR *str); SQLPOINTER ConvertToSQLPOINTER(uint64_t ptr); +SQLPOINTER ConvertToSQLPOINTER(const char *str); +std::string ConvertHexToString(SQLCHAR val[16], int precision); } // namespace odbc_test diff --git a/tools/odbc/test/run_nanodbc_tests.sh b/tools/odbc/test/run_nanodbc_tests.sh index c764e488ac11..879e59eeabe3 100755 --- a/tools/odbc/test/run_nanodbc_tests.sh +++ b/tools/odbc/test/run_nanodbc_tests.sh @@ -8,8 +8,21 @@ set -e BASE_DIR=$(dirname $0) #Configuring ODBC files -$BASE_DIR/../linux_setup/unixodbc_setup.sh -u -D $(pwd)/build/debug/tools/odbc/libduckdb_odbc.so +# Check the OS and set the "extension" variable +case "$(uname -s)" in + Darwin) + extension="dylib" + ;; + Linux) + extension="so" + ;; + *) + echo "Unsupported OS. Exiting." + exit 1 + ;; +esac +$BASE_DIR/../linux_setup/unixodbc_setup.sh -u -D $(pwd)/build/debug/tools/odbc/libduckdb_odbc.${extension} export NANODBC_TEST_CONNSTR_ODBC="DRIVER=DuckDB Driver;" export ASAN_OPTIONS=verify_asan_link_order=0 diff --git a/tools/odbc/test/run_psqlodbc_tests.sh b/tools/odbc/test/run_psqlodbc_tests.sh index 4d632ff14053..00c0a9f2ff9d 100755 --- a/tools/odbc/test/run_psqlodbc_tests.sh +++ b/tools/odbc/test/run_psqlodbc_tests.sh @@ -35,6 +35,6 @@ do exit 1 fi # clean odbc trace file - rm $TRACE_FILE + rm -f $TRACE_FILE done < ${BASE_DUCKDB_DIR}/tools/odbc/test/psql_supported_tests diff --git a/tools/odbc/test/tests/bools_as_char.cpp b/tools/odbc/test/tests/bools_as_char.cpp index 48410f7ad74b..873b281a5a27 100644 --- a/tools/odbc/test/tests/bools_as_char.cpp +++ b/tools/odbc/test/tests/bools_as_char.cpp @@ -64,7 +64,7 @@ TEST_CASE("Test bools to char conversion", "[odbc]") { // Fetch result METADATA_CHECK(hstmt, 3, "b", sizeof('b'), SQL_CHAR, types_map[SQL_CHAR], 0, SQL_NULLABLE_UNKNOWN); - std::vector expected_data[3] = {{"1", "yeah", "true"}, {"2", "yes", "true"}, {"3", "true", "true"}}; + std::vector expected_data[3] = {{"1", "yeah", "true"}, {"2", "yes", "true"}, {"3", "true", "true"}}; for (int i = 0; i < 3; i++) { EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); diff --git a/tools/odbc/test/tests/catalog_functions.cpp b/tools/odbc/test/tests/catalog_functions.cpp index 65ed48d65082..fb912e18161d 100644 --- a/tools/odbc/test/tests/catalog_functions.cpp +++ b/tools/odbc/test/tests/catalog_functions.cpp @@ -48,11 +48,7 @@ void TestGetTypeInfo(HSTMT &hstmt, std::map &types_map) { auto &entry = expected_data[i].first; METADATA_CHECK(hstmt, i + 1, entry.col_name.c_str(), entry.col_name.length(), entry.col_type, types_map[entry.col_type], 0, SQL_NULLABLE_UNKNOWN); - if (expected_data[i].second.empty()) { - DATA_CHECK(hstmt, i + 1, nullptr); - continue; - } - DATA_CHECK(hstmt, i + 1, expected_data[i].second.c_str()); + DATA_CHECK(hstmt, i + 1, expected_data[i].second); } // Test SQLGetTypeInfo with SQL_ALL_TYPES and data_type @@ -163,7 +159,7 @@ static void TestSQLTables(HSTMT &hstmt, std::map &types_ma DATA_CHECK(hstmt, 3, "bool_table"); break; case 2: - DATA_CHECK(hstmt, 3, "byte_table"); + DATA_CHECK(hstmt, 3, "bytea_table"); break; case 3: DATA_CHECK(hstmt, 3, "interval_table"); @@ -220,8 +216,8 @@ static void TestSQLColumns(HSTMT &hstmt, std::map &types_m std::vector> expected_data = { {"bool_table", "id", "13", "INTEGER"}, {"bool_table", "t", "25", "VARCHAR"}, - {"bool_table", "b", "10", "BOOLEAN"}, {"byte_table", "id", "13", "INTEGER"}, - {"byte_table", "t", "26", "BLOB"}, {"interval_table", "id", "13", "INTEGER"}, + {"bool_table", "b", "10", "BOOLEAN"}, {"bytea_table", "id", "13", "INTEGER"}, + {"bytea_table", "t", "26", "BLOB"}, {"interval_table", "id", "13", "INTEGER"}, {"interval_table", "iv", "27", "INTERVAL"}, {"interval_table", "d", "25", "VARCHAR"}, {"lo_test_table", "id", "13", "INTEGER"}, {"lo_test_table", "large_data", "26", "BLOB"}, {"test_table_1", "id", "13", "INTEGER"}, {"test_table_1", "t", "25", "VARCHAR"}, @@ -240,12 +236,12 @@ static void TestSQLColumns(HSTMT &hstmt, std::map &types_m } auto &entry = expected_data[i]; - DATA_CHECK(hstmt, 1, nullptr); + DATA_CHECK(hstmt, 1, ""); DATA_CHECK(hstmt, 2, "main"); - DATA_CHECK(hstmt, 3, entry[0].c_str()); - DATA_CHECK(hstmt, 4, entry[1].c_str()); - DATA_CHECK(hstmt, 5, entry[2].c_str()); - DATA_CHECK(hstmt, 6, entry[3].c_str()); + DATA_CHECK(hstmt, 3, entry[0]); + DATA_CHECK(hstmt, 4, entry[1]); + DATA_CHECK(hstmt, 5, entry[2]); + DATA_CHECK(hstmt, 6, entry[3]); } } diff --git a/tools/odbc/test/tests/cte.cpp b/tools/odbc/test/tests/cte.cpp new file mode 100644 index 000000000000..1ac08723e55a --- /dev/null +++ b/tools/odbc/test/tests/cte.cpp @@ -0,0 +1,76 @@ +#include "../common.h" + +using namespace odbc_test; + +static void RunDataCheckOnTable(HSTMT &hstmt, int num_rows) { + for (int i = 1; i <= num_rows; i++) { + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + DATA_CHECK(hstmt, 1, std::to_string(i)); + DATA_CHECK(hstmt, 2, std::string("foo") + std::to_string(i)); + } +} + +// Test Simple With Query +static void SimpleWithTest(HSTMT &hstmt) { + EXECUTE_AND_CHECK("SQLExectDirect(WITH)", SQLExecDirect, hstmt, + ConvertToSQLCHAR("with recursive cte as (select g, 'foo' || g as foocol from " + "generate_series(1,10) as g(g)) select * from cte;"), + SQL_NTS); + + RunDataCheckOnTable(hstmt, 10); + + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +// Test With Query with Prepare and Execute +static void PreparedWithTest(HSTMT &hstmt) { + EXECUTE_AND_CHECK("SQLPrepare(WITH)", SQLPrepare, hstmt, + ConvertToSQLCHAR("with cte as (select g, 'foo' || g as foocol from generate_series(1,10) as " + "g(g)) select * from cte WHERE g < ?"), + SQL_NTS); + + SQLINTEGER param = 3; + SQLLEN param_len = sizeof(param); + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_INTEGER, SQL_INTEGER, 0, 0, + ¶m, sizeof(param), ¶m_len); + + EXECUTE_AND_CHECK("SQLExecute", SQLExecute, hstmt); + + RunDataCheckOnTable(hstmt, 2); + + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +static void TestCTEandSetFetchEnv(const char *extra_params) { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLDriverConnect + DRIVER_CONNECT_TO_DATABASE(env, dbc, extra_params); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + SimpleWithTest(hstmt); + PreparedWithTest(hstmt); + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} + +/** + * Runs two WITH queries both with declare fetch on and off, which should not affect the result + * because the queries are not using cursors. + */ +TEST_CASE("Test CTE", "[odbc]") { + // First test with UseDeclareFetch=0 + TestCTEandSetFetchEnv("UseDeclareFetch=0"); + + // Then test with UseDeclareFetch=1 + TestCTEandSetFetchEnv("UseDeclareFetch=1;Fetch=1"); +} diff --git a/tools/odbc/test/tests/cursor_commit.cpp b/tools/odbc/test/tests/cursor_commit.cpp new file mode 100644 index 000000000000..ea0c86b4818b --- /dev/null +++ b/tools/odbc/test/tests/cursor_commit.cpp @@ -0,0 +1,117 @@ +#include "../common.h" + +using namespace odbc_test; + +/** + * Execute a query that generates a result set + */ +static void SimpleCursorCommitTest(SQLHANDLE dbc) { + HSTMT hstmt = SQL_NULL_HSTMT; + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + EXECUTE_AND_CHECK("SQLSetStmtAttr", SQLSetStmtAttr, hstmt, SQL_ATTR_CURSOR_TYPE, + ConvertToSQLPOINTER(SQL_CURSOR_STATIC), SQL_IS_UINTEGER); + + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt, + ConvertToSQLCHAR("SELECT g FROM generate_series(1,3) g(g)"), SQL_NTS); + + char buf[1024]; + SQLLEN buf_len; + EXECUTE_AND_CHECK("SQLBindCol", SQLBindCol, hstmt, 1, SQL_C_CHAR, &buf, sizeof(buf), &buf_len); + + // Commit. This implicitly closes the cursor in the server. + EXECUTE_AND_CHECK("SQLEndTran", SQLEndTran, SQL_HANDLE_DBC, dbc, SQL_COMMIT); + + for (char i = 1; i < 4; i++) { + EXECUTE_AND_CHECK("SQLFetchScroll", SQLFetchScroll, hstmt, SQL_FETCH_NEXT, 0); + REQUIRE(buf_len == 1); + REQUIRE(STR_EQUAL(buf, std::to_string(i).c_str())); + } + + SQLRETURN ret = SQLFetchScroll(hstmt, SQL_FETCH_NEXT, 0); + REQUIRE(ret == SQL_NO_DATA); + + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); +} + +static void PreparedCursorCommitTest(SQLHANDLE dbc) { + HSTMT hstmt = SQL_NULL_HSTMT; + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Try to commit without an open query + EXECUTE_AND_CHECK("SQLEndTran", SQLEndTran, SQL_HANDLE_DBC, dbc, SQL_COMMIT); + + // Prepare a statement + EXECUTE_AND_CHECK("SQLPrepare", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT ?::BOOL"), SQL_NTS); + + // Commit with a prepared statement + EXECUTE_AND_CHECK("SQLEndTran", SQLEndTran, SQL_HANDLE_DBC, dbc, SQL_COMMIT); + + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); +} + +static void MultipleHSTMTTest(SQLHANDLE dbc) { + HSTMT hstmt1 = SQL_NULL_HSTMT; + HSTMT hstmt2 = SQL_NULL_HSTMT; + + // Allocate a statement handles + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt1); + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt2); + + // Execute queries on both statement handles + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt1, + ConvertToSQLCHAR("SELECT g FROM generate_series(1,3) g(g)"), SQL_NTS); + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt2, + ConvertToSQLCHAR("SELECT g FROM generate_series(1,3) g(g)"), SQL_NTS); + + // Free first statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt1, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt1); + + // Commit test after the first handle is released + EXECUTE_AND_CHECK("SQLEndTran", SQLEndTran, SQL_HANDLE_DBC, dbc, SQL_COMMIT); + + SQLINTEGER buf; + SQLLEN buf_len; + + EXECUTE_AND_CHECK("SQLBindCol", SQLBindCol, hstmt2, 1, SQL_C_SLONG, &buf, sizeof(buf), &buf_len); + + // Fetch from the second statement handle + for (int i = 1; i < 4; i++) { + EXECUTE_AND_CHECK("SQLFetchScroll", SQLFetchScroll, hstmt2, SQL_FETCH_NEXT, 0); + REQUIRE(buf_len == sizeof(buf)); + REQUIRE(buf == i); + } + + SQLRETURN ret = SQLFetchScroll(hstmt2, SQL_FETCH_NEXT, 0); + REQUIRE(ret == SQL_NO_DATA); + + // Free second statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt2, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt2); +} + +// These tests are related to cursor commit behavior. +// The cursor represents the result set of a query, and is closed when the transaction is committed. +TEST_CASE("Test setting cursor attributes, and closing the cursor", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + EXECUTE_AND_CHECK("SQLSetConnectAttr", SQLSetConnectAttr, dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, + SQL_IS_INTEGER); + + SimpleCursorCommitTest(dbc); + PreparedCursorCommitTest(dbc); + MultipleHSTMTTest(dbc); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/data_execution.cpp b/tools/odbc/test/tests/data_execution.cpp new file mode 100644 index 000000000000..1d24e3d95200 --- /dev/null +++ b/tools/odbc/test/tests/data_execution.cpp @@ -0,0 +1,131 @@ +#include "../common.h" + +using namespace odbc_test; + +static void DataAtExecution(HSTMT &hstmt) { + // Prepare a statement + EXECUTE_AND_CHECK("SQLPrepare", SQLPrepare, hstmt, + ConvertToSQLCHAR("SELECT id FROM bytea_table WHERE t = ? OR t = ?"), SQL_NTS); + + SQLCHAR *param_1 = ConvertToSQLCHAR("bar"); + SQLLEN param_1_bytes = strlen(ConvertToCString(param_1)); + SQLLEN param_1_len = SQL_DATA_AT_EXEC; + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_VARCHAR, + param_1_bytes, 0, ConvertToSQLPOINTER(1), 0, ¶m_1_len); + + SQLLEN param_2_len = SQL_DATA_AT_EXEC; + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 2, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_VARCHAR, 6, 0, + ConvertToSQLPOINTER(2), 0, ¶m_2_len); + + // Execute the statement + SQLRETURN ret = SQLExecute(hstmt); + REQUIRE(ret == SQL_NEED_DATA); + + // Set the parameter data + SQLPOINTER param_id = nullptr; + while ((ret = SQLParamData(hstmt, ¶m_id)) == SQL_NEED_DATA) { + if (param_id == ConvertToSQLPOINTER(1)) { + EXECUTE_AND_CHECK("SQLPutData", SQLPutData, hstmt, param_1, param_1_bytes); + } else if (param_id == ConvertToSQLPOINTER(2)) { + EXECUTE_AND_CHECK("SQLPutData", SQLPutData, hstmt, ConvertToSQLPOINTER("foo"), 3); + EXECUTE_AND_CHECK("SQLPutData", SQLPutData, hstmt, ConvertToSQLPOINTER("bar"), 3); + } else { + FAIL("Unexpected parameter id"); + } + } + ODBC_CHECK(ret, "SQLParamData"); + + // Fetch the results + for (int i = 2; i < 4; i++) { + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + DATA_CHECK(hstmt, 0, std::to_string(i)); + } + + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +static void ArrayBindingDataAtExecution(HSTMT &hstmt) { + SQLLEN str_ind[2] = {SQL_DATA_AT_EXEC, SQL_DATA_AT_EXEC}; + SQLUSMALLINT status[2]; + SQLULEN num_processed; + + // Prepare a statement + EXECUTE_AND_CHECK("SQLPrepare", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT id FROM bytea_table WHERE t = ?"), + SQL_NTS); + + // Set STMT attributes PARAM_BIND_TYPE, PARAM_STATUS_PTR, PARAMS_PROCESSED_PTR, and PARAMSET_SIZE + EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_PARAM_BIND_TYPE)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAM_BIND_TYPE, + reinterpret_cast(SQL_PARAM_BIND_BY_COLUMN), 0); + EXECUTE_AND_CHECK("SQLSetStmtAttr(SQL_ATTR_PARAM_STATUS_PTR)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAM_STATUS_PTR, + status, 0); + EXECUTE_AND_CHECK("SQLSetStmtAttr(SQL_ATTR_PARAMS_PROCESSED_PTR)", SQLSetStmtAttr, hstmt, + SQL_ATTR_PARAMS_PROCESSED_PTR, &num_processed, 0); + EXECUTE_AND_CHECK("SQLSetStmtAttr(SQL_ATTR_PARAMSET_SIZE)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAMSET_SIZE, + ConvertToSQLPOINTER(2), 0); + + // Bind the array + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_VARBINARY, 5, + 0, ConvertToSQLPOINTER(1), 0, str_ind); + + // Execute the statement + SQLRETURN ret = SQLExecute(hstmt); + REQUIRE(ret == SQL_NEED_DATA); + + // Set the parameter data + SQLPOINTER param_id = nullptr; + while ((ret = SQLParamData(hstmt, ¶m_id)) == SQL_NEED_DATA) { + if (num_processed == 1) { + EXECUTE_AND_CHECK("SQLPutData", SQLPutData, hstmt, ConvertToSQLPOINTER("foo"), 3); + } else if (num_processed == 2) { + EXECUTE_AND_CHECK("SQLPutData", SQLPutData, hstmt, ConvertToSQLPOINTER("barf"), 4); + } else { + FAIL("Unexpected parameter id"); + } + } + ODBC_CHECK(ret, "SQLParamData"); + + for (int i = 0; i < num_processed; i++) { + REQUIRE(status[i] == SQL_PARAM_SUCCESS); + } + + // Fetch the results + for (int i = 4; i; i++) { + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + DATA_CHECK(hstmt, 0, std::to_string(i)); + + ret = SQLMoreResults(hstmt); + if (ret == SQL_NO_DATA) { + break; + } else if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + ODBC_CHECK(ret, "SQLMoreResults"); + } + } + REQUIRE(SQLFetch(hstmt) == SQL_NO_DATA); +} + +TEST_CASE("Test SQLBindParameter, SQLParamData, and SQLPutData", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + InitializeDatabase(hstmt); + + // Tests data-at-execution for a single parameter + DataAtExecution(hstmt); + + // Tests data-at-execution for an array of parameters + ArrayBindingDataAtExecution(hstmt); + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/declare_fetch_block.cpp b/tools/odbc/test/tests/declare_fetch_block.cpp new file mode 100644 index 000000000000..de777be899bf --- /dev/null +++ b/tools/odbc/test/tests/declare_fetch_block.cpp @@ -0,0 +1,174 @@ +#include "../common.h" + +using namespace odbc_test; + +const int TABLE_SIZE[] = {120, 4096}; +const int ARRAY_SIZE[] = {84, 512}; + +enum ESize { SMALL, LARGE }; + +static void TemporaryTable(HSTMT &hstmt, ESize S) { + EXECUTE_AND_CHECK("SQLExecDirect (CREATE TABLE)", SQLExecDirect, hstmt, + ConvertToSQLCHAR("CREATE TEMPORARY TABLE test (id int4 primary key)"), SQL_NTS); + + // Insert S size rows + for (int i = 0; i < TABLE_SIZE[S]; i++) { + std::string query = "INSERT INTO test VALUES (" + std::to_string(i) + ")"; + EXECUTE_AND_CHECK("SQLExecDirect (INSERT)", SQLExecDirect, hstmt, ConvertToSQLCHAR(query), SQL_NTS); + } + + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +static void BlockCursor(HSTMT &hstmt, ESize S, SQLINTEGER *&id, SQLLEN *&id_ind) { + SQLULEN rows_fetched; + + // Set array S to ARRAY_SIZE[S] + EXECUTE_AND_CHECK("SQLSetStmtAttr (ROW_ARRAY_SIZE)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROW_ARRAY_SIZE, + ConvertToSQLPOINTER(ARRAY_SIZE[S]), 0); + // Set ROWS_FETCHED_PTR to rows_fetched + EXECUTE_AND_CHECK("SQLSetStmtAttr (ROWS_FETCHED_PTR)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROWS_FETCHED_PTR, + &rows_fetched, 0); + + // Bind Column + EXECUTE_AND_CHECK("SQLBindCol (id)", SQLBindCol, hstmt, 1, SQL_C_SLONG, id, 0, id_ind); + + // Execute the query + EXECUTE_AND_CHECK("SQLExecDirect (SELECT)", SQLExecDirect, hstmt, ConvertToSQLCHAR("SELECT * FROM test"), SQL_NTS); + + int expected_rows_fetched = 0; + if (S == SMALL) { + expected_rows_fetched = 2; + } else { + expected_rows_fetched = 8; + } + + int total_rows_fetched = 0; + + // Fetch results + for (int i = 0; i < expected_rows_fetched; i++) { + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + REQUIRE(rows_fetched <= ARRAY_SIZE[S]); + total_rows_fetched += rows_fetched; + REQUIRE(total_rows_fetched == i * ARRAY_SIZE[S] + rows_fetched); + } + REQUIRE(total_rows_fetched == TABLE_SIZE[S]); + + SQLRETURN ret = SQLFetch(hstmt); + REQUIRE(ret == SQL_NO_DATA); + + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +static void FetchRows(HSTMT &hstmt, SQLULEN &rows_fetched, SQLSMALLINT scroll_orientation, ESize S) { + int total_rows_fetched = 0; + for (int j = 0; j < TABLE_SIZE[S]; j++) { + EXECUTE_AND_CHECK("SQLFetchScroll", SQLFetchScroll, hstmt, scroll_orientation, 0); + REQUIRE(rows_fetched == 1); + total_rows_fetched++; + } + REQUIRE(total_rows_fetched == TABLE_SIZE[S]); + REQUIRE(SQLFetchScroll(hstmt, scroll_orientation, 0) == SQL_NO_DATA); +} + +static void ScrollNext(HSTMT &hstmt, ESize S) { + SQLULEN rows_fetched; + + // Set array size to 1, + EXECUTE_AND_CHECK("SQLSetStmtAttr(ROW_ARRAY_SIZE)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROW_ARRAY_SIZE, + ConvertToSQLPOINTER(1), SQL_IS_INTEGER); + // Set rows fetched ptr + EXECUTE_AND_CHECK("SQLSetStmtAttr (ROWS_FETCHED_PTR)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROWS_FETCHED_PTR, + &rows_fetched, 0); + // Cursor Type to Static: which means data in the result set is static + EXECUTE_AND_CHECK("SQLSetStmtAttr(CURSOR_TYPE)", SQLSetStmtAttr, hstmt, SQL_ATTR_CURSOR_TYPE, + ConvertToSQLPOINTER(SQL_CURSOR_STATIC), 0); + // and Concurrency to Rowver: Cursor uses optimistic concurrency control, comparing row versions such as SQLBase + // ROWID or Sybase TIMESTAMP. + SQLRETURN ret = SQLSetStmtAttr(hstmt, SQL_ATTR_CONCURRENCY, ConvertToSQLPOINTER(SQL_CONCUR_ROWVER), 0); + REQUIRE(ret == SQL_SUCCESS_WITH_INFO); + + // Execute the query + EXECUTE_AND_CHECK("SQLExecDirect (SELECT)", SQLExecDirect, hstmt, ConvertToSQLCHAR("SELECT * FROM test"), SQL_NTS); + + // Fetch results using SQLFetchScroll + for (int i = 0; i < 2; i++) { + // First check if fetch next works + FetchRows(hstmt, rows_fetched, SQL_FETCH_NEXT, S); + + // Then check if fetch prior works + FetchRows(hstmt, rows_fetched, SQL_FETCH_PRIOR, S); + } + + // Close the cursor + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); + // Unbind the columns + EXECUTE_AND_CHECK("SQLFreeStmt(UNBIND)", SQLFreeStmt, hstmt, SQL_UNBIND); +} + +static void FetchAbsolute(HSTMT &hstmt, ESize S) { + SQLULEN rows_fetched; + // Set rows fetched ptr + EXECUTE_AND_CHECK("SQLSetStmtAttr (ROWS_FETCHED_PTR)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROWS_FETCHED_PTR, + &rows_fetched, 0); + + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt, ConvertToSQLCHAR("SELECT * FROM test"), SQL_NTS); + + // Fetch beyond the last row, should return SQL_NO_DATA + SQLRETURN ret = SQLFetchScroll(hstmt, SQL_FETCH_ABSOLUTE, TABLE_SIZE[S] + 1); + REQUIRE(ret == SQL_NO_DATA); + + EXECUTE_AND_CHECK("SQLFetchScroll (ABSOLUTE, 1)", SQLFetchScroll, hstmt, SQL_FETCH_ABSOLUTE, 1); + + // Keep fetching until we reach the last row + int id; + for (id = 1; id; id++) { + if (SQLFetchScroll(hstmt, SQL_FETCH_NEXT, 0) == SQL_NO_DATA) { + id--; + break; + } + } + REQUIRE(id == TABLE_SIZE[S]); + + // Close the cursor + EXECUTE_AND_CHECK("SQLFreeStmt(CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +TEST_CASE("Test Using SQLFetchScroll with different orrientations", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + HSTMT hstmt = SQL_NULL_HSTMT; + + // Perform the tests for both SMALL and LARGE tables and different fetch sizes + ESize size[] = {SMALL, LARGE}; + for (int i = 0; i < 2; i++) { + // Connect to the database using SQLDriverConnect with UseDeclareFetch=1 + DRIVER_CONNECT_TO_DATABASE(env, dbc, "UseDeclareFetch=1"); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Create a temporary table and insert size[i] rows + TemporaryTable(hstmt, size[i]); + + SQLINTEGER *id = new SQLINTEGER[TABLE_SIZE[size[i]]]; + SQLLEN *id_ind = new SQLLEN[TABLE_SIZE[size[i]]]; + // Block cursor, fetch rows in blocks of size[i] + BlockCursor(hstmt, size[i], id, id_ind); + + // Scroll cursor, fetch rows one by one + ScrollNext(hstmt, size[i]); + + // Fetch rows using SQL_FETCH_ABSOLUTE + FetchAbsolute(hstmt, size[i]); + + delete[] id; + delete[] id_ind; + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); + } +} diff --git a/tools/odbc/test/tests/multicolumn_param_bind.cpp b/tools/odbc/test/tests/multicolumn_param_bind.cpp new file mode 100644 index 000000000000..8934dc5d723d --- /dev/null +++ b/tools/odbc/test/tests/multicolumn_param_bind.cpp @@ -0,0 +1,103 @@ +#include "../common.h" + +using namespace odbc_test; + +#define MAX_INSERT_COUNT 2 +#define MAX_BUFFER_SIZE 100 + +TEST_CASE("Test binding multiple columsn at once", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Create a table + EXECUTE_AND_CHECK("SQLExecDirect (CREATE TABLE)", SQLExecDirect, hstmt, + ConvertToSQLCHAR("CREATE TABLE test_tbl (Column1 VARCHAR(100), Column2 VARCHAR(100))"), SQL_NTS); + + // Free and re-allocate the statement handle to clear the statement + EXECUTE_AND_CHECK("SQLFreeStmt (SQL_CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Set the SQL_ATTR_PARAM_BIND_TYPE statement attribute to use column-wise binding. + EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_PARAM_BIND_TYPE)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAM_BIND_TYPE, + reinterpret_cast(SQL_PARAM_BIND_BY_COLUMN), 0); + + // Specify an array in which to return the status of each set of parameters. + SQLUSMALLINT param_status[MAX_INSERT_COUNT]; + EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_PARAM_STATUS_PTR)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAM_STATUS_PTR, + param_status, 0); + + // Specify an SQLULEN value into which to return the number of sets of parameters processed. + SQLULEN params_processed; + EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_PARAMS_PROCESSED_PTR)", SQLSetStmtAttr, hstmt, + SQL_ATTR_PARAMS_PROCESSED_PTR, ¶ms_processed, 0); + + // Specify the number of parameter sets to be processed before execution occurs. + EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_PARAMSET_SIZE)", SQLSetStmtAttr, hstmt, SQL_ATTR_PARAMSET_SIZE, + ConvertToSQLPOINTER(MAX_INSERT_COUNT), 0); + + const char *c1_r1 = "John Doe", *c1_r2 = "Jane Doe"; + const char *c2_r1 = "John", *c2_r2 = "Jane"; + + SQLLEN c1_ind[MAX_INSERT_COUNT] = {static_cast(strlen(c1_r1)), static_cast(strlen(c1_r2))}; + SQLLEN c2_ind[MAX_INSERT_COUNT] = {static_cast(strlen(c2_r1)), static_cast(strlen(c2_r2))}; + + char c1[MAX_INSERT_COUNT][MAX_BUFFER_SIZE]; + memcpy(c1[0], c1_r1, c1_ind[0]); + memcpy(c1[1], c1_r2, c1_ind[1]); + + char c2[MAX_INSERT_COUNT][MAX_BUFFER_SIZE]; + memcpy(c2[0], c2_r1, c2_ind[0]); + memcpy(c2[1], c2_r2, c2_ind[1]); + + // Bind the parameters in column-wise fashion. + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, + MAX_BUFFER_SIZE - 1, 0, c1, MAX_BUFFER_SIZE, c1_ind); + EXECUTE_AND_CHECK("SQLBindParameter", SQLBindParameter, hstmt, 2, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, + MAX_BUFFER_SIZE - 1, 0, c2, MAX_BUFFER_SIZE, c2_ind); + + // Execute the statement + EXECUTE_AND_CHECK("SQLExecDirect (INSERT)", SQLExecDirect, hstmt, + ConvertToSQLCHAR("INSERT INTO test_tbl (Column1, Column2) VALUES (?, ?)"), SQL_NTS); + + // Verify that the correct number of parameter sets were processed. + for (int i = 0; i < params_processed; i++) { + REQUIRE(param_status[i] == SQL_PARAM_SUCCESS); + } + + // Close the cursor + EXECUTE_AND_CHECK("SQLFreeStmt (SQL_CLOSE)", SQLFreeStmt, hstmt, SQL_CLOSE); + + // Get the data back and verify it + EXECUTE_AND_CHECK("SQLExecDirect (SELECT)", SQLExecDirect, hstmt, ConvertToSQLCHAR("SELECT * FROM test_tbl"), + SQL_NTS); + + std::string col_name[2] = {"Column1", "Column2"}; + std::map types_map = InitializeTypesMap(); + + for (int i = 0; i >= 0; i++) { + SQLRETURN ret = SQLFetch(hstmt); + if (ret == SQL_NO_DATA) { + break; + } + ODBC_CHECK(ret, "SQLFetch"); + for (int j = 0; j < params_processed; j++) { + EXECUTE_AND_CHECK("SQLGetData", SQLGetData, hstmt, 1, SQL_C_CHAR, c1[j], MAX_BUFFER_SIZE, &c1_ind[j]); + EXECUTE_AND_CHECK("SQLGetData", SQLGetData, hstmt, 2, SQL_C_CHAR, c2[j], MAX_BUFFER_SIZE, &c2_ind[j]); + } + } + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/numeric.cpp b/tools/odbc/test/tests/numeric.cpp new file mode 100644 index 000000000000..5e9884071e5b --- /dev/null +++ b/tools/odbc/test/tests/numeric.cpp @@ -0,0 +1,120 @@ +#include "../common.h" + +using namespace odbc_test; + +static unsigned char HexToInt(char c) { + if (c >= '0' && c <= '9') { + return static_cast(c - '0'); + } else if (c >= 'a' && c <= 'f') { + return static_cast(c - 'a' + 10); + } else if (c >= 'A' && c <= 'F') { + return static_cast(c - 'A' + 10); + } else { + FAIL("invalid hex-encoded numeric value"); + return 0; + } +} + +static void BuildNumericStruct(SQL_NUMERIC_STRUCT *numeric, unsigned char sign, const char *hexval, + unsigned char precision, unsigned char scale) { + memset(numeric, 0, sizeof(SQL_NUMERIC_STRUCT)); + numeric->sign = sign; + numeric->precision = precision; + numeric->scale = scale; + + // Convert hexval to binary + int len = 0; + while (*hexval) { + if (*hexval == ' ') { + hexval++; + continue; + } + if (len >= SQL_MAX_NUMERIC_LEN) { + FAIL("hex-encoded numeric value too long"); + } + numeric->val[len] = HexToInt(*hexval) << 4 | HexToInt(*(hexval + 1)); + hexval += 2; + len++; + } +} + +static void TestNumericParams(HSTMT &hstmt, unsigned char sign, const char *hexval, unsigned char precision, + unsigned char scale, const std::string &expected) { + SQL_NUMERIC_STRUCT numeric; + BuildNumericStruct(&numeric, sign, hexval, precision, scale); + + SQLLEN numeric_len = sizeof(numeric); + EXECUTE_AND_CHECK("SQLBindParameter (numeric)", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_C_NUMERIC, + SQL_NUMERIC, precision, scale, &numeric, numeric_len, &numeric_len); + + EXECUTE_AND_CHECK("SQLExecute", SQLExecute, hstmt); + + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + DATA_CHECK(hstmt, 1, expected); + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +static void TestNumericResult(HSTMT &hstmt, const char *num_str, const std::string &expected_result, + unsigned char expected_precision = 18, unsigned char expected_scale = 3, + unsigned int precision = 18, unsigned int scale = 3) { + SQL_NUMERIC_STRUCT numeric; + + std::string query = "SELECT '" + std::string(num_str) + "'::numeric(" + std::to_string(precision) + "," + + std::to_string(scale) + ")"; + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt, ConvertToSQLCHAR(query), SQL_NTS); + + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + EXECUTE_AND_CHECK("SQLGetData", SQLGetData, hstmt, 1, SQL_C_NUMERIC, &numeric, sizeof(numeric), nullptr); + REQUIRE(numeric.precision == expected_precision); + REQUIRE(numeric.scale == expected_scale); + REQUIRE(numeric.sign == 1); + REQUIRE(ConvertHexToString(numeric.val, expected_result.length()) == expected_result); +} + +TEST_CASE("Test numeric limits and conversion", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Test 25.212 with default precision and scale + EXECUTE_AND_CHECK("SQLPrepare (?::numeric)", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT ?::numeric"), SQL_NTS); + TestNumericParams(hstmt, 1, "7C62", 5, 3, "25.212"); + + // Test 0 (negative and positive) with precision 1 and scale 0 + EXECUTE_AND_CHECK("SQLPrepare (?::numeric(1,0))", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT ?::numeric(1,0)"), + SQL_NTS); + TestNumericParams(hstmt, 1, "00", 1, 0, "0"); + TestNumericParams(hstmt, 0, "00", 1, 0, "0"); + + // Test 7.70 with precision 3 and scale 2 + EXECUTE_AND_CHECK("SQLPrepare (?::numeric(3,2))", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT ?::numeric(3,2)"), + SQL_NTS); + TestNumericParams(hstmt, 1, "0203", 3, 2, "7.70"); + + // Test 12345678901234567890123456789012345678 with precision 38 and scale 0 + EXECUTE_AND_CHECK("SQLPrepare (?::numeric(38,0))", SQLPrepare, hstmt, ConvertToSQLCHAR("SELECT ?::numeric(38,0)"), + SQL_NTS); + TestNumericParams(hstmt, 1, "4EF338DE509049C4133302F0F6B04909", 38, 0, "12345678901234567890123456789012345678"); + + // Test setting numeric struct within the application + TestNumericResult(hstmt, "25.212", "7C62000000000000", 5, 3); + TestNumericResult(hstmt, "24197857161011715162171839636988778104", "7856341278563412", 38, 0, 38, 0); + TestNumericResult(hstmt, "12345678901234567890123456789012345678", "4EF338DE509049C4", 38, 0, 38, 0); + TestNumericResult(hstmt, "-0", "0000000000000000", 1, 3); + TestNumericResult(hstmt, "0", "0000000000000000", 1, 3); + TestNumericResult(hstmt, "7.70", "0203000000000000", 3, 2, 3, 2); + TestNumericResult(hstmt, "999999999999", "FF0FA5D4E8000000", 12, 3); + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/quotes.cpp b/tools/odbc/test/tests/quotes.cpp new file mode 100644 index 000000000000..10a3402981dc --- /dev/null +++ b/tools/odbc/test/tests/quotes.cpp @@ -0,0 +1,71 @@ +#include "../common.h" + +using namespace odbc_test; + +/** + * Bind a parameter and execute a query. Check that the result is as expected. + */ +static void BindParamAndExecute(HSTMT &hstmt, SQLCHAR *query, SQLCHAR *param, + const std::vector &expected_result) { + SQLLEN len = strlen((char *)param); + + EXECUTE_AND_CHECK("SQLBindParameter (param)", SQLBindParameter, hstmt, 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_CHAR, 20, + 0, param, len, &len); + + EXECUTE_AND_CHECK("SQLExecDirect", SQLExecDirect, hstmt, query, SQL_NTS); + + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + for (int i = 0; i < expected_result.size(); i++) { + DATA_CHECK(hstmt, i + 1, expected_result[i]); + } + + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +TEST_CASE("Test parameter quoting and in combination with special characters", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + // Check that the driver escapes quotes correctly + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT 'foo', ?::text"), ConvertToSQLCHAR("param'quote"), + {"foo", "param'quote"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT 'foo', ?::text"), ConvertToSQLCHAR("param\\backlash"), + {"foo", "param\\backlash"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT 'foo', ?::text"), ConvertToSQLCHAR("ends with backslash\\"), + {"foo", "ends with backslash\\"}); + + // Check that the driver's build-in parser interprets quotes correctly. Check that it distinguishes between ? + // parameter markers and ? literals. + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT 'doubled '' quotes', ?::text"), ConvertToSQLCHAR("param"), + {"doubled ' quotes", "param"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT E'escaped quote\\' here', ?::text"), ConvertToSQLCHAR("param"), + {"escaped quote' here", "param"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT $$dollar quoted string$$, ?::text"), ConvertToSQLCHAR("param"), + {"dollar quoted string", "param"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT $xx$complex $dollar quotes$xx$, ?::text"), + ConvertToSQLCHAR("param"), {"complex $dollar quotes", "param"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT $dollar$more complex $dollar quotes$dollar$, ?::text"), + ConvertToSQLCHAR("param"), {"more complex $dollar quotes", "param"}); + + // Test backlash escaping without the E'' syntax + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT 'escaped quote'' here', ?::text"), ConvertToSQLCHAR("param"), + {"escaped quote' here", "param"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT ?::text, '1' a$1"), ConvertToSQLCHAR("$ in an identifier"), + {"$ in an identifier", "1"}); + BindParamAndExecute(hstmt, ConvertToSQLCHAR("SELECT '1'::text a$$S1,?::text,$$2 $'s in an identifier$$::text"), + ConvertToSQLCHAR("param"), {"1", "param", "2 $'s in an identifier"}); + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/result_conversion.cpp b/tools/odbc/test/tests/result_conversion.cpp new file mode 100644 index 000000000000..0ab0bddc035c --- /dev/null +++ b/tools/odbc/test/tests/result_conversion.cpp @@ -0,0 +1,596 @@ +#include "../common.h" + +using namespace odbc_test; + +#define EXPECTED_ERROR "error" +#define EPSILON 0.0000001 + +static const std::vector> all_types = {{"boolean", "true"}, + {"bytea", "\\x464F4F"}, + {"char", "x"}, + {"int8", "1234567890"}, + {"int2", "12345"}, + {"int4", "1234567"}, + {"text", "textdata"}, + {"float4", "1.234"}, + {"double", "1.23456789012"}, + {"varchar", "foobar"}, + {"date", "2011-02-13"}, + {"time", "13:23:34"}, + {"timestamp", "2011-02-15 15:49:18"}, + {"interval", "10 years -11 months -12 days 13:14:00"}, + {"numeric", "1234.567890"}}; + +static const std::vector all_sql_types = {SQL_C_CHAR, + SQL_C_WCHAR, + SQL_C_SSHORT, + SQL_C_USHORT, + SQL_C_SLONG, + SQL_C_ULONG, + SQL_C_FLOAT, + SQL_C_DOUBLE, + SQL_C_BIT, + SQL_C_STINYINT, + SQL_C_UTINYINT, + SQL_C_SBIGINT, + SQL_C_UBIGINT, + SQL_C_BINARY, + SQL_C_BOOKMARK, + SQL_C_VARBOOKMARK, + SQL_C_TYPE_DATE, + SQL_C_TYPE_TIME, + SQL_C_TYPE_TIMESTAMP, + SQL_C_NUMERIC, + SQL_C_GUID, + SQL_C_INTERVAL_YEAR, + SQL_C_INTERVAL_MONTH, + SQL_C_INTERVAL_DAY, + SQL_C_INTERVAL_HOUR, + SQL_C_INTERVAL_MINUTE, + SQL_C_INTERVAL_SECOND, + SQL_C_INTERVAL_YEAR_TO_MONTH, + SQL_C_INTERVAL_DAY_TO_HOUR, + SQL_C_INTERVAL_DAY_TO_MINUTE, + SQL_C_INTERVAL_DAY_TO_SECOND, + SQL_C_INTERVAL_HOUR_TO_MINUTE, + SQL_C_INTERVAL_HOUR_TO_SECOND, + SQL_C_INTERVAL_MINUTE_TO_SECOND}; + +static const std::vector> results = { + { + "true", + "true", + "1", + "1", + "1", + "1", + "1.000000", + "1.000000", + EXPECTED_ERROR, + "1", + "1", + "1", + "1", + "74727565", + "1", + "74727565", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + }, + {"F4F4F", "F4F4F", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, "4634463446", + EXPECTED_ERROR, "4634463446", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"x", + "x", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + "78", + EXPECTED_ERROR, + "78", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR}, + {"1234567890", "1234567890", EXPECTED_ERROR, EXPECTED_ERROR, "1234567890", + "1234567890", "1234567936.000000", "1234567890.000000", EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, "1234567890", "1234567890", "31323334353637383930", "1234567890", + "31323334353637383930", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"12345", "12345", "12345", "12345", "12345", "12345", "12345.000000", + "12345.000000", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, "12345", "12345", "3132333435", + "12345", "3132333435", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"1234567", "1234567", EXPECTED_ERROR, EXPECTED_ERROR, "1234567", "1234567", + "1234567.000000", "1234567.000000", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, "1234567", + "1234567", "31323334353637", "1234567", "31323334353637", EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"textdata", "textdata", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, "7465787464617461", EXPECTED_ERROR, "7465787464617461", EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + { + "1.234", + "1.234", + "1", + "1", + "1", + "1", + "1.234000", + "1.2339999676", + EXPECTED_ERROR, + "1", + "1", + "1", + "1", + "312E323334", + "1", + "312E323334", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + }, + {"1.23456789012", + "1.23456789012", + "1", + "1", + "1", + "1", + "1.234568", + "1.234568", + EXPECTED_ERROR, + "1", + "1", + "1", + "1", + "312E3233343536373839303132", + "1", + "312E3233343536373839303132", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR}, + {"foobar", "foobar", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, "666F6F626172", + EXPECTED_ERROR, "666F6F626172", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"2011-02-13", "2011-02-13", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, "323031312D30322D3133", EXPECTED_ERROR, + "323031312D30322D3133", "2011-2-13", EXPECTED_ERROR, "2011-2-13-0-0-0-0", EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"13:23:34", "13:23:34", EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, "31333A32333A3334", EXPECTED_ERROR, "31333A32333A3334", EXPECTED_ERROR, "13-23-34", + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR, EXPECTED_ERROR}, + {"2011-02-15 15:49:18", + "2011-02-15 15:49:18", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + "323031312D30322D31352031353A34393A3138", + EXPECTED_ERROR, + "323031312D30322D31352031353A34393A3138", + "2011-2-15", + "15-49-18", + "2011-2-15-15-49-18-0", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR}, + {"9 years 1 month -12 days 13:14:00", + "9 years 1 month -12 days 13:14:00", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + "392079656172732031206D6F6E7468202D313220646179732031333A31343A3030", + EXPECTED_ERROR, + "392079656172732031206D6F6E7468202D313220646179732031333A31343A3030", + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + EXPECTED_ERROR, + "0-9", + "0-109", + "0-3258", + "0-78205", + "0-4692314", + "0-281538840", + "0-9-1", + "0-3258-13", + "0-3258-13-14", + "0-3258-13-14-0-0", + "0-78205-14", + "0-78205-14-0-0", + "0-4692314-0-0"}, + {"1234.568", "1234.568", + "1235", "1235", + "1235", "1235", + "1234.567993", "1234.568000", + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, "1235", + "1235", "313233342E353638", + "1235", "313233342E353638", + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, "7:3:1:88D61200000000000000000000000000", + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR, + EXPECTED_ERROR, EXPECTED_ERROR}, +}; + +static std::vector Split(std::string text, char delim) { + std::string line; + std::vector vec; + std::stringstream ss(text); + while (std::getline(ss, line, delim)) { + vec.push_back(line); + } + return vec; +} + +static void ConvertToTypes(SQLINTEGER sql_type, void *result, const std::string &expected_result) { + switch (sql_type) { + case SQL_C_CHAR: + REQUIRE(STR_EQUAL(static_cast(result), expected_result.c_str())); + return; + case SQL_C_WCHAR: { + WCHAR *wresult = static_cast(result); + for (size_t i = 0; i < expected_result.size(); i++) { + REQUIRE(wresult[i] == static_cast(expected_result[i])); + } + return; + } + case SQL_C_SSHORT: + REQUIRE(*static_cast(result) == static_cast(std::stoi(expected_result))); + return; + case SQL_C_USHORT: + REQUIRE(*static_cast(result) == static_cast(std::stoul(expected_result))); + return; + case SQL_C_SLONG: + REQUIRE(*static_cast(result) == static_cast(std::stoi(expected_result))); + return; + case SQL_C_ULONG: + REQUIRE(*static_cast(result) == static_cast(std::stoul(expected_result))); + return; + case SQL_C_FLOAT: + REQUIRE(*static_cast(result) == + Approx(static_cast(std::stof(expected_result))).margin(EPSILON)); + return; + case SQL_C_DOUBLE: + REQUIRE(*static_cast(result) == + Approx(static_cast(std::stod(expected_result))).margin(EPSILON)); + return; + case SQL_C_BIT: + REQUIRE(*static_cast(result) == static_cast(std::stoi(expected_result))); + return; + case SQL_C_STINYINT: + REQUIRE(*static_cast(result) == static_cast(std::stoi(expected_result))); + return; + case SQL_C_UTINYINT: + REQUIRE(*static_cast(result) == static_cast(std::stoul(expected_result))); + return; + case SQL_C_SBIGINT: + REQUIRE(*static_cast(result) == static_cast(std::stoll(expected_result))); + return; + case SQL_C_UBIGINT: + REQUIRE(*static_cast(result) == static_cast(std::stoull(expected_result))); + return; + case SQL_C_BINARY: + REQUIRE(ConvertHexToString(static_cast(result), expected_result.length()) == expected_result); + return; + case SQL_C_TYPE_DATE: { + auto *date = static_cast(result); + std::vector split_expected_result = Split(expected_result, '-'); + REQUIRE(date->year == std::stoi(split_expected_result[0])); + REQUIRE(date->month == std::stoi(split_expected_result[1])); + REQUIRE(date->day == std::stoi(split_expected_result[2])); + return; + } + case SQL_C_TYPE_TIME: { + auto *time = static_cast(result); + std::vector split_expected_result = Split(expected_result, '-'); + REQUIRE(time->hour == std::stoi(split_expected_result[0])); + REQUIRE(time->minute == std::stoi(split_expected_result[1])); + REQUIRE(time->second == std::stoi(split_expected_result[2])); + return; + } + case SQL_C_TYPE_TIMESTAMP: { + auto *timestamp = static_cast(result); + std::vector split_expected_result = Split(expected_result, '-'); + REQUIRE(timestamp->year == std::stoi(split_expected_result[0])); + REQUIRE(timestamp->month == std::stoi(split_expected_result[1])); + REQUIRE(timestamp->day == std::stoi(split_expected_result[2])); + REQUIRE(timestamp->hour == std::stoi(split_expected_result[3])); + REQUIRE(timestamp->minute == std::stoi(split_expected_result[4])); + REQUIRE(timestamp->second == std::stoi(split_expected_result[5])); + REQUIRE(timestamp->fraction == std::stoi(split_expected_result[6])); + return; + } + case SQL_C_NUMERIC: { + auto *numeric = static_cast(result); + std::vector split_expected_result = Split(expected_result, ':'); + REQUIRE(numeric->precision == std::stoi(split_expected_result[0])); + REQUIRE(numeric->scale == std::stoi(split_expected_result[1])); + REQUIRE(numeric->sign == std::stoi(split_expected_result[2])); + REQUIRE(ConvertHexToString(numeric->val, split_expected_result[3].length()) == split_expected_result[3]); + return; + } + case SQL_C_GUID: { + // This one never gets called because it never succeeds + return; + } + case SQL_C_INTERVAL_YEAR: + case SQL_C_INTERVAL_MONTH: + case SQL_C_INTERVAL_DAY: + case SQL_C_INTERVAL_HOUR: + case SQL_C_INTERVAL_MINUTE: + case SQL_C_INTERVAL_SECOND: + case SQL_C_INTERVAL_YEAR_TO_MONTH: + case SQL_C_INTERVAL_DAY_TO_HOUR: + case SQL_C_INTERVAL_DAY_TO_MINUTE: + case SQL_C_INTERVAL_DAY_TO_SECOND: + case SQL_C_INTERVAL_HOUR_TO_MINUTE: + case SQL_C_INTERVAL_HOUR_TO_SECOND: + case SQL_C_INTERVAL_MINUTE_TO_SECOND: { + auto *interval = static_cast(result); + std::vector split_expected_result = Split(expected_result, '-'); + REQUIRE(interval->interval_sign == std::stoi(split_expected_result[0])); + switch (interval->interval_type) { + case SQL_IS_YEAR: + REQUIRE(interval->intval.year_month.year == std::stoi(split_expected_result[1])); + break; + case SQL_IS_MONTH: + REQUIRE(interval->intval.year_month.month == std::stoi(split_expected_result[1])); + break; + case SQL_IS_DAY: + REQUIRE(interval->intval.day_second.day == std::stoi(split_expected_result[1])); + break; + case SQL_IS_HOUR: + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[1])); + break; + case SQL_IS_MINUTE: + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[1])); + break; + case SQL_IS_SECOND: + REQUIRE(interval->intval.day_second.second == std::stoi(split_expected_result[1])); + break; + case SQL_IS_YEAR_TO_MONTH: + REQUIRE(interval->intval.year_month.year == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.year_month.month == std::stoi(split_expected_result[2])); + break; + case SQL_IS_DAY_TO_HOUR: + REQUIRE(interval->intval.day_second.day == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[2])); + break; + case SQL_IS_DAY_TO_MINUTE: + REQUIRE(interval->intval.day_second.day == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[2])); + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[3])); + break; + case SQL_IS_DAY_TO_SECOND: + REQUIRE(interval->intval.day_second.day == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[2])); + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[3])); + REQUIRE(interval->intval.day_second.second == std::stoi(split_expected_result[4])); + REQUIRE(interval->intval.day_second.fraction == std::stoi(split_expected_result[5])); + break; + case SQL_IS_HOUR_TO_MINUTE: + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[2])); + break; + case SQL_IS_HOUR_TO_SECOND: + REQUIRE(interval->intval.day_second.hour == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[2])); + REQUIRE(interval->intval.day_second.second == std::stoi(split_expected_result[3])); + REQUIRE(interval->intval.day_second.fraction == std::stoi(split_expected_result[4])); + break; + case SQL_IS_MINUTE_TO_SECOND: + REQUIRE(interval->intval.day_second.minute == std::stoi(split_expected_result[1])); + REQUIRE(interval->intval.day_second.second == std::stoi(split_expected_result[2])); + REQUIRE(interval->intval.day_second.fraction == std::stoi(split_expected_result[3])); + break; + default: + FAIL("Unknown interval type"); + } + return; + } + default: + FAIL("Unknown type"); + } +} + +static void ConvertAndCheck(HSTMT &hstmt, const std::string &type, const std::string &type_to_convert, + SQLINTEGER sql_type, const std::string &expected_result, int content_size = 256) { + std::string query = "SELECT $$" + type_to_convert + "$$::" + type; + + EXECUTE_AND_CHECK(query.c_str(), SQLExecDirect, hstmt, ConvertToSQLCHAR(query), SQL_NTS); + + EXECUTE_AND_CHECK("SQLFetch", SQLFetch, hstmt); + SQLCHAR content[256]; + SQLLEN content_len; + SQLRETURN ret = SQLGetData(hstmt, 1, sql_type, content, content_size, &content_len); + if (expected_result == EXPECTED_ERROR) { + REQUIRE(ret == SQL_ERROR); + return; + } + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + FAIL("SQLGetData failed"); + } + ConvertToTypes(sql_type, content, expected_result); + + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); +} + +TEST_CASE("Test converting using SQLGetData", "[odbc]") { + SQLHANDLE env; + SQLHANDLE dbc; + + HSTMT hstmt = SQL_NULL_HSTMT; + + // Connect to the database using SQLConnect + CONNECT_TO_DATABASE(env, dbc); + + // Allocate a statement handle + EXECUTE_AND_CHECK("SQLAllocHandle (HSTMT)", SQLAllocHandle, SQL_HANDLE_STMT, dbc, &hstmt); + + for (int type_index = 0; type_index < all_types.size(); type_index++) { + for (int sql_type_index = 0; sql_type_index < all_sql_types.size(); sql_type_index++) { + ConvertAndCheck(hstmt, all_types[type_index][0], all_types[type_index][1], all_sql_types[sql_type_index], + results[type_index][sql_type_index]); + } + } + + // Conversion to GUID throws error if the string is not of correct form + ConvertAndCheck(hstmt, "text", "543c5e21-435a-440b-943c-64af1ad571f1", SQL_C_GUID, EXPECTED_ERROR); + + // Test for truncations + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_CHAR, "foob", 5); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_CHAR, "fooba", 6); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_CHAR, "foobar", 7); + + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_WCHAR, "foob", 10); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_WCHAR, "foob", 11); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_WCHAR, "fooba", 12); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_WCHAR, "fooba", 13); + ConvertAndCheck(hstmt, "text", "foobar", SQL_C_WCHAR, "foobar", 14); + + ConvertAndCheck(hstmt, "text", "", SQL_C_CHAR, ""); + + // Test different timestamp subtype conversions + // Timestamp -> Date + ConvertAndCheck(hstmt, "timestamp_s", "2021-07-15 12:30:00", SQL_C_TYPE_DATE, "2021-7-15"); + ConvertAndCheck(hstmt, "timestamp_ms", "2021-07-15 12:30:00", SQL_C_TYPE_DATE, "2021-7-15"); + ConvertAndCheck(hstmt, "timestamp", "2021-07-15 12:30:00", SQL_C_TYPE_DATE, "2021-7-15"); + ConvertAndCheck(hstmt, "timestamp_ns", "2021-07-15 12:30:00", SQL_C_TYPE_DATE, "2021-7-15"); + + // TIMESTAMP -> TIME + ConvertAndCheck(hstmt, "timestamp_s", "2021-07-15 12:30:00", SQL_C_TYPE_TIME, "12-30-0"); + ConvertAndCheck(hstmt, "timestamp_ms", "2021-07-15 12:30:00", SQL_C_TYPE_TIME, "12-30-0"); + ConvertAndCheck(hstmt, "timestamp", "2021-07-15 12:30:00", SQL_C_TYPE_TIME, "12-30-0"); + ConvertAndCheck(hstmt, "timestamp_ns", "2021-07-15 12:30:00", SQL_C_TYPE_TIME, "12-30-0"); + + // TIMESTAMP -> TIMESTAMP + ConvertAndCheck(hstmt, "timestamp_s", "2021-07-15 12:30:00", SQL_C_TYPE_TIMESTAMP, "2021-7-15-12-30-0-0"); + ConvertAndCheck(hstmt, "timestamp_ms", "2021-07-15 12:30:00", SQL_C_TYPE_TIMESTAMP, "2021-7-15-12-30-0-0"); + ConvertAndCheck(hstmt, "timestamp", "2021-07-15 12:30:00", SQL_C_TYPE_TIMESTAMP, "2021-7-15-12-30-0-0"); + ConvertAndCheck(hstmt, "timestamp_ns", "2021-07-15 12:30:00", SQL_C_TYPE_TIMESTAMP, "2021-7-15-12-30-0-0"); + + // Free the statement handle + EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); + EXECUTE_AND_CHECK("SQLFreeHandle (HSTMT)", SQLFreeHandle, SQL_HANDLE_STMT, hstmt); + + DISCONNECT_FROM_DATABASE(env, dbc); +} diff --git a/tools/odbc/test/tests/row_wise_fetching.cpp b/tools/odbc/test/tests/row_wise_fetching.cpp index 655a1c511a06..07d3838feeb1 100644 --- a/tools/odbc/test/tests/row_wise_fetching.cpp +++ b/tools/odbc/test/tests/row_wise_fetching.cpp @@ -1,5 +1,4 @@ #include "../common.h" -#include using namespace odbc_test; diff --git a/tools/odbc/test/tests/select.cpp b/tools/odbc/test/tests/select.cpp index 3f365539bd61..bd3d8ab58ad5 100644 --- a/tools/odbc/test/tests/select.cpp +++ b/tools/odbc/test/tests/select.cpp @@ -47,7 +47,7 @@ TEST_CASE("Test Select Statement", "[odbc]") { // Check the data for (int i = 1; i < 1600; i++) { - DATA_CHECK(hstmt, i, std::to_string(i).c_str()); + DATA_CHECK(hstmt, i, std::to_string(i)); } // SELECT $x; should throw error diff --git a/tools/odbc/test/tests/set_attr.cpp b/tools/odbc/test/tests/set_attr.cpp index feaae48fa44a..da3ba21f8ddf 100644 --- a/tools/odbc/test/tests/set_attr.cpp +++ b/tools/odbc/test/tests/set_attr.cpp @@ -1,7 +1,5 @@ #include "../common.h" -#include - using namespace odbc_test; TEST_CASE("Test SQL_ATTR_ROW_BIND_TYPE attribute in SQLSetStmtAttr", "[odbc]") { @@ -19,7 +17,7 @@ TEST_CASE("Test SQL_ATTR_ROW_BIND_TYPE attribute in SQLSetStmtAttr", "[odbc]") { // Set the statement attribute SQL_ATTR_ROW_BIND_TYPE uint64_t row_len = 256; EXECUTE_AND_CHECK("SQLSetStmtAttr (SQL_ATTR_ROW_BIND_TYPE)", SQLSetStmtAttr, hstmt, SQL_ATTR_ROW_BIND_TYPE, - (SQLPOINTER)row_len, SQL_IS_INTEGER); + ConvertToSQLPOINTER(row_len), SQL_IS_INTEGER); // Check the statement attribute SQL_ATTR_ROW_BIND_TYPE SQLULEN buf; @@ -48,7 +46,7 @@ TEST_CASE("Test SQL_ATTR_ACCESS_MODE and SQL_ATTR_METADATA_ID attribute in SQLSe // Set the Connect attribute SQL_ATTR_ACCESS_MODE to SQL_MODE_READ_ONLY EXECUTE_AND_CHECK("SQLSetConnectAttr (SQL_ATTR_ACCESS_MODE)", SQLSetConnectAttr, dbc, SQL_ATTR_ACCESS_MODE, - (SQLPOINTER)SQL_MODE_READ_ONLY, SQL_IS_INTEGER); + ConvertToSQLPOINTER(SQL_MODE_READ_ONLY), SQL_IS_INTEGER); // Check the Connect attribute SQL_ATTR_ACCESS_MODE SQLUINTEGER buf; @@ -67,7 +65,7 @@ TEST_CASE("Test SQL_ATTR_ACCESS_MODE and SQL_ATTR_METADATA_ID attribute in SQLSe // Set the Connect attribute SQL_ATTR_METADATA_ID to SQL_TRUE EXECUTE_AND_CHECK("SQLSetConnectAttr (SQL_ATTR_METADATA_ID)", SQLSetConnectAttr, dbc, SQL_ATTR_METADATA_ID, - (SQLPOINTER)SQL_TRUE, SQL_IS_INTEGER); + ConvertToSQLPOINTER(SQL_TRUE), SQL_IS_INTEGER); // Free the statement handle EXECUTE_AND_CHECK("SQLFreeStmt (HSTMT)", SQLFreeStmt, hstmt, SQL_CLOSE); diff --git a/tools/pythonpkg/README.md b/tools/pythonpkg/README.md index 9fbbc76872fe..14130738569d 100644 --- a/tools/pythonpkg/README.md +++ b/tools/pythonpkg/README.md @@ -125,3 +125,14 @@ For example: export PATH="$PATH:/opt/homebrew/Cellar/llvm/15.0.2/bin" ``` +# What are py::objects and a py::handles?? + +These are classes provided by pybind11, the library we use to manage our interaction with the python environment. +py::handle is a direct wrapper around a raw PyObject* and does not manage any references. +py::object is similar to py::handle but it can handle refcounts. + +I say *can* because it doesn't have to, using `py::reinterpret_borrow(...)` we can create a non-owning py::object, this is essentially just a py::handle but py::handle can't be used if the prototype requires a py::object. + +`py::reinterpret_steal(...)` creates an owning py::object, this will increase the refcount of the python object and will decrease the refcount when the py::object goes out of scope. + +When directly interacting with python functions that return a `PyObject*`, such as `PyDateTime_DATE_GET_TZINFO`, you should generally wrap the call in `py::reinterpret_steal` to take ownership of the returned object. diff --git a/tools/pythonpkg/cibw.toml b/tools/pythonpkg/cibw.toml index beb1aaf60aae..bbbdeea1fb72 100644 --- a/tools/pythonpkg/cibw.toml +++ b/tools/pythonpkg/cibw.toml @@ -4,7 +4,7 @@ [tool.cibuildwheel] environment = "PIP_CONSTRAINT='build-constraints.txt'" before-build = 'pip install oldest-supported-numpy' -before-test = 'pip install --prefer-binary pandas pytest-timeout mypy "psutil>=5.9.0" "requests>=2.26" fsspec && (pip install --prefer-binary "pyarrow>=8.0" || true) && (pip install --prefer-binary "torch" || true) && (pip install --prefer-binary "polars" || true)&& (pip install --prefer-binary "tensorflow" || true)' +before-test = 'pip install --prefer-binary pandas pytest-timeout mypy "psutil>=5.9.0" "requests>=2.26" fsspec && (pip install --prefer-binary "pyarrow>=8.0" || true) && (pip install --prefer-binary "torch" || true) && (pip install --prefer-binary "polars" || true) && (pip install --prefer-binary "adbc_driver_manager" || true) && (pip install --prefer-binary "tensorflow" || true)' test-requires = 'pytest' test-command = 'DUCKDB_PYTHON_TEST_EXTENSION_PATH={project} DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED=1 python -m pytest {project}/tests' diff --git a/tools/pythonpkg/src/include/duckdb_python/import_cache/modules/datetime_module.hpp b/tools/pythonpkg/src/include/duckdb_python/import_cache/modules/datetime_module.hpp index 2a7b6f0d3dff..8fe2b79ebbef 100644 --- a/tools/pythonpkg/src/include/duckdb_python/import_cache/modules/datetime_module.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/import_cache/modules/datetime_module.hpp @@ -12,6 +12,40 @@ namespace duckdb { +struct DatetimeDatetimeCacheItem : public PythonImportCacheItem { +public: + static constexpr const char *Name = "datetime.datetime"; + +public: + ~DatetimeDatetimeCacheItem() override { + } + virtual void LoadSubtypes(PythonImportCache &cache) override { + max.LoadAttribute("max", cache, *this); + min.LoadAttribute("min", cache, *this); + } + +public: + PythonImportCacheItem max; + PythonImportCacheItem min; +}; + +struct DatetimeDateCacheItem : public PythonImportCacheItem { +public: + static constexpr const char *Name = "datetime.date"; + +public: + ~DatetimeDateCacheItem() override { + } + virtual void LoadSubtypes(PythonImportCache &cache) override { + max.LoadAttribute("max", cache, *this); + min.LoadAttribute("min", cache, *this); + } + +public: + PythonImportCacheItem max; + PythonImportCacheItem min; +}; + struct DatetimeCacheItem : public PythonImportCacheItem { public: static constexpr const char *Name = "datetime"; @@ -27,8 +61,8 @@ struct DatetimeCacheItem : public PythonImportCacheItem { } public: - PythonImportCacheItem datetime; - PythonImportCacheItem date; + DatetimeDatetimeCacheItem datetime; + DatetimeDateCacheItem date; PythonImportCacheItem time; PythonImportCacheItem timedelta; }; diff --git a/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache.hpp b/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache.hpp index 6e47b7d72b4b..d9a6c14feae7 100644 --- a/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache.hpp @@ -91,7 +91,7 @@ struct PythonImportCache { IpywidgetsCacheItem ipywidgets_module; public: - PyObject *AddCache(py::object item); + py::handle AddCache(py::object item); private: vector owned_objects; diff --git a/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache_item.hpp b/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache_item.hpp index f02e2b73f160..71713358ce9d 100644 --- a/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache_item.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/import_cache/python_import_cache_item.hpp @@ -38,13 +38,13 @@ struct PythonImportCacheItem { } private: - PyObject *AddCache(PythonImportCache &cache, py::object object); + py::handle AddCache(PythonImportCache &cache, py::object object); private: //! Whether or not we attempted to load the module bool load_succeeded; //! The stored item - PyObject *object; + py::handle object; }; } // namespace duckdb diff --git a/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_analyzer.hpp b/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_analyzer.hpp index ac98200cdc8e..306bd18b3946 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_analyzer.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_analyzer.hpp @@ -28,17 +28,17 @@ class PandasAnalyzer { } public: - LogicalType GetListType(py::handle &ele, bool &can_convert); + LogicalType GetListType(py::object &ele, bool &can_convert); LogicalType DictToMap(const PyDictionary &dict, bool &can_convert); LogicalType DictToStruct(const PyDictionary &dict, bool &can_convert); - LogicalType GetItemType(py::handle ele, bool &can_convert); - bool Analyze(py::handle column); + LogicalType GetItemType(py::object ele, bool &can_convert); + bool Analyze(py::object column); LogicalType AnalyzedType() { return analyzed_type; } private: - LogicalType InnerAnalyze(py::handle column, bool &can_convert, bool sample = true, idx_t increment = 1); + LogicalType InnerAnalyze(py::object column, bool &can_convert, bool sample = true, idx_t increment = 1); uint64_t GetSampleIncrement(idx_t rows); private: diff --git a/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_bind.hpp b/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_bind.hpp index 84c0a6d7ac7b..30625e571a44 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_bind.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pandas/pandas_bind.hpp @@ -17,8 +17,8 @@ struct PandasColumnBindData { unique_ptr mask; //! Only for categorical types string internal_categorical_type; - //! When object types are cast we must hold their data somewhere - PythonObjectContainer object_str_val; + //! Hold ownership of objects created during scanning + PythonObjectContainer object_str_val; }; struct Pandas { diff --git a/tools/pythonpkg/src/include/duckdb_python/pybind11/python_object_container.hpp b/tools/pythonpkg/src/include/duckdb_python/pybind11/python_object_container.hpp index cf1f465e79fd..8614f90d4d37 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pybind11/python_object_container.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pybind11/python_object_container.hpp @@ -15,14 +15,8 @@ namespace duckdb { -template -struct PythonAssignmentFunction { - typedef void (*assign_t)(TGT_PY_TYPE &, SRC_PY_TYPE &); -}; - //! Every Python Object Must be created through our container //! The Container ensures that the GIL is HOLD on Python Object Construction/Destruction/Modification -template class PythonObjectContainer { public: PythonObjectContainer() { @@ -33,32 +27,21 @@ class PythonObjectContainer { py_obj.clear(); } - unique_ptr GetLock() { - return make_uniq(); - } - - template - void AssignInternal(typename PythonAssignmentFunction::assign_t lambda, - NEW_PY_TYPE &new_value, PythonGILWrapper &lock) { - PY_TYPE obj; - lambda(obj, new_value); - PushInternal(lock, obj); + void Push(py::object &&obj) { + py::gil_scoped_acquire gil; + PushInternal(std::move(obj)); } - void PushInternal(PythonGILWrapper &lock, PY_TYPE obj) { - py_obj.push_back(obj); + const py::object &LastAddedObject() { + D_ASSERT(!py_obj.empty()); + return py_obj.back(); } - void Push(PY_TYPE obj) { - auto lock = GetLock(); - PushInternal(*lock, std::move(obj)); - } - - const PY_TYPE *GetPointerTop() { - return &py_obj.back(); +private: + void PushInternal(py::object &&obj) { + py_obj.emplace_back(obj); } -private: - vector py_obj; + vector py_obj; }; } // namespace duckdb diff --git a/tools/pythonpkg/src/include/duckdb_python/pyfilesystem.hpp b/tools/pythonpkg/src/include/duckdb_python/pyfilesystem.hpp index 31dfd4e4cffe..02f616b7e040 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyfilesystem.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyfilesystem.hpp @@ -85,6 +85,7 @@ class PythonFilesystem : public FileSystem { bool OnDiskFile(FileHandle &handle) override { return false; } + string PathSeparator(const string &path) override; int64_t GetFileSize(FileHandle &handle) override; void RemoveFile(const string &filename) override; void MoveFile(const string &source, const string &dest) override; diff --git a/tools/pythonpkg/src/include/duckdb_python/pyresult.hpp b/tools/pythonpkg/src/include/duckdb_python/pyresult.hpp index c76bf24c49e5..65c340ede9ea 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyresult.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyresult.hpp @@ -59,7 +59,7 @@ struct DuckDBPyResult { void FillNumpy(py::dict &res, idx_t col_idx, NumpyResultConversion &conversion, const char *name); - bool FetchArrowChunk(QueryResult *result, py::list &batches, idx_t rows_per_batch); + bool FetchArrowChunk(ChunkScanState &scan_state, py::list &batches, idx_t rows_per_batch); PandasDataFrame FrameFromNumpy(bool date_as_object, const py::handle &o); diff --git a/tools/pythonpkg/src/include/duckdb_python/python_objects.hpp b/tools/pythonpkg/src/include/duckdb_python/python_objects.hpp index 0abaabb2cbd7..8b9c8f4eb4d5 100644 --- a/tools/pythonpkg/src/include/duckdb_python/python_objects.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/python_objects.hpp @@ -37,7 +37,7 @@ struct PyDictionary { idx_t len; public: - PyObject *operator[](const py::object &obj) const { + py::handle operator[](const py::object &obj) const { return PyDict_GetItem(dict.ptr(), obj.ptr()); } @@ -117,9 +117,9 @@ struct PyTimeDelta { interval_t ToInterval(); private: - static int64_t GetDays(PyObject *obj); - static int64_t GetSeconds(PyObject *obj); - static int64_t GetMicros(PyObject *obj); + static int64_t GetDays(py::handle &obj); + static int64_t GetSeconds(py::handle &obj); + static int64_t GetMicros(py::handle &obj); }; struct PyTime { @@ -130,18 +130,18 @@ struct PyTime { int32_t minute; int32_t second; int32_t microsecond; - PyObject *timezone_obj; + py::object timezone_obj; public: dtime_t ToDuckTime(); Value ToDuckValue(); private: - static int32_t GetHours(PyObject *obj); - static int32_t GetMinutes(PyObject *obj); - static int32_t GetSeconds(PyObject *obj); - static int32_t GetMicros(PyObject *obj); - static PyObject *GetTZInfo(PyObject *obj); + static int32_t GetHours(py::handle &obj); + static int32_t GetMinutes(py::handle &obj); + static int32_t GetSeconds(py::handle &obj); + static int32_t GetMicros(py::handle &obj); + static py::object GetTZInfo(py::handle &obj); }; struct PyDateTime { @@ -155,23 +155,25 @@ struct PyDateTime { int32_t minute; int32_t second; int32_t micros; - PyObject *tzone_obj; + py::object tzone_obj; public: timestamp_t ToTimestamp(); date_t ToDate(); dtime_t ToDuckTime(); Value ToDuckValue(const LogicalType &target_type); - -public: - static int32_t GetYears(PyObject *obj); - static int32_t GetMonths(PyObject *obj); - static int32_t GetDays(PyObject *obj); - static int32_t GetHours(PyObject *obj); - static int32_t GetMinutes(PyObject *obj); - static int32_t GetSeconds(PyObject *obj); - static int32_t GetMicros(PyObject *obj); - static PyObject *GetTZInfo(PyObject *obj); + bool IsPositiveInfinity() const; + bool IsNegativeInfinity() const; + +public: + static int32_t GetYears(py::handle &obj); + static int32_t GetMonths(py::handle &obj); + static int32_t GetDays(py::handle &obj); + static int32_t GetHours(py::handle &obj); + static int32_t GetMinutes(py::handle &obj); + static int32_t GetSeconds(py::handle &obj); + static int32_t GetMicros(py::handle &obj); + static py::object GetTZInfo(py::handle &obj); }; struct PyDate { @@ -183,6 +185,8 @@ struct PyDate { public: Value ToDuckValue(); + bool IsPositiveInfinity() const; + bool IsNegativeInfinity() const; }; struct PyTimezone { @@ -190,7 +194,7 @@ struct PyTimezone { PyTimezone() = delete; public: - DUCKDB_API static interval_t GetUTCOffset(PyObject *tzone_obj); + DUCKDB_API static interval_t GetUTCOffset(py::handle &tzone_obj); }; struct PythonObject { diff --git a/tools/pythonpkg/src/include/duckdb_python/pyutil.hpp b/tools/pythonpkg/src/include/duckdb_python/pyutil.hpp index 38487020bf56..ca19af816ad9 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyutil.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyutil.hpp @@ -6,28 +6,28 @@ namespace duckdb { struct PyUtil { - static idx_t PyByteArrayGetSize(PyObject *obj) { - return PyByteArray_GET_SIZE(obj); // NOLINT + static idx_t PyByteArrayGetSize(py::handle &obj) { + return PyByteArray_GET_SIZE(obj.ptr()); // NOLINT } - static Py_buffer *PyMemoryViewGetBuffer(PyObject *obj) { - return PyMemoryView_GET_BUFFER(obj); + static Py_buffer *PyMemoryViewGetBuffer(py::handle &obj) { + return PyMemoryView_GET_BUFFER(obj.ptr()); } - static bool PyUnicodeIsCompactASCII(PyObject *obj) { - return PyUnicode_IS_COMPACT_ASCII(obj); + static bool PyUnicodeIsCompactASCII(py::handle &obj) { + return PyUnicode_IS_COMPACT_ASCII(obj.ptr()); } - static const char *PyUnicodeData(PyObject *obj) { - return const_char_ptr_cast(PyUnicode_DATA(obj)); + static const char *PyUnicodeData(py::handle &obj) { + return const_char_ptr_cast(PyUnicode_DATA(obj.ptr())); } - static char *PyUnicodeDataMutable(PyObject *obj) { - return char_ptr_cast(PyUnicode_DATA(obj)); + static char *PyUnicodeDataMutable(py::handle &obj) { + return char_ptr_cast(PyUnicode_DATA(obj.ptr())); } - static idx_t PyUnicodeGetLength(PyObject *obj) { - return PyUnicode_GET_LENGTH(obj); + static idx_t PyUnicodeGetLength(py::handle &obj) { + return PyUnicode_GET_LENGTH(obj.ptr()); } static bool PyUnicodeIsCompact(PyCompactUnicodeObject *obj) { @@ -38,20 +38,20 @@ struct PyUtil { return PyUnicode_IS_ASCII(obj); } - static int PyUnicodeKind(PyObject *obj) { - return PyUnicode_KIND(obj); + static int PyUnicodeKind(py::handle &obj) { + return PyUnicode_KIND(obj.ptr()); } - static Py_UCS1 *PyUnicode1ByteData(PyObject *obj) { - return PyUnicode_1BYTE_DATA(obj); + static Py_UCS1 *PyUnicode1ByteData(py::handle &obj) { + return PyUnicode_1BYTE_DATA(obj.ptr()); } - static Py_UCS2 *PyUnicode2ByteData(PyObject *obj) { - return PyUnicode_2BYTE_DATA(obj); + static Py_UCS2 *PyUnicode2ByteData(py::handle &obj) { + return PyUnicode_2BYTE_DATA(obj.ptr()); } - static Py_UCS4 *PyUnicode4ByteData(PyObject *obj) { - return PyUnicode_4BYTE_DATA(obj); + static Py_UCS4 *PyUnicode4ByteData(py::handle &obj) { + return PyUnicode_4BYTE_DATA(obj.ptr()); } }; diff --git a/tools/pythonpkg/src/map.cpp b/tools/pythonpkg/src/map.cpp index 54e580fd017a..0faa7d0308da 100644 --- a/tools/pythonpkg/src/map.cpp +++ b/tools/pythonpkg/src/map.cpp @@ -25,7 +25,7 @@ struct MapFunctionData : public TableFunctionData { vector in_names, out_names; }; -static py::handle FunctionCall(NumpyResultConversion &conversion, const vector &names, PyObject *function) { +static py::object FunctionCall(NumpyResultConversion &conversion, const vector &names, PyObject *function) { py::dict in_numpy_dict; for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { in_numpy_dict[names[col_idx].c_str()] = conversion.ToArray(col_idx); @@ -40,7 +40,7 @@ static py::handle FunctionCall(NumpyResultConversion &conversion, const vector(df_obj); if (df.is_none()) { // no return, probably modified in place throw InvalidInputException("No return value from Python function"); } diff --git a/tools/pythonpkg/src/native/python_conversion.cpp b/tools/pythonpkg/src/native/python_conversion.cpp index 0eee897e75a9..91b063db4967 100644 --- a/tools/pythonpkg/src/native/python_conversion.cpp +++ b/tools/pythonpkg/src/native/python_conversion.cpp @@ -482,15 +482,14 @@ Value TransformPythonValue(py::handle ele, const LogicalType &target_type, bool case PythonObjectType::String: return ele.cast(); case PythonObjectType::ByteArray: { - auto byte_array = ele.ptr(); - const_data_ptr_t bytes = const_data_ptr_cast(PyByteArray_AsString(byte_array)); // NOLINT - idx_t byte_length = PyUtil::PyByteArrayGetSize(byte_array); // NOLINT + auto byte_array = ele; + const_data_ptr_t bytes = const_data_ptr_cast(PyByteArray_AsString(byte_array.ptr())); // NOLINT + idx_t byte_length = PyUtil::PyByteArrayGetSize(byte_array); // NOLINT return Value::BLOB(bytes, byte_length); } case PythonObjectType::MemoryView: { py::memoryview py_view = ele.cast(); - PyObject *py_view_ptr = py_view.ptr(); - Py_buffer *py_buf = PyUtil::PyMemoryViewGetBuffer(py_view_ptr); // NOLINT + Py_buffer *py_buf = PyUtil::PyMemoryViewGetBuffer(py_view); // NOLINT return Value::BLOB(const_data_ptr_t(py_buf->buf), idx_t(py_buf->len)); } case PythonObjectType::Bytes: { diff --git a/tools/pythonpkg/src/native/python_objects.cpp b/tools/pythonpkg/src/native/python_objects.cpp index 64e2e251ccee..95c3214b4637 100644 --- a/tools/pythonpkg/src/native/python_objects.cpp +++ b/tools/pythonpkg/src/native/python_objects.cpp @@ -20,10 +20,9 @@ PyDictionary::PyDictionary(py::object dict) { } PyTimeDelta::PyTimeDelta(py::handle &obj) { - auto ptr = obj.ptr(); - days = PyTimeDelta::GetDays(ptr); - seconds = PyTimeDelta::GetSeconds(ptr); - microseconds = PyTimeDelta::GetMicros(ptr); + days = PyTimeDelta::GetDays(obj); + seconds = PyTimeDelta::GetSeconds(obj); + microseconds = PyTimeDelta::GetMicros(obj); } interval_t PyTimeDelta::ToInterval() { @@ -42,16 +41,16 @@ interval_t PyTimeDelta::ToInterval() { return interval; } -int64_t PyTimeDelta::GetDays(PyObject *obj) { - return PyDateTime_TIMEDELTA_GET_DAYS(obj); // NOLINT +int64_t PyTimeDelta::GetDays(py::handle &obj) { + return PyDateTime_TIMEDELTA_GET_DAYS(obj.ptr()); // NOLINT } -int64_t PyTimeDelta::GetSeconds(PyObject *obj) { - return PyDateTime_TIMEDELTA_GET_SECONDS(obj); // NOLINT +int64_t PyTimeDelta::GetSeconds(py::handle &obj) { + return PyDateTime_TIMEDELTA_GET_SECONDS(obj.ptr()); // NOLINT } -int64_t PyTimeDelta::GetMicros(PyObject *obj) { - return PyDateTime_TIMEDELTA_GET_MICROSECONDS(obj); // NOLINT +int64_t PyTimeDelta::GetMicros(py::handle &obj) { + return PyDateTime_TIMEDELTA_GET_MICROSECONDS(obj.ptr()); // NOLINT } PyDecimal::PyDecimal(py::handle &obj) : obj(obj) { @@ -209,12 +208,11 @@ Value PyDecimal::ToDuckValue() { } PyTime::PyTime(py::handle &obj) : obj(obj) { - auto ptr = obj.ptr(); - hour = PyTime::GetHours(ptr); // NOLINT - minute = PyTime::GetMinutes(ptr); // NOLINT - second = PyTime::GetSeconds(ptr); // NOLINT - microsecond = PyTime::GetMicros(ptr); // NOLINT - timezone_obj = PyTime::GetTZInfo(ptr); // NOLINT + hour = PyTime::GetHours(obj); // NOLINT + minute = PyTime::GetMinutes(obj); // NOLINT + second = PyTime::GetSeconds(obj); // NOLINT + microsecond = PyTime::GetMicros(obj); // NOLINT + timezone_obj = PyTime::GetTZInfo(obj); // NOLINT } dtime_t PyTime::ToDuckTime() { return Time::FromTime(hour, minute, second, microsecond); @@ -222,7 +220,7 @@ dtime_t PyTime::ToDuckTime() { Value PyTime::ToDuckValue() { auto duckdb_time = this->ToDuckTime(); - if (this->timezone_obj != Py_None) { + if (!py::none().is(this->timezone_obj)) { auto utc_offset = PyTimezone::GetUTCOffset(this->timezone_obj); // 'Add' requires a date_t for overflows date_t ignored_date; @@ -232,43 +230,42 @@ Value PyTime::ToDuckValue() { return Value::TIME(duckdb_time); } -int32_t PyTime::GetHours(PyObject *obj) { - return PyDateTime_TIME_GET_HOUR(obj); // NOLINT +int32_t PyTime::GetHours(py::handle &obj) { + return PyDateTime_TIME_GET_HOUR(obj.ptr()); // NOLINT } -int32_t PyTime::GetMinutes(PyObject *obj) { - return PyDateTime_TIME_GET_MINUTE(obj); // NOLINT +int32_t PyTime::GetMinutes(py::handle &obj) { + return PyDateTime_TIME_GET_MINUTE(obj.ptr()); // NOLINT } -int32_t PyTime::GetSeconds(PyObject *obj) { - return PyDateTime_TIME_GET_SECOND(obj); // NOLINT +int32_t PyTime::GetSeconds(py::handle &obj) { + return PyDateTime_TIME_GET_SECOND(obj.ptr()); // NOLINT } -int32_t PyTime::GetMicros(PyObject *obj) { - return PyDateTime_TIME_GET_MICROSECOND(obj); // NOLINT +int32_t PyTime::GetMicros(py::handle &obj) { + return PyDateTime_TIME_GET_MICROSECOND(obj.ptr()); // NOLINT } -PyObject *PyTime::GetTZInfo(PyObject *obj) { - return PyDateTime_TIME_GET_TZINFO(obj); // NOLINT +py::object PyTime::GetTZInfo(py::handle &obj) { + // The object returned is borrowed, there is no reference to steal + return py::reinterpret_borrow(PyDateTime_TIME_GET_TZINFO(obj.ptr())); // NOLINT } -interval_t PyTimezone::GetUTCOffset(PyObject *tzone_obj) { - auto tzinfo = py::reinterpret_borrow(tzone_obj); - auto res = tzinfo.attr("utcoffset")(py::none()); +interval_t PyTimezone::GetUTCOffset(py::handle &tzone_obj) { + auto res = tzone_obj.attr("utcoffset")(py::none()); auto timedelta = PyTimeDelta(res); return timedelta.ToInterval(); } PyDateTime::PyDateTime(py::handle &obj) : obj(obj) { - auto ptr = obj.ptr(); - year = PyDateTime::GetYears(ptr); - month = PyDateTime::GetMonths(ptr); - day = PyDateTime::GetDays(ptr); - hour = PyDateTime::GetHours(ptr); - minute = PyDateTime::GetMinutes(ptr); - second = PyDateTime::GetSeconds(ptr); - micros = PyDateTime::GetMicros(ptr); - tzone_obj = PyDateTime::GetTZInfo(ptr); + year = PyDateTime::GetYears(obj); + month = PyDateTime::GetMonths(obj); + day = PyDateTime::GetDays(obj); + hour = PyDateTime::GetHours(obj); + minute = PyDateTime::GetMinutes(obj); + second = PyDateTime::GetSeconds(obj); + micros = PyDateTime::GetMicros(obj); + tzone_obj = PyDateTime::GetTZInfo(obj); } timestamp_t PyDateTime::ToTimestamp() { @@ -277,9 +274,25 @@ timestamp_t PyDateTime::ToTimestamp() { return Timestamp::FromDatetime(date, time); } +bool PyDateTime::IsPositiveInfinity() const { + return year == 9999 && month == 12 && day == 31 && hour == 23 && minute == 59 && second == 59 && micros == 999999; +} + +bool PyDateTime::IsNegativeInfinity() const { + return year == 1 && month == 1 && day == 1 && hour == 0 && minute == 0 && second == 0 && micros == 0; +} + Value PyDateTime::ToDuckValue(const LogicalType &target_type) { + if (IsPositiveInfinity()) { + // FIXME: respect the target_type ? + return Value::TIMESTAMP(timestamp_t::infinity()); + } + if (IsNegativeInfinity()) { + // FIXME: respect the target_type ? + return Value::TIMESTAMP(timestamp_t::ninfinity()); + } auto timestamp = ToTimestamp(); - if (tzone_obj != Py_None) { + if (!py::none().is(tzone_obj)) { auto utc_offset = PyTimezone::GetUTCOffset(tzone_obj); // Need to subtract the UTC offset, so we invert the interval utc_offset = Interval::Invert(utc_offset); @@ -310,53 +323,79 @@ dtime_t PyDateTime::ToDuckTime() { return Time::FromTime(hour, minute, second, micros); } -int32_t PyDateTime::GetYears(PyObject *obj) { - return PyDateTime_GET_YEAR(obj); // NOLINT +int32_t PyDateTime::GetYears(py::handle &obj) { + return PyDateTime_GET_YEAR(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetMonths(PyObject *obj) { - return PyDateTime_GET_MONTH(obj); // NOLINT +int32_t PyDateTime::GetMonths(py::handle &obj) { + return PyDateTime_GET_MONTH(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetDays(PyObject *obj) { - return PyDateTime_GET_DAY(obj); // NOLINT +int32_t PyDateTime::GetDays(py::handle &obj) { + return PyDateTime_GET_DAY(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetHours(PyObject *obj) { - return PyDateTime_DATE_GET_HOUR(obj); // NOLINT +int32_t PyDateTime::GetHours(py::handle &obj) { + return PyDateTime_DATE_GET_HOUR(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetMinutes(PyObject *obj) { - return PyDateTime_DATE_GET_MINUTE(obj); // NOLINT +int32_t PyDateTime::GetMinutes(py::handle &obj) { + return PyDateTime_DATE_GET_MINUTE(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetSeconds(PyObject *obj) { - return PyDateTime_DATE_GET_SECOND(obj); // NOLINT +int32_t PyDateTime::GetSeconds(py::handle &obj) { + return PyDateTime_DATE_GET_SECOND(obj.ptr()); // NOLINT } -int32_t PyDateTime::GetMicros(PyObject *obj) { - return PyDateTime_DATE_GET_MICROSECOND(obj); // NOLINT +int32_t PyDateTime::GetMicros(py::handle &obj) { + return PyDateTime_DATE_GET_MICROSECOND(obj.ptr()); // NOLINT } -PyObject *PyDateTime::GetTZInfo(PyObject *obj) { - return PyDateTime_DATE_GET_TZINFO(obj); // NOLINT +py::object PyDateTime::GetTZInfo(py::handle &obj) { + // The object returned is borrowed, there is no reference to steal + return py::reinterpret_borrow(PyDateTime_DATE_GET_TZINFO(obj.ptr())); // NOLINT } PyDate::PyDate(py::handle &ele) { - auto ptr = ele.ptr(); - year = PyDateTime::GetYears(ptr); - month = PyDateTime::GetMonths(ptr); - day = PyDateTime::GetDays(ptr); + year = PyDateTime::GetYears(ele); + month = PyDateTime::GetMonths(ele); + day = PyDateTime::GetDays(ele); } Value PyDate::ToDuckValue() { + if (IsPositiveInfinity()) { + return Value::DATE(date_t::infinity()); + } + if (IsNegativeInfinity()) { + return Value::DATE(date_t::ninfinity()); + } return Value::DATE(year, month, day); } +bool PyDate::IsPositiveInfinity() const { + return year == 9999 && month == 12 && day == 31; +} + +bool PyDate::IsNegativeInfinity() const { + return year == 1 && month == 1 && day == 1; +} + void PythonObject::Initialize() { PyDateTime_IMPORT; // NOLINT: Python datetime initialize #2 } +enum class InfinityType : uint8_t { NONE, POSITIVE, NEGATIVE }; + +InfinityType GetTimestampInfinityType(timestamp_t ×tamp) { + if (timestamp == timestamp_t::infinity()) { + return InfinityType::POSITIVE; + } + if (timestamp == timestamp_t::ninfinity()) { + return InfinityType::NEGATIVE; + } + return InfinityType::NONE; +} + py::object PythonObject::FromValue(const Value &val, const LogicalType &type) { auto &import_cache = *DuckDBPyConnection::ImportCache(); if (val.IsNull()) { @@ -408,6 +447,8 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type) { case LogicalTypeId::TIMESTAMP_TZ: { D_ASSERT(type.InternalType() == PhysicalType::INT64); auto timestamp = val.GetValueUnsafe(); + + InfinityType infinity = InfinityType::NONE; if (type.id() == LogicalTypeId::TIMESTAMP_MS) { timestamp = Timestamp::FromEpochMs(timestamp.value); } else if (type.id() == LogicalTypeId::TIMESTAMP_NS) { @@ -415,6 +456,19 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type) { } else if (type.id() == LogicalTypeId::TIMESTAMP_SEC) { timestamp = Timestamp::FromEpochSeconds(timestamp.value); } + infinity = GetTimestampInfinityType(timestamp); + + // Deal with infinity + switch (infinity) { + case InfinityType::POSITIVE: { + return py::reinterpret_borrow(import_cache.datetime().datetime.max()); + } + case InfinityType::NEGATIVE: { + return py::reinterpret_borrow(import_cache.datetime().datetime.min()); + } + case InfinityType::NONE: + break; + } int32_t year, month, day, hour, min, sec, micros; date_t date; dtime_t time; @@ -437,6 +491,12 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type) { auto date = val.GetValueUnsafe(); int32_t year, month, day; + if (!duckdb::Date::IsFinite(date)) { + if (date == date_t::infinity()) { + return py::reinterpret_borrow(import_cache.datetime().date.max()); + } + return py::reinterpret_borrow(import_cache.datetime().date.min()); + } duckdb::Date::Convert(date, year, month, day); return py::reinterpret_steal(PyDate_FromDate(year, month, day)); } diff --git a/tools/pythonpkg/src/numpy/array_wrapper.cpp b/tools/pythonpkg/src/numpy/array_wrapper.cpp index 4e185eb24084..d68dc53ca325 100644 --- a/tools/pythonpkg/src/numpy/array_wrapper.cpp +++ b/tools/pythonpkg/src/numpy/array_wrapper.cpp @@ -162,19 +162,20 @@ struct StringConvert { // based on the max codepoint, we construct the result string auto result = PyUnicode_New(start_pos + codepoint_count, max_codepoint); // based on the resulting unicode kind, we fill in the code points - auto kind = PyUtil::PyUnicodeKind(result); + auto result_handle = py::handle(result); + auto kind = PyUtil::PyUnicodeKind(result_handle); switch (kind) { case PyUnicode_1BYTE_KIND: - ConvertUnicodeValueTemplated(PyUtil::PyUnicode1ByteData(result), codepoints, codepoint_count, data, - start_pos); + ConvertUnicodeValueTemplated(PyUtil::PyUnicode1ByteData(result_handle), codepoints, + codepoint_count, data, start_pos); break; case PyUnicode_2BYTE_KIND: - ConvertUnicodeValueTemplated(PyUtil::PyUnicode2ByteData(result), codepoints, codepoint_count, data, - start_pos); + ConvertUnicodeValueTemplated(PyUtil::PyUnicode2ByteData(result_handle), codepoints, + codepoint_count, data, start_pos); break; case PyUnicode_4BYTE_KIND: - ConvertUnicodeValueTemplated(PyUtil::PyUnicode4ByteData(result), codepoints, codepoint_count, data, - start_pos); + ConvertUnicodeValueTemplated(PyUtil::PyUnicode4ByteData(result_handle), codepoints, + codepoint_count, data, start_pos); break; default: throw NotImplementedException("Unsupported typekind constant '%d' for Python Unicode Compact decode", kind); @@ -198,7 +199,8 @@ struct StringConvert { // no unicode: fast path // directly construct the string and memcpy it auto result = PyUnicode_New(len, 127); - auto target_data = PyUtil::PyUnicodeDataMutable(result); + auto result_handle = py::handle(result); + auto target_data = PyUtil::PyUnicodeDataMutable(result_handle); memcpy(target_data, data, len); return result; } diff --git a/tools/pythonpkg/src/numpy/numpy_bind.cpp b/tools/pythonpkg/src/numpy/numpy_bind.cpp index 937f7c458491..e4614503e34a 100644 --- a/tools/pythonpkg/src/numpy/numpy_bind.cpp +++ b/tools/pythonpkg/src/numpy/numpy_bind.cpp @@ -52,7 +52,7 @@ void NumpyBind::Bind(const ClientContext &context, py::handle df, vector(pandas_col); diff --git a/tools/pythonpkg/src/numpy/numpy_scan.cpp b/tools/pythonpkg/src/numpy/numpy_scan.cpp index d1c420877fa9..b5a8eb5978cc 100644 --- a/tools/pythonpkg/src/numpy/numpy_scan.cpp +++ b/tools/pythonpkg/src/numpy/numpy_scan.cpp @@ -317,26 +317,22 @@ void NumpyScan::Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, } if (!py::isinstance(val)) { if (!gil) { - gil = bind_data.object_str_val.GetLock(); + gil = make_uniq(); } - bind_data.object_str_val.AssignInternal( - [](py::str &obj, PyObject &new_val) { - py::handle object_handle = &new_val; - obj = py::str(object_handle); - }, - *val, *gil); - val = reinterpret_cast(bind_data.object_str_val.GetPointerTop()->ptr()); + bind_data.object_str_val.Push(std::move(py::str(val))); + val = reinterpret_cast(bind_data.object_str_val.LastAddedObject().ptr()); } } // Python 3 string representation: // https://github.com/python/cpython/blob/3a8fdb28794b2f19f6c8464378fb8b46bce1f5f4/Include/cpython/unicodeobject.h#L79 - if (!py::isinstance(val)) { + py::handle val_handle(val); + if (!py::isinstance(val_handle)) { out_mask.SetInvalid(row); continue; } - if (PyUtil::PyUnicodeIsCompactASCII(val)) { + if (PyUtil::PyUnicodeIsCompactASCII(val_handle)) { // ascii string: we can zero copy - tgt_ptr[row] = string_t(PyUtil::PyUnicodeData(val), PyUtil::PyUnicodeGetLength(val)); + tgt_ptr[row] = string_t(PyUtil::PyUnicodeData(val_handle), PyUtil::PyUnicodeGetLength(val_handle)); } else { // unicode gunk auto ascii_obj = reinterpret_cast(val); @@ -347,19 +343,19 @@ void NumpyScan::Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, tgt_ptr[row] = string_t(const_char_ptr_cast(unicode_obj->utf8), unicode_obj->utf8_length); } else if (PyUtil::PyUnicodeIsCompact(unicode_obj) && !PyUtil::PyUnicodeIsASCII(unicode_obj)) { // NOLINT - auto kind = PyUtil::PyUnicodeKind(val); + auto kind = PyUtil::PyUnicodeKind(val_handle); switch (kind) { case PyUnicode_1BYTE_KIND: - tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode1ByteData(val), - PyUtil::PyUnicodeGetLength(val), out); + tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode1ByteData(val_handle), + PyUtil::PyUnicodeGetLength(val_handle), out); break; case PyUnicode_2BYTE_KIND: - tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode2ByteData(val), - PyUtil::PyUnicodeGetLength(val), out); + tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode2ByteData(val_handle), + PyUtil::PyUnicodeGetLength(val_handle), out); break; case PyUnicode_4BYTE_KIND: - tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode4ByteData(val), - PyUtil::PyUnicodeGetLength(val), out); + tgt_ptr[row] = DecodePythonUnicode(PyUtil::PyUnicode4ByteData(val_handle), + PyUtil::PyUnicodeGetLength(val_handle), out); break; default: throw NotImplementedException( diff --git a/tools/pythonpkg/src/pandas/analyzer.cpp b/tools/pythonpkg/src/pandas/analyzer.cpp index 277239d3ae2a..32ea4242f395 100644 --- a/tools/pythonpkg/src/pandas/analyzer.cpp +++ b/tools/pythonpkg/src/pandas/analyzer.cpp @@ -143,7 +143,7 @@ static bool UpgradeType(LogicalType &left, const LogicalType &right) { return true; } -LogicalType PandasAnalyzer::GetListType(py::handle &ele, bool &can_convert) { +LogicalType PandasAnalyzer::GetListType(py::object &ele, bool &can_convert) { auto size = py::len(ele); if (size == 0) { @@ -153,7 +153,8 @@ LogicalType PandasAnalyzer::GetListType(py::handle &ele, bool &can_convert) { idx_t i = 0; LogicalType list_type = LogicalType::SQLNULL; for (auto py_val : ele) { - auto item_type = GetItemType(py_val, can_convert); + auto object = py::reinterpret_borrow(py_val); + auto item_type = GetItemType(object, can_convert); if (!i) { list_type = item_type; } else { @@ -257,7 +258,7 @@ LogicalType PandasAnalyzer::DictToStruct(const PyDictionary &dict, bool &can_con //! e.g python lists can consist of multiple different types, which we cant communicate downwards through //! LogicalType's alone -LogicalType PandasAnalyzer::GetItemType(py::handle ele, bool &can_convert) { +LogicalType PandasAnalyzer::GetItemType(py::object ele, bool &can_convert) { auto object_type = GetPythonObjectType(ele); switch (object_type) { @@ -362,7 +363,7 @@ uint64_t PandasAnalyzer::GetSampleIncrement(idx_t rows) { return rows / sample; } -LogicalType PandasAnalyzer::InnerAnalyze(py::handle column, bool &can_convert, bool sample, idx_t increment) { +LogicalType PandasAnalyzer::InnerAnalyze(py::object column, bool &can_convert, bool sample, idx_t increment) { idx_t rows = py::len(column); if (!rows) { @@ -406,13 +407,13 @@ LogicalType PandasAnalyzer::InnerAnalyze(py::handle column, bool &can_convert, b return item_type; } -bool PandasAnalyzer::Analyze(py::handle column) { +bool PandasAnalyzer::Analyze(py::object column) { // Disable analyze if (sample_size == 0) { return false; } bool can_convert = true; - LogicalType type = InnerAnalyze(column, can_convert); + LogicalType type = InnerAnalyze(std::move(column), can_convert); if (can_convert) { analyzed_type = type; } diff --git a/tools/pythonpkg/src/pandas/bind.cpp b/tools/pythonpkg/src/pandas/bind.cpp index f0b80dbf9020..b7e1bf29c893 100644 --- a/tools/pythonpkg/src/pandas/bind.cpp +++ b/tools/pythonpkg/src/pandas/bind.cpp @@ -76,7 +76,7 @@ static LogicalType BindColumn(PandasBindColumn &column_p, PandasColumnBindData & enum_entries_ptr[i] = StringVector::AddStringOrBlob(enum_entries_vec, enum_entries[i]); } D_ASSERT(py::hasattr(column.attr("cat"), "codes")); - column_type = LogicalType::ENUM(enum_name, enum_entries_vec, size); + column_type = LogicalType::ENUM(enum_entries_vec, size); auto pandas_col = py::array(column.attr("cat").attr("codes")); bind_data.internal_categorical_type = string(py::str(pandas_col.attr("dtype"))); bind_data.pandas_col = make_uniq(pandas_col); diff --git a/tools/pythonpkg/src/pyfilesystem.cpp b/tools/pythonpkg/src/pyfilesystem.cpp index 78b88dfe99bc..50fafd8aeccf 100644 --- a/tools/pythonpkg/src/pyfilesystem.cpp +++ b/tools/pythonpkg/src/pyfilesystem.cpp @@ -115,6 +115,9 @@ vector PythonFilesystem::Glob(const string &path, FileOpener *opener) { } return results; } +string PythonFilesystem::PathSeparator(const string &path) { + return "/"; +} int64_t PythonFilesystem::GetFileSize(FileHandle &handle) { // TODO: this value should be cached on the PythonFileHandle PythonGILWrapper gil; diff --git a/tools/pythonpkg/src/pyrelation.cpp b/tools/pythonpkg/src/pyrelation.cpp index 984c09a089b5..02aee6bf1fb1 100644 --- a/tools/pythonpkg/src/pyrelation.cpp +++ b/tools/pythonpkg/src/pyrelation.cpp @@ -634,7 +634,7 @@ unique_ptr DuckDBPyRelation::GetAttribute(const string &name) return make_uniq(rel->Project({StringUtil::Format("%s.%s", names[0], name)})); } if (ContainsColumnByName(name)) { - return make_uniq(rel->Project({name})); + return make_uniq(rel->Project({StringUtil::Format("\"%s\"", name)})); } throw py::attribute_error(StringUtil::Format("This relation does not contain a column by the name of '%s'", name)); } diff --git a/tools/pythonpkg/src/pyresult.cpp b/tools/pythonpkg/src/pyresult.cpp index 9bab3fda9ef6..ebcdb5860972 100644 --- a/tools/pythonpkg/src/pyresult.cpp +++ b/tools/pythonpkg/src/pyresult.cpp @@ -16,6 +16,7 @@ #include "duckdb_python/numpy/array_wrapper.hpp" #include "duckdb/common/exception.hpp" #include "duckdb_python/arrow/arrow_export_utils.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" namespace duckdb { @@ -287,19 +288,20 @@ py::dict DuckDBPyResult::FetchTF() { return result_dict; } -bool DuckDBPyResult::FetchArrowChunk(QueryResult *query_result, py::list &batches, idx_t rows_per_batch) { +bool DuckDBPyResult::FetchArrowChunk(ChunkScanState &scan_state, py::list &batches, idx_t rows_per_batch) { ArrowArray data; idx_t count; + auto &query_result = *result.get(); + auto arrow_options = query_result.GetArrowOptions(query_result); { py::gil_scoped_release release; - count = ArrowUtil::FetchChunk(query_result, rows_per_batch, &data); + count = ArrowUtil::FetchChunk(scan_state, arrow_options, rows_per_batch, &data); } if (count == 0) { return false; } ArrowSchema arrow_schema; - ArrowConverter::ToArrowSchema(&arrow_schema, query_result->types, query_result->names, - QueryResult::GetArrowOptions(*query_result)); + ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, query_result.names, arrow_options); TransformDuckToArrowChunk(arrow_schema, data, batches); return true; } @@ -311,8 +313,8 @@ py::list DuckDBPyResult::FetchAllArrowChunks(idx_t rows_per_batch) { auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib"); py::list batches; - - while (FetchArrowChunk(result.get(), batches, rows_per_batch)) { + QueryResultChunkScanState scan_state(*result.get()); + while (FetchArrowChunk(scan_state, batches, rows_per_batch)) { } return batches; } diff --git a/tools/pythonpkg/src/python_import_cache.cpp b/tools/pythonpkg/src/python_import_cache.cpp index e245705bae3f..b5b00e46a66e 100644 --- a/tools/pythonpkg/src/python_import_cache.cpp +++ b/tools/pythonpkg/src/python_import_cache.cpp @@ -20,7 +20,7 @@ bool PythonImportCacheItem::IsLoaded() const { return type.ptr() != nullptr; } -PyObject *PythonImportCacheItem::AddCache(PythonImportCache &cache, py::object object) { +py::handle PythonImportCacheItem::AddCache(PythonImportCache &cache, py::object object) { return cache.AddCache(std::move(object)); } @@ -59,7 +59,7 @@ PythonImportCache::~PythonImportCache() { owned_objects.clear(); } -PyObject *PythonImportCache::AddCache(py::object item) { +py::handle PythonImportCache::AddCache(py::object item) { auto object_ptr = item.ptr(); owned_objects.push_back(std::move(item)); return object_ptr; diff --git a/tools/pythonpkg/tests/fast/api/test_3728.py b/tools/pythonpkg/tests/fast/api/test_3728.py index aa5864c23f1c..da0d2015a4bf 100644 --- a/tools/pythonpkg/tests/fast/api/test_3728.py +++ b/tools/pythonpkg/tests/fast/api/test_3728.py @@ -15,5 +15,5 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ ('name', 'STRING', None, None, None, None, None), - ('current_mood', 'mood', None, None, None, None, None), + ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tools/pythonpkg/tests/fast/api/test_attribute_getter.py b/tools/pythonpkg/tests/fast/api/test_attribute_getter.py index dd2955db1d09..3ad1e027d66e 100644 --- a/tools/pythonpkg/tests/fast/api/test_attribute_getter.py +++ b/tools/pythonpkg/tests/fast/api/test_attribute_getter.py @@ -53,3 +53,7 @@ def test_getattr_struct(self): rel = duckdb.sql("select {'a':5, 'b':6} as a, 5 as b") assert rel.a.a.fetchall()[0][0] == 5 assert rel.a.b.fetchall()[0][0] == 6 + + def test_getattr_spaces(self): + rel = duckdb.sql('select 42 as "hello world"') + assert rel['hello world'].fetchall()[0][0] == 42 diff --git a/tools/pythonpkg/tests/fast/arrow/test_5547.py b/tools/pythonpkg/tests/fast/arrow/test_5547.py new file mode 100644 index 000000000000..5a376ad7f57e --- /dev/null +++ b/tools/pythonpkg/tests/fast/arrow/test_5547.py @@ -0,0 +1,37 @@ +import duckdb +import pandas as pd +from pandas.testing import assert_frame_equal +import pytest + +pa = pytest.importorskip('pyarrow') + + +def test_5547(): + num_rows = 2**17 + 1 + + tbl = pa.Table.from_pandas( + pd.DataFrame.from_records( + [ + dict( + id=i, + nested=dict( + a=i, + ), + ) + for i in range(num_rows) + ] + ) + ) + + con = duckdb.connect() + expected = tbl.to_pandas() + result = con.execute( + """ + SELECT * + FROM tbl + """ + ).df() + + assert_frame_equal(expected, result) + + con.close() diff --git a/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py index b0d0cc156e6b..dc484f1bb2f3 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py +++ b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py @@ -42,5 +42,6 @@ def test_copy_on_write(self, col): ) rel = con.sql('select * from df_in') res = rel.fetchall() + print(res) expected = convert_to_result(col) assert res == expected diff --git a/tools/pythonpkg/tests/fast/test_adbc.py b/tools/pythonpkg/tests/fast/test_adbc.py new file mode 100644 index 000000000000..b56852811795 --- /dev/null +++ b/tools/pythonpkg/tests/fast/test_adbc.py @@ -0,0 +1,94 @@ +import duckdb +import pytest +import sys +import datetime +import os + +adbc_driver_manager = pytest.importorskip("adbc_driver_manager.dbapi") +adbc_driver_manager_lib = pytest.importorskip("adbc_driver_manager._lib") + +pyarrow = pytest.importorskip("pyarrow") + + +def test_insertion(): + if sys.platform.startswith("win"): + pytest.xfail("Not supported on Windows") + con = adbc_driver_manager.connect(driver=duckdb.duckdb.__file__, entrypoint="duckdb_adbc_init") + + table = pyarrow.table( + [ + [1, 2, 3, 4], + ["a", "b", None, "d"], + ], + names=["ints", "strs"], + ) + reader = table.to_reader() + + with con.cursor() as cursor: + cursor.adbc_ingest("ingest", reader, "create") + cursor.execute("SELECT * FROM ingest") + assert cursor.fetch_arrow_table() == table + + with con.cursor() as cursor: + cursor.adbc_ingest("ingest_table", table, "create") + cursor.execute("SELECT * FROM ingest") + assert cursor.fetch_arrow_table() == table + + # Test Append + with con.cursor() as cursor: + with pytest.raises( + adbc_driver_manager_lib.InternalError, + match=r'Failed to create table \'ingest_table\': Table with name "ingest_table" already exists!', + ): + cursor.adbc_ingest("ingest_table", table, "create") + cursor.adbc_ingest("ingest_table", table, "append") + cursor.execute("SELECT count(*) FROM ingest_table") + assert cursor.fetch_arrow_table().to_pydict() == {"count_star()": [8]} + + +def test_read(): + if sys.platform.startswith("win"): + pytest.xfail("Not supported on Windows") + con = adbc_driver_manager.connect(driver=duckdb.duckdb.__file__, entrypoint="duckdb_adbc_init") + with con.cursor() as cursor: + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "category.csv") + cursor.execute(f"SELECT * FROM '{filename}'") + assert cursor.fetch_arrow_table().to_pydict() == { + "CATEGORY_ID": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "NAME": [ + "Action", + "Animation", + "Children", + "Classics", + "Comedy", + "Documentary", + "Drama", + "Family", + "Foreign", + "Games", + "Horror", + "Music", + "New", + "Sci-Fi", + "Sports", + "Travel", + ], + "LAST_UPDATE": [ + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + datetime.datetime(2006, 2, 15, 4, 46, 27), + ], + } diff --git a/tools/pythonpkg/tests/fast/test_filesystem.py b/tools/pythonpkg/tests/fast/test_filesystem.py index 94aa9afc5616..ac0d9eed7506 100644 --- a/tools/pythonpkg/tests/fast/test_filesystem.py +++ b/tools/pythonpkg/tests/fast/test_filesystem.py @@ -176,7 +176,8 @@ def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): fs = filesystem('file', skip_instance_cache=True) write_errors = intercept(monkeypatch, LocalFileOpener, 'write') conn.register_filesystem(fs) - conn.execute(f"ATTACH 'file://{db_path}'") + db_path_posix = str(PurePosixPath(tmp_path.as_posix()) / "hello.db") + conn.execute(f"ATTACH 'file://{db_path_posix}'") conn.execute('INSERT INTO hello.t VALUES (1)') @@ -192,19 +193,13 @@ def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Abstrac duckdb_cursor.execute("copy (select 1 as a) to 'memory://root' (partition_by (a))") - assert ( - memory.open('/root\\a=1\\data_0.csv' if sys.platform == 'win32' else '/root/a=1/data_0.csv').read() - == b'1\n' - ) + assert memory.open('/root/a=1/data_0.csv').read() == b'1\n' def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute("copy (select 2 as a) to 'memory://partition' (partition_by (a))") - if sys.platform == 'win32': - path = 'memory:///partition\\*\\*.csv' - else: - path = 'memory:///partition/*/*.csv' + path = 'memory:///partition/*/*.csv' query = "SELECT * FROM read_csv_auto('" + path + "'" diff --git a/tools/pythonpkg/tests/fast/types/test_datetime_date.py b/tools/pythonpkg/tests/fast/types/test_datetime_date.py new file mode 100644 index 000000000000..83973be40224 --- /dev/null +++ b/tools/pythonpkg/tests/fast/types/test_datetime_date.py @@ -0,0 +1,30 @@ +import duckdb +import datetime + + +class TestDateTimeDate(object): + def test_date_infinity(self): + con = duckdb.connect() + # Positive infinity + con.execute("SELECT 'infinity'::DATE") + result = con.fetchall() + # datetime.date.max + assert result == [(datetime.date(9999, 12, 31),)] + + con.execute("SELECT '-infinity'::DATE") + result = con.fetchall() + # datetime.date.min + assert result == [(datetime.date(1, 1, 1),)] + + def test_date_infinity_roundtrip(self): + con = duckdb.connect() + + # positive infinity + con.execute("select $1, $1 = 'infinity'::DATE", [datetime.date.max]) + res = con.fetchall() + assert res == [(datetime.date.max, True)] + + # negative infinity + con.execute("select $1, $1 = '-infinity'::DATE", [datetime.date.min]) + res = con.fetchall() + assert res == [(datetime.date.min, True)] diff --git a/tools/pythonpkg/tests/fast/types/test_datetime_datetime.py b/tools/pythonpkg/tests/fast/types/test_datetime_datetime.py new file mode 100644 index 000000000000..45aff1fca0fe --- /dev/null +++ b/tools/pythonpkg/tests/fast/types/test_datetime_datetime.py @@ -0,0 +1,49 @@ +import duckdb +import datetime +import pytest + + +def create_query(positive, type): + inf = 'infinity' if positive else '-infinity' + return f""" + select '{inf}'::{type} + """ + + +class TestDateTimeDateTime(object): + @pytest.mark.parametrize('positive', [True, False]) + @pytest.mark.parametrize( + 'type', + [ + 'TIMESTAMP', + 'TIMESTAMP_S', + 'TIMESTAMP_MS', + 'TIMESTAMP_NS', + 'TIMESTAMPTZ', + 'TIMESTAMP_US', + ], + ) + def test_timestamp_infinity(self, positive, type): + con = duckdb.connect() + + if type in ['TIMESTAMP_S', 'TIMESTAMP_MS', 'TIMESTAMP_NS']: + # Infinity (both positive and negative) is not supported for non-usecond timetamps + return + + expected_val = datetime.datetime.max if positive else datetime.datetime.min + query = create_query(positive, type) + res = con.sql(query).fetchall()[0][0] + assert res == expected_val + + def test_timestamp_infinity_roundtrip(self): + con = duckdb.connect() + + # positive infinity + con.execute("select $1, $1 = 'infinity'::TIMESTAMP", [datetime.datetime.max]) + res = con.fetchall() + assert res == [(datetime.datetime.max, True)] + + # negative infinity + con.execute("select $1, $1 = '-infinity'::TIMESTAMP", [datetime.datetime.min]) + res = con.fetchall() + assert res == [(datetime.datetime.min, True)] diff --git a/tools/rpkg/src/scan.cpp b/tools/rpkg/src/scan.cpp index bef98b2c7edd..a6de173d7323 100644 --- a/tools/rpkg/src/scan.cpp +++ b/tools/rpkg/src/scan.cpp @@ -108,7 +108,7 @@ static duckdb::unique_ptr DataFrameScanBind(ClientContext &context for (R_xlen_t level_idx = 0; level_idx < levels.size(); level_idx++) { levels_ptr[level_idx] = StringVector::AddString(duckdb_levels, (string)levels[level_idx]); } - duckdb_col_type = LogicalType::ENUM(df_names[col_idx], duckdb_levels, levels.size()); + duckdb_col_type = LogicalType::ENUM(duckdb_levels, levels.size()); break; } case RType::STRING: diff --git a/tools/rpkg/src/statement.cpp b/tools/rpkg/src/statement.cpp index 9758a2962dd6..3cb8cbcef952 100644 --- a/tools/rpkg/src/statement.cpp +++ b/tools/rpkg/src/statement.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/arrow/arrow_wrapper.hpp" #include "duckdb/common/arrow/result_arrow_wrapper.hpp" +#include "duckdb/main/chunk_scan_state/query_result.hpp" #include "duckdb/parser/statement/relation_statement.hpp" @@ -267,13 +268,14 @@ struct AppendableRList { idx_t size = 0; }; -bool FetchArrowChunk(QueryResult *result, AppendableRList &batches_list, ArrowArray &arrow_data, - ArrowSchema &arrow_schema, SEXP batch_import_from_c, SEXP arrow_namespace, idx_t chunk_size) { - auto count = ArrowUtil::FetchChunk(result, chunk_size, &arrow_data); +bool FetchArrowChunk(ChunkScanState &scan_state, ArrowOptions options, AppendableRList &batches_list, + ArrowArray &arrow_data, ArrowSchema &arrow_schema, SEXP batch_import_from_c, SEXP arrow_namespace, + idx_t chunk_size) { + auto count = ArrowUtil::FetchChunk(scan_state, options, chunk_size, &arrow_data); if (count == 0) { return false; } - ArrowConverter::ToArrowSchema(&arrow_schema, result->types, result->names, QueryResult::GetArrowOptions(*result)); + ArrowConverter::ToArrowSchema(&arrow_schema, scan_state.Types(), scan_state.Names(), options); batches_list.PrepAppend(); batches_list.Append(cpp11::safe[Rf_eval](batch_import_from_c, arrow_namespace)); return true; @@ -298,8 +300,10 @@ bool FetchArrowChunk(QueryResult *result, AppendableRList &batches_list, ArrowAr // create data batches AppendableRList batches_list; - while (FetchArrowChunk(result, batches_list, arrow_data, arrow_schema, batch_import_from_c, arrow_namespace, - chunk_size)) { + QueryResultChunkScanState scan_state(*result); + auto arrow_options = result->GetArrowOptions(*result); + while (FetchArrowChunk(scan_state, arrow_options, batches_list, arrow_data, arrow_schema, batch_import_from_c, + arrow_namespace, chunk_size)) { } SET_LENGTH(batches_list.the_list, batches_list.size);