1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
using System;
using System.Collections.Generic;
using System.Text;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Find projection expressions on the right side of CROSS APPLY that do not
/// depend exclusively on right side productions and move them outside of the
/// CROSS APPLY by enclosing the CROSS APPLY with a new source.
/// </summary>
class SqlLiftIndependentRowExpressions {
internal static SqlNode Lift(SqlNode node) {
ColumnLifter cl = new ColumnLifter();
node = cl.Visit(node);
return node;
}
private class ColumnLifter : SqlVisitor {
SelectScope expressionSink;
SqlAggregateChecker aggregateChecker;
internal ColumnLifter() {
this.aggregateChecker = new SqlAggregateChecker();
}
class SelectScope {
// Stack of projections lifted from the right to be pushed on the left.
internal Stack<List<SqlColumn>> Lifted = new Stack<List<SqlColumn>>();
internal IEnumerable<SqlAlias> LeftProduction;
internal HashSet<SqlColumn> ReferencedExpressions = new HashSet<SqlColumn>();
}
internal override SqlSelect VisitSelect(SqlSelect select) {
SelectScope s = expressionSink;
// Don't lift through a TOP.
if (select.Top != null) {
expressionSink = null;
}
// Don't lift through a GROUP BY (or implicit GROUP BY).
if (select.GroupBy.Count > 0 || this.aggregateChecker.HasAggregates(select)) {
expressionSink = null;
}
// Don't lift through DISTINCT
if (select.IsDistinct) {
expressionSink = null;
}
if (expressionSink != null) {
List<SqlColumn> keep = new List<SqlColumn>();
List<SqlColumn> lift = new List<SqlColumn>();
foreach (SqlColumn sc in select.Row.Columns) {
bool referencesLeftsideAliases = SqlAliasesReferenced.ReferencesAny(sc.Expression, expressionSink.LeftProduction);
bool isLockedExpression = expressionSink.ReferencedExpressions.Contains(sc);
if (referencesLeftsideAliases && !isLockedExpression) {
lift.Add(sc);
} else {
keep.Add(sc);
}
}
select.Row.Columns.Clear();
select.Row.Columns.AddRange(keep);
if (lift.Count > 0) {
expressionSink.Lifted.Push(lift);
}
}
SqlSelect sel = base.VisitSelect(select);
expressionSink = s;
return sel;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
if (expressionSink!=null) {
expressionSink.ReferencedExpressions.Add(cref.Column);
}
return cref;
}
internal override SqlSource VisitJoin(SqlJoin join) {
if (join.JoinType == SqlJoinType.CrossApply) {
// Visit the left side as usual.
join.Left = this.VisitSource(join.Left);
// Visit the condition as usual.
join.Condition = this.VisitExpression(join.Condition);
// Visit the right, with the expressionSink set.
SelectScope s = expressionSink;
expressionSink = new SelectScope();
expressionSink.LeftProduction = SqlGatherProducedAliases.Gather(join.Left);
join.Right = this.VisitSource(join.Right);
// Were liftable expressions found?
SqlSource newSource = join;
foreach (List<SqlColumn> cols in expressionSink.Lifted) {
newSource = PushSourceDown(newSource, cols);
}
expressionSink = s;
return newSource;
}
return base.VisitJoin(join);
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlSource PushSourceDown(SqlSource sqlSource, List<SqlColumn> cols) {
SqlSelect ns = new SqlSelect(new SqlNop(cols[0].ClrType, cols[0].SqlType, sqlSource.SourceExpression), sqlSource, sqlSource.SourceExpression);
ns.Row.Columns.AddRange(cols);
return new SqlAlias(ns);
}
}
}
}
|