Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48206][SQL][TESTS] Add tests for window rewrites with RewriteWithExpression #46492

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.IntegerType

class RewriteWithExpressionSuite extends PlanTest {

Expand All @@ -37,6 +36,20 @@ class RewriteWithExpressionSuite extends PlanTest {
private val testRelation = LocalRelation($"a".int, $"b".int)
private val testRelation2 = LocalRelation($"x".int, $"y".int)

private def normalizeCommonExpressionIds(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressions {
case a: Alias if a.name.startsWith("_common_expr") =>
a.withName("_common_expr_0")
case a: AttributeReference if a.name.startsWith("_common_expr") =>
a.withName("_common_expr_0")
}
}

override def comparePlans(
plan1: LogicalPlan, plan2: LogicalPlan, checkAnalysis: Boolean = true): Unit = {
super.comparePlans(normalizeCommonExpressionIds(plan1), normalizeCommonExpressionIds(plan2))
}

test("simple common expression") {
val a = testRelation.output.head
val expr = With(a) { case Seq(ref) =>
Expand All @@ -52,65 +65,48 @@ class RewriteWithExpressionSuite extends PlanTest {
ref * ref
}
val plan = testRelation.select(expr.as("col"))
val commonExprId = expr.defs.head.id.id
val commonExprName = s"_common_expr_$commonExprId"
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.select(($"$commonExprName" * $"$commonExprName").as("col"))
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.select(($"_common_expr_0" * $"_common_expr_0").as("col"))
.analyze
)
}

test("nested WITH expression in the definition expression") {
val a = testRelation.output.head
val Seq(a, b) = testRelation.output
val innerExpr = With(a + a) { case Seq(ref) =>
ref + ref
}
val innerCommonExprId = innerExpr.defs.head.id.id
val innerCommonExprName = s"_common_expr_$innerCommonExprId"

val b = testRelation.output.last
val outerExpr = With(innerExpr + b) { case Seq(ref) =>
ref * ref
}
val outerCommonExprId = outerExpr.defs.head.id.id
val outerCommonExprName = s"_common_expr_$outerCommonExprId"

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b)
.as(outerCommonExprName)
val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)(
exprId = rewrittenOuterExpr.exprId)
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*)
.select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*)
.select((outerExprAttr * outerExprAttr).as("col"))
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.select((testRelation.output ++ Seq($"_common_expr_0",
($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))): _*)
.select(($"_common_expr_1" * $"_common_expr_1").as("col"))
.analyze
)
}

test("nested WITH expression in the main expression") {
val a = testRelation.output.head
val Seq(a, b) = testRelation.output
val innerExpr = With(a + a) { case Seq(ref) =>
ref + ref
}
val innerCommonExprId = innerExpr.defs.head.id.id
val innerCommonExprName = s"_common_expr_$innerCommonExprId"

val b = testRelation.output.last
val outerExpr = With(b + b) { case Seq(ref) =>
ref * ref + innerExpr
}
val outerCommonExprId = outerExpr.defs.head.id.id
val outerCommonExprName = s"_common_expr_$outerCommonExprId"

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
val rewrittenOuterExpr = (b + b).as(outerCommonExprName)
val rewrittenInnerExpr = (a + a).as("_common_expr_0")
val rewrittenOuterExpr = (b + b).as("_common_expr_1")
val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute +
(rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
comparePlans(
Expand All @@ -124,11 +120,10 @@ class RewriteWithExpressionSuite extends PlanTest {
}

test("correlated nested WITH expression is not supported") {
val b = testRelation.output.last
val Seq(a, b) = testRelation.output
val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
val outerRef = new CommonExpressionRef(outerCommonExprDef)

val a = testRelation.output.head
// The inner expression definition references the outer expression
val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1))
val ref1 = new CommonExpressionRef(commonExprDef1)
Expand All @@ -152,13 +147,11 @@ class RewriteWithExpressionSuite extends PlanTest {
ref < 10 && ref > 0
}
val plan = testRelation.where(condition)
val commonExprId = condition.defs.head.id.id
val commonExprName = s"_common_expr_$commonExprId"
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.where($"$commonExprName" < 10 && $"$commonExprName" > 0)
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.where($"_common_expr_0" < 10 && $"_common_expr_0" > 0)
.select(testRelation.output: _*)
.analyze
)
Expand All @@ -170,13 +163,11 @@ class RewriteWithExpressionSuite extends PlanTest {
ref < 10 && ref > 0
}
val plan = testRelation.join(testRelation2, condition = Some(condition))
val commonExprId = condition.defs.head.id.id
val commonExprName = s"_common_expr_$commonExprId"
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.join(testRelation2, condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0))
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.join(testRelation2, condition = Some($"_common_expr_0" < 10 && $"_common_expr_0" > 0))
.select((testRelation.output ++ testRelation2.output): _*)
.analyze
)
Expand All @@ -188,14 +179,12 @@ class RewriteWithExpressionSuite extends PlanTest {
ref < 10 && ref > 0
}
val plan = testRelation.join(testRelation2, condition = Some(condition))
val commonExprId = condition.defs.head.id.id
val commonExprName = s"_common_expr_$commonExprId"
comparePlans(
Optimizer.execute(plan),
testRelation
.join(
testRelation2.select((testRelation2.output :+ (x + x).as(commonExprName)): _*),
condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)
testRelation2.select((testRelation2.output :+ (x + x).as("_common_expr_0")): _*),
condition = Some($"_common_expr_0" < 10 && $"_common_expr_0" > 0)
)
.select((testRelation.output ++ testRelation2.output): _*)
.analyze
Expand Down Expand Up @@ -234,14 +223,12 @@ class RewriteWithExpressionSuite extends PlanTest {
ref * ref
}, a))
val plan2 = testRelation.select(expr2.as("col"))
val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id
val commonExprName = s"_common_expr_$commonExprId"
// With in the always-evaluated branches can still be optimized.
comparePlans(
Optimizer.execute(plan2),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col"))
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.select(Coalesce(Seq(($"_common_expr_0" * $"_common_expr_0"), a)).as("col"))
.analyze
)
}
Expand All @@ -261,38 +248,32 @@ class RewriteWithExpressionSuite extends PlanTest {
(expr2 + 2).as("col1"),
count(expr3 - 3).as("col2")
)
val commonExpr1Id = expr1.defs.head.id.id
val commonExpr1Name = s"_common_expr_$commonExpr1Id"
// Note that the common expression in expr2 gets de-duplicated by PullOutGroupingExpressions.
val commonExpr3Id = expr3.defs.head.id.id
val commonExpr3Name = s"_common_expr_$commonExpr3Id"
val groupingExprName = "_groupingexpression"
val aggExprName = "_aggregateexpression"
comparePlans(
Optimizer.execute(plan),
testRelation
.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
.select(testRelation.output :+
($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
.select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr3Name)): _*)
.groupBy($"$groupingExprName")(
$"$groupingExprName",
count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName)
($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"): _*)
.select(testRelation.output ++ Seq($"_groupingexpression",
(a + 1).as("_common_expr_1")): _*)
.groupBy($"_groupingexpression")(
$"_groupingexpression",
count($"_common_expr_1" * $"_common_expr_1" - 3).as("_aggregateexpression")
)
.select(($"$groupingExprName" + 2).as("col1"), $"`$aggExprName`".as("col2"))
.select(($"_groupingexpression" + 2).as("col1"), $"_aggregateexpression".as("col2"))
.analyze
)
// Running CollapseProject after the rule cleans up the unnecessary projections.
comparePlans(
CollapseProject(Optimizer.execute(plan)),
testRelation
.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
.select(testRelation.output ++ Seq(
($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
(a + 1).as(commonExpr3Name)): _*)
.groupBy($"$groupingExprName")(
($"$groupingExprName" + 2).as("col1"),
count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2")
($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"),
(a + 1).as("_common_expr_1")): _*)
.groupBy($"_groupingexpression")(
($"_groupingexpression" + 2).as("col1"),
count($"_common_expr_1" * $"_common_expr_1" - 3).as("col2")
)
.analyze
)
Expand All @@ -311,21 +292,16 @@ class RewriteWithExpressionSuite extends PlanTest {
expr1.as("col2"),
max(expr2).as("col3")
)
val commonExpr1Id = expr1.defs.head.id.id
val commonExpr1Name = s"_common_expr_$commonExpr1Id"
val commonExpr2Id = expr2.defs.head.id.id
val commonExpr2Name = s"_common_expr_$commonExpr2Id"
val aggExprName = "_aggregateexpression"
comparePlans(
Optimizer.execute(plan),
testRelation
.select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*)
.groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(aggExprName))
.select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name))
.select(testRelation.output :+ (b + 2).as("_common_expr_0"): _*)
.groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"))
.select(a, $"_aggregateexpression", (a + 1).as("_common_expr_1"))
.select(
(a + 3).as("col1"),
($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"),
$"`$aggExprName`".as("col3")
($"_common_expr_1" * $"_common_expr_1").as("col2"),
$"_aggregateexpression".as("col3")
)
.analyze
)
Expand All @@ -340,14 +316,13 @@ class RewriteWithExpressionSuite extends PlanTest {
(a - 1).as("col1"),
expr.as("col2")
)
val aggExprName = "_aggregateexpression"
comparePlans(
Optimizer.execute(plan),
testRelation
.groupBy(a)(a, count(a - 1).as(aggExprName))
.groupBy(a)(a, count(a - 1).as("_aggregateexpression"))
.select(
(a - 1).as("col1"),
($"$aggExprName" * $"$aggExprName").as("col2")
($"_aggregateexpression" * $"_aggregateexpression").as("col2")
)
.analyze
)
Expand Down Expand Up @@ -376,19 +351,91 @@ class RewriteWithExpressionSuite extends PlanTest {
ref * max(expr) + ref
}
val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze
val commonExpr1Id = expr.defs.head.id.id
val commonExpr1Name = s"_common_expr_$commonExpr1Id"
val commonExpr2Id = nestedExpr.defs.head.id.id
val commonExpr2Name = s"_common_expr_$commonExpr2Id"
val aggExprName = "_aggregateexpression"
comparePlans(
Optimizer.execute(plan),
testRelation
.select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
.groupBy(a)(a, max($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName))
.select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name))
.select(($"$commonExpr2Name" * $"$aggExprName" + $"$commonExpr2Name").as("col"))
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
.groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"))
.select($"a", $"_aggregateexpression", (a - 1).as("_common_expr_1"))
.select(($"_common_expr_1" * $"_aggregateexpression" + $"_common_expr_1").as("col"))
.analyze
)
}

test("WITH expression in window exprs") {
val Seq(a, b) = testRelation.output
val expr1 = With(a + 1) { case Seq(ref) =>
ref * ref
}
val expr2 = With(b + 2) { case Seq(ref) =>
ref * ref
}
val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
val plan = testRelation
.window(
Seq(windowExpr(count(a), windowSpec(Seq(expr2), Nil, frame)).as("col2")),
Seq(expr2),
Nil
)
.window(
Seq(windowExpr(sum(expr1), windowSpec(Seq(a), Nil, frame)).as("col3")),
Seq(a),
Nil
)
.select((a - 1).as("col1"), $"col2", $"col3")
.analyze
comparePlans(
Optimizer.execute(plan),
testRelation
.select(a, b, (b + 2).as("_common_expr_0"))
.select(a, b, $"_common_expr_0", (b + 2).as("_common_expr_1"))
.window(
Seq(windowExpr(count(a), windowSpec(Seq($"_common_expr_0" * $"_common_expr_0"), Nil,
frame)).as("col2")),
Seq($"_common_expr_1" * $"_common_expr_1"),
Nil
)
.select(a, b, $"col2")
.select(a, b, $"col2", (a + 1).as("_common_expr_2"))
.window(
Seq(windowExpr(sum($"_common_expr_2" * $"_common_expr_2"),
windowSpec(Seq(a), Nil, frame)).as("col3")),
Seq(a),
Nil
)
.select(a, b, $"col2", $"col3")
.select((a - 1).as("col1"), $"col2", $"col3")
.analyze
)
}

test("WITH common expression is window function") {
val a = testRelation.output.head
val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
val winExpr = windowExpr(sum(a), windowSpec(Seq(a), Nil, frame))
val expr = With(winExpr) {
case Seq(ref) => ref * ref
}
val plan = testRelation.select(expr.as("col")).analyze
comparePlans(
Optimizer.execute(plan),
testRelation
.select(a)
.window(Seq(winExpr.as("_we0")), Seq(a), Nil)
.select(a, $"_we0", ($"_we0" * $"_we0").as("col"))
.select($"col")
.analyze
)
}

test("window functions in child of WITH expression with ref is not supported") {
val a = testRelation.output.head
intercept[java.lang.AssertionError] {
val expr = With(a - 1) { case Seq(ref) =>
ref + windowExpr(sum(ref), windowSpec(Seq(a), Nil, UnspecifiedFrame))
}
val plan = testRelation.window(Seq(expr.as("col")), Seq(a), Nil)
Optimizer.execute(plan)
}
}
}