diff --git a/spring-core/src/main/java/org/springframework/util/function/SingletonSupplier.java b/spring-core/src/main/java/org/springframework/util/function/SingletonSupplier.java index 5b53f023b6a7..ece76159e3f5 100644 --- a/spring-core/src/main/java/org/springframework/util/function/SingletonSupplier.java +++ b/spring-core/src/main/java/org/springframework/util/function/SingletonSupplier.java @@ -47,6 +47,7 @@ public class SingletonSupplier implements Supplier<@Nullable T> { private volatile @Nullable T singletonInstance; + private volatile boolean initialized; /** * Guards access to write operations on the {@code singletonInstance} field. */ @@ -63,6 +64,7 @@ public SingletonSupplier(@Nullable T instance, Supplier d this.instanceSupplier = null; this.defaultSupplier = defaultSupplier; this.singletonInstance = instance; + this.initialized = (instance != null); } /** @@ -85,6 +87,7 @@ private SingletonSupplier(@Nullable T singletonInstance) { this.instanceSupplier = null; this.defaultSupplier = null; this.singletonInstance = singletonInstance; + this.initialized = (singletonInstance != null); } @@ -94,26 +97,24 @@ private SingletonSupplier(@Nullable T singletonInstance) { */ @Override public @Nullable T get() { - T instance = this.singletonInstance; - if (instance == null) { + if (!this.initialized) { this.writeLock.lock(); try { - instance = this.singletonInstance; - if (instance == null) { + if (!this.initialized) { if (this.instanceSupplier != null) { - instance = this.instanceSupplier.get(); + this.singletonInstance = this.instanceSupplier.get(); } - if (instance == null && this.defaultSupplier != null) { - instance = this.defaultSupplier.get(); + if (this.singletonInstance == null && this.defaultSupplier != null) { + this.singletonInstance = this.defaultSupplier.get(); } - this.singletonInstance = instance; + this.initialized = true; } } finally { this.writeLock.unlock(); } } - return instance; + return this.singletonInstance; } /** diff --git a/spring-core/src/test/java/org/springframework/util/function/SingletonSupplierTests.java b/spring-core/src/test/java/org/springframework/util/function/SingletonSupplierTests.java new file mode 100644 index 000000000000..b87fb3c4538a --- /dev/null +++ b/spring-core/src/test/java/org/springframework/util/function/SingletonSupplierTests.java @@ -0,0 +1,190 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util.function; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link SingletonSupplier}. + * + * @author Dmytro Nosan + */ +class SingletonSupplierTests { + + @Test + void shouldReturnDefaultWhenInstanceSupplierReturnsNull() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>(() -> null, () -> "Default"); + assertThat(singletonSupplier.get()).isEqualTo("Default"); + } + + @Test + void shouldReturnNullForOfNullableWithNullInstance() { + SingletonSupplier singletonSupplier = SingletonSupplier.ofNullable((String) null); + assertThat(singletonSupplier).isNull(); + } + + @Test + void shouldReturnNullForOfNullableWithNullSupplier() { + SingletonSupplier singletonSupplier = SingletonSupplier.ofNullable((Supplier) null); + assertThat(singletonSupplier).isNull(); + } + + @Test + void shouldReturnNullWhenAllSuppliersReturnNull() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>(() -> null, () -> null); + assertThat(singletonSupplier.get()).isNull(); + } + + @Test + void shouldReturnNullWhenNoInstanceOrDefaultSupplier() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>((String) null, null); + assertThat(singletonSupplier.get()).isNull(); + } + + @Test + void shouldReturnSingletonInstanceOnMultipleCalls() { + SingletonSupplier singletonSupplier = SingletonSupplier.of("Hello"); + assertThat(singletonSupplier.get()).isEqualTo("Hello"); + assertThat(singletonSupplier.get()).isEqualTo("Hello"); + } + + + @Test + void shouldReturnSingletonInstanceOnMultipleSupplierCalls() { + SingletonSupplier singletonSupplier = SingletonSupplier.of(new HelloStringSupplier()); + assertThat(singletonSupplier.get()).isEqualTo("Hello 0"); + assertThat(singletonSupplier.get()).isEqualTo("Hello 0"); + } + + @Test + void shouldReturnSupplierForOfNullableWithNonNullInstance() { + SingletonSupplier singletonSupplier = SingletonSupplier.ofNullable("Hello"); + assertThat(singletonSupplier).isNotNull(); + assertThat(singletonSupplier.get()).isEqualTo("Hello"); + } + + @Test + void shouldReturnSupplierForOfNullableWithNonNullSupplier() { + SingletonSupplier singletonSupplier = SingletonSupplier.ofNullable(() -> "Hello"); + assertThat(singletonSupplier).isNotNull(); + assertThat(singletonSupplier.get()).isEqualTo("Hello"); + } + + @Test + void shouldThrowWhenObtainCalledAndNoInstanceAvailable() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>((String) null, null); + assertThatThrownBy(singletonSupplier::obtain).isInstanceOf(IllegalStateException.class) + .hasMessage("No instance from Supplier"); + } + + @Test + void shouldUseDefaultSupplierWhenInstanceIsNull() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>((String) null, () -> "defaultSupplier"); + assertThat(singletonSupplier.get()).isEqualTo("defaultSupplier"); + } + + @Test + void shouldUseDefaultSupplierWhenInstanceSupplierReturnsNull() { + SingletonSupplier singletonSupplier = new SingletonSupplier<>((Supplier) null, () -> "defaultSupplier"); + assertThat(singletonSupplier.get()).isEqualTo("defaultSupplier"); + } + + @Test + void shouldUseInstanceSupplierWhenProvidedAndIgnoreDefaultSupplier() { + AtomicInteger defaultValue = new AtomicInteger(); + SingletonSupplier singletonSupplier = new SingletonSupplier<>(() -> -1, defaultValue::incrementAndGet); + assertThat(singletonSupplier.get()).isEqualTo(-1); + assertThat(defaultValue.get()).isEqualTo(0); + } + + @Test + void shouldUseInstanceWhenProvidedAndIgnoreDefaultSupplier() { + AtomicInteger defaultValue = new AtomicInteger(); + SingletonSupplier singletonSupplier = new SingletonSupplier<>(-1, defaultValue::incrementAndGet); + assertThat(singletonSupplier.get()).isEqualTo(-1); + assertThat(defaultValue.get()).isEqualTo(0); + } + + @Test + void shouldReturnConsistentlyNullSingletonInstanceOnMultipleSupplierCalls() { + SingletonSupplier singletonSupplier = SingletonSupplier.of(new Supplier<>() { + + int count = 0; + + @Override + public String get() { + if (this.count++ == 0) { + return null; + } + return "Hello"; + } + }); + + assertThat(singletonSupplier.get()).isNull(); + assertThat(singletonSupplier.get()).isNull(); + } + + @RepeatedTest(100) + void shouldReturnSingletonInstanceOnMultipleConcurrentSupplierCalls() throws Exception { + int numberOfThreads = 4; + CountDownLatch ready = new CountDownLatch(numberOfThreads); + CountDownLatch start = new CountDownLatch(1); + List> futures = new ArrayList<>(); + SingletonSupplier singletonSupplier = SingletonSupplier.of(new HelloStringSupplier()); + ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads); + try { + for (int i = 0; i < numberOfThreads; i++) { + futures.add(executorService.submit(() -> { + ready.countDown(); + start.await(); + return singletonSupplier.obtain(); + })); + } + ready.await(); + start.countDown(); + assertThat(futures).extracting(Future::get).containsOnly("Hello 0"); + } + finally { + executorService.shutdown(); + } + } + + + private static final class HelloStringSupplier implements Supplier { + + private final AtomicInteger count = new AtomicInteger(); + + @Override + public String get() { + return "Hello " + this.count.getAndIncrement(); + } + } + +} diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java index d8baf9692447..fba4ada29b0f 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java @@ -209,6 +209,7 @@ void dataSourceInitialization() throws Exception { reset(dataSource); given(dataSource.getConnection()).willReturn(connection); + translator = new SQLErrorCodeSQLExceptionTranslator(dataSource); assertThat(translator.translate("test", null, duplicateKeyException)) .isInstanceOf(DuplicateKeyException.class);