package com.demo3.interceper;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
@Intercepts({@Signature(method="prepare",type=StatementHandler.class,args={Connection.class})})
public class PageInterceptor implements Interceptor{
private String databaseType; //数据库类型,不同的数据库有不同的分页方法
/**
* 拦截后要执行的方法
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 对于StatementHandler其实只有两个实现类,一个是RoutingSattementHandler,另一个是抽象类BaseStatementhandler
//BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler
//SimpleStatementHandler是用于处理Statement的;PreparedStatementHandler是用于处理PreparedStatement的;
//CallableStatementHandler是用于处理CallStatement的
//mbatis进行sql语句处理时都是建立RountingStatementHandler,而在RountingStatementHandler里面有一个StatementHandler类型的delegate属性
//RountingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler
//PreparedStatementHandler或CallStatementHandler,在RoutingStatemnetHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法
//我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler
//的时候,是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象
RoutingStatementHandler handler = (RoutingStatementHandler)invocation.getTarget();
//通过反射获取到当前RoutingStatementHandler的delegate属性
StatementHandler delegate = (StatementHandler)ReflectUtil.getFieldValue(handler,"delegate");
//获取到当前Statementhandler的boundSql,这里不管是调用handler.getBoundSql(),
//还是直接调用delegate.getBoundSql()结果是一样的,因为之前已经说过了
//RoutingStatementHandler实现的所有StatementHandler接口方法里面都是调用的delegate对应的方法。
BoundSql boundsql = delegate.getBoundSql();
//拿到当前绑定Sql的参数对象,就是我们在调用对应的Mapper映射语句时所传入的参数对象
Object obj = boundsql.getParameterObject();
//这里我们简单的通过传入的是Page对象就认定它是需要进行分页操作的。
if(obj instanceof Page<?>){
Page<?> page = (Page<?>)obj;
//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate,"mappedStatement");
//拦截到的prepare方法参数是一个Connection对象
Connection conn = (Connection)invocation.getArgs()[0];
//获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
String sql = boundsql.getSql();
//给当前的page参数对象设置总记录数
this.setTotalRecord(page, mappedStatement, conn);
//获取分页Sql语句
String pageSql = this.getPageSql(page,sql);
//利用反射设置当前BoundSql对应的sql属性为我们建立好分页Sql语居
ReflectUtil.setFieldValue(boundsql,"sql",pageSql);
}
return invocation.proceed();
}
/**
* 拦截器对应的封装原始对象的方法
* @param arg0
* @return
*/
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 设置注册拦截器时设定的属性
*/
@Override
public void setProperties(Properties properties) {
this.databaseType = properties.getProperty("databaseType");
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* @param page
* @return
*/
private String getPageSql(Page<?> page,String sql){
StringBuffer sqlBuffer = new StringBuffer();
sqlBuffer.append(sql);
if("mysql".equalsIgnoreCase(databaseType)){
return getMysqlPageSql(page, sqlBuffer);
}else if("oracle".equalsIgnoreCase(databaseType)){
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}
/**
* 获取Mysql数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(Page<?> page,StringBuffer buffer){
//计算第一条记录的位置,Mysql中记录的位置从0开始
int offset = (page.getPageNo()-1)*page.getPageSize();
buffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
return buffer.toString();
}
/**
* 获取Oracle数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(Page<?> page,StringBuffer buffer){
//计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNo()-1)*page.getPageSize()+1;
buffer.insert(0, "select b.*,rownum r from (").append(") b where rownum < ").append(offset+page.getPageSize());
buffer.insert(0, "select * from (").append(") where r >=").append(offset);
//上面的Sql语句拼接之后大概是这个样子:
//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
return buffer.toString();
}
private void setTotalRecord(Page<?> page,MappedStatement ms,Connection conn){
//获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
//delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
BoundSql boundsql = ms.getBoundSql(page);
//获取到我们自己写在Mapper映射语句中对应的Sql语句
String sql = boundsql.getSql();
//通过查询Sql语句获取到对应的计算总记录数的sql语句
String countsql = this.getCountSql(sql);
//通过BoundSql获取对应的参数映射
List<ParameterMapping> parameterMapping = boundsql.getParameterMappings();
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countsql, parameterMapping, page);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(ms, page, boundsql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement past = null;
ResultSet rs = null;
try {
past = conn.prepareStatement(countsql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(past);
//之后就是执行获取总记录数的Sql语句和获取结果了。
rs = past.executeQuery();
if(rs.next()){
int totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setTotalRecord(totalRecord);
}
} catch (SQLException e) {
e.printStackTrace();
}finally{
try {
if(rs!=null)
rs.close();
if(past!=null)
past.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
* @param sql
* @return
*/
private String getCountSql(String sql){
int index = sql.indexOf("from");
return "select count(*)"+ sql.substring(index);
}
/**
* 利用反射进行操作的一个工具类
*
*/
private static class ReflectUtil{
/**
* 利用反射获取指定对象的指定属性
* @param obj 目标对象
* @param f