From 488ff8e61423c1542ac9d46c9c75423beef7b00c Mon Sep 17 00:00:00 2001 From: "Alina (Xi) Li" Date: Tue, 28 Oct 2025 16:44:13 -0700 Subject: [PATCH 1/4] Extract implementation of gh-46574 --- .../flight/sql/odbc/odbc_impl/CMakeLists.txt | 4 +- .../odbc/odbc_impl/flight_sql_auth_method.cc | 4 +- .../odbc/odbc_impl/flight_sql_connection.cc | 12 +- .../odbc/odbc_impl/flight_sql_result_set.cc | 6 +- .../odbc/odbc_impl/flight_sql_result_set.h | 1 + .../odbc/odbc_impl/flight_sql_statement.cc | 78 +++++----- .../sql/odbc/odbc_impl/flight_sql_statement.h | 3 +- .../flight_sql_statement_get_tables.cc | 47 +++--- .../flight_sql_statement_get_tables.h | 27 ++-- .../flight_sql_stream_chunk_buffer.cc | 50 +++++-- .../flight_sql_stream_chunk_buffer.h | 6 +- .../flight_sql_stream_chunk_buffer_test.cc | 136 ++++++++++++++++++ .../sql/odbc/odbc_impl/get_info_cache.cc | 15 +- .../sql/odbc/odbc_impl/get_info_cache.h | 4 +- 14 files changed, 295 insertions(+), 98 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt index b232577ee37..c94312fcd22 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt @@ -162,9 +162,11 @@ add_arrow_test(odbc_spi_impl_test accessors/time_array_accessor_test.cc accessors/timestamp_array_accessor_test.cc flight_sql_connection_test.cc + flight_sql_stream_chunk_buffer_test.cc parse_table_types_test.cc json_converter_test.cc record_batch_transformer_test.cc util_test.cc EXTRA_LINK_LIBS - arrow_odbc_spi_impl) + arrow_odbc_spi_impl + arrow_flight_testing_shared) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc index bdf7f71589c..b0090a8cf74 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc @@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler { NoOpClientAuthHandler() {} Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override { - // Write a blank string. The server should ignore this and just accept any Handshake - // request. + // The server should ignore this and just accept any Handshake + // request. Some servers do not allow authentication with no handshakes. return outgoing->Write(std::string()); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc index 479a72f3fea..d8bde79f6ee 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc @@ -100,6 +100,8 @@ inline std::string GetCerts() { return ""; } #endif const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::DRIVER, + FlightSqlConnection::DSN, FlightSqlConnection::HOST, FlightSqlConnection::PORT, FlightSqlConnection::USER, @@ -153,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, auto flight_ssl_configs = LoadFlightSslConfigs(properties); Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); - FlightClientOptions client_options = + client_options_ = BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs); const std::shared_ptr& cookie_factory = GetCookieFactory(); - client_options.middleware.push_back(cookie_factory); + client_options_.middleware.push_back(cookie_factory); std::unique_ptr flight_client; - ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client)); std::unique_ptr auth_method = FlightSqlAuthMethod::FromProperties(flight_client, properties); @@ -364,7 +366,7 @@ void FlightSqlConnection::Close() { std::shared_ptr FlightSqlConnection::CreateStatement() { return std::shared_ptr(new FlightSqlStatement( - diagnostics_, *sql_client_, call_options_, metadata_settings_)); + diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_)); } bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, @@ -410,7 +412,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string& driver_version) : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), odbc_version_(odbc_version), - info_(call_options_, sql_client_, driver_version), + info_(client_options_, call_options_, sql_client_, driver_version), closed_(true) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); attribute_[LOGIN_TIMEOUT] = static_cast(0); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc index 19149b3c48d..80967b9f200 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc @@ -29,12 +29,12 @@ namespace arrow::flight::sql::odbc { FlightSqlResultSet::FlightSqlResultSet( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) : metadata_settings_(metadata_settings), - chunk_buffer_(flight_sql_client, call_options, flight_info, + chunk_buffer_(flight_sql_client, client_options, call_options, flight_info, metadata_settings_.chunk_buffer_capacity), transformer_(transformer), metadata_(transformer diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h index 6083b332824..ac2ae80e010 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h @@ -51,6 +51,7 @@ class FlightSqlResultSet : public ResultSet { ~FlightSqlResultSet() override; FlightSqlResultSet(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc index 30eb1fdf44a..785a04c7b0e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc @@ -41,9 +41,10 @@ using util::ThrowIfNotOK; namespace { -void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement) { +void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement, + const FlightCallOptions& options) { if (prepared_statement != nullptr) { - ThrowIfNotOK(prepared_statement->Close()); + ThrowIfNotOK(prepared_statement->Close(options)); prepared_statement.reset(); } } @@ -52,11 +53,13 @@ void ClosePreparedStatementIfAny(std::shared_ptr& prepared_st FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings) : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), sql_client_(sql_client), + client_options_(std::move(client_options)), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { attribute_[METADATA_ID] = static_cast(SQL_FALSE); @@ -97,7 +100,7 @@ boost::optional FlightSqlStatement::GetAttribute( boost::optional> FlightSqlStatement::Prepare( const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Prepare(call_options_, query); @@ -111,27 +114,30 @@ boost::optional> FlightSqlStatement::Prepare( } bool FlightSqlStatement::ExecutePrepared() { + // GH-47990 TODO: use DCHECK instead of assert assert(prepared_statement_.get() != nullptr); - Result> result = prepared_statement_->Execute(); + Result> result = + prepared_statement_->Execute(call_options_); + ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } bool FlightSqlStatement::Execute(const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Execute(call_options_, query); ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } @@ -146,33 +152,35 @@ std::shared_ptr FlightSqlStatement::GetTables( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* table_type, const ColumnNames& column_names) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); std::vector table_types; if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) && (table_name && table_name->empty())) { - current_result_set_ = GetTablesForSQLAllCatalogs( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && *schema_name == "%") && (table_name && table_name->empty())) { - current_result_set_ = - GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name, - diagnostics_, metadata_settings_); + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, client_options_, call_options_, sql_client_, schema_name, + diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && schema_name->empty()) && (table_name && table_name->empty()) && (table_type && *table_type == "%")) { - current_result_set_ = GetTablesForSQLAllTableTypes( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else { if (table_type) { ParseTableTypes(*table_type, table_types); } current_result_set_ = GetTablesForGenericUse( - column_names, call_options_, sql_client_, catalog_name, schema_name, table_name, - table_types, diagnostics_, metadata_settings_); + column_names, client_options_, call_options_, sql_client_, catalog_name, + schema_name, table_name, table_types, diagnostics_, metadata_settings_); } return current_result_set_; @@ -199,7 +207,7 @@ std::shared_ptr FlightSqlStatement::GetTables_V3( std::shared_ptr FlightSqlStatement::GetColumns_V2( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -210,9 +218,9 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } @@ -220,7 +228,7 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( std::shared_ptr FlightSqlStatement::GetColumns_V3( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -231,15 +239,15 @@ std::shared_ptr FlightSqlStatement::GetColumns_V3( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -249,15 +257,15 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -267,9 +275,9 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h index 36dc245c1d7..3593b2f774d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h @@ -32,6 +32,7 @@ class FlightSqlStatement : public Statement { private: Diagnostics diagnostics_; std::map attribute_; + FlightClientOptions client_options_; FlightCallOptions call_options_; FlightSqlClient& sql_client_; std::shared_ptr current_result_set_; @@ -46,7 +47,7 @@ class FlightSqlStatement : public Statement { public: FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, - FlightCallOptions call_options, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc index 2a6bb8970b4..f50cc4cf2f9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc @@ -69,9 +69,9 @@ void ParseTableTypes(const std::string& table_type, } std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetCatalogs(call_options); std::shared_ptr schema; @@ -89,13 +89,15 @@ std::shared_ptr GetTablesForSQLAllCatalogs( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetDbSchemas(call_options, nullptr, schema_name); @@ -115,14 +117,15 @@ std::shared_ptr GetTablesForSQLAllDbSchemas( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTableTypes(call_options); std::shared_ptr schema; @@ -140,16 +143,17 @@ std::shared_ptr GetTablesForSQLAllTableTypes( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForGenericUse( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTables( call_options, catalog_name, schema_name, table_name, false, &table_types); @@ -168,8 +172,9 @@ std::shared_ptr GetTablesForGenericUse( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h index 31abab91cb5..0c3ad10f97b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h @@ -40,25 +40,26 @@ void ParseTableTypes(const std::string& table_type, std::vector& table_types); std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForGenericUse( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index 25bf04ea507..fa1a4d9f064 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -20,37 +20,67 @@ namespace arrow::flight::sql::odbc { -using arrow::Result; - FlightStreamChunkBuffer::FlightStreamChunkBuffer( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, size_t queue_capacity) + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, + size_t queue_capacity) : queue_(queue_capacity) { - // FIXME: Endpoint iteration should consider endpoints may be at different hosts for (const auto& endpoint : flight_info->endpoints()) { const Ticket& ticket = endpoint.ticket; - auto result = flight_sql_client.DoGet(call_options, ticket); + arrow::Result> result; + std::shared_ptr temp_flight_sql_client; + auto endpoint_locations = endpoint.locations; + if (endpoint_locations.empty()) { + // list of Locations needs to be empty to proceed + result = flight_sql_client.DoGet(call_options, ticket); + } else { + // If it is non-empty, the driver should create a FlightSqlClient to connect to one + // of the specified Locations directly. + + // GH-47117: Currently a new FlightClient will be made for each partition that + // returns a non-empty Location, which is then disposed of. It may be better to + // cache clients because a server may report the same Locations. It would also be + // good to identify when the reported Location is the same as the original + // connection's Location and skip creating a FlightClient in that scenario. + + std::unique_ptr temp_flight_client; + util::ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) + .Value(&temp_flight_client)); + temp_flight_sql_client.reset(new FlightSqlClient(std::move(temp_flight_client))); + + result = temp_flight_sql_client->DoGet(call_options, ticket); + } + util::ThrowIfNotOK(result.status()); std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); - BlockingQueue>::Supplier supplier = [=] { + BlockingQueue, + std::shared_ptr>>::Supplier supplier = [=] { auto result = stream_reader_ptr->Next(); bool is_not_ok = !result.ok(); bool is_not_empty = result.ok() && (result.ValueOrDie().data != nullptr); - return boost::make_optional(is_not_ok || is_not_empty, std::move(result)); + // If result is valid, save the temp Flight SQL Client for future stream reader + // call. temp_flight_sql_client is intentionally null if the list of endpoint + // Locations is empty. + // After all data is fetched from reader, the temp client is closed. + return boost::make_optional( + is_not_ok || is_not_empty, + std::make_pair(std::move(result), temp_flight_sql_client)); }; queue_.AddProducer(std::move(supplier)); } } bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk* chunk) { - Result result; - if (!queue_.Pop(&result)) { + std::pair, std::shared_ptr> + closeable_endpoint_stream_pair; + if (!queue_.Pop(&closeable_endpoint_stream_pair)) { return false; } + Result result = closeable_endpoint_stream_pair.first; if (!result.status().ok()) { Close(); throw DriverException(result.status().message()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h index f59336c984d..772a854eb59 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h @@ -23,11 +23,15 @@ namespace arrow::flight::sql::odbc { +using arrow::Result; + class FlightStreamChunkBuffer { - BlockingQueue> queue_; + BlockingQueue, std::shared_ptr>> + queue_; public: FlightStreamChunkBuffer(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity = 5); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc new file mode 100644 index 00000000000..a3f23ecaaf9 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array.h" + +#include "arrow/testing/gtest_util.h" + +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h" +#include "arrow/flight/sql/odbc/odbc_impl/json_converter.h" +#include "arrow/flight/test_flight_server.h" +#include "arrow/flight/test_util.h" + +#include + +namespace arrow::flight::sql::odbc { + +using arrow::Array; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; + +class FlightStreamChunkBufferTest : public ::testing::Test { + // Sets up two mock servers for each test case. + // This is for testing endpoint iteration only. + + protected: + void SetUp() override { + // Set up server 1 + server1 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options1(location1); + ASSERT_OK(server1->Init(options1)); + ASSERT_OK_AND_ASSIGN(server_location1, + Location::ForGrpcTcp("localhost", server1->port())); + + // Set up server 2 + server2 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options2(location2); + ASSERT_OK(server2->Init(options2)); + ASSERT_OK_AND_ASSIGN(server_location2, + Location::ForGrpcTcp("localhost", server2->port())); + + // Make SQL Client that is connected to server 1 + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location1)); + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server2->Shutdown()); + } + + public: + arrow::flight::Location server_location1; + std::shared_ptr server1; + arrow::flight::Location server_location2; + std::shared_ptr server2; + std::shared_ptr sql_client; +}; + +FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { + // Sever will generate random data for `ticket-ints-1` + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-1"}, {location2}, std::nullopt, {}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; + + auto schema1 = arrow::flight::ExampleIntSchema(); + + return arrow::flight::MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, + 100000, false, ""); +} + +void VerifyArraysContainIntsOnly(std::shared_ptr intArray) { + for (int64_t i = 0; i < intArray->length(); ++i) { + // null values are accepted + if (!intArray->IsNull(i)) { + auto scalar_data = intArray->GetScalar(i).ValueOrDie(); + std::string scalar_str = ConvertToJson(*scalar_data); + ASSERT_TRUE(std::all_of(scalar_str.begin(), scalar_str.end(), ::isdigit)); + } + } +} + +TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + FlightCallOptions options; + FlightInfo info = MultipleEndpointsFlightInfo(server_location1, server_location2); + std::shared_ptr info_ptr = std::make_shared(info); + + FlightStreamChunkBuffer chunk_buffer(*sql_client, client_options, options, info_ptr); + + FlightStreamChunk current_chunk; + + // Server returns 5 batch of results from each endpoints. + // Each batch contains 8 columns + int num_chunks = 0; + while (chunk_buffer.GetNext(¤t_chunk)) { + num_chunks++; + + int num_cols = current_chunk.data->num_columns(); + ASSERT_EQ(8, num_cols); + + for (int i = 0; i < num_cols; i++) { + auto array = current_chunk.data->column(i); + // Each array has random length + ASSERT_GT(array->length(), 0); + + VerifyArraysContainIntsOnly(array); + } + } + + // Verify 5 batches of data is returned by each of the two endpoints. + // In total 10 batches should be returned. + ASSERT_EQ(10, num_chunks); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc index bf2f6b6eca2..7f6ba8042de 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc @@ -199,10 +199,14 @@ inline void SetDefaultIfMissing(std::unordered_map& } // namespace -GetInfoCache::GetInfoCache(FlightCallOptions& call_options, +GetInfoCache::GetInfoCache(FlightClientOptions& client_options, + FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version) - : call_options_(call_options), sql_client_(client), has_server_info_(false) { + : client_options_(client_options), + call_options_(call_options), + sql_client_(client), + has_server_info_(false) { info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; info_[SQL_DRIVER_VER] = util::ConvertToDBMSVer(driver_version); @@ -283,7 +287,8 @@ bool GetInfoCache::LoadInfoFromServer() { arrow::Result> result = sql_client_->GetSqlInfo(call_options_, {}); util::ThrowIfNotOK(result.status()); - FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, result.ValueOrDie()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, client_options_, call_options_, + result.ValueOrDie()); FlightStreamChunk chunk; bool supports_correlation_name = false; @@ -311,8 +316,8 @@ bool GetInfoCache::LoadInfoFromServer() { std::string server_name( reinterpret_cast(scalar->child_value().get())->view()); - // TODO: Consider creating different properties in GetSqlInfo. - // TODO: Investigate if SQL_SERVER_NAME should just be the host + // GH-47855 TODO: Consider creating different properties in GetSqlInfo. + // GH-47856 TODO: Investigate if SQL_SERVER_NAME should just be the host // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for // the DatabaseProductName. info_[SQL_SERVER_NAME] = server_name; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h index d0e0efd159f..a1452e4b466 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h @@ -30,13 +30,15 @@ namespace arrow::flight::sql::odbc { class GetInfoCache { private: std::unordered_map info_; + FlightClientOptions& client_options_; FlightCallOptions& call_options_; std::unique_ptr& sql_client_; std::mutex mutex_; std::atomic has_server_info_; public: - GetInfoCache(FlightCallOptions& call_options, std::unique_ptr& client, + GetInfoCache(FlightClientOptions& client_options, FlightCallOptions& call_options, + std::unique_ptr& client, const std::string& driver_version); void SetProperty(uint16_t property, Connection::Info value); Connection::Info GetInfo(uint16_t info_type); From 62f7b431ce8292a9347cee9c2196fabc80cc04e4 Mon Sep 17 00:00:00 2001 From: "Alina (Xi) Li" Date: Wed, 29 Oct 2025 14:27:28 -0700 Subject: [PATCH 2/4] Add `server->Wait` call --- .../sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc index a3f23ecaaf9..18ae3d8c57e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -66,7 +66,9 @@ class FlightStreamChunkBufferTest : public ::testing::Test { void TearDown() override { ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server1->Wait()); ASSERT_OK(server2->Shutdown()); + ASSERT_OK(server1->Wait()); } public: From 3e168684a30b2b9187d83501fe819b4a29acbf07 Mon Sep 17 00:00:00 2001 From: justing-bq <62349012+justing-bq@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:07:23 -0800 Subject: [PATCH 3/4] Avoid using "using" in Headers --- .../sql/odbc/odbc_impl/attribute_utils.h | 34 ++++++++++--------- .../sql/odbc/odbc_impl/encoding_utils.h | 27 +++++++-------- .../odbc_impl/flight_sql_get_tables_reader.cc | 4 +-- .../odbc_impl/flight_sql_get_tables_reader.h | 6 ++-- .../flight_sql_get_type_info_reader.cc | 1 + .../flight_sql_get_type_info_reader.h | 26 +++++++------- .../flight_sql_stream_chunk_buffer.cc | 2 ++ .../flight_sql_stream_chunk_buffer.h | 5 ++- .../flight/sql/odbc/odbc_impl/spi/statement.h | 4 +-- .../flight/sql/odbc/odbc_impl/system_dsn.h | 8 ++--- .../arrow/flight/sql/odbc/odbc_impl/util.cc | 1 + .../arrow/flight/sql/odbc/odbc_impl/util.h | 33 +++++++++--------- .../flight/sql/odbc/tests/odbc_test_suite.cc | 8 ++--- 13 files changed, 77 insertions(+), 82 deletions(-) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h index 7baea759ede..b70460e5dd9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h @@ -30,10 +30,6 @@ namespace ODBC { -using arrow::flight::sql::odbc::Diagnostics; -using arrow::flight::sql::odbc::DriverException; -using arrow::flight::sql::odbc::WcsToUtf8; - template inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr) { @@ -70,7 +66,7 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER template inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = GetAttributeUTF8(attribute_value, output, output_size, output_len_ptr); if (SQL_SUCCESS_WITH_INFO == result) { @@ -85,10 +81,11 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, O output_size, O* output_len_ptr) { size_t length = ConvertToSqlWChar( attribute_value, reinterpret_cast(output), - is_length_in_bytes ? output_size : output_size * GetSqlWCharSize()); + is_length_in_bytes ? output_size + : output_size * arrow::flight::sql::odbc::GetSqlWCharSize()); if (!is_length_in_bytes) { - length = length / GetSqlWCharSize(); + length = length / arrow::flight::sql::odbc::GetSqlWCharSize(); } if (output_len_ptr) { @@ -97,17 +94,19 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, if (output && output_size < - static_cast(length + (is_length_in_bytes ? GetSqlWCharSize() : 1))) { + static_cast(length + (is_length_in_bytes + ? arrow::flight::sql::odbc::GetSqlWCharSize() + : 1))) { return SQL_SUCCESS_WITH_INFO; } return SQL_SUCCESS; } template -inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, - bool is_length_in_bytes, SQLPOINTER output, - O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { +inline SQLRETURN GetAttributeSQLWCHAR( + const std::string& attribute_value, bool is_length_in_bytes, SQLPOINTER output, + O output_size, O* output_len_ptr, + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output, output_size, output_len_ptr); if (SQL_SUCCESS_WITH_INFO == result) { @@ -120,7 +119,7 @@ template inline SQLRETURN GetStringAttribute(bool is_unicode, const std::string& attribute_value, bool is_length_in_bytes, SQLPOINTER output, O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = SQL_SUCCESS; if (is_unicode) { result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output, @@ -158,9 +157,11 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i std::string& attribute_to_write) { thread_local std::vector utf8_str; if (input_length_in_bytes == SQL_NTS) { - WcsToUtf8(new_value, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8(new_value, &utf8_str); } else { - WcsToUtf8(new_value, input_length_in_bytes / GetSqlWCharSize(), &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8( + new_value, input_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(), + &utf8_str); } attribute_to_write.assign((char*)utf8_str.data()); } @@ -168,7 +169,8 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i template void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) { if (static_cast(reinterpret_cast(value)) != allowed_value) { - throw DriverException("Optional feature not implemented", "HYC00"); + throw arrow::flight::sql::odbc::DriverException("Optional feature not implemented", + "HYC00"); } } } // namespace ODBC diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h index a5cc3a6f4c8..4c7fb8667b8 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h @@ -33,17 +33,12 @@ namespace ODBC { -using arrow::flight::sql::odbc::DriverException; -using arrow::flight::sql::odbc::GetSqlWCharSize; -using arrow::flight::sql::odbc::Utf8ToWcs; -using arrow::flight::sql::odbc::WcsToUtf8; - // Return the number of bytes required for the conversion. template inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { thread_local std::vector wstr; - Utf8ToWcs(str.data(), str.size(), &wstr); + arrow::flight::sql::odbc::Utf8ToWcs(str.data(), str.size(), &wstr); SQLLEN value_length_in_bytes = wstr.size(); if (buffer) { @@ -52,11 +47,14 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, // Write a NUL terminator if (buffer_size_in_bytes >= - value_length_in_bytes + static_cast(GetSqlWCharSize())) { - reinterpret_cast(buffer)[value_length_in_bytes / GetSqlWCharSize()] = + value_length_in_bytes + + static_cast(arrow::flight::sql::odbc::GetSqlWCharSize())) { + reinterpret_cast( + buffer)[value_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize()] = '\0'; } else { - SQLLEN num_chars_written = buffer_size_in_bytes / GetSqlWCharSize(); + SQLLEN num_chars_written = + buffer_size_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(); // If we failed to even write one char, the buffer is too small to hold a // NUL-terminator. if (num_chars_written > 0) { @@ -69,15 +67,16 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { - switch (GetSqlWCharSize()) { + switch (arrow::flight::sql::odbc::GetSqlWCharSize()) { case sizeof(char16_t): return ConvertToSqlWChar(str, buffer, buffer_size_in_bytes); case sizeof(char32_t): return ConvertToSqlWChar(str, buffer, buffer_size_in_bytes); default: assert(false); - throw DriverException("Encoding is unsupported, SQLWCHAR size: " + - std::to_string(GetSqlWCharSize())); + throw arrow::flight::sql::odbc::DriverException( + "Encoding is unsupported, SQLWCHAR size: " + + std::to_string(arrow::flight::sql::odbc::GetSqlWCharSize())); } } @@ -93,9 +92,9 @@ inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQ thread_local std::vector utf8_str; if (msg_len == SQL_NTS) { - WcsToUtf8((void*)wchar_msg, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, &utf8_str); } else { - WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); } return std::string(utf8_str.begin(), utf8_str.end()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc index 5fe6069648f..ebff8c40f2c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc @@ -36,7 +36,7 @@ GetTablesReader::GetTablesReader(std::shared_ptr record_batch) bool GetTablesReader::Next() { return ++current_row_ < record_batch_->num_rows(); } -optional GetTablesReader::GetCatalogName() { +std::optional GetTablesReader::GetCatalogName() { const auto& array = checked_pointer_cast(record_batch_->column(0)); if (array->IsNull(current_row_)) return nullopt; @@ -44,7 +44,7 @@ optional GetTablesReader::GetCatalogName() { return array->GetString(current_row_); } -optional GetTablesReader::GetDbSchemaName() { +std::optional GetTablesReader::GetDbSchemaName() { const auto& array = checked_pointer_cast(record_batch_->column(1)); if (array->IsNull(current_row_)) return nullopt; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h index 6cc464d072b..ad9739d87bb 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h @@ -20,8 +20,6 @@ namespace arrow::flight::sql::odbc { -using std::optional; - class GetTablesReader { private: std::shared_ptr record_batch_; @@ -32,9 +30,9 @@ class GetTablesReader { bool Next(); - optional GetCatalogName(); + std::optional GetCatalogName(); - optional GetDbSchemaName(); + std::optional GetDbSchemaName(); std::string GetTableName(); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc index 7f290096e5a..13115c88dbd 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc @@ -28,6 +28,7 @@ namespace arrow::flight::sql::odbc { using arrow::internal::checked_pointer_cast; using std::nullopt; +using std::optional; GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr record_batch) : record_batch_(std::move(record_batch)), current_row_(-1) {} diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h index a7c1d51182f..ce38a925ae1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h @@ -20,8 +20,6 @@ namespace arrow::flight::sql::odbc { -using std::optional; - class GetTypeInfoReader { private: std::shared_ptr record_batch_; @@ -36,13 +34,13 @@ class GetTypeInfoReader { int32_t GetDataType(); - optional GetColumnSize(); + std::optional GetColumnSize(); - optional GetLiteralPrefix(); + std::optional GetLiteralPrefix(); - optional GetLiteralSuffix(); + std::optional GetLiteralSuffix(); - optional> GetCreateParams(); + std::optional> GetCreateParams(); int32_t GetNullable(); @@ -50,25 +48,25 @@ class GetTypeInfoReader { int32_t GetSearchable(); - optional GetUnsignedAttribute(); + std::optional GetUnsignedAttribute(); bool GetFixedPrecScale(); - optional GetAutoIncrement(); + std::optional GetAutoIncrement(); - optional GetLocalTypeName(); + std::optional GetLocalTypeName(); - optional GetMinimumScale(); + std::optional GetMinimumScale(); - optional GetMaximumScale(); + std::optional GetMaximumScale(); int32_t GetSqlDataType(); - optional GetDatetimeSubcode(); + std::optional GetDatetimeSubcode(); - optional GetNumPrecRadix(); + std::optional GetNumPrecRadix(); - optional GetIntervalPrecision(); + std::optional GetIntervalPrecision(); }; } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index fa1a4d9f064..0862a2f8441 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -20,6 +20,8 @@ namespace arrow::flight::sql::odbc { +using arrow::Result; + FlightStreamChunkBuffer::FlightStreamChunkBuffer( FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h index 772a854eb59..696e67e5aa7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h @@ -23,10 +23,9 @@ namespace arrow::flight::sql::odbc { -using arrow::Result; - class FlightStreamChunkBuffer { - BlockingQueue, std::shared_ptr>> + BlockingQueue< + std::pair, std::shared_ptr>> queue_; public: diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h index 970e447dfdc..390950e7413 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h @@ -24,8 +24,6 @@ namespace arrow::flight::sql::odbc { -using boost::optional; - class ResultSet; class ResultSetMetadata; @@ -69,7 +67,7 @@ class Statement { /// /// \param attribute Attribute identifier to be retrieved. /// \return Value associated with the attribute. - virtual optional GetAttribute( + virtual boost::optional GetAttribute( Statement::StatementAttributeId attribute) = 0; /// \brief Prepares the statement. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h index 32d17af6753..1110849977f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h @@ -22,8 +22,6 @@ namespace arrow::flight::sql::odbc { -using config::Configuration; - #if defined _WIN32 /** * Display connection window for user to configure connection parameters. @@ -32,7 +30,7 @@ using config::Configuration; * @param config Output configuration. * @return True on success and false on fail. */ -bool DisplayConnectionWindow(void* window_parent, Configuration& config); +bool DisplayConnectionWindow(void* window_parent, config::Configuration& config); /** * For SQLDriverConnect. @@ -44,7 +42,7 @@ bool DisplayConnectionWindow(void* window_parent, Configuration& config); * @param properties Output properties. * @return True on success and false on fail. */ -bool DisplayConnectionWindow(void* window_parent, Configuration& config, +bool DisplayConnectionWindow(void* window_parent, config::Configuration& config, Connection::ConnPropertyMap& properties); #endif @@ -55,7 +53,7 @@ bool DisplayConnectionWindow(void* window_parent, Configuration& config, * @param driver Driver. * @return True on success and false on fail. */ -bool RegisterDsn(const Configuration& config, LPCWSTR driver); +bool RegisterDsn(const config::Configuration& config, LPCWSTR driver); /** * Unregister specified DSN. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc index df6aff9cfa7..a2cc08358e2 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc @@ -64,6 +64,7 @@ CDataType GetDefaultCCharType(bool use_wide_char) { using std::make_optional; using std::nullopt; +using std::optional; /// \brief Returns the mapping from Arrow type to SqlDataType /// \param field the field to return the SqlDataType for diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h index 8197f741d1e..4980cb0f19e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h @@ -49,8 +49,6 @@ namespace util { typedef std::function(const std::shared_ptr&)> ArrayConvertTask; -using std::optional; - inline void ThrowIfNotOK(const Status& status) { if (!status.ok()) { throw DriverException(status.message()); @@ -63,7 +61,7 @@ inline bool CheckIfSetToOnlyValidValue(const AttributeTypeT& value, T allowed_va } template -Status AppendToBuilder(BUILDER& builder, optional opt_value) { +Status AppendToBuilder(BUILDER& builder, std::optional opt_value) { if (opt_value) { return builder.Append(*opt_value); } else { @@ -87,29 +85,30 @@ CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2); std::string GetTypeNameFromSqlDataType(int16_t data_type); -optional GetRadixFromSqlDataType(SqlDataType data_type); +std::optional GetRadixFromSqlDataType(SqlDataType data_type); int16_t GetNonConciseDataType(SqlDataType data_type); -optional GetSqlDateTimeSubCode(SqlDataType data_type); +std::optional GetSqlDateTimeSubCode(SqlDataType data_type); -optional GetCharOctetLength(SqlDataType data_type, - const arrow::Result& column_size, - const int32_t decimal_precison = 0); +std::optional GetCharOctetLength(SqlDataType data_type, + const arrow::Result& column_size, + const int32_t decimal_precison = 0); -optional GetBufferLength(SqlDataType data_type, - const optional& column_size); +std::optional GetBufferLength(SqlDataType data_type, + const std::optional& column_size); -optional GetLength(SqlDataType data_type, const optional& column_size); +std::optional GetLength(SqlDataType data_type, + const std::optional& column_size); -optional GetTypeScale(SqlDataType data_type, - const optional& type_scale); +std::optional GetTypeScale(SqlDataType data_type, + const std::optional& type_scale); -optional GetColumnSize(SqlDataType data_type, - const optional& column_size); +std::optional GetColumnSize(SqlDataType data_type, + const std::optional& column_size); -optional GetDisplaySize(SqlDataType data_type, - const optional& column_size); +std::optional GetDisplaySize(SqlDataType data_type, + const std::optional& column_size); std::string ConvertSqlPatternToRegexString(const std::string& pattern); diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index fccb5525759..782a51156c8 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -447,22 +447,22 @@ std::wstring ConvertToWString(const std::vector& str_val, SQLSMALLINT } else { EXPECT_GT(str_len, 0); EXPECT_LE(str_len, static_cast(kOdbcBufferSize)); - attr_str = std::wstring(str_val.begin(), - str_val.begin() + str_len / ODBC::GetSqlWCharSize()); + attr_str = + std::wstring(str_val.begin(), str_val.begin() + str_len / GetSqlWCharSize()); } return attr_str; } void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring& expected) { SQLWCHAR buf[1024]; - SQLLEN buf_len = sizeof(buf) * ODBC::GetSqlWCharSize(); + SQLLEN buf_len = sizeof(buf) * GetSqlWCharSize(); ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len, &buf_len)); EXPECT_GT(buf_len, 0); // returned buf_len is in bytes so convert to length in characters - size_t char_count = static_cast(buf_len) / ODBC::GetSqlWCharSize(); + size_t char_count = static_cast(buf_len) / GetSqlWCharSize(); std::wstring returned(buf, buf + char_count); EXPECT_EQ(expected, returned); From 4d34c4c39065c363c94a056e88a3e242227ab008 Mon Sep 17 00:00:00 2001 From: justing-bq <62349012+justing-bq@users.noreply.github.com> Date: Mon, 10 Nov 2025 15:38:08 -0800 Subject: [PATCH 4/4] Address more feedback --- .../odbc_impl/flight_sql_stream_chunk_buffer.cc | 5 ++++- .../flight_sql_stream_chunk_buffer_test.cc | 17 +++++------------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index 0862a2f8441..a01a0c2407d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -49,7 +49,8 @@ FlightStreamChunkBuffer::FlightStreamChunkBuffer( std::unique_ptr temp_flight_client; util::ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) .Value(&temp_flight_client)); - temp_flight_sql_client.reset(new FlightSqlClient(std::move(temp_flight_client))); + temp_flight_sql_client = + std::make_shared(std::move(temp_flight_client)); result = temp_flight_sql_client->DoGet(call_options, ticket); } @@ -67,6 +68,8 @@ FlightStreamChunkBuffer::FlightStreamChunkBuffer( // call. temp_flight_sql_client is intentionally null if the list of endpoint // Locations is empty. // After all data is fetched from reader, the temp client is closed. + + // gh-48084 Replace boost::optional with std::optional return boost::make_optional( is_not_ok || is_not_empty, std::make_pair(std::move(result), temp_flight_sql_client)); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc index 18ae3d8c57e..cbe5cd8f7e5 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -24,6 +24,7 @@ #include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" +#include #include namespace arrow::flight::sql::odbc { @@ -92,17 +93,6 @@ FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { 100000, false, ""); } -void VerifyArraysContainIntsOnly(std::shared_ptr intArray) { - for (int64_t i = 0; i < intArray->length(); ++i) { - // null values are accepted - if (!intArray->IsNull(i)) { - auto scalar_data = intArray->GetScalar(i).ValueOrDie(); - std::string scalar_str = ConvertToJson(*scalar_data); - ASSERT_TRUE(std::all_of(scalar_str.begin(), scalar_str.end(), ::isdigit)); - } - } -} - TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { FlightClientOptions client_options = FlightClientOptions::Defaults(); FlightCallOptions options; @@ -127,7 +117,10 @@ TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { // Each array has random length ASSERT_GT(array->length(), 0); - VerifyArraysContainIntsOnly(array); + std::vector int_types = { + Type::type::INT8, Type::type::UINT8, Type::type::INT16, Type::type::UINT16, + Type::type::INT32, Type::type::UINT32, Type::type::INT64, Type::type::UINT64}; + ASSERT_THAT(int_types, testing::Contains(array->type_id())); } }