package com.tencentcloudapi.gateway.filter;
import com.tencentcloudapi.common.Sign;
import com.tencentcloudapi.gateway.api.dto.BaseResponse;
import com.tencentcloudapi.gateway.api.dto.UserInfo;
import com.tencentcloudapi.gateway.api.service.UserManageService;
import com.tencentcloudapi.gateway.api.utils.FilterUtil;
import io.micrometer.common.util.StringUtils;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import javax.xml.bind.DatatypeConverter;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @Description 限流过滤器
* @Author miller.Lai
* @Date 2023-11-06 10:23
*/
@Component
@Slf4j
public class RequestLimitFilter implements GlobalFilter, Ordered {
@Resource
private UserManageService userManageService;
private final ConcurrentHashMap<String, AtomicInteger> requestLimitMap =new ConcurrentHashMap();
@Value("${tencentcloud.api.request.maxQps:60}")
private int maxQPS;
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
// 接口请求所在秒钟
String currentTime = String.valueOf(System.currentTimeMillis() / 1000L);
// 每秒并发数限制
synchronized(this){
AtomicInteger currentQPS = requestLimitMap.get(currentTime);
if(currentQPS != null ){
if(currentQPS.get() >= maxQPS){
log.error("当前请求达到QPS上线,请求被拒绝处理");
return FilterUtil.getVoidMono(exchange,new BaseResponse("500","系统繁忙,请稍后重试",null));
}else{
currentQPS.getAndIncrement();
}
}else{
requestLimitMap.clear();
requestLimitMap.put(currentTime,new AtomicInteger(1));
}
}
ServerHttpRequest request = exchange.getRequest();
List<String> authorizationList = request.getHeaders().get("authorization");
List<String> actionList = request.getHeaders().get("X-TC-Action");
// 如果请求头中存在 authorization 信息,则要进行替换,以适应现有的接口TPS限制策略
if (!CollectionUtils.isEmpty(authorizationList) && StringUtils.isNotEmpty(authorizationList.get(0))&&
!CollectionUtils.isEmpty(actionList) && StringUtils.isNotEmpty(actionList.get(0))) {
UserInfo userInfo = new UserInfo();
// 接口名称
String action = null;
try {
action = actionList.get(0);
// 获取可用的腾讯秘钥,这是一个阻塞方法
userInfo = userManageService.getAvailableUserInfo(action,exchange);
String secretId = userInfo.getSecretInfo().getSecretId();
String secretKey = userInfo.getSecretInfo().getSecretKey();
// 根据可用的秘钥对请求头中的认证信息做重新生成
String signedHeaders = "content-type;host";
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
sdf.setTimeZone(TimeZone.getTimeZone("UTC"));
// 接口请求所在秒钟
String timestamp = String.valueOf(System.currentTimeMillis() / 1000L);
String date = sdf.format(new Date(Long.parseLong(timestamp + "000")));
String service = request.getHeaders().get("service") != null ? request.getHeaders().get("service").get(0) : "";
String credentialScope = date + "/" + service + "/tc3_request";
String stringToSign = request.getHeaders().get("stringToSign") != null ? new String(Base64.getDecoder().decode(Objects.requireNonNull(request.getHeaders().get("stringToSign")).get(0).getBytes())) : "";
try {
byte[] secretDate = Sign.hmac256(("TC3" + secretKey).getBytes(StandardCharsets.UTF_8), date);
byte[] secretService = Sign.hmac256(secretDate, service);
byte[] secretSigning = Sign.hmac256(secretService, "tc3_request");
String signature = DatatypeConverter.printHexBinary(Sign.hmac256(secretSigning, stringToSign)).toLowerCase();
String authorization = "TC3-HMAC-SHA256 Credential=" + secretId + "/" + credentialScope + ", SignedHeaders=" + signedHeaders + ", Signature=" + signature;
exchange.getRequest().mutate().headers(httpHeaders -> {
// 去除自定义的请求头
httpHeaders.remove("service");
httpHeaders.remove("stringToSign");
// 去除不合法的认证信息
httpHeaders.remove("authorization");
// 塞入有效的认证信息
httpHeaders.add("authorization", authorization);
});
} catch (Exception e) {
log.error(e.getMessage());
return FilterUtil.getVoidMono(exchange,new BaseResponse("500",e.getMessage(),null));
}
UserInfo finalUserInfo = userInfo;
String finalAction = action;
// 获取是否正常获取信号量许可的标志
String hasAcquired =(String) exchange.getAttributes().get("hasAcquired");
log.info("{} ************* 请求头修饰完毕,开始进行转发,修饰后的请求头信息: *************",exchange.getAttributes().get("requestId"));
// 打印请求头信息
exchange.getRequest().getHeaders().forEach((header, values) -> {
log.info( exchange.getAttributes().get("requestId")+" " +header + " : " + values);
});
log.info("{} ************* 秘钥信息: *************",exchange.getAttributes().get("requestId"));
log.info("{} secretId:{}",exchange.getAttributes().get("requestId"),secretId);
log.info("{} secretKey:{}",exchange.getAttributes().get("requestId"),secretKey);
return chain.filter(exchange).then( Mono.fromRunnable(() -> {
// 接口逻辑执行完毕后释放信号量锁
if("1".equals(hasAcquired)) {
finalUserInfo.getInterfaceInfo(finalAction).getSemaphore().release();
log.info("{} ************* 已释放请求通行许可 *************",exchange.getAttributes().get("requestId"));
}
}));
} catch (Exception e) {
// 在发生错误时释放信号量锁
// 将当前线程标记为已获取许可
// 获取是否正常获取信号量许可的标志
String hasAcquired =(String) exchange.getAttributes().get("hasAcquired");
if("1".equals(hasAcquired)){
userInfo.getInterfaceInfo(action).getSemaphore().release();
log.info("{} ************* 已释放请求通行许可 *************",exchange.getAttributes().get("requestId"));
}