diff --git a/src/integrationTest/java/com/mongodb/hibernate/BasicCrudIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/BasicCrudIntegrationTests.java index ddcbab15..953a1a34 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/BasicCrudIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/BasicCrudIntegrationTests.java @@ -16,14 +16,12 @@ package com.mongodb.hibernate; -import static com.mongodb.hibernate.MongoTestAssertions.assertEquals; +import static com.mongodb.hibernate.MongoTestAssertions.assertEq; import static org.assertj.core.api.Assertions.assertThat; import com.mongodb.client.MongoCollection; import com.mongodb.hibernate.junit.InjectMongoCollection; import com.mongodb.hibernate.junit.MongoExtension; -import jakarta.persistence.Column; -import jakarta.persistence.Embeddable; import jakarta.persistence.Entity; import jakarta.persistence.Id; import jakarta.persistence.Table; @@ -39,10 +37,7 @@ @SessionFactory(exportSchema = false) @DomainModel( - annotatedClasses = { - BasicCrudIntegrationTests.Book.class, - BasicCrudIntegrationTests.BookWithEmbeddedField.class, - BasicCrudIntegrationTests.BookDynamicallyUpdated.class + annotatedClasses = {BasicCrudIntegrationTests.Book.class, BasicCrudIntegrationTests.BookDynamicallyUpdated.class }) @ExtendWith(MongoExtension.class) class BasicCrudIntegrationTests implements SessionFactoryScopeAware { @@ -106,31 +101,6 @@ void testEntityWithNullFieldValueInsertion() { .formatted(author)); assertCollectionContainsExactly(expectedDocument); } - - @Test - void testEntityWithEmbeddedFieldInsertion() { - sessionFactoryScope.inTransaction(session -> { - var book = new BookWithEmbeddedField(); - book.id = 1; - book.title = "War and Peace"; - var author = new Author(); - author.firstName = "Leo"; - author.lastName = "Tolstoy"; - book.author = author; - book.publishYear = 1867; - session.persist(book); - }); - var expectedDocument = BsonDocument.parse( - """ - { - _id: 1, - title: "War and Peace", - authorFirstName: "Leo", - authorLastName: "Tolstoy", - publishYear: 1867 - }"""); - assertCollectionContainsExactly(expectedDocument); - } } @Nested @@ -219,7 +189,7 @@ void testFindByPrimaryKeyWithoutNullValueField() { sessionFactoryScope.inTransaction(session -> session.persist(book)); var loadedBook = sessionFactoryScope.fromTransaction(session -> session.find(Book.class, 1)); - assertEquals(book, loadedBook); + assertEq(book, loadedBook); } @Test @@ -236,7 +206,7 @@ void testFindByPrimaryKeyWithNullValueField() { sessionFactoryScope.inTransaction(session -> session.persist(book)); var loadedBook = sessionFactoryScope.fromTransaction(session -> session.find(Book.class, 1)); - assertEquals(book, loadedBook); + assertEq(book, loadedBook); } } @@ -270,27 +240,4 @@ static class BookDynamicallyUpdated { int publishYear; } - - @Entity - @Table(name = "books") - static class BookWithEmbeddedField { - @Id - int id; - - String title; - - Author author; - - int publishYear; - } - - @Embeddable - static class Author { - - @Column(name = "authorFirstName") - String firstName; - - @Column(name = "authorLastName") - String lastName; - } } diff --git a/src/integrationTest/java/com/mongodb/hibernate/MongoTestAssertions.java b/src/integrationTest/java/com/mongodb/hibernate/MongoTestAssertions.java index 1406357b..2f026ebd 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/MongoTestAssertions.java +++ b/src/integrationTest/java/com/mongodb/hibernate/MongoTestAssertions.java @@ -18,6 +18,8 @@ import static org.assertj.core.api.Assertions.assertThat; +import java.util.function.BiConsumer; +import org.assertj.core.api.RecursiveComparisonAssert; import org.jspecify.annotations.Nullable; public final class MongoTestAssertions { @@ -28,11 +30,19 @@ private MongoTestAssertions() {} * {@link org.junit.jupiter.api.Assertions#assertEquals(Object, Object)}. It should work even if * {@code expected}/{@code actual} does not override {@link Object#equals(Object)}. */ - public static void assertEquals(@Nullable Object expected, @Nullable Object actual) { - assertThat(actual) - .usingRecursiveComparison() - .usingOverriddenEquals() - .withStrictTypeChecking() - .isEqualTo(expected); + public static void assertEq(@Nullable Object expected, @Nullable Object actual) { + assertUsingRecursiveComparison(expected, actual, RecursiveComparisonAssert::isEqualTo); + } + + public static void assertUsingRecursiveComparison( + @Nullable Object expected, + @Nullable Object actual, + BiConsumer, Object> assertion) { + assertion.accept( + assertThat(expected) + .usingRecursiveComparison() + .usingOverriddenEquals() + .withStrictTypeChecking(), + actual); } } diff --git a/src/integrationTest/java/com/mongodb/hibernate/boot/FailedBootstrappingIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/boot/FailedBootstrappingIntegrationTests.java index f0e387ac..72eab997 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/boot/FailedBootstrappingIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/boot/FailedBootstrappingIntegrationTests.java @@ -23,7 +23,7 @@ import com.mongodb.hibernate.junit.MongoExtension; import org.bson.BsonDocument; import org.bson.BsonString; -import org.hibernate.cfg.Configuration; +import org.hibernate.boot.MetadataSources; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -50,7 +50,7 @@ void couldNotInstantiateDialect() { } } """))) { - new Configuration().buildSessionFactory().close(); + new MetadataSources().buildMetadata(); } }) .hasRootCause( diff --git a/src/integrationTest/java/com/mongodb/hibernate/embeddable/EmbeddableIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/embeddable/EmbeddableIntegrationTests.java new file mode 100644 index 00000000..93a8f35f --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/embeddable/EmbeddableIntegrationTests.java @@ -0,0 +1,242 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.embeddable; + +import static com.mongodb.hibernate.MongoTestAssertions.assertEq; +import static com.mongodb.hibernate.MongoTestAssertions.assertUsingRecursiveComparison; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.mongodb.client.MongoCollection; +import com.mongodb.hibernate.junit.InjectMongoCollection; +import com.mongodb.hibernate.junit.MongoExtension; +import jakarta.persistence.AccessType; +import jakarta.persistence.AttributeOverride; +import jakarta.persistence.Column; +import jakarta.persistence.Embeddable; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import org.bson.BsonDocument; +import org.hibernate.annotations.Parent; +import org.hibernate.boot.MetadataSources; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@SessionFactory(exportSchema = false) +@DomainModel( + annotatedClasses = { + EmbeddableIntegrationTests.ItemWithFlattenedValues.class, + EmbeddableIntegrationTests.ItemWithOmittedEmptyValue.class + }) +@ExtendWith(MongoExtension.class) +class EmbeddableIntegrationTests implements SessionFactoryScopeAware { + @InjectMongoCollection("items") + private static MongoCollection mongoCollection; + + private SessionFactoryScope sessionFactoryScope; + + @Test + void testFlattenedValues() { + var item = new ItemWithFlattenedValues(new Single(1), new Single(2), new PairWithParent(3, new Pair(4, 5))); + item.flattened2.parent = item; + sessionFactoryScope.inTransaction(session -> session.persist(item)); + assertCollectionContainsExactly( + """ + { + _id: 1, + flattened1_a: 2, + flattened2_a: 3, + flattened2_flattened_a: 4, + flattened2_flattened_b: 5 + } + """); + var loadedItem = sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithFlattenedValues.class, item.flattenedId)); + assertEq(item, loadedItem); + var updatedItem = sessionFactoryScope.fromTransaction(session -> { + var result = session.find(ItemWithFlattenedValues.class, item.flattenedId); + result.flattened1.a = -result.flattened1.a; + return result; + }); + assertCollectionContainsExactly( + """ + { + _id: 1, + flattened1_a: -2, + flattened2_a: 3, + flattened2_flattened_a: 4, + flattened2_flattened_b: 5 + } + """); + loadedItem = sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithFlattenedValues.class, updatedItem.flattenedId)); + assertEq(updatedItem, loadedItem); + } + + @Test + void testFlattenedEmptyValue() { + var item = new ItemWithOmittedEmptyValue(1, new Empty()); + sessionFactoryScope.inTransaction(session -> session.persist(item)); + assertCollectionContainsExactly( + // Hibernate ORM does not store/read the empty `item.omitted` value. + // See https://hibernate.atlassian.net/browse/HHH-11936 for more details. + """ + { + _id: 1 + } + """); + var loadedItem = + sessionFactoryScope.fromTransaction(session -> session.find(ItemWithOmittedEmptyValue.class, item.id)); + assertUsingRecursiveComparison(item, loadedItem, (assertion, actual) -> assertion + .ignoringFields("omitted") + .isEqualTo(actual)); + var updatedItem = sessionFactoryScope.fromTransaction(session -> { + var result = session.find(ItemWithOmittedEmptyValue.class, item.id); + result.omitted = null; + return result; + }); + assertCollectionContainsExactly( + """ + { + _id: 1 + } + """); + loadedItem = sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithOmittedEmptyValue.class, updatedItem.id)); + assertEq(updatedItem, loadedItem); + } + + @Override + public void injectSessionFactoryScope(SessionFactoryScope sessionFactoryScope) { + this.sessionFactoryScope = sessionFactoryScope; + } + + private static void assertCollectionContainsExactly(String json) { + assertThat(mongoCollection.find()).containsExactly(BsonDocument.parse(json)); + } + + @Entity + @Table(name = "items") + static class ItemWithFlattenedValues { + @Id + Single flattenedId; + + @AttributeOverride(name = "a", column = @Column(name = "flattened1_a")) + Single flattened1; + + @AttributeOverride(name = "a", column = @Column(name = "flattened2_a")) + @AttributeOverride(name = "flattened.a", column = @Column(name = "flattened2_flattened_a")) + @AttributeOverride(name = "flattened.b", column = @Column(name = "flattened2_flattened_b")) + PairWithParent flattened2; + + ItemWithFlattenedValues() {} + + ItemWithFlattenedValues(Single flattenedId, Single flattened1, PairWithParent flattened2) { + this.flattenedId = flattenedId; + this.flattened1 = flattened1; + this.flattened2 = flattened2; + } + } + + @Embeddable + static class Single { + int a; + + Single() {} + + Single(int a) { + this.a = a; + } + } + + @Embeddable + static class PairWithParent { + int a; + Pair flattened; + + @Parent ItemWithFlattenedValues parent; + + PairWithParent() {} + + PairWithParent(int a, Pair flattened) { + this.a = a; + this.flattened = flattened; + } + + /** + * Hibernate ORM requires a getter for a {@link Parent} field, despite us using {@linkplain AccessType#FIELD + * field-based access}. + */ + void setParent(ItemWithFlattenedValues parent) { + this.parent = parent; + } + + /** + * Hibernate ORM requires a getter for a {@link Parent} field, despite us using {@linkplain AccessType#FIELD + * field-based access}. + */ + ItemWithFlattenedValues getParent() { + return parent; + } + } + + @Embeddable + record Pair(int a, int b) {} + + @Entity + @Table(name = "items") + static class ItemWithOmittedEmptyValue { + @Id + int id; + + Empty omitted; + + ItemWithOmittedEmptyValue() {} + + ItemWithOmittedEmptyValue(int id, Empty omitted) { + this.id = id; + this.omitted = omitted; + } + } + + @Embeddable + static class Empty {} + + @Nested + class Unsupported { + @Test + void testPrimaryKeySpanningMultipleFields() { + assertThatThrownBy(() -> new MetadataSources() + .addAnnotatedClass(ItemWithPairAsId.class) + .buildMetadata()) + .hasMessageContaining("does not support primary key spanning multiple columns"); + } + + @Entity + @Table(name = "items") + static class ItemWithPairAsId { + @Id + Pair id; + } + } +} diff --git a/src/integrationTest/java/com/mongodb/hibernate/embeddable/StructAggregateEmbeddableIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/embeddable/StructAggregateEmbeddableIntegrationTests.java new file mode 100644 index 00000000..d5c3c38e --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/embeddable/StructAggregateEmbeddableIntegrationTests.java @@ -0,0 +1,340 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.embeddable; + +import static com.mongodb.hibernate.MongoTestAssertions.assertEq; +import static com.mongodb.hibernate.MongoTestAssertions.assertUsingRecursiveComparison; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.mongodb.client.MongoCollection; +import com.mongodb.hibernate.junit.InjectMongoCollection; +import com.mongodb.hibernate.junit.MongoExtension; +import jakarta.persistence.AccessType; +import jakarta.persistence.Column; +import jakarta.persistence.Embeddable; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import org.bson.BsonDocument; +import org.hibernate.annotations.Parent; +import org.hibernate.annotations.Struct; +import org.hibernate.boot.MetadataSources; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SessionFactoryScopeAware; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@SessionFactory(exportSchema = false) +@DomainModel( + annotatedClasses = { + StructAggregateEmbeddableIntegrationTests.ItemWithNestedValues.class, + StructAggregateEmbeddableIntegrationTests.ItemWithOmittedEmptyValue.class, + StructAggregateEmbeddableIntegrationTests.Unsupported.ItemWithNestedValueHavingNonInsertable.class, + StructAggregateEmbeddableIntegrationTests.Unsupported.ItemWithNestedValueHavingAllNonInsertable.class, + StructAggregateEmbeddableIntegrationTests.Unsupported.ItemWithNestedValueHavingNonUpdatable.class + }) +@ExtendWith(MongoExtension.class) +class StructAggregateEmbeddableIntegrationTests implements SessionFactoryScopeAware { + @InjectMongoCollection("items") + private static MongoCollection mongoCollection; + + private SessionFactoryScope sessionFactoryScope; + + @Test + void testNestedValues() { + var item = new ItemWithNestedValues(new Single(1), new Single(2), new PairWithParent(3, new Pair(4, 5))); + item.nested2.parent = item; + sessionFactoryScope.inTransaction(session -> session.persist(item)); + assertCollectionContainsExactly( + // Hibernate ORM flattens `item.id` despite it being of an aggregate type + """ + { + _id: 1, + nested1: { + a: 2 + }, + nested2: { + a: 3, + nested: { + a: 4, + b: 5 + } + } + } + """); + var loadedItem = + sessionFactoryScope.fromTransaction(session -> session.find(ItemWithNestedValues.class, item.nestedId)); + assertEq(item, loadedItem); + var updatedItem = sessionFactoryScope.fromTransaction(session -> { + var result = session.find(ItemWithNestedValues.class, item.nestedId); + result.nested1.a = -result.nested1.a; + return result; + }); + assertCollectionContainsExactly( + """ + { + _id: 1, + nested1: { + a: -2 + }, + nested2: { + a: 3, + nested: { + a: 4, + b: 5 + } + } + } + """); + loadedItem = sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithNestedValues.class, updatedItem.nestedId)); + assertEq(updatedItem, loadedItem); + } + + @Test + void testNestedEmptyValue() { + var item = new ItemWithOmittedEmptyValue(1, new Empty()); + sessionFactoryScope.inTransaction(session -> session.persist(item)); + assertCollectionContainsExactly( + // Hibernate ORM does not store/read the empty `item.omitted` value. + // See https://hibernate.atlassian.net/browse/HHH-11936 for more details. + """ + { + _id: 1 + } + """); + var loadedItem = + sessionFactoryScope.fromTransaction(session -> session.find(ItemWithOmittedEmptyValue.class, item.id)); + assertUsingRecursiveComparison(item, loadedItem, (assertion, actual) -> assertion + .ignoringFields("omitted") + .isEqualTo(actual)); + var updatedItem = sessionFactoryScope.fromTransaction(session -> { + var result = session.find(ItemWithOmittedEmptyValue.class, item.id); + result.omitted = null; + return result; + }); + assertCollectionContainsExactly( + """ + { + _id: 1 + } + """); + loadedItem = sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithOmittedEmptyValue.class, updatedItem.id)); + assertEq(updatedItem, loadedItem); + } + + @Override + public void injectSessionFactoryScope(SessionFactoryScope sessionFactoryScope) { + this.sessionFactoryScope = sessionFactoryScope; + } + + private static void assertCollectionContainsExactly(String json) { + assertThat(mongoCollection.find()).containsExactly(BsonDocument.parse(json)); + } + + @Entity + @Table(name = "items") + static class ItemWithNestedValues { + @Id + Single nestedId; + + Single nested1; + + PairWithParent nested2; + + ItemWithNestedValues() {} + + ItemWithNestedValues(Single nestedId, Single nested1, PairWithParent nested2) { + this.nestedId = nestedId; + this.nested1 = nested1; + this.nested2 = nested2; + } + } + + @Embeddable + @Struct(name = "Single") + static class Single { + int a; + + Single() {} + + Single(int a) { + this.a = a; + } + } + + @Embeddable + @Struct(name = "PairWithParent") + static class PairWithParent { + int a; + Pair nested; + + @Parent ItemWithNestedValues parent; + + PairWithParent() {} + + PairWithParent(int a, Pair nested) { + this.a = a; + this.nested = nested; + } + + /** + * Hibernate ORM requires a getter for a {@link Parent} field, despite us using {@linkplain AccessType#FIELD + * field-based access}. + */ + void setParent(ItemWithNestedValues parent) { + this.parent = parent; + } + + /** + * Hibernate ORM requires a getter for a {@link Parent} field, despite us using {@linkplain AccessType#FIELD + * field-based access}. + */ + ItemWithNestedValues getParent() { + return parent; + } + } + + @Embeddable + @Struct(name = "Pair") + record Pair(int a, int b) {} + + @Entity + @Table(name = "items") + static class ItemWithOmittedEmptyValue { + @Id + int id; + + Empty omitted; + + ItemWithOmittedEmptyValue() {} + + ItemWithOmittedEmptyValue(int id, Empty omitted) { + this.id = id; + this.omitted = omitted; + } + } + + @Embeddable + @Struct(name = "Empty") + static class Empty {} + + @Nested + class Unsupported { + @Test + void testPrimaryKeySpanningMultipleFields() { + assertThatThrownBy(() -> new MetadataSources() + .addAnnotatedClass(ItemWithPairAsId.class) + .buildMetadata()) + .hasMessageContaining("does not support primary key spanning multiple columns"); + } + + @Test + void testNonInsertable() { + var item = new ItemWithNestedValueHavingNonInsertable(1, new PairHavingNonInsertable(2, 3)); + assertThatThrownBy(() -> sessionFactoryScope.inTransaction(session -> session.persist(item))) + .hasMessageContaining("must be insertable"); + } + + @Test + void testAllNonInsertable() { + var item = new ItemWithNestedValueHavingAllNonInsertable(1, new PairAllNonInsertable(2, 3)); + sessionFactoryScope.inTransaction(session -> session.persist(item)); + assertCollectionContainsExactly( + // `item.omitted` is considered empty because all its persistent attributes are non-insertable. + // Hibernate ORM does not store/read the empty `item.omitted` value. + // See https://hibernate.atlassian.net/browse/HHH-11936 for more details. + """ + { + _id: 1 + } + """); + assertThatThrownBy(() -> sessionFactoryScope.fromTransaction( + session -> session.find(ItemWithNestedValueHavingAllNonInsertable.class, item.id))) + .isInstanceOf(Exception.class); + } + + @Test + void testNonUpdatable() { + sessionFactoryScope.inTransaction(session -> { + var item = new ItemWithNestedValueHavingNonUpdatable(1, new PairHavingNonUpdatable(2, 3)); + session.persist(item); + assertThatThrownBy(session::flush).hasMessageContaining("must be updatable"); + }); + } + + @Entity + @Table(name = "items") + static class ItemWithPairAsId { + @Id + Pair id; + } + + @Entity + @Table(name = "items") + record ItemWithNestedValueHavingNonInsertable(@Id int id, PairHavingNonInsertable nested) {} + + @Embeddable + @Struct(name = "PairHavingNonInsertable") + record PairHavingNonInsertable(@Column(insertable = false) int a, int b) {} + + @Entity + @Table(name = "items") + record ItemWithNestedValueHavingNonUpdatable(@Id int id, PairHavingNonUpdatable nested) {} + + @Embeddable + @Struct(name = "PairHavingNonUpdatable") + static class PairHavingNonUpdatable { + @Column(updatable = false) + int a; + + int b; + + PairHavingNonUpdatable() {} + + PairHavingNonUpdatable(int a, int b) { + this.a = a; + this.b = b; + } + } + + @Entity + @Table(name = "items") + static class ItemWithNestedValueHavingAllNonInsertable { + @Id + int id; + + PairAllNonInsertable omitted; + + ItemWithNestedValueHavingAllNonInsertable() {} + + ItemWithNestedValueHavingAllNonInsertable(int id, PairAllNonInsertable omitted) { + this.id = id; + this.omitted = omitted; + } + } + + @Embeddable + @Struct(name = "PairAllNonInsertable") + record PairAllNonInsertable(@Column(insertable = false) int a, @Column(insertable = false) int b) {} + } +} diff --git a/src/integrationTest/java/com/mongodb/hibernate/type/ObjectIdIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/type/ObjectIdIntegrationTests.java index c9035a36..40be8e2a 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/type/ObjectIdIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/type/ObjectIdIntegrationTests.java @@ -16,7 +16,7 @@ package com.mongodb.hibernate.type; -import static com.mongodb.hibernate.MongoTestAssertions.assertEquals; +import static com.mongodb.hibernate.MongoTestAssertions.assertEq; import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -91,7 +91,7 @@ void findById() { item.vExplicitlyAnnotatedNotForThePublic = new ObjectId(3, 4); sessionFactoryScope.inTransaction(session -> session.persist(item)); var loadedItem = sessionFactoryScope.fromTransaction(session -> session.find(Item.class, item.id)); - assertEquals(item, loadedItem); + assertEq(item, loadedItem); } @Override @@ -131,7 +131,7 @@ void assignedValue() { item.id = 1; item.v = v; sessionFactoryScope.inTransaction(session -> session.persist(item)); - assertEquals(v, item.v); + assertEq(v, item.v); } } diff --git a/src/main/java/com/mongodb/hibernate/dialect/MongoAggregateSupport.java b/src/main/java/com/mongodb/hibernate/dialect/MongoAggregateSupport.java new file mode 100644 index 00000000..ea2457f1 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/dialect/MongoAggregateSupport.java @@ -0,0 +1,71 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.dialect; + +import static java.lang.String.format; + +import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import java.sql.JDBCType; +import org.hibernate.dialect.aggregate.AggregateSupportImpl; +import org.hibernate.mapping.AggregateColumn; +import org.hibernate.mapping.Column; + +final class MongoAggregateSupport extends AggregateSupportImpl { + static final MongoAggregateSupport INSTANCE = new MongoAggregateSupport(); + + private MongoAggregateSupport() {} + + @Override + public String aggregateComponentCustomReadExpression( + String template, + String placeholder, + String aggregateParentReadExpression, + String columnExpression, + AggregateColumn aggregateColumn, + Column column) { + var aggregateColumnType = aggregateColumn.getTypeCode(); + if (aggregateColumnType == JDBCType.STRUCT.getVendorTypeNumber()) { + return format( + "unused from %s.aggregateComponentCustomReadExpression", + MongoAggregateSupport.class.getSimpleName()); + } + throw new FeatureNotSupportedException(format("The SQL type code [%d] is not supported", aggregateColumnType)); + } + + @Override + public String aggregateComponentAssignmentExpression( + String aggregateParentAssignmentExpression, + String columnExpression, + AggregateColumn aggregateColumn, + Column column) { + var aggregateColumnType = aggregateColumn.getTypeCode(); + if (aggregateColumnType == JDBCType.STRUCT.getVendorTypeNumber()) { + return format( + "unused from %s.aggregateComponentAssignmentExpression", + MongoAggregateSupport.class.getSimpleName()); + } + throw new FeatureNotSupportedException(format("The SQL type code [%d] is not supported", aggregateColumnType)); + } + + @Override + public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) { + if (aggregateSqlTypeCode == JDBCType.STRUCT.getVendorTypeNumber()) { + return false; + } + throw new FeatureNotSupportedException(format("The SQL type code [%d] is not supported", aggregateSqlTypeCode)); + } +} diff --git a/src/main/java/com/mongodb/hibernate/dialect/MongoDialect.java b/src/main/java/com/mongodb/hibernate/dialect/MongoDialect.java index d3015f9b..4aee1d37 100644 --- a/src/main/java/com/mongodb/hibernate/dialect/MongoDialect.java +++ b/src/main/java/com/mongodb/hibernate/dialect/MongoDialect.java @@ -20,12 +20,14 @@ import static java.lang.String.format; import com.mongodb.hibernate.internal.translate.MongoTranslatorFactory; +import com.mongodb.hibernate.internal.type.MongoStructJdbcType; import com.mongodb.hibernate.internal.type.ObjectIdJavaType; import com.mongodb.hibernate.internal.type.ObjectIdJdbcType; import com.mongodb.hibernate.jdbc.MongoConnectionProvider; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.DatabaseVersion; import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.aggregate.AggregateSupport; import org.hibernate.engine.jdbc.dialect.spi.DialectResolutionInfo; import org.hibernate.service.ServiceRegistry; import org.hibernate.sql.ast.SqlAstTranslatorFactory; @@ -91,10 +93,16 @@ public void contribute(TypeContributions typeContributions, ServiceRegistry serv super.contribute(typeContributions, serviceRegistry); typeContributions.contributeJavaType(ObjectIdJavaType.INSTANCE); typeContributions.contributeJdbcType(ObjectIdJdbcType.INSTANCE); + typeContributions.contributeJdbcType(MongoStructJdbcType.INSTANCE); } @Override public @Nullable String toQuotedIdentifier(@Nullable String name) { return name; } + + @Override + public AggregateSupport getAggregateSupport() { + return MongoAggregateSupport.INSTANCE; + } } diff --git a/src/main/java/com/mongodb/hibernate/internal/FeatureNotSupportedException.java b/src/main/java/com/mongodb/hibernate/internal/FeatureNotSupportedException.java index 6654e461..dc63e46a 100644 --- a/src/main/java/com/mongodb/hibernate/internal/FeatureNotSupportedException.java +++ b/src/main/java/com/mongodb/hibernate/internal/FeatureNotSupportedException.java @@ -17,6 +17,7 @@ package com.mongodb.hibernate.internal; import java.io.Serial; +import java.sql.SQLFeatureNotSupportedException; public final class FeatureNotSupportedException extends RuntimeException { @@ -38,4 +39,8 @@ public FeatureNotSupportedException() {} public FeatureNotSupportedException(String message) { super(message); } + + public FeatureNotSupportedException(SQLFeatureNotSupportedException cause) { + super(cause); + } } diff --git a/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java b/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java index 13c55f2b..e5f37ef5 100644 --- a/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java +++ b/src/main/java/com/mongodb/hibernate/internal/extension/MongoAdditionalMappingContributor.java @@ -83,8 +83,8 @@ private static void setIdentifierColumnName(PersistentClass persistentClass) { var idColumns = identifier.getColumns(); if (idColumns.size() > 1) { throw new FeatureNotSupportedException(format( - "%s: %s does not support [%s] field spanning multiple columns %s", - persistentClass, MONGO_DBMS_NAME, ID_FIELD_NAME, idColumns)); + "%s: %s does not support primary key spanning multiple columns %s", + persistentClass, MONGO_DBMS_NAME, idColumns)); } assertTrue(idColumns.size() == 1); var idColumn = idColumns.get(0); diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java b/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java index 5ccf118c..8887f726 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java @@ -18,7 +18,6 @@ import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; -import static com.mongodb.hibernate.internal.MongoConstants.ID_FIELD_NAME; import static com.mongodb.hibernate.internal.MongoConstants.MONGO_DBMS_NAME; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_AGGREGATE; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_MUTATION; @@ -63,24 +62,18 @@ import com.mongodb.hibernate.internal.translate.mongoast.filter.AstFilter; import com.mongodb.hibernate.internal.translate.mongoast.filter.AstFilterFieldPath; import com.mongodb.hibernate.internal.translate.mongoast.filter.AstLogicalFilter; +import com.mongodb.hibernate.internal.type.ValueConverter; import java.io.IOException; import java.io.StringWriter; -import java.math.BigDecimal; +import java.sql.SQLFeatureNotSupportedException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; -import org.bson.BsonBoolean; -import org.bson.BsonDecimal128; -import org.bson.BsonDouble; -import org.bson.BsonInt32; -import org.bson.BsonInt64; -import org.bson.BsonString; import org.bson.BsonValue; import org.bson.json.JsonMode; import org.bson.json.JsonWriter; import org.bson.json.JsonWriterSettings; -import org.bson.types.Decimal128; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.internal.util.collections.Stack; import org.hibernate.persister.entity.EntityPersister; @@ -328,7 +321,7 @@ private AstFilter getKeyFilter(AbstractRestrictedTableMutation 1) { throw new FeatureNotSupportedException( - format("%s does not support '%s' spanning multiple columns", MONGO_DBMS_NAME, ID_FIELD_NAME)); + format("%s does not support primary key spanning multiple columns", MONGO_DBMS_NAME)); } assertTrue(tableMutation.getNumberOfKeyBindings() == 1); var keyBinding = tableMutation.getKeyBindings().get(0); @@ -469,15 +462,18 @@ public void visitSelectClause(SelectClause selectClause) { @Override public void visitColumnReference(ColumnReference columnReference) { if (columnReference.isColumnExpressionFormula()) { - throw new FeatureNotSupportedException(); + throw new FeatureNotSupportedException("Formulas are not supported"); } astVisitorValueHolder.yield(FIELD_PATH, columnReference.getColumnExpression()); } @Override public void visitQueryLiteral(QueryLiteral queryLiteral) { - var bsonValue = toBsonValue(queryLiteral.getLiteralValue()); - astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(bsonValue)); + var literalValue = queryLiteral.getLiteralValue(); + if (literalValue == null) { + throw new FeatureNotSupportedException("TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); + } + astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(toBsonValue(literalValue))); } @Override @@ -496,8 +492,8 @@ public void visitJunction(Junction junction) { @Override public void visitUnparsedNumericLiteral(UnparsedNumericLiteral unparsedNumericLiteral) { - astVisitorValueHolder.yield( - FIELD_VALUE, new AstLiteralValue(toBsonValue(unparsedNumericLiteral.getLiteralValue()))); + var literalValue = assertNotNull(unparsedNumericLiteral.getLiteralValue()); + astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(toBsonValue(literalValue))); } @Override @@ -889,28 +885,11 @@ private static boolean isComparingFieldWithValue(ComparisonPredicate comparisonP || (isFieldPathExpression(rhs) && isValueExpression(lhs)); } - private static BsonValue toBsonValue(@Nullable Object queryLiteral) { - if (queryLiteral == null) { - throw new FeatureNotSupportedException("TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); - } - if (queryLiteral instanceof Boolean boolValue) { - return BsonBoolean.valueOf(boolValue); - } - if (queryLiteral instanceof Integer intValue) { - return new BsonInt32(intValue); - } - if (queryLiteral instanceof Long longValue) { - return new BsonInt64(longValue); - } - if (queryLiteral instanceof Double doubleValue) { - return new BsonDouble(doubleValue); - } - if (queryLiteral instanceof BigDecimal bigDecimalValue) { - return new BsonDecimal128(new Decimal128(bigDecimalValue)); - } - if (queryLiteral instanceof String stringValue) { - return new BsonString(stringValue); + private static BsonValue toBsonValue(Object value) { + try { + return ValueConverter.toBsonValue(value); + } catch (SQLFeatureNotSupportedException e) { + throw new FeatureNotSupportedException(e); } - throw new FeatureNotSupportedException("Unsupported Java type: " + queryLiteral.getClass()); } } diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java b/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java index 071c8fdf..b940323d 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java @@ -23,6 +23,7 @@ import org.hibernate.query.spi.QueryOptions; import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.model.ast.TableMutation; +import org.hibernate.sql.model.internal.TableUpdateNoSet; import org.hibernate.sql.model.jdbc.JdbcMutationOperation; import org.jspecify.annotations.Nullable; @@ -40,7 +41,13 @@ public O translate(@Nullable JdbcParameterBindings jdbcParameterBindings, QueryO assertNull(jdbcParameterBindings); checkQueryOptionsSupportability(queryOptions); - var mutationCommand = acceptAndYield(tableMutation, COLLECTION_MUTATION); - return tableMutation.createMutationOperation(renderMongoAstNode(mutationCommand), getParameterBinders()); + String mql; + if ((TableMutation) tableMutation instanceof TableUpdateNoSet) { + mql = ""; + } else { + var mutationCommand = acceptAndYield(tableMutation, COLLECTION_MUTATION); + mql = renderMongoAstNode(mutationCommand); + } + return tableMutation.createMutationOperation(mql, getParameterBinders()); } } diff --git a/src/main/java/com/mongodb/hibernate/internal/type/MongoStructJdbcType.java b/src/main/java/com/mongodb/hibernate/internal/type/MongoStructJdbcType.java new file mode 100644 index 00000000..20b8cd46 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/internal/type/MongoStructJdbcType.java @@ -0,0 +1,218 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.type; + +import static com.mongodb.hibernate.internal.MongoAssertions.assertFalse; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; +import static com.mongodb.hibernate.internal.MongoAssertions.fail; +import static com.mongodb.hibernate.internal.type.ValueConverter.toBsonValue; +import static com.mongodb.hibernate.internal.type.ValueConverter.toDomainValue; + +import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import java.io.Serial; +import java.sql.CallableStatement; +import java.sql.JDBCType; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import org.bson.BsonDocument; +import org.bson.BsonValue; +import org.hibernate.annotations.Struct; +import org.hibernate.metamodel.mapping.EmbeddableMappingType; +import org.hibernate.metamodel.spi.RuntimeModelCreationContext; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.AggregateJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.StructJdbcType; +import org.jspecify.annotations.Nullable; + +/** Thread-safe. */ +public final class MongoStructJdbcType implements StructJdbcType { + @Serial + private static final long serialVersionUID = 1L; + + public static final MongoStructJdbcType INSTANCE = new MongoStructJdbcType(); + private static final JDBCType JDBC_TYPE = JDBCType.STRUCT; + + private final @Nullable EmbeddableMappingType embeddableMappingType; + + private final @Nullable String structTypeName; + + private MongoStructJdbcType() { + this(null, null); + } + + private MongoStructJdbcType( + @Nullable EmbeddableMappingType embeddableMappingType, @Nullable String structTypeName) { + this.embeddableMappingType = embeddableMappingType; + this.structTypeName = structTypeName; + } + + @Override + public String getStructTypeName() { + return assertNotNull(structTypeName); + } + + /** + * This method may be called multiple times with equal {@code sqlType} and different {@code mappingType}. + * + * @param sqlType The {@link Struct#name()}. + */ + @Override + public AggregateJdbcType resolveAggregateJdbcType( + EmbeddableMappingType mappingType, String sqlType, RuntimeModelCreationContext creationContext) { + return new MongoStructJdbcType(mappingType, sqlType); + } + + @Override + public EmbeddableMappingType getEmbeddableMappingType() { + return assertNotNull(embeddableMappingType); + } + + @Override + public BsonDocument createJdbcValue(Object domainValue, WrapperOptions options) throws SQLException { + var embeddableMappingType = assertNotNull(this.embeddableMappingType); + if (embeddableMappingType.isPolymorphic()) { + throw new FeatureNotSupportedException("Polymorphic mapping is not supported"); + } + var result = new BsonDocument(); + var jdbcValueCount = embeddableMappingType.getJdbcValueCount(); + for (int columnIndex = 0; columnIndex < jdbcValueCount; columnIndex++) { + var jdbcValueSelectable = embeddableMappingType.getJdbcValueSelectable(columnIndex); + assertFalse(jdbcValueSelectable.isFormula()); + if (!jdbcValueSelectable.isInsertable()) { + throw new FeatureNotSupportedException( + "Persistent attributes of a `@Struct @Embeddable` must be insertable"); + } + if (!jdbcValueSelectable.isUpdateable()) { + throw new FeatureNotSupportedException( + "Persistent attributes of a `@Struct @Embeddable` must be updatable"); + } + var fieldName = jdbcValueSelectable.getSelectableName(); + var value = embeddableMappingType.getValue(domainValue, columnIndex); + if (value == null) { + throw new FeatureNotSupportedException( + "TODO-HIBERNATE-48 https://jira.mongodb.org/browse/HIBERNATE-48"); + } + BsonValue bsonValue; + var jdbcMapping = jdbcValueSelectable.getJdbcMapping(); + if (jdbcMapping.getJdbcType().getJdbcTypeCode() == JDBC_TYPE.getVendorTypeNumber()) { + if (!(jdbcMapping.getJdbcValueBinder() instanceof Binder structValueBinder)) { + throw fail(); + } + if (!(structValueBinder.getJdbcType() instanceof MongoStructJdbcType structJdbcType)) { + throw fail(); + } + bsonValue = structJdbcType.createJdbcValue(value, options); + } else { + bsonValue = toBsonValue(value); + } + result.append(fieldName, bsonValue); + } + return result; + } + + @Override + public Object[] extractJdbcValues(Object rawJdbcValue, WrapperOptions options) throws SQLException { + if (!(rawJdbcValue instanceof BsonDocument bsonDocument)) { + throw fail(); + } + var result = new Object[bsonDocument.size()]; + var elementIdx = 0; + for (var value : bsonDocument.values()) { + assertNotNull(value); + result[elementIdx++] = + value instanceof BsonDocument ? extractJdbcValues(value, options) : toDomainValue(value); + } + return result; + } + + @Override + public int getJdbcTypeCode() { + return JDBC_TYPE.getVendorTypeNumber(); + } + + @Override + public ValueBinder getBinder(JavaType javaType) { + return new Binder<>(javaType); + } + + @Override + public ValueExtractor getExtractor(JavaType javaType) { + return new Extractor<>(javaType); + } + + /** Thread-safe. */ + private final class Binder extends BasicBinder { + @Serial + private static final long serialVersionUID = 1L; + + Binder(JavaType javaType) { + super(javaType, MongoStructJdbcType.this); + } + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + if (!(getJdbcType() instanceof MongoStructJdbcType structJdbcType)) { + throw fail(); + } + st.setObject(index, structJdbcType.createJdbcValue(value, options), structJdbcType.getJdbcTypeCode()); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) throws SQLException { + throw new SQLFeatureNotSupportedException(); + } + } + + /** Thread-safe. */ + private final class Extractor extends BasicExtractor { + @Serial + private static final long serialVersionUID = 1L; + + Extractor(JavaType javaType) { + super(javaType, MongoStructJdbcType.this); + } + + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if (!(getJdbcType() instanceof MongoStructJdbcType structJdbcType)) { + throw fail(); + } + var classX = getJavaType().getJavaTypeClass(); + assertTrue(classX.equals(Object[].class)); + var bsonDocument = rs.getObject(paramIndex, BsonDocument.class); + return classX.cast(structJdbcType.extractJdbcValues(bsonDocument, options)); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + throw new SQLFeatureNotSupportedException(); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + throw new SQLFeatureNotSupportedException(); + } + } +} diff --git a/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJavaType.java b/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJavaType.java index c19db47f..12d5b847 100644 --- a/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJavaType.java +++ b/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJavaType.java @@ -25,6 +25,7 @@ import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +/** Thread-safe. */ public final class ObjectIdJavaType extends AbstractClassJavaType { @Serial private static final long serialVersionUID = 1L; @@ -49,10 +50,10 @@ public X unwrap(ObjectId value, Class type, WrapperOptions options) { @Override public ObjectId wrap(X value, WrapperOptions options) { - if (value instanceof ObjectId wrapped) { - return wrapped; + if (!(value instanceof ObjectId wrapped)) { + throw new FeatureNotSupportedException(); } - throw new FeatureNotSupportedException(); + return wrapped; } @Override diff --git a/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJdbcType.java b/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJdbcType.java index ac1a13da..177196d0 100644 --- a/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJdbcType.java +++ b/src/main/java/com/mongodb/hibernate/internal/type/ObjectIdJdbcType.java @@ -32,6 +32,7 @@ import org.hibernate.type.descriptor.jdbc.BasicExtractor; import org.hibernate.type.descriptor.jdbc.JdbcType; +/** Thread-safe. */ public final class ObjectIdJdbcType implements JdbcType { @Serial private static final long serialVersionUID = 1L; @@ -58,7 +59,7 @@ public ValueBinder getBinder(JavaType javaType) { throw new FeatureNotSupportedException(); } @SuppressWarnings("unchecked") - var result = (ValueBinder) Binder.INSTANCE; + var result = (ValueBinder) new Binder(JAVA_TYPE); return result; } @@ -68,19 +69,17 @@ public ValueExtractor getExtractor(JavaType javaType) { throw new FeatureNotSupportedException(); } @SuppressWarnings("unchecked") - var result = (ValueExtractor) Extractor.INSTANCE; + var result = (ValueExtractor) new Extractor(JAVA_TYPE); return result; } /** Thread-safe. */ - private static final class Binder extends BasicBinder { + private final class Binder extends BasicBinder { @Serial private static final long serialVersionUID = 1L; - static final Binder INSTANCE = new Binder(); - - private Binder() { - super(JAVA_TYPE, ObjectIdJdbcType.INSTANCE); + private Binder(JavaType javaType) { + super(javaType, ObjectIdJdbcType.this); } @Override @@ -97,14 +96,12 @@ protected void doBind(CallableStatement st, ObjectId value, String name, Wrapper } /** Thread-safe. */ - private static final class Extractor extends BasicExtractor { + private final class Extractor extends BasicExtractor { @Serial private static final long serialVersionUID = 1L; - static final Extractor INSTANCE = new Extractor(); - - private Extractor() { - super(JAVA_TYPE, ObjectIdJdbcType.INSTANCE); + private Extractor(JavaType javaType) { + super(javaType, ObjectIdJdbcType.this); } @Override diff --git a/src/main/java/com/mongodb/hibernate/internal/type/ValueConverter.java b/src/main/java/com/mongodb/hibernate/internal/type/ValueConverter.java new file mode 100644 index 00000000..3e24a06a --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/internal/type/ValueConverter.java @@ -0,0 +1,191 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.type; + +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static java.lang.String.format; + +import java.math.BigDecimal; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLFeatureNotSupportedException; +import org.bson.BsonBinary; +import org.bson.BsonBoolean; +import org.bson.BsonDecimal128; +import org.bson.BsonDouble; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonObjectId; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.types.Decimal128; +import org.bson.types.ObjectId; + +/** + * Provides conversion methods between {@link BsonValue}s, which our {@link PreparedStatement}/{@link ResultSet} + * implementation uses under the hood and rarely exposes, and domain values we usually use when setting parameter values + * on our {@link PreparedStatement}, or retrieving column values from a {@link ResultSet}. + */ +public final class ValueConverter { + private ValueConverter() {} + + public static BsonValue toBsonValue(Object value) throws SQLFeatureNotSupportedException { + assertNotNull(value); + if (value instanceof Boolean v) { + return toBsonValue(v.booleanValue()); + } else if (value instanceof Integer v) { + return toBsonValue(v.intValue()); + } else if (value instanceof Long v) { + return toBsonValue(v.longValue()); + } else if (value instanceof Double v) { + return toBsonValue(v.doubleValue()); + } else if (value instanceof BigDecimal v) { + return toBsonValue(v); + } else if (value instanceof String v) { + return toBsonValue(v); + } else if (value instanceof byte[] v) { + return toBsonValue(v); + } else if (value instanceof ObjectId v) { + return toBsonValue(v); + } else { + throw new SQLFeatureNotSupportedException(format( + "Value [%s] of type [%s] is not supported", + value, value.getClass().getTypeName())); + } + } + + public static BsonBoolean toBsonValue(boolean value) { + return BsonBoolean.valueOf(value); + } + + public static BsonInt32 toBsonValue(int value) { + return new BsonInt32(value); + } + + public static BsonInt64 toBsonValue(long value) { + return new BsonInt64(value); + } + + public static BsonDouble toBsonValue(double value) { + return new BsonDouble(value); + } + + public static BsonDecimal128 toBsonValue(BigDecimal value) { + return new BsonDecimal128(new Decimal128(value)); + } + + public static BsonString toBsonValue(String value) { + return new BsonString(value); + } + + public static BsonBinary toBsonValue(byte[] value) { + return new BsonBinary(value); + } + + public static BsonObjectId toBsonValue(ObjectId value) { + return new BsonObjectId(value); + } + + static Object toDomainValue(BsonValue value) throws SQLFeatureNotSupportedException { + assertNotNull(value); + if (value instanceof BsonBoolean v) { + return toDomainValue(v); + } else if (value instanceof BsonInt32 v) { + return toDomainValue(v); + } else if (value instanceof BsonInt64 v) { + return toDomainValue(v); + } else if (value instanceof BsonDouble v) { + return toDomainValue(v); + } else if (value instanceof BsonDecimal128 v) { + return toDomainValue(v); + } else if (value instanceof BsonString v) { + return toDomainValue(v); + } else if (value instanceof BsonBinary v) { + return toDomainValue(v); + } else if (value instanceof BsonObjectId v) { + return toDomainValue(v); + } else { + throw new SQLFeatureNotSupportedException(format( + "Value [%s] of type [%s] is not supported", + value, value.getClass().getTypeName())); + } + } + + public static boolean toBooleanDomainValue(BsonValue value) { + return toDomainValue(value.asBoolean()); + } + + private static boolean toDomainValue(BsonBoolean value) { + return value.getValue(); + } + + public static int toIntDomainValue(BsonValue value) { + return toDomainValue(value.asInt32()); + } + + private static int toDomainValue(BsonInt32 value) { + return value.intValue(); + } + + public static long toLongDomainValue(BsonValue value) { + return toDomainValue(value.asInt64()); + } + + private static long toDomainValue(BsonInt64 value) { + return value.longValue(); + } + + public static double toDoubleDomainValue(BsonValue value) { + return toDomainValue(value.asDouble()); + } + + private static double toDomainValue(BsonDouble value) { + return value.getValue(); + } + + public static BigDecimal toBigDecimalDomainValue(BsonValue value) { + return toDomainValue(value.asDecimal128()); + } + + private static BigDecimal toDomainValue(BsonDecimal128 value) { + return value.decimal128Value().bigDecimalValue(); + } + + public static String toStringDomainValue(BsonValue value) { + return toDomainValue(value.asString()); + } + + private static String toDomainValue(BsonString value) { + return value.getValue(); + } + + public static byte[] toByteArrayDomainValue(BsonValue value) { + return toDomainValue(value.asBinary()); + } + + private static byte[] toDomainValue(BsonBinary value) { + return value.asBinary().getData(); + } + + public static ObjectId toObjectIdDomainValue(BsonValue value) { + return toDomainValue(value.asObjectId()); + } + + private static ObjectId toDomainValue(BsonObjectId value) { + return value.getValue(); + } +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java index 264f8acb..155f0c29 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -17,6 +17,7 @@ package com.mongodb.hibernate.jdbc; import static com.mongodb.hibernate.internal.MongoAssertions.fail; +import static com.mongodb.hibernate.internal.type.ValueConverter.toBsonValue; import static java.lang.String.format; import com.mongodb.client.ClientSession; @@ -39,18 +40,9 @@ import java.util.List; import java.util.function.Consumer; import org.bson.BsonArray; -import org.bson.BsonBinary; -import org.bson.BsonBoolean; -import org.bson.BsonDecimal128; import org.bson.BsonDocument; -import org.bson.BsonDouble; -import org.bson.BsonInt32; -import org.bson.BsonInt64; -import org.bson.BsonObjectId; -import org.bson.BsonString; import org.bson.BsonType; import org.bson.BsonValue; -import org.bson.types.Decimal128; import org.bson.types.ObjectId; final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { @@ -111,7 +103,7 @@ public void setNull(int parameterIndex, int sqlType) throws SQLException { case Types.SQLXML: case Types.STRUCT: throw new SQLFeatureNotSupportedException( - "Unsupported sql type: " + JDBCType.valueOf(sqlType).getName()); + "Unsupported SQL type: " + JDBCType.valueOf(sqlType).getName()); } throw new SQLFeatureNotSupportedException( "TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74, TODO-HIBERNATE-48 https://jira.mongodb.org/browse/HIBERNATE-48"); @@ -121,49 +113,49 @@ public void setNull(int parameterIndex, int sqlType) throws SQLException { public void setBoolean(int parameterIndex, boolean x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, BsonBoolean.valueOf(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setInt(int parameterIndex, int x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonInt32(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setLong(int parameterIndex, long x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonInt64(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setDouble(int parameterIndex, double x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonDouble(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonDecimal128(new Decimal128(x))); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setString(int parameterIndex, String x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonString(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override public void setBytes(int parameterIndex, byte[] x) throws SQLException { checkClosed(); checkParameterIndex(parameterIndex); - setParameter(parameterIndex, new BsonBinary(x)); + setParameter(parameterIndex, toBsonValue(x)); } @Override @@ -193,24 +185,23 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ checkParameterIndex(parameterIndex); BsonValue value; if (targetSqlType == MqlType.OBJECT_ID.getVendorTypeNumber()) { - if (x instanceof ObjectId v) { - value = new BsonObjectId(v); - } else { + if (!(x instanceof ObjectId v)) { throw fail(); } + value = toBsonValue(v); + } else if (targetSqlType == JDBCType.STRUCT.getVendorTypeNumber()) { + if (!(x instanceof BsonDocument v)) { + throw fail(); + } + value = v; } else { - throw new SQLFeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); + throw new SQLFeatureNotSupportedException(format( + "Parameter value [%s] of SQL type [%d] with index [%d] is not supported", + x, targetSqlType, parameterIndex)); } setParameter(parameterIndex, value); } - @Override - public void setObject(int parameterIndex, Object x) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - throw new SQLFeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); - } - @Override public void addBatch() throws SQLException { checkClosed(); @@ -245,13 +236,6 @@ public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws S throw new SQLFeatureNotSupportedException("TODO-HIBERNATE-42 https://jira.mongodb.org/browse/HIBERNATE-42"); } - @Override - public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { - checkClosed(); - checkParameterIndex(parameterIndex); - throw new SQLFeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); - } - @Override public void setQueryTimeout(int seconds) throws SQLException { checkClosed(); diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java index 96ddc440..64a462b8 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoResultSet.java @@ -37,6 +37,7 @@ import static java.lang.String.format; import com.mongodb.client.MongoCursor; +import com.mongodb.hibernate.internal.type.ValueConverter; import java.math.BigDecimal; import java.sql.Array; import java.sql.Date; @@ -111,43 +112,42 @@ public boolean wasNull() throws SQLException { public @Nullable String getString(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asString().getValue()); + return getValue(columnIndex, ValueConverter::toStringDomainValue); } @Override public boolean getBoolean(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asBoolean().getValue(), false); + return getValue(columnIndex, ValueConverter::toBooleanDomainValue, false); } @Override public int getInt(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asInt32().intValue(), 0); + return getValue(columnIndex, ValueConverter::toIntDomainValue, 0); } @Override public long getLong(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asInt64().longValue(), 0L); + return getValue(columnIndex, ValueConverter::toLongDomainValue, 0L); } @Override public double getDouble(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asDouble().getValue(), 0) - .doubleValue(); + return getValue(columnIndex, ValueConverter::toDoubleDomainValue, 0d); } @Override public byte @Nullable [] getBytes(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue(columnIndex, bsonValue -> bsonValue.asBinary().getData()); + return getValue(columnIndex, ValueConverter::toByteArrayDomainValue); } @Override @@ -189,9 +189,7 @@ public double getDouble(int columnIndex) throws SQLException { public @Nullable BigDecimal getBigDecimal(int columnIndex) throws SQLException { checkClosed(); checkColumnIndex(columnIndex); - return getValue( - columnIndex, - bsonValue -> bsonValue.asDecimal128().decimal128Value().bigDecimalValue()); + return getValue(columnIndex, ValueConverter::toBigDecimalDomainValue); } @Override @@ -207,9 +205,12 @@ public double getDouble(int columnIndex) throws SQLException { checkColumnIndex(columnIndex); Object value; if (type.equals(ObjectId.class)) { - value = getValue(columnIndex, bsonValue -> bsonValue.asObjectId().getValue()); + value = getValue(columnIndex, ValueConverter::toObjectIdDomainValue); + } else if (type.equals(BsonDocument.class)) { + value = getValue(columnIndex, BsonValue::asDocument); } else { - throw new SQLFeatureNotSupportedException("To be implemented in scope of Array / Struct tickets"); + throw new SQLFeatureNotSupportedException( + format("Type [%s] for a column with index [%d] is not supported", type, columnIndex)); } return type.cast(value); } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java index c62eeda7..4c371ecd 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -258,14 +258,11 @@ private static void checkSetterMethods( () -> asserter.accept(() -> mongoPreparedStatement.setTimestamp(parameterIndex, new Timestamp(now))), () -> asserter.accept( () -> mongoPreparedStatement.setObject(parameterIndex, Mockito.mock(Array.class), Types.OTHER)), - () -> asserter.accept( - () -> mongoPreparedStatement.setObject(parameterIndex, Mockito.mock(Array.class))), () -> asserter.accept(() -> mongoPreparedStatement.setArray(parameterIndex, Mockito.mock(Array.class))), () -> asserter.accept(() -> mongoPreparedStatement.setDate(parameterIndex, new Date(now), calendar)), () -> asserter.accept(() -> mongoPreparedStatement.setTime(parameterIndex, new Time(now), calendar)), () -> asserter.accept( - () -> mongoPreparedStatement.setTimestamp(parameterIndex, new Timestamp(now), calendar)), - () -> asserter.accept(() -> mongoPreparedStatement.setNull(parameterIndex, Types.STRUCT, "BOOK"))); + () -> mongoPreparedStatement.setTimestamp(parameterIndex, new Timestamp(now), calendar))); } private static void checkMethodsWithOpenPrecondition(