diff --git a/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java new file mode 100644 index 00000000..9c288e3b --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/query/select/LimitOffsetFetchClauseIntegrationTests.java @@ -0,0 +1,653 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.query.select; + +import static com.mongodb.hibernate.internal.MongoAssertions.fail; +import static java.lang.String.format; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.hibernate.cfg.JdbcSettings.DIALECT; +import static org.hibernate.cfg.QuerySettings.QUERY_PLAN_CACHE_ENABLED; +import static org.junit.jupiter.params.provider.EnumSource.Mode.EXCLUDE; + +import com.mongodb.hibernate.dialect.MongoDialect; +import com.mongodb.hibernate.internal.FeatureNotSupportedException; +import com.mongodb.hibernate.internal.MongoConstants; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.bson.BsonDocument; +import org.hibernate.Session; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.dialect.spi.DialectResolutionInfo; +import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.query.sqm.FetchClauseType; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.SqlAstTranslatorFactory; +import org.hibernate.sql.ast.tree.MutationStatement; +import org.hibernate.sql.ast.tree.select.SelectStatement; +import org.hibernate.sql.exec.spi.JdbcOperationQueryMutation; +import org.hibernate.sql.exec.spi.JdbcOperationQuerySelect; +import org.hibernate.sql.model.ast.TableMutation; +import org.hibernate.sql.model.jdbc.JdbcMutationOperation; +import org.hibernate.stat.QueryStatistics; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.Setting; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.ValueSource; + +@DomainModel(annotatedClasses = Book.class) +class LimitOffsetFetchClauseIntegrationTests extends AbstractSelectionQueryIntegrationTests { + + private static final List testingBooks = List.of( + new Book(0, "Nostromo", 1904, true), + new Book(1, "The Age of Innocence", 1920, false), + new Book(2, "Remembrance of Things Past", 1913, true), + new Book(3, "The Magic Mountain", 1924, false), + new Book(4, "A Passage to India", 1924, true), + new Book(5, "Ulysses", 1922, false), + new Book(6, "Mrs. Dalloway", 1925, false), + new Book(7, "The Trial", 1925, true), + new Book(8, "Sons and Lovers", 1913, false), + new Book(9, "The Sound and the Fury", 1929, false)); + + private static List getBooksByIds(int... ids) { + return Arrays.stream(ids) + .mapToObj(id -> testingBooks.stream() + .filter(c -> c.id == id) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("id does not exist: " + id))) + .toList(); + } + + @BeforeEach + void beforeEach() { + getSessionFactoryScope().inTransaction(session -> testingBooks.forEach(session::persist)); + getTestCommandListener().clear(); + } + + @Nested + class WithoutQueryOptionsLimit { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testHqlLimitClauseOnly(boolean useLiteralParameter) { + assertSelectionQuery( + useLiteralParameter ? "from Book order by id LIMIT 5" : "from Book order by id LIMIT :limit", + Book.class, + useLiteralParameter ? null : q -> q.setParameter("limit", 5), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$limit": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(5), + getBooksByIds(0, 1, 2, 3, 4)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testHqlOffsetClauseOnly(boolean useLiteralParameter) { + assertSelectionQuery( + useLiteralParameter ? "from Book order by id OFFSET 7" : "from Book order by id OFFSET :offset", + Book.class, + useLiteralParameter ? null : q -> q.setParameter("offset", 7), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$skip": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(7), + getBooksByIds(7, 8, 9)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testHqlLimitAndOffsetClauses(boolean useLiteralParameters) { + assertSelectionQuery( + useLiteralParameters + ? "from Book order by id LIMIT 2 OFFSET 3" + : "from Book order by id LIMIT :limit OFFSET :offset", + Book.class, + useLiteralParameters + ? null + : q -> q.setParameter("offset", 3).setParameter("limit", 2), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$skip": %d + }, + { + "$limit": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(3, 2), + getBooksByIds(3, 4)); + } + + @ParameterizedTest + @ValueSource( + strings = { + "FETCH FIRST :limit ROWS ONLY", + "FETCH NEXT :limit ROWS ONLY", + }) + void testHqlFetchClauseOnly(String fetchClause) { + assertSelectionQuery( + "from Book order by id " + fetchClause, + Book.class, + q -> q.setParameter("limit", 5), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$limit": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(5), + getBooksByIds(0, 1, 2, 3, 4)); + } + } + + @Nested + class WithQueryOptionsLimit { + + @Nested + class WithoutHqlClauses { + @Test + void testQueryOptionsSetFirstResultOnly() { + assertSelectionQuery( + "from Book order by id", + Book.class, + q -> q.setFirstResult(6), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$skip": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(6), + getBooksByIds(6, 7, 8, 9)); + } + + @Test + void testQueryOptionsSetMaxResultOnly() { + assertSelectionQuery( + "from Book order by id", + Book.class, + q -> q.setMaxResults(3), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$limit": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(3), + getBooksByIds(0, 1, 2)); + } + + @Test + void testQueryOptionsSetFirstResultAndMaxResults() { + assertSelectionQuery( + "from Book order by id", + Book.class, + q -> q.setFirstResult(2).setMaxResults(3), + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + { + "$skip": %d + }, + { + "$limit": %d + }, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """ + .formatted(2, 3), + getBooksByIds(2, 3, 4)); + } + } + + @Nested + class WithHqlClauses { + + private static final String expectedMqlTemplate = + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + %s, + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """; + + @Test + void testFirstResultConflictingOnly() { + var firstResult = 5; + var expectedBooks = getBooksByIds(5, 6, 7, 8, 9); + assertSelectionQuery( + "from Book order by id LIMIT :limit OFFSET :offset", + Book.class, + q -> + // hql clauses will be ignored totally + q.setParameter("limit", 10) + .setParameter("offset", 0) + .setFirstResult(firstResult), + expectedMqlTemplate.formatted("{\"$skip\": " + firstResult + "}"), + expectedBooks); + } + + @Test + void testMaxResultsConflictingOnly() { + var maxResults = 3; + var expectedBooks = getBooksByIds(0, 1, 2); + assertSelectionQuery( + "from Book order by id LIMIT :limit OFFSET :offset", + Book.class, + q -> + // hql clauses will be ignored totally + q.setParameter("limit", 10) + .setParameter("offset", 0) + .setMaxResults(maxResults), + expectedMqlTemplate.formatted("{\"$limit\": " + maxResults + "}"), + expectedBooks); + } + + @Test + void testBothFirstResultAndMaxResultsConflicting() { + var firstResult = 5; + var maxResults = 3; + var expectedBooks = getBooksByIds(5, 6, 7); + assertSelectionQuery( + "from Book order by id LIMIT :limit OFFSET :offset", + Book.class, + q -> + // hql clauses will be ignored totally + q.setParameter("limit", 10) + .setParameter("offset", 0) + .setFirstResult(firstResult) + .setMaxResults(maxResults), + expectedMqlTemplate.formatted( + "{\"$skip\": " + firstResult + "}," + "{\"$limit\": " + maxResults + "}"), + expectedBooks); + } + } + } + + @Nested + class FeatureNotSupportedTests { + + @ParameterizedTest + @EnumSource(value = FetchClauseType.class, mode = EXCLUDE, names = "ROWS_ONLY") + void testUnsupportedFetchClauseType(FetchClauseType fetchClauseType) { + var hqlSuffix = + switch (fetchClauseType) { + case ROWS_ONLY -> fail("ROWS_ONLY should have been excluded from the test"); + case ROWS_WITH_TIES -> "FETCH FIRST :limit ROWS WITH TIES"; + case PERCENT_ONLY -> "FETCH FIRST :limit PERCENT ROWS ONLY"; + case PERCENT_WITH_TIES -> "FETCH FIRST :limit PERCENT ROWS WITH TIES"; + }; + var hql = "from Book order by id " + hqlSuffix; + assertSelectQueryFailure( + hql, + Book.class, + q -> q.setParameter("limit", 10), + FeatureNotSupportedException.class, + "%s does not support '%s' fetch clause type", + MongoConstants.MONGO_DBMS_NAME, + fetchClauseType); + } + } + + @Nested + @DomainModel(annotatedClasses = Book.class) + @ServiceRegistry( + settings = { + @Setting(name = QUERY_PLAN_CACHE_ENABLED, value = "true"), + @Setting( + name = DIALECT, + value = + "com.mongodb.hibernate.query.select.LimitOffsetFetchClauseIntegrationTests$TranslatingCacheTestingDialect"), + }) + class QueryPlanCacheTests extends AbstractSelectionQueryIntegrationTests { + + private static final String HQL = "from Book order by id"; + private static final String expectedMqlTemplate = + """ + { + "aggregate": "books", + "pipeline": [ + { + "$sort": { + "_id": 1 + } + }, + %s + %s + { + "$project": { + "_id": true, + "discount": true, + "isbn13": true, + "outOfStock": true, + "price": true, + "publishYear": true, + "title": true + } + } + ] + } + """; + + private TranslatingCacheTestingDialect translatingCacheTestingDialect; + + @BeforeEach + void beforeEach() { + translatingCacheTestingDialect = (TranslatingCacheTestingDialect) getSessionFactoryScope() + .getSessionFactory() + .getJdbcServices() + .getDialect(); + getTestCommandListener().clear(); + } + + @ParameterizedTest + @CsvSource({"true,false", "false,true", "true,true"}) + void testQueryOptionsLimitCached(boolean isFirstResultSet, boolean isMaxResultsSet) { + getSessionFactoryScope().inTransaction(session -> { + setQueryOptionsAndQuery( + session, + isFirstResultSet ? 5 : null, + isMaxResultsSet ? 10 : null, + format( + expectedMqlTemplate, + (isFirstResultSet ? "{\"$skip\": 5}," : ""), + (isMaxResultsSet ? "{\"$limit\": 10}," : ""))); + var initialSelectTranslatingCount = translatingCacheTestingDialect.getSelectTranslatingCount(); + + assertThat(initialSelectTranslatingCount).isPositive(); + + setQueryOptionsAndQuery( + session, + isFirstResultSet ? 3 : null, + isMaxResultsSet ? 6 : null, + format( + expectedMqlTemplate, + (isFirstResultSet ? "{\"$skip\": 3}," : ""), + (isMaxResultsSet ? "{\"$limit\": 6}," : ""))); + assertThat(translatingCacheTestingDialect.getSelectTranslatingCount()) + .isEqualTo(initialSelectTranslatingCount); + }); + } + + @Test + void testCacheInvalidatedDueToQueryOptionsAdded() { + getSessionFactoryScope().inTransaction(session -> { + setQueryOptionsAndQuery(session, null, null, format(expectedMqlTemplate, "", "")); + var initialSelectTranslatingCount = translatingCacheTestingDialect.getSelectTranslatingCount(); + assertThat(initialSelectTranslatingCount).isPositive(); + + setQueryOptionsAndQuery(session, 1, null, format(expectedMqlTemplate, "{\"$skip\": 1},", "")); + assertThat(translatingCacheTestingDialect.getSelectTranslatingCount()) + .isEqualTo(initialSelectTranslatingCount + 1); + + setQueryOptionsAndQuery( + session, 1, 5, format(expectedMqlTemplate, "{\"$skip\": 1},", "{\"$limit\": 5},")); + assertThat(translatingCacheTestingDialect.getSelectTranslatingCount()) + .isEqualTo(initialSelectTranslatingCount + 2); + }); + } + + @Test + void testCacheInvalidatedDueToQueryOptionsRemoved() { + getSessionFactoryScope().inTransaction(session -> { + setQueryOptionsAndQuery( + session, 10, 5, format(expectedMqlTemplate, "{\"$skip\": 10},", "{\"$limit\": 5},")); + var initialSelectTranslatingCount = translatingCacheTestingDialect.getSelectTranslatingCount(); + assertThat(initialSelectTranslatingCount).isPositive(); + + setQueryOptionsAndQuery(session, null, 5, format(expectedMqlTemplate, "", "{\"$limit\": 5},")); + assertThat(translatingCacheTestingDialect.getSelectTranslatingCount()) + .isEqualTo(initialSelectTranslatingCount + 1); + + setQueryOptionsAndQuery(session, null, null, format(expectedMqlTemplate, "", "")); + assertThat(translatingCacheTestingDialect.getSelectTranslatingCount()) + .isEqualTo(initialSelectTranslatingCount + 2); + }); + } + + private void setQueryOptionsAndQuery( + Session session, Integer firstResult, Integer maxResults, String expectedMql) { + var query = session.createSelectionQuery(HQL, Book.class); + if (firstResult != null) { + query.setFirstResult(firstResult); + } + if (maxResults != null) { + query.setMaxResults(maxResults); + } + getTestCommandListener().clear(); + query.getResultList(); + if (expectedMql != null) { + var expectedCommand = BsonDocument.parse(expectedMql); + assertActualCommand(expectedCommand); + } + } + } + + /** + * A dialect that counts how many times the select translator is created. + * + *

Note that {@link QueryStatistics#getPlanCacheHitCount()} is not used, because it counts the number of times + * the query plan cache is hit, not whether {@link SqlAstTranslator} is reused afterwards (e.g., incompatible + * {@link org.hibernate.query.spi.QueryOptions QueryOptions}s will end up with new translator bing created). + */ + public static final class TranslatingCacheTestingDialect extends Dialect { + private final AtomicInteger selectTranslatingCounter = new AtomicInteger(); + private final Dialect delegate; + + public TranslatingCacheTestingDialect(DialectResolutionInfo info) { + super(info); + delegate = new MongoDialect(info); + } + + @Override + public SqlAstTranslatorFactory getSqlAstTranslatorFactory() { + return new SqlAstTranslatorFactory() { + @Override + public SqlAstTranslator buildSelectTranslator( + SessionFactoryImplementor sessionFactory, SelectStatement statement) { + selectTranslatingCounter.incrementAndGet(); + return delegate.getSqlAstTranslatorFactory().buildSelectTranslator(sessionFactory, statement); + } + + @Override + public SqlAstTranslator buildMutationTranslator( + SessionFactoryImplementor sessionFactory, MutationStatement statement) { + throw new IllegalStateException("mutation translator not expected"); + } + + @Override + public SqlAstTranslator buildModelMutationTranslator( + TableMutation mutation, SessionFactoryImplementor sessionFactory) { + return delegate.getSqlAstTranslatorFactory().buildModelMutationTranslator(mutation, sessionFactory); + } + }; + } + + public int getSelectTranslatingCount() { + return selectTranslatingCounter.get(); + } + } +} diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java b/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java index 22df389a..6254f6f5 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/AbstractMqlTranslator.java @@ -17,18 +17,20 @@ package com.mongodb.hibernate.internal.translate; import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNull; import static com.mongodb.hibernate.internal.MongoAssertions.assertTrue; +import static com.mongodb.hibernate.internal.MongoAssertions.fail; import static com.mongodb.hibernate.internal.MongoConstants.EXTENDED_JSON_WRITER_SETTINGS; import static com.mongodb.hibernate.internal.MongoConstants.MONGO_DBMS_NAME; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_AGGREGATE; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_MUTATION; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_NAME; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.FIELD_PATH; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.FIELD_VALUE; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.FILTER; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.MUTATION_RESULT; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.PROJECT_STAGE_SPECIFICATIONS; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.SELECT_RESULT; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.SORT_FIELDS; import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.TUPLE; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.VALUE; import static com.mongodb.hibernate.internal.translate.mongoast.AstLiteralValue.FALSE; import static com.mongodb.hibernate.internal.translate.mongoast.AstLiteralValue.TRUE; import static com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSortOrder.ASC; @@ -43,6 +45,7 @@ import static com.mongodb.hibernate.internal.translate.mongoast.filter.AstLogicalFilterOperator.NOR; import static com.mongodb.hibernate.internal.translate.mongoast.filter.AstLogicalFilterOperator.OR; import static java.lang.String.format; +import static org.hibernate.query.sqm.FetchClauseType.ROWS_ONLY; import com.mongodb.hibernate.internal.FeatureNotSupportedException; import com.mongodb.hibernate.internal.extension.service.StandardServiceRegistryScopedState; @@ -52,15 +55,16 @@ import com.mongodb.hibernate.internal.translate.mongoast.AstLiteralValue; import com.mongodb.hibernate.internal.translate.mongoast.AstNode; import com.mongodb.hibernate.internal.translate.mongoast.AstParameterMarker; -import com.mongodb.hibernate.internal.translate.mongoast.command.AstCommand; import com.mongodb.hibernate.internal.translate.mongoast.command.AstDeleteCommand; import com.mongodb.hibernate.internal.translate.mongoast.command.AstInsertCommand; import com.mongodb.hibernate.internal.translate.mongoast.command.AstUpdateCommand; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstAggregateCommand; +import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstLimitStage; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstMatchStage; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstProjectStage; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstProjectStageIncludeSpecification; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstProjectStageSpecification; +import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSkipStage; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSortField; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSortOrder; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSortStage; @@ -73,6 +77,8 @@ import com.mongodb.hibernate.internal.type.ValueConversions; import java.io.IOException; import java.io.StringWriter; +import java.sql.PreparedStatement; +import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.util.ArrayList; import java.util.HashSet; @@ -86,6 +92,7 @@ import org.hibernate.persister.entity.EntityPersister; import org.hibernate.persister.internal.SqlFragmentPredicate; import org.hibernate.query.NullPrecedence; +import org.hibernate.query.spi.Limit; import org.hibernate.query.spi.QueryOptions; import org.hibernate.query.sqm.ComparisonOperator; import org.hibernate.query.sqm.sql.internal.BasicValuedPathInterpretation; @@ -165,6 +172,8 @@ import org.hibernate.sql.ast.tree.select.SortSpecification; import org.hibernate.sql.ast.tree.update.Assignment; import org.hibernate.sql.ast.tree.update.UpdateStatement; +import org.hibernate.sql.exec.internal.AbstractJdbcParameter; +import org.hibernate.sql.exec.spi.ExecutionContext; import org.hibernate.sql.exec.spi.JdbcOperation; import org.hibernate.sql.exec.spi.JdbcParameterBinder; import org.hibernate.sql.exec.spi.JdbcParameterBindings; @@ -178,6 +187,7 @@ import org.hibernate.sql.model.internal.TableInsertStandard; import org.hibernate.sql.model.internal.TableUpdateCustomSql; import org.hibernate.sql.model.internal.TableUpdateStandard; +import org.hibernate.type.BasicType; import org.jspecify.annotations.Nullable; abstract class AbstractMqlTranslator implements SqlAstTranslator { @@ -190,6 +200,8 @@ abstract class AbstractMqlTranslator implements SqlAstT private final Set affectedTableNames = new HashSet<>(); + private @Nullable QueryOptionsLimit queryOptionsLimit; + AbstractMqlTranslator(SessionFactoryImplementor sessionFactory) { this.sessionFactory = sessionFactory; assertNotNull(sessionFactory @@ -225,15 +237,11 @@ public Stack getCurrentClauseStack() { @Override public Set getAffectedTableNames() { - return affectedTableNames; - } - - List getParameterBinders() { - return parameterBinders; + throw fail(); } @SuppressWarnings("overloads") - R acceptAndYield(Statement statement, AstVisitorValueDescriptor resultDescriptor) { + R acceptAndYield(Statement statement, AstVisitorValueDescriptor resultDescriptor) { return astVisitorValueHolder.execute(resultDescriptor, () -> statement.accept(this)); } @@ -254,12 +262,15 @@ public void visitStandardTableInsert(TableInsertStandard tableInsert) { if (valueExpression == null) { throw new FeatureNotSupportedException(); } - var fieldValue = acceptAndYield(valueExpression, FIELD_VALUE); + var fieldValue = acceptAndYield(valueExpression, VALUE); astElements.add(new AstElement(fieldName, fieldValue)); } astVisitorValueHolder.yield( - COLLECTION_MUTATION, - new AstInsertCommand(tableInsert.getMutatingTable().getTableName(), new AstDocument(astElements))); + MUTATION_RESULT, + ModelMutationMqlTranslator.Result.create( + new AstInsertCommand( + tableInsert.getMutatingTable().getTableName(), new AstDocument(astElements)), + parameterBinders)); } @Override @@ -277,8 +288,10 @@ public void visitStandardTableDelete(TableDeleteStandard tableDelete) { } var keyFilter = getKeyFilter(tableDelete); astVisitorValueHolder.yield( - COLLECTION_MUTATION, - new AstDeleteCommand(tableDelete.getMutatingTable().getTableName(), keyFilter)); + MUTATION_RESULT, + ModelMutationMqlTranslator.Result.create( + new AstDeleteCommand(tableDelete.getMutatingTable().getTableName(), keyFilter), + parameterBinders)); } @Override @@ -293,12 +306,14 @@ public void visitStandardTableUpdate(TableUpdateStandard tableUpdate) { var updates = new ArrayList(tableUpdate.getNumberOfValueBindings()); for (var valueBinding : tableUpdate.getValueBindings()) { var fieldName = valueBinding.getColumnReference().getColumnExpression(); - var fieldValue = acceptAndYield(valueBinding.getValueExpression(), FIELD_VALUE); + var fieldValue = acceptAndYield(valueBinding.getValueExpression(), VALUE); updates.add(new AstFieldUpdate(fieldName, fieldValue)); } astVisitorValueHolder.yield( - COLLECTION_MUTATION, - new AstUpdateCommand(tableUpdate.getMutatingTable().getTableName(), keyFilter, updates)); + MUTATION_RESULT, + ModelMutationMqlTranslator.Result.create( + new AstUpdateCommand(tableUpdate.getMutatingTable().getTableName(), keyFilter, updates), + parameterBinders)); } private AstFilter getKeyFilter(AbstractRestrictedTableMutation tableMutation) { @@ -314,14 +329,14 @@ private AstFilter getKeyFilter(AbstractRestrictedTableMutation(3); + var stages = new ArrayList(); createMatchStage(querySpec).ifPresent(stages::add); createSortStage(querySpec).ifPresent(stages::add); + + var skipLimitStagesAndJdbcParams = + assertNotNull(queryOptionsLimit).createSkipLimitStagesAndJdbcParams(querySpec); + stages.addAll(skipLimitStagesAndJdbcParams.stages()); + stages.add(createProjectStage(querySpec.getSelectClause())); - astVisitorValueHolder.yield(COLLECTION_AGGREGATE, new AstAggregateCommand(collection, stages)); + astVisitorValueHolder.yield( + SELECT_RESULT, + new SelectMqlTranslator.Result( + new AstAggregateCommand(collection, stages), + parameterBinders, + affectedTableNames, + skipLimitStagesAndJdbcParams.offset(), + skipLimitStagesAndJdbcParams.limit())); } private Optional createMatchStage(QuerySpec querySpec) { @@ -378,6 +402,72 @@ private Optional createSortStage(QuerySpec querySpec) { return Optional.empty(); } + @Override + public void visitOffsetFetchClause(QueryPart queryPart) { + fail(); + } + + private final class QueryOptionsLimit { + private final @Nullable Limit limit; + + QueryOptionsLimit(@Nullable Limit limit) { + this.limit = limit; + } + + StagesAndJdbcParameters createSkipLimitStagesAndJdbcParams(QueryPart queryPart) { + Expression skipExpression; + Expression limitExpression; + JdbcParameter offsetParameter = null; + JdbcParameter limitParameter = null; + if (queryPart.isRoot() && limit != null && !limit.isEmpty()) { + // We check if limit's firstRow/maxRows is set, + // but ignore the actual values when creating OffsetJdbcParameter/LimitJdbcParameter. + // Hibernate ORM reuses the translation result for the same HQL/SQL queries + // with different values passed to setFirstResult/setMaxResults. Therefore, we cannot include the + // values available when translating in the translation result. The only thing we pay attention to is + // whether they are specified or not, because the translation results corresponding to + // setFirstResult/setMaxResults being present + // must be different from those with the limits being absent. Hibernate ORM also caches them separately. + var basicIntegerType = sessionFactory.getTypeConfiguration().getBasicTypeForJavaType(Integer.class); + if (limit.getFirstRow() != null) { + offsetParameter = new OffsetJdbcParameter(basicIntegerType); + } + if (limit.getMaxRows() != null) { + limitParameter = new LimitJdbcParameter(basicIntegerType); + } + skipExpression = offsetParameter; + limitExpression = limitParameter; + } else { + if (queryPart.getFetchClauseType() != ROWS_ONLY) { + throw new FeatureNotSupportedException(format( + "%s does not support '%s' fetch clause type", + MONGO_DBMS_NAME, queryPart.getFetchClauseType())); + } + skipExpression = queryPart.getOffsetClauseExpression(); + limitExpression = queryPart.getFetchClauseExpression(); + } + var skipAndLimitStages = new ArrayList(); + if (skipExpression != null) { + var skipValue = acceptAndYield(skipExpression, VALUE); + skipAndLimitStages.add(new AstSkipStage(skipValue)); + } + if (limitExpression != null) { + var limitValue = acceptAndYield(limitExpression, VALUE); + skipAndLimitStages.add(new AstLimitStage(limitValue)); + } + return new StagesAndJdbcParameters(skipAndLimitStages, offsetParameter, limitParameter); + } + + record StagesAndJdbcParameters( + List stages, @Nullable JdbcParameter offset, @Nullable JdbcParameter limit) {} + } + + void applyQueryOptions(QueryOptions queryOptions) { + checkQueryOptionsSupportability(queryOptions); + assertNull(queryOptionsLimit); + queryOptionsLimit = new QueryOptionsLimit(queryOptions.getLimit()); + } + private AstProjectStage createProjectStage(SelectClause selectClause) { var projectStageSpecifications = acceptAndYield(selectClause, PROJECT_STAGE_SPECIFICATIONS); return new AstProjectStage(projectStageSpecifications); @@ -420,7 +510,7 @@ public void visitRelationalPredicate(ComparisonPredicate comparisonPredicate) { } var fieldPath = acceptAndYield((isFieldOnLeftHandSide ? lhs : rhs), FIELD_PATH); - var comparisonValue = acceptAndYield((isFieldOnLeftHandSide ? rhs : lhs), FIELD_VALUE); + var comparisonValue = acceptAndYield((isFieldOnLeftHandSide ? rhs : lhs), VALUE); var operator = isFieldOnLeftHandSide ? comparisonPredicate.getOperator() @@ -479,7 +569,7 @@ public void visitQueryLiteral(QueryLiteral queryLiteral) { if (literalValue == null) { throw new FeatureNotSupportedException("TODO-HIBERNATE-74 https://jira.mongodb.org/browse/HIBERNATE-74"); } - astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(toBsonValue(literalValue))); + astVisitorValueHolder.yield(VALUE, new AstLiteralValue(toBsonValue(literalValue))); } @Override @@ -499,7 +589,7 @@ public void visitJunction(Junction junction) { @Override public void visitUnparsedNumericLiteral(UnparsedNumericLiteral unparsedNumericLiteral) { var literalValue = assertNotNull(unparsedNumericLiteral.getLiteralValue()); - astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(toBsonValue(literalValue))); + astVisitorValueHolder.yield(VALUE, new AstLiteralValue(toBsonValue(literalValue))); } @Override @@ -599,11 +689,6 @@ public void visitQueryGroup(QueryGroup queryGroup) { throw new FeatureNotSupportedException(); } - @Override - public void visitOffsetFetchClause(QueryPart queryPart) { - throw new FeatureNotSupportedException(); - } - @Override public void visitSqlSelection(SqlSelection sqlSelection) { throw new FeatureNotSupportedException(); @@ -871,7 +956,7 @@ static void checkJdbcParameterBindingsSupportability(@Nullable JdbcParameterBind } } - static void checkQueryOptionsSupportability(QueryOptions queryOptions) { + private static void checkQueryOptionsSupportability(QueryOptions queryOptions) { if (queryOptions.getTimeout() != null) { throw new FeatureNotSupportedException("'timeout' inQueryOptions not supported"); } @@ -913,9 +998,6 @@ static void checkQueryOptionsSupportability(QueryOptions queryOptions) { if (queryOptions.getFetchSize() != null) { throw new FeatureNotSupportedException("TODO-HIBERNATE-54 https://jira.mongodb.org/browse/HIBERNATE-54"); } - if (queryOptions.getLimit() != null && !queryOptions.getLimit().isEmpty()) { - throw new FeatureNotSupportedException("TODO-HIBERNATE-70 https://jira.mongodb.org/browse/HIBERNATE-70"); - } } private static AstComparisonFilterOperator getAstComparisonFilterOperator(ComparisonOperator operator) { @@ -956,4 +1038,52 @@ private static BsonValue toBsonValue(Object value) { throw new FeatureNotSupportedException(e); } } + + private static final class OffsetJdbcParameter extends AbstractJdbcParameter { + + OffsetJdbcParameter(BasicType type) { + super(type); + } + + @Override + @SuppressWarnings("unchecked") + public void bindParameterValue( + PreparedStatement statement, + int startPosition, + JdbcParameterBindings jdbcParamBindings, + ExecutionContext executionContext) + throws SQLException { + getJdbcMapping() + .getJdbcValueBinder() + .bind( + statement, + executionContext.getQueryOptions().getLimit().getFirstRow(), + startPosition, + executionContext.getSession()); + } + } + + private static final class LimitJdbcParameter extends AbstractJdbcParameter { + + LimitJdbcParameter(BasicType type) { + super(type); + } + + @Override + @SuppressWarnings("unchecked") + public void bindParameterValue( + PreparedStatement statement, + int startPosition, + JdbcParameterBindings jdbcParamBindings, + ExecutionContext executionContext) + throws SQLException { + getJdbcMapping() + .getJdbcValueBinder() + .bind( + statement, + executionContext.getQueryOptions().getLimit().getMaxRows(), + startPosition, + executionContext.getSession()); + } + } } diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptor.java b/src/main/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptor.java index a0f8cc99..ad844f63 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptor.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptor.java @@ -20,7 +20,6 @@ import static com.mongodb.hibernate.internal.MongoAssertions.fail; import com.mongodb.hibernate.internal.translate.mongoast.AstValue; -import com.mongodb.hibernate.internal.translate.mongoast.command.AstCommand; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstProjectStageSpecification; import com.mongodb.hibernate.internal.translate.mongoast.command.aggregate.AstSortField; import com.mongodb.hibernate.internal.translate.mongoast.filter.AstFilter; @@ -34,13 +33,15 @@ @SuppressWarnings("UnusedTypeParameter") final class AstVisitorValueDescriptor { - static final AstVisitorValueDescriptor COLLECTION_MUTATION = new AstVisitorValueDescriptor<>(); - static final AstVisitorValueDescriptor COLLECTION_AGGREGATE = new AstVisitorValueDescriptor<>(); + static final AstVisitorValueDescriptor MUTATION_RESULT = + new AstVisitorValueDescriptor<>(); + static final AstVisitorValueDescriptor SELECT_RESULT = + new AstVisitorValueDescriptor<>(); static final AstVisitorValueDescriptor COLLECTION_NAME = new AstVisitorValueDescriptor<>(); static final AstVisitorValueDescriptor FIELD_PATH = new AstVisitorValueDescriptor<>(); - static final AstVisitorValueDescriptor FIELD_VALUE = new AstVisitorValueDescriptor<>(); + static final AstVisitorValueDescriptor VALUE = new AstVisitorValueDescriptor<>(); static final AstVisitorValueDescriptor> PROJECT_STAGE_SPECIFICATIONS = new AstVisitorValueDescriptor<>(); diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java b/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java index b940323d..90b75c40 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/ModelMutationMqlTranslator.java @@ -16,11 +16,16 @@ package com.mongodb.hibernate.internal.translate; +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; import static com.mongodb.hibernate.internal.MongoAssertions.assertNull; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_MUTATION; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.MUTATION_RESULT; +import static java.util.Collections.emptyList; +import com.mongodb.hibernate.internal.translate.mongoast.command.AstCommand; +import java.util.List; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.query.spi.QueryOptions; +import org.hibernate.sql.exec.spi.JdbcParameterBinder; import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.model.ast.TableMutation; import org.hibernate.sql.model.internal.TableUpdateNoSet; @@ -39,15 +44,38 @@ final class ModelMutationMqlTranslator extends @Override public O translate(@Nullable JdbcParameterBindings jdbcParameterBindings, QueryOptions queryOptions) { assertNull(jdbcParameterBindings); - checkQueryOptionsSupportability(queryOptions); + applyQueryOptions(queryOptions); - String mql; + Result result; if ((TableMutation) tableMutation instanceof TableUpdateNoSet) { - mql = ""; + result = Result.empty(); } else { - var mutationCommand = acceptAndYield(tableMutation, COLLECTION_MUTATION); - mql = renderMongoAstNode(mutationCommand); + result = acceptAndYield(tableMutation, MUTATION_RESULT); + } + return result.createJdbcMutationOperation(tableMutation); + } + + static final class Result { + private final @Nullable AstCommand command; + + private final List parameterBinders; + + private Result(@Nullable AstCommand command, List parameterBinders) { + this.command = command; + this.parameterBinders = parameterBinders; + } + + static Result create(AstCommand command, List parameterBinders) { + return new Result(assertNotNull(command), parameterBinders); + } + + private static Result empty() { + return new Result(null, emptyList()); + } + + private O createJdbcMutationOperation(TableMutation tableMutation) { + var mql = command == null ? "" : renderMongoAstNode(command); + return tableMutation.createMutationOperation(mql, parameterBinders); } - return tableMutation.createMutationOperation(mql, getParameterBinders()); } } diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslator.java b/src/main/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslator.java index 3a745ac2..2e6981b6 100644 --- a/src/main/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslator.java +++ b/src/main/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslator.java @@ -16,14 +16,22 @@ package com.mongodb.hibernate.internal.translate; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_AGGREGATE; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.SELECT_RESULT; +import static java.lang.Integer.MAX_VALUE; +import static java.util.Collections.emptyMap; import static org.hibernate.sql.ast.SqlTreePrinter.logSqlAst; +import static org.hibernate.sql.exec.spi.JdbcLockStrategy.NONE; +import com.mongodb.hibernate.internal.translate.mongoast.command.AstCommand; +import java.util.List; +import java.util.Set; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.query.spi.QueryOptions; import org.hibernate.sql.ast.tree.Statement; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.exec.spi.JdbcOperationQuerySelect; +import org.hibernate.sql.exec.spi.JdbcParameterBinder; import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.results.jdbc.spi.JdbcValuesMappingProducerProvider; import org.jspecify.annotations.Nullable; @@ -31,13 +39,10 @@ final class SelectMqlTranslator extends AbstractMqlTranslator { private final SelectStatement selectStatement; - private final JdbcValuesMappingProducerProvider jdbcValuesMappingProducerProvider; SelectMqlTranslator(SessionFactoryImplementor sessionFactory, SelectStatement selectStatement) { super(sessionFactory); this.selectStatement = selectStatement; - jdbcValuesMappingProducerProvider = - sessionFactory.getServiceRegistry().requireService(JdbcValuesMappingProducerProvider.class); } @Override @@ -47,16 +52,51 @@ public JdbcOperationQuerySelect translate( logSqlAst(selectStatement); checkJdbcParameterBindingsSupportability(jdbcParameterBindings); - checkQueryOptionsSupportability(queryOptions); + applyQueryOptions(queryOptions); - var aggregateCommand = acceptAndYield((Statement) selectStatement, COLLECTION_AGGREGATE); - var jdbcValuesMappingProducer = - jdbcValuesMappingProducerProvider.buildMappingProducer(selectStatement, getSessionFactory()); + var result = acceptAndYield((Statement) selectStatement, SELECT_RESULT); + return result.createJdbcOperationQuerySelect(selectStatement, getSessionFactory()); + } + + static final class Result { + private final AstCommand command; + private final List parameterBinders; + private final Set affectedTableNames; + private final @Nullable JdbcParameter offsetParameter; + private final @Nullable JdbcParameter limitParameter; + + Result( + AstCommand command, + List parameterBinders, + Set affectedTableNames, + @Nullable JdbcParameter offsetParameter, + @Nullable JdbcParameter limitParameter) { + this.command = command; + this.parameterBinders = parameterBinders; + this.affectedTableNames = affectedTableNames; + this.offsetParameter = offsetParameter; + this.limitParameter = limitParameter; + } - return new JdbcOperationQuerySelect( - renderMongoAstNode(aggregateCommand), - getParameterBinders(), - jdbcValuesMappingProducer, - getAffectedTableNames()); + private JdbcOperationQuerySelect createJdbcOperationQuerySelect( + SelectStatement selectStatement, SessionFactoryImplementor sessionFactory) { + var jdbcValuesMappingProducerProvider = + sessionFactory.getServiceRegistry().requireService(JdbcValuesMappingProducerProvider.class); + var jdbcValuesMappingProducer = + jdbcValuesMappingProducerProvider.buildMappingProducer(selectStatement, sessionFactory); + return new JdbcOperationQuerySelect( + renderMongoAstNode(command), + parameterBinders, + jdbcValuesMappingProducer, + affectedTableNames, + 0, + MAX_VALUE, + emptyMap(), + NONE, + // The following parameters are provided for query plan cache purposes. + // Not setting them could result in reusing the wrong query plan and subsequently the wrong MQL. + offsetParameter, + limitParameter); + } } } diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStage.java b/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStage.java new file mode 100644 index 00000000..2cc23673 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStage.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.translate.mongoast.command.aggregate; + +import com.mongodb.hibernate.internal.translate.mongoast.AstValue; +import org.bson.BsonWriter; + +public record AstLimitStage(AstValue value) implements AstStage { + @Override + public void render(BsonWriter writer) { + writer.writeStartDocument(); + { + writer.writeName("$limit"); + value.render(writer); + } + writer.writeEndDocument(); + } +} diff --git a/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStage.java b/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStage.java new file mode 100644 index 00000000..f6ba9e7b --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStage.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.translate.mongoast.command.aggregate; + +import com.mongodb.hibernate.internal.translate.mongoast.AstValue; +import org.bson.BsonWriter; + +public record AstSkipStage(AstValue value) implements AstStage { + @Override + public void render(BsonWriter writer) { + writer.writeStartDocument(); + { + writer.writeName("$skip"); + value.render(writer); + } + writer.writeEndDocument(); + } +} diff --git a/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptorTests.java b/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptorTests.java index 011ced34..f75244fe 100644 --- a/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptorTests.java +++ b/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueDescriptorTests.java @@ -24,6 +24,6 @@ class AstVisitorValueDescriptorTests { @Test void testToString() { - assertEquals("COLLECTION_MUTATION", AstVisitorValueDescriptor.COLLECTION_MUTATION.toString()); + assertEquals("MUTATION_RESULT", AstVisitorValueDescriptor.MUTATION_RESULT.toString()); } } diff --git a/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueHolderTests.java b/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueHolderTests.java index b11d20aa..c9865e63 100644 --- a/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueHolderTests.java +++ b/src/test/java/com/mongodb/hibernate/internal/translate/AstVisitorValueHolderTests.java @@ -16,8 +16,9 @@ package com.mongodb.hibernate.internal.translate; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.COLLECTION_MUTATION; -import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.FIELD_VALUE; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.MUTATION_RESULT; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.VALUE; +import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -45,9 +46,9 @@ void beforeEach() { void testSimpleUsage() { var value = new AstLiteralValue(new BsonString("field_value")); - Runnable valueYielder = () -> astVisitorValueHolder.yield(FIELD_VALUE, value); + Runnable valueYielder = () -> astVisitorValueHolder.yield(VALUE, value); - var valueGotten = astVisitorValueHolder.execute(FIELD_VALUE, valueYielder); + var valueGotten = astVisitorValueHolder.execute(VALUE, valueYielder); assertSame(value, valueGotten); } @@ -57,15 +58,17 @@ void testRecursiveUsage() { Runnable tableInserter = () -> { Runnable fieldValueYielder = () -> { - astVisitorValueHolder.yield(FIELD_VALUE, AstParameterMarker.INSTANCE); + astVisitorValueHolder.yield(VALUE, AstParameterMarker.INSTANCE); }; - var fieldValue = astVisitorValueHolder.execute(FIELD_VALUE, fieldValueYielder); + var fieldValue = astVisitorValueHolder.execute(VALUE, fieldValueYielder); AstElement astElement = new AstElement("province", fieldValue); astVisitorValueHolder.yield( - COLLECTION_MUTATION, new AstInsertCommand("city", new AstDocument(List.of(astElement)))); + MUTATION_RESULT, + ModelMutationMqlTranslator.Result.create( + new AstInsertCommand("city", new AstDocument(List.of(astElement))), emptyList())); }; - astVisitorValueHolder.execute(COLLECTION_MUTATION, tableInserter); + astVisitorValueHolder.execute(MUTATION_RESULT, tableInserter); } @Test @@ -73,11 +76,11 @@ void testRecursiveUsage() { void testHolderNotEmptyWhenSetting() { Runnable valueYielder = () -> { - astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(new BsonString("value1"))); - astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(new BsonString("value2"))); + astVisitorValueHolder.yield(VALUE, new AstLiteralValue(new BsonString("value1"))); + astVisitorValueHolder.yield(VALUE, new AstLiteralValue(new BsonString("value2"))); }; - assertThrows(Error.class, () -> astVisitorValueHolder.execute(FIELD_VALUE, valueYielder)); + assertThrows(Error.class, () -> astVisitorValueHolder.execute(VALUE, valueYielder)); } @Test @@ -85,14 +88,14 @@ void testHolderNotEmptyWhenSetting() { void testHolderExpectingDifferentDescriptor() { Runnable valueYielder = - () -> astVisitorValueHolder.yield(FIELD_VALUE, new AstLiteralValue(new BsonString("some_value"))); + () -> astVisitorValueHolder.yield(VALUE, new AstLiteralValue(new BsonString("some_value"))); - assertThrows(Error.class, () -> astVisitorValueHolder.execute(COLLECTION_MUTATION, valueYielder)); + assertThrows(Error.class, () -> astVisitorValueHolder.execute(MUTATION_RESULT, valueYielder)); } @Test @DisplayName("Exception is thrown when no value is yielded") void testHolderStillEmpty() { - assertThrows(Error.class, () -> astVisitorValueHolder.execute(FIELD_VALUE, () -> {})); + assertThrows(Error.class, () -> astVisitorValueHolder.execute(VALUE, () -> {})); } } diff --git a/src/test/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslatorTests.java b/src/test/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslatorTests.java index 3e70c220..2e6981b6 100644 --- a/src/test/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslatorTests.java +++ b/src/test/java/com/mongodb/hibernate/internal/translate/SelectMqlTranslatorTests.java @@ -16,80 +16,87 @@ package com.mongodb.hibernate.internal.translate; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; +import static com.mongodb.hibernate.internal.translate.AstVisitorValueDescriptor.SELECT_RESULT; +import static java.lang.Integer.MAX_VALUE; +import static java.util.Collections.emptyMap; +import static org.hibernate.sql.ast.SqlTreePrinter.logSqlAst; +import static org.hibernate.sql.exec.spi.JdbcLockStrategy.NONE; -import com.mongodb.hibernate.internal.extension.service.StandardServiceRegistryScopedState; +import com.mongodb.hibernate.internal.translate.mongoast.command.AstCommand; +import java.util.List; +import java.util.Set; import org.hibernate.engine.spi.SessionFactoryImplementor; -import org.hibernate.metamodel.mapping.SelectableMapping; -import org.hibernate.persister.entity.EntityPersister; import org.hibernate.query.spi.QueryOptions; -import org.hibernate.service.spi.ServiceRegistryImplementor; -import org.hibernate.spi.NavigablePath; -import org.hibernate.sql.ast.spi.SqlAliasBaseImpl; -import org.hibernate.sql.ast.tree.expression.ColumnReference; -import org.hibernate.sql.ast.tree.from.NamedTableReference; -import org.hibernate.sql.ast.tree.from.StandardTableGroup; -import org.hibernate.sql.ast.tree.select.QuerySpec; +import org.hibernate.sql.ast.tree.Statement; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; import org.hibernate.sql.ast.tree.select.SelectStatement; -import org.hibernate.sql.results.internal.SqlSelectionImpl; +import org.hibernate.sql.exec.spi.JdbcOperationQuerySelect; +import org.hibernate.sql.exec.spi.JdbcParameterBinder; +import org.hibernate.sql.exec.spi.JdbcParameterBindings; import org.hibernate.sql.results.jdbc.spi.JdbcValuesMappingProducerProvider; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.MockMakers; -import org.mockito.junit.jupiter.MockitoExtension; +import org.jspecify.annotations.Nullable; -@ExtendWith(MockitoExtension.class) -class SelectMqlTranslatorTests { +final class SelectMqlTranslator extends AbstractMqlTranslator { - @Test - void testAffectedTableNames( - @Mock EntityPersister entityPersister, - @Mock(mockMaker = MockMakers.PROXY) SessionFactoryImplementor sessionFactory, - @Mock JdbcValuesMappingProducerProvider jdbcValuesMappingProducerProvider, - @Mock(mockMaker = MockMakers.PROXY) ServiceRegistryImplementor serviceRegistry, - @Mock StandardServiceRegistryScopedState standardServiceRegistryScopedState, - @Mock SelectableMapping selectableMapping) { + private final SelectStatement selectStatement; - var tableName = "books"; - SelectStatement selectFromTableName; - { // prepare `selectFromTableName` - doReturn(new String[] {tableName}).when(entityPersister).getQuerySpaces(); + SelectMqlTranslator(SessionFactoryImplementor sessionFactory, SelectStatement selectStatement) { + super(sessionFactory); + this.selectStatement = selectStatement; + } - var namedTableReference = new NamedTableReference(tableName, "b1_0"); + @Override + public JdbcOperationQuerySelect translate( + @Nullable JdbcParameterBindings jdbcParameterBindings, QueryOptions queryOptions) { - var querySpec = new QuerySpec(true); - var tableGroup = new StandardTableGroup( - false, - new NavigablePath("Book"), - entityPersister, - null, - namedTableReference, - new SqlAliasBaseImpl("b1"), - sessionFactory); - querySpec.getFromClause().addRoot(tableGroup); - querySpec - .getSelectClause() - .addSqlSelection(new SqlSelectionImpl( - new ColumnReference(tableGroup.getPrimaryTableReference(), selectableMapping))); - selectFromTableName = new SelectStatement(querySpec); - } - { // prepare `sessionFactory` - doReturn(serviceRegistry).when(sessionFactory).getServiceRegistry(); - doReturn(jdbcValuesMappingProducerProvider) - .when(serviceRegistry) - .requireService(eq(JdbcValuesMappingProducerProvider.class)); - doReturn(standardServiceRegistryScopedState) - .when(serviceRegistry) - .requireService(eq(StandardServiceRegistryScopedState.class)); - } + logSqlAst(selectStatement); + + checkJdbcParameterBindingsSupportability(jdbcParameterBindings); + applyQueryOptions(queryOptions); - var translator = new SelectMqlTranslator(sessionFactory, selectFromTableName); + var result = acceptAndYield((Statement) selectStatement, SELECT_RESULT); + return result.createJdbcOperationQuerySelect(selectStatement, getSessionFactory()); + } + + static final class Result { + private final AstCommand command; + private final List parameterBinders; + private final Set affectedTableNames; + private final @Nullable JdbcParameter offsetParameter; + private final @Nullable JdbcParameter limitParameter; - translator.translate(null, QueryOptions.NONE); + Result( + AstCommand command, + List parameterBinders, + Set affectedTableNames, + @Nullable JdbcParameter offsetParameter, + @Nullable JdbcParameter limitParameter) { + this.command = command; + this.parameterBinders = parameterBinders; + this.affectedTableNames = affectedTableNames; + this.offsetParameter = offsetParameter; + this.limitParameter = limitParameter; + } - assertThat(translator.getAffectedTableNames()).containsExactly(tableName); + private JdbcOperationQuerySelect createJdbcOperationQuerySelect( + SelectStatement selectStatement, SessionFactoryImplementor sessionFactory) { + var jdbcValuesMappingProducerProvider = + sessionFactory.getServiceRegistry().requireService(JdbcValuesMappingProducerProvider.class); + var jdbcValuesMappingProducer = + jdbcValuesMappingProducerProvider.buildMappingProducer(selectStatement, sessionFactory); + return new JdbcOperationQuerySelect( + renderMongoAstNode(command), + parameterBinders, + jdbcValuesMappingProducer, + affectedTableNames, + 0, + MAX_VALUE, + emptyMap(), + NONE, + // The following parameters are provided for query plan cache purposes. + // Not setting them could result in reusing the wrong query plan and subsequently the wrong MQL. + offsetParameter, + limitParameter); + } } } diff --git a/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStageTests.java b/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStageTests.java new file mode 100644 index 00000000..2521f238 --- /dev/null +++ b/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstLimitStageTests.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.translate.mongoast.command.aggregate; + +import static com.mongodb.hibernate.internal.translate.mongoast.AstNodeAssertions.assertRendering; + +import com.mongodb.hibernate.internal.translate.mongoast.AstLiteralValue; +import org.bson.BsonInt32; +import org.junit.jupiter.api.Test; + +class AstLimitStageTests { + + @Test + void testRendering() { + var limitValue = 10; + var astLimitStage = new AstLimitStage(new AstLiteralValue(new BsonInt32(limitValue))); + + var expectedJson = + """ + {"$limit": {"$numberInt": "%d"}}\ + """.formatted(limitValue); + assertRendering(expectedJson, astLimitStage); + } +} diff --git a/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStageTests.java b/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStageTests.java new file mode 100644 index 00000000..c7e5eb9a --- /dev/null +++ b/src/test/java/com/mongodb/hibernate/internal/translate/mongoast/command/aggregate/AstSkipStageTests.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.internal.translate.mongoast.command.aggregate; + +import static com.mongodb.hibernate.internal.translate.mongoast.AstNodeAssertions.assertRendering; + +import com.mongodb.hibernate.internal.translate.mongoast.AstLiteralValue; +import org.bson.BsonInt32; +import org.junit.jupiter.api.Test; + +class AstSkipStageTests { + + @Test + void testRendering() { + var skipValue = 5; + var astSkipStage = new AstSkipStage(new AstLiteralValue(new BsonInt32(skipValue))); + + var expectedJson = + """ + {"$skip": {"$numberInt": "%d"}}\ + """.formatted(skipValue); + assertRendering(expectedJson, astSkipStage); + } +}