package com.wx.server.controller;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.github.yitter.idgen.YitIdHelper;
import com.wx.server.pojo.Company;
import com.wx.server.pojo.User;
import com.wx.server.service.ICompanyService;
import com.wx.server.service.IUserService;
import com.wx.server.util.JwtUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.oltu.oauth2.as.request.OAuthAuthzRequest;
import org.apache.oltu.oauth2.as.request.OAuthTokenRequest;
import org.apache.oltu.oauth2.as.response.OAuthASResponse;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.apache.oltu.oauth2.common.message.OAuthResponse;
import org.apache.oltu.oauth2.common.message.types.ParameterStyle;
import org.apache.oltu.oauth2.rs.request.OAuthAccessResourceRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@Controller
@RequestMapping("/auth")
@Slf4j
@SuppressWarnings("all")
public class AuthController {
@Autowired
private ICompanyService companyService;
@Autowired
private IUserService userService;
@Autowired
private RedisTemplate redisTemplate;
@GetMapping("/code")
public String sendCode(HttpServletRequest request, Model model) {
try {
//解析请求
OAuthAuthzRequest oathReq = new OAuthAuthzRequest(request);
//获取到客户端的id
String clientId = oathReq.getClientId();
if (clientId == null) return "redirect:/error";
Company company = companyService.getOne(new QueryWrapper<Company>().lambda().eq(Company::getClientId, clientId), false);
//如果客户端跟我们没有合作
if (company == null) return "redirect:/error";
model.addAttribute("redirect_uri", oathReq.getRedirectURI());
model.addAttribute("company_info", company);
} catch (Exception e) {
return "error";
}
//跳到认证登录页
return "login";
}
@PostMapping("/doLogin")
public String doLogin(User user, String redirect_uri, Model model) {
//用户登录
User one = userService.getOne(new QueryWrapper<User>().lambda()
.eq(User::getAccount, user.getAccount())
.eq(User::getPassword, user.getPassword()));
//登录失败
if (one == null) return "redirect:/error";
//生成一个雪花id
long snowId = YitIdHelper.nextId();
//将需要的数据放到缓存中
//1.用作等会需要的判断,判断用户进行授权的请求是否来源与后端
//2.将用户的数据放入到缓存中,避免在前端暴露
redisTemplate.opsForValue().set("request:" + snowId, redirect_uri, 5, TimeUnit.MINUTES);
redisTemplate.opsForValue().set("user:" + snowId, one.getId());
model.addAttribute("snowId", snowId + "");
return "auth";
}
@GetMapping("/doAuth")
public String doAuth(HttpServletRequest request, String snowId) throws Exception {
//回调路径
Object obj = redisTemplate.opsForValue().get("request:" + snowId);
if (obj == null) {
return "redirect:/error";
}
//request:123456
String userId = redisTemplate.opsForValue().get("user:" + snowId).toString();
String code = getCode();
redisTemplate.opsForValue().set("code:" + code, userId, 5, TimeUnit.MINUTES);
//获取构建响应的对象
OAuthASResponse.OAuthAuthorizationResponseBuilder builder =
OAuthASResponse.authorizationResponse(request, HttpServletResponse.SC_OK);
builder.setCode(code);
String redirectURI = obj.toString();
OAuthResponse oauthResp = builder.location(redirectURI).buildQueryMessage();
//形成路径路径拼接效果 http://localhost:80/client/callback?code=xx
String uri = oauthResp.getLocationUri();
// redirect:/client/callback?code=xx
return "redirect:" + uri;
}
/**
* 生成授权码方法
*/
public String getCode() {
Random r = new Random();
String code = "";
for (int i = 0; i < 8; ++i) {
int temp = r.nextInt(52);
char x = (char) (temp < 26 ? temp + 97 : (temp % 26) + 65);
code += x;
}
return code;
}
@PostMapping("/token")
public HttpEntity getAccessToken(HttpServletRequest request) throws OAuthProblemException, OAuthSystemException {
//OAuthTokenRequest解析请求
OAuthTokenRequest tokenReq = new OAuthTokenRequest(request);
//获得客户端的信息
String clientId = tokenReq.getClientId();
String clientSecret = tokenReq.getClientSecret();
Company company = companyService.getOne(new QueryWrapper<Company>().lambda().eq(Company::getClientId, clientId).eq(Company::getClientSecret, clientSecret), false);
//去数据库做查询
if (company != null) {
//做授权码的判断
String code = tokenReq.getCode();
//将授权码带入到缓存中查看是否有对应的数据
String userId = redisTemplate.opsForValue().get("code:" + code).toString();
if(userId==null){
System.out.println("授权码可能过期或者伪造了");
return null;
}
String openId = userService.getById(userId).getOpenId();
String token = new JwtUtils().generateToken(openId);
//构造保护令牌的响应对象
OAuthResponse oAuthResponse = OAuthASResponse
.tokenResponse(HttpServletResponse.SC_OK)
.setAccessToken(token)
.buildJSONMessage();
return new ResponseEntity(oAuthResponse.getBody(), HttpStatus.valueOf(oAuthResponse.getResponseStatus()));
}
return null;
}
@GetMapping("/userinfo")
@ResponseBody
public Object getUserInfo(HttpServletRequest request) throws OAuthProblemException, OAuthSystemException {
//OAuthAccessResourceRequest解析请求
OAuthAccessResourceRequest oAuthAccessResourceRequest = new OAuthAccessResourceRequest(request, ParameterStyle.HEADER);
//请求头中获取令牌
String token = oAuthAccessResourceRequest.getAccessToken();
//判断令牌是否是我发给你的
String openId = new JwtUtils().getOpenIdFromToken(token);
if (openId == null) {
return null;
}
User user = userService.getOne(new QueryWrapper<User>().lambda().eq(User::getOpenId, openId));
return user;
}
}