189 lines
4.1 KiB
Go
189 lines
4.1 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"licserver/internal/model"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type TokenService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewTokenService(db *gorm.DB) *TokenService {
|
|
return &TokenService{db: db}
|
|
}
|
|
|
|
// 生成访问令牌
|
|
func generateToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|
|
// 创建访问令牌
|
|
func (s *TokenService) CreateToken(deviceUID, tokenType string, expireTime time.Time, ipList []string, createdBy uint) (*model.AccessToken, error) {
|
|
// 验证设备是否存在
|
|
var device model.Device
|
|
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
|
|
return nil, errors.New("设备不存在")
|
|
}
|
|
|
|
// 生成令牌
|
|
token, err := generateToken()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 创建令牌记录
|
|
accessToken := &model.AccessToken{
|
|
Token: token,
|
|
DeviceUID: deviceUID,
|
|
Type: tokenType,
|
|
Status: "active",
|
|
ExpireTime: expireTime,
|
|
IPList: strings.Join(ipList, ","),
|
|
CreatedBy: createdBy,
|
|
}
|
|
|
|
if err := s.db.Create(accessToken).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 记录日志
|
|
tokenLog := &model.TokenLog{
|
|
TokenID: accessToken.ID,
|
|
Action: "create",
|
|
Status: "success",
|
|
Message: "创建访问令牌",
|
|
}
|
|
s.db.Create(tokenLog)
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
// 验证令牌
|
|
func (s *TokenService) ValidateToken(token, ip string) (*model.AccessToken, error) {
|
|
var accessToken model.AccessToken
|
|
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
|
|
return nil, errors.New("无效的令牌")
|
|
}
|
|
|
|
// 检查令牌状态
|
|
if accessToken.Status != "active" {
|
|
return nil, errors.New("令牌已被撤销")
|
|
}
|
|
|
|
// 检查过期时间
|
|
if time.Now().After(accessToken.ExpireTime) {
|
|
return nil, errors.New("令牌已过期")
|
|
}
|
|
|
|
// 检查IP限制
|
|
if accessToken.IPList != "" {
|
|
allowedIPs := strings.Split(accessToken.IPList, ",")
|
|
allowed := false
|
|
for _, allowedIP := range allowedIPs {
|
|
if allowedIP == ip {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
return nil, errors.New("IP地址不允许访问")
|
|
}
|
|
}
|
|
|
|
// 更新使用记录
|
|
s.db.Model(&accessToken).Updates(map[string]interface{}{
|
|
"last_used": time.Now(),
|
|
"usage_count": gorm.Expr("usage_count + 1"),
|
|
})
|
|
|
|
// 记录日志
|
|
tokenLog := &model.TokenLog{
|
|
TokenID: accessToken.ID,
|
|
Action: "use",
|
|
IP: ip,
|
|
Status: "success",
|
|
}
|
|
s.db.Create(tokenLog)
|
|
|
|
return &accessToken, nil
|
|
}
|
|
|
|
// 撤销令牌
|
|
func (s *TokenService) RevokeToken(token string, userID uint) error {
|
|
var accessToken model.AccessToken
|
|
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
|
|
return errors.New("令牌不存在")
|
|
}
|
|
|
|
// 检查权限
|
|
if accessToken.CreatedBy != userID {
|
|
return errors.New("无权操作此令牌")
|
|
}
|
|
|
|
// 更新状态
|
|
if err := s.db.Model(&accessToken).Update("status", "revoked").Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// 记录日志
|
|
tokenLog := &model.TokenLog{
|
|
TokenID: accessToken.ID,
|
|
Action: "revoke",
|
|
Status: "success",
|
|
Message: "撤销访问令牌",
|
|
}
|
|
s.db.Create(tokenLog)
|
|
|
|
return nil
|
|
}
|
|
|
|
// 获取令牌列表
|
|
func (s *TokenService) GetTokens(deviceUID string, page, pageSize int) ([]model.AccessToken, int64, error) {
|
|
var tokens []model.AccessToken
|
|
var total int64
|
|
|
|
query := s.db.Model(&model.AccessToken{})
|
|
if deviceUID != "" {
|
|
query = query.Where("device_uid = ?", deviceUID)
|
|
}
|
|
|
|
query.Count(&total)
|
|
|
|
if page > 0 && pageSize > 0 {
|
|
offset := (page - 1) * pageSize
|
|
query = query.Offset(offset).Limit(pageSize)
|
|
}
|
|
|
|
err := query.Order("created_at DESC").Find(&tokens).Error
|
|
return tokens, total, err
|
|
}
|
|
|
|
// 获取令牌日志
|
|
func (s *TokenService) GetTokenLogs(tokenID uint, page, pageSize int) ([]model.TokenLog, int64, error) {
|
|
var logs []model.TokenLog
|
|
var total int64
|
|
|
|
query := s.db.Model(&model.TokenLog{}).Where("token_id = ?", tokenID)
|
|
query.Count(&total)
|
|
|
|
if page > 0 && pageSize > 0 {
|
|
offset := (page - 1) * pageSize
|
|
query = query.Offset(offset).Limit(pageSize)
|
|
}
|
|
|
|
err := query.Order("created_at DESC").Find(&logs).Error
|
|
return logs, total, err
|
|
}
|