diff --git a/cassandra/src/main/java/org/apache/calcite/adapter/cassandra/CassandraRules.java b/cassandra/src/main/java/org/apache/calcite/adapter/cassandra/CassandraRules.java
index db5d83960d0c..71f7592191b4 100644
--- a/cassandra/src/main/java/org/apache/calcite/adapter/cassandra/CassandraRules.java
+++ b/cassandra/src/main/java/org/apache/calcite/adapter/cassandra/CassandraRules.java
@@ -35,7 +35,9 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
@@ -405,16 +407,33 @@ protected CassandraLimitRule(CassandraLimitRuleConfig config) {
super(config);
}
- public RelNode convert(EnumerableLimit limit) {
+ public @Nullable RelNode convert(EnumerableLimit limit) {
+ final RexLiteral fetch =
+ limit.fetch == null
+ ? null
+ : RexUtil.reduceFetchToLiteral(limit.getCluster(), limit.fetch);
+ if (limit.fetch != null && fetch == null) {
+ return null;
+ }
+ if (fetch != null) {
+ try {
+ RexLiteral.bigDecimalValue(fetch).intValueExact();
+ } catch (ArithmeticException e) {
+ return null;
+ }
+ }
final RelTraitSet traitSet =
limit.getTraitSet().replace(CassandraRel.CONVENTION);
return new CassandraLimit(limit.getCluster(), traitSet,
- convert(limit.getInput(), CassandraRel.CONVENTION), limit.offset, limit.fetch);
+ convert(limit.getInput(), CassandraRel.CONVENTION), limit.offset, fetch);
}
@Override public void onMatch(RelOptRuleCall call) {
EnumerableLimit limit = call.rel(0);
- call.transformTo(convert(limit));
+ final RelNode converted = convert(limit);
+ if (converted != null) {
+ call.transformTo(converted);
+ }
}
/** Deprecated in favor of CassandraLimitRuleConfig. */
diff --git a/cassandra/src/test/java/org/apache/calcite/test/CassandraAdapterTest.java b/cassandra/src/test/java/org/apache/calcite/test/CassandraAdapterTest.java
index 411ffdbb65be..3d36224e6243 100644
--- a/cassandra/src/test/java/org/apache/calcite/test/CassandraAdapterTest.java
+++ b/cassandra/src/test/java/org/apache/calcite/test/CassandraAdapterTest.java
@@ -127,6 +127,34 @@ static void load(CqlSession session) {
.explainContains("CassandraLimit(fetch=[8])\n");
}
+ @Test void testFetchExpression() {
+ CalciteAssert.that()
+ .with(TWISSANDRA)
+ .query("select \"tweet_id\" from \"userline\" "
+ + "where \"username\" = '!PUBLIC!' "
+ + "fetch next (1 + abs(-2)) rows only")
+ .returnsCount(3)
+ .explainContains("CassandraLimit(fetch=[3])\n");
+ CalciteAssert.that()
+ .with(TWISSANDRA)
+ .query("select \"tweet_id\" from \"userline\" "
+ + "where \"username\" = '!PUBLIC!' "
+ + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range");
+ }
+
+ @Test void testFetchExpressionBeyondIntegerRange() {
+ CalciteAssert.that()
+ .with(TWISSANDRA)
+ .query("select \"tweet_id\" from \"userline\" "
+ + "where \"username\" = '!PUBLIC!' "
+ + "fetch next "
+ + "(cast(3000000000 as bigint) + 1) rows only")
+ .returnsCount(146)
+ .explainContains("EnumerableLimit(fetch=[3000000001:BIGINT])\n"
+ + " CassandraToEnumerableConverter\n");
+ }
+
@Test void testSortLimit() {
CalciteAssert.that()
.with(TWISSANDRA)
diff --git a/core/src/main/codegen/templates/Parser.jj b/core/src/main/codegen/templates/Parser.jj
index 2f784b34125b..09778ba05062 100644
--- a/core/src/main/codegen/templates/Parser.jj
+++ b/core/src/main/codegen/templates/Parser.jj
@@ -690,7 +690,7 @@ SqlNode ExprOrJoinOrOrderedQuery(ExprContext exprContext) :
*
*
* [ OFFSET start { ROW | ROWS } ]
- * [ FETCH { FIRST | NEXT } [ count ] { ROW | ROWS } ONLY ]
+ * [ FETCH { FIRST | NEXT } [ count | (expression) ] { ROW | ROWS } ONLY ]
*
*/
SqlNode OrderedQueryOrExpr(ExprContext exprContext) :
@@ -777,10 +777,26 @@ void FetchClause(SqlNode[] offsetFetch) :
{
// SQL:2008-style syntax. "OFFSET ... FETCH ...".
// If you specify both LIMIT and FETCH, FETCH wins.
- ( | ) offsetFetch[1] = UnsignedNumericLiteralOrParam()
+ ( | ) offsetFetch[1] = FetchCount()
( | )
}
+/**
+ * Parses the row count of a FETCH clause. Expressions must be parenthesized.
+ */
+SqlNode FetchCount() :
+{
+ final SqlNode e;
+}
+{
+ (
+ e = UnsignedNumericLiteralOrParam()
+ |
+ e = Expression(ExprContext.ACCEPT_NON_QUERY)
+ )
+ { return e; }
+}
+
/**
* Parses a LIMIT clause in an ORDER BY expression.
*/
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimit.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimit.java
index 3c046de9228d..55c22ae0560b 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimit.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimit.java
@@ -17,6 +17,11 @@
package org.apache.calcite.adapter.enumerable;
import org.apache.calcite.DataContext;
+import org.apache.calcite.linq4j.AbstractEnumerable;
+import org.apache.calcite.linq4j.Enumerable;
+import org.apache.calcite.linq4j.EnumerableDefaults;
+import org.apache.calcite.linq4j.Enumerator;
+import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
@@ -33,10 +38,14 @@
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.BuiltInMethod;
import org.checkerframework.checker.nullness.qual.Nullable;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.Comparator;
import java.util.List;
/** Relational expression that applies a limit and/or offset to its input. */
@@ -54,6 +63,7 @@ public EnumerableLimit(
@Nullable RexNode offset,
@Nullable RexNode fetch) {
super(cluster, traitSet, input);
+ validateLiteralFetch(fetch);
this.offset = offset;
this.fetch = fetch;
assert getConvention() instanceof EnumerableConvention;
@@ -110,8 +120,8 @@ public static EnumerableLimit create(final RelNode input, @Nullable RexNode offs
if (fetch != null) {
v =
builder.append("fetch",
- Expressions.call(v, BuiltInMethod.TAKE.method,
- getExpression(fetch)));
+ Expressions.call(EnumerableLimit.class, "take", v,
+ getExpressionForFetch(fetch, implementor, builder)));
}
builder.add(Expressions.return_(null, v));
@@ -127,10 +137,105 @@ static Expression getExpression(RexNode rexNode) {
Expressions.constant("?" + param.getIndex())),
Integer.class);
} else {
- // TODO: Enumerable runtime only supports INT types for FETCH and OFFSET, not BIGINT types.
- // Currently, using BIGINT types for execution will result in an error message.
+ // TODO: Enumerable runtime only supports INT types for OFFSET, not BIGINT types.
+ // Currently, using BIGINT types for OFFSET will result in an error message.
// This issue needs to be fixed. For more information, see CALCITE-7156.
return Expressions.constant(RexLiteral.intValue(rexNode));
}
}
+
+ static Expression getExpressionForFetch(RexNode rexNode,
+ EnumerableRelImplementor implementor, BlockBuilder builder) {
+ if (rexNode instanceof RexDynamicParam) {
+ final RexDynamicParam param = (RexDynamicParam) rexNode;
+ return Expressions.call(EnumerableLimit.class, "toFetchValue",
+ Expressions.convert_(
+ Expressions.call(DataContext.ROOT,
+ BuiltInMethod.DATA_CONTEXT_GET.method,
+ Expressions.constant("?" + param.getIndex())),
+ Number.class));
+ } else if (rexNode instanceof RexLiteral) {
+ return Expressions.constant(
+ toFetchValue(((RexLiteral) rexNode).getValueAs(Number.class)));
+ } else {
+ final Expression expression =
+ RexToLixTranslator.forAggregation(implementor.getTypeFactory(),
+ builder, null, implementor.getConformance())
+ .translate(rexNode);
+ return Expressions.call(EnumerableLimit.class, "toFetchValue",
+ Expressions.convert_(Expressions.box(expression), Number.class));
+ }
+ }
+
+ /** Converts a FETCH expression result to Calcite's canonical representation. */
+ public static BigDecimal toFetchValue(@Nullable Number value) {
+ return RexUtil.validateFetchValue(value);
+ }
+
+ /** Applies a FETCH value without narrowing it to {@code int} or {@code long}. */
+ public static Enumerable take(Enumerable source, BigDecimal fetch) {
+ final BigInteger count = fetch.toBigIntegerExact();
+ return new AbstractEnumerable() {
+ @Override public Enumerator enumerator() {
+ final Enumerator input = source.enumerator();
+ return new Enumerator() {
+ private BigInteger remaining = count;
+ private boolean done;
+
+ @Override public T current() {
+ return input.current();
+ }
+
+ @Override public boolean moveNext() {
+ if (done) {
+ return false;
+ }
+ if (remaining.signum() == 0 || !input.moveNext()) {
+ done = true;
+ return false;
+ }
+ // Preserve take(int)'s eager evaluation of the current row.
+ input.current();
+ remaining = remaining.subtract(BigInteger.ONE);
+ return true;
+ }
+
+ @Override public void reset() {
+ input.reset();
+ remaining = count;
+ done = false;
+ }
+
+ @Override public void close() {
+ input.close();
+ }
+ };
+ }
+ };
+ }
+
+ /** Sorts and applies FETCH while preserving an arbitrary-precision value. */
+ public static Enumerable orderBy(Enumerable source,
+ Function1 keySelector, @Nullable Comparator comparator,
+ int offset, @Nullable BigDecimal fetch) {
+ if (fetch != null
+ && fetch.compareTo(BigDecimal.valueOf(Integer.MAX_VALUE)) <= 0
+ && comparator != null) {
+ return EnumerableDefaults.orderBy(source, keySelector, comparator,
+ offset, fetch.intValueExact());
+ }
+ Enumerable result = EnumerableDefaults.orderBy(source, keySelector, comparator);
+ if (offset > 0) {
+ result = result.skip(offset);
+ }
+ return fetch == null ? result : take(result, fetch);
+ }
+
+ static void validateLiteralFetch(@Nullable RexNode fetch) {
+ if (fetch instanceof RexLiteral) {
+ final Number value = ((RexLiteral) fetch).getValueAs(Number.class);
+ toFetchValue(value);
+ }
+ }
+
}
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimitSort.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimitSort.java
index defc28f69a53..f48389f28c8e 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimitSort.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableLimitSort.java
@@ -25,12 +25,14 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;
+import java.math.BigDecimal;
+
import static org.apache.calcite.adapter.enumerable.EnumerableLimit.getExpression;
+import static org.apache.calcite.adapter.enumerable.EnumerableLimit.getExpressionForFetch;
/**
* Implementation of {@link org.apache.calcite.rel.core.Sort} in
@@ -52,6 +54,7 @@ public EnumerableLimitSort(
@Nullable RexNode offset,
@Nullable RexNode fetch) {
super(cluster, traitSet, input, collation, offset, fetch);
+ EnumerableLimit.validateLiteralFetch(fetch);
assert this.getConvention() instanceof EnumerableConvention;
assert this.getConvention() == input.getConvention();
}
@@ -98,9 +101,9 @@ public static EnumerableLimitSort create(
final Expression fetchVal;
if (this.fetch == null) {
- fetchVal = Expressions.constant(Integer.MAX_VALUE);
+ fetchVal = Expressions.constant(null, BigDecimal.class);
} else {
- fetchVal = getExpression(this.fetch);
+ fetchVal = getExpressionForFetch(this.fetch, implementor, builder);
}
final Expression offsetVal;
@@ -112,17 +115,13 @@ public static EnumerableLimitSort create(
builder.add(
Expressions.return_(null,
- Expressions.call(BuiltInMethod.ORDER_BY_WITH_FETCH_AND_OFFSET.method,
+ Expressions.call(EnumerableLimit.class, "orderBy",
Expressions.list(childExp,
builder.append("keySelector", pair.left))
.appendIfNotNull(
builder.appendIfNotNull("comparator", pair.right))
- .appendIfNotNull(
- builder.appendIfNotNull("offset",
- Expressions.constant(offsetVal)))
- .appendIfNotNull(
- builder.appendIfNotNull("fetch",
- Expressions.constant(fetchVal))))));
+ .append(builder.append("offset", offsetVal))
+ .append(builder.append("fetch", fetchVal)))));
return implementor.result(physType, builder.toBlock());
}
}
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java
index 7d47e639b78e..f0b574a43d61 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java
@@ -29,6 +29,7 @@
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
@@ -88,9 +89,13 @@ public EnumerableMergeUnionRule(Config config) {
// Push down sort limit, if possible.
RexNode inputFetch = null;
if (sort.fetch != null) {
- if (sort.offset == null) {
+ final boolean safeToRepeat =
+ RexUtil.isDeterministic(sort.fetch);
+ if (sort.offset == null && safeToRepeat) {
inputFetch = sort.fetch;
- } else if (sort.fetch instanceof RexLiteral && sort.offset instanceof RexLiteral) {
+ } else if (safeToRepeat
+ && sort.fetch instanceof RexLiteral
+ && sort.offset instanceof RexLiteral) {
inputFetch =
call.builder().literal(RexLiteral.bigDecimalValue(sort.fetch)
.add(RexLiteral.bigDecimalValue(sort.offset)));
diff --git a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
index 9a35cf7cce2d..1b080506677e 100644
--- a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
+++ b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
@@ -54,6 +54,7 @@
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
@@ -767,7 +768,18 @@ protected JdbcSortRule(Config config) {
* JDBC convention
* @return A new JdbcSort
*/
- public RelNode convert(Sort sort, boolean convertInputTraits) {
+ public @Nullable RelNode convert(Sort sort, boolean convertInputTraits) {
+ final RexNode fetch;
+ if (sort.fetch == null
+ || sort.fetch instanceof RexLiteral
+ || sort.fetch instanceof RexDynamicParam) {
+ fetch = sort.fetch;
+ } else {
+ fetch = RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch);
+ if (fetch == null) {
+ return null;
+ }
+ }
final RelTraitSet traitSet = sort.getTraitSet().replace(out);
final RelNode input;
@@ -779,7 +791,7 @@ public RelNode convert(Sort sort, boolean convertInputTraits) {
}
return new JdbcSort(sort.getCluster(), traitSet,
- input, sort.getCollation(), sort.offset, sort.fetch);
+ input, sort.getCollation(), sort.offset, fetch);
}
}
diff --git a/core/src/main/java/org/apache/calcite/interpreter/SortNode.java b/core/src/main/java/org/apache/calcite/interpreter/SortNode.java
index 71d9f2b22e42..1425689d45ce 100644
--- a/core/src/main/java/org/apache/calcite/interpreter/SortNode.java
+++ b/core/src/main/java/org/apache/calcite/interpreter/SortNode.java
@@ -16,14 +16,19 @@
*/
package org.apache.calcite.interpreter;
+import org.apache.calcite.adapter.enumerable.EnumerableLimit;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.Util;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
@@ -35,8 +40,18 @@
* {@link org.apache.calcite.rel.core.Sort}.
*/
public class SortNode extends AbstractSingleNode {
+ private final @Nullable Scalar fetchScalar;
+ private final @Nullable Context fetchContext;
+
public SortNode(Compiler compiler, Sort rel) {
super(compiler, rel);
+ if (rel.fetch != null && !(rel.fetch instanceof RexLiteral)) {
+ this.fetchScalar = compiler.compile(ImmutableList.of(rel.fetch), null);
+ this.fetchContext = compiler.createContext();
+ } else {
+ this.fetchScalar = null;
+ this.fetchContext = null;
+ }
}
private static int getValueAsInt(RexNode node) {
@@ -44,15 +59,31 @@ private static int getValueAsInt(RexNode node) {
() -> "getValueAs(Integer.class) for " + node);
}
+ private @Nullable BigDecimal getFetch() {
+ if (rel.fetch == null) {
+ return null;
+ }
+ final @Nullable Number value;
+ if (rel.fetch instanceof RexLiteral) {
+ value = ((RexLiteral) rel.fetch).getValueAs(Number.class);
+ } else {
+ final Object result =
+ requireNonNull(fetchScalar, "fetchScalar")
+ .execute(requireNonNull(fetchContext, "fetchContext"));
+ if (result != null && !(result instanceof Number)) {
+ throw new IllegalArgumentException("FETCH value is not numeric: " + result);
+ }
+ value = (Number) result;
+ }
+ return EnumerableLimit.toFetchValue(value);
+ }
+
@Override public void run() throws InterruptedException {
final int offset =
rel.offset == null
? 0
: getValueAsInt(rel.offset);
- final int fetch =
- rel.fetch == null
- ? -1
- : getValueAsInt(rel.fetch);
+ final @Nullable BigDecimal fetch = getFetch();
// In pure limit mode. No sort required.
Row row;
loop:
@@ -63,9 +94,12 @@ private static int getValueAsInt(RexNode node) {
break loop;
}
}
- if (fetch >= 0) {
- for (int i = 0; i < fetch && (row = source.receive()) != null; i++) {
+ if (fetch != null) {
+ BigDecimal fetched = BigDecimal.ZERO;
+ while (fetched.compareTo(fetch) < 0
+ && (row = source.receive()) != null) {
sink.send(row);
+ fetched = fetched.add(BigDecimal.ONE);
}
} else {
while ((row = source.receive()) != null) {
@@ -79,9 +113,11 @@ private static int getValueAsInt(RexNode node) {
list.add(row);
}
list.sort(comparator());
- final int end = fetch < 0 || offset + fetch > list.size()
+ final int available = Math.max(list.size() - offset, 0);
+ final int end = fetch == null
+ || fetch.compareTo(BigDecimal.valueOf(available)) >= 0
? list.size()
- : offset + fetch;
+ : offset + fetch.intValueExact();
for (int i = offset; i < end; i++) {
sink.send(list.get(i));
}
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
index 1f6502243626..ffd9fb8e7c0f 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
@@ -25,6 +25,7 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
@@ -56,6 +57,7 @@
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
+import java.util.Objects;
import java.util.Set;
import static com.google.common.base.Preconditions.checkArgument;
@@ -1043,8 +1045,16 @@ private static boolean alreadySmaller(RelMetadataQuery mq, RelNode input,
if (fetch == null) {
return true;
}
+ final RelNode strippedInput = input.stripped();
+ if (strippedInput instanceof Sort) {
+ final Sort sort = (Sort) strippedInput;
+ if (Objects.equals(offset, sort.offset)
+ && Objects.equals(fetch, sort.fetch)) {
+ return true;
+ }
+ }
final Double rowCount = mq.getMaxRowCount(input);
- if (rowCount == null || offset instanceof RexDynamicParam || fetch instanceof RexDynamicParam) {
+ if (rowCount == null || offset instanceof RexDynamicParam || !(fetch instanceof RexLiteral)) {
// Cannot be determined
return false;
}
diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
index 0404fa7bc51c..12cbd7831200 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java
@@ -54,11 +54,14 @@
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexDynamicParam;
+import org.apache.calcite.rex.RexExecutor;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlAsofJoin;
@@ -1191,7 +1194,7 @@ public Result visit(Sort e) {
sqlSelect.setOffset(offset);
}
if (e.fetch != null) {
- SqlNode fetch = builder.context.toSql(null, e.fetch);
+ SqlNode fetch = toSqlFetch(e, builder.context);
sqlSelect.setFetch(fetch);
}
return result(sqlSelect, ImmutableList.of(Clause.ORDER_BY), e, null);
@@ -1249,13 +1252,32 @@ public Result visit(Sort e) {
* The builder must have been created with OFFSET and FETCH clauses. */
void offsetFetch(Sort e, Builder builder) {
if (e.fetch != null) {
- builder.setFetch(builder.context.toSql(null, e.fetch));
+ builder.setFetch(toSqlFetch(e, builder.context));
}
if (e.offset != null) {
builder.setOffset(builder.context.toSql(null, e.offset));
}
}
+ private static SqlNode toSqlFetch(Sort sort, Context context) {
+ final RexNode fetch = requireNonNull(sort.fetch, "fetch");
+ if (fetch instanceof RexLiteral
+ || fetch instanceof RexDynamicParam
+ || !RexUtil.isConstant(fetch)
+ || !RexUtil.isDeterministic(fetch)
+ || RexUtil.containsDynamicFunction(fetch)
+ || RexUtil.containsDynamicParam(fetch)) {
+ return context.toSql(null, fetch);
+ }
+ final RexExecutor executor =
+ Util.first(sort.getCluster().getPlanner().getExecutor(), RexUtil.EXECUTOR);
+ final List reducedValues = new ArrayList<>(1);
+ executor.reduce(sort.getCluster().getRexBuilder(),
+ Collections.singletonList(fetch), reducedValues);
+ final RexNode reduced = reducedValues.get(0);
+ return context.toSql(null, reduced instanceof RexLiteral ? reduced : fetch);
+ }
+
public boolean hasTrickyRollup(Sort e, Aggregate aggregate) {
return !dialect.supportsAggregateFunction(SqlKind.ROLLUP)
&& dialect.supportsGroupByWithRollup()
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/MeasureRules.java b/core/src/main/java/org/apache/calcite/rel/rules/MeasureRules.java
index 037a4d605459..f69810a14d42 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/MeasureRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/MeasureRules.java
@@ -30,7 +30,6 @@
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
@@ -508,8 +507,7 @@ protected ProjectSortMeasureRule(ProjectSortMeasureRuleConfig config) {
relBuilder.push(sort.getInput())
.projectPlus(map.keySet())
- .sortLimit(sort.offset == null ? 0 : RexLiteral.numberValue(sort.offset),
- sort.fetch == null ? -1 : RexLiteral.numberValue(sort.fetch),
+ .sortLimit(sort.offset, sort.fetch,
sort.getSortExps())
.project(newProjects);
call.transformTo(relBuilder.build());
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/PruneEmptyRules.java b/core/src/main/java/org/apache/calcite/rel/rules/PruneEmptyRules.java
index 5d70c5e0dec4..5cb60bc90652 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/PruneEmptyRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/PruneEmptyRules.java
@@ -41,7 +41,6 @@
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
@@ -519,9 +518,8 @@ public interface SortFetchZeroRuleConfig extends PruneEmptyRule.Config {
return new RemoveEmptySingleRule(this) {
@Override public boolean matches(final RelOptRuleCall call) {
Sort sort = call.rel(0);
- return sort.fetch != null
- && !(sort.fetch instanceof RexDynamicParam)
- && RexLiteral.bigDecimalValue(sort.fetch).equals(BigDecimal.ZERO);
+ return sort.fetch instanceof RexLiteral
+ && BigDecimal.ZERO.equals(RexLiteral.bigDecimalValue(sort.fetch));
}
};
}
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SortJoinTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SortJoinTransposeRule.java
index 4310d6d65576..df967e56aad6 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SortJoinTransposeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SortJoinTransposeRule.java
@@ -105,9 +105,9 @@ public SortJoinTransposeRule(Class extends Sort> sortClass,
final Sort sort = call.rel(0);
final Join join = call.rel(1);
- // Do nothing if SORT contains dynamic parameters in offset or fetch
+ // The pushed fetch is calculated from literal offset and fetch values.
if (sort.offset instanceof RexDynamicParam
- || sort.fetch instanceof RexDynamicParam) {
+ || sort.fetch != null && !(sort.fetch instanceof RexLiteral)) {
return false;
}
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SortRemoveRedundantRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SortRemoveRedundantRule.java
index 9bcf026fc656..08563cdcdb86 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SortRemoveRedundantRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SortRemoveRedundantRule.java
@@ -133,6 +133,9 @@ protected SortRemoveRedundantRule(final SortRemoveRedundantRule.Config config) {
private static Optional getRowCountThreshold(Sort sort) {
if (RelOptUtil.isLimit(sort)) {
assert sort.fetch != null;
+ if (!(sort.fetch instanceof RexLiteral)) {
+ return Optional.empty();
+ }
final BigDecimal fetch = RexLiteral.bigDecimalValue(sort.fetch);
// We don't need to deal with fetch is 0.
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SortUnionTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SortUnionTransposeRule.java
index 416825ee926d..93b6af657c43 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SortUnionTransposeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SortUnionTransposeRule.java
@@ -23,7 +23,7 @@
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
-import org.apache.calcite.rex.RexDynamicParam;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilderFactory;
import org.immutables.value.Value;
@@ -67,13 +67,14 @@ public SortUnionTransposeRule(
@Override public boolean matches(RelOptRuleCall call) {
final Sort sort = call.rel(0);
final Union union = call.rel(1);
- // We only apply this rule if Union.all is true, Sort.offset is null and Sort.fetch is not
- // a dynamic param.
+ // Re-evaluating a non-deterministic FETCH in every branch can produce a
+ // different limit from the top Sort.
// There is a flag indicating if this rule should be applied when
// Sort.fetch is null.
return union.all
&& sort.offset == null
- && !(sort.fetch instanceof RexDynamicParam)
+ && (sort.fetch == null
+ || RexUtil.isDeterministic(sort.fetch))
&& (config.matchNullFetch() || sort.fetch != null);
}
diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
index 3604e98dfd5b..9c8e60c2cf3e 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
@@ -19,6 +19,7 @@
import org.apache.calcite.DataContexts;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.plan.PlanTooComplexError;
+import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptPredicateList;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelCollation;
@@ -48,6 +49,7 @@
import org.apache.calcite.util.ControlFlowException;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.NumberUtil;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.RangeSets;
import org.apache.calcite.util.Sarg;
@@ -63,9 +65,11 @@
import org.apiguardian.api.API;
import org.checkerframework.checker.nullness.qual.Nullable;
+import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -840,6 +844,88 @@ public static boolean isDeterministic(RexNode e) {
}
}
+ /** Returns whether an expression contains a dynamic function. */
+ public static boolean containsDynamicFunction(RexNode e) {
+ try {
+ e.accept(
+ new RexVisitorImpl(true) {
+ @Override public Void visitCall(RexCall call) {
+ if (call.getOperator().isDynamicFunction()) {
+ throw Util.FoundOne.NULL;
+ }
+ return super.visitCall(call);
+ }
+ });
+ return false;
+ } catch (Util.FoundOne ex) {
+ Util.swallow(ex, null);
+ return true;
+ }
+ }
+
+ /** Returns whether an expression contains a dynamic parameter. */
+ public static boolean containsDynamicParam(RexNode e) {
+ try {
+ e.accept(
+ new RexVisitorImpl(true) {
+ @Override public Void visitDynamicParam(RexDynamicParam dynamicParam) {
+ throw Util.FoundOne.NULL;
+ }
+ });
+ return false;
+ } catch (Util.FoundOne ex) {
+ Util.swallow(ex, null);
+ return true;
+ }
+ }
+
+ /** Converts a FETCH expression result to its validated canonical representation. */
+ public static BigDecimal validateFetchValue(@Nullable Number value) {
+ if (value == null) {
+ throw new IllegalArgumentException("FETCH expression evaluated to NULL");
+ }
+ final BigDecimal decimal = NumberUtil.toBigDecimal(value);
+ try {
+ decimal.toBigIntegerExact();
+ } catch (ArithmeticException e) {
+ throw new IllegalArgumentException("FETCH value " + value
+ + " is not an integer", e);
+ }
+ if (decimal.signum() < 0) {
+ throw new IllegalArgumentException("FETCH value " + value
+ + " is out of range; expected a non-negative value");
+ }
+ return decimal;
+ }
+
+ /** Reduces a constant FETCH expression to a validated literal. */
+ public static @Nullable RexLiteral reduceFetchToLiteral(
+ RelOptCluster cluster, RexNode fetch) {
+ final RexLiteral literal;
+ if (fetch instanceof RexLiteral) {
+ literal = (RexLiteral) fetch;
+ } else {
+ if (!isConstant(fetch)
+ || !isDeterministic(fetch)
+ || containsDynamicFunction(fetch)
+ || containsDynamicParam(fetch)) {
+ return null;
+ }
+ final RexExecutor executor =
+ Util.first(cluster.getPlanner().getExecutor(), EXECUTOR);
+ final List reducedValues = new ArrayList<>(1);
+ executor.reduce(cluster.getRexBuilder(),
+ Collections.singletonList(fetch), reducedValues);
+ final RexNode reduced = reducedValues.get(0);
+ if (!(reduced instanceof RexLiteral)) {
+ return null;
+ }
+ literal = (RexLiteral) reduced;
+ }
+ validateFetchValue(literal.getValueAs(Number.class));
+ return literal;
+ }
+
public static List retainDeterministic(List list) {
List conjunctions = new ArrayList<>();
for (RexNode x : list) {
diff --git a/core/src/main/java/org/apache/calcite/runtime/CalciteResource.java b/core/src/main/java/org/apache/calcite/runtime/CalciteResource.java
index a9cb02d065cb..9d77b69cd9b3 100644
--- a/core/src/main/java/org/apache/calcite/runtime/CalciteResource.java
+++ b/core/src/main/java/org/apache/calcite/runtime/CalciteResource.java
@@ -164,6 +164,15 @@ ExInstWithCause validatorContext(int a0, int a1,
@BaseMessage("Values passed to {0} operator must have compatible types")
ExInst incompatibleValueType(String a0);
+ @BaseMessage("FETCH expression must have an integral numeric type; actual type is ''{0}''")
+ ExInst fetchExpressionMustBeIntegral(String type);
+
+ @BaseMessage("FETCH expression cannot reference table column ''{0}''")
+ ExInst fetchExpressionCannotReferenceColumn(String column);
+
+ @BaseMessage("FETCH expression evaluated to NULL")
+ ExInst fetchExpressionEvaluatedToNull();
+
@BaseMessage("Values in expression list must have compatible types")
ExInst incompatibleTypesInList();
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
index 869976f0d42a..663aee63b0db 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
@@ -1081,7 +1081,18 @@ protected static void unparseFetchUsingAnsi(SqlWriter writer, @Nullable SqlNode
writer.startList(SqlWriter.FrameTypeEnum.FETCH);
writer.keyword("FETCH");
writer.keyword("NEXT");
- fetch.unparse(writer, -1, -1);
+ if (fetch instanceof SqlLiteral
+ || fetch instanceof SqlDynamicParam) {
+ fetch.unparse(writer, -1, -1);
+ } else {
+ final SqlWriter.Frame expressionFrame = writer.startList("(", ")");
+ if (fetch instanceof SqlCall) {
+ writer.getDialect().unparseCall(writer, (SqlCall) fetch, 0, 0);
+ } else {
+ fetch.unparse(writer, 0, 0);
+ }
+ writer.endList(expressionFrame);
+ }
writer.keyword("ROWS");
writer.keyword("ONLY");
writer.endList(fetchFrame);
@@ -1091,13 +1102,32 @@ protected static void unparseFetchUsingAnsi(SqlWriter writer, @Nullable SqlNode
/** Unparses offset/fetch using "LIMIT fetch OFFSET offset" syntax. */
protected static void unparseFetchUsingLimit(SqlWriter writer, @Nullable SqlNode offset,
@Nullable SqlNode fetch) {
+ unparseFetchUsingLimit(writer, offset, fetch, false);
+ }
+
+ /** Unparses offset/fetch using "LIMIT fetch OFFSET offset" syntax,
+ * optionally allowing a scalar expression as fetch. */
+ protected static void unparseFetchUsingLimit(SqlWriter writer, @Nullable SqlNode offset,
+ @Nullable SqlNode fetch, boolean allowExpression) {
checkArgument(fetch != null || offset != null);
- unparseLimit(writer, fetch);
+ unparseLimit(writer, fetch, allowExpression);
unparseOffset(writer, offset);
}
protected static void unparseLimit(SqlWriter writer, @Nullable SqlNode fetch) {
+ unparseLimit(writer, fetch, false);
+ }
+
+ private static void unparseLimit(SqlWriter writer, @Nullable SqlNode fetch,
+ boolean allowExpression) {
if (fetch != null) {
+ if (!allowExpression
+ && !(fetch instanceof SqlLiteral)
+ && !(fetch instanceof SqlDynamicParam)) {
+ throw new IllegalArgumentException(
+ "LIMIT dialect does not support FETCH expressions that cannot "
+ + "be reduced to a literal");
+ }
writer.newlineAndIndent();
final SqlWriter.Frame fetchFrame =
writer.startList(SqlWriter.FrameTypeEnum.FETCH);
diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/SqliteSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/SqliteSqlDialect.java
index 82376ae576ab..f31276413600 100644
--- a/core/src/main/java/org/apache/calcite/sql/dialect/SqliteSqlDialect.java
+++ b/core/src/main/java/org/apache/calcite/sql/dialect/SqliteSqlDialect.java
@@ -90,7 +90,7 @@ public SqliteSqlDialect(SqlDialect.Context context) {
@Override public void unparseOffsetFetch(SqlWriter writer, @Nullable SqlNode offset,
@Nullable SqlNode fetch) {
- unparseFetchUsingLimit(writer, offset, fetch);
+ unparseFetchUsingLimit(writer, offset, fetch, true);
}
@Override public void unparseCall(SqlWriter writer, SqlCall call,
diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
index 370ef5ecc39a..c1143ff1c7c4 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
@@ -1745,6 +1745,35 @@ private void handleOffsetFetch(@Nullable SqlNode offset, @Nullable SqlNode fetch
}
}
+ private void validateFetchExpression(@Nullable SqlNode fetch) {
+ if (fetch == null || fetch instanceof SqlDynamicParam) {
+ return;
+ }
+ if (SqlUtil.isNullLiteral(fetch, true)) {
+ throw newValidationError(fetch,
+ RESOURCE.fetchExpressionEvaluatedToNull());
+ }
+ validateNoAggs(aggOrOverFinder, fetch, "FETCH");
+ fetch.accept(new SqlBasicVisitor() {
+ @Override public Void visit(SqlIdentifier id) {
+ if (makeNullaryCall(id) != null) {
+ return null;
+ }
+ throw newValidationError(id,
+ RESOURCE.fetchExpressionCannotReferenceColumn(id.toString()));
+ }
+ });
+ final SqlValidatorScope scope = getEmptyScope();
+ inferUnknownTypes(typeFactory.createSqlType(SqlTypeName.INTEGER), scope, fetch);
+ validateExpr(fetch, scope);
+ final RelDataType type = getValidatedNodeType(fetch);
+ if (!SqlTypeUtil.isIntType(type)
+ && !(SqlTypeUtil.isDecimal(type) && type.getScale() == 0)) {
+ throw newValidationError(fetch,
+ RESOURCE.fetchExpressionMustBeIntegral(type.getFullTypeString()));
+ }
+ }
+
/**
* Performs expression rewrites which are always used unconditionally. These
* rewrites massage the expression tree into a standard form so that the
@@ -4442,6 +4471,7 @@ protected void validateSelect(
validateWindowClause(select);
validateQualifyClause(select);
handleOffsetFetch(select.getOffset(), select.getFetch());
+ validateFetchExpression(select.getFetch());
// Validate the SELECT clause late, because a select item might
// depend on the GROUP BY list, or the window function might reference
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 14d72f8dfe69..2b75f0cdb324 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -574,6 +574,10 @@ protected RexNode removeCorrelationExpr(
// Its output does not change the input ordering, so there's no
// need to call propagateExpr.
+ if (isCorVarDefined && !canDecorrelateFetch(rel)) {
+ return null;
+ }
+
final RelNode oldInput = rel.getInput();
final Frame frame = getInvoke(oldInput, isCorVarDefined, rel, true);
if (frame == null) {
@@ -1136,8 +1140,13 @@ private static void shiftMapping(Map mapping, int startIndex,
return register(sort, result, mapOldToNewOutputs, corDefOutputs);
}
+ static boolean canDecorrelateFetch(Sort sort) {
+ return sort.fetch == null
+ || RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch) != null;
+ }
+
protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
- if (sort.offset != null || sort.fetch == null) {
+ if (sort.offset != null || !(sort.fetch instanceof RexLiteral)) {
return null;
}
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/TopDownGeneralDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/TopDownGeneralDecorrelator.java
index c3d2bd924961..c80e91d6de57 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/TopDownGeneralDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/TopDownGeneralDecorrelator.java
@@ -616,6 +616,10 @@ public RelNode unnestInternal(Aggregate aggregate, boolean allowEmptyOutputFromR
}
public RelNode unnestInternal(Sort sort, boolean allowEmptyOutputFromRewrite) {
+ if (!RelDecorrelator.canDecorrelateFetch(sort)) {
+ throw new UnsupportedOperationException(
+ "Cannot decorrelate Sort with a runtime FETCH expression");
+ }
RelNode newInput = unnest(sort.getInput(), allowEmptyOutputFromRewrite);
UnnestedQuery inputInfo =
requireNonNull(mapRelToUnnestedQuery.get(sort.getInput()));
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index cf9ccecdefef..48d96afe1c10 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -79,13 +79,22 @@
import org.apache.calcite.rex.RexExecutor;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLambda;
+import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexNodeAndFieldIndex;
+import org.apache.calcite.rex.RexOver;
+import org.apache.calcite.rex.RexPatternFieldRef;
+import org.apache.calcite.rex.RexRangeRef;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSimplify;
import org.apache.calcite.rex.RexSubQuery;
+import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexUnknownAs;
import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.rex.RexWindowExclusion;
@@ -108,6 +117,7 @@
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.type.TableFunctionReturnTypeInference;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.sql2rel.SqlToRelConverter;
@@ -3797,8 +3807,7 @@ public RelBuilder sortLimit(Number offset, Number fetch,
*
* @param offsetNode RexLiteral means number of rows to skip is deterministic,
* RexDynamicParam means number of rows to skip is dynamic.
- * @param fetchNode RexLiteral means maximum number of rows to fetch is deterministic,
- * RexDynamicParam mean maximum number is dynamic.
+ * @param fetchNode Maximum number of rows to fetch
* @param nodes Sort expressions
*/
public RelBuilder sortLimit(@Nullable RexNode offsetNode, @Nullable RexNode fetchNode,
@@ -3808,12 +3817,19 @@ public RelBuilder sortLimit(@Nullable RexNode offsetNode, @Nullable RexNode fetc
throw new IllegalArgumentException("OFFSET node must be RexLiteral or RexDynamicParam");
}
}
- if (fetchNode != null) {
- if (!(fetchNode instanceof RexLiteral || fetchNode instanceof RexDynamicParam)) {
- throw new IllegalArgumentException("FETCH node must be RexLiteral or RexDynamicParam");
- }
+ if (fetchNode != null && isInvalidFetchExpression(fetchNode)) {
+ throw new IllegalArgumentException(
+ "FETCH node must not reference input fields or contain aggregate functions, "
+ + "window functions, or subqueries");
+ }
+ if (fetchNode != null
+ && !SqlTypeUtil.isIntType(fetchNode.getType())
+ && !(SqlTypeUtil.isDecimal(fetchNode.getType())
+ && fetchNode.getType().getScale() == 0)) {
+ throw new IllegalArgumentException(
+ "FETCH node must have an integral numeric type; actual type is "
+ + fetchNode.getType().getFullTypeString());
}
-
final Registrar registrar = new Registrar(fields(), ImmutableList.of());
final List fieldCollations =
registrar.registerFieldCollations(nodes);
@@ -3880,6 +3896,68 @@ public RelBuilder sortLimit(@Nullable RexNode offsetNode, @Nullable RexNode fetc
return this;
}
+ private static boolean isInvalidFetchExpression(RexNode node) {
+ try {
+ node.accept(
+ new RexVisitorImpl(true) {
+ @Override public Void visitCall(RexCall call) {
+ if (call.getOperator().isAggregator()) {
+ throw Util.FoundOne.NULL;
+ }
+ return super.visitCall(call);
+ }
+
+ @Override public Void visitInputRef(RexInputRef inputRef) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitLocalRef(RexLocalRef localRef) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitTableInputRef(RexTableInputRef ref) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitPatternFieldRef(RexPatternFieldRef fieldRef) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitRangeRef(RexRangeRef rangeRef) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitOver(RexOver over) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitSubQuery(RexSubQuery subQuery) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitNodeAndFieldIndex(
+ RexNodeAndFieldIndex nodeAndFieldIndex) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitLambda(RexLambda lambda) {
+ throw Util.FoundOne.NULL;
+ }
+
+ @Override public Void visitLambdaRef(RexLambdaRef lambdaRef) {
+ throw Util.FoundOne.NULL;
+ }
+ });
+ return false;
+ } catch (Util.FoundOne e) {
+ return true;
+ }
+ }
+
private static RelFieldCollation collation(RexNode node,
RelFieldCollation.Direction direction,
RelFieldCollation.@Nullable NullDirection nullDirection,
diff --git a/core/src/main/resources/org/apache/calcite/runtime/CalciteResource.properties b/core/src/main/resources/org/apache/calcite/runtime/CalciteResource.properties
index 056aeb7b0715..6cacd96bb195 100644
--- a/core/src/main/resources/org/apache/calcite/runtime/CalciteResource.properties
+++ b/core/src/main/resources/org/apache/calcite/runtime/CalciteResource.properties
@@ -61,6 +61,9 @@ ValidatorContext=From line {0,number,#}, column {1,number,#} to line {2,number,#
CannotCastValue=Cast function cannot convert value of type {0} to type {1}
UnknownDatatypeName=Unknown datatype name ''{0}''
IncompatibleValueType=Values passed to {0} operator must have compatible types
+FetchExpressionMustBeIntegral=FETCH expression must have an integral numeric type; actual type is ''{0}''
+FetchExpressionCannotReferenceColumn=FETCH expression cannot reference table column ''{0}''
+FetchExpressionEvaluatedToNull=FETCH expression evaluated to NULL
IncompatibleTypesInList=Values in expression list must have compatible types
IncompatibleCharset=Cannot apply operation ''{0}'' to strings with different charsets ''{1}'' and ''{2}''
InvalidOrderByPos=ORDER BY is only allowed on top-level SELECT
diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index 66afda8653b2..bb667b62c775 100644
--- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -4831,6 +4831,48 @@ private SqlDialect nonOrdinalDialect() {
.withSybase().ok(expectedSybase);
}
+ @Test void testFetchExpressionWithLimitDialect() {
+ final String query = "select \"product_id\"\n"
+ + "from \"product\"\n"
+ + "fetch next (1 + 2) rows only";
+ final String expected = "SELECT `product_id`\n"
+ + "FROM `foodmart`.`product`\n"
+ + "LIMIT 3";
+ sql(query).withMysql().ok(expected);
+ }
+
+ @Test void testParameterizedFetchExpressionWithLimitDialect() {
+ final String query = "select \"product_id\"\n"
+ + "from \"product\"\n"
+ + "fetch next (? + 1) rows only";
+ sql(query).withMysql().throws_(
+ "LIMIT dialect does not support FETCH expressions that cannot "
+ + "be reduced to a literal");
+ }
+
+ @Test void testParameterizedFetchExpressionWithSQLite() {
+ final String query = "select \"product_id\"\n"
+ + "from \"product\"\n"
+ + "fetch next (? + 1) rows only";
+ final String expected = "SELECT \"product_id\"\n"
+ + "FROM \"foodmart\".\"product\"\n"
+ + "LIMIT ? + 1";
+ sql(query).withSQLite().ok(expected);
+ }
+
+ @Test void testDynamicFetchExpressionIsNotReduced() {
+ final String query = "select \"product_id\"\n"
+ + "from \"product\"\n"
+ + "fetch next (extract(day from current_date)) rows only";
+ final String expected = "SELECT \"product_id\"\n"
+ + "FROM \"foodmart\".\"product\"\n"
+ + "FETCH NEXT (EXTRACT(DAY FROM CURRENT_DATE)) ROWS ONLY";
+ sql(query).ok(expected);
+ sql(query).withMysql().throws_(
+ "LIMIT dialect does not support FETCH expressions that cannot "
+ + "be reduced to a literal");
+ }
+
@Test void testSelectQueryComplex() {
String query =
"select count(*), \"units_per_case\" from \"product\" where \"cases_per_pallet\" > 100 "
diff --git a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
index 2bdc7a1d56c4..004ac0514faa 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
@@ -3564,6 +3564,18 @@ private void assertTypeAndToString(
hasSize(0));
}
+ @Test void testContainsDynamicParam() {
+ final RelDataType intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
+ final RexNode literal = rexBuilder.makeExactLiteral(BigDecimal.ONE, intType);
+ final RexNode dynamicParam = rexBuilder.makeDynamicParam(intType, 0);
+ final RexNode expression =
+ rexBuilder.makeCall(SqlStdOperatorTable.PLUS, literal, dynamicParam);
+
+ assertThat(RexUtil.containsDynamicParam(literal), is(false));
+ assertThat(RexUtil.containsDynamicParam(dynamicParam), is(true));
+ assertThat(RexUtil.containsDynamicParam(expression), is(true));
+ }
+
@Test void testConstantMap() {
final RelDataType intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
final RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT);
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java b/core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java
index df12fbe8a8d1..942fabb670f9 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java
@@ -1386,6 +1386,60 @@ private LockWrapper exclusiveCleanDb(Connection c) throws SQLException {
.planHasSql("SELECT \"EMPNO\", \"ENAME\"\nFROM \"SCOTT\".\"EMP\"\nWHERE \"EMPNO\" = ?");
}
+ @Test void testFetchExpressionPushDown() {
+ CalciteAssert.model(JdbcTest.SCOTT_MODEL)
+ .query("select empno from scott.emp "
+ + "fetch next (1 + abs(-2)) rows only")
+ .explainContains("JdbcSort(fetch=[3])")
+ .returnsCount(3);
+ }
+
+ @Test void testParameterizedFetchExpressionIsNotPushedDown() {
+ CalciteAssert.model(JdbcTest.SCOTT_MODEL)
+ .query("select empno from scott.emp "
+ + "fetch next (? + 1) rows only")
+ .consumesPreparedStatement(p -> p.setInt(1, 2))
+ .explainContains("EnumerableLimit(fetch=[+(?0, 1)])\n"
+ + " JdbcToEnumerableConverter\n")
+ .returnsCount(3);
+ }
+
+ @Test void testDynamicFetchExpressionIsNotPushedDown() {
+ CalciteAssert.model(JdbcTest.SCOTT_MODEL)
+ .query("select empno from scott.emp "
+ + "fetch next (extract(day from current_date)) rows only")
+ .explainContains("EnumerableLimit(fetch=[EXTRACT(FLAG(DAY), CURRENT_DATE)])\n"
+ + " JdbcToEnumerableConverter\n");
+ }
+
+ @Test void testParameterizedFetchExpressionRepeatedExecution() throws Exception {
+ CalciteAssert.model(JdbcTest.SCOTT_MODEL)
+ .doWithConnection(connection -> {
+ final String sql = "select empno from scott.emp order by empno "
+ + "fetch next (? + 1) rows only";
+ try (PreparedStatement p = connection.prepareStatement(sql)) {
+ p.setInt(1, 0);
+ try (ResultSet resultSet = p.executeQuery()) {
+ assertThat(rowCount(resultSet), is(1));
+ }
+ p.setInt(1, 2);
+ try (ResultSet resultSet = p.executeQuery()) {
+ assertThat(rowCount(resultSet), is(3));
+ }
+ } catch (SQLException e) {
+ throw TestUtil.rethrow(e);
+ }
+ });
+ }
+
+ private static int rowCount(ResultSet resultSet) throws SQLException {
+ int count = 0;
+ while (resultSet.next()) {
+ count++;
+ }
+ return count;
+ }
+
/**
* Test case for
* [CALCITE-4619]
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 78098912edfc..43afce9ade4a 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -3765,6 +3765,87 @@ public void checkOrderBy(final boolean desc,
+ "store_id=4; grocery_sqft=16844\n");
}
+ /** Tests FETCH with a parenthesized expression. */
+ @Test void testFetchExpression() {
+ CalciteAssert.that()
+ .query("select * from (values (1), (2), (3), (4)) as t(x)\n"
+ + "fetch next (1 + abs(-2)) rows only")
+ .returns("X=1\n"
+ + "X=2\n"
+ + "X=3\n");
+ }
+
+ /** Tests FETCH expressions in bindable/interpreter convention. */
+ @Test void testBindableFetchExpression() {
+ try (Hook.Closeable ignored = Hook.ENABLE_BINDABLE.addThread(Hook.propertyJ(true))) {
+ final CalciteAssert.AssertThat with = CalciteAssert.that();
+ with
+ .query("select * from (values (1), (2), (3), (4)) as t(x)\n"
+ + "fetch next (rand_integer(1) + 2) rows only")
+ .explainContains("BindableSort(fetch=[+(RAND_INTEGER(1), 2)])")
+ .returns("X=1\n"
+ + "X=2\n");
+ with.query("select * from (values (1), (2), (3), (4)) as t(x)\n"
+ + "fetch next (cast(9223372036854775808 as decimal(20, 0))) rows only")
+ .returns("X=1\nX=2\nX=3\nX=4\n");
+ }
+ }
+
+ /** Tests scalar functions with positive and negative arguments in FETCH. */
+ @Test void testFetchExpressionFunctionArguments() {
+ final CalciteAssert.AssertThat with = CalciteAssert.that();
+ final String values = "select * from (values (1), (2), (3)) as t(x)\n";
+ with.query(values + "fetch next (abs(2)) rows only")
+ .returns("X=1\n"
+ + "X=2\n");
+ with.query(values + "fetch next (abs(-2)) rows only")
+ .returns("X=1\n"
+ + "X=2\n");
+ }
+
+ /** Tests invalid runtime values produced by FETCH expressions. */
+ @Test void testFetchExpressionInvalidValue() {
+ final CalciteAssert.AssertThat with = CalciteAssert.that();
+ final String values = "select * from (values (1), (2), (3)) as t(x)\n";
+ with.query(values + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range; expected a non-negative value");
+ with.query(values + "fetch next (-1) rows only")
+ .throws_("FETCH value -1 is out of range; expected a non-negative value");
+ with.query(values
+ + "fetch next (cast(null as integer)) rows only")
+ .throws_("FETCH expression evaluated to NULL");
+ }
+
+ @Test void testCorrelatedFetchExpressionInvalidValue() {
+ final String sqlPrefix = "select d.\"name\", e.\"name\"\n"
+ + "from \"hr\".\"depts\" d,\n"
+ + "lateral (select \"name\" from \"hr\".\"emps\"\n"
+ + " where \"deptno\" = d.\"deptno\"\n";
+ for (String fetch : new String[] {"(0 - 1)", "(-1)"}) {
+ for (boolean topDown : new boolean[] {false, true}) {
+ CalciteAssert.hr()
+ .with(CalciteConnectionProperty.TOPDOWN_GENERAL_DECORRELATION_ENABLED, topDown)
+ .query(sqlPrefix + " fetch next " + fetch + " rows only) e")
+ .throws_("FETCH value -1 is out of range");
+ }
+ }
+ }
+
+ /** Tests FETCH values beyond the range of BIGINT. */
+ @Test void testFetchExpressionBeyondLong() {
+ final CalciteAssert.AssertThat with = CalciteAssert.that();
+ final String values = "select * from (values (1), (2), (3), (4)) as t(x)\n";
+ final String expected = "X=1\nX=2\nX=3\nX=4\n";
+ with.query(values + "fetch next 9223372036854775808 rows only")
+ .returns(expected);
+ with.query(values + "fetch next "
+ + "(cast(9223372036854775808 as decimal(20, 0)) + 1) rows only")
+ .returns(expected);
+ with.query(values + "order by x fetch next "
+ + "(cast(9223372036854775808 as decimal(20, 0)) + 1) rows only")
+ .returns(expected);
+ }
+
/** Tests ORDER BY ... OFFSET ... FETCH. */
@Test void testOrderByOffsetFetch() {
CalciteAssert.that()
@@ -5977,6 +6058,125 @@ private CalciteAssert.AssertQuery withEmpDept(String sql) {
Matchers.returnsUnordered("name=Eric"));
}
+ /** Tests dynamic parameters in parenthesized FETCH expressions. */
+ @Test void testPreparedFetchExpression() throws Exception {
+ CalciteAssert.that()
+ .doWithConnection(connection -> {
+ final String values =
+ "select * from (values (1), (2), (3), (4)) as t(x)\n";
+ checkPreparedFetch(connection, values + "fetch next (?) rows only",
+ 2, "X=1\nX=2\n");
+ checkPreparedFetch(connection, values + "fetch next (? + 1) rows only",
+ 2, "X=1\nX=2\nX=3\n");
+ checkPreparedFetch(connection,
+ values + "fetch next (abs(cast(? as integer))) rows only",
+ 2, "X=1\nX=2\n");
+ checkPreparedFetch(connection,
+ values + "fetch next (abs(cast(? as integer))) rows only",
+ -2, "X=1\nX=2\n");
+ checkPreparedFetchRepeated(connection,
+ values + "fetch next (?) rows only",
+ new int[] {1, 3},
+ new String[] {"X=1\n", "X=1\nX=2\nX=3\n"});
+ checkPreparedFetchRepeated(connection,
+ values + "fetch next (? + 1) rows only",
+ new int[] {0, 2, 3},
+ new String[] {"X=1\n", "X=1\nX=2\nX=3\n",
+ "X=1\nX=2\nX=3\nX=4\n"});
+ checkPreparedFetch(connection,
+ values + "fetch next (? + abs(2)) rows only",
+ 1, "X=1\nX=2\nX=3\n");
+ checkPreparedFetch(connection,
+ values + "fetch next (cast(? as decimal(20, 0))) rows only",
+ new BigDecimal("9223372036854775808"),
+ "X=1\nX=2\nX=3\nX=4\n");
+
+ checkPreparedFetchFails(connection,
+ values + "fetch next (?) rows only", -1,
+ "FETCH value -1 is out of range");
+ checkPreparedFetchFails(connection,
+ values + "fetch next (? + 1) rows only", -2,
+ "FETCH value -1 is out of range");
+ });
+ }
+
+ /** Tests dynamic parameters in bindable/interpreter FETCH expressions. */
+ @Test void testBindablePreparedFetchExpression() throws Exception {
+ try (Hook.Closeable ignored = Hook.ENABLE_BINDABLE.addThread(Hook.propertyJ(true))) {
+ CalciteAssert.that()
+ .doWithConnection(connection -> {
+ final String values =
+ "select * from (values (1), (2), (3), (4)) as t(x)\n";
+ checkPreparedFetch(connection,
+ values + "fetch next (? + 1) rows only",
+ 2, "X=1\nX=2\nX=3\n");
+ checkPreparedFetchRepeated(connection,
+ values + "fetch next (? + 1) rows only",
+ new int[] {0, 2, 3},
+ new String[] {"X=1\n", "X=1\nX=2\nX=3\n",
+ "X=1\nX=2\nX=3\nX=4\n"});
+ checkPreparedFetch(connection,
+ values + "fetch next (cast(? as decimal(20, 0))) rows only",
+ new BigDecimal("9223372036854775808"),
+ "X=1\nX=2\nX=3\nX=4\n");
+ });
+ }
+ }
+
+ private static void checkPreparedFetch(Connection connection, String sql,
+ int value, String expected) {
+ try (PreparedStatement p = connection.prepareStatement(sql)) {
+ p.setInt(1, value);
+ try (ResultSet r = p.executeQuery()) {
+ assertThat(CalciteAssert.toString(r), is(expected));
+ }
+ } catch (SQLException e) {
+ throw TestUtil.rethrow(e);
+ }
+ }
+
+ private static void checkPreparedFetch(Connection connection, String sql,
+ BigDecimal value, String expected) {
+ try (PreparedStatement p = connection.prepareStatement(sql)) {
+ p.setBigDecimal(1, value);
+ try (ResultSet r = p.executeQuery()) {
+ assertThat(CalciteAssert.toString(r), is(expected));
+ }
+ } catch (SQLException e) {
+ throw TestUtil.rethrow(e);
+ }
+ }
+
+ private static void checkPreparedFetchRepeated(Connection connection, String sql,
+ int[] values, String[] expected) {
+ try (PreparedStatement p = connection.prepareStatement(sql)) {
+ for (int i = 0; i < values.length; i++) {
+ p.setInt(1, values[i]);
+ try (ResultSet r = p.executeQuery()) {
+ assertThat(CalciteAssert.toString(r), is(expected[i]));
+ }
+ }
+ } catch (SQLException e) {
+ throw TestUtil.rethrow(e);
+ }
+ }
+
+ private static void checkPreparedFetchFails(Connection connection, String sql,
+ long value, String expectedMessage) {
+ try (PreparedStatement p = connection.prepareStatement(sql)) {
+ if (value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE) {
+ p.setInt(1, (int) value);
+ } else {
+ p.setLong(1, value);
+ }
+ final SQLException e =
+ assertThrows(SQLException.class, p::executeQuery);
+ assertThat(e.getMessage(), containsString(expectedMessage));
+ } catch (SQLException e) {
+ throw TestUtil.rethrow(e);
+ }
+ }
+
private void checkPreparedOffsetFetch(final int offset, final int fetch,
final Matcher super ResultSet> matcher) throws Exception {
CalciteAssert.hr()
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index f1a4f1c27162..20e4c9c97df6 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -52,7 +52,10 @@
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexNodeAndFieldIndex;
+import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.schema.SchemaPlus;
@@ -5227,6 +5230,78 @@ private static RelNode buildCorrelateWithJoin(JoinRelType type, RelBuilder build
assertThat(mq.getMaxRowCount(planAfter), is(Double.POSITIVE_INFINITY));
}
+ @Test void testFetchExpressionCannotReferenceInputField() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ builder.scan("DEPT");
+ final RexNode field = builder.field("DEPTNO");
+
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, field, ImmutableList.of()));
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null,
+ builder.call(SqlStdOperatorTable.PLUS, builder.literal(1), field),
+ ImmutableList.of()));
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null,
+ new RexNodeAndFieldIndex(0, 0, "DEPTNO", field.getType()),
+ ImmutableList.of()));
+ }
+
+ @Test void testFetchExpressionMustHaveIntegralType() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ builder.scan("DEPT");
+
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, builder.literal("x"), ImmutableList.of()));
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, builder.literal(new BigDecimal("1.5")),
+ ImmutableList.of()));
+ }
+
+ @Test void testFetchExpressionCannotContainAggregateWindowOrSubQuery() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ final RelDataType intType =
+ builder.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
+ builder.scan("DEPT");
+ final RexNode aggregate =
+ builder.call(SqlStdOperatorTable.SUM, builder.literal(1));
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, aggregate, ImmutableList.of()));
+
+ final RexNode over =
+ builder.getRexBuilder().makeOver(intType,
+ SqlStdOperatorTable.ROW_NUMBER, ImmutableList.of(),
+ ImmutableList.of(), ImmutableList.of(),
+ RexWindowBounds.UNBOUNDED_PRECEDING,
+ RexWindowBounds.UNBOUNDED_FOLLOWING,
+ true, true, false, false, false);
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, over, ImmutableList.of()));
+
+ final RelBuilder subQueryBuilder = RelBuilder.create(config().build());
+ final RexNode subQuery =
+ RexSubQuery.scalar(subQueryBuilder.values(new String[] {"N"}, 1).build());
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, subQuery, ImmutableList.of()));
+ }
+
+ @Test void testFetchExpressionCannotContainLambda() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ final RelDataType intType =
+ builder.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
+ builder.scan("DEPT");
+ final RexLambdaRef lambdaRef = new RexLambdaRef(0, "x", intType);
+ final RexNode lambda =
+ builder.getRexBuilder().makeLambdaCall(
+ builder.call(SqlStdOperatorTable.PLUS, lambdaRef, builder.literal(1)),
+ ImmutableList.of(lambdaRef));
+
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, lambda, ImmutableList.of()));
+ assertThrows(IllegalArgumentException.class,
+ () -> builder.sortLimit(null, lambdaRef, ImmutableList.of()));
+ }
+
@Test void testAdoptConventionEnumerable() {
final RelBuilder builder = RelBuilder.create(config().build());
RelNode root = builder
diff --git a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
index f438a0ad9423..34f64b814a7b 100644
--- a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
@@ -1467,7 +1467,7 @@ void testColumnOriginsUnion() {
@Test void testRowCountSortLimitBeyondLong() {
final BigDecimal fetch = BigDecimal.valueOf(Long.MAX_VALUE).add(BigDecimal.ONE);
final double fetchDouble = fetch.doubleValue();
- final String sql = "select * from emp order by ename limit " + fetchDouble;
+ final String sql = "select * from emp order by ename limit " + fetch.toPlainString();
final RelMetadataFixture fixture = sql(sql);
fixture.assertThatRowCount(is(EMP_SIZE), is(0D), is(fetchDouble));
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 715d9f4f7b8e..711e8e9b8432 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1680,6 +1680,28 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}
+ @Test void testSortUnionTransposeWithNonDeterministicFetch() {
+ final String sql = "select a.name from dept a\n"
+ + "union all\n"
+ + "select b.name from dept b\n"
+ + "order by name fetch next (rand_integer(10)) rows only";
+ sql(sql)
+ .withPreRule(CoreRules.PROJECT_SET_OP_TRANSPOSE)
+ .withRule(CoreRules.SORT_UNION_TRANSPOSE)
+ .checkUnchanged();
+ }
+
+ @Test void testSortUnionTransposePushesParameterizedFetchExpression() {
+ final String sql = "select a.name from dept a\n"
+ + "union all\n"
+ + "select b.name from dept b\n"
+ + "order by name fetch next (? + 1) rows only";
+ sql(sql)
+ .withPreRule(CoreRules.PROJECT_SET_OP_TRANSPOSE)
+ .withRule(CoreRules.SORT_UNION_TRANSPOSE)
+ .check();
+ }
+
@Test void testSortRemovalAllKeysConstant() {
final String sql = "select count(*) as c\n"
+ "from sales.emp\n"
@@ -5899,11 +5921,9 @@ private void checkEmptyJoin(RelOptFixture f) {
.checkUnchanged();
}
- /** Test case for
- * [CALCITE-6647]
- * SortUnionTransposeRule should not push SORT past a UNION when SORT's fetch is DynamicParam
- . */
- @Test void testSortWithDynamicParam() {
+ /** Verifies that SortUnionTransposeRule pushes a deterministic dynamic FETCH
+ * past a UNION only once. */
+ @Test void testSortWithDynamicParamPushesOnce() {
HepProgramBuilder builder = new HepProgramBuilder();
builder.addRuleClass(SortProjectTransposeRule.class);
builder.addRuleClass(SortUnionTransposeRule.class);
@@ -9617,6 +9637,16 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
.check();
}
+ @Test void testDecorrelateProjectWithFetchExpression() {
+ final String query = "SELECT name, "
+ + "(SELECT sal FROM emp where dept.deptno = emp.deptno order by sal "
+ + "fetch next (1 + 0) rows only) "
+ + "FROM dept";
+ sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
+ .withLateDecorrelate(true)
+ .check();
+ }
+
/** Test case for [CALCITE-7289]
* Select NULL subquery throwing exception. */
@Test void testNullSelect() {
@@ -11980,6 +12010,33 @@ private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
.check();
}
+ @Test void testNondeterministicFetchPreventsDecorrelation() {
+ checkNondeterministicFetchPreventsDecorrelation(false);
+ }
+
+ @Test void testNondeterministicFetchPreventsTopDownDecorrelation() {
+ checkNondeterministicFetchPreventsDecorrelation(true);
+ }
+
+ private void checkNondeterministicFetchPreventsDecorrelation(boolean topDown) {
+ final String sql = "select t.deptno, e.ename\n"
+ + "from (select distinct deptno from emp) t,\n"
+ + "lateral (select ename from emp\n"
+ + " where emp.deptno = t.deptno\n"
+ + " order by sal\n"
+ + " fetch next (rand_integer(2) + 1) rows only) e";
+
+ final RelOptFixture fixture = sql(sql)
+ .withRule() // empty program
+ .withLateDecorrelate(true)
+ .withTopDownGeneralDecorrelate(topDown);
+ if (topDown) {
+ fixture.check();
+ } else {
+ fixture.checkUnchanged();
+ }
+ }
+
@Test void testTopDownGeneralDecorrelateForFilterSome() {
final String sql = "select empno from emp where "
+ "empno > SOME(select empno from emp_b where emp.ename = emp_b.ename)";
diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
index 4e8d17cd7a27..c807e5881fa1 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
@@ -1250,6 +1250,12 @@ public static void checkActualAndReferenceFiles() {
sql(sql).ok();
}
+ @Test void testFetchWithExpression() {
+ final String sql =
+ "select empno from emp fetch next (1 + abs(-2)) rows only";
+ sql(sql).ok();
+ }
+
/** Test case for
* [CALCITE-439]
* SqlValidatorUtil.uniquify() may not terminate under some conditions. */
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index 5cd74f62f21a..2422fb3860bb 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -9998,6 +9998,21 @@ void testGroupExpressionEquivalenceParams() {
.rewritesTo(expected);
}
+ @Test void testFetchExpressionType() {
+ sql("select name from dept fetch next (^upper('x')^) rows only")
+ .fails("FETCH expression must have an integral numeric type; "
+ + "actual type is 'CHAR\\(1\\) NOT NULL'");
+ sql("select name from dept fetch next (^'x'^) rows only")
+ .fails("FETCH expression must have an integral numeric type; "
+ + "actual type is 'CHAR\\(1\\) NOT NULL'");
+ sql("select name from dept fetch next (^1.5^) rows only")
+ .fails("FETCH expression must have an integral numeric type; "
+ + "actual type is 'DECIMAL\\(2, 1\\) NOT NULL'");
+ sql("select name from dept "
+ + "fetch next (^row_number() over ()^) rows only")
+ .fails("Windowed aggregate expression is illegal in FETCH clause");
+ }
+
@Test void testRewriteWithOffsetWithoutOrderBy() {
final String sql = "select name from dept offset 2";
final String expected = "SELECT `NAME`\n"
diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java
index 68bb56cf366d..23afb8c24628 100644
--- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java
+++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java
@@ -78,6 +78,30 @@ class EnumerableMergeUnionTest {
"empid=45; name=Pascal");
}
+ @Test void mergeUnionDoesNotPushNonDeterministicFetch() {
+ tester(false,
+ new HrSchemaBig(),
+ "select * from (select empid, name from emps "
+ + "union all select empid, name from emps) "
+ + "order by empid fetch next (rand_integer(10)) rows only")
+ .explainContains("EnumerableLimitSort(sort0=[$0], dir0=[ASC], "
+ + "fetch=[RAND_INTEGER(10)])\n"
+ + " EnumerableMergeUnion(all=[true])\n"
+ + " EnumerableSort(sort0=[$0], dir0=[ASC])\n");
+ }
+
+ @Test void mergeUnionPushesParameterizedFetchExpression() {
+ tester(false,
+ new HrSchemaBig(),
+ "select * from (select empid, name from emps "
+ + "union all select empid, name from emps) "
+ + "order by empid fetch next (? + 1) rows only")
+ .explainContains("EnumerableLimit(fetch=[+(?0, 1)])\n"
+ + " EnumerableMergeUnion(all=[true])\n"
+ + " EnumerableLimitSort(sort0=[$0], dir0=[ASC], "
+ + "fetch=[+(?0, 1)])\n");
+ }
+
@Test void mergeUnionAllOrderByName() {
tester(false,
new HrSchemaBig(),
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c782bb913cc1..1d25ccf3b7e8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2755,6 +2755,46 @@ LogicalProject(NAME=[$1])
LogicalFilter(condition=[<=($3, 1)])
LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 DESC NULLS FIRST)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -10984,6 +11024,95 @@ LogicalProject(USER=[USER])
LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalProject(NAME=[$1], DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -19584,7 +19713,55 @@ LogicalSort(sort0=[$0], dir0=[ASC], fetch=[0])
]]>
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index 5ee7180efa75..fa39cbab96fd 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -2426,6 +2426,18 @@ LogicalSort(fetch=[5])
LogicalSort(fetch=[?0])
LogicalProject(EMPNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+
+
+
+
diff --git a/core/src/test/resources/sql/fetch.iq b/core/src/test/resources/sql/fetch.iq
new file mode 100644
index 000000000000..9e335f712a07
--- /dev/null
+++ b/core/src/test/resources/sql/fetch.iq
@@ -0,0 +1,176 @@
+# fetch.iq
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+!use post
+!set outputformat mysql
+
+# FETCH accepts a parenthesized arithmetic expression.
+select *
+from (values (1), (2), (3), (4)) as t(x)
+fetch next (1 + abs(-2)) rows only;
++---+
+| X |
++---+
+| 1 |
+| 2 |
+| 3 |
++---+
+(3 rows)
+
+!ok
+
+# FETCH accepts a parenthesized scalar expression.
+select *
+from (values (1), (2), (3), (4)) as t(x)
+fetch next (abs(2)) rows only;
++---+
+| X |
++---+
+| 1 |
+| 2 |
++---+
+(2 rows)
+
+!ok
+
+# FETCH values are not restricted to the BIGINT range.
+select *
+from (values (1), (2), (3), (4)) as t(x)
+fetch next (cast(9223372036854775808 as decimal(20, 0)) + 1) rows only;
++---+
+| X |
++---+
+| 1 |
+| 2 |
+| 3 |
+| 4 |
++---+
+(4 rows)
+
+!ok
+
+# FETCH expression cannot be negative.
+select *
+from (values (1), (2), (3)) as t(x)
+fetch next (0 - 1) rows only;
+FETCH value -1 is out of range
+!error
+
+# FETCH expression cannot evaluate to NULL.
+select *
+from (values (1), (2), (3)) as t(x)
+fetch next (cast(null as integer)) rows only;
+FETCH expression evaluated to NULL
+!error
+
+# FETCH expression must have an integral numeric type.
+select *
+from (values (1), (2), (3)) as t(x)
+fetch next (1.5) rows only;
+FETCH expression must have an integral numeric type
+!error
+
+# FETCH expression cannot reference input columns.
+select *
+from (values (1), (2), (3)) as t(x)
+fetch next (x) rows only;
+FETCH expression cannot reference table column 'X'
+!error
+
+# Expressions without parentheses are not allowed in FETCH.
+select *
+from (values (1), (2), (3)) as t(x)
+fetch next 1 + 2 rows only;
+Encountered "+"
+!error
+
+# FETCH expression works with a table source.
+select deptno, dname
+from dept
+order by deptno
+fetch next (1 + 1) rows only;
++--------+-------------+
+| DEPTNO | DNAME |
++--------+-------------+
+| 10 | Sales |
+| 20 | Marketing |
++--------+-------------+
+(2 rows)
+
+!ok
+
+# FETCH expression works together with OFFSET on a table source.
+select deptno, dname
+from dept
+order by deptno
+offset 1 rows
+fetch next (1 + 1) rows only;
++--------+-------------+
+| DEPTNO | DNAME |
++--------+-------------+
+| 20 | Marketing |
+| 30 | Engineering |
++--------+-------------+
+(2 rows)
+
+!ok
+
+# FETCH expression may contain a scalar function on a table source.
+select deptno
+from dept
+order by deptno
+fetch next (abs(-3)) rows only;
++--------+
+| DEPTNO |
++--------+
+| 10 |
+| 20 |
+| 30 |
++--------+
+(3 rows)
+
+!ok
+
+# FETCH expression cannot reference columns of a table source.
+select deptno, dname
+from dept
+order by deptno
+fetch next (deptno) rows only;
+FETCH expression cannot reference table column 'DEPTNO'
+!error
+
+# FETCH expression cannot reference columns even inside a larger expression.
+select deptno, dname
+from dept
+order by deptno
+fetch next (deptno + 1) rows only;
+FETCH expression cannot reference table column 'DEPTNO'
+!error
+
+# FETCH expression may be zero on a table source.
+select deptno
+from dept
+order by deptno
+fetch next (2 - 2) rows only;
++--------+
+| DEPTNO |
++--------+
++--------+
+(0 rows)
+
+!ok
diff --git a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java
index ff64d44ed2c1..1c8f9ef98623 100644
--- a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java
+++ b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidRules.java
@@ -850,6 +850,20 @@ protected DruidSortRule(DruidSortRuleConfig config) {
@Override public void onMatch(RelOptRuleCall call) {
final Sort sort = call.rel(0);
final DruidQuery query = call.rel(1);
+ final RexLiteral fetch =
+ sort.fetch == null
+ ? null
+ : RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch);
+ if (sort.fetch != null && fetch == null) {
+ return;
+ }
+ if (fetch != null) {
+ try {
+ RexLiteral.bigDecimalValue(fetch).intValueExact();
+ } catch (ArithmeticException e) {
+ return;
+ }
+ }
if (!DruidQuery.isValidSignature(query.signature() + 'l')) {
return;
}
@@ -865,7 +879,8 @@ protected DruidSortRule(DruidSortRuleConfig config) {
}
final RelNode newSort = sort
- .copy(sort.getTraitSet(), ImmutableList.of(Util.last(query.rels)));
+ .copy(sort.getTraitSet(), Util.last(query.rels), sort.getCollation(),
+ sort.offset, fetch);
call.transformTo(DruidQuery.extendQuery(query, newSort));
}
diff --git a/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java b/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
index eec575e331d3..1323fcb3552b 100644
--- a/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
+++ b/druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
@@ -967,6 +967,31 @@ private void checkGroupBySingleSortLimit(boolean approx) {
.explainContains(explain);
}
+ @Test void testFetchExpression() {
+ final String sql = "select \"state_province\"\n"
+ + "from \"foodmart\"\n"
+ + "fetch next (1 + abs(-2)) rows only";
+ sql(sql)
+ .returnsCount(3)
+ .explainContains("DruidQuery(table=[[foodmart, foodmart]], "
+ + "intervals=[[1900-01-09T00:00:00.000Z/"
+ + "2992-01-10T00:00:00.000Z]], projects=[[$30]], fetch=[3])");
+ sql("select \"state_province\" from \"foodmart\" "
+ + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range");
+ }
+
+ @Test void testFetchExpressionBeyondIntegerRange() {
+ final String sql = "select \"state_province\"\n"
+ + "from \"foodmart\"\n"
+ + "fetch next "
+ + "(cast(3000000000 as bigint) + 1) rows only";
+ sql(sql)
+ .returnsCount(86837)
+ .explainContains("BindableSort(fetch=[3000000001:BIGINT])\n"
+ + " DruidQuery(table=[[foodmart, foodmart]], ");
+ }
+
/** Tests that distinct-count is pushed down to Druid and evaluated using
* "cardinality". The result is approximate, but gives the correct result in
* this example when rounded down using FLOOR. */
diff --git a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchRules.java b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchRules.java
index 3c44992b0a48..259c7bc55090 100644
--- a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchRules.java
+++ b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchRules.java
@@ -38,6 +38,7 @@
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -217,12 +218,26 @@ protected ElasticsearchSortRule(Config config) {
super(config);
}
- @Override public RelNode convert(RelNode relNode) {
+ @Override public @Nullable RelNode convert(RelNode relNode) {
final Sort sort = (Sort) relNode;
+ final RexLiteral fetch =
+ sort.fetch == null
+ ? null
+ : RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch);
+ if (sort.fetch != null && fetch == null) {
+ return null;
+ }
+ if (fetch != null) {
+ try {
+ RexLiteral.bigDecimalValue(fetch).longValueExact();
+ } catch (ArithmeticException e) {
+ return null;
+ }
+ }
final RelTraitSet traitSet = sort.getTraitSet().replace(out).replace(sort.getCollation());
return new ElasticsearchSort(relNode.getCluster(), traitSet,
convert(sort.getInput(), traitSet.replace(RelCollations.EMPTY)), sort.getCollation(),
- sort.offset, sort.fetch);
+ sort.offset, fetch);
}
}
diff --git a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchSort.java b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchSort.java
index e7af28627978..56e84659f99d 100644
--- a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchSort.java
+++ b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchSort.java
@@ -78,7 +78,7 @@ public class ElasticsearchSort extends Sort implements ElasticsearchRel {
}
if (fetch != null) {
- implementor.fetch(RexLiteral.numberValue(fetch).longValue());
+ implementor.fetch(RexLiteral.bigDecimalValue(fetch).longValueExact());
}
}
diff --git a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/Scrolling.java b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/Scrolling.java
index 64b1d34681df..2e2561c4349f 100644
--- a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/Scrolling.java
+++ b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/Scrolling.java
@@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.Iterator;
+import java.util.NoSuchElementException;
import java.util.function.Consumer;
import static com.google.common.base.Preconditions.checkArgument;
@@ -72,12 +73,35 @@ Iterator query(ObjectNode query) {
Iterator result = flatten(iterator);
// apply limit
if (limit != Long.MAX_VALUE) {
- result = Iterators.limit(result, (int) limit);
+ result = limit(result, limit);
}
return result;
}
+ private static Iterator limit(Iterator iterator, long limit) {
+ checkArgument(limit >= 0, "limit: %s >= 0", limit);
+ return new Iterator() {
+ private long remaining = limit;
+
+ @Override public boolean hasNext() {
+ return remaining > 0 && iterator.hasNext();
+ }
+
+ @Override public E next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ --remaining;
+ return iterator.next();
+ }
+
+ @Override public void remove() {
+ iterator.remove();
+ }
+ };
+ }
+
/**
* Combines lazily multiple {@link ElasticsearchJson.Result} into a single iterator of
* {@link ElasticsearchJson.SearchHit}.
diff --git a/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java b/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
index b548d511540b..52d9dfc4c38e 100644
--- a/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
+++ b/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
@@ -571,6 +571,39 @@ private static Consumer sortedResultSetChecker(String column,
ElasticsearchChecker.elasticsearchChecker(
"'_source':['state','id']",
"size:3"));
+ calciteAssert()
+ .query("select state, id from zips\n"
+ + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range");
+ }
+
+ @Test void testFetchExpression() {
+ final String sql = "select state, id from zips\n"
+ + "fetch next (1 + abs(-2)) rows only";
+
+ calciteAssert()
+ .query(sql)
+ .returnsCount(3)
+ .explainContains("ElasticsearchSort(fetch=[3])")
+ .queryContains(
+ ElasticsearchChecker.elasticsearchChecker(
+ "'_source':['state','id']",
+ "size:3"));
+ calciteAssert()
+ .query("select state, id from zips\n"
+ + "fetch next (cast(3000000000 as bigint) + 1) rows only")
+ .runs()
+ .explainContains("ElasticsearchSort(fetch=[3000000001:BIGINT])");
+ }
+
+ @Test void testFetchExpressionBeyondLongRange() {
+ calciteAssert()
+ .query("select state, id from zips\n"
+ + "fetch next "
+ + "(cast(9223372036854775808 as decimal(20, 0))) rows only")
+ .returnsCount(ZIPS_SIZE)
+ .explainContains("EnumerableLimit(fetch=[9223372036854775808:DECIMAL(19, 0)])\n"
+ + " ElasticsearchToEnumerableConverter\n");
}
@Test void limit2() {
diff --git a/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeRules.java b/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeRules.java
index b8beb854e0f6..f9788f27d63b 100644
--- a/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeRules.java
+++ b/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeRules.java
@@ -34,6 +34,7 @@
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -221,6 +222,20 @@ protected GeodeSortLimitRule(GeodeSortLimitRuleConfig config) {
@Override public void onMatch(RelOptRuleCall call) {
final Sort sort = call.rel(0);
+ final RexLiteral fetch =
+ sort.fetch == null
+ ? null
+ : RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch);
+ if (sort.fetch != null && fetch == null) {
+ return;
+ }
+ if (fetch != null) {
+ try {
+ RexLiteral.bigDecimalValue(fetch).longValueExact();
+ } catch (ArithmeticException e) {
+ return;
+ }
+ }
final RelTraitSet traitSet = sort.getTraitSet()
.replace(GeodeRel.CONVENTION)
@@ -229,7 +244,7 @@ protected GeodeSortLimitRule(GeodeSortLimitRuleConfig config) {
GeodeSort geodeSort =
new GeodeSort(sort.getCluster(), traitSet,
convert(call.getPlanner(), sort.getInput(), traitSet.replace(RelCollations.EMPTY)),
- sort.getCollation(), sort.fetch);
+ sort.getCollation(), fetch);
call.transformTo(geodeSort);
}
diff --git a/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeSort.java b/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeSort.java
index c9b33252d980..75b0d2d1aa98 100644
--- a/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeSort.java
+++ b/geode/src/main/java/org/apache/calcite/adapter/geode/rel/GeodeSort.java
@@ -86,7 +86,7 @@ public class GeodeSort extends Sort implements GeodeRel {
}
if (fetch != null) {
- geodeImplementContext.setLimit(RexLiteral.numberValue(fetch).longValue());
+ geodeImplementContext.setLimit(RexLiteral.bigDecimalValue(fetch).longValueExact());
}
}
diff --git a/geode/src/test/java/org/apache/calcite/adapter/geode/rel/GeodeBookstoreTest.java b/geode/src/test/java/org/apache/calcite/adapter/geode/rel/GeodeBookstoreTest.java
index 3604be2e6a13..2bfbd58ff2ac 100644
--- a/geode/src/test/java/org/apache/calcite/adapter/geode/rel/GeodeBookstoreTest.java
+++ b/geode/src/test/java/org/apache/calcite/adapter/geode/rel/GeodeBookstoreTest.java
@@ -356,6 +356,28 @@ private CalciteAssert.AssertThat calciteAssert() {
+ " GeodeTableScan(table=[[geode, BookCustomer]])\n");
}
+ @Test void testFetchExpression() {
+ calciteAssert()
+ .query("select * from geode.BookCustomer "
+ + "fetch next (1 + abs(-2)) rows only")
+ .returnsCount(3)
+ .explainContains("GeodeSort(fetch=[3])");
+ calciteAssert()
+ .query("select * from geode.BookCustomer "
+ + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range");
+ }
+
+ @Test void testFetchExpressionBeyondLongRange() {
+ calciteAssert()
+ .query("select * from geode.BookCustomer "
+ + "fetch next "
+ + "(cast(9223372036854775808 as decimal(20, 0))) rows only")
+ .returnsCount(3)
+ .explainContains("EnumerableLimit(fetch=[9223372036854775808:DECIMAL(19, 0)])\n"
+ + " GeodeToEnumerableConverter\n");
+ }
+
@Test void testSelectWithNestedPdx2() {
calciteAssert()
.query("select primaryAddress from geode.BookCustomer limit 2")
diff --git a/mongodb/src/main/java/org/apache/calcite/adapter/mongodb/MongoRules.java b/mongodb/src/main/java/org/apache/calcite/adapter/mongodb/MongoRules.java
index 48e46e610d55..ee55036564a1 100644
--- a/mongodb/src/main/java/org/apache/calcite/adapter/mongodb/MongoRules.java
+++ b/mongodb/src/main/java/org/apache/calcite/adapter/mongodb/MongoRules.java
@@ -36,6 +36,7 @@
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
@@ -46,6 +47,7 @@
import org.apache.calcite.util.Util;
import org.apache.calcite.util.trace.CalciteTrace;
+import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import java.util.AbstractList;
@@ -262,14 +264,28 @@ private static class MongoSortRule extends MongoConverterRule {
super(config);
}
- @Override public RelNode convert(RelNode rel) {
+ @Override public @Nullable RelNode convert(RelNode rel) {
final Sort sort = (Sort) rel;
+ final RexLiteral fetch =
+ sort.fetch == null
+ ? null
+ : RexUtil.reduceFetchToLiteral(sort.getCluster(), sort.fetch);
+ if (sort.fetch != null && fetch == null) {
+ return null;
+ }
+ if (fetch != null) {
+ try {
+ RexLiteral.bigDecimalValue(fetch).longValueExact();
+ } catch (ArithmeticException e) {
+ return null;
+ }
+ }
final RelTraitSet traitSet =
sort.getTraitSet().replace(out)
.replace(sort.getCollation());
return new MongoSort(rel.getCluster(), traitSet,
convert(sort.getInput(), traitSet.replace(RelCollations.EMPTY)),
- sort.getCollation(), sort.offset, sort.fetch);
+ sort.getCollation(), sort.offset, fetch);
}
}
diff --git a/mongodb/src/test/java/org/apache/calcite/adapter/mongodb/MongoAdapterTest.java b/mongodb/src/test/java/org/apache/calcite/adapter/mongodb/MongoAdapterTest.java
index 7c79718e7984..a745dd227b9e 100644
--- a/mongodb/src/test/java/org/apache/calcite/adapter/mongodb/MongoAdapterTest.java
+++ b/mongodb/src/test/java/org/apache/calcite/adapter/mongodb/MongoAdapterTest.java
@@ -215,6 +215,54 @@ private CalciteAssert.AssertThat assertModel(URL url) {
mongoChecker(
"{$limit: 3}",
"{$project: {STATE: '$state', ID: '$_id'}}"));
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next (0 - 1) rows only")
+ .throws_("FETCH value -1 is out of range");
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next (cast(null as integer)) rows only")
+ .throws_("FETCH expression evaluated to NULL");
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next 3000000001 rows only")
+ .runs()
+ .explainContains("MongoSort(fetch=[3000000001:BIGINT])")
+ .queryContains(
+ mongoChecker(
+ "{$project: {STATE: '$state', ID: '$_id'}}",
+ "{$limit: 3000000001}"));
+ }
+
+ @Test void testFetchExpression() {
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next (1 + abs(-2)) rows only")
+ .returnsCount(3)
+ .explainContains("MongoSort(fetch=[3])")
+ .queryContains(
+ mongoChecker(
+ "{$limit: 3}",
+ "{$project: {STATE: '$state', ID: '$_id'}}"));
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next (cast(3000000000 as bigint) + 1) rows only")
+ .runs()
+ .explainContains("MongoSort(fetch=[3000000001:BIGINT])")
+ .queryContains(
+ mongoChecker(
+ "{$project: {STATE: '$state', ID: '$_id'}}",
+ "{$limit: 3000000001}"));
+ }
+
+ @Test void testFetchExpressionBeyondLongRange() {
+ assertModel(MODEL)
+ .query("select state, id from zips\n"
+ + "fetch next "
+ + "(cast(9223372036854775808 as decimal(20, 0))) rows only")
+ .returnsCount(ZIPS_SIZE)
+ .explainContains("EnumerableLimit(fetch=[9223372036854775808:DECIMAL(19, 0)])\n"
+ + " MongoToEnumerableConverter\n");
}
@Test void testJoin() {
diff --git a/server/src/test/java/org/apache/calcite/test/ServerTest.java b/server/src/test/java/org/apache/calcite/test/ServerTest.java
index 40e1430c8799..edc4e308f5be 100644
--- a/server/src/test/java/org/apache/calcite/test/ServerTest.java
+++ b/server/src/test/java/org/apache/calcite/test/ServerTest.java
@@ -43,6 +43,7 @@
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.DriverManager;
+import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
@@ -299,6 +300,40 @@ static Connection connect() throws SQLException {
}
}
+ /** Tests that FETCH cannot reference a column of its input table. */
+ @Test void testFetchExpressionCannotReferenceInputColumn() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table person (id int not null, name varchar(20))");
+ try (PreparedStatement p =
+ c.prepareStatement("insert into person (id, name) values (?, ?)")) {
+ p.setInt(1, 1);
+ p.setString(2, "foo");
+ assertThat(p.executeUpdate(), is(1));
+ }
+
+ SQLException e =
+ assertThrows(
+ SQLException.class, () -> s.executeQuery("select * from person "
+ + "fetch next id rows only"));
+ assertThat(e.getMessage(), containsString("Encountered \"id\""));
+
+ e =
+ assertThrows(
+ SQLException.class, () -> s.executeQuery("select * from person "
+ + "fetch next (id) rows only"));
+ assertThat(e.getMessage(),
+ containsString("FETCH expression cannot reference table column 'ID'"));
+
+ e =
+ assertThrows(
+ SQLException.class, () -> s.executeQuery("select * from person "
+ + "fetch next (1 + id) rows only"));
+ assertThat(e.getMessage(),
+ containsString("FETCH expression cannot reference table column 'ID'"));
+ }
+ }
+
/** Test case for
* [CALCITE-6022]
* Support "CREATE TABLE ... LIKE" DDL in server module. */
diff --git a/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java b/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java
index 4ab6f776ba4a..8e1c1cb31bce 100644
--- a/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java
+++ b/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java
@@ -4029,12 +4029,31 @@ void checkPeriodPredicate(Checker checker) {
+ "FROM `FOO`\n"
+ "OFFSET ? ROWS\n"
+ "FETCH NEXT ? ROWS ONLY");
+ // Arithmetic and scalar expressions are allowed within parentheses.
+ sql("select a from foo fetch next (1 + abs(-2)) rows only")
+ .ok("SELECT `A`\n"
+ + "FROM `FOO`\n"
+ + "FETCH NEXT (1 + ABS(-2)) ROWS ONLY");
+ // Expressions without parentheses are not allowed.
+ sql("select a from foo fetch next 1 ^+^ 2 rows only")
+ .fails("(?s).*Encountered \"\\+\" at .*");
+ sql("select a from foo fetch next ? ^+^ abs(2) rows only")
+ .fails("(?s).*Encountered \"\\+\" at .*");
// missing ROWS after FETCH
sql("select a from foo offset 1 fetch next 3 ^only^")
.fails("(?s).*Encountered \"only\" at .*");
// FETCH before OFFSET is illegal
sql("select a from foo fetch next 3 rows only ^offset^ 1")
.fails("(?s).*Encountered \"offset\" at .*");
+ // Subqueries are not allowed in FETCH
+ sql("select a from foo fetch next ^select^ 2 rows only")
+ .fails("(?s).*Encountered \"select\" at .*");
+ sql("select a from foo fetch next (^select^ 2) rows only")
+ .fails("(?s).*Encountered \"select\" at .*");
+ sql("select a from foo fetch next (^select^ ?) rows only")
+ .fails("(?s).*Encountered \"select\" at .*");
+ sql("select a from foo fetch next (^select^ max(a) from foo) rows only")
+ .fails("(?s).*Encountered \"select\" at .*");
}
/**