`
ayufox
  • 浏览: 273752 次
  • 性别: Icon_minigender_1
  • 来自: 深圳
社区版块
存档分类
最新评论

AST构建Hibernat动态查询

阅读更多
一、效果
java 代码
 
  1. public class HqlCompilerImplTest extends TestCase  
  2. {  
  3.     private IHqlCompiler compiler = new HqlCompilerImpl();  
  4.   
  5.     public void test1()  
  6.     {  
  7.         Hql hql = compiler.compile("from User u where u.username = :username and u.age = :age order by u.username,u.age",  
  8.                 new MapContext());  
  9.         assertEquals("from User u order by u.username,u.age", hql.getHql());  
  10.         assertEquals(0, hql.getParameters().length);  
  11.     }  
  12.   
  13.     public void test2)  
  14.     {  
  15.         MapContext context = new MapContext();  
  16.         context.add("username""ray");  
  17.           
  18.         Hql hql = compiler  
  19.                 .compile(  
  20.                         "from User u where u.username = :username and u.age = :age order by u.username,u.age",  
  21.                         context);  
  22.           
  23.         assertEquals(  
  24.                 "from User u where u.username = :username order by u.username,u.age",  
  25.                 hql.getHql());  
  26.         assertEquals(1, hql.getParameters().length);  
  27.         assertEquals("username", hql.getParameters()[0].getName());  
  28.         assertEquals("ray", hql.getParameters()[0].getValue());  
  29.     }  
  30.   
  31.     public void test3()  
  32.     {  
  33.         MapContext context = new MapContext();  
  34.         context.add("username""ray");  
  35.         context.add("age"new Integer(10));  
  36.           
  37.         Hql hql = compiler  
  38.         .compile(  
  39.                 "from User u where u.username = :username and u.age = :age order by u.username,u.age",  
  40.                 context);  
  41.         assertEquals(  
  42.                 "from User u where (u.username = :username and u.age = :age) order by u.username,u.age",  
  43.                 hql.getHql());  
  44.         assertEquals(2, hql.getParameters().length);  
  45.         assertEquals("username", hql.getParameters()[0].getName());  
  46.         assertEquals("ray", hql.getParameters()[0].getValue());  
  47.         assertEquals("age", hql.getParameters()[1].getName());  
  48.         assertEquals(new Integer(10), hql.getParameters()[1].getValue());  
  49.     }  
  50. }  

二、基本接口
java 代码
 
  1. package com.ayufox.framework.core;  
  2.   
  3. /** 
  4.  * @author ray 
  5.  * 
  6.  */  
  7. public interface Context  
  8. {  
  9.     Object get(String name);  
  10. }  
java 代码
 
  1. package com.ayufox.framework.core.dao.hqlx;  
  2.   
  3. import com.ayufox.framework.core.Context;  
  4. import com.ayufox.framework.core.dao.hql.Hql;  
  5.   
  6. /** 
  7.  * @author ray 
  8.  * 
  9.  */  
  10. public interface IHqlCompiler  
  11. {  
  12.     Hql compile(String hql, Object ... values);  
  13.       
  14.     Hql compile(String hql, Context context);  
  15. }  
三、实现
java 代码
 
  1. package com.ayufox.framework.core.dao.hqlx.ast;  
  2.   
  3. import java.io.ByteArrayOutputStream;  
  4. import java.io.PrintStream;  
  5. import java.util.Collections;  
  6. import java.util.HashMap;  
  7.   
  8. import org.apache.commons.logging.Log;  
  9. import org.apache.commons.logging.LogFactory;  
  10. import org.hibernate.hql.ast.HqlParser;  
  11.   
  12. import antlr.RecognitionException;  
  13. import antlr.TokenStreamException;  
  14. import antlr.collections.AST;  
  15.   
  16. import com.ayufox.framework.core.Context;  
  17. import com.ayufox.framework.core.cache.Cache;  
  18. import com.ayufox.framework.core.cache.MapCache;  
  19. import com.ayufox.framework.core.dao.hql.Hql;  
  20. import com.ayufox.framework.core.dao.hqlx.DynamicMapContext;  
  21. import com.ayufox.framework.core.dao.hqlx.IHqlCompiler;  
  22. import com.ayufox.framework.core.dao.hqlx.IllegalSyntaxException;  
  23.   
  24. /** 
  25.  * @author ray 
  26.  * 
  27.  */  
  28. public class HqlCompilerImpl implements IHqlCompiler  
  29. {  
  30.     private final static Log LOG = LogFactory.getLog(HqlCompilerImpl.class);  
  31.   
  32.     private Cache astCache = new MapCache(Collections  
  33.             .synchronizedMap(new HashMap()));  
  34.   
  35.     public void setAstCache(Cache astCache)  
  36.     {  
  37.         this.astCache = astCache;  
  38.     }  
  39.   
  40.     /* (non-Javadoc) 
  41.      * @see com.konceptusa.framework.core.dao.hqlx.IHqlCompiler#compile(java.lang.String, java.lang.Object[]) 
  42.      */  
  43.     public Hql compile(String hql, Object... values)  
  44.     {  
  45.         return compile(hql, new DynamicMapContext(values));  
  46.     }  
  47.   
  48.     /* (non-Javadoc) 
  49.      * @see com.konceptusa.framework.core.dao.hqlx.IHqlCompiler#compile(java.lang.String, com.konceptusa.framework.core.Context) 
  50.      */  
  51.     public Hql compile(String hql, Context context)  
  52.     {  
  53.         if (hql == null || context == null)  
  54.         {  
  55.             throw new IllegalArgumentException("hql or context can't be null");  
  56.         }  
  57.   
  58.         AST ast = getRootAST(hql);  
  59.   
  60.         HqlCompileExecutor compilerContext = new HqlCompileExecutor(ast,  
  61.                 context);  
  62.         return compilerContext.build();  
  63.     }  
  64.   
  65.     protected AST getRootAST(String hql)  
  66.     {  
  67.         AST ast = (AST) this.astCache.get(hql);  
  68.         if (ast == null)  
  69.         {  
  70.             ast = createAST(hql);  
  71.             this.astCache.put(hql, ast);  
  72.             if (LOG.isDebugEnabled())  
  73.             {  
  74.                 LOG.debug("get ast[" + ast + "] from cache for hql[" + hql  
  75.                         + "]");  
  76.             }  
  77.         }  
  78.         return ast;  
  79.     }  
  80.   
  81.     private AST createAST(String hql)  
  82.     {  
  83.         HqlParser parser = HqlParser.getInstance(hql);  
  84.         try  
  85.         {  
  86.             parser.statement();  
  87.         }  
  88.         catch (RecognitionException e)  
  89.         {  
  90.             throw new IllegalSyntaxException(e);  
  91.         }  
  92.         catch (TokenStreamException e)  
  93.         {  
  94.             throw new IllegalSyntaxException(e);  
  95.         }  
  96.         AST ast = parser.getAST();  
  97.         parser.getParseErrorHandler().throwQueryException();  
  98.         if (LOG.isDebugEnabled())  
  99.         {  
  100.             ByteArrayOutputStream baos = new ByteArrayOutputStream();  
  101.             parser.showAst(ast, new PrintStream(baos));  
  102.             LOG.debug("AST:" + new String(baos.toByteArray()));  
  103.         }  
  104.         return ast;  
  105.     }  
  106. }  

分享到:
评论
3 楼 onlyerlee 2007-09-27  
能把源码发到我邮箱里来吗?我想运行下知道是怎样的效果. onlyerli@ecvision.com
2 楼 pigfly 2007-05-27  
这样还是要自己写hql语句,可不可以根据结果集合自动拼凑hql呢
1 楼 ayufox 2007-05-25  
package com.ayufox.framework.core.dao.hqlx.ast;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hibernate.hql.antlr.SqlTokenTypes;

import antlr.collections.AST;

import com.ayufox.framework.core.Context;
import com.ayufox.framework.core.common.CommonUtils;
import com.ayufox.framework.core.dao.hql.Hql;
import com.ayufox.framework.core.dao.hql.Parameter;

/**
 * @author ray
 *
 */
public class HqlCompileExecutor implements SqlTokenTypes
{
	private final static Log LOG = LogFactory.getLog(HqlCompileExecutor.class);
	private final static Map NODE_TYPE_MAP = new HashMap();
	static
	{
		Field[] fields = SqlTokenTypes.class.getFields();
		for (int i = 0; i < fields.length; i++)
		{
			Field field = fields[i];
			int modifiers = field.getModifiers();
			if (Modifier.isPublic(modifiers) && Modifier.isStatic(modifiers)
					&& Modifier.isFinal(modifiers))
			{
				String name = field.getName();
				try
				{
					Object value = field.get(null);
					NODE_TYPE_MAP.put(value, name);
				}
				catch (IllegalAccessException ex)
				{
					LOG.error("Imposible", ex);
				}
			}
		}
	}

	protected Hql hql;
	protected List<Parameter> parameters;
	protected Context context;

	private AST root;

	public HqlCompileExecutor(AST root, Context context)
	{
		init(root, context);
	}
	
	protected HqlCompileExecutor()
	{
		
	}
	
	protected void init(AST root, Context context)
	{
		this.context = context;
		this.root = root;
		this.parameters = new ArrayList<Parameter>();
	}

	public Hql build()
	{
		if (this.hql == null)
		{
			String hql = buildHql();
			if (LOG.isDebugEnabled())
			{
				LOG.debug("hql:" + hql);
				for (int i = 0; i < this.parameters.size(); i++)
				{
					LOG.debug("  parameter[" + i + "]:name=["
							+ this.parameters.get(i).getName() + "],value=["
							+ this.parameters.get(i).getValue() + "]");
				}
			}
			this.hql = new Hql(hql, this.parameters
					.toArray(new Parameter[this.parameters.size()]));
		}
		return this.hql;
	}

	protected String buildHql()
	{
		StringBuffer buffer = new StringBuffer();
		print(buffer, this.root);
		return buffer.toString();
	}

	protected void print(StringBuffer buffer, AST ast)
	{
		switch (ast.getType())
		{
			case QUERY: //查询根节点
				printChildren(buffer, ast, " ");
				break;
			case RANGE: //FROM根节点
				printRANGE(buffer, ast);
				break;
			case IDENT:
			case ALIAS:
			case ROW_STAR: //*
			case DISTINCT:
			case ALL:
			case NUM_INT:
			case NUM_DOUBLE:
			case NUM_FLOAT:
			case NUM_LONG:
				printText(buffer, ast);
				break;
			case INNER:
			case OUTER:
			case FULL:
			case LEFT:
			case RIGHT:
				printText(buffer, ast);
				buffer.append(" join");
				break;
			case JOIN:
				printJOIN(buffer, ast);
				break;
			case FROM:
				printSelfAndChildren(buffer, ast, " ");
				break;
			case SELECT:
				printSelfAndChildren(buffer, ast, ",");
				break;
			case SELECT_FROM:
				printSELECT_FROM(buffer, ast);
				break;
			case COUNT:
			case AGGREGATE: //聚合函数
				printFunction(buffer, ast);
				break;
			case DOT: //.
				printDOT(buffer, ast);
				break;
			case AND:
			case OR:
				printLinkWord(buffer, ast);
				break;
			case WHERE:
				printWhere(buffer, ast);
				break;
			case CONSTRUCTOR:
				printCONSTRUCTOR(buffer, ast);
				break;
			case LIKE:
			case NOT_LIKE:
			case EQ:
			case GT:
			case GE:
			case LT:
			case LE:
			case IN:
			case NOT_IN:
				printCondition(buffer, ast);
				break;
			case BETWEEN:
			case NOT_BETWEEN:
				printBETWEEN(buffer, ast);
				break;
			case IN_LIST: //in
				printInList(buffer, ast);
				break;
			case COLON: //:
				printCOLON(buffer, ast);
				break;
			case ORDER:
				printORDER(buffer, ast);
				break;
			case GROUP:
				buffer.append("group by ");
				printChildren(buffer, ast, ",");
				break;
			case ASCENDING:
			case DESCENDING:
				printText(buffer, ast);
			    break;
			default:
				LOG.warn("Incognizance node, type["
						+ NODE_TYPE_MAP.get(ast.getType()) + "("
						+ ast.getType() + ")] text[" + ast.getText() + "]");
		}
	}

	protected void printCONSTRUCTOR(StringBuffer buffer, AST ast)
	{
		buffer.append("new ");
		AST constructor = ast.getFirstChild();
		print(buffer, constructor);
		buffer.append('(');
		AST next = constructor.getNextSibling();
		while (next != null)
		{
			if (buffer.charAt(buffer.length()-1) != '(')
			{
				buffer.append(',');
			}
			print(buffer, next);
			next = next.getNextSibling();
		}
		buffer.append(')');
	}

	protected void printORDER(StringBuffer buffer, AST ast)
	{
		buffer.append("order by ");
		AST child = ast.getFirstChild();
		AST next = null;
		while (child != null)
		{
			print(buffer, child);
			next = child.getNextSibling();
			if (next != null)
			{
				if (next.getType() == DESCENDING || next.getType() == ASCENDING)
				{
					buffer.append(" ");
				}
				else
				{
					buffer.append(",");
				}
			}
			child = next;
		}
	}

	protected void printRANGE(StringBuffer buffer, AST ast)
	{
		printChildren(buffer, ast, " ");

		AST next = ast.getNextSibling();
		if (next != null && next.getType() == RANGE)
		{
			buffer.append(",");
		}
	}

	protected void printInList(StringBuffer buffer, AST ast)
	{
		buffer.append("(");
		printChildren(buffer, ast, "");
		buffer.append(")");
	}

	protected void printWhere(StringBuffer buffer, AST ast)
	{
		//where节点下只有一个子节点
		int length = buffer.length();
		AST child = ast.getFirstChild();

		print(buffer, child);

		if (buffer.length() > length)
		{
			buffer.insert(length, ast.getText() + " ");
		}
	}

	/*
	 * 连接词,即and和or
	 */
	protected void printLinkWord(StringBuffer buffer, AST ast)
	{
		int sourceLength = buffer.length();
		AST left = ast.getFirstChild();
		AST right = left.getNextSibling();

		print(buffer, left);

		int middleLength = buffer.length();

		print(buffer, right);

		int allLength = buffer.length();

		//or/and的左表达式和右表达式都存在,需要加上括号和连接词
		if ((allLength > middleLength) && (middleLength > sourceLength))
		{
			buffer.insert(sourceLength, "(");

			buffer.insert(middleLength + 1, " " + ast.getText() + " ");

			buffer.append(")");
		}
	}

	/*
	 * :
	 */
	protected void printCOLON(StringBuffer buffer, AST ast)
	{
		int length = buffer.length();
		printChildren(buffer, ast, "");
		String name = buffer.substring(length, buffer.length()).replace('_',
				'.');
		Object value = context.get(name);
		if (CommonUtils.isEmpty(value))
		{
			throw new ValueNotExistException(name);
		}
		buffer.insert(length, ast.getText());
		this.parameters.add(new Parameter(name, value));
	}

	protected void printText(StringBuffer buffer, AST ast)
	{
		buffer.append(ast.getText());
	}

	/*
	 * 函数,如avg(..)、count(...)等
	 */
	protected void printFunction(StringBuffer buffer, AST ast)
	{
		buffer.append(ast.getText());
		buffer.append("(");
		printChildren(buffer, ast, " ");
		buffer.append(")");
	}

	/*
	 * .
	 */
	protected void printDOT(StringBuffer buffer, AST ast)
	{
		AST left = ast.getFirstChild();
		print(buffer, left);

		buffer.append(ast.getText());

		AST right = left.getNextSibling();
		print(buffer, right);
	}

	/*
	 * 打印所有子节点
	 * @param buffer 结果输出为止
	 * @param ast AST树
	 * @param join 子节点间连接符
	 */
	protected void printChildren(StringBuffer buffer, AST ast, String join)
	{
		AST child = ast.getFirstChild();
		while (child != null)
		{
			print(buffer, child);
			child = child.getNextSibling();
			if (child != null)
			{
				buffer.append(join);
			}
		}
	}

	/*
	 * 打印自身,然后再打印子节点
	 * @param buffer 
	 * @param ast
	 * @param join
	 * @param selfJoin
	 */
	protected void printSelfAndChildren(StringBuffer buffer, AST ast,
			String join)
	{
		buffer.append(ast.getText());
		buffer.append(" ");
		printChildren(buffer, ast, join);
	}

	/*
	 * join
	 */
	protected void printJOIN(StringBuffer buffer, AST ast)
	{
		int length = buffer.length();
		printChildren(buffer, ast, " ");
		if (!buffer.substring(length, buffer.length()).contains(ast.getText()))
		{
			buffer.insert(length, "join ");
		}
	}

	protected void printSELECT_FROM(StringBuffer buffer, AST ast)
	{
		AST from = ast.getFirstChild();
		AST select = from.getNextSibling();
		if (select != null)
		{
			print(buffer, select);
			buffer.append(" ");
		}
		print(buffer, from);
	}

	protected void printBETWEEN(StringBuffer buffer, AST ast)
	{
		int length = buffer.length();
		AST property = ast.getFirstChild();
		AST first = property.getNextSibling();
		AST second = first.getNextSibling();
		try
		{
			print(buffer, property);
			buffer.append(" ");
			printText(buffer, ast);
			buffer.append(" ");
			print(buffer, first);
			buffer.append(" and ");
			print(buffer, second);
		}
		catch (ValueNotExistException e)
		{
			if (buffer.lastIndexOf(":") > length)
			{
				this.parameters.remove(this.parameters.size() - 1);
			}
			//将已输出的部分截断
			buffer.setLength(length);
		}
	}

	/*
	 * 双操作数条件式,譬如like、=、in等等
	 */
	protected void printCondition(StringBuffer buffer, AST ast)
	{
		int length = buffer.length();
		AST left = ast.getFirstChild();
		print(buffer, left);

		buffer.append(" ");
		buffer.append(ast.getText());
		buffer.append(" ");

		AST right = left.getNextSibling();

		try
		{
			print(buffer, right);
		}
		catch (ValueNotExistException e)
		{
			//将已输出的部分截断
			buffer.setLength(length);
		}
	}
}

相关推荐

Global site tag (gtag.js) - Google Analytics