first commit

This commit is contained in:
JiXieShi
2024-11-14 22:55:43 +08:00
commit 421cfb8cfa
98 changed files with 12617 additions and 0 deletions

100
internal/service/captcha.go Normal file
View File

@@ -0,0 +1,100 @@
package service
import (
"errors"
"licserver/internal/model"
"licserver/internal/utils"
"strings"
"time"
"github.com/mojocn/base64Captcha"
"gorm.io/gorm"
)
type CaptchaService struct {
db *gorm.DB
emailConfig *utils.EmailConfig
store base64Captcha.Store
}
func NewCaptchaService(db *gorm.DB, emailConfig *utils.EmailConfig) *CaptchaService {
return &CaptchaService{
db: db,
emailConfig: emailConfig,
store: base64Captcha.DefaultMemStore,
}
}
func (s *CaptchaService) SendEmailCaptcha(email, captchaType string) error {
// 检查是否存在未过期的验证码
var count int64
s.db.Model(&model.Captcha{}).
Where("target = ? AND type = ? AND expires_at > ? AND used = ?",
email, captchaType, time.Now(), false).
Count(&count)
if count > 0 {
return errors.New("请勿频繁发送验证码")
}
// 生成验证码
code, err := utils.GenerateCaptcha()
if err != nil {
return err
}
// 保存验证码
captcha := model.Captcha{
Code: code,
Type: captchaType,
Target: email,
ExpiresAt: time.Now().Add(5 * time.Minute),
Used: false,
}
if err := s.db.Create(&captcha).Error; err != nil {
return err
}
// 发送验证码邮件
emailService := utils.NewEmailService(s.emailConfig)
content := utils.GenerateEmailCaptchaContent(code, email, captchaType)
return emailService.SendEmail(email, "验证码", content)
}
func (s *CaptchaService) VerifyCaptcha(target, captchaType, code string) error {
var captcha model.Captcha
err := s.db.Where("target = ? AND type = ? AND code = ? AND used = ? AND expires_at > ?",
target, captchaType, code, false, time.Now()).
First(&captcha).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("验证码无效或已过期")
}
return err
}
// 标记验证码为已使用
return s.db.Model(&captcha).Update("used", true).Error
}
// 生成图片验证码
func (s *CaptchaService) GenerateImageCaptcha() (string, string, error) {
driver := base64Captcha.NewDriverDigit(80, 240, 6, 0.7, 80)
c := base64Captcha.NewCaptcha(driver, s.store)
id, b64s, err := c.Generate()
if err != nil {
return "", "", err
}
// 确保返回的base64字符串不包含前缀
b64s = strings.TrimPrefix(b64s, "data:image/png;base64,")
return id, b64s, nil
}
// 验证图片验证码
func (s *CaptchaService) VerifyImageCaptcha(id, code string) bool {
return s.store.Verify(id, code, true)
}

721
internal/service/device.go Normal file
View File

@@ -0,0 +1,721 @@
package service
import (
"errors"
"fmt"
"licserver/internal/model"
"time"
"gorm.io/gorm"
)
type DeviceService struct {
db *gorm.DB
licenseService *LicenseService
}
func NewDeviceService(db *gorm.DB, licenseService *LicenseService) *DeviceService {
return &DeviceService{
db: db,
licenseService: licenseService,
}
}
type DeviceRegisterInput struct {
UID string `json:"uid" binding:"required"`
DeviceModel string `json:"device_model" binding:"required"`
LicenseCode string `json:"license_code"`
}
func (s *DeviceService) RegisterDevice(input *DeviceRegisterInput, ip string) error {
// 检查设备型号是否存在且处于启用状态
var deviceModel model.DeviceModel
if err := s.db.Where("model_name = ? AND status = ?", input.DeviceModel, "active").First(&deviceModel).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("设备型号不存在或已禁用")
}
return err
}
// 检查设备是否已注册
var count int64
s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count)
if count > 0 {
return errors.New("设备已注册")
}
// 创建设备记录
device := &model.Device{
UID: input.UID,
DeviceType: deviceModel.DeviceType,
DeviceModel: input.DeviceModel,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "inactive",
LastActiveAt: time.Now(),
StartCount: 0,
}
// 如果提供了授权码,进行授权绑定
if input.LicenseCode != "" {
license, err := s.licenseService.GetLicenseByCode(input.LicenseCode)
if err != nil {
return err
}
if license.Status != "unused" {
return errors.New("授权码已被使用")
}
device.Status = "active"
device.LicenseCode = license.Code
device.LicenseType = license.LicenseType
device.MaxUses = license.MaxUses
device.Duration = license.Duration
if license.LicenseType == "time" {
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
}
}
return s.db.Transaction(func(tx *gorm.DB) error {
// 创建设备记录
if err := tx.Create(device).Error; err != nil {
return err
}
// 如果有授权码,更新授权码状态
if device.LicenseCode != "" {
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", device.LicenseCode).
Updates(map[string]interface{}{
"status": "used",
"used_by": input.UID,
"used_at": time.Now(),
}).Error; err != nil {
return err
}
}
// 记录设备日志
logMsg := "设备注册成功"
if device.LicenseCode != "" {
logMsg += fmt.Sprintf(",使用授权码: %s", device.LicenseCode)
}
log := model.DeviceLog{
DeviceUID: input.UID,
Action: "register",
Message: logMsg,
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
func (s *DeviceService) ValidateDevice(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备未注册")
}
// 更新最后活跃时间
device.LastActiveAt = time.Now()
// 如果设备已激活,检查授权状态
if device.Status == "active" {
if device.LicenseCode != "" {
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
device.Status = "expired"
s.db.Save(&device)
return errors.New("设备授权已过期")
}
}
}
return s.db.Save(&device).Error
}
func (s *DeviceService) GetDevices(params *DeviceQueryParams) ([]model.Device, int64, error) {
var devices []model.Device
var total int64
query := s.db.Model(&model.Device{})
if params.UID != "" {
query = query.Where("uid LIKE ?", "%"+params.UID+"%")
}
if params.DeviceType != "" {
query = query.Where("device_type = ?", params.DeviceType)
}
if params.Company != "" {
query = query.Where("company LIKE ?", "%"+params.Company+"%")
}
if params.LicenseType != "" {
query = query.Where("license_type = ?", params.LicenseType)
}
if params.Status != "" {
query = query.Where("status = ?", params.Status)
}
query.Count(&total)
if params.Page > 0 && params.PageSize > 0 {
offset := (params.Page - 1) * params.PageSize
query = query.Offset(offset).Limit(params.PageSize)
}
err := query.Find(&devices).Error
return devices, total, err
}
func (s *DeviceService) UpdateStartCount(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return err
}
// 更新启动次数和最后活跃时间
device.StartCount++
device.LastActiveAt = time.Now()
// 如果设备已激活,检查授权状态
if device.Status == "active" {
// 检查授权码有效性
if device.LicenseCode != "" {
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
device.Status = "expired"
}
}
// 检查次数限制
if device.LicenseType == "count" && device.StartCount >= device.MaxUses {
device.Status = "expired"
}
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "start",
Message: fmt.Sprintf("设备启动,当前次数:%d", device.StartCount),
Status: "success",
}
return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&device).Error; err != nil {
return err
}
return tx.Create(&log).Error
})
}
func (s *DeviceService) UpdateDevice(uid string, updates map[string]interface{}) error {
return s.db.Model(&model.Device{}).Where("uid = ?", uid).Updates(updates).Error
}
func (s *DeviceService) DeleteDevice(uid string) error {
return s.db.Where("uid = ?", uid).Delete(&model.Device{}).Error
}
func (s *DeviceService) GetLicenseInfo(deviceUID string) (*model.Device, error) {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
return &device, nil
}
func (s *DeviceService) CheckLicenseStatus(deviceUID string) (string, error) {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return "", errors.New("设备不存在")
}
if device.LicenseCode == "" {
return "未授权", nil
}
if device.Status != "active" {
return device.Status, nil
}
switch device.LicenseType {
case "时间段":
if time.Now().After(device.ExpireTime) {
device.Status = "expired"
s.db.Save(&device)
return "已过期", nil
}
case "启动次数":
if device.StartCount >= device.MaxUses {
device.Status = "expired"
s.db.Save(&device)
return "已达到使用上限", nil
}
}
return "正常", nil
}
type DeviceQueryParams struct {
UID string
DeviceType string
Company string
LicenseType string
Status string
Page int
PageSize int
}
type DeviceCreateInput struct {
UID string `json:"uid" binding:"required"`
DeviceType string `json:"device_type" binding:"required"`
DeviceModel string `json:"device_model" binding:"required"`
Company string `json:"company"`
}
func (s *DeviceService) CreateDevice(input *DeviceCreateInput) error {
// 检查设备UID是否已存在
var count int64
if err := s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("设备UID已存在")
}
// 创建设备记录
device := &model.Device{
UID: input.UID,
DeviceType: input.DeviceType,
DeviceModel: input.DeviceModel,
Company: input.Company,
RegisterTime: time.Now(),
Status: "inactive", // 初始状态为未激活
}
return s.db.Create(device).Error
}
// 添加更新检查方法
func (s *DeviceService) CheckUpdate(deviceUID, currentVersion string) (*model.FileUpload, error) {
// 获取设备信息
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
// 检查设备状态
if device.Status != "active" {
return nil, errors.New("设备未激活或已过期")
}
// 查找最新的更新文件
var update model.FileUpload
err := s.db.Where("device_model = ? AND is_update = ?", device.DeviceModel, true).
Order("created_at DESC").
First(&update).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 没有可用更新
}
return nil, err
}
// 比较版本
if update.Version <= currentVersion && !update.ForceUpdate {
return nil, nil // 当前版本已是最新
}
return &update, nil
}
// 添加设备型号相关的方法
// CreateDeviceModel 创建设备型号
func (s *DeviceService) CreateDeviceModel(model_ *model.DeviceModel) error {
// 检查型号名称是否已存在
var count int64
s.db.Model(&model.DeviceModel{}).Where("model_name = ?", model_.ModelName).Count(&count)
if count > 0 {
return errors.New("设备型号已存在")
}
return s.db.Create(model_).Error
}
// UpdateDeviceModel 更新设备型号
func (s *DeviceService) UpdateDeviceModel(id uint, model_ *model.DeviceModel) error {
// 检查型号名称是否被其他型号使用
var count int64
s.db.Model(&model.DeviceModel{}).Where("model_name = ? AND id != ?", model_.ModelName, id).Count(&count)
if count > 0 {
return errors.New("设备型号已存在")
}
return s.db.Model(&model.DeviceModel{}).Where("id = ?", id).Updates(model_).Error
}
// DeleteDeviceModel 删除设备型号
func (s *DeviceService) DeleteDeviceModel(id uint) error {
// 检查是否有设备使用此型号
var count int64
s.db.Model(&model.DeviceModel{}).Where("device_model = ?", id).Count(&count)
if count > 0 {
return errors.New("该型号下存在设备,无法删除")
}
return s.db.Delete(&model.DeviceModel{}, id).Error
}
// GetDeviceModels 获取设备型号列表
func (s *DeviceService) GetDeviceModels(modelName, deviceType, company string, page, pageSize int) ([]model.DeviceModel, int64, error) {
var models []model.DeviceModel
var total int64
query := s.db.Model(&model.DeviceModel{})
if modelName != "" {
query = query.Where("model_name LIKE ?", "%"+modelName+"%")
}
if deviceType != "" {
query = query.Where("device_type = ?", deviceType)
}
if company != "" {
query = query.Where("company LIKE ?", "%"+company+"%")
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
// 查询设备型号列表
if err := query.Find(&models).Error; err != nil {
return nil, 0, err
}
// 查询每个型号下的设备数量
for i := range models {
var count int64
s.db.Model(&model.Device{}).Where("device_model = ?", models[i].ModelName).Count(&count)
models[i].DeviceCount = int(count)
}
return models, total, nil
}
// BatchDeleteDeviceModels 批量删除设备型号
func (s *DeviceService) BatchDeleteDeviceModels(ids []uint) error {
// 检查是否有设备使用这些型号
var count int64
s.db.Model(&model.Device{}).Where("device_model IN (?)", ids).Count(&count)
if count > 0 {
return errors.New("选中的型号中存在正在使用的型号,无法删除")
}
return s.db.Delete(&model.DeviceModel{}, ids).Error
}
// GetRegisteredDevices 获取已注册设备列表
func (s *DeviceService) GetRegisteredDevices(uid, deviceModel, status string, page, pageSize int) ([]model.Device, int64, error) {
var devices []model.Device
var total int64
query := s.db.Model(&model.Device{})
if uid != "" {
query = query.Where("uid LIKE ?", "%"+uid+"%")
}
if deviceModel != "" {
query = query.Where("device_model = ?", deviceModel)
}
if status != "" {
query = query.Where("status = ?", status)
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&devices).Error
return devices, total, err
}
// BindLicense 绑定授权码
func (s *DeviceService) BindLicense(uid string, licenseCode string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备不存在")
}
// 检查设备当前状态
if device.LicenseCode != "" {
return errors.New("设备已绑定授权码,请先解绑")
}
// 验证授权码
license, err := s.licenseService.GetLicenseByCode(licenseCode)
if err != nil {
return err
}
if license.Status != "unused" {
return errors.New("授权码已被使用")
}
// 根据授权类型处理
switch license.LicenseType {
case "time":
if license.Duration <= 0 {
return errors.New("无效的授权时长")
}
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
device.Duration = license.Duration
device.MaxUses = 0
case "count":
if license.MaxUses <= 0 {
return errors.New("无效的使用次数")
}
device.ExpireTime = time.Time{} // 清空过期时间
device.Duration = 0
device.MaxUses = license.MaxUses
device.StartCount = 0 // 重置启动次数
case "permanent":
device.ExpireTime = time.Time{} // 清空过期时间
device.Duration = 0
device.MaxUses = 0
default:
return errors.New("无效的授权类型")
}
// 更新设备基本信息
device.LicenseCode = licenseCode
device.LicenseType = license.LicenseType
device.Status = "active"
device.LastActiveAt = time.Now()
return s.db.Transaction(func(tx *gorm.DB) error {
// 更新设备信息
if err := tx.Save(&device).Error; err != nil {
return err
}
// 更新授权码状态
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", licenseCode).
Updates(map[string]interface{}{
"status": "used",
"used_by": uid,
"used_at": time.Now(),
}).Error; err != nil {
return err
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "bind_license",
Message: fmt.Sprintf("绑定%s授权码: %s", getLicenseTypeText(license.LicenseType), licenseCode),
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
// 获取授权类型的中文描述
func getLicenseTypeText(licenseType string) string {
switch licenseType {
case "time":
return "时间"
case "count":
return "次数"
case "permanent":
return "永久"
default:
return "未知"
}
}
// UnbindLicense 解绑授权码
func (s *DeviceService) UnbindLicense(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备不存在")
}
if device.LicenseCode == "" {
return errors.New("设备未绑定授权码")
}
oldLicenseCode := device.LicenseCode
return s.db.Transaction(func(tx *gorm.DB) error {
// 更新设备信息
if err := tx.Model(&device).Updates(map[string]interface{}{
"license_code": "",
"license_type": "",
"status": "inactive",
"expire_time": nil,
"max_uses": 0,
"duration": 0,
}).Error; err != nil {
return err
}
// 更新授权码状态
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", oldLicenseCode).
Updates(map[string]interface{}{
"status": "unused",
"used_by": "",
"used_at": nil,
}).Error; err != nil {
return err
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "unbind_license",
Message: fmt.Sprintf("解绑授权码: %s", oldLicenseCode),
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
// GetDeviceLogs 获取设备日志
func (s *DeviceService) GetDeviceLogs(uid string, page, pageSize int) ([]model.DeviceLog, int64, error) {
var logs []model.DeviceLog
var total int64
query := s.db.Model(&model.DeviceLog{}).Where("device_uid = ?", uid)
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}
// DashboardStats 仪表盘统计数据
type DashboardStats struct {
TotalDevices int64 `json:"total_devices"` // 设备总数
TotalLicenses int64 `json:"total_licenses"` // 授权码总数
TodayNew int64 `json:"today_new"` // 今日新增
OnlineDevices int64 `json:"online_devices"` // 在线设备
ActiveDevices int64 `json:"active_devices"` // 激活设备
ExpiredDevices int64 `json:"expired_devices"` // 过期设备
}
// GetDashboardStats 获取仪表盘统计数据
func (s *DeviceService) GetDashboardStats() (*DashboardStats, error) {
var stats DashboardStats
// 获取设备总数
s.db.Model(&model.Device{}).Count(&stats.TotalDevices)
// 获取授权码总数
s.db.Model(&model.LicenseCode{}).Count(&stats.TotalLicenses)
// 获取今日新增设备数
today := time.Now().Format("2006-01-02")
s.db.Model(&model.Device{}).Where("DATE(register_time) = ?", today).Count(&stats.TodayNew)
// 获取在线设备数最近30分钟内有活动的设备
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
s.db.Model(&model.Device{}).Where("last_active_at > ?", thirtyMinutesAgo).Count(&stats.OnlineDevices)
// 获取激活设备数
s.db.Model(&model.Device{}).Where("status = ?", "active").Count(&stats.ActiveDevices)
// 获取过期设备数
s.db.Model(&model.Device{}).Where("status = ?", "expired").Count(&stats.ExpiredDevices)
return &stats, nil
}

View File

@@ -0,0 +1,298 @@
package service
import (
"licserver/internal/model"
"licserver/internal/utils"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDeviceService_RegisterDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
assert.Len(t, licenses, 1)
tests := []struct {
name string
input *DeviceRegisterInput
wantErr bool
}{
{
name: "正常注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: false,
},
{
name: "重复注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
{
name: "使用无效授权码",
input: &DeviceRegisterInput{
UID: "test-device-002",
DeviceModel: "test-model",
LicenseCode: "invalid-code",
},
wantErr: true,
},
{
name: "使用不存在的设备型号",
input: &DeviceRegisterInput{
UID: "test-device-003",
DeviceModel: "non-existent-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.RegisterDevice(tt.input, "127.0.0.1")
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证设备是否正确注册
var device model.Device
err = db.Where("uid = ?", tt.input.UID).First(&device).Error
assert.NoError(t, err)
assert.Equal(t, deviceModel.DeviceType, device.DeviceType)
assert.Equal(t, "active", device.Status)
assert.Equal(t, tt.input.LicenseCode, device.LicenseCode)
})
}
}
func TestDeviceService_ValidateDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
// 注册测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseCode: licenses[0].Code,
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(30 * 24 * time.Hour),
}
err = db.Create(device).Error
assert.NoError(t, err)
tests := []struct {
name string
uid string
wantErr bool
}{
{
name: "验证正常设备",
uid: "test-device-001",
wantErr: false,
},
{
name: "验证不存在的设备",
uid: "non-existent-device",
wantErr: true,
},
{
name: "验证过期设备",
uid: "expired-device",
wantErr: true,
},
}
// 创建过期设备
expiredDevice := &model.Device{
UID: "expired-device",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "expired",
LicenseCode: "expired-license",
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(-24 * time.Hour),
}
err = db.Create(expiredDevice).Error
assert.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.ValidateDevice(tt.uid)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}
func TestDeviceService_UpdateStartCount(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseType: "count",
MaxUses: 5,
StartCount: 0,
}
err = db.Create(device).Error
assert.NoError(t, err)
// 测试更新启动次数
for i := 1; i <= 5; i++ {
err = deviceService.UpdateStartCount(device.UID)
assert.NoError(t, err)
// 验证启动次数
var updatedDevice model.Device
err = db.First(&updatedDevice, "uid = ?", device.UID).Error
assert.NoError(t, err)
assert.Equal(t, i, updatedDevice.StartCount)
// 检查最后一次是否将状态更新为过期
if i == 5 {
assert.Equal(t, "expired", updatedDevice.Status)
} else {
assert.Equal(t, "active", updatedDevice.Status)
}
}
// 测试超出使用次数
err = deviceService.UpdateStartCount(device.UID)
assert.Error(t, err)
}
func TestDeviceService_CreateDeviceModel(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
tests := []struct {
name string
model *model.DeviceModel
wantErr bool
}{
{
name: "创建有效设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: false,
},
{
name: "重复的设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.CreateDeviceModel(tt.model)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证创建的设备型号
var savedModel model.DeviceModel
err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error
assert.NoError(t, err)
assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType)
assert.Equal(t, tt.model.Company, savedModel.Company)
assert.Equal(t, tt.model.Status, savedModel.Status)
})
}
}

709
internal/service/license.go Normal file
View File

@@ -0,0 +1,709 @@
package service
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"licserver/internal/model"
"gorm.io/gorm"
)
type LicenseService struct {
db *gorm.DB
}
func NewLicenseService(db *gorm.DB) *LicenseService {
return &LicenseService{db: db}
}
// 生成授权码
func generateLicenseCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// 创建授权码
type LicenseCreateInput struct {
LicenseType string `json:"license_type" binding:"required"`
Duration int `json:"duration"` // 时间授权的有效期(分钟)
MaxUses int `json:"max_uses"` // 次数授权的使用次数
Count int `json:"count" binding:"required,min=1"` // 生成数量
Remark string `json:"remark"` // 备注
BindCount int `json:"bind_count"` // 可绑定次数,默认为-1无限制
}
func (s *LicenseService) CreateLicenses(input *LicenseCreateInput, createdBy uint) ([]model.LicenseCode, error) {
// 验证参数
input.LicenseType = strings.ToLower(input.LicenseType) // 转为小写
switch input.LicenseType {
case "time":
if input.Duration <= 0 {
return nil, errors.New("时间授权必须指定有效期")
}
case "count":
if input.MaxUses <= 0 {
return nil, errors.New("次数授权必须指定使用次数")
}
case "permanent":
// 永久授权不需要额外参数
default:
return nil, errors.New("无效的授权类型")
}
// 如果未指定绑定次数,设置为默认值-1
if input.BindCount == 0 {
input.BindCount = -1
}
// 生成批次号
batchNo := time.Now().Format("20060102150405")
licenses := make([]model.LicenseCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
code, err := generateLicenseCode()
if err != nil {
return nil, err
}
license := model.LicenseCode{
Code: code,
LicenseType: input.LicenseType,
Duration: input.Duration,
MaxUses: input.MaxUses,
Status: "unused",
CreatedBy: createdBy,
BatchNo: batchNo,
Remark: input.Remark,
BindCount: input.BindCount,
}
licenses = append(licenses, license)
}
// 批量创建授权码
if err := s.db.Create(&licenses).Error; err != nil {
return nil, err
}
return licenses, nil
}
// 验证并使用授权码
func (s *LicenseService) UseLicense(code, deviceUID, ip string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return nil, errors.New("授权码不存在")
}
// 检查授权码状态
if license.Status != "unused" {
return nil, errors.New("授权码已被使用")
}
// 检查绑定次数
if license.BindCount == 0 {
return nil, errors.New("授权码已达到最大绑定次数限制")
}
// 更新授权码状态
updates := map[string]interface{}{
"status": "used",
"used_by": deviceUID,
"used_at": time.Now(),
}
// 如果不是无限制,减少绑定次数
if license.BindCount > 0 {
updates["bind_count"] = license.BindCount - 1
}
if err := s.db.Model(&license).Updates(updates).Error; err != nil {
return nil, err
}
// 记录使用日志
log := model.LicenseLog{
LicenseID: license.ID,
DeviceUID: deviceUID,
Action: "use",
IP: ip,
Status: "success",
Message: fmt.Sprintf("设备 %s 使用授权码", deviceUID),
}
if err := s.db.Create(&log).Error; err != nil {
return nil, err
}
return &license, nil
}
// 获取授权码列表
func (s *LicenseService) GetLicenses(status, licenseType, batchNo string, page, pageSize int) ([]model.LicenseCode, int64, error) {
var licenses []model.LicenseCode
var total int64
query := s.db.Model(&model.LicenseCode{})
if status != "" {
query = query.Where("status = ?", strings.ToLower(status))
}
if licenseType != "" {
query = query.Where("license_type = ?", strings.ToLower(licenseType))
}
if batchNo != "" {
query = query.Where("batch_no = ?", batchNo)
}
// 获取所有符合条件的授权码
var allLicenses []model.LicenseCode
if err := query.Find(&allLicenses).Error; err != nil {
return nil, 0, err
}
// 检查每个授权码的有效性
for i := range allLicenses {
if allLicenses[i].Status == "used" {
if err := s.CheckLicenseValidity(allLicenses[i].Code); err != nil {
// 如果检查失败,更新状态
s.db.Model(&allLicenses[i]).Update("status", "expired")
allLicenses[i].Status = "expired"
}
}
}
total = int64(len(allLicenses))
// 分页
if page > 0 && pageSize > 0 {
start := (page - 1) * pageSize
end := start + pageSize
if start < len(allLicenses) {
if end > len(allLicenses) {
end = len(allLicenses)
}
licenses = allLicenses[start:end]
}
} else {
licenses = allLicenses
}
return licenses, total, nil
}
// 获取授权码使用日志
func (s *LicenseService) GetLicenseLogs(licenseID uint, page, pageSize int) ([]model.LicenseLog, int64, error) {
var logs []model.LicenseLog
var total int64
query := s.db.Model(&model.LicenseLog{}).Where("license_id = ?", licenseID)
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}
// ExportLogs 导出授权码日志
func (s *LicenseService) ExportLogs(licenseID uint) ([]byte, error) {
logs, _, err := s.GetLicenseLogs(licenseID, 0, 0) // 获取所有日志
if err != nil {
return nil, err
}
// 创建CSV内容
var content strings.Builder
content.WriteString("操作类,设备UID,IP地址,状态,详细信息,时间\n")
for _, log := range logs {
// 转换操作类型
action := map[string]string{
"create": "创建",
"use": "使用",
"verify": "验证",
}[log.Action]
// 转换状态
status := map[string]string{
"success": "成功",
"failed": "失败",
}[log.Status]
// 写入一行记录
content.WriteString(fmt.Sprintf("%s,%s,%s,%s,%s,%s\n",
action,
log.DeviceUID,
log.IP,
status,
log.Message,
log.CreatedAt.Format("2006-01-02 15:04:05"),
))
}
return []byte(content.String()), nil
}
// 撤销授权码
func (s *LicenseService) RevokeLicense(code string, userID uint) error {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return errors.New("授权码不存在")
}
// 检查权限
if license.CreatedBy != userID {
return errors.New("无权操作此授权码")
}
// 更新状态
if err := s.db.Model(&license).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
log := model.LicenseLog{
LicenseID: license.ID,
Action: "revoke",
Status: "success",
Message: "授权码已撤销",
}
s.db.Create(&log)
return nil
}
// 批量撤销授权码
func (s *LicenseService) RevokeLicenses(codes []string, userID uint) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for _, code := range codes {
var license model.LicenseCode
if err := tx.Where("code = ?", code).First(&license).Error; err != nil {
continue
}
// 检查权限
if license.CreatedBy != userID {
continue
}
// 更新状态
if err := tx.Model(&license).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
log := model.LicenseLog{
LicenseID: license.ID,
Action: "revoke",
Status: "success",
Message: "授权码已撤销",
}
tx.Create(&log)
}
return nil
})
}
// 验证授权码
func (s *LicenseService) ValidateLicense(code string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return nil, errors.New("无效的授权码")
}
// 检查状态
if license.Status != "unused" {
return nil, errors.New("授权码已被使用或已撤销")
}
return &license, nil
}
// 导出授权码
func (s *LicenseService) ExportLicenses(codes []string) ([]byte, error) {
var licenses []model.LicenseCode
if err := s.db.Where("code IN ?", codes).Find(&licenses).Error; err != nil {
return nil, err
}
// 创建CSV内容
var content strings.Builder
content.WriteString("授权码,授权类型,有效期(天),使用次数,状态,使用设备,使用时间,批次号,备注\n")
for _, license := range licenses {
// 转换授权类型
licenseType := map[string]string{
"time": "时间授权",
"count": "次数授权",
"permanent": "永久授权",
}[license.LicenseType]
// 转换状态
status := map[string]string{
"unused": "未使用",
"used": "已使用",
"revoked": "已撤销",
}[license.Status]
// 写入一行记录
content.WriteString(fmt.Sprintf("%s,%s,%d,%d,%s,%s,%s,%s,%s\n",
license.Code,
licenseType,
license.Duration,
license.MaxUses,
status,
license.UsedBy,
license.UsedAt.Format("2006-01-02 15:04:05"),
license.BatchNo,
license.Remark,
))
}
return []byte(content.String()), nil
}
// 获取授权码统计信息
func (s *LicenseService) GetLicenseStats() (map[string]interface{}, error) {
var stats struct {
Total int64
Unused int64
Used int64
Revoked int64
Today int64
ThisWeek int64
ThisMonth int64
}
// 获取总数
s.db.Model(&model.LicenseCode{}).Count(&stats.Total)
// 获取各状态数量
s.db.Model(&model.LicenseCode{}).Where("status = ?", "unused").Count(&stats.Unused)
s.db.Model(&model.LicenseCode{}).Where("status = ?", "used").Count(&stats.Used)
s.db.Model(&model.LicenseCode{}).Where("status = ?", "revoked").Count(&stats.Revoked)
// 获取今日创建数量
today := time.Now().Format("2006-01-02")
s.db.Model(&model.LicenseCode{}).Where("DATE(created_at) = ?", today).Count(&stats.Today)
// 获取本周创建数量
weekStart := time.Now().AddDate(0, 0, -int(time.Now().Weekday()))
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", weekStart).Count(&stats.ThisWeek)
// 获取本月创建数量
monthStart := time.Now().Format("2006-01") + "-01"
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", monthStart).Count(&stats.ThisMonth)
return map[string]interface{}{
"total": stats.Total,
"unused": stats.Unused,
"used": stats.Used,
"revoked": stats.Revoked,
"today": stats.Today,
"this_week": stats.ThisWeek,
"this_month": stats.ThisMonth,
}, nil
}
// 添加检查授权码有效性的方法
func (s *LicenseService) CheckLicenseValidity(code string) error {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return errors.New("授权码不存在")
}
if license.Status != "unused" && license.Status != "used" {
return errors.New("授权码已被撤销或过期")
}
// 检查授权类型特定的限制
switch license.LicenseType {
case "time":
// 计算过期时间
expireTime := license.UsedAt.Add(time.Duration(license.Duration) * time.Minute)
if time.Now().After(expireTime) {
// 更新状态为过期
s.db.Model(&license).Update("status", "expired")
return errors.New("授权码已过期")
}
case "count":
if license.UsedCount >= license.MaxUses {
// 更新状态为过期
s.db.Model(&license).Update("status", "expired")
return errors.New("授权码使用次数已达上限")
}
}
return nil
}
// GetLicenseByCode 通过授权码获取授权信息
func (s *LicenseService) GetLicenseByCode(code string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("授权码不存在")
}
return nil, err
}
return &license, nil
}

View File

@@ -0,0 +1,153 @@
package service
import (
"licserver/internal/utils"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLicenseService_CreateLicenses(t *testing.T) {
db := utils.TestDB(t)
service := NewLicenseService(db)
tests := []struct {
name string
input *LicenseCreateInput
wantErr bool
}{
{
name: "创建时间授权码",
input: &LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 5,
Remark: "test time license",
},
wantErr: false,
},
{
name: "创建次数授权码",
input: &LicenseCreateInput{
LicenseType: "count",
MaxUses: 100,
Count: 3,
Remark: "test count license",
},
wantErr: false,
},
{
name: "创建永久授权码",
input: &LicenseCreateInput{
LicenseType: "permanent",
Count: 1,
Remark: "test permanent license",
},
wantErr: false,
},
{
name: "无效的授权类型",
input: &LicenseCreateInput{
LicenseType: "invalid",
Count: 1,
},
wantErr: true,
},
{
name: "时间授权无有效期",
input: &LicenseCreateInput{
LicenseType: "time",
Duration: 0,
Count: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
licenses, err := service.CreateLicenses(tt.input, 1)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Len(t, licenses, tt.input.Count)
for _, license := range licenses {
assert.Equal(t, tt.input.LicenseType, license.LicenseType)
assert.Equal(t, "unused", license.Status)
assert.Equal(t, tt.input.Remark, license.Remark)
if tt.input.LicenseType == "time" {
assert.Equal(t, tt.input.Duration, license.Duration)
} else if tt.input.LicenseType == "count" {
assert.Equal(t, tt.input.MaxUses, license.MaxUses)
}
}
})
}
}
func TestLicenseService_UseLicense(t *testing.T) {
db := utils.TestDB(t)
service := NewLicenseService(db)
// 创建测试授权码
input := &LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}
licenses, err := service.CreateLicenses(input, 1)
assert.NoError(t, err)
assert.Len(t, licenses, 1)
tests := []struct {
name string
code string
deviceUID string
ip string
wantErr bool
}{
{
name: "正常使用授权码",
code: licenses[0].Code,
deviceUID: "test-device-001",
ip: "127.0.0.1",
wantErr: false,
},
{
name: "使用不存在的授权码",
code: "invalid-code",
deviceUID: "test-device-002",
ip: "127.0.0.1",
wantErr: true,
},
{
name: "重复使用授权码",
code: licenses[0].Code,
deviceUID: "test-device-003",
ip: "127.0.0.1",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
license, err := service.UseLicense(tt.code, tt.deviceUID, tt.ip)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, "used", license.Status)
assert.Equal(t, tt.deviceUID, license.UsedBy)
assert.NotZero(t, license.UsedAt)
})
}
}

377
internal/service/monitor.go Normal file
View File

@@ -0,0 +1,377 @@
package service
import (
"licserver/internal/model"
"runtime"
"time"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/host"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/net"
"github.com/shirou/gopsutil/v3/process"
"gorm.io/gorm"
)
type MonitorService struct {
db *gorm.DB
startTime time.Time
}
func NewMonitorService(db *gorm.DB) *MonitorService {
return &MonitorService{
db: db,
startTime: time.Now(),
}
}
func (s *MonitorService) GetSystemStatus() (*model.SystemStatus, error) {
status := &model.SystemStatus{}
// CPU信息
if err := s.getCPUInfo(status); err != nil {
return nil, err
}
// 内存信息
if err := s.getMemoryInfo(status); err != nil {
return nil, err
}
// 磁盘信息
if err := s.getDiskInfo(status); err != nil {
return nil, err
}
// 网络信息
if err := s.getNetworkInfo(status); err != nil {
return nil, err
}
// 进程信息
if err := s.getProcessInfo(status); err != nil {
return nil, err
}
// 主机信息
if err := s.getHostInfo(status); err != nil {
return nil, err
}
// 系统信息
s.getSystemInfo(status)
return status, nil
}
func (s *MonitorService) getCPUInfo(status *model.SystemStatus) error {
cpuPercent, err := cpu.Percent(time.Second, false)
if err != nil {
return err
}
// Windows 系统不支持 LoadAvg设置为默认值
status.CPU.LoadAvg = []float64{0, 0, 0}
cpuInfo, err := cpu.Info()
if err != nil {
return err
}
status.CPU.Usage = cpuPercent[0]
status.CPU.CoreCount = runtime.NumCPU()
if len(cpuInfo) > 0 {
status.CPU.ModelName = cpuInfo[0].ModelName
status.CPU.MHz = cpuInfo[0].Mhz
}
return nil
}
func (s *MonitorService) getMemoryInfo(status *model.SystemStatus) error {
memInfo, err := mem.VirtualMemory()
if err != nil {
return err
}
swapInfo, err := mem.SwapMemory()
if err != nil {
return err
}
status.Memory.Total = memInfo.Total
status.Memory.Used = memInfo.Used
status.Memory.Free = memInfo.Free
status.Memory.UsageRate = memInfo.UsedPercent
status.Memory.SwapTotal = swapInfo.Total
status.Memory.SwapUsed = swapInfo.Used
status.Memory.SwapFree = swapInfo.Free
status.Memory.SwapUsageRate = swapInfo.UsedPercent
return nil
}
func (s *MonitorService) getDiskInfo(status *model.SystemStatus) error {
partitions, err := disk.Partitions(true)
if err != nil {
return err
}
status.Disk.Partitions = make([]model.DiskPartition, 0)
for _, partition := range partitions {
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
status.Disk.Partitions = append(status.Disk.Partitions, model.DiskPartition{
Device: partition.Device,
Mountpoint: partition.Mountpoint,
Fstype: partition.Fstype,
Total: usage.Total,
Used: usage.Used,
Free: usage.Free,
UsageRate: usage.UsedPercent,
})
}
return nil
}
func (s *MonitorService) getNetworkInfo(status *model.SystemStatus) error {
interfaces, err := net.Interfaces()
if err != nil {
return err
}
ioCounters, err := net.IOCounters(true)
if err != nil {
return err
}
status.Network.Interfaces = make([]model.NetworkInterface, 0)
for _, iface := range interfaces {
var counter net.IOCountersStat
for _, io := range ioCounters {
if io.Name == iface.Name {
counter = io
break
}
}
// 获取接口的地址列表
addrs := make([]string, 0)
for _, addr := range iface.Addrs {
addrs = append(addrs, addr.String())
}
status.Network.Interfaces = append(status.Network.Interfaces, model.NetworkInterface{
Name: iface.Name,
BytesSent: counter.BytesSent,
BytesRecv: counter.BytesRecv,
PacketsSent: counter.PacketsSent,
PacketsRecv: counter.PacketsRecv,
Addrs: addrs,
})
}
return nil
}
func (s *MonitorService) getProcessInfo(status *model.SystemStatus) error {
processes, err := process.Processes()
if err != nil {
return err
}
status.Process.Total = len(processes)
status.Process.List = make([]model.ProcessInfo, 0)
for i := 0; i < 10 && i < len(processes); i++ {
p := processes[i]
name, _ := p.Name()
cpu, _ := p.CPUPercent()
mem, _ := p.MemoryPercent()
created, _ := p.CreateTime()
status.Process.List = append(status.Process.List, model.ProcessInfo{
PID: int(p.Pid),
Name: name,
CPU: cpu,
Memory: float64(mem),
Created: created,
})
}
return nil
}
func (s *MonitorService) getHostInfo(status *model.SystemStatus) error {
info, err := host.Info()
if err != nil {
return err
}
status.Host.Hostname = info.Hostname
status.Host.OS = info.OS
status.Host.Platform = info.Platform
status.Host.PlatformVersion = info.PlatformVersion
status.Host.KernelVersion = info.KernelVersion
status.Host.BootTime = time.Unix(int64(info.BootTime), 0)
return nil
}
func (s *MonitorService) getSystemInfo(status *model.SystemStatus) {
status.System.Uptime = time.Since(s.startTime)
status.System.CurrentTime = time.Now()
var activeUsers int64
s.db.Model(&model.User{}).Count(&activeUsers)
status.System.ActiveUsers = int(activeUsers)
var totalDevices int64
s.db.Model(&model.Device{}).Count(&totalDevices)
status.System.TotalDevices = int(totalDevices)
}

47
internal/service/site.go Normal file
View File

@@ -0,0 +1,47 @@
package service
import (
"errors"
"licserver/internal/utils"
"sync"
)
type SiteService struct {
config *utils.Config
mu sync.RWMutex
}
func NewSiteService(config *utils.Config) *SiteService {
return &SiteService{
config: config,
}
}
func (s *SiteService) GetSettings() utils.SiteConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.config.Site
}
func (s *SiteService) UpdateSettings(settings utils.SiteConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// 更新内存中的配置
s.config.Site = settings
// 持久化配置到文件
return utils.SaveConfig(s.config, "system", "更新站点设置")
}
// ValidateSettings 验证站点设置
func (s *SiteService) ValidateSettings(settings utils.SiteConfig) error {
// 这里可以添加更多的验证逻辑
if settings.Title == "" {
return errors.New("站点标题不能为空")
}
if settings.BaseURL == "" {
return errors.New("基础URL不能为空")
}
return nil
}

188
internal/service/token.go Normal file
View File

@@ -0,0 +1,188 @@
package service
import (
"crypto/rand"
"encoding/hex"
"errors"
"strings"
"time"
"licserver/internal/model"
"gorm.io/gorm"
)
type TokenService struct {
db *gorm.DB
}
func NewTokenService(db *gorm.DB) *TokenService {
return &TokenService{db: db}
}
// 生成访问令牌
func generateToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// 创建访问令牌
func (s *TokenService) CreateToken(deviceUID, tokenType string, expireTime time.Time, ipList []string, createdBy uint) (*model.AccessToken, error) {
// 验证设备是否存在
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
// 生成令牌
token, err := generateToken()
if err != nil {
return nil, err
}
// 创建令牌记录
accessToken := &model.AccessToken{
Token: token,
DeviceUID: deviceUID,
Type: tokenType,
Status: "active",
ExpireTime: expireTime,
IPList: strings.Join(ipList, ","),
CreatedBy: createdBy,
}
if err := s.db.Create(accessToken).Error; err != nil {
return nil, err
}
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "create",
Status: "success",
Message: "创建访问令牌",
}
s.db.Create(tokenLog)
return accessToken, nil
}
// 验证令牌
func (s *TokenService) ValidateToken(token, ip string) (*model.AccessToken, error) {
var accessToken model.AccessToken
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
return nil, errors.New("无效的令牌")
}
// 检查令牌状态
if accessToken.Status != "active" {
return nil, errors.New("令牌已被撤销")
}
// 检查过期时间
if time.Now().After(accessToken.ExpireTime) {
return nil, errors.New("令牌已过期")
}
// 检查IP限制
if accessToken.IPList != "" {
allowedIPs := strings.Split(accessToken.IPList, ",")
allowed := false
for _, allowedIP := range allowedIPs {
if allowedIP == ip {
allowed = true
break
}
}
if !allowed {
return nil, errors.New("IP地址不允许访问")
}
}
// 更新使用记录
s.db.Model(&accessToken).Updates(map[string]interface{}{
"last_used": time.Now(),
"usage_count": gorm.Expr("usage_count + 1"),
})
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "use",
IP: ip,
Status: "success",
}
s.db.Create(tokenLog)
return &accessToken, nil
}
// 撤销令牌
func (s *TokenService) RevokeToken(token string, userID uint) error {
var accessToken model.AccessToken
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
return errors.New("令牌不存在")
}
// 检查权限
if accessToken.CreatedBy != userID {
return errors.New("无权操作此令牌")
}
// 更新状态
if err := s.db.Model(&accessToken).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "revoke",
Status: "success",
Message: "撤销访问令牌",
}
s.db.Create(tokenLog)
return nil
}
// 获取令牌列表
func (s *TokenService) GetTokens(deviceUID string, page, pageSize int) ([]model.AccessToken, int64, error) {
var tokens []model.AccessToken
var total int64
query := s.db.Model(&model.AccessToken{})
if deviceUID != "" {
query = query.Where("device_uid = ?", deviceUID)
}
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&tokens).Error
return tokens, total, err
}
// 获取令牌日志
func (s *TokenService) GetTokenLogs(tokenID uint, page, pageSize int) ([]model.TokenLog, int64, error) {
var logs []model.TokenLog
var total int64
query := s.db.Model(&model.TokenLog{}).Where("token_id = ?", tokenID)
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}

463
internal/service/upload.go Normal file
View File

@@ -0,0 +1,463 @@
package service
import (
"errors"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"sort"
"strings"
"time"
"licserver/internal/model"
"licserver/internal/utils"
"github.com/google/uuid"
"gorm.io/gorm"
"crypto/sha256"
"encoding/hex"
)
type UploadService struct {
db *gorm.DB
config *utils.Config
}
func NewUploadService(db *gorm.DB, config *utils.Config) *UploadService {
return &UploadService{
db: db,
config: config,
}
}
func (s *UploadService) UploadFile(file *multipart.FileHeader, userID uint, deviceUID, description string) (*model.FileUpload, error) {
// 生成唯一文件名
ext := filepath.Ext(file.Filename)
uniqueID := uuid.New().String()
fileName := fmt.Sprintf("%s%s", uniqueID, ext)
filePath := filepath.Join(s.config.Upload.Path, fileName)
// 确保上传目录存在
if err := os.MkdirAll(s.config.Upload.Path, 0755); err != nil {
return nil, err
}
// 保存文件
src, err := file.Open()
if err != nil {
return nil, err
}
defer src.Close()
dst, err := os.Create(filePath)
if err != nil {
return nil, err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return nil, err
}
// 如果提供了设备UID获取设备型号
var deviceModel string
if deviceUID != "" {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err == nil {
deviceModel = device.DeviceModel
}
}
// 创建数据库记录
upload := &model.FileUpload{
FileName: file.Filename,
FilePath: filePath,
FileSize: file.Size,
FileType: strings.ToLower(ext),
UploadedBy: userID,
DeviceUID: deviceUID,
DeviceModel: deviceModel,
Description: description,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(filePath)
return nil, err
}
return upload, nil
}
func (s *UploadService) DownloadFile(id uint) (*model.FileUpload, error) {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return nil, err
}
return &file, nil
}
func (s *UploadService) DeleteFile(id uint, userID uint) error {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return err
}
if file.UploadedBy != userID {
return errors.New("无权删除此文件")
}
if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&file).Error
}
func (s *UploadService) GetDeviceFiles(deviceUID string) ([]model.FileUpload, error) {
var files []model.FileUpload
err := s.db.Where("device_uid = ?", deviceUID).Find(&files).Error
return files, err
}
func (s *UploadService) UploadChunk(
file *multipart.FileHeader,
fileHash string,
chunkNumber int,
totalChunks int,
totalSize int64,
filename string,
userID uint,
deviceUID string,
) error {
// 创建分片存储目录
chunkDir := filepath.Join(s.config.Upload.Path, "chunks", fileHash)
if err := os.MkdirAll(chunkDir, 0755); err != nil {
return err
}
// 保存分片文件
chunkPath := filepath.Join(chunkDir, fmt.Sprintf("%d", chunkNumber))
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
dst, err := os.Create(chunkPath)
if err != nil {
return err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return err
}
// 记录分片信息
chunk := model.UploadChunk{
FileHash: fileHash,
ChunkNumber: chunkNumber,
ChunkSize: file.Size,
ChunkPath: chunkPath,
TotalChunks: totalChunks,
TotalSize: totalSize,
Filename: filename,
FileType: strings.ToLower(filepath.Ext(filename)),
UploadedBy: userID,
DeviceUID: deviceUID,
}
return s.db.Create(&chunk).Error
}
func (s *UploadService) CheckUploadStatus(fileHash string) (bool, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return false, err
}
if len(chunks) == 0 {
return false, nil
}
totalChunks := chunks[0].TotalChunks
return len(chunks) == totalChunks, nil
}
func (s *UploadService) MergeChunks(fileHash string) (*model.FileUpload, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return nil, err
}
if len(chunks) == 0 {
return nil, errors.New("未找到文件分片")
}
if len(chunks) != chunks[0].TotalChunks {
return nil, errors.New("文件分片不完整")
}
// 按分片序号排序
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkNumber < chunks[j].ChunkNumber
})
// 创建最终文件
finalPath := filepath.Join(s.config.Upload.Path, fmt.Sprintf("%s%s", uuid.New().String(), chunks[0].FileType))
finalFile, err := os.Create(finalPath)
if err != nil {
return nil, err
}
defer finalFile.Close()
// 合并分片
hash := sha256.New()
for _, chunk := range chunks {
chunkFile, err := os.Open(chunk.ChunkPath)
if err != nil {
return nil, err
}
if _, err = io.Copy(finalFile, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
if _, err = io.Copy(hash, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
chunkFile.Close()
os.Remove(chunk.ChunkPath) // 删除已合并的分片
}
// 验证文件哈希
if hex.EncodeToString(hash.Sum(nil)) != fileHash {
os.Remove(finalPath)
return nil, errors.New("文件哈希验证失败")
}
// 创建文件记录
upload := &model.FileUpload{
FileName: chunks[0].Filename,
FilePath: finalPath,
FileSize: chunks[0].TotalSize,
FileType: chunks[0].FileType,
UploadedBy: chunks[0].UploadedBy,
DeviceUID: chunks[0].DeviceUID,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(finalPath)
return nil, err
}
// 清理分片记录
s.db.Where("file_hash = ?", fileHash).Delete(&model.UploadChunk{})
os.RemoveAll(filepath.Dir(chunks[0].ChunkPath))
return upload, nil
}
func (s *UploadService) CleanupExpiredChunks() error {
expireTime := time.Now().Add(-24 * time.Hour)
var expiredChunks []model.UploadChunk
if err := s.db.Where("completed = ? AND created_at < ?", false, expireTime).Find(&expiredChunks).Error; err != nil {
return err
}
for _, chunk := range expiredChunks {
os.Remove(chunk.ChunkPath)
if len(chunk.ChunkPath) > 0 {
chunkDir := filepath.Dir(chunk.ChunkPath)
os.RemoveAll(chunkDir)
}
}
return s.db.Unscoped().Where("completed = ? AND created_at < ?", false, expireTime).Delete(&model.UploadChunk{}).Error
}

361
internal/service/user.go Normal file
View File

@@ -0,0 +1,361 @@
package service
import (
"errors"
"fmt"
"time"
"licserver/internal/model"
"licserver/internal/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type UserService struct {
db *gorm.DB
config *utils.Config
captchaService *CaptchaService
}
type UserProfile struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Role string `json:"role"`
}
func NewUserService(db *gorm.DB, config *utils.Config) *UserService {
return &UserService{
db: db,
config: config,
captchaService: NewCaptchaService(db, &config.Email),
}
}
func (s *UserService) Register(username, password, email, captcha string) error {
// 验证验证码
if err := s.captchaService.VerifyCaptcha(email, "register", captcha); err != nil {
return err
}
// 检查用户名是否已存在
var count int64
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
if count > 0 {
return errors.New("用户名已存在")
}
// 检查邮箱是否已存在
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
// 原有的注册逻辑
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user := model.User{
Username: username,
Password: string(hashedPassword),
Email: email,
Role: "user",
}
return s.db.Create(&user).Error
}
func (s *UserService) Login(username, password string) (string, error) {
var user model.User
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", errors.New("用户不存在")
}
return "", err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
return "", errors.New("密码错误")
}
// 生成 JWT token
token, err := utils.GenerateToken(user.ID, user.Username, user.Role, &s.config.JWT)
if err != nil {
return "", err
}
// 更新最后登录时间
s.db.Model(&user).Update("last_login", gorm.Expr("CURRENT_TIMESTAMP"))
return token, nil
}
func (s *UserService) GetUserByID(id uint) (*UserProfile, error) {
var user model.User
if err := s.db.First(&user, id).Error; err != nil {
return nil, err
}
return &UserProfile{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Role: user.Role,
}, nil
}
func (s *UserService) UpdateProfile(userID uint, email string) error {
// 检查邮箱是否被其他用户使用
var count int64
if err := s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, userID).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("邮箱已被其他用户使用")
}
// 更新用户信息
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("email", email).Error
}
func (s *UserService) ChangePassword(userID uint, oldPassword, newPassword string) error {
var user model.User
if err := s.db.First(&user, userID).Error; err != nil {
return err
}
// 验证旧密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(oldPassword)); err != nil {
return errors.New("旧密码错误")
}
// 加密新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
return s.db.Model(&user).Update("password", string(hashedPassword)).Error
}
func (s *UserService) ResetPassword(email, captcha string) error {
// 验证验证码
if err := s.captchaService.VerifyCaptcha(email, "reset", captcha); err != nil {
return err
}
// 原有的重置密码逻辑
var user model.User
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
return errors.New("邮箱不存在")
}
// 生成重置令牌
token, err := utils.GenerateResetToken()
if err != nil {
return err
}
// 保存重置令牌
resetToken := model.PasswordResetToken{
UserID: user.ID,
Token: token,
ExpiresAt: time.Now().Add(24 * time.Hour),
Used: false,
}
if err := s.db.Create(&resetToken).Error; err != nil {
return err
}
// 发送重置邮件
emailService := utils.NewEmailService(&s.config.Email)
resetLink := fmt.Sprintf("http://localhost:%s/reset-password?token=%s", s.config.Server.Port, token)
emailBody := fmt.Sprintf(`
<h3>密码重置</h3>
<p>您好,%s</p>
<p>请点击以下链接重置您的密码:</p>
<p><a href="%s">重置密码</a></p>
<p>此链接将在24小时后失效。</p>
<p>如果您没有请求重置密码,请忽略此邮件。</p>
`, user.Username, resetLink)
return emailService.SendEmail(user.Email, "密码重置", emailBody)
}
func (s *UserService) ValidateResetToken(token string) (*model.User, error) {
var resetToken model.PasswordResetToken
if err := s.db.Where("token = ? AND used = ? AND expires_at > ?",
token, false, time.Now()).First(&resetToken).Error; err != nil {
return nil, errors.New("无效或已过期的重置令牌")
}
var user model.User
if err := s.db.First(&user, resetToken.UserID).Error; err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) ResetPasswordWithToken(token, newPassword string) error {
user, err := s.ValidateResetToken(token)
if err != nil {
return err
}
// 更新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
// 使用事务确保原子性
return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&user).Update("password", string(hashedPassword)).Error; err != nil {
return err
}
// 标记令牌为已使用
if err := tx.Model(&model.PasswordResetToken{}).
Where("token = ?", token).
Update("used", true).Error; err != nil {
return err
}
return nil
})
}
func (s *UserService) SendRegisterCaptcha(email string) error {
// 检查邮箱是否已被注册
var count int64
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
return s.captchaService.SendEmailCaptcha(email, "register")
}
func (s *UserService) SendResetPasswordCaptcha(email string) error {
var user model.User
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
return errors.New("邮箱不存在")
}
return s.captchaService.SendEmailCaptcha(email, "reset")
}
func (s *UserService) GetCaptchaService() *CaptchaService {
return s.captchaService
}
// GetUsers 获取用户列表
func (s *UserService) GetUsers(username, role string, page, pageSize int) ([]UserProfile, int64, error) {
var users []model.User
var total int64
var profiles []UserProfile
query := s.db.Model(&model.User{})
if username != "" {
query = query.Where("username LIKE ?", "%"+username+"%")
}
if role != "" {
query = query.Where("role = ?", role)
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
// 转换为 UserProfile
for _, user := range users {
profiles = append(profiles, UserProfile{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Role: user.Role,
})
}
return profiles, total, nil
}
// CreateUser 创建新用户
func (s *UserService) CreateUser(username, password, email, role string) error {
// 检查用户名是否已存在
var count int64
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
if count > 0 {
return errors.New("用户名已存在")
}
// 检查邮箱是否已存在
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user := model.User{
Username: username,
Password: string(hashedPassword),
Email: email,
Role: role,
}
return s.db.Create(&user).Error
}
// UpdateUser 更新用户信息
func (s *UserService) UpdateUser(id uint, email, role string) error {
// 检查邮箱是否被其他用户使用
var count int64
s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, id).Count(&count)
if count > 0 {
return errors.New("邮箱已被其他用户使用")
}
return s.db.Model(&model.User{}).Where("id = ?", id).Updates(map[string]interface{}{
"email": email,
"role": role,
}).Error
}
// DeleteUser 删除用户
func (s *UserService) DeleteUser(id uint) error {
// 检查是否为最后一个管理员
var adminCount int64
s.db.Model(&model.User{}).Where("role = ?", "admin").Count(&adminCount)
var user model.User
if err := s.db.First(&user, id).Error; err != nil {
return err
}
if user.Role == "admin" && adminCount <= 1 {
return errors.New("不能删除最后一个管理员")
}
return s.db.Delete(&model.User{}, id).Error
}

View File

@@ -0,0 +1,61 @@
package service
import (
"licserver/internal/model"
"licserver/internal/utils"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUserService_Register(t *testing.T) {
db := utils.TestDB(t)
config := utils.TestConfig()
userService := NewUserService(db, config)
// 测试正常注册
err := userService.Register("testuser", "password123", "test@example.com", "123456")
assert.NoError(t, err)
var user model.User
err = db.Where("username = ?", "testuser").First(&user).Error
assert.NoError(t, err)
assert.Equal(t, "testuser", user.Username)
assert.Equal(t, "test@example.com", user.Email)
assert.Equal(t, "user", user.Role)
// 测试重复用户名
err = userService.Register("testuser", "password123", "test2@example.com", "123456")
assert.Error(t, err)
assert.Contains(t, err.Error(), "用户名已存在")
// 测试重复邮箱
err = userService.Register("testuser2", "password123", "test@example.com", "123456")
assert.Error(t, err)
assert.Contains(t, err.Error(), "邮箱已被注册")
}
func TestUserService_Login(t *testing.T) {
db := utils.TestDB(t)
config := utils.TestConfig()
userService := NewUserService(db, config)
// 创建测试用户
err := userService.Register("testuser", "password123", "test@example.com", "123456")
assert.NoError(t, err)
// 测试正确登录
token, err := userService.Login("testuser", "password123")
assert.NoError(t, err)
assert.NotEmpty(t, token)
// 测试错误密码
_, err = userService.Login("testuser", "wrongpassword")
assert.Error(t, err)
assert.Contains(t, err.Error(), "密码错误")
// 测试不存在的用户
_, err = userService.Login("nonexistent", "password123")
assert.Error(t, err)
assert.Contains(t, err.Error(), "用户不存在")
}