Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48155][SQL] AQEPropagateEmptyRelation for join should check if remain child is just BroadcastQueryStageExec #46523

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }

protected def canPropagate(plan: LogicalPlan): Boolean = true
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved

protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
case p: Union if p.children.exists(isEmpty) =>
val newChildren = p.children.filterNot(isEmpty)
Expand Down Expand Up @@ -111,18 +113,19 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
case LeftSemi if isRightEmpty | isFalseCondition => empty(p)
case LeftAnti if isRightEmpty | isFalseCondition => p.left
case FullOuter if isLeftEmpty && isRightEmpty => empty(p)
case LeftOuter | FullOuter if isRightEmpty =>
case LeftOuter | FullOuter if isRightEmpty && canPropagate(p.left) =>
Project(p.left.output ++ nullValueProjectList(p.right), p.left)
case RightOuter if isRightEmpty => empty(p)
case RightOuter | FullOuter if isLeftEmpty =>
case RightOuter | FullOuter if isLeftEmpty && canPropagate(p.right) =>
Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
case LeftOuter if isFalseCondition =>
case LeftOuter if isFalseCondition && canPropagate(p.left) =>
Project(p.left.output ++ nullValueProjectList(p.right), p.left)
case RightOuter if isFalseCondition =>
case RightOuter if isFalseCondition && canPropagate(p.right) =>
Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
case _ => p
}
} else if (joinType == LeftSemi && conditionOpt.isEmpty && nonEmpty(p.right)) {
} else if (joinType == LeftSemi && conditionOpt.isEmpty &&
nonEmpty(p.right) && canPropagate(p.left)) {
p.left
} else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) {
empty(p)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
case _ => false
}

override protected def canPropagate(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, _: BroadcastQueryStageExec) => false
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
case _ => true
}

override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
// LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
// `PropagateEmptyRelationBase.commonApplyFunc`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ class AdaptiveQueryExecSuite
}
}

private def findTopLevelUnion(plan: SparkPlan): Seq[UnionExec] = {
collect(plan) {
case l: UnionExec => l
}
}

private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectWithSubqueries(plan) {
case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
Expand Down Expand Up @@ -2772,6 +2778,30 @@ class AdaptiveQueryExecSuite
}
}

test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
|INNER JOIN (
| SELECT * FROM testData2
| WHERE b = 0
| UNION ALL
| SELECT * FROM testData2
| WHErE b != 0
|) t2
|ON t1.b = t2.b AND t1.a = 0
|RIGHT OUTER JOIN testData2 t3
|ON t1.a > t3.a
|GROUP BY t3.b
""".stripMargin
)
assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
assert(findTopLevelUnion(adaptivePlan).size == 0)
}
}

test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") {
// partitioning: HashPartitioning
Expand Down