这两天空闲时间研究了一下 Linq 的提供器,简单的实现了一下,代码写的很乱,也没有注释,也没怎么对代码进行设计,因此有很多的临时变量和有些不必要的操作,但注重的是实现原理吧,微软的 Linq to SQL 实现水很深,这个例子只是简单的实现 select 和 where,其他的没有实现,并且对于 where 查询,只支持有限的 ==、>、<,不过这个不重要,如果需要可以添加对应的实现。
先把代码记录下来吧,以后有时间再优化下代码和添加些注释。
IQueryable 的实现:
IQueryProvider 的实现:
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Linq.Expressions;
- using System.Collections;
- namespace SimpleLinq2Sql {
- public class CustomTable < T > :IQueryable < T > {
- private Type _ElementType = null;
- private Expression _Expression = null;
- private IQueryProvider _Provider = null;
- public Type ElementType {
- get {
- return _ElementType;
- }
- }
- public Expression Expression {
- get {
- return _Expression;
- }
- }
- public IQueryProvider Provider {
- get {
- return _Provider;
- }
- }
- public CustomTable(Expression expression, IQueryProvider provider) {
- if (provider == null) throw new Exception("provider can't be null");
- _ElementType = typeof(T);
- _Expression = expression;
- _Provider = provider;
- }
- public CustomTable() : this(null, new CustomProvider()) {
- _Expression = Expression.Constant(this);
- }
- public IEnumerator < T > GetEnumerator() {
- return (Provider.Execute < IEnumerable < T >> (Expression)).GetEnumerator();
- }
- IEnumerator IEnumerable.GetEnumerator() {
- return GetEnumerator();
- }
- public override string ToString() {
- return _Provider.ToString();
- }
- }
- }
实体、属性与数据库中的表、列映射帮助类
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Linq.Expressions;
- using System.Reflection;
- using System.Data;
- using System.Data.SqlClient;
- namespace SimpleLinq2Sql {
- public class CustomProvider: IQueryProvider {
- private string sql = "";
- private int count = 0;
- private string tableName = "";
- private string selector = "";
- private string where = "";
- private Type _PreType = null;
- private Type _ElementType = null;
- public IQueryable < T > CreateQuery < T > (Expression expression) {
- _ElementType = typeof(T);
- SetQueryText(expression);
- count++;
- return new CustomTable < T > (expression, this);
- }
- public IQueryable CreateQuery(Expression expression) {
- _ElementType = expression.Type.GetGenericArguments()[0];
- SetQueryText(expression);
- count++;
- object[] args = new object[] {
- expression,
- this
- };
- return (IQueryable) Activator.CreateInstance(typeof(CustomTable < >).MakeGenericType(_ElementType), args);
- }
- public T Execute < T > (Expression expression) {
- return (T) ExecuteSql(expression);
- }
- public object Execute(Expression expression) {
- return ExecuteSql(expression);
- }
- private void SetQueryText(Expression expression) {
- MethodCallExpression call = (MethodCallExpression) expression;
- Expression first = call.Arguments[0];
- Expression second = call.Arguments[1];
- SetTableName(first);
- if (call.Method.Name == "Select") {
- where = " ";
- } else if (call.Method.Name == "Where") {
- selector = "select " + "t" + count + ".* ";
- }
- ProcessExpression(second);
- sql = selector + " from " + tableName + " " + where;
- }
- private void SetTableName(Expression expression) {
- if (expression is ConstantExpression) {
- _PreType = expression.Type.GetGenericArguments()[0];
- tableName = MapHelper.GetTableName(_PreType) + " as t" + count + " ";
- }
- if (expression is MethodCallExpression) {
- _PreType = expression.Type.GetGenericArguments()[0];
- tableName = "( " + sql + " ) as t" + count + " ";
- }
- }
- void ProcessExpression(Expression expression) {
- if (expression is UnaryExpression) {
- UnaryExpression tmp = (UnaryExpression) expression;
- ProcessExpression(tmp.Operand);
- }
- if (expression is LambdaExpression) {
- ProcessExpression(((LambdaExpression) expression).Body);
- }
- if (expression is BinaryExpression) {
- ProcessBinary((BinaryExpression) expression);
- }
- if (expression is NewExpression) {
- ProcessNew((NewExpression) expression);
- }
- }
- void ProcessBinary(BinaryExpression expression) {
- string membername = "";
- string propertyname = "";
- object value = "";
- string ope = "";
- if (expression.Left is BinaryExpression || expression.Right is BinaryExpression) {
- throw new Exception("only be one binary");
- }
- if (expression.Left is MemberExpression) {
- MemberExpression tmp = (MemberExpression) expression.Left;
- propertyname = tmp.Member.Name;
- membername = MapHelper.GetColumnName(_PreType, propertyname);
- }
- if (expression.Right is ConstantExpression) {
- ConstantExpression tmp = (ConstantExpression) expression.Right;
- value = tmp.Value;
- }
- if (expression.NodeType == ExpressionType.Equal) {
- ope = " = ";
- }
- if (expression.NodeType == ExpressionType.LessThan) {
- ope = " < ";
- }
- if (expression.NodeType == ExpressionType.GreaterThan) {
- ope = " > ";
- }
- Type type = MapHelper.GetColumnType(_PreType, propertyname);
- switch (type.Name) {
- case "Int32":
- case "Single":
- case "Double":
- where += " where t" + count + "." + membername + ope + value;
- break;
- case "String":
- case "DateTime":
- where += " where t" + count + "." + membername + ope + "'" + value + "'";
- break;
- }
- }
- void ProcessNew(NewExpression expression) {
- selector = "select ";
- List < string > newName = new List < String > ();
- List < string > oldName = new List < string > ();
- foreach(MemberInfo mi in expression.Members) {
- newName.Add(mi.Name);
- }
- foreach(MemberExpression arg in expression.Arguments) {
- oldName.Add(arg.Member.Name);
- }
- for (int i = 0; i < oldName.Count; i++) {
- if (newName[i] == oldName[i]) {
- selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + ",";
- } else {
- selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + " as " + newName[i] + " ,";
- }
- }
- selector = selector.Substring(0, selector.Length - 1);
- }
- private object ExecuteSql(Expression expression) {
- DataSet ds = new DataSet();
- using(SqlConnection connection = new SqlConnection("Data Source=.;Initial Catalog=TestLinq;Integrated Security=True")) //这里写死了数据库连接
- {
- connection.Open();
- SqlCommand cmd = new SqlCommand(sql, connection);
- SqlDataAdapter da = new SqlDataAdapter(cmd);
- da.Fill(ds);
- }
- return Table2Entity.ConvertFromTable(ds.Tables[0], _ElementType);;
- }
- public override string ToString() {
- return sql;
- }
- }
- }
自定义 TableAttribute
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Reflection;
- namespace SimpleLinq2Sql {
- public static class MapHelper {
- public static string GetTableName(Type type) {
- if (!type.IsDefined(typeof(TableAttribute), false)) throw new Exception("");
- TableAttribute ta = Attribute.GetCustomAttribute(type, typeof(TableAttribute)) as TableAttribute;
- return ta.TableName;
- }
- public static string GetColumnName(Type type, string propertyName) {
- PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
- if (pi == null) throw new Exception("");
- if (!pi.IsDefined(typeof(ColumnAttribute), false)) return propertyName;
- ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
- return ca.ColumnName;
- }
- public static Type GetColumnType(Type type, string propertyName) {
- PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
- if (pi == null) throw new Exception("");
- if (!pi.IsDefined(typeof(ColumnAttribute), false)) return pi.PropertyType;
- ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;
- return SwithType(ca.ColumnType);
- }
- static Type SwithType(DataType dtype) {
- Type type = null;
- switch (dtype) {
- case DataType.String:
- type = typeof(String);
- break;
- case DataType.Int:
- type = typeof(Int32);
- break;
- case DataType.DateTime:
- type = typeof(DateTime);
- break;
- case DataType.Float:
- type = typeof(float);
- break;
- case DataType.Double:
- type = typeof(double);
- break;
- }
- return type;
- }
- }
- }
自定义 ColumnAttribute
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- namespace SimpleLinq2Sql { [AttributeUsage(AttributeTargets.Class)] internal class TableAttribute: Attribute {
- private string _TableName;
- public string TableName {
- get {
- return _TableName;
- }
- }
- public TableAttribute(string tableName) {
- _TableName = tableName;
- }
- }
- }
数据类型枚举
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- namespace SimpleLinq2Sql { [AttributeUsage(AttributeTargets.Property)] internal class ColumnAttribute: Attribute {
- private string _ColumnName;
- private DataType _ColumnType = DataType.String;
- public string ColumnName {
- get {
- return _ColumnName;
- }
- }
- public DataType ColumnType {
- get {
- return _ColumnType;
- }
- set {
- _ColumnType = value;
- }
- }
- public ColumnAttribute(string columnName) {
- _ColumnName = columnName;
- }
- }
- }
Table 转换对应实体类
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- namespace SimpleLinq2Sql {
- public enum DataType {
- Int,
- String,
- Float,
- Double,
- DateTime
- }
- }
自定义的实体类:
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Data;
- using System.Reflection;
- namespace SimpleLinq2Sql {
- internal static class Table2Entity {
- static object ConvertFromDataRow(DataRow dr, Type type) {
- object o = null;
- if (!type.IsDefined(typeof(TableAttribute), false)) {
- List < object > paralist = new List < object > ();
- PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
- foreach(PropertyInfo p in pi) {
- if (!dr.Table.Columns.Contains(p.Name)) throw new Exception("");
- object value = Convert.ChangeType(dr[p.Name], p.PropertyType);
- paralist.Add(value);
- }
- o = Activator.CreateInstance(type, paralist.ToArray());
- } else {
- o = Activator.CreateInstance(type);
- PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
- foreach(PropertyInfo p in pi) {
- if (!dr.Table.Columns.Contains(MapHelper.GetColumnName(type, p.Name))) throw new Exception("");
- object value = Convert.ChangeType(dr[MapHelper.GetColumnName(type, p.Name)], p.PropertyType);
- p.SetValue(o, value, null);
- }
- }
- return o;
- }
- public static object ConvertFromTable(DataTable dt, Type type) {
- var t = typeof(List < >).MakeGenericType(type);
- object obj = Activator.CreateInstance(t);
- MethodInfo add = t.GetMethod("Add");
- foreach(DataRow dr in dt.Rows) {
- add.Invoke(obj, new object[] {
- ConvertFromDataRow(dr, type)
- });
- }
- return obj;
- }
- }
- }
Program 执行
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- namespace SimpleLinq2Sql { [Table("Student")] public class Student {
- private int _ID;
- private string _StuName;
- private string _Address;
- private int _Sex;
- private int _CollegeID; [Column("ID", ColumnType = DataType.Int)] public int ID {
- get {
- return _ID;
- }
- set {
- _ID = value;
- }
- } [Column("StuName")] public string Name {
- get {
- return _StuName;
- }
- set {
- _StuName = value;
- }
- } [Column("Address")] public string Address {
- get {
- return _Address;
- }
- set {
- _Address = value;
- }
- } [Column("Sex", ColumnType = DataType.Int)] public int Sex {
- get {
- return _Sex;
- }
- set {
- _Sex = value;
- }
- } [Column("CollegeID", ColumnType = DataType.Int)] public int CollegeID {
- get {
- return _CollegeID;
- }
- set {
- _CollegeID = value;
- }
- }
- }
- }
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- namespace SimpleLinq2Sql {
- class Program {
- static void Main(string[] args) {
- var o = new CustomTable < Student > ().Where(r = >r.Address == "china").Select(r = >new {
- NewName = r.Name,
- Country = r.Address,
- r.CollegeID,
- r.Sex
- }).Where(r = >r.Sex == 1);
- Console.WriteLine(o.ToString());
- foreach(var i in o) {
- Console.WriteLine(i.NewName + "," + i.Country + "," + i.CollegeID + "," + i.Sex);
- }
- Console.Read();
- }
- }
- }
来源: http://lib.csdn.net/article/dotnet/39079