package com.xd.pre.modules.security.social;
import com.xd.pre.modules.data.tenant.PreTenantContextHolder;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.security.crypto.encrypt.TextEncryptor;
import org.springframework.social.connect.*;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.*;
import java.util.Map.Entry;
/**
* @Classname PreConnectionData
* @Description 个别sql语句问题 需要重写
* @Author Created by Lihaodong (alias:小东啊) lihaodongmail@163.com
* @Date 2019-07-19 09:18
* @Version 1.0
*/
@ConditionalOnClass(ConnectionRepository.class)
public class PreJdbcConnectionRepository implements ConnectionRepository {
private final String userId;
private final JdbcTemplate jdbcTemplate;
private final ConnectionFactoryLocator connectionFactoryLocator;
private final TextEncryptor textEncryptor;
private final String tablePrefix;
private final String RANK = "`rank`";
public PreJdbcConnectionRepository(String userId, JdbcTemplate jdbcTemplate, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor, String tablePrefix) {
this.userId = userId;
this.jdbcTemplate = jdbcTemplate;
this.connectionFactoryLocator = connectionFactoryLocator;
this.textEncryptor = textEncryptor;
this.tablePrefix = tablePrefix;
}
@Override
public MultiValueMap<String, Connection<?>> findAllConnections() {
List<Connection<?>> resultList = jdbcTemplate.query(selectFromUserConnection() + " where userId = ? order by providerId, `rank`", connectionMapper, userId);
MultiValueMap<String, Connection<?>> connections = new LinkedMultiValueMap<String, Connection<?>>();
Set<String> registeredProviderIds = connectionFactoryLocator.registeredProviderIds();
for (String registeredProviderId : registeredProviderIds) {
connections.put(registeredProviderId, Collections.<Connection<?>>emptyList());
}
for (Connection<?> connection : resultList) {
String providerId = connection.getKey().getProviderId();
if (connections.get(providerId).size() == 0) {
connections.put(providerId, new LinkedList<Connection<?>>());
}
connections.add(providerId, connection);
}
return connections;
}
@Override
public List<Connection<?>> findConnections(String providerId) {
return jdbcTemplate.query(selectFromUserConnection() + " where userId = ? and providerId = ? order by `rank`", connectionMapper, userId, providerId);
}
@Override
public <A> List<Connection<A>> findConnections(Class<A> apiType) {
List<?> connections = findConnections(getProviderId(apiType));
return (List<Connection<A>>) connections;
}
@Override
public MultiValueMap<String, Connection<?>> findConnectionsToUsers(MultiValueMap<String, String> providerUsers) {
if (providerUsers == null || providerUsers.isEmpty()) {
throw new IllegalArgumentException("Unable to execute find: no providerUsers provided");
}
StringBuilder providerUsersCriteriaSql = new StringBuilder();
MapSqlParameterSource parameters = new MapSqlParameterSource();
parameters.addValue("userId", userId);
for (Iterator<Entry<String, List<String>>> it = providerUsers.entrySet().iterator(); it.hasNext(); ) {
Entry<String, List<String>> entry = it.next();
String providerId = entry.getKey();
providerUsersCriteriaSql.append("providerId = :providerId_").append(providerId).append(" and providerUserId in (:providerUserIds_").append(providerId).append(")");
parameters.addValue("providerId_" + providerId, providerId);
parameters.addValue("providerUserIds_" + providerId, entry.getValue());
if (it.hasNext()) {
providerUsersCriteriaSql.append(" or ");
}
}
List<Connection<?>> resultList = new NamedParameterJdbcTemplate(jdbcTemplate).query(selectFromUserConnection() + " where userId = :userId and " + providerUsersCriteriaSql + " order by providerId, `rank`", parameters, connectionMapper);
MultiValueMap<String, Connection<?>> connectionsForUsers = new LinkedMultiValueMap<String, Connection<?>>();
for (Connection<?> connection : resultList) {
String providerId = connection.getKey().getProviderId();
List<String> userIds = providerUsers.get(providerId);
List<Connection<?>> connections = connectionsForUsers.get(providerId);
if (connections == null) {
connections = new ArrayList<Connection<?>>(userIds.size());
for (int i = 0; i < userIds.size(); i++) {
connections.add(null);
}
connectionsForUsers.put(providerId, connections);
}
String providerUserId = connection.getKey().getProviderUserId();
int connectionIndex = userIds.indexOf(providerUserId);
connections.set(connectionIndex, connection);
}
return connectionsForUsers;
}
@Override
public Connection<?> getConnection(ConnectionKey connectionKey) {
try {
return jdbcTemplate.queryForObject(selectFromUserConnection() + " where userId = ? and providerId = ? and providerUserId = ?", connectionMapper, userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
} catch (EmptyResultDataAccessException e) {
throw new NoSuchConnectionException(connectionKey);
}
}
@SuppressWarnings("unchecked")
@Override
public <A> Connection<A> getConnection(Class<A> apiType, String providerUserId) {
String providerId = getProviderId(apiType);
return (Connection<A>) getConnection(new ConnectionKey(providerId, providerUserId));
}
@SuppressWarnings("unchecked")
@Override
public <A> Connection<A> getPrimaryConnection(Class<A> apiType) {
String providerId = getProviderId(apiType);
Connection<A> connection = (Connection<A>) findPrimaryConnection(providerId);
if (connection == null) {
throw new NotConnectedException(providerId);
}
return connection;
}
@SuppressWarnings("unchecked")
@Override
public <A> Connection<A> findPrimaryConnection(Class<A> apiType) {
String providerId = getProviderId(apiType);
return (Connection<A>) findPrimaryConnection(providerId);
}
/**
* 添加时需要带上租户id
* @param connection
*/
@Transactional(rollbackFor = Exception.class)
@Override
public void addConnection(Connection<?> connection) {
try {
// 获取租户id
Long tenantId = PreTenantContextHolder.getCurrentTenantId();
ConnectionData data = connection.createData();
int rank = jdbcTemplate.queryForObject("select coalesce(max(" + RANK + ") + 1, 1) as ranks from " + tablePrefix + "UserConnection where userId = ? and providerId = ? and tenant_id = ?", new Object[]{userId, data.getProviderId(),tenantId}, Integer.class);
jdbcTemplate.update("insert into " + tablePrefix + "UserConnection (userId, providerId, providerUserId, `rank`, displayName, profileUrl, imageUrl, accessToken, secret, r