first commit
This commit is contained in:
100
internal/service/captcha.go
Normal file
100
internal/service/captcha.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"licserver/internal/model"
|
||||
"licserver/internal/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type CaptchaService struct {
|
||||
db *gorm.DB
|
||||
emailConfig *utils.EmailConfig
|
||||
store base64Captcha.Store
|
||||
}
|
||||
|
||||
func NewCaptchaService(db *gorm.DB, emailConfig *utils.EmailConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
db: db,
|
||||
emailConfig: emailConfig,
|
||||
store: base64Captcha.DefaultMemStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SendEmailCaptcha(email, captchaType string) error {
|
||||
// 检查是否存在未过期的验证码
|
||||
var count int64
|
||||
s.db.Model(&model.Captcha{}).
|
||||
Where("target = ? AND type = ? AND expires_at > ? AND used = ?",
|
||||
email, captchaType, time.Now(), false).
|
||||
Count(&count)
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("请勿频繁发送验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := utils.GenerateCaptcha()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存验证码
|
||||
captcha := model.Captcha{
|
||||
Code: code,
|
||||
Type: captchaType,
|
||||
Target: email,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&captcha).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送验证码邮件
|
||||
emailService := utils.NewEmailService(s.emailConfig)
|
||||
content := utils.GenerateEmailCaptchaContent(code, email, captchaType)
|
||||
return emailService.SendEmail(email, "验证码", content)
|
||||
}
|
||||
|
||||
func (s *CaptchaService) VerifyCaptcha(target, captchaType, code string) error {
|
||||
var captcha model.Captcha
|
||||
err := s.db.Where("target = ? AND type = ? AND code = ? AND used = ? AND expires_at > ?",
|
||||
target, captchaType, code, false, time.Now()).
|
||||
First(&captcha).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("验证码无效或已过期")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 标记验证码为已使用
|
||||
return s.db.Model(&captcha).Update("used", true).Error
|
||||
}
|
||||
|
||||
// 生成图片验证码
|
||||
func (s *CaptchaService) GenerateImageCaptcha() (string, string, error) {
|
||||
driver := base64Captcha.NewDriverDigit(80, 240, 6, 0.7, 80)
|
||||
c := base64Captcha.NewCaptcha(driver, s.store)
|
||||
id, b64s, err := c.Generate()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 确保返回的base64字符串不包含前缀
|
||||
b64s = strings.TrimPrefix(b64s, "data:image/png;base64,")
|
||||
|
||||
return id, b64s, nil
|
||||
}
|
||||
|
||||
// 验证图片验证码
|
||||
func (s *CaptchaService) VerifyImageCaptcha(id, code string) bool {
|
||||
return s.store.Verify(id, code, true)
|
||||
}
|
721
internal/service/device.go
Normal file
721
internal/service/device.go
Normal file
@@ -0,0 +1,721 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"licserver/internal/model"
|
||||
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DeviceService struct {
|
||||
db *gorm.DB
|
||||
|
||||
licenseService *LicenseService
|
||||
}
|
||||
|
||||
func NewDeviceService(db *gorm.DB, licenseService *LicenseService) *DeviceService {
|
||||
|
||||
return &DeviceService{
|
||||
|
||||
db: db,
|
||||
|
||||
licenseService: licenseService,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type DeviceRegisterInput struct {
|
||||
UID string `json:"uid" binding:"required"`
|
||||
DeviceModel string `json:"device_model" binding:"required"`
|
||||
LicenseCode string `json:"license_code"`
|
||||
}
|
||||
|
||||
func (s *DeviceService) RegisterDevice(input *DeviceRegisterInput, ip string) error {
|
||||
// 检查设备型号是否存在且处于启用状态
|
||||
var deviceModel model.DeviceModel
|
||||
if err := s.db.Where("model_name = ? AND status = ?", input.DeviceModel, "active").First(&deviceModel).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("设备型号不存在或已禁用")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查设备是否已注册
|
||||
var count int64
|
||||
s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("设备已注册")
|
||||
}
|
||||
|
||||
// 创建设备记录
|
||||
device := &model.Device{
|
||||
UID: input.UID,
|
||||
DeviceType: deviceModel.DeviceType,
|
||||
DeviceModel: input.DeviceModel,
|
||||
Company: deviceModel.Company,
|
||||
RegisterTime: time.Now(),
|
||||
Status: "inactive",
|
||||
LastActiveAt: time.Now(),
|
||||
StartCount: 0,
|
||||
}
|
||||
|
||||
// 如果提供了授权码,进行授权绑定
|
||||
if input.LicenseCode != "" {
|
||||
license, err := s.licenseService.GetLicenseByCode(input.LicenseCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if license.Status != "unused" {
|
||||
return errors.New("授权码已被使用")
|
||||
}
|
||||
|
||||
device.Status = "active"
|
||||
device.LicenseCode = license.Code
|
||||
device.LicenseType = license.LicenseType
|
||||
device.MaxUses = license.MaxUses
|
||||
device.Duration = license.Duration
|
||||
|
||||
if license.LicenseType == "time" {
|
||||
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建设备记录
|
||||
if err := tx.Create(device).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果有授权码,更新授权码状态
|
||||
if device.LicenseCode != "" {
|
||||
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", device.LicenseCode).
|
||||
Updates(map[string]interface{}{
|
||||
"status": "used",
|
||||
"used_by": input.UID,
|
||||
"used_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 记录设备日志
|
||||
logMsg := "设备注册成功"
|
||||
if device.LicenseCode != "" {
|
||||
logMsg += fmt.Sprintf(",使用授权码: %s", device.LicenseCode)
|
||||
}
|
||||
log := model.DeviceLog{
|
||||
DeviceUID: input.UID,
|
||||
Action: "register",
|
||||
Message: logMsg,
|
||||
Status: "success",
|
||||
}
|
||||
if err := tx.Create(&log).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DeviceService) ValidateDevice(uid string) error {
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
|
||||
return errors.New("设备未注册")
|
||||
}
|
||||
|
||||
// 更新最后活跃时间
|
||||
device.LastActiveAt = time.Now()
|
||||
|
||||
// 如果设备已激活,检查授权状态
|
||||
if device.Status == "active" {
|
||||
if device.LicenseCode != "" {
|
||||
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
|
||||
device.Status = "expired"
|
||||
s.db.Save(&device)
|
||||
return errors.New("设备授权已过期")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.db.Save(&device).Error
|
||||
}
|
||||
|
||||
func (s *DeviceService) GetDevices(params *DeviceQueryParams) ([]model.Device, int64, error) {
|
||||
|
||||
var devices []model.Device
|
||||
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.Device{})
|
||||
|
||||
if params.UID != "" {
|
||||
|
||||
query = query.Where("uid LIKE ?", "%"+params.UID+"%")
|
||||
|
||||
}
|
||||
|
||||
if params.DeviceType != "" {
|
||||
|
||||
query = query.Where("device_type = ?", params.DeviceType)
|
||||
|
||||
}
|
||||
|
||||
if params.Company != "" {
|
||||
|
||||
query = query.Where("company LIKE ?", "%"+params.Company+"%")
|
||||
|
||||
}
|
||||
|
||||
if params.LicenseType != "" {
|
||||
|
||||
query = query.Where("license_type = ?", params.LicenseType)
|
||||
|
||||
}
|
||||
|
||||
if params.Status != "" {
|
||||
|
||||
query = query.Where("status = ?", params.Status)
|
||||
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
if params.Page > 0 && params.PageSize > 0 {
|
||||
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
|
||||
query = query.Offset(offset).Limit(params.PageSize)
|
||||
|
||||
}
|
||||
|
||||
err := query.Find(&devices).Error
|
||||
|
||||
return devices, total, err
|
||||
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateStartCount(uid string) error {
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新启动次数和最后活跃时间
|
||||
device.StartCount++
|
||||
device.LastActiveAt = time.Now()
|
||||
|
||||
// 如果设备已激活,检查授权状态
|
||||
if device.Status == "active" {
|
||||
// 检查授权码有效性
|
||||
if device.LicenseCode != "" {
|
||||
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
|
||||
device.Status = "expired"
|
||||
}
|
||||
}
|
||||
// 检查次数限制
|
||||
if device.LicenseType == "count" && device.StartCount >= device.MaxUses {
|
||||
device.Status = "expired"
|
||||
}
|
||||
}
|
||||
|
||||
// 记录设备日志
|
||||
log := model.DeviceLog{
|
||||
DeviceUID: uid,
|
||||
Action: "start",
|
||||
Message: fmt.Sprintf("设备启动,当前次数:%d", device.StartCount),
|
||||
Status: "success",
|
||||
}
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Save(&device).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&log).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DeviceService) UpdateDevice(uid string, updates map[string]interface{}) error {
|
||||
|
||||
return s.db.Model(&model.Device{}).Where("uid = ?", uid).Updates(updates).Error
|
||||
|
||||
}
|
||||
|
||||
func (s *DeviceService) DeleteDevice(uid string) error {
|
||||
|
||||
return s.db.Where("uid = ?", uid).Delete(&model.Device{}).Error
|
||||
|
||||
}
|
||||
|
||||
func (s *DeviceService) GetLicenseInfo(deviceUID string) (*model.Device, error) {
|
||||
|
||||
var device model.Device
|
||||
|
||||
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
|
||||
|
||||
return nil, errors.New("设备不存在")
|
||||
|
||||
}
|
||||
|
||||
return &device, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *DeviceService) CheckLicenseStatus(deviceUID string) (string, error) {
|
||||
|
||||
var device model.Device
|
||||
|
||||
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
|
||||
|
||||
return "", errors.New("设备不存在")
|
||||
|
||||
}
|
||||
|
||||
if device.LicenseCode == "" {
|
||||
|
||||
return "未授权", nil
|
||||
|
||||
}
|
||||
|
||||
if device.Status != "active" {
|
||||
|
||||
return device.Status, nil
|
||||
|
||||
}
|
||||
|
||||
switch device.LicenseType {
|
||||
|
||||
case "时间段":
|
||||
|
||||
if time.Now().After(device.ExpireTime) {
|
||||
|
||||
device.Status = "expired"
|
||||
|
||||
s.db.Save(&device)
|
||||
|
||||
return "已过期", nil
|
||||
|
||||
}
|
||||
|
||||
case "启动次数":
|
||||
|
||||
if device.StartCount >= device.MaxUses {
|
||||
|
||||
device.Status = "expired"
|
||||
|
||||
s.db.Save(&device)
|
||||
|
||||
return "已达到使用上限", nil
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return "正常", nil
|
||||
|
||||
}
|
||||
|
||||
type DeviceQueryParams struct {
|
||||
UID string
|
||||
DeviceType string
|
||||
Company string
|
||||
LicenseType string
|
||||
Status string
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
type DeviceCreateInput struct {
|
||||
UID string `json:"uid" binding:"required"`
|
||||
DeviceType string `json:"device_type" binding:"required"`
|
||||
DeviceModel string `json:"device_model" binding:"required"`
|
||||
Company string `json:"company"`
|
||||
}
|
||||
|
||||
func (s *DeviceService) CreateDevice(input *DeviceCreateInput) error {
|
||||
// 检查设备UID是否已存在
|
||||
var count int64
|
||||
if err := s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("设备UID已存在")
|
||||
}
|
||||
|
||||
// 创建设备记录
|
||||
device := &model.Device{
|
||||
UID: input.UID,
|
||||
DeviceType: input.DeviceType,
|
||||
DeviceModel: input.DeviceModel,
|
||||
Company: input.Company,
|
||||
RegisterTime: time.Now(),
|
||||
Status: "inactive", // 初始状态为未激活
|
||||
}
|
||||
|
||||
return s.db.Create(device).Error
|
||||
}
|
||||
|
||||
// 添加更新检查方法
|
||||
func (s *DeviceService) CheckUpdate(deviceUID, currentVersion string) (*model.FileUpload, error) {
|
||||
// 获取设备信息
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
|
||||
return nil, errors.New("设备不存在")
|
||||
}
|
||||
|
||||
// 检查设备状态
|
||||
if device.Status != "active" {
|
||||
return nil, errors.New("设备未激活或已过期")
|
||||
}
|
||||
|
||||
// 查找最新的更新文件
|
||||
var update model.FileUpload
|
||||
err := s.db.Where("device_model = ? AND is_update = ?", device.DeviceModel, true).
|
||||
Order("created_at DESC").
|
||||
First(&update).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil // 没有可用更新
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 比较版本
|
||||
if update.Version <= currentVersion && !update.ForceUpdate {
|
||||
return nil, nil // 当前版本已是最新
|
||||
}
|
||||
|
||||
return &update, nil
|
||||
}
|
||||
|
||||
// 添加设备型号相关的方法
|
||||
|
||||
// CreateDeviceModel 创建设备型号
|
||||
func (s *DeviceService) CreateDeviceModel(model_ *model.DeviceModel) error {
|
||||
// 检查型号名称是否已存在
|
||||
var count int64
|
||||
s.db.Model(&model.DeviceModel{}).Where("model_name = ?", model_.ModelName).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("设备型号已存在")
|
||||
}
|
||||
|
||||
return s.db.Create(model_).Error
|
||||
}
|
||||
|
||||
// UpdateDeviceModel 更新设备型号
|
||||
func (s *DeviceService) UpdateDeviceModel(id uint, model_ *model.DeviceModel) error {
|
||||
// 检查型号名称是否被其他型号使用
|
||||
var count int64
|
||||
s.db.Model(&model.DeviceModel{}).Where("model_name = ? AND id != ?", model_.ModelName, id).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("设备型号已存在")
|
||||
}
|
||||
|
||||
return s.db.Model(&model.DeviceModel{}).Where("id = ?", id).Updates(model_).Error
|
||||
}
|
||||
|
||||
// DeleteDeviceModel 删除设备型号
|
||||
func (s *DeviceService) DeleteDeviceModel(id uint) error {
|
||||
// 检查是否有设备使用此型号
|
||||
var count int64
|
||||
s.db.Model(&model.DeviceModel{}).Where("device_model = ?", id).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("该型号下存在设备,无法删除")
|
||||
}
|
||||
|
||||
return s.db.Delete(&model.DeviceModel{}, id).Error
|
||||
}
|
||||
|
||||
// GetDeviceModels 获取设备型号列表
|
||||
func (s *DeviceService) GetDeviceModels(modelName, deviceType, company string, page, pageSize int) ([]model.DeviceModel, int64, error) {
|
||||
var models []model.DeviceModel
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.DeviceModel{})
|
||||
|
||||
if modelName != "" {
|
||||
query = query.Where("model_name LIKE ?", "%"+modelName+"%")
|
||||
}
|
||||
if deviceType != "" {
|
||||
query = query.Where("device_type = ?", deviceType)
|
||||
}
|
||||
if company != "" {
|
||||
query = query.Where("company LIKE ?", "%"+company+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
query.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
// 查询设备型号列表
|
||||
if err := query.Find(&models).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 查询每个型号下的设备数量
|
||||
for i := range models {
|
||||
var count int64
|
||||
s.db.Model(&model.Device{}).Where("device_model = ?", models[i].ModelName).Count(&count)
|
||||
models[i].DeviceCount = int(count)
|
||||
}
|
||||
|
||||
return models, total, nil
|
||||
}
|
||||
|
||||
// BatchDeleteDeviceModels 批量删除设备型号
|
||||
func (s *DeviceService) BatchDeleteDeviceModels(ids []uint) error {
|
||||
// 检查是否有设备使用这些型号
|
||||
var count int64
|
||||
s.db.Model(&model.Device{}).Where("device_model IN (?)", ids).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("选中的型号中存在正在使用的型号,无法删除")
|
||||
}
|
||||
|
||||
return s.db.Delete(&model.DeviceModel{}, ids).Error
|
||||
}
|
||||
|
||||
// GetRegisteredDevices 获取已注册设备列表
|
||||
func (s *DeviceService) GetRegisteredDevices(uid, deviceModel, status string, page, pageSize int) ([]model.Device, int64, error) {
|
||||
var devices []model.Device
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.Device{})
|
||||
|
||||
if uid != "" {
|
||||
query = query.Where("uid LIKE ?", "%"+uid+"%")
|
||||
}
|
||||
if deviceModel != "" {
|
||||
query = query.Where("device_model = ?", deviceModel)
|
||||
}
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
query.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&devices).Error
|
||||
return devices, total, err
|
||||
}
|
||||
|
||||
// BindLicense 绑定授权码
|
||||
func (s *DeviceService) BindLicense(uid string, licenseCode string) error {
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
|
||||
return errors.New("设备不存在")
|
||||
}
|
||||
|
||||
// 检查设备当前状态
|
||||
if device.LicenseCode != "" {
|
||||
return errors.New("设备已绑定授权码,请先解绑")
|
||||
}
|
||||
|
||||
// 验证授权码
|
||||
license, err := s.licenseService.GetLicenseByCode(licenseCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if license.Status != "unused" {
|
||||
return errors.New("授权码已被使用")
|
||||
}
|
||||
|
||||
// 根据授权类型处理
|
||||
switch license.LicenseType {
|
||||
case "time":
|
||||
if license.Duration <= 0 {
|
||||
return errors.New("无效的授权时长")
|
||||
}
|
||||
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
|
||||
device.Duration = license.Duration
|
||||
device.MaxUses = 0
|
||||
case "count":
|
||||
if license.MaxUses <= 0 {
|
||||
return errors.New("无效的使用次数")
|
||||
}
|
||||
device.ExpireTime = time.Time{} // 清空过期时间
|
||||
device.Duration = 0
|
||||
device.MaxUses = license.MaxUses
|
||||
device.StartCount = 0 // 重置启动次数
|
||||
case "permanent":
|
||||
device.ExpireTime = time.Time{} // 清空过期时间
|
||||
device.Duration = 0
|
||||
device.MaxUses = 0
|
||||
default:
|
||||
return errors.New("无效的授权类型")
|
||||
}
|
||||
|
||||
// 更新设备基本信息
|
||||
device.LicenseCode = licenseCode
|
||||
device.LicenseType = license.LicenseType
|
||||
device.Status = "active"
|
||||
device.LastActiveAt = time.Now()
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 更新设备信息
|
||||
if err := tx.Save(&device).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新授权码状态
|
||||
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", licenseCode).
|
||||
Updates(map[string]interface{}{
|
||||
"status": "used",
|
||||
"used_by": uid,
|
||||
"used_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录设备日志
|
||||
log := model.DeviceLog{
|
||||
DeviceUID: uid,
|
||||
Action: "bind_license",
|
||||
Message: fmt.Sprintf("绑定%s授权码: %s", getLicenseTypeText(license.LicenseType), licenseCode),
|
||||
Status: "success",
|
||||
}
|
||||
if err := tx.Create(&log).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// 获取授权类型的中文描述
|
||||
func getLicenseTypeText(licenseType string) string {
|
||||
switch licenseType {
|
||||
case "time":
|
||||
return "时间"
|
||||
case "count":
|
||||
return "次数"
|
||||
case "permanent":
|
||||
return "永久"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
// UnbindLicense 解绑授权码
|
||||
func (s *DeviceService) UnbindLicense(uid string) error {
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
|
||||
return errors.New("设备不存在")
|
||||
}
|
||||
|
||||
if device.LicenseCode == "" {
|
||||
return errors.New("设备未绑定授权码")
|
||||
}
|
||||
|
||||
oldLicenseCode := device.LicenseCode
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 更新设备信息
|
||||
if err := tx.Model(&device).Updates(map[string]interface{}{
|
||||
"license_code": "",
|
||||
"license_type": "",
|
||||
"status": "inactive",
|
||||
"expire_time": nil,
|
||||
"max_uses": 0,
|
||||
"duration": 0,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新授权码状态
|
||||
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", oldLicenseCode).
|
||||
Updates(map[string]interface{}{
|
||||
"status": "unused",
|
||||
"used_by": "",
|
||||
"used_at": nil,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录设备日志
|
||||
log := model.DeviceLog{
|
||||
DeviceUID: uid,
|
||||
Action: "unbind_license",
|
||||
Message: fmt.Sprintf("解绑授权码: %s", oldLicenseCode),
|
||||
Status: "success",
|
||||
}
|
||||
if err := tx.Create(&log).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetDeviceLogs 获取设备日志
|
||||
func (s *DeviceService) GetDeviceLogs(uid string, page, pageSize int) ([]model.DeviceLog, int64, error) {
|
||||
var logs []model.DeviceLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.DeviceLog{}).Where("device_uid = ?", uid)
|
||||
|
||||
// 获取总数
|
||||
query.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&logs).Error
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
// DashboardStats 仪表盘统计数据
|
||||
type DashboardStats struct {
|
||||
TotalDevices int64 `json:"total_devices"` // 设备总数
|
||||
TotalLicenses int64 `json:"total_licenses"` // 授权码总数
|
||||
TodayNew int64 `json:"today_new"` // 今日新增
|
||||
OnlineDevices int64 `json:"online_devices"` // 在线设备
|
||||
ActiveDevices int64 `json:"active_devices"` // 激活设备
|
||||
ExpiredDevices int64 `json:"expired_devices"` // 过期设备
|
||||
}
|
||||
|
||||
// GetDashboardStats 获取仪表盘统计数据
|
||||
func (s *DeviceService) GetDashboardStats() (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
|
||||
// 获取设备总数
|
||||
s.db.Model(&model.Device{}).Count(&stats.TotalDevices)
|
||||
|
||||
// 获取授权码总数
|
||||
s.db.Model(&model.LicenseCode{}).Count(&stats.TotalLicenses)
|
||||
|
||||
// 获取今日新增设备数
|
||||
today := time.Now().Format("2006-01-02")
|
||||
s.db.Model(&model.Device{}).Where("DATE(register_time) = ?", today).Count(&stats.TodayNew)
|
||||
|
||||
// 获取在线设备数(最近30分钟内有活动的设备)
|
||||
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
|
||||
s.db.Model(&model.Device{}).Where("last_active_at > ?", thirtyMinutesAgo).Count(&stats.OnlineDevices)
|
||||
|
||||
// 获取激活设备数
|
||||
s.db.Model(&model.Device{}).Where("status = ?", "active").Count(&stats.ActiveDevices)
|
||||
|
||||
// 获取过期设备数
|
||||
s.db.Model(&model.Device{}).Where("status = ?", "expired").Count(&stats.ExpiredDevices)
|
||||
|
||||
return &stats, nil
|
||||
}
|
298
internal/service/device_test.go
Normal file
298
internal/service/device_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"licserver/internal/model"
|
||||
"licserver/internal/utils"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeviceService_RegisterDevice(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
licenseService := NewLicenseService(db)
|
||||
deviceService := NewDeviceService(db, licenseService)
|
||||
|
||||
// 创建测试设备型号
|
||||
deviceModel := &model.DeviceModel{
|
||||
ModelName: "test-model",
|
||||
DeviceType: "software",
|
||||
Company: "test-company",
|
||||
Status: "active",
|
||||
CreatedBy: 1,
|
||||
}
|
||||
err := db.Create(deviceModel).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试授权码
|
||||
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
Count: 1,
|
||||
Remark: "test",
|
||||
}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, licenses, 1)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *DeviceRegisterInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常注册设备",
|
||||
input: &DeviceRegisterInput{
|
||||
UID: "test-device-001",
|
||||
DeviceModel: "test-model",
|
||||
LicenseCode: licenses[0].Code,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "重复注册设备",
|
||||
input: &DeviceRegisterInput{
|
||||
UID: "test-device-001",
|
||||
DeviceModel: "test-model",
|
||||
LicenseCode: licenses[0].Code,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "使用无效授权码",
|
||||
input: &DeviceRegisterInput{
|
||||
UID: "test-device-002",
|
||||
DeviceModel: "test-model",
|
||||
LicenseCode: "invalid-code",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "使用不存在的设备型号",
|
||||
input: &DeviceRegisterInput{
|
||||
UID: "test-device-003",
|
||||
DeviceModel: "non-existent-model",
|
||||
LicenseCode: licenses[0].Code,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := deviceService.RegisterDevice(tt.input, "127.0.0.1")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证设备是否正确注册
|
||||
var device model.Device
|
||||
err = db.Where("uid = ?", tt.input.UID).First(&device).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, deviceModel.DeviceType, device.DeviceType)
|
||||
assert.Equal(t, "active", device.Status)
|
||||
assert.Equal(t, tt.input.LicenseCode, device.LicenseCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceService_ValidateDevice(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
licenseService := NewLicenseService(db)
|
||||
deviceService := NewDeviceService(db, licenseService)
|
||||
|
||||
// 创建测试设备型号
|
||||
deviceModel := &model.DeviceModel{
|
||||
ModelName: "test-model",
|
||||
DeviceType: "software",
|
||||
Company: "test-company",
|
||||
CreatedBy: 1,
|
||||
}
|
||||
err := db.Create(deviceModel).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试授权码
|
||||
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
Count: 1,
|
||||
Remark: "test",
|
||||
}, 1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 注册测试设备
|
||||
device := &model.Device{
|
||||
UID: "test-device-001",
|
||||
DeviceType: deviceModel.DeviceType,
|
||||
DeviceModel: deviceModel.ModelName,
|
||||
Company: deviceModel.Company,
|
||||
RegisterTime: time.Now(),
|
||||
Status: "active",
|
||||
LicenseCode: licenses[0].Code,
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
ExpireTime: time.Now().Add(30 * 24 * time.Hour),
|
||||
}
|
||||
err = db.Create(device).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uid string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "验证正常设备",
|
||||
uid: "test-device-001",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "验证不存在的设备",
|
||||
uid: "non-existent-device",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "验证过期设备",
|
||||
uid: "expired-device",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
// 创建过期设备
|
||||
expiredDevice := &model.Device{
|
||||
UID: "expired-device",
|
||||
DeviceType: deviceModel.DeviceType,
|
||||
DeviceModel: deviceModel.ModelName,
|
||||
Company: deviceModel.Company,
|
||||
RegisterTime: time.Now(),
|
||||
Status: "expired",
|
||||
LicenseCode: "expired-license",
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
ExpireTime: time.Now().Add(-24 * time.Hour),
|
||||
}
|
||||
err = db.Create(expiredDevice).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := deviceService.ValidateDevice(tt.uid)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceService_UpdateStartCount(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
licenseService := NewLicenseService(db)
|
||||
deviceService := NewDeviceService(db, licenseService)
|
||||
|
||||
// 创建测试设备型号
|
||||
deviceModel := &model.DeviceModel{
|
||||
ModelName: "test-model",
|
||||
DeviceType: "software",
|
||||
Company: "test-company",
|
||||
CreatedBy: 1,
|
||||
}
|
||||
err := db.Create(deviceModel).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试设备
|
||||
device := &model.Device{
|
||||
UID: "test-device-001",
|
||||
DeviceType: deviceModel.DeviceType,
|
||||
DeviceModel: deviceModel.ModelName,
|
||||
Company: deviceModel.Company,
|
||||
RegisterTime: time.Now(),
|
||||
Status: "active",
|
||||
LicenseType: "count",
|
||||
MaxUses: 5,
|
||||
StartCount: 0,
|
||||
}
|
||||
err = db.Create(device).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 测试更新启动次数
|
||||
for i := 1; i <= 5; i++ {
|
||||
err = deviceService.UpdateStartCount(device.UID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证启动次数
|
||||
var updatedDevice model.Device
|
||||
err = db.First(&updatedDevice, "uid = ?", device.UID).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, i, updatedDevice.StartCount)
|
||||
|
||||
// 检查最后一次是否将状态更新为过期
|
||||
if i == 5 {
|
||||
assert.Equal(t, "expired", updatedDevice.Status)
|
||||
} else {
|
||||
assert.Equal(t, "active", updatedDevice.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试超出使用次数
|
||||
err = deviceService.UpdateStartCount(device.UID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDeviceService_CreateDeviceModel(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
licenseService := NewLicenseService(db)
|
||||
deviceService := NewDeviceService(db, licenseService)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model *model.DeviceModel
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "创建有效设备型号",
|
||||
model: &model.DeviceModel{
|
||||
ModelName: "test-model",
|
||||
DeviceType: "software",
|
||||
Company: "test-company",
|
||||
Status: "active",
|
||||
CreatedBy: 1,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "重复的设备型号",
|
||||
model: &model.DeviceModel{
|
||||
ModelName: "test-model",
|
||||
DeviceType: "software",
|
||||
Company: "test-company",
|
||||
Status: "active",
|
||||
CreatedBy: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := deviceService.CreateDeviceModel(tt.model)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证创建的设备型号
|
||||
var savedModel model.DeviceModel
|
||||
err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType)
|
||||
assert.Equal(t, tt.model.Company, savedModel.Company)
|
||||
assert.Equal(t, tt.model.Status, savedModel.Status)
|
||||
})
|
||||
}
|
||||
}
|
709
internal/service/license.go
Normal file
709
internal/service/license.go
Normal file
@@ -0,0 +1,709 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"encoding/hex"
|
||||
|
||||
"errors"
|
||||
|
||||
"fmt"
|
||||
|
||||
"strings"
|
||||
|
||||
"time"
|
||||
|
||||
"licserver/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LicenseService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLicenseService(db *gorm.DB) *LicenseService {
|
||||
|
||||
return &LicenseService{db: db}
|
||||
|
||||
}
|
||||
|
||||
// 生成授权码
|
||||
|
||||
func generateLicenseCode() (string, error) {
|
||||
|
||||
b := make([]byte, 16)
|
||||
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
return "", err
|
||||
|
||||
}
|
||||
|
||||
return hex.EncodeToString(b), nil
|
||||
|
||||
}
|
||||
|
||||
// 创建授权码
|
||||
|
||||
type LicenseCreateInput struct {
|
||||
LicenseType string `json:"license_type" binding:"required"`
|
||||
|
||||
Duration int `json:"duration"` // 时间授权的有效期(分钟)
|
||||
|
||||
MaxUses int `json:"max_uses"` // 次数授权的使用次数
|
||||
|
||||
Count int `json:"count" binding:"required,min=1"` // 生成数量
|
||||
|
||||
Remark string `json:"remark"` // 备注
|
||||
|
||||
BindCount int `json:"bind_count"` // 可绑定次数,默认为-1(无限制)
|
||||
|
||||
}
|
||||
|
||||
func (s *LicenseService) CreateLicenses(input *LicenseCreateInput, createdBy uint) ([]model.LicenseCode, error) {
|
||||
|
||||
// 验证参数
|
||||
|
||||
input.LicenseType = strings.ToLower(input.LicenseType) // 转为小写
|
||||
|
||||
switch input.LicenseType {
|
||||
case "time":
|
||||
if input.Duration <= 0 {
|
||||
return nil, errors.New("时间授权必须指定有效期")
|
||||
}
|
||||
case "count":
|
||||
if input.MaxUses <= 0 {
|
||||
return nil, errors.New("次数授权必须指定使用次数")
|
||||
}
|
||||
case "permanent":
|
||||
// 永久授权不需要额外参数
|
||||
default:
|
||||
return nil, errors.New("无效的授权类型")
|
||||
}
|
||||
|
||||
// 如果未指定绑定次数,设置为默认值-1
|
||||
if input.BindCount == 0 {
|
||||
input.BindCount = -1
|
||||
}
|
||||
|
||||
// 生成批次号
|
||||
|
||||
batchNo := time.Now().Format("20060102150405")
|
||||
|
||||
licenses := make([]model.LicenseCode, 0, input.Count)
|
||||
|
||||
for i := 0; i < input.Count; i++ {
|
||||
|
||||
code, err := generateLicenseCode()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
license := model.LicenseCode{
|
||||
|
||||
Code: code,
|
||||
|
||||
LicenseType: input.LicenseType,
|
||||
|
||||
Duration: input.Duration,
|
||||
|
||||
MaxUses: input.MaxUses,
|
||||
|
||||
Status: "unused",
|
||||
|
||||
CreatedBy: createdBy,
|
||||
|
||||
BatchNo: batchNo,
|
||||
|
||||
Remark: input.Remark,
|
||||
|
||||
BindCount: input.BindCount,
|
||||
}
|
||||
|
||||
licenses = append(licenses, license)
|
||||
|
||||
}
|
||||
|
||||
// 批量创建授权码
|
||||
|
||||
if err := s.db.Create(&licenses).Error; err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
return licenses, nil
|
||||
|
||||
}
|
||||
|
||||
// 验证并使用授权码
|
||||
|
||||
func (s *LicenseService) UseLicense(code, deviceUID, ip string) (*model.LicenseCode, error) {
|
||||
|
||||
var license model.LicenseCode
|
||||
|
||||
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
|
||||
return nil, errors.New("授权码不存在")
|
||||
|
||||
}
|
||||
|
||||
// 检查授权码状态
|
||||
|
||||
if license.Status != "unused" {
|
||||
|
||||
return nil, errors.New("授权码已被使用")
|
||||
|
||||
}
|
||||
|
||||
// 检查绑定次数
|
||||
if license.BindCount == 0 {
|
||||
return nil, errors.New("授权码已达到最大绑定次数限制")
|
||||
}
|
||||
|
||||
// 更新授权码状态
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": "used",
|
||||
"used_by": deviceUID,
|
||||
"used_at": time.Now(),
|
||||
}
|
||||
|
||||
// 如果不是无限制,减少绑定次数
|
||||
if license.BindCount > 0 {
|
||||
updates["bind_count"] = license.BindCount - 1
|
||||
}
|
||||
|
||||
if err := s.db.Model(&license).Updates(updates).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录使用日志
|
||||
log := model.LicenseLog{
|
||||
LicenseID: license.ID,
|
||||
DeviceUID: deviceUID,
|
||||
Action: "use",
|
||||
IP: ip,
|
||||
Status: "success",
|
||||
Message: fmt.Sprintf("设备 %s 使用授权码", deviceUID),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&log).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &license, nil
|
||||
|
||||
}
|
||||
|
||||
// 获取授权码列表
|
||||
|
||||
func (s *LicenseService) GetLicenses(status, licenseType, batchNo string, page, pageSize int) ([]model.LicenseCode, int64, error) {
|
||||
|
||||
var licenses []model.LicenseCode
|
||||
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.LicenseCode{})
|
||||
|
||||
if status != "" {
|
||||
|
||||
query = query.Where("status = ?", strings.ToLower(status))
|
||||
|
||||
}
|
||||
|
||||
if licenseType != "" {
|
||||
|
||||
query = query.Where("license_type = ?", strings.ToLower(licenseType))
|
||||
|
||||
}
|
||||
|
||||
if batchNo != "" {
|
||||
|
||||
query = query.Where("batch_no = ?", batchNo)
|
||||
|
||||
}
|
||||
|
||||
// 获取所有符合条件的授权码
|
||||
|
||||
var allLicenses []model.LicenseCode
|
||||
|
||||
if err := query.Find(&allLicenses).Error; err != nil {
|
||||
|
||||
return nil, 0, err
|
||||
|
||||
}
|
||||
|
||||
// 检查每个授权码的有效性
|
||||
|
||||
for i := range allLicenses {
|
||||
|
||||
if allLicenses[i].Status == "used" {
|
||||
|
||||
if err := s.CheckLicenseValidity(allLicenses[i].Code); err != nil {
|
||||
|
||||
// 如果检查失败,更新状态
|
||||
|
||||
s.db.Model(&allLicenses[i]).Update("status", "expired")
|
||||
|
||||
allLicenses[i].Status = "expired"
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
total = int64(len(allLicenses))
|
||||
|
||||
// 分页
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
|
||||
start := (page - 1) * pageSize
|
||||
|
||||
end := start + pageSize
|
||||
|
||||
if start < len(allLicenses) {
|
||||
|
||||
if end > len(allLicenses) {
|
||||
|
||||
end = len(allLicenses)
|
||||
|
||||
}
|
||||
|
||||
licenses = allLicenses[start:end]
|
||||
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
licenses = allLicenses
|
||||
|
||||
}
|
||||
|
||||
return licenses, total, nil
|
||||
|
||||
}
|
||||
|
||||
// 获取授权码使用日志
|
||||
|
||||
func (s *LicenseService) GetLicenseLogs(licenseID uint, page, pageSize int) ([]model.LicenseLog, int64, error) {
|
||||
|
||||
var logs []model.LicenseLog
|
||||
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.LicenseLog{}).Where("license_id = ?", licenseID)
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&logs).Error
|
||||
|
||||
return logs, total, err
|
||||
|
||||
}
|
||||
|
||||
// ExportLogs 导出授权码日志
|
||||
|
||||
func (s *LicenseService) ExportLogs(licenseID uint) ([]byte, error) {
|
||||
|
||||
logs, _, err := s.GetLicenseLogs(licenseID, 0, 0) // 获取所有日志
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 创建CSV内容
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString("操作类,设备UID,IP地址,状态,详细信息,时间\n")
|
||||
|
||||
for _, log := range logs {
|
||||
|
||||
// 转换操作类型
|
||||
|
||||
action := map[string]string{
|
||||
|
||||
"create": "创建",
|
||||
|
||||
"use": "使用",
|
||||
|
||||
"verify": "验证",
|
||||
}[log.Action]
|
||||
|
||||
// 转换状态
|
||||
|
||||
status := map[string]string{
|
||||
|
||||
"success": "成功",
|
||||
|
||||
"failed": "失败",
|
||||
}[log.Status]
|
||||
|
||||
// 写入一行记录
|
||||
|
||||
content.WriteString(fmt.Sprintf("%s,%s,%s,%s,%s,%s\n",
|
||||
|
||||
action,
|
||||
|
||||
log.DeviceUID,
|
||||
|
||||
log.IP,
|
||||
|
||||
status,
|
||||
|
||||
log.Message,
|
||||
|
||||
log.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
return []byte(content.String()), nil
|
||||
|
||||
}
|
||||
|
||||
// 撤销授权码
|
||||
|
||||
func (s *LicenseService) RevokeLicense(code string, userID uint) error {
|
||||
|
||||
var license model.LicenseCode
|
||||
|
||||
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
|
||||
return errors.New("授权码不存在")
|
||||
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
|
||||
if license.CreatedBy != userID {
|
||||
|
||||
return errors.New("无权操作此授权码")
|
||||
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
|
||||
if err := s.db.Model(&license).Update("status", "revoked").Error; err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
|
||||
log := model.LicenseLog{
|
||||
|
||||
LicenseID: license.ID,
|
||||
|
||||
Action: "revoke",
|
||||
|
||||
Status: "success",
|
||||
|
||||
Message: "授权码已撤销",
|
||||
}
|
||||
|
||||
s.db.Create(&log)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// 批量撤销授权码
|
||||
|
||||
func (s *LicenseService) RevokeLicenses(codes []string, userID uint) error {
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
for _, code := range codes {
|
||||
|
||||
var license model.LicenseCode
|
||||
|
||||
if err := tx.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
|
||||
continue
|
||||
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
|
||||
if license.CreatedBy != userID {
|
||||
|
||||
continue
|
||||
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
|
||||
if err := tx.Model(&license).Update("status", "revoked").Error; err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
|
||||
log := model.LicenseLog{
|
||||
|
||||
LicenseID: license.ID,
|
||||
|
||||
Action: "revoke",
|
||||
|
||||
Status: "success",
|
||||
|
||||
Message: "授权码已撤销",
|
||||
}
|
||||
|
||||
tx.Create(&log)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// 验证授权码
|
||||
|
||||
func (s *LicenseService) ValidateLicense(code string) (*model.LicenseCode, error) {
|
||||
|
||||
var license model.LicenseCode
|
||||
|
||||
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
|
||||
return nil, errors.New("无效的授权码")
|
||||
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
|
||||
if license.Status != "unused" {
|
||||
|
||||
return nil, errors.New("授权码已被使用或已撤销")
|
||||
|
||||
}
|
||||
|
||||
return &license, nil
|
||||
|
||||
}
|
||||
|
||||
// 导出授权码
|
||||
|
||||
func (s *LicenseService) ExportLicenses(codes []string) ([]byte, error) {
|
||||
|
||||
var licenses []model.LicenseCode
|
||||
|
||||
if err := s.db.Where("code IN ?", codes).Find(&licenses).Error; err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 创建CSV内容
|
||||
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString("授权码,授权类型,有效期(天),使用次数,状态,使用设备,使用时间,批次号,备注\n")
|
||||
|
||||
for _, license := range licenses {
|
||||
|
||||
// 转换授权类型
|
||||
|
||||
licenseType := map[string]string{
|
||||
|
||||
"time": "时间授权",
|
||||
|
||||
"count": "次数授权",
|
||||
|
||||
"permanent": "永久授权",
|
||||
}[license.LicenseType]
|
||||
|
||||
// 转换状态
|
||||
|
||||
status := map[string]string{
|
||||
|
||||
"unused": "未使用",
|
||||
|
||||
"used": "已使用",
|
||||
|
||||
"revoked": "已撤销",
|
||||
}[license.Status]
|
||||
|
||||
// 写入一行记录
|
||||
|
||||
content.WriteString(fmt.Sprintf("%s,%s,%d,%d,%s,%s,%s,%s,%s\n",
|
||||
|
||||
license.Code,
|
||||
|
||||
licenseType,
|
||||
|
||||
license.Duration,
|
||||
|
||||
license.MaxUses,
|
||||
|
||||
status,
|
||||
|
||||
license.UsedBy,
|
||||
|
||||
license.UsedAt.Format("2006-01-02 15:04:05"),
|
||||
|
||||
license.BatchNo,
|
||||
|
||||
license.Remark,
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
return []byte(content.String()), nil
|
||||
|
||||
}
|
||||
|
||||
// 获取授权码统计信息
|
||||
|
||||
func (s *LicenseService) GetLicenseStats() (map[string]interface{}, error) {
|
||||
|
||||
var stats struct {
|
||||
Total int64
|
||||
|
||||
Unused int64
|
||||
|
||||
Used int64
|
||||
|
||||
Revoked int64
|
||||
|
||||
Today int64
|
||||
|
||||
ThisWeek int64
|
||||
|
||||
ThisMonth int64
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Count(&stats.Total)
|
||||
|
||||
// 获取各状态数量
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("status = ?", "unused").Count(&stats.Unused)
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("status = ?", "used").Count(&stats.Used)
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("status = ?", "revoked").Count(&stats.Revoked)
|
||||
|
||||
// 获取今日创建数量
|
||||
|
||||
today := time.Now().Format("2006-01-02")
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("DATE(created_at) = ?", today).Count(&stats.Today)
|
||||
|
||||
// 获取本周创建数量
|
||||
|
||||
weekStart := time.Now().AddDate(0, 0, -int(time.Now().Weekday()))
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", weekStart).Count(&stats.ThisWeek)
|
||||
|
||||
// 获取本月创建数量
|
||||
|
||||
monthStart := time.Now().Format("2006-01") + "-01"
|
||||
|
||||
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", monthStart).Count(&stats.ThisMonth)
|
||||
|
||||
return map[string]interface{}{
|
||||
|
||||
"total": stats.Total,
|
||||
|
||||
"unused": stats.Unused,
|
||||
|
||||
"used": stats.Used,
|
||||
|
||||
"revoked": stats.Revoked,
|
||||
|
||||
"today": stats.Today,
|
||||
|
||||
"this_week": stats.ThisWeek,
|
||||
|
||||
"this_month": stats.ThisMonth,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// 添加检查授权码有效性的方法
|
||||
|
||||
func (s *LicenseService) CheckLicenseValidity(code string) error {
|
||||
|
||||
var license model.LicenseCode
|
||||
|
||||
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
|
||||
return errors.New("授权码不存在")
|
||||
|
||||
}
|
||||
|
||||
if license.Status != "unused" && license.Status != "used" {
|
||||
|
||||
return errors.New("授权码已被撤销或过期")
|
||||
|
||||
}
|
||||
|
||||
// 检查授权类型特定的限制
|
||||
|
||||
switch license.LicenseType {
|
||||
case "time":
|
||||
// 计算过期时间
|
||||
|
||||
expireTime := license.UsedAt.Add(time.Duration(license.Duration) * time.Minute)
|
||||
|
||||
if time.Now().After(expireTime) {
|
||||
|
||||
// 更新状态为过期
|
||||
|
||||
s.db.Model(&license).Update("status", "expired")
|
||||
|
||||
return errors.New("授权码已过期")
|
||||
|
||||
}
|
||||
case "count":
|
||||
if license.UsedCount >= license.MaxUses {
|
||||
|
||||
// 更新状态为过期
|
||||
|
||||
s.db.Model(&license).Update("status", "expired")
|
||||
|
||||
return errors.New("授权码使用次数已达上限")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// GetLicenseByCode 通过授权码获取授权信息
|
||||
func (s *LicenseService) GetLicenseByCode(code string) (*model.LicenseCode, error) {
|
||||
var license model.LicenseCode
|
||||
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("授权码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &license, nil
|
||||
}
|
153
internal/service/license_test.go
Normal file
153
internal/service/license_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"licserver/internal/utils"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLicenseService_CreateLicenses(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
service := NewLicenseService(db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *LicenseCreateInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "创建时间授权码",
|
||||
input: &LicenseCreateInput{
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
Count: 5,
|
||||
Remark: "test time license",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "创建次数授权码",
|
||||
input: &LicenseCreateInput{
|
||||
LicenseType: "count",
|
||||
MaxUses: 100,
|
||||
Count: 3,
|
||||
Remark: "test count license",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "创建永久授权码",
|
||||
input: &LicenseCreateInput{
|
||||
LicenseType: "permanent",
|
||||
Count: 1,
|
||||
Remark: "test permanent license",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效的授权类型",
|
||||
input: &LicenseCreateInput{
|
||||
LicenseType: "invalid",
|
||||
Count: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "时间授权无有效期",
|
||||
input: &LicenseCreateInput{
|
||||
LicenseType: "time",
|
||||
Duration: 0,
|
||||
Count: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
licenses, err := service.CreateLicenses(tt.input, 1)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, licenses, tt.input.Count)
|
||||
|
||||
for _, license := range licenses {
|
||||
assert.Equal(t, tt.input.LicenseType, license.LicenseType)
|
||||
assert.Equal(t, "unused", license.Status)
|
||||
assert.Equal(t, tt.input.Remark, license.Remark)
|
||||
|
||||
if tt.input.LicenseType == "time" {
|
||||
assert.Equal(t, tt.input.Duration, license.Duration)
|
||||
} else if tt.input.LicenseType == "count" {
|
||||
assert.Equal(t, tt.input.MaxUses, license.MaxUses)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLicenseService_UseLicense(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
service := NewLicenseService(db)
|
||||
|
||||
// 创建测试授权码
|
||||
input := &LicenseCreateInput{
|
||||
LicenseType: "time",
|
||||
Duration: 30,
|
||||
Count: 1,
|
||||
Remark: "test",
|
||||
}
|
||||
|
||||
licenses, err := service.CreateLicenses(input, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, licenses, 1)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
deviceUID string
|
||||
ip string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常使用授权码",
|
||||
code: licenses[0].Code,
|
||||
deviceUID: "test-device-001",
|
||||
ip: "127.0.0.1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "使用不存在的授权码",
|
||||
code: "invalid-code",
|
||||
deviceUID: "test-device-002",
|
||||
ip: "127.0.0.1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "重复使用授权码",
|
||||
code: licenses[0].Code,
|
||||
deviceUID: "test-device-003",
|
||||
ip: "127.0.0.1",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
license, err := service.UseLicense(tt.code, tt.deviceUID, tt.ip)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "used", license.Status)
|
||||
assert.Equal(t, tt.deviceUID, license.UsedBy)
|
||||
assert.NotZero(t, license.UsedAt)
|
||||
})
|
||||
}
|
||||
}
|
377
internal/service/monitor.go
Normal file
377
internal/service/monitor.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"licserver/internal/model"
|
||||
|
||||
"runtime"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/host"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/net"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MonitorService struct {
|
||||
db *gorm.DB
|
||||
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
func NewMonitorService(db *gorm.DB) *MonitorService {
|
||||
|
||||
return &MonitorService{
|
||||
|
||||
db: db,
|
||||
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) GetSystemStatus() (*model.SystemStatus, error) {
|
||||
|
||||
status := &model.SystemStatus{}
|
||||
|
||||
// CPU信息
|
||||
|
||||
if err := s.getCPUInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 内存信息
|
||||
|
||||
if err := s.getMemoryInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 磁盘信息
|
||||
|
||||
if err := s.getDiskInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 网络信息
|
||||
|
||||
if err := s.getNetworkInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 进程信息
|
||||
|
||||
if err := s.getProcessInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 主机信息
|
||||
|
||||
if err := s.getHostInfo(status); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 系统信息
|
||||
|
||||
s.getSystemInfo(status)
|
||||
|
||||
return status, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getCPUInfo(status *model.SystemStatus) error {
|
||||
|
||||
cpuPercent, err := cpu.Percent(time.Second, false)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// Windows 系统不支持 LoadAvg,设置为默认值
|
||||
|
||||
status.CPU.LoadAvg = []float64{0, 0, 0}
|
||||
|
||||
cpuInfo, err := cpu.Info()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.CPU.Usage = cpuPercent[0]
|
||||
|
||||
status.CPU.CoreCount = runtime.NumCPU()
|
||||
|
||||
if len(cpuInfo) > 0 {
|
||||
|
||||
status.CPU.ModelName = cpuInfo[0].ModelName
|
||||
|
||||
status.CPU.MHz = cpuInfo[0].Mhz
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getMemoryInfo(status *model.SystemStatus) error {
|
||||
|
||||
memInfo, err := mem.VirtualMemory()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
swapInfo, err := mem.SwapMemory()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.Memory.Total = memInfo.Total
|
||||
|
||||
status.Memory.Used = memInfo.Used
|
||||
|
||||
status.Memory.Free = memInfo.Free
|
||||
|
||||
status.Memory.UsageRate = memInfo.UsedPercent
|
||||
|
||||
status.Memory.SwapTotal = swapInfo.Total
|
||||
|
||||
status.Memory.SwapUsed = swapInfo.Used
|
||||
|
||||
status.Memory.SwapFree = swapInfo.Free
|
||||
|
||||
status.Memory.SwapUsageRate = swapInfo.UsedPercent
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getDiskInfo(status *model.SystemStatus) error {
|
||||
|
||||
partitions, err := disk.Partitions(true)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.Disk.Partitions = make([]model.DiskPartition, 0)
|
||||
|
||||
for _, partition := range partitions {
|
||||
|
||||
usage, err := disk.Usage(partition.Mountpoint)
|
||||
|
||||
if err != nil {
|
||||
|
||||
continue
|
||||
|
||||
}
|
||||
|
||||
status.Disk.Partitions = append(status.Disk.Partitions, model.DiskPartition{
|
||||
|
||||
Device: partition.Device,
|
||||
|
||||
Mountpoint: partition.Mountpoint,
|
||||
|
||||
Fstype: partition.Fstype,
|
||||
|
||||
Total: usage.Total,
|
||||
|
||||
Used: usage.Used,
|
||||
|
||||
Free: usage.Free,
|
||||
|
||||
UsageRate: usage.UsedPercent,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getNetworkInfo(status *model.SystemStatus) error {
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
ioCounters, err := net.IOCounters(true)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.Network.Interfaces = make([]model.NetworkInterface, 0)
|
||||
|
||||
for _, iface := range interfaces {
|
||||
|
||||
var counter net.IOCountersStat
|
||||
|
||||
for _, io := range ioCounters {
|
||||
|
||||
if io.Name == iface.Name {
|
||||
|
||||
counter = io
|
||||
|
||||
break
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 获取接口的地址列表
|
||||
|
||||
addrs := make([]string, 0)
|
||||
|
||||
for _, addr := range iface.Addrs {
|
||||
addrs = append(addrs, addr.String())
|
||||
}
|
||||
|
||||
status.Network.Interfaces = append(status.Network.Interfaces, model.NetworkInterface{
|
||||
|
||||
Name: iface.Name,
|
||||
|
||||
BytesSent: counter.BytesSent,
|
||||
|
||||
BytesRecv: counter.BytesRecv,
|
||||
|
||||
PacketsSent: counter.PacketsSent,
|
||||
|
||||
PacketsRecv: counter.PacketsRecv,
|
||||
|
||||
Addrs: addrs,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getProcessInfo(status *model.SystemStatus) error {
|
||||
|
||||
processes, err := process.Processes()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.Process.Total = len(processes)
|
||||
|
||||
status.Process.List = make([]model.ProcessInfo, 0)
|
||||
|
||||
for i := 0; i < 10 && i < len(processes); i++ {
|
||||
|
||||
p := processes[i]
|
||||
|
||||
name, _ := p.Name()
|
||||
|
||||
cpu, _ := p.CPUPercent()
|
||||
|
||||
mem, _ := p.MemoryPercent()
|
||||
|
||||
created, _ := p.CreateTime()
|
||||
|
||||
status.Process.List = append(status.Process.List, model.ProcessInfo{
|
||||
|
||||
PID: int(p.Pid),
|
||||
|
||||
Name: name,
|
||||
|
||||
CPU: cpu,
|
||||
|
||||
Memory: float64(mem),
|
||||
|
||||
Created: created,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getHostInfo(status *model.SystemStatus) error {
|
||||
|
||||
info, err := host.Info()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
status.Host.Hostname = info.Hostname
|
||||
|
||||
status.Host.OS = info.OS
|
||||
|
||||
status.Host.Platform = info.Platform
|
||||
|
||||
status.Host.PlatformVersion = info.PlatformVersion
|
||||
|
||||
status.Host.KernelVersion = info.KernelVersion
|
||||
|
||||
status.Host.BootTime = time.Unix(int64(info.BootTime), 0)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *MonitorService) getSystemInfo(status *model.SystemStatus) {
|
||||
|
||||
status.System.Uptime = time.Since(s.startTime)
|
||||
|
||||
status.System.CurrentTime = time.Now()
|
||||
|
||||
var activeUsers int64
|
||||
|
||||
s.db.Model(&model.User{}).Count(&activeUsers)
|
||||
|
||||
status.System.ActiveUsers = int(activeUsers)
|
||||
|
||||
var totalDevices int64
|
||||
|
||||
s.db.Model(&model.Device{}).Count(&totalDevices)
|
||||
|
||||
status.System.TotalDevices = int(totalDevices)
|
||||
|
||||
}
|
47
internal/service/site.go
Normal file
47
internal/service/site.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"licserver/internal/utils"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SiteService struct {
|
||||
config *utils.Config
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSiteService(config *utils.Config) *SiteService {
|
||||
return &SiteService{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SiteService) GetSettings() utils.SiteConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.config.Site
|
||||
}
|
||||
|
||||
func (s *SiteService) UpdateSettings(settings utils.SiteConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 更新内存中的配置
|
||||
s.config.Site = settings
|
||||
|
||||
// 持久化配置到文件
|
||||
return utils.SaveConfig(s.config, "system", "更新站点设置")
|
||||
}
|
||||
|
||||
// ValidateSettings 验证站点设置
|
||||
func (s *SiteService) ValidateSettings(settings utils.SiteConfig) error {
|
||||
// 这里可以添加更多的验证逻辑
|
||||
if settings.Title == "" {
|
||||
return errors.New("站点标题不能为空")
|
||||
}
|
||||
if settings.BaseURL == "" {
|
||||
return errors.New("基础URL不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
188
internal/service/token.go
Normal file
188
internal/service/token.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"licserver/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TokenService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTokenService(db *gorm.DB) *TokenService {
|
||||
return &TokenService{db: db}
|
||||
}
|
||||
|
||||
// 生成访问令牌
|
||||
func generateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// 创建访问令牌
|
||||
func (s *TokenService) CreateToken(deviceUID, tokenType string, expireTime time.Time, ipList []string, createdBy uint) (*model.AccessToken, error) {
|
||||
// 验证设备是否存在
|
||||
var device model.Device
|
||||
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
|
||||
return nil, errors.New("设备不存在")
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建令牌记录
|
||||
accessToken := &model.AccessToken{
|
||||
Token: token,
|
||||
DeviceUID: deviceUID,
|
||||
Type: tokenType,
|
||||
Status: "active",
|
||||
ExpireTime: expireTime,
|
||||
IPList: strings.Join(ipList, ","),
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
|
||||
if err := s.db.Create(accessToken).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
tokenLog := &model.TokenLog{
|
||||
TokenID: accessToken.ID,
|
||||
Action: "create",
|
||||
Status: "success",
|
||||
Message: "创建访问令牌",
|
||||
}
|
||||
s.db.Create(tokenLog)
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// 验证令牌
|
||||
func (s *TokenService) ValidateToken(token, ip string) (*model.AccessToken, error) {
|
||||
var accessToken model.AccessToken
|
||||
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
|
||||
return nil, errors.New("无效的令牌")
|
||||
}
|
||||
|
||||
// 检查令牌状态
|
||||
if accessToken.Status != "active" {
|
||||
return nil, errors.New("令牌已被撤销")
|
||||
}
|
||||
|
||||
// 检查过期时间
|
||||
if time.Now().After(accessToken.ExpireTime) {
|
||||
return nil, errors.New("令牌已过期")
|
||||
}
|
||||
|
||||
// 检查IP限制
|
||||
if accessToken.IPList != "" {
|
||||
allowedIPs := strings.Split(accessToken.IPList, ",")
|
||||
allowed := false
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if allowedIP == ip {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
return nil, errors.New("IP地址不允许访问")
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用记录
|
||||
s.db.Model(&accessToken).Updates(map[string]interface{}{
|
||||
"last_used": time.Now(),
|
||||
"usage_count": gorm.Expr("usage_count + 1"),
|
||||
})
|
||||
|
||||
// 记录日志
|
||||
tokenLog := &model.TokenLog{
|
||||
TokenID: accessToken.ID,
|
||||
Action: "use",
|
||||
IP: ip,
|
||||
Status: "success",
|
||||
}
|
||||
s.db.Create(tokenLog)
|
||||
|
||||
return &accessToken, nil
|
||||
}
|
||||
|
||||
// 撤销令牌
|
||||
func (s *TokenService) RevokeToken(token string, userID uint) error {
|
||||
var accessToken model.AccessToken
|
||||
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
|
||||
return errors.New("令牌不存在")
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if accessToken.CreatedBy != userID {
|
||||
return errors.New("无权操作此令牌")
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
if err := s.db.Model(&accessToken).Update("status", "revoked").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
tokenLog := &model.TokenLog{
|
||||
TokenID: accessToken.ID,
|
||||
Action: "revoke",
|
||||
Status: "success",
|
||||
Message: "撤销访问令牌",
|
||||
}
|
||||
s.db.Create(tokenLog)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取令牌列表
|
||||
func (s *TokenService) GetTokens(deviceUID string, page, pageSize int) ([]model.AccessToken, int64, error) {
|
||||
var tokens []model.AccessToken
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.AccessToken{})
|
||||
if deviceUID != "" {
|
||||
query = query.Where("device_uid = ?", deviceUID)
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&tokens).Error
|
||||
return tokens, total, err
|
||||
}
|
||||
|
||||
// 获取令牌日志
|
||||
func (s *TokenService) GetTokenLogs(tokenID uint, page, pageSize int) ([]model.TokenLog, int64, error) {
|
||||
var logs []model.TokenLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.TokenLog{}).Where("token_id = ?", tokenID)
|
||||
query.Count(&total)
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&logs).Error
|
||||
return logs, total, err
|
||||
}
|
463
internal/service/upload.go
Normal file
463
internal/service/upload.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"fmt"
|
||||
|
||||
"io"
|
||||
|
||||
"mime/multipart"
|
||||
|
||||
"os"
|
||||
|
||||
"path/filepath"
|
||||
|
||||
"sort"
|
||||
|
||||
"strings"
|
||||
|
||||
"time"
|
||||
|
||||
"licserver/internal/model"
|
||||
|
||||
"licserver/internal/utils"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"crypto/sha256"
|
||||
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
type UploadService struct {
|
||||
db *gorm.DB
|
||||
|
||||
config *utils.Config
|
||||
}
|
||||
|
||||
func NewUploadService(db *gorm.DB, config *utils.Config) *UploadService {
|
||||
|
||||
return &UploadService{
|
||||
|
||||
db: db,
|
||||
|
||||
config: config,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) UploadFile(file *multipart.FileHeader, userID uint, deviceUID, description string) (*model.FileUpload, error) {
|
||||
|
||||
// 生成唯一文件名
|
||||
|
||||
ext := filepath.Ext(file.Filename)
|
||||
|
||||
uniqueID := uuid.New().String()
|
||||
|
||||
fileName := fmt.Sprintf("%s%s", uniqueID, ext)
|
||||
|
||||
filePath := filepath.Join(s.config.Upload.Path, fileName)
|
||||
|
||||
// 确保上传目录存在
|
||||
|
||||
if err := os.MkdirAll(s.config.Upload.Path, 0755); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 保存文件
|
||||
|
||||
src, err := file.Open()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
defer src.Close()
|
||||
|
||||
dst, err := os.Create(filePath)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
defer dst.Close()
|
||||
|
||||
if _, err = io.Copy(dst, src); err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 如果提供了设备UID,获取设备型号
|
||||
|
||||
var deviceModel string
|
||||
|
||||
if deviceUID != "" {
|
||||
|
||||
var device model.Device
|
||||
|
||||
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err == nil {
|
||||
|
||||
deviceModel = device.DeviceModel
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 创建数据库记录
|
||||
|
||||
upload := &model.FileUpload{
|
||||
|
||||
FileName: file.Filename,
|
||||
|
||||
FilePath: filePath,
|
||||
|
||||
FileSize: file.Size,
|
||||
|
||||
FileType: strings.ToLower(ext),
|
||||
|
||||
UploadedBy: userID,
|
||||
|
||||
DeviceUID: deviceUID,
|
||||
|
||||
DeviceModel: deviceModel,
|
||||
|
||||
Description: description,
|
||||
}
|
||||
|
||||
if err := s.db.Create(upload).Error; err != nil {
|
||||
|
||||
os.Remove(filePath)
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
return upload, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) DownloadFile(id uint) (*model.FileUpload, error) {
|
||||
|
||||
var file model.FileUpload
|
||||
|
||||
if err := s.db.First(&file, id).Error; err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
return &file, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) DeleteFile(id uint, userID uint) error {
|
||||
|
||||
var file model.FileUpload
|
||||
|
||||
if err := s.db.First(&file, id).Error; err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
if file.UploadedBy != userID {
|
||||
|
||||
return errors.New("无权删除此文件")
|
||||
|
||||
}
|
||||
|
||||
if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
return s.db.Delete(&file).Error
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) GetDeviceFiles(deviceUID string) ([]model.FileUpload, error) {
|
||||
|
||||
var files []model.FileUpload
|
||||
|
||||
err := s.db.Where("device_uid = ?", deviceUID).Find(&files).Error
|
||||
|
||||
return files, err
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) UploadChunk(
|
||||
|
||||
file *multipart.FileHeader,
|
||||
|
||||
fileHash string,
|
||||
|
||||
chunkNumber int,
|
||||
|
||||
totalChunks int,
|
||||
|
||||
totalSize int64,
|
||||
|
||||
filename string,
|
||||
|
||||
userID uint,
|
||||
|
||||
deviceUID string,
|
||||
|
||||
) error {
|
||||
|
||||
// 创建分片存储目录
|
||||
|
||||
chunkDir := filepath.Join(s.config.Upload.Path, "chunks", fileHash)
|
||||
|
||||
if err := os.MkdirAll(chunkDir, 0755); err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// 保存分片文件
|
||||
|
||||
chunkPath := filepath.Join(chunkDir, fmt.Sprintf("%d", chunkNumber))
|
||||
|
||||
src, err := file.Open()
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
defer src.Close()
|
||||
|
||||
dst, err := os.Create(chunkPath)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
defer dst.Close()
|
||||
|
||||
if _, err = io.Copy(dst, src); err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// 记录分片信息
|
||||
|
||||
chunk := model.UploadChunk{
|
||||
|
||||
FileHash: fileHash,
|
||||
|
||||
ChunkNumber: chunkNumber,
|
||||
|
||||
ChunkSize: file.Size,
|
||||
|
||||
ChunkPath: chunkPath,
|
||||
|
||||
TotalChunks: totalChunks,
|
||||
|
||||
TotalSize: totalSize,
|
||||
|
||||
Filename: filename,
|
||||
|
||||
FileType: strings.ToLower(filepath.Ext(filename)),
|
||||
|
||||
UploadedBy: userID,
|
||||
|
||||
DeviceUID: deviceUID,
|
||||
}
|
||||
|
||||
return s.db.Create(&chunk).Error
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) CheckUploadStatus(fileHash string) (bool, error) {
|
||||
|
||||
var chunks []model.UploadChunk
|
||||
|
||||
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
|
||||
|
||||
return false, err
|
||||
|
||||
}
|
||||
|
||||
if len(chunks) == 0 {
|
||||
|
||||
return false, nil
|
||||
|
||||
}
|
||||
|
||||
totalChunks := chunks[0].TotalChunks
|
||||
|
||||
return len(chunks) == totalChunks, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) MergeChunks(fileHash string) (*model.FileUpload, error) {
|
||||
|
||||
var chunks []model.UploadChunk
|
||||
|
||||
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
if len(chunks) == 0 {
|
||||
|
||||
return nil, errors.New("未找到文件分片")
|
||||
|
||||
}
|
||||
|
||||
if len(chunks) != chunks[0].TotalChunks {
|
||||
|
||||
return nil, errors.New("文件分片不完整")
|
||||
|
||||
}
|
||||
|
||||
// 按分片序号排序
|
||||
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
|
||||
return chunks[i].ChunkNumber < chunks[j].ChunkNumber
|
||||
|
||||
})
|
||||
|
||||
// 创建最终文件
|
||||
|
||||
finalPath := filepath.Join(s.config.Upload.Path, fmt.Sprintf("%s%s", uuid.New().String(), chunks[0].FileType))
|
||||
|
||||
finalFile, err := os.Create(finalPath)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
defer finalFile.Close()
|
||||
|
||||
// 合并分片
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
for _, chunk := range chunks {
|
||||
|
||||
chunkFile, err := os.Open(chunk.ChunkPath)
|
||||
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
if _, err = io.Copy(finalFile, chunkFile); err != nil {
|
||||
|
||||
chunkFile.Close()
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
if _, err = io.Copy(hash, chunkFile); err != nil {
|
||||
|
||||
chunkFile.Close()
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
chunkFile.Close()
|
||||
|
||||
os.Remove(chunk.ChunkPath) // 删除已合并的分片
|
||||
|
||||
}
|
||||
|
||||
// 验证文件哈希
|
||||
|
||||
if hex.EncodeToString(hash.Sum(nil)) != fileHash {
|
||||
|
||||
os.Remove(finalPath)
|
||||
|
||||
return nil, errors.New("文件哈希验证失败")
|
||||
|
||||
}
|
||||
|
||||
// 创建文件记录
|
||||
|
||||
upload := &model.FileUpload{
|
||||
|
||||
FileName: chunks[0].Filename,
|
||||
|
||||
FilePath: finalPath,
|
||||
|
||||
FileSize: chunks[0].TotalSize,
|
||||
|
||||
FileType: chunks[0].FileType,
|
||||
|
||||
UploadedBy: chunks[0].UploadedBy,
|
||||
|
||||
DeviceUID: chunks[0].DeviceUID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(upload).Error; err != nil {
|
||||
|
||||
os.Remove(finalPath)
|
||||
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
// 清理分片记录
|
||||
|
||||
s.db.Where("file_hash = ?", fileHash).Delete(&model.UploadChunk{})
|
||||
|
||||
os.RemoveAll(filepath.Dir(chunks[0].ChunkPath))
|
||||
|
||||
return upload, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UploadService) CleanupExpiredChunks() error {
|
||||
|
||||
expireTime := time.Now().Add(-24 * time.Hour)
|
||||
|
||||
var expiredChunks []model.UploadChunk
|
||||
|
||||
if err := s.db.Where("completed = ? AND created_at < ?", false, expireTime).Find(&expiredChunks).Error; err != nil {
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
for _, chunk := range expiredChunks {
|
||||
|
||||
os.Remove(chunk.ChunkPath)
|
||||
|
||||
if len(chunk.ChunkPath) > 0 {
|
||||
|
||||
chunkDir := filepath.Dir(chunk.ChunkPath)
|
||||
|
||||
os.RemoveAll(chunkDir)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return s.db.Unscoped().Where("completed = ? AND created_at < ?", false, expireTime).Delete(&model.UploadChunk{}).Error
|
||||
|
||||
}
|
361
internal/service/user.go
Normal file
361
internal/service/user.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"fmt"
|
||||
|
||||
"time"
|
||||
|
||||
"licserver/internal/model"
|
||||
|
||||
"licserver/internal/utils"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
db *gorm.DB
|
||||
config *utils.Config
|
||||
captchaService *CaptchaService
|
||||
}
|
||||
|
||||
type UserProfile struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB, config *utils.Config) *UserService {
|
||||
return &UserService{
|
||||
db: db,
|
||||
config: config,
|
||||
captchaService: NewCaptchaService(db, &config.Email),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserService) Register(username, password, email, captcha string) error {
|
||||
// 验证验证码
|
||||
if err := s.captchaService.VerifyCaptcha(email, "register", captcha); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
var count int64
|
||||
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 原有的注册逻辑
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
Username: username,
|
||||
Password: string(hashedPassword),
|
||||
Email: email,
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
return s.db.Create(&user).Error
|
||||
}
|
||||
|
||||
func (s *UserService) Login(username, password string) (string, error) {
|
||||
var user model.User
|
||||
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", errors.New("用户不存在")
|
||||
}
|
||||
return "", err
|
||||
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
return "", errors.New("密码错误")
|
||||
}
|
||||
|
||||
// 生成 JWT token
|
||||
token, err := utils.GenerateToken(user.ID, user.Username, user.Role, &s.config.JWT)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
s.db.Model(&user).Update("last_login", gorm.Expr("CURRENT_TIMESTAMP"))
|
||||
return token, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UserService) GetUserByID(id uint) (*UserProfile, error) {
|
||||
var user model.User
|
||||
if err := s.db.First(&user, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UserProfile{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateProfile(userID uint, email string) error {
|
||||
// 检查邮箱是否被其他用户使用
|
||||
var count int64
|
||||
if err := s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, userID).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("email", email).Error
|
||||
}
|
||||
|
||||
func (s *UserService) ChangePassword(userID uint, oldPassword, newPassword string) error {
|
||||
var user model.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(oldPassword)); err != nil {
|
||||
return errors.New("旧密码错误")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Model(&user).Update("password", string(hashedPassword)).Error
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPassword(email, captcha string) error {
|
||||
// 验证验证码
|
||||
if err := s.captchaService.VerifyCaptcha(email, "reset", captcha); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 原有的重置密码逻辑
|
||||
var user model.User
|
||||
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
|
||||
return errors.New("邮箱不存在")
|
||||
}
|
||||
|
||||
// 生成重置令牌
|
||||
token, err := utils.GenerateResetToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存重置令牌
|
||||
resetToken := model.PasswordResetToken{
|
||||
UserID: user.ID,
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
Used: false,
|
||||
}
|
||||
if err := s.db.Create(&resetToken).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送重置邮件
|
||||
emailService := utils.NewEmailService(&s.config.Email)
|
||||
resetLink := fmt.Sprintf("http://localhost:%s/reset-password?token=%s", s.config.Server.Port, token)
|
||||
emailBody := fmt.Sprintf(`
|
||||
<h3>密码重置</h3>
|
||||
<p>您好,%s</p>
|
||||
<p>请点击以下链接重置您的密码:</p>
|
||||
<p><a href="%s">重置密码</a></p>
|
||||
<p>此链接将在24小时后失效。</p>
|
||||
<p>如果您没有请求重置密码,请忽略此邮件。</p>
|
||||
`, user.Username, resetLink)
|
||||
return emailService.SendEmail(user.Email, "密码重置", emailBody)
|
||||
|
||||
}
|
||||
|
||||
func (s *UserService) ValidateResetToken(token string) (*model.User, error) {
|
||||
var resetToken model.PasswordResetToken
|
||||
if err := s.db.Where("token = ? AND used = ? AND expires_at > ?",
|
||||
token, false, time.Now()).First(&resetToken).Error; err != nil {
|
||||
return nil, errors.New("无效或已过期的重置令牌")
|
||||
}
|
||||
|
||||
var user model.User
|
||||
if err := s.db.First(&user, resetToken.UserID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPasswordWithToken(token, newPassword string) error {
|
||||
user, err := s.ValidateResetToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用事务确保原子性
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Model(&user).Update("password", string(hashedPassword)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 标记令牌为已使用
|
||||
if err := tx.Model(&model.PasswordResetToken{}).
|
||||
Where("token = ?", token).
|
||||
Update("used", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserService) SendRegisterCaptcha(email string) error {
|
||||
// 检查邮箱是否已被注册
|
||||
var count int64
|
||||
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
return s.captchaService.SendEmailCaptcha(email, "register")
|
||||
}
|
||||
|
||||
func (s *UserService) SendResetPasswordCaptcha(email string) error {
|
||||
var user model.User
|
||||
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
|
||||
return errors.New("邮箱不存在")
|
||||
}
|
||||
|
||||
return s.captchaService.SendEmailCaptcha(email, "reset")
|
||||
}
|
||||
|
||||
func (s *UserService) GetCaptchaService() *CaptchaService {
|
||||
return s.captchaService
|
||||
}
|
||||
|
||||
// GetUsers 获取用户列表
|
||||
func (s *UserService) GetUsers(username, role string, page, pageSize int) ([]UserProfile, int64, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
var profiles []UserProfile
|
||||
|
||||
query := s.db.Model(&model.User{})
|
||||
|
||||
if username != "" {
|
||||
query = query.Where("username LIKE ?", "%"+username+"%")
|
||||
}
|
||||
if role != "" {
|
||||
query = query.Where("role = ?", role)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
query.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
query = query.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为 UserProfile
|
||||
for _, user := range users {
|
||||
profiles = append(profiles, UserProfile{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
})
|
||||
}
|
||||
|
||||
return profiles, total, nil
|
||||
}
|
||||
|
||||
// CreateUser 创建新用户
|
||||
func (s *UserService) CreateUser(username, password, email, role string) error {
|
||||
// 检查用户名是否已存在
|
||||
var count int64
|
||||
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
Username: username,
|
||||
Password: string(hashedPassword),
|
||||
Email: email,
|
||||
Role: role,
|
||||
}
|
||||
|
||||
return s.db.Create(&user).Error
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (s *UserService) UpdateUser(id uint, email, role string) error {
|
||||
// 检查邮箱是否被其他用户使用
|
||||
var count int64
|
||||
s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, id).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
return s.db.Model(&model.User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||
"email": email,
|
||||
"role": role,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (s *UserService) DeleteUser(id uint) error {
|
||||
// 检查是否为最后一个管理员
|
||||
var adminCount int64
|
||||
s.db.Model(&model.User{}).Where("role = ?", "admin").Count(&adminCount)
|
||||
|
||||
var user model.User
|
||||
if err := s.db.First(&user, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.Role == "admin" && adminCount <= 1 {
|
||||
return errors.New("不能删除最后一个管理员")
|
||||
}
|
||||
|
||||
return s.db.Delete(&model.User{}, id).Error
|
||||
}
|
61
internal/service/user_test.go
Normal file
61
internal/service/user_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"licserver/internal/model"
|
||||
"licserver/internal/utils"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUserService_Register(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
config := utils.TestConfig()
|
||||
userService := NewUserService(db, config)
|
||||
|
||||
// 测试正常注册
|
||||
err := userService.Register("testuser", "password123", "test@example.com", "123456")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var user model.User
|
||||
err = db.Where("username = ?", "testuser").First(&user).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testuser", user.Username)
|
||||
assert.Equal(t, "test@example.com", user.Email)
|
||||
assert.Equal(t, "user", user.Role)
|
||||
|
||||
// 测试重复用户名
|
||||
err = userService.Register("testuser", "password123", "test2@example.com", "123456")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "用户名已存在")
|
||||
|
||||
// 测试重复邮箱
|
||||
err = userService.Register("testuser2", "password123", "test@example.com", "123456")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "邮箱已被注册")
|
||||
}
|
||||
|
||||
func TestUserService_Login(t *testing.T) {
|
||||
db := utils.TestDB(t)
|
||||
config := utils.TestConfig()
|
||||
userService := NewUserService(db, config)
|
||||
|
||||
// 创建测试用户
|
||||
err := userService.Register("testuser", "password123", "test@example.com", "123456")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 测试正确登录
|
||||
token, err := userService.Login("testuser", "password123")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 测试错误密码
|
||||
_, err = userService.Login("testuser", "wrongpassword")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "密码错误")
|
||||
|
||||
// 测试不存在的用户
|
||||
_, err = userService.Login("nonexistent", "password123")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "用户不存在")
|
||||
}
|
Reference in New Issue
Block a user