diff --git a/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/AdaptingDriverBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/AdaptingDriverBoltConnection.java index e2c6dfa99..56cbf10ce 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/AdaptingDriverBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/AdaptingDriverBoltConnection.java @@ -16,24 +16,14 @@ */ package org.neo4j.driver.internal.adaptedbolt; -import java.time.Duration; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletionStage; -import java.util.function.Supplier; -import org.neo4j.bolt.connection.AccessMode; import org.neo4j.bolt.connection.AuthInfo; -import org.neo4j.bolt.connection.AuthTokens; import org.neo4j.bolt.connection.BoltConnection; -import org.neo4j.bolt.connection.BoltConnectionState; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; -import org.neo4j.bolt.connection.DatabaseName; -import org.neo4j.bolt.connection.NotificationConfig; -import org.neo4j.bolt.connection.TelemetryApi; -import org.neo4j.bolt.connection.TransactionType; -import org.neo4j.driver.Value; +import org.neo4j.bolt.connection.message.Message; import org.neo4j.driver.internal.value.BoltValueFactory; final class AdaptingDriverBoltConnection implements DriverBoltConnection { @@ -49,141 +39,15 @@ final class AdaptingDriverBoltConnection implements DriverBoltConnection { } @Override - public CompletionStage onLoop(Supplier supplier) { - return connection.onLoop(supplier); - } - - @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return connection - .route(databaseName, impersonatedUser, bookmarks) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return connection - .beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - boltValueFactory.toBoltMap(txMetadata), - txType, - notificationConfig) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return connection - .runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - boltValueFactory.toBoltMap(parameters), - txTimeout, - boltValueFactory.toBoltMap(txMetadata), - notificationConfig) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return connection - .run(query, boltValueFactory.toBoltMap(parameters)) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage pull(long qid, long request) { - return connection - .pull(qid, request) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage discard(long qid, long number) { + public CompletionStage writeAndFlush(DriverResponseHandler handler, List messages) { return connection - .discard(qid, number) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage commit() { - return connection.commit().exceptionally(errorMapper::mapAndTrow).thenApply(ignored -> this); - } - - @Override - public CompletionStage rollback() { - return connection.rollback().exceptionally(errorMapper::mapAndTrow).thenApply(ignored -> this); - } - - @Override - public CompletionStage reset() { - return connection.reset().exceptionally(errorMapper::mapAndTrow).thenApply(ignored -> this); - } - - @Override - public CompletionStage logoff() { - return connection.logoff().exceptionally(errorMapper::mapAndTrow).thenApply(ignored -> this); - } - - @Override - public CompletionStage logon(Map authMap) { - return connection - .logon(AuthTokens.custom(boltValueFactory.toBoltMap(authMap))) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return connection - .telemetry(telemetryApi) - .exceptionally(errorMapper::mapAndTrow) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage clear() { - return connection.clear().exceptionally(errorMapper::mapAndTrow).thenApply(ignored -> this); + .writeAndFlush(new AdaptingDriverResponseHandler(handler, errorMapper, boltValueFactory), messages) + .exceptionally(errorMapper::mapAndTrow); } @Override - public CompletionStage flush(DriverResponseHandler handler) { - return connection - .flush(new AdaptingDriverResponseHandler(handler, errorMapper, boltValueFactory)) - .exceptionally(errorMapper::mapAndTrow); + public CompletionStage write(List messages) { + return connection.write(messages).exceptionally(errorMapper::mapAndTrow); } @Override @@ -196,11 +60,6 @@ public CompletionStage close() { return connection.close().exceptionally(errorMapper::mapAndTrow); } - @Override - public BoltConnectionState state() { - return connection.state(); - } - @Override public CompletionStage authData() { return connection.authInfo().exceptionally(errorMapper::mapAndTrow); @@ -230,4 +89,9 @@ public boolean telemetrySupported() { public boolean serverSideRoutingEnabled() { return connection.serverSideRoutingEnabled(); } + + @Override + public BoltValueFactory valueFactory() { + return boltValueFactory; + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/DriverBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/DriverBoltConnection.java index cde9fea18..1c995908b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/DriverBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/adaptedbolt/DriverBoltConnection.java @@ -16,71 +16,22 @@ */ package org.neo4j.driver.internal.adaptedbolt; -import java.time.Duration; -import java.util.Map; -import java.util.Set; +import java.util.List; import java.util.concurrent.CompletionStage; -import java.util.function.Supplier; -import org.neo4j.bolt.connection.AccessMode; import org.neo4j.bolt.connection.AuthInfo; -import org.neo4j.bolt.connection.BoltConnectionState; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; -import org.neo4j.bolt.connection.DatabaseName; -import org.neo4j.bolt.connection.NotificationConfig; -import org.neo4j.bolt.connection.TelemetryApi; -import org.neo4j.bolt.connection.TransactionType; -import org.neo4j.driver.Value; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.driver.internal.value.BoltValueFactory; public interface DriverBoltConnection { - CompletionStage onLoop(Supplier supplier); + default CompletionStage writeAndFlush(DriverResponseHandler handler, Message messages) { + return writeAndFlush(handler, List.of(messages)); + } - CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks); + CompletionStage writeAndFlush(DriverResponseHandler handler, List messages); - CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig); - - CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig); - - CompletionStage run(String query, Map parameters); - - CompletionStage pull(long qid, long request); - - CompletionStage discard(long qid, long number); - - CompletionStage commit(); - - CompletionStage rollback(); - - CompletionStage reset(); - - CompletionStage logoff(); - - CompletionStage logon(Map authMap); - - CompletionStage telemetry(TelemetryApi telemetryApi); - - CompletionStage clear(); - - CompletionStage flush(DriverResponseHandler handler); + CompletionStage write(List messages); CompletionStage forceClose(String reason); @@ -88,8 +39,6 @@ CompletionStage runInAutoCommitTransaction( // ----- MUTABLE DATA ----- - BoltConnectionState state(); - CompletionStage authData(); // ----- IMMUTABLE DATA ----- @@ -103,4 +52,8 @@ CompletionStage runInAutoCommitTransaction( boolean telemetrySupported(); boolean serverSideRoutingEnabled(); + + // ----- EXTRAS ----- + + BoltValueFactory valueFactory(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java index 6bfcf569c..9a53b00ba 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java @@ -16,8 +16,10 @@ */ package org.neo4j.driver.internal.async; +import java.util.List; import java.util.Objects; import java.util.concurrent.CompletionStage; +import org.neo4j.bolt.connection.message.Message; import org.neo4j.driver.AuthTokenManager; import org.neo4j.driver.exceptions.SecurityException; import org.neo4j.driver.exceptions.SecurityRetryableException; @@ -35,8 +37,8 @@ public BoltConnectionWithAuthTokenManager(DriverBoltConnection delegate, AuthTok } @Override - public CompletionStage flush(DriverResponseHandler handler) { - return delegate.flush(new ErrorMappingResponseHandler(handler, this::mapSecurityError)); + public CompletionStage writeAndFlush(DriverResponseHandler handler, List messages) { + return delegate.writeAndFlush(new ErrorMappingResponseHandler(handler, this::mapSecurityError), messages); } private Throwable mapSecurityError(Throwable throwable) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java index fbf6d7de6..5b16db0a2 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java @@ -16,24 +16,16 @@ */ package org.neo4j.driver.internal.async; -import java.time.Duration; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletionStage; -import java.util.function.Supplier; -import org.neo4j.bolt.connection.AccessMode; import org.neo4j.bolt.connection.AuthInfo; -import org.neo4j.bolt.connection.BoltConnectionState; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; -import org.neo4j.bolt.connection.DatabaseName; -import org.neo4j.bolt.connection.NotificationConfig; -import org.neo4j.bolt.connection.TelemetryApi; -import org.neo4j.bolt.connection.TransactionType; -import org.neo4j.driver.Value; +import org.neo4j.bolt.connection.message.Message; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; import org.neo4j.driver.internal.adaptedbolt.DriverResponseHandler; +import org.neo4j.driver.internal.value.BoltValueFactory; public abstract class DelegatingBoltConnection implements DriverBoltConnection { protected final DriverBoltConnection delegate; @@ -43,117 +35,13 @@ protected DelegatingBoltConnection(DriverBoltConnection delegate) { } @Override - public CompletionStage onLoop(Supplier supplier) { - return delegate.onLoop(supplier); + public CompletionStage writeAndFlush(DriverResponseHandler handler, List messages) { + return delegate.writeAndFlush(handler, messages); } @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return delegate.route(databaseName, impersonatedUser, bookmarks).thenApply(ignored -> this); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return delegate.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - txMetadata, - txType, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return delegate.runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - parameters, - txTimeout, - txMetadata, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return delegate.run(query, parameters).thenApply(ignored -> this); - } - - @Override - public CompletionStage pull(long qid, long request) { - return delegate.pull(qid, request).thenApply(ignored -> this); - } - - @Override - public CompletionStage discard(long qid, long number) { - return delegate.discard(qid, number).thenApply(ignored -> this); - } - - @Override - public CompletionStage commit() { - return delegate.commit().thenApply(ignored -> this); - } - - @Override - public CompletionStage rollback() { - return delegate.rollback().thenApply(ignored -> this); - } - - @Override - public CompletionStage reset() { - return delegate.reset().thenApply(ignored -> this); - } - - @Override - public CompletionStage logoff() { - return delegate.logoff().thenApply(ignored -> this); - } - - @Override - public CompletionStage logon(Map authMap) { - return delegate.logon(authMap).thenApply(ignored -> this); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return delegate.telemetry(telemetryApi).thenApply(ignored -> this); - } - - @Override - public CompletionStage clear() { - return delegate.clear().thenApply(ignored -> this); - } - - @Override - public CompletionStage flush(DriverResponseHandler handler) { - return delegate.flush(handler); + public CompletionStage write(List messages) { + return delegate.write(messages); } @Override @@ -166,11 +54,6 @@ public CompletionStage close() { return delegate.close(); } - @Override - public BoltConnectionState state() { - return delegate.state(); - } - @Override public CompletionStage authData() { return delegate.authData(); @@ -200,4 +83,9 @@ public boolean telemetrySupported() { public boolean serverSideRoutingEnabled() { return delegate.serverSideRoutingEnabled(); } + + @Override + public BoltValueFactory valueFactory() { + return delegate.valueFactory(); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java index c3b1a48e5..995da243d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java @@ -20,6 +20,7 @@ import static org.neo4j.driver.internal.util.Futures.completedWithNull; import static org.neo4j.driver.internal.util.Futures.completionExceptionCause; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -45,6 +46,8 @@ import org.neo4j.bolt.connection.SecurityPlan; import org.neo4j.bolt.connection.TelemetryApi; import org.neo4j.bolt.connection.exception.MinVersionAcquisitionException; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.AccessMode; import org.neo4j.driver.AuthToken; @@ -168,22 +171,32 @@ public CompletionStage runAsync(Query query, TransactionConfig con null, this::handleDatabaseName, null); - var cursorStage = apiTelemetryWork - .pipelineTelemetryIfEnabled(connection) - .thenCompose(conn -> conn.runInAutoCommitTransaction( - connectionContext.databaseNameFuture.getNow(DatabaseNameUtil.defaultDatabase()), - asBoltAccessMode(mode), - connectionContext.impersonatedUser, - determineBookmarks(true).stream() - .map(Bookmark::value) - .collect(Collectors.toSet()), - query.text(), - parameters, - config.timeout(), - config.metadata(), - notificationConfig)) - .thenCompose(conn -> conn.pull(-1, fetchSize)) - .thenCompose(conn -> conn.flush(resultCursor)) + var cursorStage = CompletableFuture.completedStage(null) + .thenCompose(ignored -> { + var messages = new ArrayList(3); + var telemetryMessage = apiTelemetryWork.getTelemetryMessageIfEnabled(connection); + if (telemetryMessage != null) { + messages.add(telemetryMessage); + } + messages.add(Messages.run( + connectionContext + .databaseNameFuture + .getNow(DatabaseNameUtil.defaultDatabase()) + .databaseName() + .orElse(null), + asBoltAccessMode(mode), + connectionContext.impersonatedUser, + determineBookmarks(true).stream() + .map(Bookmark::value) + .collect(Collectors.toSet()), + query.text(), + connection.valueFactory().toBoltMap(parameters), + config.timeout(), + connection.valueFactory().toBoltMap(config.metadata()), + notificationConfig)); + messages.add(Messages.pull(-1, fetchSize)); + return connection.writeAndFlush(resultCursor, messages); + }) .thenCompose(ignored -> resultCursor.resultCursor()) .handle((resultCursorImpl, throwable) -> { var error = completionExceptionCause(throwable); @@ -222,21 +235,32 @@ public CompletionStage runRx( var runFailed = new AtomicBoolean(false); var responseHandler = new RunRxResponseHandler( logging, connection, query, this::handleNewBookmark, runFailed, this::handleDatabaseName); - var cursorStage = apiTelemetryWork - .pipelineTelemetryIfEnabled(connection) - .thenCompose(conn -> conn.runInAutoCommitTransaction( - connectionContext.databaseNameFuture.getNow(DatabaseNameUtil.defaultDatabase()), - asBoltAccessMode(mode), - connectionContext.impersonatedUser, - determineBookmarks(true).stream() - .map(Bookmark::value) - .collect(Collectors.toSet()), - query.text(), - parameters, - config.timeout(), - config.metadata(), - notificationConfig)) - .thenCompose(conn -> conn.flush(responseHandler)) + var cursorStage = CompletableFuture.completedStage(null) + .thenCompose(ignored -> { + var messages = new ArrayList(2); + var telemetryMessage = apiTelemetryWork.getTelemetryMessageIfEnabled(connection); + if (telemetryMessage != null) { + messages.add(telemetryMessage); + } + var runMessage = Messages.run( + connectionContext + .databaseNameFuture + .getNow(DatabaseNameUtil.defaultDatabase()) + .databaseName() + .orElse(null), + asBoltAccessMode(mode), + connectionContext.impersonatedUser, + determineBookmarks(true).stream() + .map(Bookmark::value) + .collect(Collectors.toSet()), + query.text(), + connection.valueFactory().toBoltMap(parameters), + config.timeout(), + connection.valueFactory().toBoltMap(config.metadata()), + notificationConfig); + messages.add(runMessage); + return connection.writeAndFlush(responseHandler, messages); + }) .thenCompose(ignored -> responseHandler.cursorFuture) .handle((resultCursor, throwable) -> { var error = completionExceptionCause(throwable); @@ -338,18 +362,19 @@ public CompletionStage resetAsync() { if (connection != null && connection.isOpen()) { var future = new CompletableFuture(); return connection - .reset() - .thenCompose(conn -> conn.flush(new DriverResponseHandler() { - @Override - public void onError(Throwable throwable) { - future.completeExceptionally(throwable); - } - - @Override - public void onComplete() { - future.complete(null); - } - })) + .writeAndFlush( + new DriverResponseHandler() { + @Override + public void onError(Throwable throwable) { + future.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + future.complete(null); + } + }, + Messages.reset()) .thenCompose(ignored -> future); } else { return completedWithNull(); @@ -767,8 +792,8 @@ public void onComplete() { if (error != null) { runFailed.set(true); } - cursorFuture.complete(new RxResultCursorImpl( - connection, NOOP_LOCK, query, runSummary, error, bookmarkConsumer, true, logging)); + cursorFuture.complete( + new RxResultCursorImpl(connection, query, runSummary, error, bookmarkConsumer, true, logging)); } else { var message = ignoredCount > 0 ? "Run exchange contains ignored messages." diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java index 890b7d418..e56461718 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java @@ -16,12 +16,14 @@ */ package org.neo4j.driver.internal.async; +import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.function.Consumer; -import java.util.function.Function; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; @@ -49,22 +51,19 @@ public TerminationAwareBoltConnection( public CompletionStage clearAndReset() { var future = new CompletableFuture(); var thisVal = this; - delegate.onLoop(() -> executor.execute(ignored -> clearAndResetBolt(future))) - .thenCompose(Function.identity()) - .whenComplete((ignored, throwable) -> { - if (throwable != null) { - throwableConsumer.accept(throwable); - future.completeExceptionally(throwable); - } - }); + executor.execute(ignored -> clearAndResetBolt(future)).whenComplete((ignored, throwable) -> { + if (throwable != null) { + throwableConsumer.accept(throwable); + future.completeExceptionally(throwable); + } + }); return future; } private CompletionStage clearAndResetBolt(CompletableFuture future) { var thisVal = this; - return delegate.clear() - .thenCompose(DriverBoltConnection::reset) - .thenCompose(conn -> conn.flush(new DriverResponseHandler() { + return delegate.writeAndFlush( + new DriverResponseHandler() { Throwable throwable = null; @Override @@ -82,49 +81,51 @@ public void onComplete() { future.complete(thisVal); } } - })); + }, + List.of(Messages.reset())); } @Override - public CompletionStage flush(DriverResponseHandler handler) { - return delegate.onLoop(() -> executor.execute(causeOfTermination -> flushBolt(causeOfTermination, handler))) - .thenCompose(Function.identity()); + public CompletionStage writeAndFlush(DriverResponseHandler handler, List messages) { + return executor.execute(causeOfTermination -> flushBolt(causeOfTermination, handler, messages)); } - private CompletionStage flushBolt(Throwable causeOfTermination, DriverResponseHandler handler) { + private CompletionStage flushBolt( + Throwable causeOfTermination, DriverResponseHandler handler, List messages) { if (causeOfTermination == null) { log.trace("This connection is active, will flush"); var terminationAwareResponseHandler = new TerminationAwareResponseHandler(logging, handler, executor, throwableConsumer); - return delegate.flush(terminationAwareResponseHandler).handle((ignored, flushThrowable) -> { - flushThrowable = Futures.completionExceptionCause(flushThrowable); - if (flushThrowable != null) { - if (log.isTraceEnabled()) { - log.error("The flush has failed", flushThrowable); - } - var flushThrowableRef = flushThrowable; - flushThrowable = executor.execute(existingThrowable -> { - if (existingThrowable != null) { - log.trace("The flush has failed, but there is an existing %s", existingThrowable); - return existingThrowable; + return delegate.writeAndFlush(terminationAwareResponseHandler, messages) + .handle((ignored, flushThrowable) -> { + flushThrowable = Futures.completionExceptionCause(flushThrowable); + if (flushThrowable != null) { + if (log.isTraceEnabled()) { + log.error("The flush has failed", flushThrowable); + } + var flushThrowableRef = flushThrowable; + flushThrowable = executor.execute(existingThrowable -> { + if (existingThrowable != null) { + log.trace("The flush has failed, but there is an existing %s", existingThrowable); + return existingThrowable; + } else { + throwableConsumer.accept(flushThrowableRef); + return flushThrowableRef; + } + }); + // rethrow + if (flushThrowable instanceof RuntimeException runtimeException) { + throw runtimeException; + } else { + throw new CompletionException(flushThrowable); + } } else { - throwableConsumer.accept(flushThrowableRef); - return flushThrowableRef; + return ignored; } }); - // rethrow - if (flushThrowable instanceof RuntimeException runtimeException) { - throw runtimeException; - } else { - throw new CompletionException(flushThrowable); - } - } else { - return ignored; - } - }); } else { // there is an existing error - return delegate.clear().thenCompose(ignored -> CompletableFuture.failedStage(causeOfTermination)); + return CompletableFuture.failedStage(causeOfTermination); } } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index d14a5850f..c0d03450d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -23,8 +23,10 @@ import static org.neo4j.driver.internal.util.Futures.futureCompletingConsumer; import static org.neo4j.driver.internal.util.LockUtil.executeWithLock; +import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; +import java.util.List; import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -39,6 +41,8 @@ import org.neo4j.bolt.connection.DatabaseName; import org.neo4j.bolt.connection.NotificationConfig; import org.neo4j.bolt.connection.TransactionType; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.CommitSummary; import org.neo4j.bolt.connection.summary.RunSummary; @@ -106,7 +110,6 @@ private enum State { private final ResultCursorsHolder resultCursors; private final long fetchSize; private final Lock lock = new ReentrantLock(); - private final Lock connectionLock = new ReentrantLock(); private State state = State.ACTIVE; private CompletableFuture commitFuture; private CompletableFuture rollbackFuture; @@ -120,6 +123,7 @@ private enum State { private final ApiTelemetryWork apiTelemetryWork; private final Consumer databaseNameConsumer; + private Message[] beginMessages; public UnmanagedTransaction( DriverBoltConnection connection, @@ -174,26 +178,31 @@ protected UnmanagedTransaction( // flush = false is only supported for async mode with a single subsequent run public CompletionStage beginAsync( Set initialBookmarks, TransactionConfig config, String txType, boolean flush) { - var bookmarks = initialBookmarks.stream().map(Bookmark::value).collect(Collectors.toSet()); - - return apiTelemetryWork - .pipelineTelemetryIfEnabled(connection) - .thenCompose(connection -> connection.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - TransactionType.DEFAULT, - config.timeout(), - config.metadata(), - txType, - notificationConfig)) - .thenCompose(connection -> { + return CompletableFuture.completedStage(null) + .thenApply(ignored -> { + var messages = new ArrayList(2); + var telemetryMessage = apiTelemetryWork.getTelemetryMessageIfEnabled(connection); + if (telemetryMessage != null) { + messages.add(telemetryMessage); + } + messages.add(Messages.beginTransaction( + databaseName.databaseName().orElse(null), + accessMode, + impersonatedUser, + bookmarks, + // todo txType + TransactionType.DEFAULT, + config.timeout(), + connection.valueFactory().toBoltMap(config.metadata()), + notificationConfig)); + return messages; + }) + .thenCompose(messages -> { if (flush) { var responseHandler = new BeginResponseHandler(apiTelemetryWork, databaseNameConsumer); connection - .flush(responseHandler) + .writeAndFlush(responseHandler, messages) .thenCompose(ignored -> responseHandler.summaryFuture) .whenComplete((summary, throwable) -> { if (throwable != null) { @@ -209,7 +218,7 @@ public CompletionStage beginAsync( }); return beginFuture.thenApply(ignored -> this); } else { - return CompletableFuture.completedFuture(this); + return connection.write(messages).thenApply(ignored -> this); } }); } @@ -242,10 +251,12 @@ public CompletionStage runAsync(Query query) { beginFuture, databaseNameConsumer, apiTelemetryWork); - var flushStage = connection - .run(query.text(), parameters) - .thenCompose(ignored -> connection.pull(-1, fetchSize)) - .thenCompose(ignored -> connection.flush(resultCursor)); + var flushStage = CompletableFuture.completedStage(null).thenCompose(ignored -> { + var messages = List.of( + Messages.run(query.text(), connection.valueFactory().toBoltMap(parameters)), + Messages.pull(-1, fetchSize)); + return connection.writeAndFlush(resultCursor, messages); + }); return beginFuture.thenCompose(ignored -> { var cursorStage = flushStage .thenCompose(flushResult -> resultCursor.resultCursor()) @@ -258,17 +269,11 @@ public CompletionStage runAsync(Query query) { public CompletionStage runRx(Query query) { ensureCanRunQueries(); var parameters = query.parameters().asMap(Values::value); - var responseHandler = - new RunRxResponseHandler(logging, apiTelemetryWork, beginFuture, connection, connectionLock, query); - var flushStage = connection - .onLoop(() -> { - connectionLock.lock(); - return connection - .run(query.text(), parameters) - .thenCompose(conn -> conn.flush(responseHandler)) - .whenComplete((ignored, throwable) -> connectionLock.unlock()); - }) - .thenCompose(Function.identity()); + var responseHandler = new RunRxResponseHandler(logging, apiTelemetryWork, beginFuture, connection, query); + var flushStage = CompletableFuture.completedStage(null) + .thenCompose(runMessage -> connection.writeAndFlush( + responseHandler, + Messages.run(query.text(), connection.valueFactory().toBoltMap(parameters)))); return beginFuture.thenCompose(ignored -> { var cursorStage = flushStage.thenCompose(flushResult -> responseHandler.cursorFuture); resultCursors.add(cursorStage); @@ -416,8 +421,7 @@ private CompletionStage doCommitAsync(Throwable cursorFailure) { var commitSummary = new CompletableFuture(); var responseHandler = new BasicResponseHandler(); connection - .commit() - .thenCompose(connection -> connection.flush(responseHandler)) + .writeAndFlush(responseHandler, Messages.commit()) .thenCompose(ignored -> responseHandler.summaries()) .whenComplete((summaries, throwable) -> { if (throwable != null) { @@ -455,8 +459,7 @@ private CompletionStage doRollbackAsync() { var rollbackFuture = new CompletableFuture(); var responseHandler = new BasicResponseHandler(); connection - .rollback() - .thenCompose(connection -> connection.flush(responseHandler)) + .writeAndFlush(responseHandler, Messages.rollback()) .thenCompose(ignored -> responseHandler.summaries()) .whenComplete((summaries, throwable) -> { if (throwable != null) { @@ -679,7 +682,6 @@ private static class RunRxResponseHandler implements DriverResponseHandler { private final ApiTelemetryWork apiTelemetryWork; private final CompletableFuture beginFuture; private final DriverBoltConnection connection; - private final Lock connectionLock; private final Query query; private Throwable error; private RunSummary runSummary; @@ -690,13 +692,11 @@ private RunRxResponseHandler( ApiTelemetryWork apiTelemetryWork, CompletableFuture beginFuture, DriverBoltConnection connection, - Lock connectionLock, Query query) { this.logging = logging; this.apiTelemetryWork = apiTelemetryWork; this.beginFuture = beginFuture; this.connection = connection; - this.connectionLock = connectionLock; this.query = query; } @@ -732,13 +732,13 @@ public void onIgnored() { public void onComplete() { if (error != null) { if (!beginFuture.completeExceptionally(error)) { - cursorFuture.complete(new RxResultCursorImpl( - connection, connectionLock, query, null, error, bookmark -> {}, false, logging)); + cursorFuture.complete( + new RxResultCursorImpl(connection, query, null, error, bookmark -> {}, false, logging)); } } else { if (runSummary != null) { cursorFuture.complete(new RxResultCursorImpl( - connection, connectionLock, query, runSummary, null, bookmark -> {}, false, logging)); + connection, query, runSummary, null, bookmark -> {}, false, logging)); } else { var message = ignoredCount > 0 ? "Run exchange contains ignored messages" : "Unexpected state during run"; diff --git a/driver/src/main/java/org/neo4j/driver/internal/boltlistener/ListeningBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/boltlistener/ListeningBoltConnection.java index d38e108d8..afa78ebf5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/boltlistener/ListeningBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/boltlistener/ListeningBoltConnection.java @@ -17,25 +17,17 @@ package org.neo4j.driver.internal.boltlistener; import java.time.Duration; -import java.util.Map; +import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CompletionStage; -import java.util.function.Supplier; -import org.neo4j.bolt.connection.AccessMode; import org.neo4j.bolt.connection.AuthInfo; -import org.neo4j.bolt.connection.AuthToken; import org.neo4j.bolt.connection.BoltConnection; import org.neo4j.bolt.connection.BoltConnectionState; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; -import org.neo4j.bolt.connection.DatabaseName; -import org.neo4j.bolt.connection.NotificationConfig; import org.neo4j.bolt.connection.ResponseHandler; -import org.neo4j.bolt.connection.TelemetryApi; -import org.neo4j.bolt.connection.TransactionType; -import org.neo4j.bolt.connection.values.Value; +import org.neo4j.bolt.connection.message.Message; final class ListeningBoltConnection implements BoltConnection { private final BoltConnection delegate; @@ -47,117 +39,13 @@ public ListeningBoltConnection(BoltConnection delegate, BoltConnectionListener b } @Override - public CompletionStage onLoop(Supplier supplier) { - return delegate.onLoop(supplier); + public CompletionStage writeAndFlush(ResponseHandler handler, List messages) { + return delegate.writeAndFlush(handler, messages); } @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return delegate.route(databaseName, impersonatedUser, bookmarks).thenApply(ignored -> this); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return delegate.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - txMetadata, - txType, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return delegate.runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - parameters, - txTimeout, - txMetadata, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return delegate.run(query, parameters).thenApply(ignored -> this); - } - - @Override - public CompletionStage pull(long qid, long request) { - return delegate.pull(qid, request).thenApply(ignored -> this); - } - - @Override - public CompletionStage discard(long qid, long number) { - return delegate.discard(qid, number).thenApply(ignored -> this); - } - - @Override - public CompletionStage commit() { - return delegate.commit().thenApply(ignored -> this); - } - - @Override - public CompletionStage rollback() { - return delegate.rollback().thenApply(ignored -> this); - } - - @Override - public CompletionStage reset() { - return delegate.reset().thenApply(ignored -> this); - } - - @Override - public CompletionStage logoff() { - return delegate.logoff().thenApply(ignored -> this); - } - - @Override - public CompletionStage logon(AuthToken authToken) { - return delegate.logon(authToken).thenApply(ignored -> this); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return delegate.telemetry(telemetryApi).thenApply(ignored -> this); - } - - @Override - public CompletionStage clear() { - return delegate.clear().thenApply(ignored -> this); - } - - @Override - public CompletionStage flush(ResponseHandler handler) { - return delegate.flush(handler); + public CompletionStage write(List messages) { + return delegate.write(messages); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java index 3d4b880cb..bf34aa3c1 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java @@ -29,6 +29,7 @@ import java.util.function.Consumer; import java.util.function.Function; import org.neo4j.bolt.connection.BoltProtocolVersion; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.bolt.connection.summary.TelemetrySummary; @@ -150,8 +151,7 @@ public synchronized CompletionStage consumeAsync() { var future = summaryFuture; state = State.DISCARDING; boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.discard(runSummary.queryId(), -1)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); CompletableFuture summaryFuture; @@ -230,8 +230,7 @@ var record = records.poll(); state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), fetchSize)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); CompletableFuture recordFuture; @@ -295,8 +294,7 @@ var record = records.peek(); state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), fetchSize)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); if (error != null) { @@ -376,8 +374,7 @@ public synchronized CompletionStage singleAsync() { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), fetchSize)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); if (error != null) { @@ -505,8 +502,7 @@ public synchronized CompletionStage> listAsync() { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), -1)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); CompletableFuture> recordsFuture; @@ -767,8 +763,7 @@ public void onPullSummary(PullSummary summary) { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), fetchSize)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); if (error != null) { @@ -788,8 +783,7 @@ public void onPullSummary(PullSummary summary) { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), fetchSize)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); if (error != null) { @@ -818,8 +812,7 @@ public void onPullSummary(PullSummary summary) { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), -1)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); if (error != null) { @@ -838,8 +831,7 @@ public void onPullSummary(PullSummary summary) { // consume is pending, discard all state = State.DISCARDING; boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.discard(runSummary.queryId(), -1)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); CompletableFuture summaryFuture; @@ -1205,8 +1197,7 @@ public CompletionStage pullAllFailureAsync() { state = State.STREAMING; updateRecordState(RecordState.NO_RECORD); boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + .writeAndFlush(this, Messages.pull(runSummary.queryId(), -1)) .whenComplete((ignored, throwable) -> { var error = Futures.completionExceptionCause(throwable); CompletableFuture summaryFuture; diff --git a/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java index 1ceb35af7..57a272fd9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java @@ -26,11 +26,10 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.locks.Lock; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Function; import org.neo4j.bolt.connection.BoltProtocolVersion; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Logger; @@ -88,7 +87,6 @@ public Optional databaseName() { }; private final Logger log; private final DriverBoltConnection boltConnection; - private final Lock boltConnectionLock; private final Query query; private final RunSummary runSummary; private final Throwable runError; @@ -118,7 +116,6 @@ private enum State { public RxResultCursorImpl( DriverBoltConnection boltConnection, - Lock boltConnectionLock, Query query, RunSummary runSummary, Throwable runError, @@ -126,7 +123,6 @@ public RxResultCursorImpl( boolean closeOnSummary, Logging logging) { this.boltConnection = Objects.requireNonNull(boltConnection); - this.boltConnectionLock = Objects.requireNonNull(boltConnectionLock); this.legacyNotifications = new BoltProtocolVersion(5, 5).compareTo(boltConnection.protocolVersion()) > 0; this.query = query; this.runSummary = runError == null ? runSummary : EMPTY_RUN_SUMMARY; @@ -192,20 +188,15 @@ public void request(long n) { case READY -> { var request = appendDemand(n); state = State.STREAMING; - runnable = () -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection - .pull(runSummary.queryId(), request) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - boltConnectionLock.unlock(); - throwable = Futures.completionExceptionCause(throwable); - if (throwable != null) { - handleError(throwable); - onComplete(); - } - }); - }); + runnable = () -> boltConnection + .writeAndFlush(this, Messages.pull(runSummary.queryId(), request)) + .whenComplete((ignored, throwable) -> { + throwable = Futures.completionExceptionCause(throwable); + if (throwable != null) { + handleError(throwable); + onComplete(); + } + }); } case STREAMING -> appendDemand(n); case FAILED, DISCARDING, SUCCEEDED -> {} @@ -272,42 +263,35 @@ public CompletionStage rollback() { } } var resetFuture = new CompletableFuture(); - boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection - .reset() - .thenCompose(conn -> conn.flush(new DriverResponseHandler() { - Throwable throwable = null; - - @Override - public void onError(Throwable throwable) { - this.throwable = Futures.completionExceptionCause(throwable); - } - - @Override - public void onComplete() { - if (throwable != null) { - resetFuture.completeExceptionally(throwable); - } else { - resetFuture.complete(null); + boltConnection + .writeAndFlush( + new DriverResponseHandler() { + Throwable throwable = null; + + @Override + public void onError(Throwable throwable) { + this.throwable = Futures.completionExceptionCause(throwable); } - } - })) - .whenComplete((ignored, throwable) -> { - boltConnectionLock.unlock(); - throwable = Futures.completionExceptionCause(throwable); - if (throwable != null) { - resetFuture.completeExceptionally(throwable); - } - }); - }); + + @Override + public void onComplete() { + if (throwable != null) { + resetFuture.completeExceptionally(throwable); + } else { + resetFuture.complete(null); + } + } + }, + Messages.reset()) + .whenComplete((ignored, throwable) -> { + throwable = Futures.completionExceptionCause(throwable); + if (throwable != null) { + resetFuture.completeExceptionally(throwable); + } + }); return resetFuture - .thenCompose(ignored -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection.close().whenComplete((result, error) -> boltConnectionLock.unlock()); - })) - .thenCompose(Function.identity()) + .thenCompose(ignored -> boltConnection.close()) .whenComplete((ignored, throwable) -> completeSummaryFuture(null, null)) .exceptionally(throwable -> null); } @@ -420,20 +404,15 @@ private synchronized void decrementDemand() { private synchronized Runnable setupDiscardRunnable() { state = State.DISCARDING; - return () -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - boltConnectionLock.unlock(); - throwable = Futures.completionExceptionCause(throwable); - if (throwable != null) { - handleError(throwable); - onComplete(); - } - }); - }); + return () -> boltConnection + .writeAndFlush(this, Messages.discard(runSummary.queryId(), -1)) + .whenComplete((ignored, throwable) -> { + throwable = Futures.completionExceptionCause(throwable); + if (throwable != null) { + handleError(throwable); + onComplete(); + } + }); } private synchronized Runnable setupCompletionRunnableWithPullSummary() { @@ -444,38 +423,28 @@ private synchronized Runnable setupCompletionRunnableWithPullSummary() { if (discardPending) { discardPending = false; state = State.DISCARDING; - runnable = () -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) + runnable = () -> boltConnection + .writeAndFlush(this, Messages.discard(runSummary.queryId(), -1)) + .whenComplete((ignored, flushThrowable) -> { + var error = Futures.completionExceptionCause(flushThrowable); + if (error != null) { + handleError(error); + onComplete(); + } + }); + } else { + var demand = getDemand(); + if (demand != 0) { + state = State.STREAMING; + runnable = () -> boltConnection + .writeAndFlush(this, Messages.pull(runSummary.queryId(), demand > 0 ? demand : -1)) .whenComplete((ignored, flushThrowable) -> { - boltConnectionLock.unlock(); var error = Futures.completionExceptionCause(flushThrowable); if (error != null) { handleError(error); onComplete(); } }); - }); - } else { - var demand = getDemand(); - if (demand != 0) { - state = State.STREAMING; - runnable = () -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection - .pull(runSummary.queryId(), demand > 0 ? demand : -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, flushThrowable) -> { - boltConnectionLock.unlock(); - var error = Futures.completionExceptionCause(flushThrowable); - if (error != null) { - handleError(error); - onComplete(); - } - }); - }); } else { state = State.READY; } @@ -546,12 +515,7 @@ private synchronized Runnable setupCompletionRunnableWithError(Throwable throwab private void closeBoltConnection(Runnable runnable) { var closeStage = CompletableFuture.completedStage(null); if (closeOnSummary) { - closeStage = closeStage - .thenCompose(ignored -> boltConnection.onLoop(() -> { - boltConnectionLock.lock(); - return boltConnection.close().whenComplete((result, error) -> boltConnectionLock.unlock()); - })) - .thenCompose(Function.identity()); + closeStage = closeStage.thenCompose(ignored -> boltConnection.close()); } closeStage.whenComplete((ignored, closeThrowable) -> { if (log.isTraceEnabled() && closeThrowable != null) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWork.java b/driver/src/main/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWork.java index f2c4afcdc..5764eb173 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWork.java +++ b/driver/src/main/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWork.java @@ -17,10 +17,10 @@ package org.neo4j.driver.internal.telemetry; import java.util.Objects; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicBoolean; import org.neo4j.bolt.connection.TelemetryApi; +import org.neo4j.bolt.connection.message.Messages; +import org.neo4j.bolt.connection.message.TelemetryMessage; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; public record ApiTelemetryWork(TelemetryApi telemetryApi, AtomicBoolean enabled, AtomicBoolean acknowledged) { @@ -36,11 +36,11 @@ public void acknowledge() { this.acknowledged.set(true); } - public CompletionStage pipelineTelemetryIfEnabled(DriverBoltConnection connection) { + public TelemetryMessage getTelemetryMessageIfEnabled(DriverBoltConnection connection) { if (enabled.get() && connection.telemetrySupported() && !(acknowledged.get())) { - return connection.telemetry(telemetryApi); + return Messages.telemetry(telemetryApi); } else { - return CompletableFuture.completedStage(connection); + return null; } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java index e698ddaa2..9a6e8c434 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java @@ -27,6 +27,7 @@ import static org.neo4j.driver.Values.parameters; import static org.neo4j.driver.testutil.TestUtil.connectionMock; import static org.neo4j.driver.testutil.TestUtil.newSession; +import static org.neo4j.driver.testutil.TestUtil.setupConnectionAnswers; import static org.neo4j.driver.testutil.TestUtil.setupFailingCommit; import static org.neo4j.driver.testutil.TestUtil.setupFailingRollback; import static org.neo4j.driver.testutil.TestUtil.setupFailingRun; @@ -36,21 +37,21 @@ import static org.neo4j.driver.testutil.TestUtil.verifyRunAndPull; import java.util.Collections; +import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.RollbackMessage; import org.neo4j.bolt.connection.summary.BeginSummary; -import org.neo4j.bolt.connection.summary.CommitSummary; -import org.neo4j.bolt.connection.summary.RollbackSummary; import org.neo4j.driver.Query; import org.neo4j.driver.Result; import org.neo4j.driver.Transaction; @@ -59,6 +60,7 @@ import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnectionProvider; import org.neo4j.driver.internal.adaptedbolt.DriverResponseHandler; import org.neo4j.driver.internal.value.IntegerValue; +import org.neo4j.driver.testutil.TestUtil; class InternalTransactionTest { private DriverBoltConnection connection; @@ -69,22 +71,20 @@ class InternalTransactionTest { void setUp() { connection = connectionMock(new BoltProtocolVersion(4, 0)); var connectionProvider = mock(DriverBoltConnectionProvider.class); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedFuture(connection)); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - if (handler != null) { + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { handler.onBeginSummary(mock(BeginSummary.class)); handler.onComplete(); } - return CompletableFuture.completedFuture(null); - }); + })); var session = new InternalSession(newSession(connectionProvider, Collections.emptySet())); tx = session.beginTransaction(); } @@ -111,13 +111,18 @@ void shouldFlushOnRun(Function runReturnOne) { @Test void shouldCommit() { - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onCommitSummary(mock(CommitSummary.class)); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(mock()); + handler.onComplete(); + } + })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); tx.commit(); @@ -129,13 +134,18 @@ void shouldCommit() { @Test void shouldRollbackByDefault() { - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock()); + handler.onComplete(); + } + })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); tx.close(); @@ -146,13 +156,18 @@ void shouldRollbackByDefault() { @Test void shouldRollback() { - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock()); + handler.onComplete(); + } + })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); tx.rollback(); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java index 2a9570dbb..0c29f6b52 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java @@ -44,6 +44,7 @@ import static org.neo4j.driver.testutil.TestUtil.setupConnectionAnswers; import static org.neo4j.driver.testutil.TestUtil.setupSuccessfulAutocommitRunAndPull; import static org.neo4j.driver.testutil.TestUtil.verifyAutocommitRunAndPull; +import static org.neo4j.driver.testutil.TestUtil.verifyBegin; import static org.neo4j.driver.testutil.TestUtil.verifyCommitTx; import static org.neo4j.driver.testutil.TestUtil.verifyRollbackTx; @@ -57,15 +58,19 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import java.util.stream.IntStream; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentMatchers; import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.DatabaseName; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.RollbackMessage; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.RollbackSummary; import org.neo4j.driver.AccessMode; @@ -87,6 +92,7 @@ import org.neo4j.driver.internal.retry.RetryLogic; import org.neo4j.driver.internal.util.FixedRetryLogic; import org.neo4j.driver.internal.value.IntegerValue; +import org.neo4j.driver.testutil.TestUtil; class InternalAsyncSessionTest { private DriverBoltConnection connection; @@ -97,10 +103,6 @@ class InternalAsyncSessionTest { @BeforeEach void setUp() { connection = connectionMock(new BoltProtocolVersion(4, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); given(connection.close()).willReturn(completedFuture(null)); connectionProvider = mock(DriverBoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) @@ -157,41 +159,59 @@ void shouldFlushOnRun(Function> runR @ParameterizedTest @MethodSource("allBeginTxMethods") void shouldDelegateBeginTx(Function> beginTx) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); var tx = await(beginTx.apply(asyncSession)); - verify(connection).beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - verify(connection).flush(any()); + verifyBegin(connection); assertNotNull(tx); } @ParameterizedTest @MethodSource("allRunTxMethods") void txRunShouldBeginAndCommitTx(Function> runTx) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willReturn(completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(Optional::empty); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(Optional::empty); + handler.onComplete(); + } })); var string = await(runTx.apply(asyncSession)); - verify(connection).beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verifyBegin(connection); verifyCommitTx(connection); verify(connection).close(); assertThat(string, equalTo("a")); @@ -292,19 +312,32 @@ void shouldDelegateExecuteReadToRetryLogic(ExecuteVariation executeVariation) @SuppressWarnings("deprecation") private void testTxRollbackWhenThrows(AccessMode transactionMode) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); final RuntimeException error = new IllegalStateException("Oh!"); AsyncTransactionWork> work = tx -> { @@ -315,39 +348,51 @@ private void testTxRollbackWhenThrows(AccessMode transactionMode) { assertEquals(error, e); verify(connectionProvider).connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); - verify(connection).beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verifyBegin(connection); verifyRollbackTx(connection); } private void testTxIsRetriedUntilSuccessWhenFunctionThrows(AccessMode mode) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); var failures = 12; - var failureHandlerStream = IntStream.range(0, failures) - .mapToObj(ignored -> Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); - })) - .flatMap(Function.identity()); - var retries = failures + 1; - var successHandlers = Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + var handlers = List.of( + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(Optional::empty); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(Optional::empty); + handler.onComplete(); + } }); - var allHandlers = Stream.concat(failureHandlerStream, successHandlers).toList(); - setupConnectionAnswers(connection, allHandlers); + var retries = failures + 1; + setupConnectionAnswers(connection, handlers); RetryLogic retryLogic = new FixedRetryLogic(retries); session = newSession(connectionProvider, retryLogic); @@ -358,38 +403,45 @@ private void testTxIsRetriedUntilSuccessWhenFunctionThrows(AccessMode mode) { assertEquals(42, answer); verifyInvocationCount(work, failures + 1); - verify(connection).commit(); verifyRollbackTx(connection, times(failures)); + verifyCommitTx(connection); } private void testTxIsRetriedUntilSuccessWhenCommitThrows(AccessMode mode) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); var failures = 13; - var failureHandlerStream = IntStream.range(0, failures) - .mapToObj(ignored -> Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onError(new ServiceUnavailableException("")); - handler.onComplete(); - })) - .flatMap(Function.identity()); - var retries = failures + 1; - var successHandlers = Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + var handlers = List.of( + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(Optional::empty); - handler.onComplete(); + new TestUtil.MessageHandler() { + int expectedFailures = failures; + + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + if (expectedFailures-- > 0) { + handler.onError(new ServiceUnavailableException("")); + } else { + handler.onCommitSummary(Optional::empty); + } + handler.onComplete(); + } }); - var allHandlers = Stream.concat(failureHandlerStream, successHandlers).toList(); - setupConnectionAnswers(connection, allHandlers); + var retries = failures + 1; + setupConnectionAnswers(connection, handlers); RetryLogic retryLogic = new FixedRetryLogic(retries); session = newSession(connectionProvider, retryLogic); @@ -404,23 +456,34 @@ private void testTxIsRetriedUntilSuccessWhenCommitThrows(AccessMode mode) { } private void testTxIsRetriedUntilFailureWhenFunctionThrows(AccessMode mode) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); var failures = 14; - var failureHandlerStream = IntStream.range(0, failures) - .mapToObj(ignored -> Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); - })) - .flatMap(Function.identity()); + var handlers = List.of( + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } + }); var retries = failures - 1; - setupConnectionAnswers(connection, failureHandlerStream.toList()); + setupConnectionAnswers(connection, handlers); RetryLogic retryLogic = new FixedRetryLogic(retries); session = newSession(connectionProvider, retryLogic); @@ -433,28 +496,44 @@ private void testTxIsRetriedUntilFailureWhenFunctionThrows(AccessMode mode) { assertThat(e, instanceOf(SessionExpiredException.class)); assertEquals("Oh!", e.getMessage()); verifyInvocationCount(work, failures); - verify(connection, never()).commit(); + then(connection) + .should(never()) + .writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof CommitMessage)); verifyRollbackTx(connection, times(failures)); } private void testTxIsRetriedUntilFailureWhenCommitFails(AccessMode mode) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); var failures = 17; - var failureHandlerStream = IntStream.range(0, failures) - .mapToObj(ignored -> Stream.>of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onError(new ServiceUnavailableException("")); - handler.onComplete(); - })) - .flatMap(Function.identity()); + var handlers = List.of( + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(new ServiceUnavailableException("")); + handler.onComplete(); + } + }); var retries = failures - 1; - setupConnectionAnswers(connection, failureHandlerStream.toList()); + setupConnectionAnswers(connection, handlers); RetryLogic retryLogic = new FixedRetryLogic(retries); session = newSession(connectionProvider, retryLogic); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java index 496b54e85..8182c1179 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java @@ -23,7 +23,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; @@ -43,7 +42,6 @@ import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -52,6 +50,12 @@ import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.DatabaseName; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.PullMessage; +import org.neo4j.bolt.connection.message.RollbackMessage; +import org.neo4j.bolt.connection.message.RunMessage; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.CommitSummary; import org.neo4j.bolt.connection.summary.RollbackSummary; @@ -64,8 +68,10 @@ import org.neo4j.driver.internal.InternalRecord; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnectionProvider; +import org.neo4j.driver.internal.adaptedbolt.DriverResponseHandler; import org.neo4j.driver.internal.adaptedbolt.summary.PullSummary; import org.neo4j.driver.internal.value.IntegerValue; +import org.neo4j.driver.testutil.TestUtil; class InternalAsyncTransactionTest { private DriverBoltConnection connection; @@ -74,10 +80,6 @@ class InternalAsyncTransactionTest { @BeforeEach void setUp() { connection = connectionMock(new BoltProtocolVersion(4, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); var connectionProvider = mock(DriverBoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .willAnswer((Answer>) invocation -> { @@ -104,23 +106,33 @@ private static Stream>> @ParameterizedTest @MethodSource("allSessionRunMethods") void shouldFlushOnRun(Function> runReturnOne) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.run(any(), any())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRunSummary(mock(RunSummary.class)); - handler.onPullSummary(mock(PullSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class, PullMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRunSummary(mock(RunSummary.class)); + handler.onPullSummary(mock(PullSummary.class)); + handler.onComplete(); + } })); var tx = (InternalAsyncTransaction) await(session.beginTransactionAsync()); @@ -132,20 +144,32 @@ void shouldFlushOnRun(Function> @Test void shouldCommit() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(mock(CommitSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(mock(CommitSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = (InternalAsyncTransaction) await(session.beginTransactionAsync()); @@ -159,20 +183,32 @@ void shouldCommit() { @Test void shouldRollback() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = (InternalAsyncTransaction) await(session.beginTransactionAsync()); @@ -185,20 +221,32 @@ void shouldRollback() { @Test void shouldReleaseConnectionWhenFailedToCommit() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onError(new ServiceUnavailableException("")); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(new ServiceUnavailableException("")); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = (InternalAsyncTransaction) await(session.beginTransactionAsync()); @@ -210,20 +258,32 @@ void shouldReleaseConnectionWhenFailedToCommit() { @Test void shouldReleaseConnectionWhenFailedToRollback() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onError(new ServiceUnavailableException("")); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(new ServiceUnavailableException("")); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = (InternalAsyncTransaction) await(session.beginTransactionAsync()); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java index 95e8be9c5..86eb561c7 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java @@ -21,7 +21,6 @@ import static org.hamcrest.Matchers.containsString; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -35,11 +34,14 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.mockito.ArgumentCaptor; import org.neo4j.bolt.connection.TelemetryApi; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.PullMessage; +import org.neo4j.bolt.connection.message.RunMessage; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.AuthTokenManagers; @@ -52,6 +54,7 @@ import org.neo4j.driver.TransactionConfig; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnectionProvider; +import org.neo4j.driver.internal.adaptedbolt.DriverResponseHandler; import org.neo4j.driver.internal.adaptedbolt.summary.PullSummary; import org.neo4j.driver.internal.security.BoltSecurityPlanManager; import org.neo4j.driver.internal.telemetry.ApiTelemetryWork; @@ -65,13 +68,18 @@ void logsNothingDuringFinalizationIfClosed() throws Exception { var log = mock(Logger.class); when(logging.getLog(any(Class.class))).thenReturn(log); var connection = TestUtil.connectionMock(); - given(connection.runInAutoCommitTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.pull(anyLong(), anyLong())).willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onRunSummary(mock(RunSummary.class)); - handler.onPullSummary(mock(PullSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class, PullMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRunSummary(mock(RunSummary.class)); + handler.onPullSummary(mock(PullSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(completedFuture(null)); var session = newSession(logging, connection); @@ -94,15 +102,17 @@ void logsMessageWithStacktraceDuringFinalizationIfLeaked(TestInfo testInfo) thro var log = mock(Logger.class); when(logging.getLog(any(Class.class))).thenReturn(log); var connection = TestUtil.connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); var session = newSession(logging, connection); // begin transaction to make session obtain a connection diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java index 78b2b4197..9c9b39504 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java @@ -18,16 +18,15 @@ import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.failedFuture; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; @@ -45,8 +44,12 @@ import static org.neo4j.driver.testutil.TestUtil.setupSuccessfulAutocommitRunAndPull; import static org.neo4j.driver.testutil.TestUtil.verifyAutocommitRunAndPull; import static org.neo4j.driver.testutil.TestUtil.verifyAutocommitRunRx; +import static org.neo4j.driver.testutil.TestUtil.verifyBegin; +import static org.neo4j.driver.testutil.TestUtil.verifyCommitTx; import static org.neo4j.driver.testutil.TestUtil.verifyRollbackTx; +import static org.neo4j.driver.testutil.TestUtil.verifyRunAndPull; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -54,21 +57,31 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.Consumer; -import java.util.function.Supplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.DatabaseName; import org.neo4j.bolt.connection.TelemetryApi; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; +import org.neo4j.bolt.connection.message.PullMessage; +import org.neo4j.bolt.connection.message.ResetMessage; +import org.neo4j.bolt.connection.message.RollbackMessage; +import org.neo4j.bolt.connection.message.RunMessage; +import org.neo4j.bolt.connection.message.TelemetryMessage; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.ResetSummary; import org.neo4j.bolt.connection.summary.RollbackSummary; import org.neo4j.bolt.connection.summary.RunSummary; +import org.neo4j.bolt.connection.summary.TelemetrySummary; import org.neo4j.driver.AccessMode; import org.neo4j.driver.Query; import org.neo4j.driver.TransactionConfig; @@ -80,6 +93,8 @@ import org.neo4j.driver.internal.adaptedbolt.summary.PullSummary; import org.neo4j.driver.internal.telemetry.ApiTelemetryWork; import org.neo4j.driver.internal.util.FixedRetryLogic; +import org.neo4j.driver.internal.value.BoltValueFactory; +import org.neo4j.driver.testutil.TestUtil; class NetworkSessionTest { private DriverBoltConnection connection; @@ -89,11 +104,8 @@ class NetworkSessionTest { @BeforeEach void setUp() { connection = connectionMock(new BoltProtocolVersion(5, 4)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); given(connection.close()).willReturn(completedFuture(null)); + given(connection.valueFactory()).willReturn(mock(BoltValueFactory.class)); connectionProvider = mock(DriverBoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .willAnswer((Answer>) invocation -> { @@ -135,23 +147,32 @@ void shouldNotAllowNewTxWhileOneIsRunning() { @Test void shouldBeAbleToOpenTxAfterPreviousIsClosed() { // Given - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); await(beginTransaction(session).closeAsync()); @@ -176,23 +197,37 @@ void shouldNotBeAbleToUseSessionWhileOngoingTransaction() { @Test void shouldBeAbleToUseSessionAgainWhenTransactionIsClosed() { // Given - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); await(beginTransaction(session).closeAsync()); Mockito.reset(connection); setupSuccessfulAutocommitRunAndPull(connection); + given(connection.valueFactory()).willReturn(mock(BoltValueFactory.class)); given(connection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 5)); given(connection.close()).willReturn(CompletableFuture.completedFuture(null)); var query = "RETURN 1"; @@ -206,19 +241,32 @@ void shouldBeAbleToUseSessionAgainWhenTransactionIsClosed() { @Test void shouldNotCloseAlreadyClosedSession() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); beginTransaction(session); @@ -289,19 +337,32 @@ void acquiresNewConnectionForBeginTx() { @Test void updatesBookmarkWhenTxIsClosed() { var bookmarkAfterCommit = InternalBookmark.parse("TheBookmark"); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(() -> Optional.of(bookmarkAfterCommit.value())); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(() -> Optional.of(bookmarkAfterCommit.value())); + handler.onComplete(); + } })); var tx = beginTransaction(session); @@ -315,39 +376,53 @@ void updatesBookmarkWhenTxIsClosed() { @Test void releasesConnectionWhenTxIsClosed() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.run(any(), any())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRunSummary(mock(RunSummary.class)); - handler.onPullSummary(mock(PullSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class, PullMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRunSummary(mock(RunSummary.class)); + handler.onPullSummary(mock(PullSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); var tx = beginTransaction(session); verify(connectionProvider).connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should().flush(any()); + verifyBegin(connection); var query = "RETURN 42"; await(tx.runAsync(new Query(query))); - then(connection).should().run(eq(query), any()); - then(connection).should().pull(anyLong(), anyLong()); - then(connection).should(times(2)).flush(any()); - + verifyRunAndPull(connection, query); await(tx.closeAsync()); verify(connection).close(); } @@ -360,8 +435,12 @@ void bookmarkIsPropagatedFromSession() { var tx = beginTransaction(session); assertNotNull(tx); - then(connection).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should().flush(any()); + then(connection) + .should() + .writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof BeginMessage)); } @Test @@ -371,27 +450,35 @@ void bookmarkIsPropagatedBetweenTransactions() { var session = newSession(connectionProvider); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedFuture(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onCommitSummary(() -> Optional.of(bookmark1.value())); - handler.onComplete(); - }, - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - }, - handler -> { - handler.onCommitSummary(() -> Optional.of(bookmark2.value())); - handler.onComplete(); + new TestUtil.MessageHandler() { + int num; + + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary( + () -> Optional.of(num++ == 0 ? bookmark1.value() : bookmark2.value())); + handler.onComplete(); + } })); var tx1 = beginTransaction(session); @@ -399,10 +486,8 @@ void bookmarkIsPropagatedBetweenTransactions() { assertEquals(Collections.singleton(bookmark1), session.lastBookmarks()); var tx2 = beginTransaction(session); - then(connection) - .should(times(2)) - .beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should(times(3)).flush(any()); + verifyBegin(connection, times(2)); + verifyCommitTx(connection); await(tx2.commitAsync()); assertEquals(Collections.singleton(bookmark2), session.lastBookmarks()); @@ -494,7 +579,10 @@ void shouldRunAfterRunFailure() { void shouldRunAfterBeginTxFailureOnBookmark() { var error = new RuntimeException("Hi"); var connection1 = connectionMock(new BoltProtocolVersion(5, 0)); - given(connection1.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) + given(connection1.writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof BeginMessage))) .willReturn(CompletableFuture.failedStage(error)); given(connection1.close()).willReturn(CompletableFuture.completedStage(null)); var connection2 = connectionMock(new BoltProtocolVersion(5, 0)); @@ -531,7 +619,7 @@ void shouldRunAfterBeginTxFailureOnBookmark() { verify(connectionProvider, times(2)) .connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection1).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verifyBegin(connection1); verifyAutocommitRunAndPull(connection2, "RETURN 2"); } @@ -539,22 +627,24 @@ void shouldRunAfterBeginTxFailureOnBookmark() { void shouldBeginTxAfterBeginTxFailureOnBookmark() { var error = new RuntimeException("Hi"); var connection1 = connectionMock(new BoltProtocolVersion(5, 0)); - given(connection1.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection1.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) + given(connection1.writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof BeginMessage))) .willReturn(CompletableFuture.failedStage(error)); + given(connection1.close()).willReturn(CompletableFuture.completedStage(null)); var connection2 = connectionMock(new BoltProtocolVersion(5, 0)); - given(connection2.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection2.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection2)); - setupConnectionAnswers(connection2, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection2, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); Mockito.reset(connectionProvider); @@ -586,8 +676,8 @@ void shouldBeginTxAfterBeginTxFailureOnBookmark() { verify(connectionProvider, times(2)) .connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection1).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection2).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verifyBegin(connection1); + verifyBegin(connection2); } @Test @@ -613,33 +703,51 @@ void shouldBeginTxAfterRunFailureToAcquireConnection() { verify(connectionProvider, times(2)) .connect(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + then(connection) + .should() + .writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof BeginMessage)); } @Test void shouldMarkTransactionAsTerminatedAndThenResetConnectionOnReset() { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.reset()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onResetSummary(mock(ResetSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(ResetMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onResetSummary(mock(ResetSummary.class)); + handler.onComplete(); + } })); var tx = beginTransaction(session); assertTrue(tx.isOpen()); - verify(connection, never()).reset(); + then(connection).should(never()).writeAndFlush(any(), any(ResetMessage.class)); await(session.resetAsync()); - verify(connection).reset(); + then(connection).should().writeAndFlush(any(), eq(List.of(Messages.reset()))); } @ParameterizedTest @@ -648,7 +756,26 @@ void shouldSendTelemetryIfEnabledOnBegin(boolean telemetryDisabled) { // given var session = newSession(connectionProvider, WRITE, new FixedRetryLogic(0), Set.of(), telemetryDisabled); given(connection.telemetrySupported()).willReturn(true); - given(connection.telemetry(any())).willReturn(CompletableFuture.completedStage(connection)); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + var messageTypes = new ArrayList>(); + if (!telemetryDisabled) { + messageTypes.add(TelemetryMessage.class); + } + messageTypes.add(BeginMessage.class); + return messageTypes; + } + + @Override + public void handle(DriverResponseHandler handler) { + if (!telemetryDisabled) { + handler.onTelemetrySummary(mock(TelemetrySummary.class)); + } + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } + })); setupSuccessfulBegin(connection); // when @@ -656,9 +783,17 @@ void shouldSendTelemetryIfEnabledOnBegin(boolean telemetryDisabled) { // then if (telemetryDisabled) { - then(connection).should(never()).telemetry(any()); + then(connection) + .should(never()) + .writeAndFlush(any(), ArgumentMatchers.>argThat(messages -> messages.stream() + .anyMatch(msg -> msg instanceof TelemetryMessage))); } else { - then(connection).should().telemetry(eq(TelemetryApi.UNMANAGED_TRANSACTION)); + then(connection) + .should() + .writeAndFlush( + any(), + ArgumentMatchers.>argThat(messages -> + messages.contains(Messages.telemetry(TelemetryApi.UNMANAGED_TRANSACTION)))); } } @@ -667,31 +802,62 @@ void shouldSendTelemetryIfEnabledOnBegin(boolean telemetryDisabled) { void shouldSendTelemetryIfEnabledOnRun(boolean telemetryDisabled) { // given var query = "RETURN 1"; + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + var messageTypes = new ArrayList>(); + if (!telemetryDisabled) { + messageTypes.add(TelemetryMessage.class); + } + messageTypes.add(RunMessage.class); + messageTypes.add(PullMessage.class); + return messageTypes; + } + + @Override + public void handle(DriverResponseHandler handler) { + if (!telemetryDisabled) { + handler.onTelemetrySummary(mock(TelemetrySummary.class)); + } + handler.onRunSummary(mock(RunSummary.class)); + handler.onPullSummary(mock(PullSummary.class)); + handler.onComplete(); + } + })); setupSuccessfulAutocommitRunAndPull(connection); var session = newSession(connectionProvider, WRITE, new FixedRetryLogic(0), Set.of(), telemetryDisabled); given(connection.telemetrySupported()).willReturn(true); - given(connection.telemetry(any())).willReturn(CompletableFuture.completedStage(connection)); // when run(session, query); // then if (telemetryDisabled) { - then(connection).should(never()).telemetry(any()); + then(connection) + .should(never()) + .writeAndFlush(any(), ArgumentMatchers.>argThat(messages -> messages.stream() + .anyMatch(msg -> msg instanceof TelemetryMessage))); } else { - then(connection).should().telemetry(eq(TelemetryApi.AUTO_COMMIT_TRANSACTION)); + then(connection) + .should() + .writeAndFlush( + any(), + ArgumentMatchers.>argThat(messages -> + messages.contains(Messages.telemetry(TelemetryApi.AUTO_COMMIT_TRANSACTION)))); } } private void setupSuccessfulBegin(DriverBoltConnection connection) { - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArguments()[0]; - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); - return completedFuture(null); - }); + given(connection.writeAndFlush( + any(), + ArgumentMatchers.>argThat( + argument -> argument.size() == 1 && argument.get(0) instanceof BeginMessage))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArguments()[0]; + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + return completedFuture(null); + }); } private static void run(NetworkSession session, String query) { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index 4ddd1e7fd..2f7f04187 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -25,7 +25,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.doReturn; @@ -37,6 +37,10 @@ import static org.neo4j.driver.testutil.TestUtil.await; import static org.neo4j.driver.testutil.TestUtil.connectionMock; import static org.neo4j.driver.testutil.TestUtil.setupConnectionAnswers; +import static org.neo4j.driver.testutil.TestUtil.verifyBegin; +import static org.neo4j.driver.testutil.TestUtil.verifyCommitTx; +import static org.neo4j.driver.testutil.TestUtil.verifyRollbackTx; +import static org.neo4j.driver.testutil.TestUtil.verifyRun; import static org.neo4j.driver.testutil.TestUtil.verifyRunAndPull; import java.util.Collections; @@ -57,9 +61,16 @@ import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.DatabaseNameUtil; import org.neo4j.bolt.connection.TelemetryApi; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; +import org.neo4j.bolt.connection.message.PullMessage; +import org.neo4j.bolt.connection.message.ResetMessage; +import org.neo4j.bolt.connection.message.RollbackMessage; +import org.neo4j.bolt.connection.message.RunMessage; import org.neo4j.bolt.connection.summary.BeginSummary; import org.neo4j.bolt.connection.summary.CommitSummary; -import org.neo4j.bolt.connection.summary.ResetSummary; import org.neo4j.bolt.connection.summary.RollbackSummary; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.Bookmark; @@ -74,33 +85,43 @@ import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.InternalBookmark; import org.neo4j.driver.internal.adaptedbolt.DriverBoltConnection; +import org.neo4j.driver.internal.adaptedbolt.DriverResponseHandler; import org.neo4j.driver.internal.adaptedbolt.summary.PullSummary; import org.neo4j.driver.internal.telemetry.ApiTelemetryWork; +import org.neo4j.driver.testutil.TestUtil; class UnmanagedTransactionTest { @Test void shouldFlushOnRunAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(5, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRunSummary(mock(RunSummary.class)); - handler.onPullSummary(mock(PullSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class, PullMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRunSummary(mock(RunSummary.class)); + handler.onPullSummary(mock(PullSummary.class)); + handler.onComplete(); + } })); var tx = beginTx(connection); @@ -115,23 +136,32 @@ void shouldFlushOnRunAsync() { void shouldFlushOnRunRx() { // Given var connection = connectionMock(new BoltProtocolVersion(5, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRunSummary(mock(RunSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRunSummary(mock(RunSummary.class)); + handler.onComplete(); + } })); var tx = beginTx(connection); @@ -139,31 +169,40 @@ void shouldFlushOnRunRx() { await(tx.runRx(new Query("RETURN 1"))); // Then - then(connection).should().run("RETURN 1", Collections.emptyMap()); - then(connection).should(times(2)).flush(any()); + verifyBegin(connection); + verifyRun(connection, "RETURN 1"); } @Test void shouldRollbackOnImplicitFailure() { // Given var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = beginTx(connection); @@ -172,44 +211,46 @@ void shouldRollbackOnImplicitFailure() { await(tx.closeAsync()); // Then - then(connection).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should().rollback(); - then(connection).should(times(2)).flush(any()); + verifyBegin(connection); + verifyRollbackTx(connection); then(connection).should().close(); } @Test void shouldBeginTransaction() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); beginTx(connection, Collections.emptySet()); - then(connection).should().beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any()); - then(connection).should().flush(any()); + verifyBegin(connection); } @Test void shouldBeOpenAfterConstruction() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); var tx = beginTx(connection); @@ -220,15 +261,17 @@ void shouldBeOpenAfterConstruction() { @Test void shouldBeClosedWhenMarkedAsTerminated() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); var tx = beginTx(connection); @@ -240,15 +283,17 @@ void shouldBeClosedWhenMarkedAsTerminated() { @Test void shouldBeClosedWhenMarkedTerminatedAndClosed() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(completedFuture(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var tx = beginTx(connection); @@ -263,15 +308,17 @@ void shouldBeClosedWhenMarkedTerminatedAndClosed() { void shouldReleaseConnectionWhenBeginFails() { var error = new RuntimeException("Wrong bookmark!"); var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onError(error); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(error); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); @@ -299,15 +346,17 @@ void shouldReleaseConnectionWhenBeginFails() { @Test void shouldNotReleaseConnectionWhenBeginSucceeds() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); @@ -466,14 +515,17 @@ void shouldReleaseConnectionWhenTerminatedAndRolledBack() { @Test void shouldReleaseConnectionWhenClose() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onRollbackSummary(mock(RollbackSummary.class)); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); @@ -498,15 +550,17 @@ void shouldReleaseConnectionWhenClose() { void shouldReleaseConnectionOnConnectionAuthorizationExpiredExceptionFailure() { var exception = new AuthorizationExpiredException("code", "message"); var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onError(exception); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(exception); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); @@ -534,15 +588,17 @@ void shouldReleaseConnectionOnConnectionAuthorizationExpiredExceptionFailure() { @Test void shouldReleaseConnectionOnConnectionReadTimeoutExceptionFailure() { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - handler.onError(ConnectionReadTimeoutException.INSTANCE); - handler.onComplete(); + setupConnectionAnswers(connection, List.of(new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(ConnectionReadTimeoutException.INSTANCE); + handler.onComplete(); + } })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); @@ -581,13 +637,27 @@ private static Stream similarTransactionCompletingActionArgs() { void shouldReturnExistingStageOnSimilarCompletingAction( boolean protocolCommit, String initialAction, String similarAction) { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willReturn(CompletableFuture.completedStage(null)); + setupConnectionAnswers( + connection, + List.of( + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) {} + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) {} + })); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); var tx = new UnmanagedTransaction( @@ -607,9 +677,9 @@ void shouldReturnExistingStageOnSimilarCompletingAction( assertSame(initialStage, similarStage); if (protocolCommit) { - then(connection).should(times(1)).commit(); + verifyCommitTx(connection, times(1)); } else { - then(connection).should(times(1)).rollback(); + verifyRollbackTx(connection, times(1)); } } @@ -636,24 +706,36 @@ void shouldReturnFailingStageOnConflictingCompletingAction( String conflictingAction, String expectedErrorMsg) { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - if (protocolActionCompleted) { - setupConnectionAnswers(connection, List.of(handler -> { - if (protocolCommit) { - handler.onCommitSummary(mock(CommitSummary.class)); - } else { - handler.onRollbackSummary(mock(RollbackSummary.class)); + var messageHandler = protocolCommit + ? new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + if (protocolActionCompleted) { + handler.onCommitSummary(mock(CommitSummary.class)); + handler.onComplete(); + } + } } - handler.onComplete(); - })); - } else { - given(connection.flush(any())).willReturn(CompletableFuture.completedStage(null)); - } + : new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + if (protocolActionCompleted) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } + } + }; + setupConnectionAnswers(connection, List.of(messageHandler)); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); var tx = new UnmanagedTransaction( @@ -673,11 +755,10 @@ void shouldReturnFailingStageOnConflictingCompletingAction( assertNotNull(originalActionStage); if (protocolCommit) { - then(connection).should().commit(); + verifyCommitTx(connection, times(1)); } else { - then(connection).should().rollback(); + verifyRollbackTx(connection, times(1)); } - then(connection).should().flush(any()); assertTrue(conflictingActionStage.toCompletableFuture().isCompletedExceptionally()); var throwable = assertThrows( ExecutionException.class, @@ -704,20 +785,32 @@ private static Stream closingNotActionTransactionArgs() { void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommittingAborted( boolean protocolCommit, int expectedProtocolInvocations, String originalAction, Boolean commitOnClose) { var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); - setupConnectionAnswers(connection, List.of(handler -> { - if (protocolCommit) { - handler.onCommitSummary(mock(CommitSummary.class)); - } else { - handler.onRollbackSummary(mock(RollbackSummary.class)); - } - handler.onComplete(); - })); + var messageHandler = protocolCommit + ? new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onCommitSummary(mock(CommitSummary.class)); + handler.onComplete(); + } + } + : new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onRollbackSummary(mock(RollbackSummary.class)); + handler.onComplete(); + } + }; + setupConnectionAnswers(connection, List.of(messageHandler)); given(connection.close()).willReturn(CompletableFuture.completedStage(null)); var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.UNMANAGED_TRANSACTION); var tx = new UnmanagedTransaction( @@ -738,11 +831,10 @@ void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommitt assertTrue(originalActionStage.toCompletableFuture().isDone()); assertFalse(originalActionStage.toCompletableFuture().isCompletedExceptionally()); if (protocolCommit) { - then(connection).should(times(expectedProtocolInvocations)).commit(); + verifyCommitTx(connection, times(expectedProtocolInvocations)); } else { - then(connection).should(times(expectedProtocolInvocations)).rollback(); + verifyRollbackTx(connection, times(expectedProtocolInvocations)); } - then(connection).should(times(expectedProtocolInvocations)).flush(any()); assertNull(closeStage.toCompletableFuture().join()); } @@ -750,24 +842,32 @@ void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommitt void shouldTerminateOnTerminateAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - given(connection.clear()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.reset()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onResetSummary(mock(ResetSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(ResetMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onResetSummary(mock()); + handler.onComplete(); + } })); var tx = beginTx(connection); @@ -775,32 +875,36 @@ void shouldTerminateOnTerminateAsync() { await(tx.terminateAsync()); // Then - then(connection).should().clear(); - then(connection).should().reset(); + then(connection).should().writeAndFlush(any(), eq(List.of(Messages.reset()))); } @Test void shouldServeTheSameStageOnTerminateAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - given(connection.clear()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.reset()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onResetSummary(mock(ResetSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(ResetMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) {} })); var tx = beginTx(connection); @@ -816,25 +920,45 @@ void shouldServeTheSameStageOnTerminateAsync() { void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, InterruptedException { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willReturn(CompletableFuture.completedStage(connection)); var exception = new Neo4jException("message"); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> { - handler.onError(exception); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RunMessage.class, PullMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onError(exception); + handler.onComplete(); + } + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(ResetMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onResetSummary(mock()); + handler.onComplete(); + } })); var tx = beginTx(connection); Throwable actualException = null; @@ -856,22 +980,39 @@ void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, I void shouldThrowOnRunningNewQueriesWhenTransactionIsClosing(TransactionClosingTestParams testParams) { // Given var connection = connectionMock(); - given(connection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willReturn(CompletableFuture.completedStage(connection)); - given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); - given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers( connection, List.of( - handler -> { - handler.onBeginSummary(mock(BeginSummary.class)); - handler.onComplete(); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(BeginMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) { + handler.onBeginSummary(mock(BeginSummary.class)); + handler.onComplete(); + } }, - handler -> {})); + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(CommitMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) {} + }, + new TestUtil.MessageHandler() { + @Override + public List> messageTypes() { + return List.of(RollbackMessage.class); + } + + @Override + public void handle(DriverResponseHandler handler) {} + })); var tx = beginTx(connection); // When diff --git a/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java index 0d602a08a..a2d4b8b6e 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java @@ -21,6 +21,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; @@ -39,6 +40,8 @@ import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.Messages; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.Query; import org.neo4j.driver.Value; @@ -78,12 +81,12 @@ void beforeEach() { @Test void shouldNextAsync() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onRecord(new Value[0]); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onRecord(new Value[0]); + return CompletableFuture.completedStage(null); + }); var record = cursor.nextAsync().toCompletableFuture().join(); @@ -92,7 +95,6 @@ var record = cursor.nextAsync().toCompletableFuture().join(); @Test void shouldFailNextAsyncOnError() { - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new Neo4jException("code", "message"); cursor.onError(error); cursor.onComplete(); @@ -106,9 +108,8 @@ void shouldFailNextAsyncOnError() { @Test void shouldFailNextAsyncOnFlushError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new RuntimeException("message"); - given(connection.flush(any())) + given(connection.writeAndFlush(any(), any(Message.class))) .willAnswer((Answer>) invocation -> CompletableFuture.failedStage(error)); var future = cursor.nextAsync().toCompletableFuture(); @@ -132,15 +133,15 @@ var record = cursor.singleAsync().toCompletableFuture().join(); void shouldFailSingleAsync() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onRecord(new Value[0]); - var pullSummary = mock(PullSummary.class); - given(pullSummary.hasMore()).willReturn(true); - handler.onPullSummary(pullSummary); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onRecord(new Value[0]); + var pullSummary = mock(PullSummary.class); + given(pullSummary.hasMore()).willReturn(true); + handler.onPullSummary(pullSummary); + return CompletableFuture.completedStage(null); + }); var future = cursor.singleAsync().toCompletableFuture(); @@ -152,14 +153,14 @@ void shouldFailSingleAsync() { void shouldFailSingleAsyncOnError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new Neo4jException("code", "message"); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); var future = cursor.singleAsync().toCompletableFuture(); @@ -171,9 +172,8 @@ void shouldFailSingleAsyncOnError() { void shouldFailSingleAsyncOnFlushError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new RuntimeException("message"); - given(connection.flush(any())) + given(connection.writeAndFlush(any(), any(Message.class))) .willAnswer((Answer>) invocation -> CompletableFuture.failedStage(error)); var future = cursor.singleAsync().toCompletableFuture(); @@ -186,57 +186,55 @@ void shouldFailSingleAsyncOnFlushError() { void shouldFetchMore() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - for (var i = 0; i < fetchSize; i++) { - handler.onRecord(new Value[0]); - } - var pullSummary = mock(PullSummary.class); - given(pullSummary.hasMore()).willReturn(true); - handler.onPullSummary(pullSummary); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + for (var i = 0; i < fetchSize; i++) { + handler.onRecord(new Value[0]); + } + var pullSummary = mock(PullSummary.class); + given(pullSummary.hasMore()).willReturn(true); + handler.onPullSummary(pullSummary); + return CompletableFuture.completedStage(null); + }); for (var i = 0; i < fetchSize; i++) { cursor.nextAsync().toCompletableFuture().join(); } assertNotNull(cursor.nextAsync().toCompletableFuture().join()); - then(connection).should(times(2)).pull(0, fetchSize); - then(connection).should(times(2)).flush(any()); + then(connection).should(times(2)).writeAndFlush(any(), eq(Messages.pull(0, fetchSize))); } @Test void shouldListAsync() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, -1)).willReturn(CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onRecord(new Value[0]); - var pullSummary = mock(PullSummary.class); - handler.onPullSummary(pullSummary); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onRecord(new Value[0]); + var pullSummary = mock(PullSummary.class); + handler.onPullSummary(pullSummary); + return CompletableFuture.completedStage(null); + }); assertEquals(1, cursor.listAsync().toCompletableFuture().join().size()); - then(connection).should().pull(0, -1); - then(connection).should().flush(any()); + then(connection).should().writeAndFlush(any(), eq(Messages.pull(0, -1))); } @Test void shouldFailListAsyncOnError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, -1)).willReturn(CompletableFuture.completedStage(connection)); var error = new Neo4jException("code", "message"); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); var future = cursor.listAsync().toCompletableFuture(); @@ -248,9 +246,8 @@ void shouldFailListAsyncOnError() { void shouldFailListAsyncOnFlushError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, -1)).willReturn(CompletableFuture.completedStage(connection)); var error = new RuntimeException("message"); - given(connection.flush(any())) + given(connection.writeAndFlush(any(), any(Message.class))) .willAnswer((Answer>) invocation -> CompletableFuture.failedStage(error)); var future = cursor.listAsync().toCompletableFuture(); @@ -263,14 +260,14 @@ void shouldFailListAsyncOnFlushError() { void shouldFailPeekAsyncOnError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new Neo4jException("code", "message"); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); var future = cursor.peekAsync().toCompletableFuture(); @@ -282,9 +279,8 @@ void shouldFailPeekAsyncOnError() { void shouldFailListPeekOnFlushError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.pull(0, fetchSize)).willReturn(CompletableFuture.completedStage(connection)); var error = new RuntimeException("message"); - given(connection.flush(any())) + given(connection.writeAndFlush(any(), any(Message.class))) .willAnswer((Answer>) invocation -> CompletableFuture.failedStage(error)); var future = cursor.peekAsync().toCompletableFuture(); @@ -297,14 +293,14 @@ void shouldFailListPeekOnFlushError() { void shouldFailConsumeAsyncOnError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.discard(0, -1)).willReturn(CompletableFuture.completedStage(connection)); var error = new Neo4jException("code", "message"); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); var future = cursor.consumeAsync().toCompletableFuture(); @@ -316,9 +312,8 @@ void shouldFailConsumeAsyncOnError() { void shouldFailConsumeAsyncOnFlushError() { cursor.onPullSummary(new PullSummaryImpl(true, Collections.emptyMap())); given(connection.serverAddress()).willReturn(BoltServerAddress.LOCAL_DEFAULT); - given(connection.discard(0, -1)).willReturn(CompletableFuture.completedStage(connection)); var error = new RuntimeException("message"); - given(connection.flush(any())) + given(connection.writeAndFlush(any(), any(Message.class))) .willAnswer((Answer>) invocation -> CompletableFuture.failedStage(error)); var future = cursor.consumeAsync().toCompletableFuture(); diff --git a/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java index 82cf1e62e..be0191774 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java @@ -68,8 +68,7 @@ void shouldNotifyRecordConsumerOfRunError(boolean getRunError) { // given var runError = mock(Throwable.class); given(connection.serverAddress()).willReturn(new BoltServerAddress("localhost")); - var cursor = new RxResultCursorImpl( - connection, mock(), query, null, runError, bookmarkConsumer, false, Logging.none()); + var cursor = new RxResultCursorImpl(connection, query, null, runError, bookmarkConsumer, false, Logging.none()); if (getRunError) { assertEquals(runError, cursor.getRunError()); } @@ -90,8 +89,7 @@ void shouldReturnSummaryWithRunError(boolean getRunError) { // given var runError = mock(Throwable.class); given(connection.serverAddress()).willReturn(new BoltServerAddress("localhost")); - var cursor = new RxResultCursorImpl( - connection, mock(), query, null, runError, bookmarkConsumer, false, Logging.none()); + var cursor = new RxResultCursorImpl(connection, query, null, runError, bookmarkConsumer, false, Logging.none()); if (getRunError) { assertEquals(runError, cursor.getRunError()); } @@ -109,8 +107,8 @@ void shouldReturnKeys() { // given var keys = List.of("a", "b"); given(runSummary.keys()).willReturn(keys); - var cursor = new RxResultCursorImpl( - connection, mock(), query, runSummary, null, bookmarkConsumer, false, Logging.none()); + var cursor = + new RxResultCursorImpl(connection, query, runSummary, null, bookmarkConsumer, false, Logging.none()); // when & then assertEquals(keys, cursor.keys()); diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java index a1c12f144..6afe19160 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java @@ -27,7 +27,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; @@ -39,13 +38,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; -import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.stubbing.Answer; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; +import org.neo4j.bolt.connection.message.Message; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.Logging; import org.neo4j.driver.Record; @@ -138,22 +137,18 @@ void shouldCancelKeys() { void shouldObtainRecordsAndSummary() { // Given var boltConnection = mock(DriverBoltConnection.class); - given(boltConnection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(boltConnection.pull(anyLong(), anyLong())).willReturn(CompletableFuture.completedFuture(boltConnection)); given(boltConnection.serverAddress()).willReturn(new BoltServerAddress("localhost")); given(boltConnection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 1)); - given(boltConnection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArguments()[0]; - handler.onRecord(values(1, 1, 1)); - handler.onRecord(values(2, 2, 2)); - handler.onRecord(values(3, 3, 3)); - handler.onPullSummary(mock()); - handler.onComplete(); - return CompletableFuture.completedFuture(null); - }); + given(boltConnection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArguments()[0]; + handler.onRecord(values(1, 1, 1)); + handler.onRecord(values(2, 2, 2)); + handler.onRecord(values(3, 3, 3)); + handler.onPullSummary(mock()); + handler.onComplete(); + return CompletableFuture.completedFuture(null); + }); var runSummary = mock(RunSummary.class); given(runSummary.keys()).willReturn(List.of("key1", "key2", "key3")); Record record1 = new InternalRecord(asList("key1", "key2", "key3"), values(1, 1, 1)); @@ -175,22 +170,18 @@ void shouldObtainRecordsAndSummary() { void shouldCancelStreamingButObtainSummary() { // Given var boltConnection = mock(DriverBoltConnection.class); - given(boltConnection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(boltConnection.pull(anyLong(), anyLong())).willReturn(CompletableFuture.completedFuture(boltConnection)); given(boltConnection.serverAddress()).willReturn(new BoltServerAddress("localhost")); given(boltConnection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 1)); - given(boltConnection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArguments()[0]; - handler.onRecord(values(1, 1, 1)); - handler.onRecord(values(2, 2, 2)); - handler.onRecord(values(3, 3, 3)); - handler.onPullSummary(mock()); - handler.onComplete(); - return CompletableFuture.completedFuture(null); - }); + given(boltConnection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArguments()[0]; + handler.onRecord(values(1, 1, 1)); + handler.onRecord(values(2, 2, 2)); + handler.onRecord(values(3, 3, 3)); + handler.onPullSummary(mock()); + handler.onComplete(); + return CompletableFuture.completedFuture(null); + }); var runSummary = mock(RunSummary.class); given(runSummary.keys()).willReturn(List.of("key1", "key2", "key3")); Record record1 = new InternalRecord(asList("key1", "key2", "key3"), values(1, 1, 1)); @@ -223,20 +214,16 @@ void shouldErrorIfFailedToCreateCursor() { void shouldErrorIfFailedToStream() { // Given var boltConnection = mock(DriverBoltConnection.class); - given(boltConnection.onLoop(any())).willAnswer(invocationOnMock -> { - Supplier supplier = invocationOnMock.getArgument(0); - return CompletableFuture.completedStage(supplier.get()); - }); - given(boltConnection.pull(anyLong(), anyLong())).willReturn(CompletableFuture.completedFuture(boltConnection)); given(boltConnection.serverAddress()).willReturn(new BoltServerAddress("localhost")); given(boltConnection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 1)); Throwable error = new RuntimeException("Hi"); - given(boltConnection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArguments()[0]; - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedFuture(null); - }); + given(boltConnection.writeAndFlush(any(), any(Message.class))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArguments()[0]; + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedFuture(null); + }); RxResult rxResult = newRxResult(boltConnection); // When & Then @@ -270,7 +257,7 @@ private InternalRxResult newRxResult(DriverBoltConnection boltConnection) { private InternalRxResult newRxResult(DriverBoltConnection boltConnection, RunSummary runSummary) { RxResultCursor cursor = new RxResultCursorImpl( - boltConnection, mock(), mock(), runSummary, null, databaseBookmark -> {}, false, Logging.none()); + boltConnection, mock(), runSummary, null, databaseBookmark -> {}, false, Logging.none()); return newRxResult(cursor); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWorkTest.java b/driver/src/test/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWorkTest.java index 1c04cf17e..420f7c981 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWorkTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/telemetry/ApiTelemetryWorkTest.java @@ -17,11 +17,11 @@ package org.neo4j.driver.internal.telemetry; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; -import java.util.concurrent.CompletableFuture; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.mockito.Mockito; @@ -36,14 +36,12 @@ void shouldPipelineTelemetryWhenTelemetryIsEnabledAndConnectionSupportsTelemetry var apiTelemetryWork = new ApiTelemetryWork(telemetryApi); apiTelemetryWork.setEnabled(true); var boltConnection = Mockito.mock(DriverBoltConnection.class); - var boltConnectionStage = CompletableFuture.completedFuture(boltConnection); given(boltConnection.telemetrySupported()).willReturn(true); - given(boltConnection.telemetry(telemetryApi)).willReturn(boltConnectionStage); - var stage = apiTelemetryWork.pipelineTelemetryIfEnabled(boltConnection); + var message = apiTelemetryWork.getTelemetryMessageIfEnabled(boltConnection); - assertEquals(boltConnectionStage, stage); - then(boltConnection).should().telemetry(telemetryApi); + assertNotNull(message); + assertEquals(telemetryApi, message.api()); } @ParameterizedTest @@ -54,10 +52,9 @@ void shouldNotPipelineTelemetryWhenTelemetryIsEnabledAndConnectionDoesNotSupport apiTelemetryWork.setEnabled(true); var boltConnection = Mockito.mock(DriverBoltConnection.class); - var future = apiTelemetryWork.pipelineTelemetryIfEnabled(boltConnection).toCompletableFuture(); + var message = apiTelemetryWork.getTelemetryMessageIfEnabled(boltConnection); - assertTrue(future.isDone()); - assertEquals(boltConnection, future.join()); + assertNull(message); then(boltConnection).should().telemetrySupported(); then(boltConnection).shouldHaveNoMoreInteractions(); } @@ -69,10 +66,9 @@ void shouldNotPipelineTelemetryWhenTelemetryIsDisabledAndConnectionDoesNotSuppor var apiTelemetryWork = new ApiTelemetryWork(telemetryApi); var boltConnection = Mockito.mock(DriverBoltConnection.class); - var future = apiTelemetryWork.pipelineTelemetryIfEnabled(boltConnection).toCompletableFuture(); + var message = apiTelemetryWork.getTelemetryMessageIfEnabled(boltConnection); - assertTrue(future.isDone()); - assertEquals(boltConnection, future.join()); + assertNull(message); then(boltConnection).shouldHaveNoInteractions(); } @@ -83,10 +79,9 @@ void shouldNotPipelineTelemetryWhenTelemetryIsDisabledAndConnectionSupportsTelem var boltConnection = Mockito.mock(DriverBoltConnection.class); given(boltConnection.telemetrySupported()).willReturn(true); - var future = apiTelemetryWork.pipelineTelemetryIfEnabled(boltConnection).toCompletableFuture(); + var message = apiTelemetryWork.getTelemetryMessageIfEnabled(boltConnection); - assertTrue(future.isDone()); - assertEquals(boltConnection, future.join()); + assertNull(message); then(boltConnection).shouldHaveNoInteractions(); } } diff --git a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java index 1d7891d03..197299a16 100644 --- a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java @@ -19,9 +19,9 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.stream.Collectors.toList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.atLeastOnce; @@ -54,14 +54,21 @@ import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeoutException; -import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; import org.neo4j.bolt.connection.BoltProtocolVersion; import org.neo4j.bolt.connection.BoltServerAddress; +import org.neo4j.bolt.connection.message.BeginMessage; +import org.neo4j.bolt.connection.message.CommitMessage; +import org.neo4j.bolt.connection.message.Message; +import org.neo4j.bolt.connection.message.PullMessage; +import org.neo4j.bolt.connection.message.RollbackMessage; +import org.neo4j.bolt.connection.message.RunMessage; import org.neo4j.bolt.connection.summary.CommitSummary; import org.neo4j.bolt.connection.summary.RunSummary; import org.neo4j.driver.AccessMode; @@ -81,6 +88,7 @@ import org.neo4j.driver.internal.retry.RetryLogic; import org.neo4j.driver.internal.security.BoltSecurityPlanManager; import org.neo4j.driver.internal.util.FixedRetryLogic; +import org.neo4j.driver.internal.value.BoltValueFactory; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -229,53 +237,86 @@ public static NetworkSession newSession( mock()); } - public static void setupConnectionAnswers( - DriverBoltConnection connection, List> handlerConsumers) { - given(connection.flush(any())).willAnswer(new Answer>() { - private int index; + public static void setupConnectionAnswers(DriverBoltConnection connection, List messageHandlers) { + for (var messageHandler : messageHandlers) { + given(connection.writeAndFlush(any(), ArgumentMatchers.>argThat(messages -> { + if (messages == null + || messages.size() + != messageHandler.messageTypes().size()) { + return false; + } + return IntStream.range(0, messages.size()).allMatch(i -> messageHandler + .messageTypes() + .get(i) + .isAssignableFrom(messages.get(i).getClass())); + }))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArguments()[0]; + messageHandler.handle(handler); + return CompletableFuture.completedFuture(null); + }); + } + } - @Override - public CompletionStage answer(InvocationOnMock invocation) { - var handler = (DriverResponseHandler) invocation.getArguments()[0]; - var consumer = handlerConsumers.get(index++); - consumer.accept(handler); - return CompletableFuture.completedFuture(null); - } - }); + public interface MessageHandler { + List> messageTypes(); + + void handle(DriverResponseHandler handler); } public static void verifyAutocommitRunRx(DriverBoltConnection connection, String query) { - then(connection) - .should() - .runInAutoCommitTransaction(any(), any(), any(), any(), eq(query), any(), any(), any(), any()); - then(connection).should().flush(any()); + then(connection).should().writeAndFlush(any(), ArgumentMatchers.>argThat(argument -> { + var runMessage = (RunMessage) argument.get(0); + return runMessage.query().equals(query); + })); + } + + public static void verifyRun(DriverBoltConnection connection, String query) { + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + then(connection).should(atLeastOnce()).writeAndFlush(any(), captor.capture()); + var messages = captor.getValue(); + assertInstanceOf(RunMessage.class, messages.get(0)); + assertEquals(query, ((RunMessage) messages.get(0)).query()); } public static void verifyRunAndPull(DriverBoltConnection connection, String query) { - then(connection).should().run(eq(query), any()); - then(connection).should().pull(anyLong(), anyLong()); - then(connection).should(atLeastOnce()).flush(any()); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + then(connection).should(atLeastOnce()).writeAndFlush(any(), captor.capture()); + var messages = captor.getAllValues().get(1); + assertInstanceOf(RunMessage.class, messages.get(0)); + assertEquals(query, ((RunMessage) messages.get(0)).query()); + assertInstanceOf(PullMessage.class, messages.get(1)); } public static void verifyAutocommitRunAndPull(DriverBoltConnection connection, String query) { - then(connection) - .should() - .runInAutoCommitTransaction(any(), any(), any(), any(), eq(query), any(), any(), any(), any()); - then(connection).should().pull(anyLong(), anyLong()); - then(connection).should().flush(any()); + then(connection).should().writeAndFlush(any(), ArgumentMatchers.>argThat(argument -> { + var runMessage = (RunMessage) argument.get(0); + var pullMessage = (PullMessage) argument.get(1); + return runMessage.query().equals(query) && pullMessage != null; + })); } public static void verifyCommitTx(DriverBoltConnection connection, VerificationMode mode) { - verify(connection, mode).commit(); - verify(connection, mode).close(); + verify(connection, mode) + .writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof CommitMessage)); } public static void verifyCommitTx(DriverBoltConnection connection) { verifyCommitTx(connection, times(1)); + verify(connection, atLeastOnce()).close(); } public static void verifyRollbackTx(DriverBoltConnection connection, VerificationMode mode) { - verify(connection, mode).rollback(); + verify(connection, mode) + .writeAndFlush( + any(), + ArgumentMatchers.>argThat( + messages -> messages.size() == 1 && messages.get(0) instanceof RollbackMessage)); } public static void verifyRollbackTx(DriverBoltConnection connection) { @@ -284,16 +325,29 @@ public static void verifyRollbackTx(DriverBoltConnection connection) { } public static void setupFailingRun(DriverBoltConnection connection, Throwable error) { - given(connection.run(any(), any())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - handler.onError(error); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush( + any(), + ArgumentMatchers.>argThat(messages -> messages.size() == 2 + && messages.get(0) instanceof RunMessage + && messages.get(1) instanceof PullMessage))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + handler.onError(error); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); + } + + public static void verifyBegin(DriverBoltConnection connection) { + verifyBegin(connection, atLeastOnce()); + } + + public static void verifyBegin(DriverBoltConnection connection, VerificationMode mode) { + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + then(connection).should(atLeastOnce()).writeAndFlush(any(), captor.capture()); + var messages = captor.getAllValues().get(0); + assertInstanceOf(BeginMessage.class, messages.get(0)); } public static void setupFailingCommit(DriverBoltConnection connection) { @@ -301,23 +355,22 @@ public static void setupFailingCommit(DriverBoltConnection connection) { } public static void setupFailingCommit(DriverBoltConnection connection, int times) { - given(connection.commit()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer(new Answer>() { - int invoked; - - @Override - public CompletionStage answer(InvocationOnMock invocation) { - var handler = (DriverResponseHandler) invocation.getArgument(0); - if (invoked++ < times) { - handler.onError(new ServiceUnavailableException("")); - } else { - handler.onCommitSummary(mock(CommitSummary.class)); - } - handler.onComplete(); - return CompletableFuture.completedStage(null); - } - }); + given(connection.writeAndFlush(any(), any(CommitMessage.class))) + .willAnswer(new Answer>() { + int invoked; + + @Override + public CompletionStage answer(InvocationOnMock invocation) { + var handler = (DriverResponseHandler) invocation.getArgument(0); + if (invoked++ < times) { + handler.onError(new ServiceUnavailableException("")); + } else { + handler.onCommitSummary(mock(CommitSummary.class)); + } + handler.onComplete(); + return CompletableFuture.completedStage(null); + } + }); } public static void setupFailingRollback(DriverBoltConnection connection) { @@ -325,60 +378,66 @@ public static void setupFailingRollback(DriverBoltConnection connection) { } public static void setupFailingRollback(DriverBoltConnection connection, int times) { - given(connection.rollback()).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer(new Answer>() { - int invoked; - - @Override - public CompletionStage answer(InvocationOnMock invocation) { - var handler = (DriverResponseHandler) invocation.getArgument(0); - if (invoked++ < times) { - handler.onError(new ServiceUnavailableException("")); - } else { - handler.onCommitSummary(mock(CommitSummary.class)); - } - handler.onComplete(); - return CompletableFuture.completedStage(null); - } - }); + given(connection.writeAndFlush(any(), any(RollbackMessage.class))) + .willAnswer(new Answer>() { + int invoked; + + @Override + public CompletionStage answer(InvocationOnMock invocation) { + var handler = (DriverResponseHandler) invocation.getArgument(0); + if (invoked++ < times) { + handler.onError(new ServiceUnavailableException("")); + } else { + handler.onCommitSummary(mock(CommitSummary.class)); + } + handler.onComplete(); + return CompletableFuture.completedStage(null); + } + }); } public static void setupSuccessfulRunAndPull(DriverBoltConnection connection) { - given(connection.run(any(), any())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - var runSummary = mock(RunSummary.class); - given(runSummary.keys()).willReturn(Collections.emptyList()); - handler.onRunSummary(runSummary); - var pullSummary = mock(PullSummary.class); - given(pullSummary.metadata()).willReturn(Collections.emptyMap()); - handler.onPullSummary(pullSummary); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush( + any(), + ArgumentMatchers.>argThat(messages -> messages.size() == 2 + && messages.get(0) instanceof RunMessage + && messages.get(1) instanceof PullMessage))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + var runSummary = mock(RunSummary.class); + given(runSummary.keys()).willReturn(Collections.emptyList()); + handler.onRunSummary(runSummary); + var pullSummary = mock(PullSummary.class); + given(pullSummary.metadata()).willReturn(Collections.emptyMap()); + handler.onPullSummary(pullSummary); + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); } public static void setupSuccessfulAutocommitRunAndPull(DriverBoltConnection connection) { - given(connection.runInAutoCommitTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) - .willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.pull(anyLong(), anyLong())).willAnswer((Answer>) - invocation -> CompletableFuture.completedStage(connection)); - given(connection.flush(any())).willAnswer((Answer>) invocation -> { - var handler = (DriverResponseHandler) invocation.getArgument(0); - var runSummary = mock(RunSummary.class); - given(runSummary.keys()).willReturn(Collections.emptyList()); - handler.onRunSummary(runSummary); - var pullSummary = mock(PullSummary.class); - given(pullSummary.metadata()).willReturn(Collections.emptyMap()); - handler.onPullSummary(pullSummary); - handler.onComplete(); - return CompletableFuture.completedStage(null); - }); + given(connection.writeAndFlush(any(), ArgumentMatchers.>argThat(argument -> { + if (argument.size() == 1) { + return argument.get(0) instanceof RunMessage; + } else if (argument.size() == 2) { + return argument.get(0) instanceof RunMessage && argument.get(1) instanceof PullMessage; + } else { + return false; + } + }))) + .willAnswer((Answer>) invocation -> { + var handler = (DriverResponseHandler) invocation.getArgument(0); + var runSummary = mock(RunSummary.class); + given(runSummary.keys()).willReturn(Collections.emptyList()); + handler.onRunSummary(runSummary); + if (((List) invocation.getArgument(1)).size() == 2) { + var pullSummary = mock(PullSummary.class); + given(pullSummary.metadata()).willReturn(Collections.emptyMap()); + handler.onPullSummary(pullSummary); + } + handler.onComplete(); + return CompletableFuture.completedStage(null); + }); } public static DriverBoltConnection connectionMock() { @@ -389,6 +448,7 @@ public static DriverBoltConnection connectionMock(BoltProtocolVersion protocolVe var connection = mock(DriverBoltConnection.class); when(connection.serverAddress()).thenReturn(BoltServerAddress.LOCAL_DEFAULT); when(connection.protocolVersion()).thenReturn(protocolVersion); + given(connection.valueFactory()).willReturn(mock(BoltValueFactory.class)); return connection; } diff --git a/pom.xml b/pom.xml index a196bd32c..d8b3f557c 100644 --- a/pom.xml +++ b/pom.xml @@ -31,7 +31,7 @@ true - 2.0.0 + 3.0-SNAPSHOT 1.0.4