LicenseManger/internal/service/device_test.go

299 lines
6.9 KiB
Go

package service
import (
"licserver/internal/model"
"licserver/internal/utils"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDeviceService_RegisterDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
assert.Len(t, licenses, 1)
tests := []struct {
name string
input *DeviceRegisterInput
wantErr bool
}{
{
name: "正常注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: false,
},
{
name: "重复注册设备",
input: &DeviceRegisterInput{
UID: "test-device-001",
DeviceModel: "test-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
{
name: "使用无效授权码",
input: &DeviceRegisterInput{
UID: "test-device-002",
DeviceModel: "test-model",
LicenseCode: "invalid-code",
},
wantErr: true,
},
{
name: "使用不存在的设备型号",
input: &DeviceRegisterInput{
UID: "test-device-003",
DeviceModel: "non-existent-model",
LicenseCode: licenses[0].Code,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.RegisterDevice(tt.input, "127.0.0.1")
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证设备是否正确注册
var device model.Device
err = db.Where("uid = ?", tt.input.UID).First(&device).Error
assert.NoError(t, err)
assert.Equal(t, deviceModel.DeviceType, device.DeviceType)
assert.Equal(t, "active", device.Status)
assert.Equal(t, tt.input.LicenseCode, device.LicenseCode)
})
}
}
func TestDeviceService_ValidateDevice(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试授权码
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
LicenseType: "time",
Duration: 30,
Count: 1,
Remark: "test",
}, 1)
assert.NoError(t, err)
// 注册测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseCode: licenses[0].Code,
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(30 * 24 * time.Hour),
}
err = db.Create(device).Error
assert.NoError(t, err)
tests := []struct {
name string
uid string
wantErr bool
}{
{
name: "验证正常设备",
uid: "test-device-001",
wantErr: false,
},
{
name: "验证不存在的设备",
uid: "non-existent-device",
wantErr: true,
},
{
name: "验证过期设备",
uid: "expired-device",
wantErr: true,
},
}
// 创建过期设备
expiredDevice := &model.Device{
UID: "expired-device",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "expired",
LicenseCode: "expired-license",
LicenseType: "time",
Duration: 30,
ExpireTime: time.Now().Add(-24 * time.Hour),
}
err = db.Create(expiredDevice).Error
assert.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.ValidateDevice(tt.uid)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}
func TestDeviceService_UpdateStartCount(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
// 创建测试设备型号
deviceModel := &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
CreatedBy: 1,
}
err := db.Create(deviceModel).Error
assert.NoError(t, err)
// 创建测试设备
device := &model.Device{
UID: "test-device-001",
DeviceType: deviceModel.DeviceType,
DeviceModel: deviceModel.ModelName,
Company: deviceModel.Company,
RegisterTime: time.Now(),
Status: "active",
LicenseType: "count",
MaxUses: 5,
StartCount: 0,
}
err = db.Create(device).Error
assert.NoError(t, err)
// 测试更新启动次数
for i := 1; i <= 5; i++ {
err = deviceService.UpdateStartCount(device.UID)
assert.NoError(t, err)
// 验证启动次数
var updatedDevice model.Device
err = db.First(&updatedDevice, "uid = ?", device.UID).Error
assert.NoError(t, err)
assert.Equal(t, i, updatedDevice.StartCount)
// 检查最后一次是否将状态更新为过期
if i == 5 {
assert.Equal(t, "expired", updatedDevice.Status)
} else {
assert.Equal(t, "active", updatedDevice.Status)
}
}
// 测试超出使用次数
err = deviceService.UpdateStartCount(device.UID)
assert.Error(t, err)
}
func TestDeviceService_CreateDeviceModel(t *testing.T) {
db := utils.TestDB(t)
licenseService := NewLicenseService(db)
deviceService := NewDeviceService(db, licenseService)
tests := []struct {
name string
model *model.DeviceModel
wantErr bool
}{
{
name: "创建有效设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: false,
},
{
name: "重复的设备型号",
model: &model.DeviceModel{
ModelName: "test-model",
DeviceType: "software",
Company: "test-company",
Status: "active",
CreatedBy: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := deviceService.CreateDeviceModel(tt.model)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 验证创建的设备型号
var savedModel model.DeviceModel
err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error
assert.NoError(t, err)
assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType)
assert.Equal(t, tt.model.Company, savedModel.Company)
assert.Equal(t, tt.model.Status, savedModel.Status)
})
}
}