first commit

This commit is contained in:
JiXieShi
2024-11-14 22:55:43 +08:00
commit 421cfb8cfa
98 changed files with 12617 additions and 0 deletions

45
internal/api/dashboard.go Normal file
View File

@@ -0,0 +1,45 @@
package api
import (
"licserver/internal/model"
"net/http"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type DashboardHandler struct {
db *gorm.DB
}
func NewDashboardHandler(db *gorm.DB) *DashboardHandler {
return &DashboardHandler{db: db}
}
func (h *DashboardHandler) GetStats(c *gin.Context) {
var stats struct {
TotalDevices int64 `json:"total_devices"`
TotalLicenses int64 `json:"total_licenses"`
TodayNew int64 `json:"today_new"`
OnlineDevices int64 `json:"online_devices"`
}
// 获取设备总数
h.db.Model(&model.Device{}).Count(&stats.TotalDevices)
// 获取授权码总数
h.db.Model(&model.LicenseCode{}).Count(&stats.TotalLicenses)
// 获取今日新增设备数
today := time.Now().Format("2006-01-02")
h.db.Model(&model.Device{}).Where("DATE(created_at) = ?", today).Count(&stats.TodayNew)
// 获取在线设备数最近30分钟内有活动的设备
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
h.db.Model(&model.Device{}).
Where("last_active_at > ?", thirtyMinutesAgo).
Count(&stats.OnlineDevices)
c.JSON(http.StatusOK, stats)
}

434
internal/api/device.go Normal file
View File

@@ -0,0 +1,434 @@
package api
import (
"fmt"
"net/http"
"strconv"
"licserver/internal/model"
"licserver/internal/service"
"github.com/gin-gonic/gin"
)
type DeviceHandler struct {
deviceService *service.DeviceService
}
func NewDeviceHandler(deviceService *service.DeviceService) *DeviceHandler {
return &DeviceHandler{deviceService: deviceService}
}
func (h *DeviceHandler) CreateDevice(c *gin.Context) {
var input service.DeviceCreateInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.CreateDevice(&input); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "设备创建成功"})
}
func (h *DeviceHandler) GetDevices(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "10"))
params := &service.DeviceQueryParams{
UID: c.Query("uid"),
DeviceType: c.Query("deviceType"),
Company: c.Query("company"),
LicenseType: c.Query("licenseType"),
Status: c.Query("status"),
Page: page,
PageSize: pageSize,
}
devices, total, err := h.deviceService.GetDevices(params)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "获取设备列表成功",
"count": total,
"data": devices,
})
}
func (h *DeviceHandler) UpdateStartCount(c *gin.Context) {
uid := c.Param("uid")
if err := h.deviceService.UpdateStartCount(uid); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "启动次数更新成功"})
}
func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
uid := c.Param("uid")
var updates map[string]interface{}
if err := c.ShouldBindJSON(&updates); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.UpdateDevice(uid, updates); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "设备更新成功"})
}
func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
uid := c.Param("uid")
if err := h.deviceService.DeleteDevice(uid); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "设备删除成功"})
}
func (h *DeviceHandler) RegisterDevice(c *gin.Context) {
var input service.DeviceRegisterInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.RegisterDevice(&input, c.ClientIP()); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
status := "未激活"
if input.LicenseCode != "" {
status = "已激活"
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": fmt.Sprintf("设备注册成功,当前状态:%s", status),
})
}
func (h *DeviceHandler) ValidateDevice(c *gin.Context) {
uid := c.Param("uid")
if err := h.deviceService.ValidateDevice(uid); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "设备验证通过"})
}
func (h *DeviceHandler) BindLicense(c *gin.Context) {
uid := c.Param("uid")
var input struct {
LicenseCode string `json:"license_code" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.BindLicense(uid, input.LicenseCode); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码绑定成功",
})
}
func (h *DeviceHandler) UnbindLicense(c *gin.Context) {
uid := c.Param("uid")
if err := h.deviceService.UnbindLicense(uid); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码解绑成功",
})
}
func (h *DeviceHandler) GetLicenseInfo(c *gin.Context) {
deviceUID := c.Param("uid")
device, err := h.deviceService.GetLicenseInfo(deviceUID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, device)
}
func (h *DeviceHandler) CheckLicenseStatus(c *gin.Context) {
deviceUID := c.Param("uid")
status, err := h.deviceService.CheckLicenseStatus(deviceUID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": status,
})
}
func (h *DeviceHandler) CheckUpdate(c *gin.Context) {
deviceUID := c.Param("uid")
currentVersion := c.Query("version")
update, err := h.deviceService.CheckUpdate(deviceUID, currentVersion)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, update)
}
func (h *DeviceHandler) CreateDeviceModel(c *gin.Context) {
var model model.DeviceModel
if err := c.ShouldBindJSON(&model); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
model.CreatedBy = c.GetUint("userID")
model.Status = "active"
if err := h.deviceService.CreateDeviceModel(&model); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "设备型号创建成功",
"data": model,
})
}
func (h *DeviceHandler) GetDeviceModels(c *gin.Context) {
modelName := c.Query("model_name")
deviceType := c.Query("device_type")
company := c.Query("company")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
models, total, err := h.deviceService.GetDeviceModels(modelName, deviceType, company, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"msg": "",
"count": total,
"data": models,
})
}
func (h *DeviceHandler) UpdateDeviceModel(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的ID"})
return
}
var model model.DeviceModel
if err := c.ShouldBindJSON(&model); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.UpdateDeviceModel(uint(id), &model); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "设备型号更新成功",
})
}
func (h *DeviceHandler) DeleteDeviceModel(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的ID"})
return
}
if err := h.deviceService.DeleteDeviceModel(uint(id)); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "设备型号删除成功",
})
}
func (h *DeviceHandler) BatchDeleteDeviceModels(c *gin.Context) {
var input struct {
IDs []uint `json:"ids" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.deviceService.BatchDeleteDeviceModels(input.IDs); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "设备型号批量删除成功",
})
}
func (h *DeviceHandler) GetRegisteredDevices(c *gin.Context) {
uid := c.Query("uid")
deviceModel := c.Query("device_model")
status := c.Query("status")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
devices, total, err := h.deviceService.GetRegisteredDevices(uid, deviceModel, status, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"msg": "",
"count": total,
"data": devices,
})
}
func (h *DeviceHandler) GetDeviceLogs(c *gin.Context) {
uid := c.Param("uid")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
logs, total, err := h.deviceService.GetDeviceLogs(uid, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"msg": "",
"count": total,
"data": logs,
})
}
// GetDashboardStats 获取仪表盘统计数据
func (h *DeviceHandler) GetDashboardStats(c *gin.Context) {
stats, err := h.deviceService.GetDashboardStats()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": stats,
})
}

169
internal/api/license.go Normal file
View File

@@ -0,0 +1,169 @@
package api
import (
"fmt"
"licserver/internal/service"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
)
type LicenseHandler struct {
licenseService *service.LicenseService
}
func NewLicenseHandler(licenseService *service.LicenseService) *LicenseHandler {
return &LicenseHandler{licenseService: licenseService}
}
// 创建授权码
func (h *LicenseHandler) CreateLicenses(c *gin.Context) {
var input service.LicenseCreateInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetUint("userID")
licenses, err := h.licenseService.CreateLicenses(&input, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码创建成功",
"data": licenses,
})
}
// 使用授权码
func (h *LicenseHandler) UseLicense(c *gin.Context) {
var input struct {
Code string `json:"code" binding:"required"`
DeviceUID string `json:"device_uid" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
license, err := h.licenseService.UseLicense(input.Code, input.DeviceUID, c.ClientIP())
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码使用成功",
"data": license,
})
}
// 获取授权码列表
func (h *LicenseHandler) GetLicenses(c *gin.Context) {
status := c.Query("status")
licenseType := c.Query("license_type")
batchNo := c.Query("batch_no")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
licenses, total, err := h.licenseService.GetLicenses(status, licenseType, batchNo, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 确保返回格式符合 layui table 的要求
c.JSON(http.StatusOK, gin.H{
"code": 0,
"msg": "",
"count": total,
"data": licenses,
})
}
// 获取授权码使用日志
func (h *LicenseHandler) GetLicenseLogs(c *gin.Context) {
licenseID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的授权码ID"})
return
}
// 检查是否为导出请求
if c.Query("export") == "1" {
data, err := h.licenseService.ExportLogs(uint(licenseID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 设置响应头
filename := fmt.Sprintf("license_logs_%d_%s.csv", licenseID, time.Now().Format("20060102150405"))
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
c.Data(http.StatusOK, "text/csv", data)
return
}
// 常规日志查询
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "10"))
logs, total, err := h.licenseService.GetLicenseLogs(uint(licenseID), page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "获取授权码使用日志成功",
"count": total,
"data": logs,
})
}
// 添加撤销授权码的处理方法
func (h *LicenseHandler) RevokeLicense(c *gin.Context) {
code := c.Param("code")
userID := c.GetUint("userID")
if err := h.licenseService.RevokeLicense(code, userID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码撤销成功",
})
}
// 添加批量撤销处理方法
func (h *LicenseHandler) RevokeLicenses(c *gin.Context) {
var input struct {
Codes []string `json:"codes" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetUint("userID")
if err := h.licenseService.RevokeLicenses(input.Codes, userID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "授权码批量撤销成功",
})
}

26
internal/api/monitor.go Normal file
View File

@@ -0,0 +1,26 @@
package api
import (
"licserver/internal/service"
"net/http"
"github.com/gin-gonic/gin"
)
type MonitorHandler struct {
monitorService *service.MonitorService
}
func NewMonitorHandler(monitorService *service.MonitorService) *MonitorHandler {
return &MonitorHandler{monitorService: monitorService}
}
func (h *MonitorHandler) GetSystemStatus(c *gin.Context) {
status, err := h.monitorService.GetSystemStatus()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, status)
}

121
internal/api/router.go Normal file
View File

@@ -0,0 +1,121 @@
package api
import (
"licserver/internal/middleware"
"licserver/internal/utils"
"github.com/gin-gonic/gin"
)
func SetupRouter(
userHandler *UserHandler,
deviceHandler *DeviceHandler,
monitorHandler *MonitorHandler,
config *utils.Config,
uploadHandler *UploadHandler,
siteHandler *SiteHandler,
tokenHandler *TokenHandler,
licenseHandler *LicenseHandler,
) *gin.Engine {
r := gin.Default()
// 添加错误处理中间件
r.Use(middleware.ErrorHandler())
// 静态文件服务
r.Static("/static", "./web/static")
// 首页和登录页面
r.StaticFile("/", "./web/templates/index.html")
r.StaticFile("/login", "./web/templates/login.html")
// Admin页面路由组
admin := r.Group("/admin")
admin.Use(middleware.JWTAuth(&config.JWT))
{
// 使用StaticFile处理包含Layui模板的页面
admin.StaticFile("/dashboard", "./web/templates/admin/dashboard.html")
admin.StaticFile("/devices", "./web/templates/admin/devices.html")
admin.StaticFile("/device-files", "./web/templates/admin/device-files.html")
admin.StaticFile("/device-license", "./web/templates/admin/device-license.html")
admin.StaticFile("/licenses", "./web/templates/admin/licenses.html")
admin.StaticFile("/license-logs", "./web/templates/admin/license-logs.html")
admin.StaticFile("/tokens", "./web/templates/admin/tokens.html")
admin.StaticFile("/token-logs", "./web/templates/admin/token-logs.html")
admin.StaticFile("/monitor", "./web/templates/admin/monitor.html")
admin.StaticFile("/site-settings", "./web/templates/admin/site-settings.html")
admin.StaticFile("/users", "./web/templates/admin/users.html")
admin.StaticFile("/user-edit", "./web/templates/admin/user-edit.html")
admin.StaticFile("/change-password", "./web/templates/admin/change-password.html")
}
// API路由
api := r.Group("/api")
{
// 公开API
api.GET("/captcha", userHandler.GetCaptcha)
api.POST("/captcha/verify", userHandler.VerifyCaptcha)
api.POST("/login", userHandler.Login)
api.POST("/register", userHandler.Register)
api.POST("/reset-password", userHandler.ResetPassword)
api.POST("/reset-password/confirm", userHandler.ResetPasswordWithToken)
api.POST("/captcha/register", userHandler.SendRegisterCaptcha)
api.POST("/captcha/reset-password", userHandler.SendResetPasswordCaptcha)
api.POST("/validate-token", tokenHandler.ValidateToken)
// 需要认证的API
authorized := api.Group("")
authorized.Use(middleware.JWTAuth(&config.JWT))
{
// 设备型号管理
authorized.POST("/devices/models", middleware.AdminRequired(), deviceHandler.CreateDeviceModel)
authorized.GET("/devices/models", deviceHandler.GetDeviceModels)
authorized.PUT("/devices/models/:id", middleware.AdminRequired(), deviceHandler.UpdateDeviceModel)
authorized.DELETE("/devices/models/:id", middleware.AdminRequired(), deviceHandler.DeleteDeviceModel)
authorized.POST("/devices/models/batch", middleware.AdminRequired(), deviceHandler.BatchDeleteDeviceModels)
// 设备管理
authorized.POST("/devices/register", deviceHandler.RegisterDevice)
authorized.GET("/devices/registered", deviceHandler.GetRegisteredDevices)
authorized.POST("/devices/:uid/license", middleware.AdminRequired(), deviceHandler.BindLicense)
authorized.DELETE("/devices/:uid/license", middleware.AdminRequired(), deviceHandler.UnbindLicense)
authorized.GET("/devices/:uid/logs", deviceHandler.GetDeviceLogs)
// 其他API路由...
// 用户管理
authorized.GET("/users", middleware.AdminRequired(), userHandler.GetUsers)
authorized.POST("/users", middleware.AdminRequired(), userHandler.CreateUser)
authorized.GET("/users/:id", middleware.AdminRequired(), userHandler.GetUserInfo)
authorized.PUT("/users/:id", middleware.AdminRequired(), userHandler.UpdateUser)
authorized.DELETE("/users/:id", middleware.AdminRequired(), userHandler.DeleteUser)
authorized.GET("/users/profile", userHandler.GetProfile)
authorized.PUT("/users/profile", userHandler.UpdateProfile)
authorized.POST("/users/change-password", userHandler.ChangePassword)
// 系统监控
authorized.GET("/monitor/status", middleware.AdminRequired(), monitorHandler.GetSystemStatus)
// 站点设置
authorized.GET("/site/settings", middleware.AdminRequired(), siteHandler.GetSettings)
authorized.PUT("/site/settings", middleware.AdminRequired(), siteHandler.UpdateSettings)
// Token管理
authorized.POST("/tokens", middleware.AdminRequired(), tokenHandler.CreateToken)
authorized.GET("/tokens", tokenHandler.GetTokens)
authorized.GET("/tokens/:id/logs", tokenHandler.GetTokenLogs)
authorized.DELETE("/tokens/:token", middleware.AdminRequired(), tokenHandler.RevokeToken)
// 授权码管理
authorized.POST("/licenses", middleware.AdminRequired(), licenseHandler.CreateLicenses)
authorized.GET("/licenses", licenseHandler.GetLicenses)
authorized.GET("/licenses/:id/logs", licenseHandler.GetLicenseLogs)
authorized.POST("/licenses/use", licenseHandler.UseLicense)
// 仪表盘统计
authorized.GET("/dashboard/stats", deviceHandler.GetDashboardStats)
}
}
return r
}

86
internal/api/site.go Normal file
View File

@@ -0,0 +1,86 @@
package api
import (
"licserver/internal/service"
"licserver/internal/utils"
"net/http"
"github.com/gin-gonic/gin"
)
type SiteHandler struct {
siteService *service.SiteService
}
func NewSiteHandler(siteService *service.SiteService) *SiteHandler {
return &SiteHandler{siteService: siteService}
}
func (h *SiteHandler) GetSettings(c *gin.Context) {
settings := h.siteService.GetSettings()
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": settings,
"title": settings.Title,
"description": settings.Description,
"base_url": settings.BaseURL,
"icp": settings.ICP,
"copyright": settings.Copyright,
"logo": settings.Logo,
"favicon": settings.Favicon,
})
}
func (h *SiteHandler) UpdateSettings(c *gin.Context) {
var settings utils.SiteConfig
if err := c.ShouldBindJSON(&settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.siteService.ValidateSettings(settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.siteService.UpdateSettings(settings); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "站点设置更新成功",
"data": settings,
})
}
func (h *SiteHandler) BackupSettings(c *gin.Context) {
settings := h.siteService.GetSettings()
c.Header("Content-Disposition", "attachment; filename=site_settings.json")
c.JSON(http.StatusOK, settings)
}
func (h *SiteHandler) RestoreSettings(c *gin.Context) {
var settings utils.SiteConfig
if err := c.ShouldBindJSON(&settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.siteService.ValidateSettings(settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.siteService.UpdateSettings(settings); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "站点设置恢复成功",
"data": settings,
})
}

121
internal/api/token.go Normal file
View File

@@ -0,0 +1,121 @@
package api
import (
"licserver/internal/service"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
)
type TokenHandler struct {
tokenService *service.TokenService
}
func NewTokenHandler(tokenService *service.TokenService) *TokenHandler {
return &TokenHandler{tokenService: tokenService}
}
func (h *TokenHandler) CreateToken(c *gin.Context) {
var input struct {
DeviceUID string `json:"device_uid" binding:"required"`
TokenType string `json:"token_type" binding:"required,oneof=api device"`
ExpireDays int `json:"expire_days" binding:"required,min=1"`
IPList []string `json:"ip_list"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
expireTime := time.Now().AddDate(0, 0, input.ExpireDays)
userID := c.GetUint("userID")
token, err := h.tokenService.CreateToken(
input.DeviceUID,
input.TokenType,
expireTime,
input.IPList,
userID,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, token)
}
func (h *TokenHandler) ValidateToken(c *gin.Context) {
token := c.GetHeader("X-Access-Token")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供访问令牌"})
return
}
clientIP := c.ClientIP()
accessToken, err := h.tokenService.ValidateToken(token, clientIP)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, accessToken)
}
func (h *TokenHandler) RevokeToken(c *gin.Context) {
token := c.Param("token")
userID := c.GetUint("userID")
if err := h.tokenService.RevokeToken(token, userID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "令牌已撤销"})
}
func (h *TokenHandler) GetTokens(c *gin.Context) {
deviceUID := c.Query("device_uid")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
tokens, total, err := h.tokenService.GetTokens(deviceUID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "获取令牌列表成功",
"count": total,
"data": tokens,
})
}
func (h *TokenHandler) GetTokenLogs(c *gin.Context) {
tokenID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的令牌ID"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
logs, total, err := h.tokenService.GetTokenLogs(uint(tokenID), page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "获取令牌使用日志成功",
"count": total,
"data": logs,
})
}

213
internal/api/upload.go Normal file
View File

@@ -0,0 +1,213 @@
package api
import (
"fmt"
"licserver/internal/service"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type UploadHandler struct {
uploadService *service.UploadService
}
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
return &UploadHandler{uploadService: uploadService}
}
func (h *UploadHandler) UploadFile(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到上传文件"})
return
}
deviceUID := c.PostForm("device_uid")
description := c.PostForm("description")
userID := c.GetUint("userID")
upload, err := h.uploadService.UploadFile(file, userID, deviceUID, description)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, upload)
}
func (h *UploadHandler) DownloadFile(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件ID"})
return
}
file, err := h.uploadService.DownloadFile(uint(id))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.FileAttachment(file.FilePath, file.FileName)
}
func (h *UploadHandler) DeleteFile(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件ID"})
return
}
userID := c.GetUint("userID")
if err := h.uploadService.DeleteFile(uint(id), userID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "文件删除成功"})
}
func (h *UploadHandler) GetDeviceFiles(c *gin.Context) {
deviceUID := c.Param("uid")
files, err := h.uploadService.GetDeviceFiles(deviceUID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, files)
}
func (h *UploadHandler) UploadChunk(c *gin.Context) {
file, err := c.FormFile("chunk")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到上传文件"})
return
}
fileHash := c.PostForm("fileHash")
if fileHash == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "未提供文件哈希"})
return
}
chunkNumber, err := strconv.Atoi(c.PostForm("chunkNumber"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的分片序号"})
return
}
totalChunks, err := strconv.Atoi(c.PostForm("totalChunks"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的总分片数"})
return
}
totalSize, err := strconv.ParseInt(c.PostForm("totalSize"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件大小"})
return
}
filename := c.PostForm("filename")
deviceUID := c.PostForm("deviceUID")
userID := c.GetUint("userID")
err = h.uploadService.UploadChunk(
file,
fileHash,
chunkNumber,
totalChunks,
totalSize,
filename,
userID,
deviceUID,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 检查是否所有分片都已上传
completed, err := h.uploadService.CheckUploadStatus(fileHash)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "分片上传成功",
"completed": completed,
})
}
func (h *UploadHandler) MergeChunks(c *gin.Context) {
fileHash := c.PostForm("fileHash")
if fileHash == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "未提供文件哈希"})
return
}
upload, err := h.uploadService.MergeChunks(fileHash)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "文件合并成功",
"file": upload,
})
}
func (h *UploadHandler) UploadSiteFile(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到上传文件"})
return
}
// 检查文件类型
ext := strings.ToLower(filepath.Ext(file.Filename))
allowedExts := map[string]bool{
".jpg": true, ".jpeg": true, ".png": true, ".gif": true,
".ico": true, ".svg": true,
}
if !allowedExts[ext] {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的文件类型"})
return
}
// 生成文件名
filename := fmt.Sprintf("site_%s%s", time.Now().Format("20060102150405"), ext)
// 构建目标目录路径
uploadDir := filepath.Join("web", "static", "images")
if err := os.MkdirAll(uploadDir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建目录失败"})
return
}
// 构建完整的文件路径
filePath := filepath.Join(uploadDir, filename)
// 保存文件
if err := c.SaveUploadedFile(file, filePath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败"})
return
}
// 返回文件URL使用正斜杠作为URL路径分隔符
fileURL := "/" + strings.Join([]string{"static", "images", filename}, "/")
c.JSON(http.StatusOK, gin.H{
"url": fileURL,
"message": "文件上传成功",
})
}

424
internal/api/user.go Normal file
View File

@@ -0,0 +1,424 @@
package api
import (
"licserver/internal/service"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
type UserHandler struct {
userService *service.UserService
}
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{userService: userService}
}
func (h *UserHandler) Login(c *gin.Context) {
var input struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
Captcha string `json:"captcha" binding:"required"`
CaptchaId string `json:"captchaId" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证验证码
if !h.userService.GetCaptchaService().VerifyImageCaptcha(input.CaptchaId, input.Captcha) {
c.JSON(http.StatusBadRequest, gin.H{"error": "验证码错误"})
return
}
token, err := h.userService.Login(input.Username, input.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
// 设置 cookie
// c.SetCookie("token", token, 86400, "/", "", false, true) // 24小时过期httpOnly=true
c.JSON(http.StatusOK, gin.H{"token": token})
}
func (h *UserHandler) Register(c *gin.Context) {
var input struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
Email string `json:"email" binding:"required,email"`
Captcha string `json:"captcha" binding:"required,len=6"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.Register(input.Username, input.Password, input.Email, input.Captcha); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "注册成功"})
}
func (h *UserHandler) ResetPasswordWithToken(c *gin.Context) {
var input struct {
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.ResetPasswordWithToken(input.Token, input.NewPassword); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "密码重置成功"})
}
func (h *UserHandler) ResetPassword(c *gin.Context) {
var input struct {
Email string `json:"email" binding:"required,email"`
Captcha string `json:"captcha" binding:"required,len=6"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.ResetPassword(input.Email, input.Captcha); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "重置密码邮件已发送"})
}
func (h *UserHandler) SendRegisterCaptcha(c *gin.Context) {
var input struct {
Email string `json:"email" binding:"required,email"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.SendRegisterCaptcha(input.Email); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "验证码已发送"})
}
func (h *UserHandler) SendResetPasswordCaptcha(c *gin.Context) {
var input struct {
Email string `json:"email" binding:"required,email"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.SendResetPasswordCaptcha(input.Email); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "验证码已发送"})
}
// 在 UserHandler 中添加以下方法
// 获取图片验证码
func (h *UserHandler) GetCaptcha(c *gin.Context) {
id, b64s, err := h.userService.GetCaptchaService().GenerateImageCaptcha()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "生成验证码失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"captchaId": id,
"imageBase64": b64s,
})
}
// 验证图片验证码
func (h *UserHandler) VerifyCaptcha(c *gin.Context) {
var input struct {
CaptchaId string `json:"captcha_id" binding:"required"`
Code string `json:"code" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.userService.GetCaptchaService().VerifyImageCaptcha(input.CaptchaId, input.Code) {
c.JSON(http.StatusBadRequest, gin.H{"error": "验证码错误"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "验证成功"})
}
// 获取用户列表
func (h *UserHandler) GetUsers(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
username := c.Query("username")
role := c.Query("role")
users, total, err := h.userService.GetUsers(username, role, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"msg": "获取用户列表成功",
"count": total,
"data": users,
})
}
// 创建用户
func (h *UserHandler) CreateUser(c *gin.Context) {
var input struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email" binding:"required,email"`
Role string `json:"role" binding:"required,oneof=admin user"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查权限
if c.GetString("role") != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"})
return
}
err := h.userService.CreateUser(input.Username, input.Password, input.Email, input.Role)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "用户创建成功"})
}
// 更新用户
func (h *UserHandler) UpdateUser(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
return
}
var input struct {
Email string `json:"email" binding:"required,email"`
Role string `json:"role" binding:"required,oneof=admin user"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查权限
if c.GetString("role") != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"})
return
}
err = h.userService.UpdateUser(uint(id), input.Email, input.Role)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "用户更新成功"})
}
// 删除用户
func (h *UserHandler) DeleteUser(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
return
}
// 检查权限
if c.GetString("role") != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"})
return
}
// 不能删除自己
if uint(id) == c.GetUint("userID") {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除自己"})
return
}
err = h.userService.DeleteUser(uint(id))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "用户删除成功"})
}
// 获取用户信息
func (h *UserHandler) GetUserInfo(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
return
}
user, err := h.userService.GetUserByID(uint(id))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
return
}
c.JSON(http.StatusOK, user)
}
// 获取当前用户信息
func (h *UserHandler) GetProfile(c *gin.Context) {
userID := c.GetUint("userID")
user, err := h.userService.GetUserByID(userID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"})
return
}
c.JSON(http.StatusOK, user)
}
// 修改密码
func (h *UserHandler) ChangePassword(c *gin.Context) {
var input struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetUint("userID")
err := h.userService.ChangePassword(userID, input.OldPassword, input.NewPassword)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "密码修改成功"})
}
// 在 UserHandler 结构体中添加 UpdateProfile 方法
func (h *UserHandler) UpdateProfile(c *gin.Context) {
var input struct {
Email string `json:"email" binding:"required,email"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetUint("userID")
if err := h.userService.UpdateProfile(userID, input.Email); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "个人信息更新成功"})
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"licserver/internal/utils"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func JWTAuth(config *utils.JWTConfig) gin.HandlerFunc {
return func(c *gin.Context) {
var token string
// 1. 首先从 cookie 中获取 token
tokenCookie, err := c.Cookie("token")
if err == nil {
token = tokenCookie
}
// 2. 如果 cookie 中没有,则从 header 中获取
if token == "" {
auth := c.GetHeader("Authorization")
if auth != "" {
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
token = parts[1]
}
}
}
// 3. 如果 query 参数中有 token也可以使用
if token == "" {
token = c.Query("token")
}
// 如果都没有找到 token
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证信息"})
c.Abort()
return
}
// 验证 token
claims, err := utils.ParseToken(token, config)
if err != nil {
// 如果 token 无效,清除 cookie
// c.SetCookie("token", "", -1, "/", "", true, true)
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的token"})
c.Abort()
return
}
// 将用户信息存储到上下文
c.Set("userID", claims.UserID)
c.Set("username", claims.Username)
c.Set("role", claims.Role)
// 如果是从 header 或 query 参数获取的 token设置到 cookie 中
if tokenCookie == "" {
// 设置 cookie过期时间与 token 一致
// c.SetCookie("token", token, int(claims.ExpiresAt.Unix()-claims.IssuedAt.Unix()), "/", "", false, true)
}
c.Next()
}
}
// AdminRequired 检查用户是否为管理员
func AdminRequired() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists || role != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"licserver/internal/utils"
"net/http"
"github.com/gin-gonic/gin"
)
func ErrorHandler() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 只处理第一个错误
if len(c.Errors) > 0 {
err := c.Errors[0].Err
var statusCode int
var response utils.ErrorResponse
switch err {
case utils.ErrUnauthorized:
statusCode = http.StatusUnauthorized
response = utils.ErrorResponse{
Code: 401,
Message: "未授权的访问",
Detail: err.Error(),
}
case utils.ErrForbidden:
statusCode = http.StatusForbidden
response = utils.ErrorResponse{
Code: 403,
Message: "禁止访问",
Detail: err.Error(),
}
case utils.ErrNotFound:
statusCode = http.StatusNotFound
response = utils.ErrorResponse{
Code: 404,
Message: "资源不存在",
Detail: err.Error(),
}
case utils.ErrInvalidInput:
statusCode = http.StatusBadRequest
response = utils.ErrorResponse{
Code: 400,
Message: "无效的输入",
Detail: err.Error(),
}
default:
statusCode = http.StatusInternalServerError
response = utils.ErrorResponse{
Code: 500,
Message: "服务器内部错误",
Detail: err.Error(),
}
}
c.JSON(statusCode, response)
c.Abort()
}
}
}

16
internal/model/captcha.go Normal file
View File

@@ -0,0 +1,16 @@
package model
import (
"time"
"gorm.io/gorm"
)
type Captcha struct {
gorm.Model
Code string `gorm:"size:6"` // 验证码
Type string `gorm:"size:20"` // 验证码类型register/login/reset
Target string `gorm:"size:255"` // 目标(邮箱或手机号)
ExpiresAt time.Time `gorm:"index"` // 过期时间
Used bool `gorm:"default:false"` // 是否已使用
}

20
internal/model/chunk.go Normal file
View File

@@ -0,0 +1,20 @@
package model
import (
"gorm.io/gorm"
)
type UploadChunk struct {
gorm.Model
FileHash string `gorm:"size:64;index"` // 完整文件的哈希值
ChunkNumber int `gorm:"index"` // 分片序号
ChunkSize int64 `gorm:""` // 分片大小
ChunkPath string `gorm:"size:255"` // 分片存储路径
TotalChunks int `gorm:""` // 总分片数
TotalSize int64 `gorm:""` // 文件总大小
Filename string `gorm:"size:255"` // 原始文件名
FileType string `gorm:"size:50"` // 文件类型
UploadedBy uint `gorm:"index"` // 上传者ID
DeviceUID string `gorm:"size:255;index"` // 关联的设备UID
Completed bool `gorm:"default:false"` // 是否已完成合并
}

View File

@@ -0,0 +1,13 @@
package model
import (
"gorm.io/gorm"
)
type DeviceLog struct {
gorm.Model
DeviceUID string `gorm:"index" json:"device_uid"` // 设备UID
Action string `json:"action"` // 操作类型
Message string `json:"message"` // 详细信息
Status string `json:"status"` // 状态success/failed
}

View File

@@ -0,0 +1,19 @@
package model
import (
"gorm.io/gorm"
)
type DeviceModel struct {
gorm.Model
ModelName string `gorm:"uniqueIndex" json:"model_name" form:"model_name"` // 设备型号名称
DeviceType string `gorm:"size:50" json:"device_type" form:"device_type"` // 设备类型
Company string `gorm:"size:255" json:"company" form:"company"` // 所属公司
Remark string `gorm:"size:500" json:"remark" form:"remark"` // 备注说明
DeviceCount int `gorm:"-" json:"device_count"` // 设备数量(非数据库字段)
CreatedBy uint `gorm:"index" json:"created_by"` // 创建者ID
CurrentVersion string `gorm:"size:50" json:"current_version"` // 当前版本
UpdateURL string `gorm:"size:255" json:"update_url"` // 更新地址
UpdateDesc string `gorm:"size:500" json:"update_desc"` // 更新说明
Status string `gorm:"size:20;default:active" json:"status"` // 状态active/disabled
}

33
internal/model/license.go Normal file
View File

@@ -0,0 +1,33 @@
package model
import (
"time"
"gorm.io/gorm"
)
type LicenseCode struct {
gorm.Model
Code string `gorm:"uniqueIndex" json:"code"` // 授权码
LicenseType string `gorm:"size:20" json:"license_type"` // 授权类型time/count/permanent
Duration int `json:"duration"` // 授权时长(分钟)仅当类型为time时有效
MaxUses int `json:"max_uses"` // 最大使用次数仅当类型为count时有效
UsedCount int `gorm:"default:0" json:"used_count"` // 已使用次数
Status string `gorm:"size:20" json:"status"` // 状态unused/used/expired/revoked
UsedBy string `gorm:"index" json:"used_by"` // 使用此授权码的设备UID
UsedAt time.Time `json:"used_at"` // 使用时间
CreatedBy uint `gorm:"index" json:"created_by"` // 创建者ID
BatchNo string `gorm:"size:50;index" json:"batch_no"` // 批次号
Remark string `gorm:"size:500" json:"remark"` // 备注
BindCount int `gorm:"default:-1" json:"bind_count"` // 可绑定次数,-1表示无限制0表示不能绑定
}
type LicenseLog struct {
gorm.Model
LicenseID uint `gorm:"index" json:"license_id"` // 关联的授权码ID
DeviceUID string `gorm:"index" json:"device_uid"` // 设备UID
Action string `gorm:"size:20" json:"action"` // 操作类型create/use/verify
IP string `gorm:"size:50" json:"ip"` // 操作IP
Status string `gorm:"size:20" json:"status"` // 状态success/failed
Message string `gorm:"size:500" json:"message"` // 详细信息
}

35
internal/model/models.go Normal file
View File

@@ -0,0 +1,35 @@
package model
import (
"time"
"gorm.io/gorm"
)
type Device struct {
gorm.Model
UID string `gorm:"uniqueIndex" json:"uid"`
IPPort string `json:"ip_port"`
ChipID string `json:"chip_id"`
DeviceType string `json:"device_type"`
DeviceModel string `json:"device_model"`
Company string `json:"company"`
RegisterTime time.Time `json:"register_time"`
ExpireTime time.Time `json:"expire_time"`
LicenseType string `json:"license_type"`
StartCount int `json:"start_count"`
Status string `json:"status"`
LicenseCode string `json:"license_code"`
MaxUses int `json:"max_uses"`
Duration int `json:"duration"`
LastActiveAt time.Time `json:"last_active_at"`
}
type User struct {
gorm.Model
Username string `gorm:"uniqueIndex" json:"username"`
Password string `json:"-"`
Email string `gorm:"uniqueIndex" json:"email"`
Role string `json:"role"`
LastLogin time.Time `json:"last_login"`
}

127
internal/model/monitor.go Normal file
View File

@@ -0,0 +1,127 @@
package model
import (
"time"
)
type SystemStatus struct {
CPU struct {
Usage float64 `json:"usage"` // CPU使用率
LoadAvg []float64 `json:"load_avg"` // 系统负载
CoreCount int `json:"core_count"` // CPU核心数
ModelName string `json:"model_name"` // CPU型号
MHz float64 `json:"mhz"` // CPU频率
} `json:"cpu"`
Memory struct {
Total uint64 `json:"total"` // 总内存
Used uint64 `json:"used"` // 已用内存
Free uint64 `json:"free"` // 空闲内存
UsageRate float64 `json:"usage_rate"` // 使用率
SwapTotal uint64 `json:"swap_total"` // 交换分区总大小
SwapUsed uint64 `json:"swap_used"` // 交换分区已用
SwapFree uint64 `json:"swap_free"` // 交换分区空闲
SwapUsageRate float64 `json:"swap_usage_rate"` // 交换分区使用率
} `json:"memory"`
Disk struct {
Partitions []DiskPartition `json:"partitions"` // 磁盘分区信息
} `json:"disk"`
Network struct {
Interfaces []NetworkInterface `json:"interfaces"` // 网络接口信息
} `json:"network"`
Process struct {
Total int `json:"total"` // 进程总数
List []ProcessInfo `json:"list"` // 进程列表Top N
} `json:"process"`
Host struct {
Hostname string `json:"hostname"` // 主机名
OS string `json:"os"` // 操作系统
Platform string `json:"platform"` // 平台
PlatformVersion string `json:"platform_version"` // 平台版本
KernelVersion string `json:"kernel_version"` // 内核版本
BootTime time.Time `json:"boot_time"` // 启动时间
} `json:"host"`
System struct {
Uptime time.Duration `json:"uptime"` // 系统运行时间
CurrentTime time.Time `json:"current_time"` // 当前时间
ActiveUsers int `json:"active_users"` // 活跃用户数
TotalDevices int `json:"total_devices"` // 设备总数
} `json:"system"`
}
type DiskPartition struct {
Device string `json:"device"` // 设备名
Mountpoint string `json:"mountpoint"` // 挂载点
Fstype string `json:"fstype"` // 文件系统类型
Total uint64 `json:"total"` // 总空间
Used uint64 `json:"used"` // 已用空间
Free uint64 `json:"free"` // 空闲空间
UsageRate float64 `json:"usage_rate"` // 使用率
}
type NetworkInterface struct {
Name string `json:"name"` // 接口名称
BytesSent uint64 `json:"bytes_sent"` // 发送字节数
BytesRecv uint64 `json:"bytes_recv"` // 接收字节数
PacketsSent uint64 `json:"packets_sent"` // 发送包数
PacketsRecv uint64 `json:"packets_recv"` // 接收包数
Addrs []string `json:"addrs"` // IP地址列表改为字符串数组
}
type ProcessInfo struct {
PID int `json:"pid"` // 进程ID
Name string `json:"name"` // 进程名称
CPU float64 `json:"cpu"` // CPU使用率
Memory float64 `json:"memory"` // 内存使用率
Created int64 `json:"created"` // 创建时间
}

View File

@@ -0,0 +1,15 @@
package model
import (
"time"
"gorm.io/gorm"
)
type PasswordResetToken struct {
gorm.Model
UserID uint `gorm:"index"`
Token string `gorm:"uniqueIndex"`
ExpiresAt time.Time
Used bool
}

40
internal/model/token.go Normal file
View File

@@ -0,0 +1,40 @@
package model
import (
"time"
"gorm.io/gorm"
)
type AccessToken struct {
gorm.Model
Token string `gorm:"uniqueIndex" json:"token"` // 访问令牌
DeviceUID string `gorm:"index" json:"device_uid"` // 关联的设备UID
Type string `gorm:"size:20" json:"type"` // 令牌类型api/device
Status string `gorm:"size:20" json:"status"` // 状态active/revoked
ExpireTime time.Time `json:"expire_time"` // 过期时间
LastUsed time.Time `json:"last_used"` // 最后使用时间
UsageCount int `gorm:"default:0" json:"usage_count"` // 使用次数
IPList string `gorm:"type:text" json:"ip_list"` // 允许的IP列表逗号分隔
CreatedBy uint `gorm:"index" json:"created_by"` // 创建者ID
}
// TableName 指定表名
func (AccessToken) TableName() string {
return "access_tokens"
}
type TokenLog struct {
gorm.Model
TokenID uint `gorm:"index" json:"token_id"` // 关联的令牌ID
Action string `gorm:"size:20" json:"action"` // 操作类型create/use/revoke
IP string `gorm:"size:50" json:"ip"` // 操作IP
UserAgent string `gorm:"size:255" json:"user_agent"` // User-Agent
Status string `gorm:"size:20" json:"status"` // 状态success/failed
Message string `gorm:"size:500" json:"message"` // 详细信息
}
// TableName 指定表名
func (TokenLog) TableName() string {
return "token_logs"
}

25
internal/model/upload.go Normal file
View File

@@ -0,0 +1,25 @@
package model
import (
"time"
"gorm.io/gorm"
)
type FileUpload struct {
gorm.Model
FileName string `gorm:"size:255" json:"file_name"` // 文件名
FilePath string `gorm:"size:255" json:"file_path"` // 文件路径
FileSize int64 `json:"file_size"` // 文件大小
FileType string `gorm:"size:50" json:"file_type"` // 文件类型
UploadedBy uint `gorm:"index" json:"uploaded_by"` // 上传者ID
DeviceModel string `gorm:"size:255;index" json:"device_model"` // 设备型号
Version string `gorm:"size:50" json:"version"` // 文件版本
Description string `gorm:"size:500" json:"description"` // 文件描述
IsUpdate bool `gorm:"default:false" json:"is_update"` // 是否为更新文件
Downloads int `gorm:"default:0" json:"downloads"` // 下载次数
LastDownload time.Time `json:"last_download"` // 最后下载时间
MD5 string `gorm:"size:32" json:"md5"` // 文件MD5值
ForceUpdate bool `gorm:"default:false" json:"force_update"` // 是否强制更新
DeviceUID string `gorm:"size:255;index" json:"device_uid"` // 关联的设备UID
}

100
internal/service/captcha.go Normal file
View File

@@ -0,0 +1,100 @@
package service
import (
"errors"
"licserver/internal/model"
"licserver/internal/utils"
"strings"
"time"
"github.com/mojocn/base64Captcha"
"gorm.io/gorm"
)
type CaptchaService struct {
db *gorm.DB
emailConfig *utils.EmailConfig
store base64Captcha.Store
}
func NewCaptchaService(db *gorm.DB, emailConfig *utils.EmailConfig) *CaptchaService {
return &CaptchaService{
db: db,
emailConfig: emailConfig,
store: base64Captcha.DefaultMemStore,
}
}
func (s *CaptchaService) SendEmailCaptcha(email, captchaType string) error {
// 检查是否存在未过期的验证码
var count int64
s.db.Model(&model.Captcha{}).
Where("target = ? AND type = ? AND expires_at > ? AND used = ?",
email, captchaType, time.Now(), false).
Count(&count)
if count > 0 {
return errors.New("请勿频繁发送验证码")
}
// 生成验证码
code, err := utils.GenerateCaptcha()
if err != nil {
return err
}
// 保存验证码
captcha := model.Captcha{
Code: code,
Type: captchaType,
Target: email,
ExpiresAt: time.Now().Add(5 * time.Minute),
Used: false,
}
if err := s.db.Create(&captcha).Error; err != nil {
return err
}
// 发送验证码邮件
emailService := utils.NewEmailService(s.emailConfig)
content := utils.GenerateEmailCaptchaContent(code, email, captchaType)
return emailService.SendEmail(email, "验证码", content)
}
func (s *CaptchaService) VerifyCaptcha(target, captchaType, code string) error {
var captcha model.Captcha
err := s.db.Where("target = ? AND type = ? AND code = ? AND used = ? AND expires_at > ?",
target, captchaType, code, false, time.Now()).
First(&captcha).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("验证码无效或已过期")
}
return err
}
// 标记验证码为已使用
return s.db.Model(&captcha).Update("used", true).Error
}
// 生成图片验证码
func (s *CaptchaService) GenerateImageCaptcha() (string, string, error) {
driver := base64Captcha.NewDriverDigit(80, 240, 6, 0.7, 80)
c := base64Captcha.NewCaptcha(driver, s.store)
id, b64s, err := c.Generate()
if err != nil {
return "", "", err
}
// 确保返回的base64字符串不包含前缀
b64s = strings.TrimPrefix(b64s, "data:image/png;base64,")
return id, b64s, nil
}
// 验证图片验证码
func (s *CaptchaService) VerifyImageCaptcha(id, code string) bool {
return s.store.Verify(id, code, true)
}

721
internal/service/device.go Normal file
View File

@@ -0,0 +1,721 @@
package service
import (
"errors"
"fmt"
"licserver/internal/model"
"time"
"gorm.io/gorm"
)
type DeviceService struct {
db *gorm.DB
licenseService *LicenseService
}
func NewDeviceService(db *gorm.DB, licenseService *LicenseService) *DeviceService {
return &DeviceService{
db: db,
licenseService: licenseService,
}
}
type DeviceRegisterInput struct {
UID string `json:"uid" binding:"required"`
DeviceModel string `json:"device_model" binding:"required"`
LicenseCode string `json:"license_code"`
}
func (s *DeviceService) RegisterDevice(input *DeviceRegisterInput, ip string) error {
// 检查设备型号是否存在且处于启用状态
var deviceModel model.DeviceModel
if err := s.db.Where("model_name = ? AND status = ?", input.DeviceModel, "active").First(&deviceModel).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("设备型号不存在或已禁用")
}
return err
}
// 检查设备是否已注册
var count int64
s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count)
if count > 0 {
return errors.New("设备已注册")
}
// 创建设备记录
device := &model.Device{
UID: input.UID,
DeviceType: deviceModel.DeviceType,
DeviceModel: input.DeviceModel,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "inactive",
LastActiveAt: time.Now(),
StartCount: 0,
}
// 如果提供了授权码,进行授权绑定
if input.LicenseCode != "" {
license, err := s.licenseService.GetLicenseByCode(input.LicenseCode)
if err != nil {
return err
}
if license.Status != "unused" {
return errors.New("授权码已被使用")
}
device.Status = "active"
device.LicenseCode = license.Code
device.LicenseType = license.LicenseType
device.MaxUses = license.MaxUses
device.Duration = license.Duration
if license.LicenseType == "time" {
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
}
}
return s.db.Transaction(func(tx *gorm.DB) error {
// 创建设备记录
if err := tx.Create(device).Error; err != nil {
return err
}
// 如果有授权码,更新授权码状态
if device.LicenseCode != "" {
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", device.LicenseCode).
Updates(map[string]interface{}{
"status": "used",
"used_by": input.UID,
"used_at": time.Now(),
}).Error; err != nil {
return err
}
}
// 记录设备日志
logMsg := "设备注册成功"
if device.LicenseCode != "" {
logMsg += fmt.Sprintf(",使用授权码: %s", device.LicenseCode)
}
log := model.DeviceLog{
DeviceUID: input.UID,
Action: "register",
Message: logMsg,
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
func (s *DeviceService) ValidateDevice(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备未注册")
}
// 更新最后活跃时间
device.LastActiveAt = time.Now()
// 如果设备已激活,检查授权状态
if device.Status == "active" {
if device.LicenseCode != "" {
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
device.Status = "expired"
s.db.Save(&device)
return errors.New("设备授权已过期")
}
}
}
return s.db.Save(&device).Error
}
func (s *DeviceService) GetDevices(params *DeviceQueryParams) ([]model.Device, int64, error) {
var devices []model.Device
var total int64
query := s.db.Model(&model.Device{})
if params.UID != "" {
query = query.Where("uid LIKE ?", "%"+params.UID+"%")
}
if params.DeviceType != "" {
query = query.Where("device_type = ?", params.DeviceType)
}
if params.Company != "" {
query = query.Where("company LIKE ?", "%"+params.Company+"%")
}
if params.LicenseType != "" {
query = query.Where("license_type = ?", params.LicenseType)
}
if params.Status != "" {
query = query.Where("status = ?", params.Status)
}
query.Count(&total)
if params.Page > 0 && params.PageSize > 0 {
offset := (params.Page - 1) * params.PageSize
query = query.Offset(offset).Limit(params.PageSize)
}
err := query.Find(&devices).Error
return devices, total, err
}
func (s *DeviceService) UpdateStartCount(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return err
}
// 更新启动次数和最后活跃时间
device.StartCount++
device.LastActiveAt = time.Now()
// 如果设备已激活,检查授权状态
if device.Status == "active" {
// 检查授权码有效性
if device.LicenseCode != "" {
if err := s.licenseService.CheckLicenseValidity(device.LicenseCode); err != nil {
device.Status = "expired"
}
}
// 检查次数限制
if device.LicenseType == "count" && device.StartCount >= device.MaxUses {
device.Status = "expired"
}
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "start",
Message: fmt.Sprintf("设备启动,当前次数:%d", device.StartCount),
Status: "success",
}
return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&device).Error; err != nil {
return err
}
return tx.Create(&log).Error
})
}
func (s *DeviceService) UpdateDevice(uid string, updates map[string]interface{}) error {
return s.db.Model(&model.Device{}).Where("uid = ?", uid).Updates(updates).Error
}
func (s *DeviceService) DeleteDevice(uid string) error {
return s.db.Where("uid = ?", uid).Delete(&model.Device{}).Error
}
func (s *DeviceService) GetLicenseInfo(deviceUID string) (*model.Device, error) {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
return &device, nil
}
func (s *DeviceService) CheckLicenseStatus(deviceUID string) (string, error) {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return "", errors.New("设备不存在")
}
if device.LicenseCode == "" {
return "未授权", nil
}
if device.Status != "active" {
return device.Status, nil
}
switch device.LicenseType {
case "时间段":
if time.Now().After(device.ExpireTime) {
device.Status = "expired"
s.db.Save(&device)
return "已过期", nil
}
case "启动次数":
if device.StartCount >= device.MaxUses {
device.Status = "expired"
s.db.Save(&device)
return "已达到使用上限", nil
}
}
return "正常", nil
}
type DeviceQueryParams struct {
UID string
DeviceType string
Company string
LicenseType string
Status string
Page int
PageSize int
}
type DeviceCreateInput struct {
UID string `json:"uid" binding:"required"`
DeviceType string `json:"device_type" binding:"required"`
DeviceModel string `json:"device_model" binding:"required"`
Company string `json:"company"`
}
func (s *DeviceService) CreateDevice(input *DeviceCreateInput) error {
// 检查设备UID是否已存在
var count int64
if err := s.db.Model(&model.Device{}).Where("uid = ?", input.UID).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("设备UID已存在")
}
// 创建设备记录
device := &model.Device{
UID: input.UID,
DeviceType: input.DeviceType,
DeviceModel: input.DeviceModel,
Company: input.Company,
RegisterTime: time.Now(),
Status: "inactive", // 初始状态为未激活
}
return s.db.Create(device).Error
}
// 添加更新检查方法
func (s *DeviceService) CheckUpdate(deviceUID, currentVersion string) (*model.FileUpload, error) {
// 获取设备信息
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
// 检查设备状态
if device.Status != "active" {
return nil, errors.New("设备未激活或已过期")
}
// 查找最新的更新文件
var update model.FileUpload
err := s.db.Where("device_model = ? AND is_update = ?", device.DeviceModel, true).
Order("created_at DESC").
First(&update).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 没有可用更新
}
return nil, err
}
// 比较版本
if update.Version <= currentVersion && !update.ForceUpdate {
return nil, nil // 当前版本已是最新
}
return &update, nil
}
// 添加设备型号相关的方法
// CreateDeviceModel 创建设备型号
func (s *DeviceService) CreateDeviceModel(model_ *model.DeviceModel) error {
// 检查型号名称是否已存在
var count int64
s.db.Model(&model.DeviceModel{}).Where("model_name = ?", model_.ModelName).Count(&count)
if count > 0 {
return errors.New("设备型号已存在")
}
return s.db.Create(model_).Error
}
// UpdateDeviceModel 更新设备型号
func (s *DeviceService) UpdateDeviceModel(id uint, model_ *model.DeviceModel) error {
// 检查型号名称是否被其他型号使用
var count int64
s.db.Model(&model.DeviceModel{}).Where("model_name = ? AND id != ?", model_.ModelName, id).Count(&count)
if count > 0 {
return errors.New("设备型号已存在")
}
return s.db.Model(&model.DeviceModel{}).Where("id = ?", id).Updates(model_).Error
}
// DeleteDeviceModel 删除设备型号
func (s *DeviceService) DeleteDeviceModel(id uint) error {
// 检查是否有设备使用此型号
var count int64
s.db.Model(&model.DeviceModel{}).Where("device_model = ?", id).Count(&count)
if count > 0 {
return errors.New("该型号下存在设备,无法删除")
}
return s.db.Delete(&model.DeviceModel{}, id).Error
}
// GetDeviceModels 获取设备型号列表
func (s *DeviceService) GetDeviceModels(modelName, deviceType, company string, page, pageSize int) ([]model.DeviceModel, int64, error) {
var models []model.DeviceModel
var total int64
query := s.db.Model(&model.DeviceModel{})
if modelName != "" {
query = query.Where("model_name LIKE ?", "%"+modelName+"%")
}
if deviceType != "" {
query = query.Where("device_type = ?", deviceType)
}
if company != "" {
query = query.Where("company LIKE ?", "%"+company+"%")
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
// 查询设备型号列表
if err := query.Find(&models).Error; err != nil {
return nil, 0, err
}
// 查询每个型号下的设备数量
for i := range models {
var count int64
s.db.Model(&model.Device{}).Where("device_model = ?", models[i].ModelName).Count(&count)
models[i].DeviceCount = int(count)
}
return models, total, nil
}
// BatchDeleteDeviceModels 批量删除设备型号
func (s *DeviceService) BatchDeleteDeviceModels(ids []uint) error {
// 检查是否有设备使用这些型号
var count int64
s.db.Model(&model.Device{}).Where("device_model IN (?)", ids).Count(&count)
if count > 0 {
return errors.New("选中的型号中存在正在使用的型号,无法删除")
}
return s.db.Delete(&model.DeviceModel{}, ids).Error
}
// GetRegisteredDevices 获取已注册设备列表
func (s *DeviceService) GetRegisteredDevices(uid, deviceModel, status string, page, pageSize int) ([]model.Device, int64, error) {
var devices []model.Device
var total int64
query := s.db.Model(&model.Device{})
if uid != "" {
query = query.Where("uid LIKE ?", "%"+uid+"%")
}
if deviceModel != "" {
query = query.Where("device_model = ?", deviceModel)
}
if status != "" {
query = query.Where("status = ?", status)
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&devices).Error
return devices, total, err
}
// BindLicense 绑定授权码
func (s *DeviceService) BindLicense(uid string, licenseCode string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备不存在")
}
// 检查设备当前状态
if device.LicenseCode != "" {
return errors.New("设备已绑定授权码,请先解绑")
}
// 验证授权码
license, err := s.licenseService.GetLicenseByCode(licenseCode)
if err != nil {
return err
}
if license.Status != "unused" {
return errors.New("授权码已被使用")
}
// 根据授权类型处理
switch license.LicenseType {
case "time":
if license.Duration <= 0 {
return errors.New("无效的授权时长")
}
device.ExpireTime = time.Now().Add(time.Duration(license.Duration) * time.Minute)
device.Duration = license.Duration
device.MaxUses = 0
case "count":
if license.MaxUses <= 0 {
return errors.New("无效的使用次数")
}
device.ExpireTime = time.Time{} // 清空过期时间
device.Duration = 0
device.MaxUses = license.MaxUses
device.StartCount = 0 // 重置启动次数
case "permanent":
device.ExpireTime = time.Time{} // 清空过期时间
device.Duration = 0
device.MaxUses = 0
default:
return errors.New("无效的授权类型")
}
// 更新设备基本信息
device.LicenseCode = licenseCode
device.LicenseType = license.LicenseType
device.Status = "active"
device.LastActiveAt = time.Now()
return s.db.Transaction(func(tx *gorm.DB) error {
// 更新设备信息
if err := tx.Save(&device).Error; err != nil {
return err
}
// 更新授权码状态
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", licenseCode).
Updates(map[string]interface{}{
"status": "used",
"used_by": uid,
"used_at": time.Now(),
}).Error; err != nil {
return err
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "bind_license",
Message: fmt.Sprintf("绑定%s授权码: %s", getLicenseTypeText(license.LicenseType), licenseCode),
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
// 获取授权类型的中文描述
func getLicenseTypeText(licenseType string) string {
switch licenseType {
case "time":
return "时间"
case "count":
return "次数"
case "permanent":
return "永久"
default:
return "未知"
}
}
// UnbindLicense 解绑授权码
func (s *DeviceService) UnbindLicense(uid string) error {
var device model.Device
if err := s.db.Where("uid = ?", uid).First(&device).Error; err != nil {
return errors.New("设备不存在")
}
if device.LicenseCode == "" {
return errors.New("设备未绑定授权码")
}
oldLicenseCode := device.LicenseCode
return s.db.Transaction(func(tx *gorm.DB) error {
// 更新设备信息
if err := tx.Model(&device).Updates(map[string]interface{}{
"license_code": "",
"license_type": "",
"status": "inactive",
"expire_time": nil,
"max_uses": 0,
"duration": 0,
}).Error; err != nil {
return err
}
// 更新授权码状态
if err := tx.Model(&model.LicenseCode{}).Where("code = ?", oldLicenseCode).
Updates(map[string]interface{}{
"status": "unused",
"used_by": "",
"used_at": nil,
}).Error; err != nil {
return err
}
// 记录设备日志
log := model.DeviceLog{
DeviceUID: uid,
Action: "unbind_license",
Message: fmt.Sprintf("解绑授权码: %s", oldLicenseCode),
Status: "success",
}
if err := tx.Create(&log).Error; err != nil {
return err
}
return nil
})
}
// GetDeviceLogs 获取设备日志
func (s *DeviceService) GetDeviceLogs(uid string, page, pageSize int) ([]model.DeviceLog, int64, error) {
var logs []model.DeviceLog
var total int64
query := s.db.Model(&model.DeviceLog{}).Where("device_uid = ?", uid)
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}
// DashboardStats 仪表盘统计数据
type DashboardStats struct {
TotalDevices int64 `json:"total_devices"` // 设备总数
TotalLicenses int64 `json:"total_licenses"` // 授权码总数
TodayNew int64 `json:"today_new"` // 今日新增
OnlineDevices int64 `json:"online_devices"` // 在线设备
ActiveDevices int64 `json:"active_devices"` // 激活设备
ExpiredDevices int64 `json:"expired_devices"` // 过期设备
}
// GetDashboardStats 获取仪表盘统计数据
func (s *DeviceService) GetDashboardStats() (*DashboardStats, error) {
var stats DashboardStats
// 获取设备总数
s.db.Model(&model.Device{}).Count(&stats.TotalDevices)
// 获取授权码总数
s.db.Model(&model.LicenseCode{}).Count(&stats.TotalLicenses)
// 获取今日新增设备数
today := time.Now().Format("2006-01-02")
s.db.Model(&model.Device{}).Where("DATE(register_time) = ?", today).Count(&stats.TodayNew)
// 获取在线设备数最近30分钟内有活动的设备
thirtyMinutesAgo := time.Now().Add(-30 * time.Minute)
s.db.Model(&model.Device{}).Where("last_active_at > ?", thirtyMinutesAgo).Count(&stats.OnlineDevices)
// 获取激活设备数
s.db.Model(&model.Device{}).Where("status = ?", "active").Count(&stats.ActiveDevices)
// 获取过期设备数
s.db.Model(&model.Device{}).Where("status = ?", "expired").Count(&stats.ExpiredDevices)
return &stats, nil
}

View File

@@ -0,0 +1,298 @@
package service
import (
"licserver/internal/model"
"licserver/internal/utils"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDeviceService_RegisterDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
assert.Len(t, licenses, 1)
tests := []struct {
name string
input *DeviceRegisterInput
wantErr bool
}{
{
name: "正常注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: false,
},
{
name: "重复注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
{
name: "使用无效授权码",
input: &DeviceRegisterInput{
UID: "test-device-002",
DeviceModel: "test-model",
LicenseCode: "invalid-code",
},
wantErr: true,
},
{
name: "使用不存在的设备型号",
input: &DeviceRegisterInput{
UID: "test-device-003",
DeviceModel: "non-existent-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.RegisterDevice(tt.input, "127.0.0.1")
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证设备是否正确注册
var device model.Device
err = db.Where("uid = ?", tt.input.UID).First(&device).Error
assert.NoError(t, err)
assert.Equal(t, deviceModel.DeviceType, device.DeviceType)
assert.Equal(t, "active", device.Status)
assert.Equal(t, tt.input.LicenseCode, device.LicenseCode)
})
}
}
func TestDeviceService_ValidateDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
// 注册测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseCode: licenses[0].Code,
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(30 * 24 * time.Hour),
}
err = db.Create(device).Error
assert.NoError(t, err)
tests := []struct {
name string
uid string
wantErr bool
}{
{
name: "验证正常设备",
uid: "test-device-001",
wantErr: false,
},
{
name: "验证不存在的设备",
uid: "non-existent-device",
wantErr: true,
},
{
name: "验证过期设备",
uid: "expired-device",
wantErr: true,
},
}
// 创建过期设备
expiredDevice := &model.Device{
UID: "expired-device",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "expired",
LicenseCode: "expired-license",
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(-24 * time.Hour),
}
err = db.Create(expiredDevice).Error
assert.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.ValidateDevice(tt.uid)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}
func TestDeviceService_UpdateStartCount(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseType: "count",
MaxUses: 5,
StartCount: 0,
}
err = db.Create(device).Error
assert.NoError(t, err)
// 测试更新启动次数
for i := 1; i <= 5; i++ {
err = deviceService.UpdateStartCount(device.UID)
assert.NoError(t, err)
// 验证启动次数
var updatedDevice model.Device
err = db.First(&updatedDevice, "uid = ?", device.UID).Error
assert.NoError(t, err)
assert.Equal(t, i, updatedDevice.StartCount)
// 检查最后一次是否将状态更新为过期
if i == 5 {
assert.Equal(t, "expired", updatedDevice.Status)
} else {
assert.Equal(t, "active", updatedDevice.Status)
}
}
// 测试超出使用次数
err = deviceService.UpdateStartCount(device.UID)
assert.Error(t, err)
}
func TestDeviceService_CreateDeviceModel(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
tests := []struct {
name string
model *model.DeviceModel
wantErr bool
}{
{
name: "创建有效设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: false,
},
{
name: "重复的设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.CreateDeviceModel(tt.model)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证创建的设备型号
var savedModel model.DeviceModel
err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error
assert.NoError(t, err)
assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType)
assert.Equal(t, tt.model.Company, savedModel.Company)
assert.Equal(t, tt.model.Status, savedModel.Status)
})
}
}

709
internal/service/license.go Normal file
View File

@@ -0,0 +1,709 @@
package service
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"licserver/internal/model"
"gorm.io/gorm"
)
type LicenseService struct {
db *gorm.DB
}
func NewLicenseService(db *gorm.DB) *LicenseService {
return &LicenseService{db: db}
}
// 生成授权码
func generateLicenseCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// 创建授权码
type LicenseCreateInput struct {
LicenseType string `json:"license_type" binding:"required"`
Duration int `json:"duration"` // 时间授权的有效期(分钟)
MaxUses int `json:"max_uses"` // 次数授权的使用次数
Count int `json:"count" binding:"required,min=1"` // 生成数量
Remark string `json:"remark"` // 备注
BindCount int `json:"bind_count"` // 可绑定次数,默认为-1无限制
}
func (s *LicenseService) CreateLicenses(input *LicenseCreateInput, createdBy uint) ([]model.LicenseCode, error) {
// 验证参数
input.LicenseType = strings.ToLower(input.LicenseType) // 转为小写
switch input.LicenseType {
case "time":
if input.Duration <= 0 {
return nil, errors.New("时间授权必须指定有效期")
}
case "count":
if input.MaxUses <= 0 {
return nil, errors.New("次数授权必须指定使用次数")
}
case "permanent":
// 永久授权不需要额外参数
default:
return nil, errors.New("无效的授权类型")
}
// 如果未指定绑定次数,设置为默认值-1
if input.BindCount == 0 {
input.BindCount = -1
}
// 生成批次号
batchNo := time.Now().Format("20060102150405")
licenses := make([]model.LicenseCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
code, err := generateLicenseCode()
if err != nil {
return nil, err
}
license := model.LicenseCode{
Code: code,
LicenseType: input.LicenseType,
Duration: input.Duration,
MaxUses: input.MaxUses,
Status: "unused",
CreatedBy: createdBy,
BatchNo: batchNo,
Remark: input.Remark,
BindCount: input.BindCount,
}
licenses = append(licenses, license)
}
// 批量创建授权码
if err := s.db.Create(&licenses).Error; err != nil {
return nil, err
}
return licenses, nil
}
// 验证并使用授权码
func (s *LicenseService) UseLicense(code, deviceUID, ip string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return nil, errors.New("授权码不存在")
}
// 检查授权码状态
if license.Status != "unused" {
return nil, errors.New("授权码已被使用")
}
// 检查绑定次数
if license.BindCount == 0 {
return nil, errors.New("授权码已达到最大绑定次数限制")
}
// 更新授权码状态
updates := map[string]interface{}{
"status": "used",
"used_by": deviceUID,
"used_at": time.Now(),
}
// 如果不是无限制,减少绑定次数
if license.BindCount > 0 {
updates["bind_count"] = license.BindCount - 1
}
if err := s.db.Model(&license).Updates(updates).Error; err != nil {
return nil, err
}
// 记录使用日志
log := model.LicenseLog{
LicenseID: license.ID,
DeviceUID: deviceUID,
Action: "use",
IP: ip,
Status: "success",
Message: fmt.Sprintf("设备 %s 使用授权码", deviceUID),
}
if err := s.db.Create(&log).Error; err != nil {
return nil, err
}
return &license, nil
}
// 获取授权码列表
func (s *LicenseService) GetLicenses(status, licenseType, batchNo string, page, pageSize int) ([]model.LicenseCode, int64, error) {
var licenses []model.LicenseCode
var total int64
query := s.db.Model(&model.LicenseCode{})
if status != "" {
query = query.Where("status = ?", strings.ToLower(status))
}
if licenseType != "" {
query = query.Where("license_type = ?", strings.ToLower(licenseType))
}
if batchNo != "" {
query = query.Where("batch_no = ?", batchNo)
}
// 获取所有符合条件的授权码
var allLicenses []model.LicenseCode
if err := query.Find(&allLicenses).Error; err != nil {
return nil, 0, err
}
// 检查每个授权码的有效性
for i := range allLicenses {
if allLicenses[i].Status == "used" {
if err := s.CheckLicenseValidity(allLicenses[i].Code); err != nil {
// 如果检查失败,更新状态
s.db.Model(&allLicenses[i]).Update("status", "expired")
allLicenses[i].Status = "expired"
}
}
}
total = int64(len(allLicenses))
// 分页
if page > 0 && pageSize > 0 {
start := (page - 1) * pageSize
end := start + pageSize
if start < len(allLicenses) {
if end > len(allLicenses) {
end = len(allLicenses)
}
licenses = allLicenses[start:end]
}
} else {
licenses = allLicenses
}
return licenses, total, nil
}
// 获取授权码使用日志
func (s *LicenseService) GetLicenseLogs(licenseID uint, page, pageSize int) ([]model.LicenseLog, int64, error) {
var logs []model.LicenseLog
var total int64
query := s.db.Model(&model.LicenseLog{}).Where("license_id = ?", licenseID)
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}
// ExportLogs 导出授权码日志
func (s *LicenseService) ExportLogs(licenseID uint) ([]byte, error) {
logs, _, err := s.GetLicenseLogs(licenseID, 0, 0) // 获取所有日志
if err != nil {
return nil, err
}
// 创建CSV内容
var content strings.Builder
content.WriteString("操作类,设备UID,IP地址,状态,详细信息,时间\n")
for _, log := range logs {
// 转换操作类型
action := map[string]string{
"create": "创建",
"use": "使用",
"verify": "验证",
}[log.Action]
// 转换状态
status := map[string]string{
"success": "成功",
"failed": "失败",
}[log.Status]
// 写入一行记录
content.WriteString(fmt.Sprintf("%s,%s,%s,%s,%s,%s\n",
action,
log.DeviceUID,
log.IP,
status,
log.Message,
log.CreatedAt.Format("2006-01-02 15:04:05"),
))
}
return []byte(content.String()), nil
}
// 撤销授权码
func (s *LicenseService) RevokeLicense(code string, userID uint) error {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return errors.New("授权码不存在")
}
// 检查权限
if license.CreatedBy != userID {
return errors.New("无权操作此授权码")
}
// 更新状态
if err := s.db.Model(&license).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
log := model.LicenseLog{
LicenseID: license.ID,
Action: "revoke",
Status: "success",
Message: "授权码已撤销",
}
s.db.Create(&log)
return nil
}
// 批量撤销授权码
func (s *LicenseService) RevokeLicenses(codes []string, userID uint) error {
return s.db.Transaction(func(tx *gorm.DB) error {
for _, code := range codes {
var license model.LicenseCode
if err := tx.Where("code = ?", code).First(&license).Error; err != nil {
continue
}
// 检查权限
if license.CreatedBy != userID {
continue
}
// 更新状态
if err := tx.Model(&license).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
log := model.LicenseLog{
LicenseID: license.ID,
Action: "revoke",
Status: "success",
Message: "授权码已撤销",
}
tx.Create(&log)
}
return nil
})
}
// 验证授权码
func (s *LicenseService) ValidateLicense(code string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return nil, errors.New("无效的授权码")
}
// 检查状态
if license.Status != "unused" {
return nil, errors.New("授权码已被使用或已撤销")
}
return &license, nil
}
// 导出授权码
func (s *LicenseService) ExportLicenses(codes []string) ([]byte, error) {
var licenses []model.LicenseCode
if err := s.db.Where("code IN ?", codes).Find(&licenses).Error; err != nil {
return nil, err
}
// 创建CSV内容
var content strings.Builder
content.WriteString("授权码,授权类型,有效期(天),使用次数,状态,使用设备,使用时间,批次号,备注\n")
for _, license := range licenses {
// 转换授权类型
licenseType := map[string]string{
"time": "时间授权",
"count": "次数授权",
"permanent": "永久授权",
}[license.LicenseType]
// 转换状态
status := map[string]string{
"unused": "未使用",
"used": "已使用",
"revoked": "已撤销",
}[license.Status]
// 写入一行记录
content.WriteString(fmt.Sprintf("%s,%s,%d,%d,%s,%s,%s,%s,%s\n",
license.Code,
licenseType,
license.Duration,
license.MaxUses,
status,
license.UsedBy,
license.UsedAt.Format("2006-01-02 15:04:05"),
license.BatchNo,
license.Remark,
))
}
return []byte(content.String()), nil
}
// 获取授权码统计信息
func (s *LicenseService) GetLicenseStats() (map[string]interface{}, error) {
var stats struct {
Total int64
Unused int64
Used int64
Revoked int64
Today int64
ThisWeek int64
ThisMonth int64
}
// 获取总数
s.db.Model(&model.LicenseCode{}).Count(&stats.Total)
// 获取各状态数量
s.db.Model(&model.LicenseCode{}).Where("status = ?", "unused").Count(&stats.Unused)
s.db.Model(&model.LicenseCode{}).Where("status = ?", "used").Count(&stats.Used)
s.db.Model(&model.LicenseCode{}).Where("status = ?", "revoked").Count(&stats.Revoked)
// 获取今日创建数量
today := time.Now().Format("2006-01-02")
s.db.Model(&model.LicenseCode{}).Where("DATE(created_at) = ?", today).Count(&stats.Today)
// 获取本周创建数量
weekStart := time.Now().AddDate(0, 0, -int(time.Now().Weekday()))
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", weekStart).Count(&stats.ThisWeek)
// 获取本月创建数量
monthStart := time.Now().Format("2006-01") + "-01"
s.db.Model(&model.LicenseCode{}).Where("created_at >= ?", monthStart).Count(&stats.ThisMonth)
return map[string]interface{}{
"total": stats.Total,
"unused": stats.Unused,
"used": stats.Used,
"revoked": stats.Revoked,
"today": stats.Today,
"this_week": stats.ThisWeek,
"this_month": stats.ThisMonth,
}, nil
}
// 添加检查授权码有效性的方法
func (s *LicenseService) CheckLicenseValidity(code string) error {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
return errors.New("授权码不存在")
}
if license.Status != "unused" && license.Status != "used" {
return errors.New("授权码已被撤销或过期")
}
// 检查授权类型特定的限制
switch license.LicenseType {
case "time":
// 计算过期时间
expireTime := license.UsedAt.Add(time.Duration(license.Duration) * time.Minute)
if time.Now().After(expireTime) {
// 更新状态为过期
s.db.Model(&license).Update("status", "expired")
return errors.New("授权码已过期")
}
case "count":
if license.UsedCount >= license.MaxUses {
// 更新状态为过期
s.db.Model(&license).Update("status", "expired")
return errors.New("授权码使用次数已达上限")
}
}
return nil
}
// GetLicenseByCode 通过授权码获取授权信息
func (s *LicenseService) GetLicenseByCode(code string) (*model.LicenseCode, error) {
var license model.LicenseCode
if err := s.db.Where("code = ?", code).First(&license).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("授权码不存在")
}
return nil, err
}
return &license, nil
}

View File

@@ -0,0 +1,153 @@
package service
import (
"licserver/internal/utils"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLicenseService_CreateLicenses(t *testing.T) {
db := utils.TestDB(t)
service := NewLicenseService(db)
tests := []struct {
name string
input *LicenseCreateInput
wantErr bool
}{
{
name: "创建时间授权码",
input: &LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 5,
Remark: "test time license",
},
wantErr: false,
},
{
name: "创建次数授权码",
input: &LicenseCreateInput{
LicenseType: "count",
MaxUses: 100,
Count: 3,
Remark: "test count license",
},
wantErr: false,
},
{
name: "创建永久授权码",
input: &LicenseCreateInput{
LicenseType: "permanent",
Count: 1,
Remark: "test permanent license",
},
wantErr: false,
},
{
name: "无效的授权类型",
input: &LicenseCreateInput{
LicenseType: "invalid",
Count: 1,
},
wantErr: true,
},
{
name: "时间授权无有效期",
input: &LicenseCreateInput{
LicenseType: "time",
Duration: 0,
Count: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
licenses, err := service.CreateLicenses(tt.input, 1)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Len(t, licenses, tt.input.Count)
for _, license := range licenses {
assert.Equal(t, tt.input.LicenseType, license.LicenseType)
assert.Equal(t, "unused", license.Status)
assert.Equal(t, tt.input.Remark, license.Remark)
if tt.input.LicenseType == "time" {
assert.Equal(t, tt.input.Duration, license.Duration)
} else if tt.input.LicenseType == "count" {
assert.Equal(t, tt.input.MaxUses, license.MaxUses)
}
}
})
}
}
func TestLicenseService_UseLicense(t *testing.T) {
db := utils.TestDB(t)
service := NewLicenseService(db)
// 创建测试授权码
input := &LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}
licenses, err := service.CreateLicenses(input, 1)
assert.NoError(t, err)
assert.Len(t, licenses, 1)
tests := []struct {
name string
code string
deviceUID string
ip string
wantErr bool
}{
{
name: "正常使用授权码",
code: licenses[0].Code,
deviceUID: "test-device-001",
ip: "127.0.0.1",
wantErr: false,
},
{
name: "使用不存在的授权码",
code: "invalid-code",
deviceUID: "test-device-002",
ip: "127.0.0.1",
wantErr: true,
},
{
name: "重复使用授权码",
code: licenses[0].Code,
deviceUID: "test-device-003",
ip: "127.0.0.1",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
license, err := service.UseLicense(tt.code, tt.deviceUID, tt.ip)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, "used", license.Status)
assert.Equal(t, tt.deviceUID, license.UsedBy)
assert.NotZero(t, license.UsedAt)
})
}
}

377
internal/service/monitor.go Normal file
View File

@@ -0,0 +1,377 @@
package service
import (
"licserver/internal/model"
"runtime"
"time"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/host"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/net"
"github.com/shirou/gopsutil/v3/process"
"gorm.io/gorm"
)
type MonitorService struct {
db *gorm.DB
startTime time.Time
}
func NewMonitorService(db *gorm.DB) *MonitorService {
return &MonitorService{
db: db,
startTime: time.Now(),
}
}
func (s *MonitorService) GetSystemStatus() (*model.SystemStatus, error) {
status := &model.SystemStatus{}
// CPU信息
if err := s.getCPUInfo(status); err != nil {
return nil, err
}
// 内存信息
if err := s.getMemoryInfo(status); err != nil {
return nil, err
}
// 磁盘信息
if err := s.getDiskInfo(status); err != nil {
return nil, err
}
// 网络信息
if err := s.getNetworkInfo(status); err != nil {
return nil, err
}
// 进程信息
if err := s.getProcessInfo(status); err != nil {
return nil, err
}
// 主机信息
if err := s.getHostInfo(status); err != nil {
return nil, err
}
// 系统信息
s.getSystemInfo(status)
return status, nil
}
func (s *MonitorService) getCPUInfo(status *model.SystemStatus) error {
cpuPercent, err := cpu.Percent(time.Second, false)
if err != nil {
return err
}
// Windows 系统不支持 LoadAvg设置为默认值
status.CPU.LoadAvg = []float64{0, 0, 0}
cpuInfo, err := cpu.Info()
if err != nil {
return err
}
status.CPU.Usage = cpuPercent[0]
status.CPU.CoreCount = runtime.NumCPU()
if len(cpuInfo) > 0 {
status.CPU.ModelName = cpuInfo[0].ModelName
status.CPU.MHz = cpuInfo[0].Mhz
}
return nil
}
func (s *MonitorService) getMemoryInfo(status *model.SystemStatus) error {
memInfo, err := mem.VirtualMemory()
if err != nil {
return err
}
swapInfo, err := mem.SwapMemory()
if err != nil {
return err
}
status.Memory.Total = memInfo.Total
status.Memory.Used = memInfo.Used
status.Memory.Free = memInfo.Free
status.Memory.UsageRate = memInfo.UsedPercent
status.Memory.SwapTotal = swapInfo.Total
status.Memory.SwapUsed = swapInfo.Used
status.Memory.SwapFree = swapInfo.Free
status.Memory.SwapUsageRate = swapInfo.UsedPercent
return nil
}
func (s *MonitorService) getDiskInfo(status *model.SystemStatus) error {
partitions, err := disk.Partitions(true)
if err != nil {
return err
}
status.Disk.Partitions = make([]model.DiskPartition, 0)
for _, partition := range partitions {
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
status.Disk.Partitions = append(status.Disk.Partitions, model.DiskPartition{
Device: partition.Device,
Mountpoint: partition.Mountpoint,
Fstype: partition.Fstype,
Total: usage.Total,
Used: usage.Used,
Free: usage.Free,
UsageRate: usage.UsedPercent,
})
}
return nil
}
func (s *MonitorService) getNetworkInfo(status *model.SystemStatus) error {
interfaces, err := net.Interfaces()
if err != nil {
return err
}
ioCounters, err := net.IOCounters(true)
if err != nil {
return err
}
status.Network.Interfaces = make([]model.NetworkInterface, 0)
for _, iface := range interfaces {
var counter net.IOCountersStat
for _, io := range ioCounters {
if io.Name == iface.Name {
counter = io
break
}
}
// 获取接口的地址列表
addrs := make([]string, 0)
for _, addr := range iface.Addrs {
addrs = append(addrs, addr.String())
}
status.Network.Interfaces = append(status.Network.Interfaces, model.NetworkInterface{
Name: iface.Name,
BytesSent: counter.BytesSent,
BytesRecv: counter.BytesRecv,
PacketsSent: counter.PacketsSent,
PacketsRecv: counter.PacketsRecv,
Addrs: addrs,
})
}
return nil
}
func (s *MonitorService) getProcessInfo(status *model.SystemStatus) error {
processes, err := process.Processes()
if err != nil {
return err
}
status.Process.Total = len(processes)
status.Process.List = make([]model.ProcessInfo, 0)
for i := 0; i < 10 && i < len(processes); i++ {
p := processes[i]
name, _ := p.Name()
cpu, _ := p.CPUPercent()
mem, _ := p.MemoryPercent()
created, _ := p.CreateTime()
status.Process.List = append(status.Process.List, model.ProcessInfo{
PID: int(p.Pid),
Name: name,
CPU: cpu,
Memory: float64(mem),
Created: created,
})
}
return nil
}
func (s *MonitorService) getHostInfo(status *model.SystemStatus) error {
info, err := host.Info()
if err != nil {
return err
}
status.Host.Hostname = info.Hostname
status.Host.OS = info.OS
status.Host.Platform = info.Platform
status.Host.PlatformVersion = info.PlatformVersion
status.Host.KernelVersion = info.KernelVersion
status.Host.BootTime = time.Unix(int64(info.BootTime), 0)
return nil
}
func (s *MonitorService) getSystemInfo(status *model.SystemStatus) {
status.System.Uptime = time.Since(s.startTime)
status.System.CurrentTime = time.Now()
var activeUsers int64
s.db.Model(&model.User{}).Count(&activeUsers)
status.System.ActiveUsers = int(activeUsers)
var totalDevices int64
s.db.Model(&model.Device{}).Count(&totalDevices)
status.System.TotalDevices = int(totalDevices)
}

47
internal/service/site.go Normal file
View File

@@ -0,0 +1,47 @@
package service
import (
"errors"
"licserver/internal/utils"
"sync"
)
type SiteService struct {
config *utils.Config
mu sync.RWMutex
}
func NewSiteService(config *utils.Config) *SiteService {
return &SiteService{
config: config,
}
}
func (s *SiteService) GetSettings() utils.SiteConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.config.Site
}
func (s *SiteService) UpdateSettings(settings utils.SiteConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// 更新内存中的配置
s.config.Site = settings
// 持久化配置到文件
return utils.SaveConfig(s.config, "system", "更新站点设置")
}
// ValidateSettings 验证站点设置
func (s *SiteService) ValidateSettings(settings utils.SiteConfig) error {
// 这里可以添加更多的验证逻辑
if settings.Title == "" {
return errors.New("站点标题不能为空")
}
if settings.BaseURL == "" {
return errors.New("基础URL不能为空")
}
return nil
}

188
internal/service/token.go Normal file
View File

@@ -0,0 +1,188 @@
package service
import (
"crypto/rand"
"encoding/hex"
"errors"
"strings"
"time"
"licserver/internal/model"
"gorm.io/gorm"
)
type TokenService struct {
db *gorm.DB
}
func NewTokenService(db *gorm.DB) *TokenService {
return &TokenService{db: db}
}
// 生成访问令牌
func generateToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// 创建访问令牌
func (s *TokenService) CreateToken(deviceUID, tokenType string, expireTime time.Time, ipList []string, createdBy uint) (*model.AccessToken, error) {
// 验证设备是否存在
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err != nil {
return nil, errors.New("设备不存在")
}
// 生成令牌
token, err := generateToken()
if err != nil {
return nil, err
}
// 创建令牌记录
accessToken := &model.AccessToken{
Token: token,
DeviceUID: deviceUID,
Type: tokenType,
Status: "active",
ExpireTime: expireTime,
IPList: strings.Join(ipList, ","),
CreatedBy: createdBy,
}
if err := s.db.Create(accessToken).Error; err != nil {
return nil, err
}
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "create",
Status: "success",
Message: "创建访问令牌",
}
s.db.Create(tokenLog)
return accessToken, nil
}
// 验证令牌
func (s *TokenService) ValidateToken(token, ip string) (*model.AccessToken, error) {
var accessToken model.AccessToken
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
return nil, errors.New("无效的令牌")
}
// 检查令牌状态
if accessToken.Status != "active" {
return nil, errors.New("令牌已被撤销")
}
// 检查过期时间
if time.Now().After(accessToken.ExpireTime) {
return nil, errors.New("令牌已过期")
}
// 检查IP限制
if accessToken.IPList != "" {
allowedIPs := strings.Split(accessToken.IPList, ",")
allowed := false
for _, allowedIP := range allowedIPs {
if allowedIP == ip {
allowed = true
break
}
}
if !allowed {
return nil, errors.New("IP地址不允许访问")
}
}
// 更新使用记录
s.db.Model(&accessToken).Updates(map[string]interface{}{
"last_used": time.Now(),
"usage_count": gorm.Expr("usage_count + 1"),
})
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "use",
IP: ip,
Status: "success",
}
s.db.Create(tokenLog)
return &accessToken, nil
}
// 撤销令牌
func (s *TokenService) RevokeToken(token string, userID uint) error {
var accessToken model.AccessToken
if err := s.db.Where("token = ?", token).First(&accessToken).Error; err != nil {
return errors.New("令牌不存在")
}
// 检查权限
if accessToken.CreatedBy != userID {
return errors.New("无权操作此令牌")
}
// 更新状态
if err := s.db.Model(&accessToken).Update("status", "revoked").Error; err != nil {
return err
}
// 记录日志
tokenLog := &model.TokenLog{
TokenID: accessToken.ID,
Action: "revoke",
Status: "success",
Message: "撤销访问令牌",
}
s.db.Create(tokenLog)
return nil
}
// 获取令牌列表
func (s *TokenService) GetTokens(deviceUID string, page, pageSize int) ([]model.AccessToken, int64, error) {
var tokens []model.AccessToken
var total int64
query := s.db.Model(&model.AccessToken{})
if deviceUID != "" {
query = query.Where("device_uid = ?", deviceUID)
}
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&tokens).Error
return tokens, total, err
}
// 获取令牌日志
func (s *TokenService) GetTokenLogs(tokenID uint, page, pageSize int) ([]model.TokenLog, int64, error) {
var logs []model.TokenLog
var total int64
query := s.db.Model(&model.TokenLog{}).Where("token_id = ?", tokenID)
query.Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
err := query.Order("created_at DESC").Find(&logs).Error
return logs, total, err
}

463
internal/service/upload.go Normal file
View File

@@ -0,0 +1,463 @@
package service
import (
"errors"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"sort"
"strings"
"time"
"licserver/internal/model"
"licserver/internal/utils"
"github.com/google/uuid"
"gorm.io/gorm"
"crypto/sha256"
"encoding/hex"
)
type UploadService struct {
db *gorm.DB
config *utils.Config
}
func NewUploadService(db *gorm.DB, config *utils.Config) *UploadService {
return &UploadService{
db: db,
config: config,
}
}
func (s *UploadService) UploadFile(file *multipart.FileHeader, userID uint, deviceUID, description string) (*model.FileUpload, error) {
// 生成唯一文件名
ext := filepath.Ext(file.Filename)
uniqueID := uuid.New().String()
fileName := fmt.Sprintf("%s%s", uniqueID, ext)
filePath := filepath.Join(s.config.Upload.Path, fileName)
// 确保上传目录存在
if err := os.MkdirAll(s.config.Upload.Path, 0755); err != nil {
return nil, err
}
// 保存文件
src, err := file.Open()
if err != nil {
return nil, err
}
defer src.Close()
dst, err := os.Create(filePath)
if err != nil {
return nil, err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return nil, err
}
// 如果提供了设备UID获取设备型号
var deviceModel string
if deviceUID != "" {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err == nil {
deviceModel = device.DeviceModel
}
}
// 创建数据库记录
upload := &model.FileUpload{
FileName: file.Filename,
FilePath: filePath,
FileSize: file.Size,
FileType: strings.ToLower(ext),
UploadedBy: userID,
DeviceUID: deviceUID,
DeviceModel: deviceModel,
Description: description,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(filePath)
return nil, err
}
return upload, nil
}
func (s *UploadService) DownloadFile(id uint) (*model.FileUpload, error) {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return nil, err
}
return &file, nil
}
func (s *UploadService) DeleteFile(id uint, userID uint) error {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return err
}
if file.UploadedBy != userID {
return errors.New("无权删除此文件")
}
if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&file).Error
}
func (s *UploadService) GetDeviceFiles(deviceUID string) ([]model.FileUpload, error) {
var files []model.FileUpload
err := s.db.Where("device_uid = ?", deviceUID).Find(&files).Error
return files, err
}
func (s *UploadService) UploadChunk(
file *multipart.FileHeader,
fileHash string,
chunkNumber int,
totalChunks int,
totalSize int64,
filename string,
userID uint,
deviceUID string,
) error {
// 创建分片存储目录
chunkDir := filepath.Join(s.config.Upload.Path, "chunks", fileHash)
if err := os.MkdirAll(chunkDir, 0755); err != nil {
return err
}
// 保存分片文件
chunkPath := filepath.Join(chunkDir, fmt.Sprintf("%d", chunkNumber))
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
dst, err := os.Create(chunkPath)
if err != nil {
return err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return err
}
// 记录分片信息
chunk := model.UploadChunk{
FileHash: fileHash,
ChunkNumber: chunkNumber,
ChunkSize: file.Size,
ChunkPath: chunkPath,
TotalChunks: totalChunks,
TotalSize: totalSize,
Filename: filename,
FileType: strings.ToLower(filepath.Ext(filename)),
UploadedBy: userID,
DeviceUID: deviceUID,
}
return s.db.Create(&chunk).Error
}
func (s *UploadService) CheckUploadStatus(fileHash string) (bool, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return false, err
}
if len(chunks) == 0 {
return false, nil
}
totalChunks := chunks[0].TotalChunks
return len(chunks) == totalChunks, nil
}
func (s *UploadService) MergeChunks(fileHash string) (*model.FileUpload, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return nil, err
}
if len(chunks) == 0 {
return nil, errors.New("未找到文件分片")
}
if len(chunks) != chunks[0].TotalChunks {
return nil, errors.New("文件分片不完整")
}
// 按分片序号排序
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkNumber < chunks[j].ChunkNumber
})
// 创建最终文件
finalPath := filepath.Join(s.config.Upload.Path, fmt.Sprintf("%s%s", uuid.New().String(), chunks[0].FileType))
finalFile, err := os.Create(finalPath)
if err != nil {
return nil, err
}
defer finalFile.Close()
// 合并分片
hash := sha256.New()
for _, chunk := range chunks {
chunkFile, err := os.Open(chunk.ChunkPath)
if err != nil {
return nil, err
}
if _, err = io.Copy(finalFile, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
if _, err = io.Copy(hash, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
chunkFile.Close()
os.Remove(chunk.ChunkPath) // 删除已合并的分片
}
// 验证文件哈希
if hex.EncodeToString(hash.Sum(nil)) != fileHash {
os.Remove(finalPath)
return nil, errors.New("文件哈希验证失败")
}
// 创建文件记录
upload := &model.FileUpload{
FileName: chunks[0].Filename,
FilePath: finalPath,
FileSize: chunks[0].TotalSize,
FileType: chunks[0].FileType,
UploadedBy: chunks[0].UploadedBy,
DeviceUID: chunks[0].DeviceUID,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(finalPath)
return nil, err
}
// 清理分片记录
s.db.Where("file_hash = ?", fileHash).Delete(&model.UploadChunk{})
os.RemoveAll(filepath.Dir(chunks[0].ChunkPath))
return upload, nil
}
func (s *UploadService) CleanupExpiredChunks() error {
expireTime := time.Now().Add(-24 * time.Hour)
var expiredChunks []model.UploadChunk
if err := s.db.Where("completed = ? AND created_at < ?", false, expireTime).Find(&expiredChunks).Error; err != nil {
return err
}
for _, chunk := range expiredChunks {
os.Remove(chunk.ChunkPath)
if len(chunk.ChunkPath) > 0 {
chunkDir := filepath.Dir(chunk.ChunkPath)
os.RemoveAll(chunkDir)
}
}
return s.db.Unscoped().Where("completed = ? AND created_at < ?", false, expireTime).Delete(&model.UploadChunk{}).Error
}

361
internal/service/user.go Normal file
View File

@@ -0,0 +1,361 @@
package service
import (
"errors"
"fmt"
"time"
"licserver/internal/model"
"licserver/internal/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type UserService struct {
db *gorm.DB
config *utils.Config
captchaService *CaptchaService
}
type UserProfile struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Role string `json:"role"`
}
func NewUserService(db *gorm.DB, config *utils.Config) *UserService {
return &UserService{
db: db,
config: config,
captchaService: NewCaptchaService(db, &config.Email),
}
}
func (s *UserService) Register(username, password, email, captcha string) error {
// 验证验证码
if err := s.captchaService.VerifyCaptcha(email, "register", captcha); err != nil {
return err
}
// 检查用户名是否已存在
var count int64
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
if count > 0 {
return errors.New("用户名已存在")
}
// 检查邮箱是否已存在
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
// 原有的注册逻辑
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user := model.User{
Username: username,
Password: string(hashedPassword),
Email: email,
Role: "user",
}
return s.db.Create(&user).Error
}
func (s *UserService) Login(username, password string) (string, error) {
var user model.User
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", errors.New("用户不存在")
}
return "", err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
return "", errors.New("密码错误")
}
// 生成 JWT token
token, err := utils.GenerateToken(user.ID, user.Username, user.Role, &s.config.JWT)
if err != nil {
return "", err
}
// 更新最后登录时间
s.db.Model(&user).Update("last_login", gorm.Expr("CURRENT_TIMESTAMP"))
return token, nil
}
func (s *UserService) GetUserByID(id uint) (*UserProfile, error) {
var user model.User
if err := s.db.First(&user, id).Error; err != nil {
return nil, err
}
return &UserProfile{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Role: user.Role,
}, nil
}
func (s *UserService) UpdateProfile(userID uint, email string) error {
// 检查邮箱是否被其他用户使用
var count int64
if err := s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, userID).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return errors.New("邮箱已被其他用户使用")
}
// 更新用户信息
return s.db.Model(&model.User{}).Where("id = ?", userID).Update("email", email).Error
}
func (s *UserService) ChangePassword(userID uint, oldPassword, newPassword string) error {
var user model.User
if err := s.db.First(&user, userID).Error; err != nil {
return err
}
// 验证旧密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(oldPassword)); err != nil {
return errors.New("旧密码错误")
}
// 加密新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
return s.db.Model(&user).Update("password", string(hashedPassword)).Error
}
func (s *UserService) ResetPassword(email, captcha string) error {
// 验证验证码
if err := s.captchaService.VerifyCaptcha(email, "reset", captcha); err != nil {
return err
}
// 原有的重置密码逻辑
var user model.User
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
return errors.New("邮箱不存在")
}
// 生成重置令牌
token, err := utils.GenerateResetToken()
if err != nil {
return err
}
// 保存重置令牌
resetToken := model.PasswordResetToken{
UserID: user.ID,
Token: token,
ExpiresAt: time.Now().Add(24 * time.Hour),
Used: false,
}
if err := s.db.Create(&resetToken).Error; err != nil {
return err
}
// 发送重置邮件
emailService := utils.NewEmailService(&s.config.Email)
resetLink := fmt.Sprintf("http://localhost:%s/reset-password?token=%s", s.config.Server.Port, token)
emailBody := fmt.Sprintf(`
<h3>密码重置</h3>
<p>您好,%s</p>
<p>请点击以下链接重置您的密码:</p>
<p><a href="%s">重置密码</a></p>
<p>此链接将在24小时后失效。</p>
<p>如果您没有请求重置密码,请忽略此邮件。</p>
`, user.Username, resetLink)
return emailService.SendEmail(user.Email, "密码重置", emailBody)
}
func (s *UserService) ValidateResetToken(token string) (*model.User, error) {
var resetToken model.PasswordResetToken
if err := s.db.Where("token = ? AND used = ? AND expires_at > ?",
token, false, time.Now()).First(&resetToken).Error; err != nil {
return nil, errors.New("无效或已过期的重置令牌")
}
var user model.User
if err := s.db.First(&user, resetToken.UserID).Error; err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) ResetPasswordWithToken(token, newPassword string) error {
user, err := s.ValidateResetToken(token)
if err != nil {
return err
}
// 更新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
// 使用事务确保原子性
return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&user).Update("password", string(hashedPassword)).Error; err != nil {
return err
}
// 标记令牌为已使用
if err := tx.Model(&model.PasswordResetToken{}).
Where("token = ?", token).
Update("used", true).Error; err != nil {
return err
}
return nil
})
}
func (s *UserService) SendRegisterCaptcha(email string) error {
// 检查邮箱是否已被注册
var count int64
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
return s.captchaService.SendEmailCaptcha(email, "register")
}
func (s *UserService) SendResetPasswordCaptcha(email string) error {
var user model.User
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
return errors.New("邮箱不存在")
}
return s.captchaService.SendEmailCaptcha(email, "reset")
}
func (s *UserService) GetCaptchaService() *CaptchaService {
return s.captchaService
}
// GetUsers 获取用户列表
func (s *UserService) GetUsers(username, role string, page, pageSize int) ([]UserProfile, int64, error) {
var users []model.User
var total int64
var profiles []UserProfile
query := s.db.Model(&model.User{})
if username != "" {
query = query.Where("username LIKE ?", "%"+username+"%")
}
if role != "" {
query = query.Where("role = ?", role)
}
// 获取总数
query.Count(&total)
// 分页查询
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
query = query.Offset(offset).Limit(pageSize)
}
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
// 转换为 UserProfile
for _, user := range users {
profiles = append(profiles, UserProfile{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Role: user.Role,
})
}
return profiles, total, nil
}
// CreateUser 创建新用户
func (s *UserService) CreateUser(username, password, email, role string) error {
// 检查用户名是否已存在
var count int64
s.db.Model(&model.User{}).Where("username = ?", username).Count(&count)
if count > 0 {
return errors.New("用户名已存在")
}
// 检查邮箱是否已存在
s.db.Model(&model.User{}).Where("email = ?", email).Count(&count)
if count > 0 {
return errors.New("邮箱已被注册")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user := model.User{
Username: username,
Password: string(hashedPassword),
Email: email,
Role: role,
}
return s.db.Create(&user).Error
}
// UpdateUser 更新用户信息
func (s *UserService) UpdateUser(id uint, email, role string) error {
// 检查邮箱是否被其他用户使用
var count int64
s.db.Model(&model.User{}).Where("email = ? AND id != ?", email, id).Count(&count)
if count > 0 {
return errors.New("邮箱已被其他用户使用")
}
return s.db.Model(&model.User{}).Where("id = ?", id).Updates(map[string]interface{}{
"email": email,
"role": role,
}).Error
}
// DeleteUser 删除用户
func (s *UserService) DeleteUser(id uint) error {
// 检查是否为最后一个管理员
var adminCount int64
s.db.Model(&model.User{}).Where("role = ?", "admin").Count(&adminCount)
var user model.User
if err := s.db.First(&user, id).Error; err != nil {
return err
}
if user.Role == "admin" && adminCount <= 1 {
return errors.New("不能删除最后一个管理员")
}
return s.db.Delete(&model.User{}, id).Error
}

View File

@@ -0,0 +1,61 @@
package service
import (
"licserver/internal/model"
"licserver/internal/utils"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUserService_Register(t *testing.T) {
db := utils.TestDB(t)
config := utils.TestConfig()
userService := NewUserService(db, config)
// 测试正常注册
err := userService.Register("testuser", "password123", "test@example.com", "123456")
assert.NoError(t, err)
var user model.User
err = db.Where("username = ?", "testuser").First(&user).Error
assert.NoError(t, err)
assert.Equal(t, "testuser", user.Username)
assert.Equal(t, "test@example.com", user.Email)
assert.Equal(t, "user", user.Role)
// 测试重复用户名
err = userService.Register("testuser", "password123", "test2@example.com", "123456")
assert.Error(t, err)
assert.Contains(t, err.Error(), "用户名已存在")
// 测试重复邮箱
err = userService.Register("testuser2", "password123", "test@example.com", "123456")
assert.Error(t, err)
assert.Contains(t, err.Error(), "邮箱已被注册")
}
func TestUserService_Login(t *testing.T) {
db := utils.TestDB(t)
config := utils.TestConfig()
userService := NewUserService(db, config)
// 创建测试用户
err := userService.Register("testuser", "password123", "test@example.com", "123456")
assert.NoError(t, err)
// 测试正确登录
token, err := userService.Login("testuser", "password123")
assert.NoError(t, err)
assert.NotEmpty(t, token)
// 测试错误密码
_, err = userService.Login("testuser", "wrongpassword")
assert.Error(t, err)
assert.Contains(t, err.Error(), "密码错误")
// 测试不存在的用户
_, err = userService.Login("nonexistent", "password123")
assert.Error(t, err)
assert.Contains(t, err.Error(), "用户不存在")
}

31
internal/utils/captcha.go Normal file
View File

@@ -0,0 +1,31 @@
package utils
import (
"crypto/rand"
"fmt"
)
// GenerateCaptcha 生成6位数字验证码
func GenerateCaptcha() (string, error) {
// 生成6位随机数字
b := make([]byte, 3)
if _, err := rand.Read(b); err != nil {
return "", err
}
// 将随机字节转换为6位数字
num := int(b[0])<<16 | int(b[1])<<8 | int(b[2])
return fmt.Sprintf("%06d", num%1000000), nil
}
// GenerateEmailCaptchaContent 生成验证码邮件内容
func GenerateEmailCaptchaContent(code, username, action string) string {
return fmt.Sprintf(`
<h3>验证码</h3>
<p>您好,%s</p>
<p>您正在进行%s操作验证码为</p>
<h2 style="color: #1890ff;">%s</h2>
<p>验证码有效期为5分钟请勿泄露给他人。</p>
<p>如果这不是您的操作,请忽略此邮件。</p>
`, username, action, code)
}

77
internal/utils/config.go Normal file
View File

@@ -0,0 +1,77 @@
package utils
import (
"os"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig
Database DatabaseConfig
JWT JWTConfig
Email EmailConfig
Upload UploadConfig
Site SiteConfig
}
type ServerConfig struct {
Port string
Mode string
}
type DatabaseConfig struct {
Type string
Path string
}
type JWTConfig struct {
Secret string
Expire string
}
type EmailConfig struct {
Host string
Port int
Username string
Password string
}
type UploadConfig struct {
Path string
}
type SiteConfig struct {
Title string `mapstructure:"title"`
Description string `mapstructure:"description"`
BaseURL string `mapstructure:"base_url"`
ICP string `mapstructure:"icp"`
Copyright string `mapstructure:"copyright"`
Logo string `mapstructure:"logo"`
Favicon string `mapstructure:"favicon"`
}
func LoadConfig() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
// 读取环境变量
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err != nil {
return nil, err
}
config := &Config{}
if err := viper.Unmarshal(config); err != nil {
return nil, err
}
// 环境变量优先
if port := os.Getenv("SERVER_PORT"); port != "" {
config.Server.Port = port
}
return config, nil
}

View File

@@ -0,0 +1,232 @@
package utils
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
type ConfigVersion struct {
Version int `json:"version"`
UpdatedAt time.Time `json:"updated_at"`
UpdatedBy string `json:"updated_by"`
Comment string `json:"comment"`
}
type ConfigWithVersion struct {
Config *Config `json:"config"`
Version ConfigVersion `json:"version"`
}
var (
configMutex sync.RWMutex
configFile = "config/config.json" // JSON格式更适合动态更新
)
// SaveConfig 保存配置到文件
func SaveConfig(config *Config, updatedBy, comment string) error {
configMutex.Lock()
defer configMutex.Unlock()
// 读取当前版本
currentVersion := 0
if existing, err := LoadPersistedConfig(); err == nil {
currentVersion = existing.Version.Version
}
// 创建新的配置版本
configWithVersion := ConfigWithVersion{
Config: config,
Version: ConfigVersion{
Version: currentVersion + 1,
UpdatedAt: time.Now(),
UpdatedBy: updatedBy,
Comment: comment,
},
}
// 确保配置目录存在
configDir := filepath.Dir(configFile)
if err := os.MkdirAll(configDir, 0755); err != nil {
return err
}
// 备份旧配置
if err := backupConfig(); err != nil {
return err
}
// 将配置转换为JSON
data, err := json.MarshalIndent(configWithVersion, "", " ")
if err != nil {
return err
}
// 写入文件
return os.WriteFile(configFile, data, 0644)
}
// LoadPersistedConfig 加载持久化的配置
func LoadPersistedConfig() (*ConfigWithVersion, error) {
configMutex.RLock()
defer configMutex.RUnlock()
// 检查配置文件是否存在
if _, err := os.Stat(configFile); os.IsNotExist(err) {
// 如果不存在,创建默认配置
config, err := LoadConfig()
if err != nil {
return nil, err
}
return &ConfigWithVersion{
Config: config,
Version: ConfigVersion{
Version: 1,
UpdatedAt: time.Now(),
UpdatedBy: "system",
Comment: "初始配置",
},
}, nil
}
// 读取配置文件
data, err := os.ReadFile(configFile)
if err != nil {
return nil, err
}
// 解析JSON配置
var configWithVersion ConfigWithVersion
if err := json.Unmarshal(data, &configWithVersion); err != nil {
return nil, err
}
return &configWithVersion, nil
}
// MergeConfig 合并配置(环境变量优先)
func MergeConfig(persisted, env *Config) *Config {
if env.Server.Port != "" {
persisted.Server.Port = env.Server.Port
}
if env.Server.Mode != "" {
persisted.Server.Mode = env.Server.Mode
}
// ... 其他配置项的合并 ...
return persisted
}
// backupConfig 备份配置文件
func backupConfig() error {
if _, err := os.Stat(configFile); os.IsNotExist(err) {
return nil
}
// 确保备份目录存在
backupDir := "config/backups"
if err := os.MkdirAll(backupDir, 0755); err != nil {
return err
}
// 读取当前配置
data, err := os.ReadFile(configFile)
if err != nil {
return err
}
// 创建备份文件名
backupFile := filepath.Join(backupDir,
fmt.Sprintf("config_%s.json", time.Now().Format("20060102150405")))
// 写入备份文件
return os.WriteFile(backupFile, data, 0644)
}

View File

@@ -0,0 +1,56 @@
package utils
import (
"licserver/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func InitDB(config *DatabaseConfig) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(config.Path), &gorm.Config{})
if err != nil {
return nil, err
}
// 自动迁移数据库结构
err = db.AutoMigrate(
&model.User{},
&model.Device{},
&model.DeviceModel{},
&model.PasswordResetToken{},
&model.Captcha{},
&model.FileUpload{},
&model.UploadChunk{},
&model.LicenseCode{},
&model.LicenseLog{},
&model.AccessToken{},
&model.TokenLog{},
)
if err != nil {
return nil, err
}
return db, nil
}

40
internal/utils/email.go Normal file
View File

@@ -0,0 +1,40 @@
package utils
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/smtp"
)
type EmailService struct {
config *EmailConfig
auth smtp.Auth
}
func NewEmailService(config *EmailConfig) *EmailService {
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
return &EmailService{
config: config,
auth: auth,
}
}
func (s *EmailService) SendEmail(to, subject, body string) error {
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
msg := []byte(fmt.Sprintf("To: %s\r\n"+
"Subject: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n"+
"\r\n"+
"%s\r\n", to, subject, body))
return smtp.SendMail(addr, s.auth, s.config.Username, []string{to}, msg)
}
func GenerateResetToken() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

29
internal/utils/errors.go Normal file
View File

@@ -0,0 +1,29 @@
package utils
import "errors"
var (
// 通用错误
ErrInvalidInput = errors.New("无效的输入")
ErrNotFound = errors.New("资源不存在")
ErrUnauthorized = errors.New("未授权的访问")
ErrForbidden = errors.New("禁止访问")
// 授权相关错误
ErrInvalidToken = errors.New("无效的令牌")
ErrTokenExpired = errors.New("令牌已过期")
ErrInvalidCaptcha = errors.New("无效的验证码")
ErrCaptchaExpired = errors.New("验证码已过期")
ErrInvalidLicense = errors.New("无效的授权码")
ErrLicenseExpired = errors.New("授权码已过期")
ErrLicenseUsed = errors.New("授权码已被使用")
ErrDeviceNotFound = errors.New("设备不存在")
ErrDeviceRegistered = errors.New("设备已注册")
)
// ErrorResponse 统一错误响应结构
type ErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Detail string `json:"detail,omitempty"`
}

51
internal/utils/jwt.go Normal file
View File

@@ -0,0 +1,51 @@
package utils
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
)
type Claims struct {
UserID uint
Username string
Role string
jwt.RegisteredClaims
}
func GenerateToken(userID uint, username, role string, config *JWTConfig) (string, error) {
expDuration, err := time.ParseDuration(config.Expire)
if err != nil {
return "", err
}
claims := Claims{
UserID: userID,
Username: username,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expDuration)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(config.Secret))
}
func ParseToken(tokenString string, config *JWTConfig) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(config.Secret), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}

50
internal/utils/logger.go Normal file
View File

@@ -0,0 +1,50 @@
package utils
import (
"fmt"
"os"
"path/filepath"
"time"
)
type Logger struct {
logFile *os.File
}
func NewLogger(logPath string) (*Logger, error) {
// 确保日志目录存在
if err := os.MkdirAll(filepath.Dir(logPath), 0755); err != nil {
return nil, err
}
// 打开日志文件
file, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, err
}
return &Logger{logFile: file}, nil
}
func (l *Logger) Info(format string, args ...interface{}) {
l.log("INFO", format, args...)
}
func (l *Logger) Error(format string, args ...interface{}) {
l.log("ERROR", format, args...)
}
func (l *Logger) Debug(format string, args ...interface{}) {
l.log("DEBUG", format, args...)
}
func (l *Logger) log(level, format string, args ...interface{}) {
timestamp := time.Now().Format("2006-01-02 15:04:05")
message := fmt.Sprintf(format, args...)
logLine := fmt.Sprintf("[%s] [%s] %s\n", timestamp, level, message)
l.logFile.WriteString(logLine)
}
func (l *Logger) Close() error {
return l.logFile.Close()
}

View File

@@ -0,0 +1,68 @@
package utils
import (
"licserver/internal/model"
"testing"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// TestDB 创建测试数据库连接
func TestDB(t *testing.T) *gorm.DB {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 迁移测试表
err = db.AutoMigrate(
&model.User{},
&model.Device{},
&model.DeviceModel{},
&model.LicenseCode{},
&model.LicenseLog{},
&model.AccessToken{},
&model.TokenLog{},
&model.Captcha{},
&model.PasswordResetToken{},
&model.FileUpload{},
&model.UploadChunk{},
)
assert.NoError(t, err)
return db
}
// TestConfig 创建测试配置
func TestConfig() *Config {
return &Config{
Server: ServerConfig{
Port: "8080",
Mode: "test",
},
Database: DatabaseConfig{
Type: "sqlite3",
Path: ":memory:",
},
JWT: JWTConfig{
Secret: "test-secret",
Expire: "24h",
},
Email: EmailConfig{
Host: "smtp.example.com",
Port: 587,
Username: "test@example.com",
Password: "test-password",
},
Upload: UploadConfig{
Path: "./test-uploads",
},
Site: SiteConfig{
Title: "Test Site",
Description: "Test Description",
BaseURL: "http://localhost:8080",
ICP: "Test ICP",
Copyright: "Test Copyright",
},
}
}