Thursday, 07 January 2010

Extending LINQ selectors with custom expressions

While working with a particularly nasty linq query, I came up with a situation that I've been in a number of times already, since I started working with Linq 2 SQL, and I'm sure so have a lot of other people.

Basically, the query was using a filter that was based on multiple columns in a couple of different tables. This is of course easily done with linq, but what I wanted to do was to store this filter in a shared class, so that I can reuse it in multiple places. The problem is - if you try to use a method, an extension method or a custom property with this logic inside, linq will just throw an exception:

Method 'Boolean IsComplete(Linq.Product)' has no supported translation to SQL

I have tried to find workarounds on multiple occasions, and at some point I have found a solution that could be used in a filter: instead of using a method, or a property, I could use an Expression tree. It can be passed in a Where extension method, and it will be parsed by the Linq 2 SQL provider, and translated to SQL, if it's possible. E.g.:

public static Expression<Func<User, bool>> IsActive = u => u.IsActive && !u.IsDeleted;

This has worked fine until I needed something different: at some point I had to get this property in a selector, not a filter. Because the value had to be assigned to a property in a complex object, I couldn't just assign the expression to the Boolean field. Trying to .Compile() and .Invoke() the expression didn't help either, as I was greeted with more exceptions.

This has worked fine until I needed something different: at some point I had to get this property in a selector, not a filter. Because the value had to be assigned to a property in a complex object, I couldn't just assign the expression to the Boolean field. Trying to .Compile() and .Invoke() the expression didn't help either, as I was greeted with more exceptions.

Searching for it didn't get me a lot of help, and the only solution that I have found didn't work for me, as it couldn't compile my query. Finally I decided that instead of trying to run this expression, I could take the selector expression tree that Linq creates, and update it with my changes.

After a bit more googling, debugging and object dumping in LINQPad, I have found what I've been looking for.
A selector expression can be of 2 types (from what I've seen so far): MemberInitExpression, in case it will initialise a known class object, or a NewExpression when it initialises an anonymous class. In the first case, I can get access to the class bindings and replace existing bindings with the expressions I provide. Otherwise I am looking at the constructor arguments, and replacing them with what I provide.

With this in mind, the extension class became quite simple, but it did exactly what I wanted!

public static IQueryable<R> SelectExpression<T, R>(this IQueryable<T> query, Expression<Func<T, R>> selector, Dictionary<String, Expression> updates)
{
    var properties = typeof (R).GetProperties().ToArray();
    Expression newPredicateEx;

    if (selector.Body is MemberInitExpression)
    {
        var initEx = selector.Body as MemberInitExpression;
        var bindings = initEx.Bindings.Where(b => !updates.Keys.Contains(b.Member.Name)).ToList();
        var bindUpdates = updates.Select(u => Expression.Bind(
            properties.Single(p => p.Name == u.Key), 
            Expression.Invoke(u.Value, selector.Parameters.ToArray()))).ToArray();

        newPredicateEx = Expression.MemberInit(initEx.NewExpression, bindings.Concat(bindUpdates));
    }
    else if (selector.Body is NewExpression)
    {
        var newEx = selector.Body as NewExpression;
        var args = newEx.Arguments.ToArray();
        var members = newEx.Members.ToArray();
        var newArgs = new List<Expression>();
        for(int i =0;i<members.Length;i++)
        {
            if (updates.Keys.Contains(members[i].Name))
                newArgs.Add(Expression.Invoke(updates[members[i].Name], selector.Parameters.ToArray()));
            else
                newArgs.Add(args[i]);
        }
        newPredicateEx = Expression.New(newEx.Constructor, newArgs, members);
    }
    else
    {
        throw new InvalidOperationException();
    }

    return query.Select(Expression.Lambda<Func<T, R>>(newPredicateEx, selector.Parameters));
}

To use this class, simply replace your .Select() method in your query with SelectExpression:

return query.SelectExpression(u => new User
{
    UserId = u.UserId,
    Name = u.Name
},
new Dictionary
{
    { "IsActive", User.IsActiveExpression }
});