Skip to content

Commit a3b318f

Browse files
committed
Discard further rows once maxRows has been reached
See spring-projects#34666 (comment) Signed-off-by: Yanming Zhou <[email protected]>
1 parent 5b1c552 commit a3b318f

File tree

3 files changed

+89
-19
lines changed

3 files changed

+89
-19
lines changed

Diff for: spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java

+24-16
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
* @author Rod Johnson
103103
* @author Juergen Hoeller
104104
* @author Thomas Risberg
105+
* @author Yanming Zhou
105106
* @since May 3, 2001
106107
* @see JdbcOperations
107108
* @see PreparedStatementCreator
@@ -473,12 +474,12 @@ public String getSql() {
473474

474475
@Override
475476
public void query(String sql, RowCallbackHandler rch) throws DataAccessException {
476-
query(sql, new RowCallbackHandlerResultSetExtractor(rch));
477+
query(sql, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
477478
}
478479

479480
@Override
480481
public <T> List<T> query(String sql, RowMapper<T> rowMapper) throws DataAccessException {
481-
return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper)));
482+
return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
482483
}
483484

484485
@Override
@@ -488,7 +489,7 @@ class StreamStatementCallback implements StatementCallback<Stream<T>>, SqlProvid
488489
public Stream<T> doInStatement(Statement stmt) throws SQLException {
489490
ResultSet rs = stmt.executeQuery(sql);
490491
Connection con = stmt.getConnection();
491-
return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
492+
return new ResultSetSpliterator<>(rs, rowMapper, JdbcTemplate.this.maxRows).stream().onClose(() -> {
492493
JdbcUtils.closeResultSet(rs);
493494
JdbcUtils.closeStatement(stmt);
494495
DataSourceUtils.releaseConnection(con, getDataSource());
@@ -756,12 +757,12 @@ private String appendSql(@Nullable String sql, String statement) {
756757

757758
@Override
758759
public void query(PreparedStatementCreator psc, RowCallbackHandler rch) throws DataAccessException {
759-
query(psc, new RowCallbackHandlerResultSetExtractor(rch));
760+
query(psc, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
760761
}
761762

762763
@Override
763764
public void query(String sql, @Nullable PreparedStatementSetter pss, RowCallbackHandler rch) throws DataAccessException {
764-
query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch));
765+
query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
765766
}
766767

767768
@Override
@@ -782,28 +783,28 @@ public void query(String sql, RowCallbackHandler rch, @Nullable Object @Nullable
782783

783784
@Override
784785
public <T> List<T> query(PreparedStatementCreator psc, RowMapper<T> rowMapper) throws DataAccessException {
785-
return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper)));
786+
return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
786787
}
787788

788789
@Override
789790
public <T> List<T> query(String sql, @Nullable PreparedStatementSetter pss, RowMapper<T> rowMapper) throws DataAccessException {
790-
return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper)));
791+
return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
791792
}
792793

793794
@Override
794795
public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, int[] argTypes, RowMapper<T> rowMapper) throws DataAccessException {
795-
return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper)));
796+
return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
796797
}
797798

798799
@Deprecated
799800
@Override
800801
public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, RowMapper<T> rowMapper) throws DataAccessException {
801-
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
802+
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
802803
}
803804

804805
@Override
805806
public <T> List<T> query(String sql, RowMapper<T> rowMapper, @Nullable Object @Nullable ... args) throws DataAccessException {
806-
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
807+
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
807808
}
808809

809810
/**
@@ -828,7 +829,7 @@ public <T> Stream<T> queryForStream(PreparedStatementCreator psc, @Nullable Prep
828829
}
829830
ResultSet rs = ps.executeQuery();
830831
Connection con = ps.getConnection();
831-
return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
832+
return new ResultSetSpliterator<>(rs, rowMapper, this.maxRows).stream().onClose(() -> {
832833
JdbcUtils.closeResultSet(rs);
833834
if (pss instanceof ParameterDisposer parameterDisposer) {
834835
parameterDisposer.cleanupParameters();
@@ -1347,7 +1348,7 @@ protected Map<String, Object> processResultSet(
13471348
}
13481349
else if (param.getRowCallbackHandler() != null) {
13491350
RowCallbackHandler rch = param.getRowCallbackHandler();
1350-
(new RowCallbackHandlerResultSetExtractor(rch)).extractData(rs);
1351+
(new RowCallbackHandlerResultSetExtractor(rch, -1)).extractData(rs);
13511352
return Collections.singletonMap(param.getName(),
13521353
"ResultSet returned from stored procedure was processed");
13531354
}
@@ -1730,13 +1731,17 @@ private static class RowCallbackHandlerResultSetExtractor implements ResultSetEx
17301731

17311732
private final RowCallbackHandler rch;
17321733

1733-
public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch) {
1734+
private final int maxRows;
1735+
1736+
public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch, int maxRows) {
17341737
this.rch = rch;
1738+
this.maxRows = maxRows;
17351739
}
17361740

17371741
@Override
17381742
public @Nullable Object extractData(ResultSet rs) throws SQLException {
1739-
while (rs.next()) {
1743+
int processed = 0;
1744+
while (rs.next() && (this.maxRows == -1 || (processed++) < this.maxRows)) {
17401745
this.rch.processRow(rs);
17411746
}
17421747
return null;
@@ -1754,17 +1759,20 @@ private static class ResultSetSpliterator<T> implements Spliterator<T> {
17541759

17551760
private final RowMapper<T> rowMapper;
17561761

1762+
private final int maxRows;
1763+
17571764
private int rowNum = 0;
17581765

1759-
public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper) {
1766+
public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper, int maxRows) {
17601767
this.rs = rs;
17611768
this.rowMapper = rowMapper;
1769+
this.maxRows = maxRows;
17621770
}
17631771

17641772
@Override
17651773
public boolean tryAdvance(Consumer<? super T> action) {
17661774
try {
1767-
if (this.rs.next()) {
1775+
if (this.rs.next() && (this.maxRows == -1 || this.rowNum < this.maxRows)) {
17681776
action.accept(this.rowMapper.mapRow(this.rs, this.rowNum++));
17691777
return true;
17701778
}

Diff for: spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -52,6 +52,7 @@
5252
* you can have executable query objects (containing row-mapping logic) there.
5353
*
5454
* @author Juergen Hoeller
55+
* @author Yanming Zhou
5556
* @since 1.0.2
5657
* @param <T> the result element type
5758
* @see RowMapper
@@ -64,6 +65,8 @@ public class RowMapperResultSetExtractor<T> implements ResultSetExtractor<List<T
6465

6566
private final int rowsExpected;
6667

68+
private final int maxRows;
69+
6770

6871
/**
6972
* Create a new RowMapperResultSetExtractor.
@@ -80,17 +83,29 @@ public RowMapperResultSetExtractor(RowMapper<T> rowMapper) {
8083
* (just used for optimized collection handling)
8184
*/
8285
public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected) {
86+
this(rowMapper, rowsExpected, -1);
87+
}
88+
89+
/**
90+
* Create a new RowMapperResultSetExtractor.
91+
* @param rowMapper the RowMapper which creates an object for each row
92+
* @param rowsExpected the number of expected rows
93+
* (just used for optimized collection handling)
94+
* @param maxRows the number of max rows
95+
*/
96+
public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected, int maxRows) {
8397
Assert.notNull(rowMapper, "RowMapper must not be null");
8498
this.rowMapper = rowMapper;
8599
this.rowsExpected = rowsExpected;
100+
this.maxRows = maxRows;
86101
}
87102

88103

89104
@Override
90105
public List<T> extractData(ResultSet rs) throws SQLException {
91106
List<T> results = (this.rowsExpected > 0 ? new ArrayList<>(this.rowsExpected) : new ArrayList<>());
92107
int rowNum = 0;
93-
while (rs.next()) {
108+
while (rs.next() && (this.maxRows == -1 || rowNum < this.maxRows)) {
94109
results.add(this.rowMapper.mapRow(rs, rowNum++));
95110
}
96111
return results;

Diff for: spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,7 +32,9 @@
3232
import java.util.Collections;
3333
import java.util.List;
3434
import java.util.Map;
35+
import java.util.function.BiFunction;
3536
import java.util.function.Consumer;
37+
import java.util.stream.Stream;
3638

3739
import javax.sql.DataSource;
3840

@@ -77,6 +79,7 @@
7779
* @author Thomas Risberg
7880
* @author Juergen Hoeller
7981
* @author Phillip Webb
82+
* @author Yanming Zhou
8083
*/
8184
class JdbcTemplateTests {
8285

@@ -1236,6 +1239,50 @@ public int getBatchSize() {
12361239
Collections.singletonMap("someId", 456));
12371240
}
12381241

1242+
@Test
1243+
void testSkipFurtherRowsOnceMaxRowsHasBeenReachedForRowMapper() throws Exception {
1244+
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) ->
1245+
template.query(sql, (rs, rowNum) -> rs.getString(1)));
1246+
}
1247+
1248+
@Test
1249+
void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForRowCallbackHandler() throws Exception {
1250+
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
1251+
List<String> list = new ArrayList<>();
1252+
template.query(sql, (RowCallbackHandler) rs -> list.add(rs.getString(1)));
1253+
return list;
1254+
});
1255+
}
1256+
1257+
@Test
1258+
void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForStream() throws Exception {
1259+
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
1260+
try (Stream<String> stream = template.queryForStream(sql, (rs, rowNum) -> rs.getString(1))) {
1261+
return stream.toList();
1262+
}
1263+
});
1264+
}
1265+
1266+
private void testDiscardFurtherRowsOnceMaxRowsHasBeenReached(BiFunction<JdbcTemplate,String,List<String>> function) throws Exception {
1267+
String sql = "SELECT FORENAME FROM CUSTMR";
1268+
String[] results = {"rod", "gary", " portia"};
1269+
int maxRows = 2;
1270+
1271+
given(this.resultSet.next()).willReturn(true, true, true, false);
1272+
given(this.resultSet.getString(1)).willReturn(results[0], results[1], results[2]);
1273+
given(this.connection.createStatement()).willReturn(this.preparedStatement);
1274+
1275+
JdbcTemplate template = new JdbcTemplate();
1276+
template.setDataSource(this.dataSource);
1277+
template.setMaxRows(maxRows);
1278+
1279+
assertThat(function.apply(template, sql)).as("same length").hasSize(maxRows);
1280+
1281+
verify(this.resultSet).close();
1282+
verify(this.preparedStatement).close();
1283+
verify(this.connection).close();
1284+
}
1285+
12391286
private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
12401287
DatabaseMetaData databaseMetaData = mock();
12411288
given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL");

0 commit comments

Comments
 (0)