package com.wujialiang.auth.interceptor;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.wujialiang.auth.context.UserContext;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.core.annotation.Order;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
/**
* 原生mybatis拦截方式,如果用该类拦截SQL,需要注意和分页插件的顺序
*
* @author RudeCrab
*/
@Slf4j
//@Component
@Order(value = 1000)
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class DataInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 拿到mybatis的一些对象,等下要操作
StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
// id为执行的mapper方法的全路径名,如com.rudecrab.mapper.UserMapper.insertUser
String id = mappedStatement.getId();
log.info("mapper: ==> {}", id);
// 如果不是指定的方法,直接结束拦截
// 如果方法多可以存到一个集合里,然后判断当前拦截的是否存在集合中
if (!"com.wujialiang.auth.mapper.DataMapper.getAllDatas".equals(id)) {
return invocation.proceed();
}
// 获取到原始sql语句
String sql = statementHandler.getBoundSql().getSql();
log.info("原始SQL语句: ==> {}", sql);
sql = getSql(sql);
// 修改sql
metaObject.setValue("delegate.boundSql.sql", sql);
log.info("拦截后SQL语句:==>{}", sql);
return invocation.proceed();
}
/**
* 解析SQL语句,并返回新的SQL语句
* @param sql 原SQL
* @return 新SQL
*/
private String getSql(String sql) {
try {
// 解析语句
Statement stmt = CCJSqlParserUtil.parse(sql);
Select selectStatement = (Select) stmt;
PlainSelect ps = (PlainSelect) selectStatement.getSelectBody();
// 拿到表信息
FromItem fromItem = ps.getFromItem();
Table table = (Table) fromItem;
String mainTable = table.getAlias() == null ? table.getName() : table.getAlias().getName();
List<Join> joins = ps.getJoins();
if (joins == null) {
joins = new ArrayList<>(1);
}
// 创建连表join条件
Join join = new Join();
join.setInner(true);
join.setRightItem(new Table("user_company uc"));
// 第一个:两表通过company_id连接
EqualsTo joinExpression = new EqualsTo();
joinExpression.setLeftExpression(new Column(mainTable + ".company_id"));
joinExpression.setRightExpression(new Column("uc.company_id"));
// 第二个条件:和当前登录用户id匹配
EqualsTo userIdExpression = new EqualsTo();
userIdExpression.setLeftExpression(new Column("uc.user_id"));
//并没有UserContext.getCurrentUserName(),所以先写死
userIdExpression.setRightExpression(new LongValue("1"));
//userIdExpression.setRightExpression(new LongValue(UserContext.getCurrentUserName()));
//userIdExpression.setRightExpression(new StringValue(UserContext.getCurrentUserName()));
// 将两个条件拼接起来
join.setOnExpression(new AndExpression(joinExpression, userIdExpression));
joins.add(join);
ps.setJoins(joins);
// 修改原语句
sql = ps.toString();
} catch (JSQLParserException e) {
e.printStackTrace();
}
return sql;
}
}