299 lines
6.9 KiB
Go
299 lines
6.9 KiB
Go
package service
|
|
|
|
import (
|
|
"licserver/internal/model"
|
|
"licserver/internal/utils"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestDeviceService_RegisterDevice(t *testing.T) {
|
|
db := utils.TestDB(t)
|
|
licenseService := NewLicenseService(db)
|
|
deviceService := NewDeviceService(db, licenseService)
|
|
|
|
// 创建测试设备型号
|
|
deviceModel := &model.DeviceModel{
|
|
ModelName: "test-model",
|
|
DeviceType: "software",
|
|
Company: "test-company",
|
|
Status: "active",
|
|
CreatedBy: 1,
|
|
}
|
|
err := db.Create(deviceModel).Error
|
|
assert.NoError(t, err)
|
|
|
|
// 创建测试授权码
|
|
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
|
|
LicenseType: "time",
|
|
Duration: 30,
|
|
Count: 1,
|
|
Remark: "test",
|
|
}, 1)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, licenses, 1)
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *DeviceRegisterInput
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "正常注册设备",
|
|
input: &DeviceRegisterInput{
|
|
UID: "test-device-001",
|
|
DeviceModel: "test-model",
|
|
LicenseCode: licenses[0].Code,
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "重复注册设备",
|
|
input: &DeviceRegisterInput{
|
|
UID: "test-device-001",
|
|
DeviceModel: "test-model",
|
|
LicenseCode: licenses[0].Code,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "使用无效授权码",
|
|
input: &DeviceRegisterInput{
|
|
UID: "test-device-002",
|
|
DeviceModel: "test-model",
|
|
LicenseCode: "invalid-code",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "使用不存在的设备型号",
|
|
input: &DeviceRegisterInput{
|
|
UID: "test-device-003",
|
|
DeviceModel: "non-existent-model",
|
|
LicenseCode: licenses[0].Code,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := deviceService.RegisterDevice(tt.input, "127.0.0.1")
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
|
|
// 验证设备是否正确注册
|
|
var device model.Device
|
|
err = db.Where("uid = ?", tt.input.UID).First(&device).Error
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, deviceModel.DeviceType, device.DeviceType)
|
|
assert.Equal(t, "active", device.Status)
|
|
assert.Equal(t, tt.input.LicenseCode, device.LicenseCode)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeviceService_ValidateDevice(t *testing.T) {
|
|
db := utils.TestDB(t)
|
|
licenseService := NewLicenseService(db)
|
|
deviceService := NewDeviceService(db, licenseService)
|
|
|
|
// 创建测试设备型号
|
|
deviceModel := &model.DeviceModel{
|
|
ModelName: "test-model",
|
|
DeviceType: "software",
|
|
Company: "test-company",
|
|
CreatedBy: 1,
|
|
}
|
|
err := db.Create(deviceModel).Error
|
|
assert.NoError(t, err)
|
|
|
|
// 创建测试授权码
|
|
licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{
|
|
LicenseType: "time",
|
|
Duration: 30,
|
|
Count: 1,
|
|
Remark: "test",
|
|
}, 1)
|
|
assert.NoError(t, err)
|
|
|
|
// 注册测试设备
|
|
device := &model.Device{
|
|
UID: "test-device-001",
|
|
DeviceType: deviceModel.DeviceType,
|
|
DeviceModel: deviceModel.ModelName,
|
|
Company: deviceModel.Company,
|
|
RegisterTime: time.Now(),
|
|
Status: "active",
|
|
LicenseCode: licenses[0].Code,
|
|
LicenseType: "time",
|
|
Duration: 30,
|
|
ExpireTime: time.Now().Add(30 * 24 * time.Hour),
|
|
}
|
|
err = db.Create(device).Error
|
|
assert.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
uid string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "验证正常设备",
|
|
uid: "test-device-001",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "验证不存在的设备",
|
|
uid: "non-existent-device",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "验证过期设备",
|
|
uid: "expired-device",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
// 创建过期设备
|
|
expiredDevice := &model.Device{
|
|
UID: "expired-device",
|
|
DeviceType: deviceModel.DeviceType,
|
|
DeviceModel: deviceModel.ModelName,
|
|
Company: deviceModel.Company,
|
|
RegisterTime: time.Now(),
|
|
Status: "expired",
|
|
LicenseCode: "expired-license",
|
|
LicenseType: "time",
|
|
Duration: 30,
|
|
ExpireTime: time.Now().Add(-24 * time.Hour),
|
|
}
|
|
err = db.Create(expiredDevice).Error
|
|
assert.NoError(t, err)
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := deviceService.ValidateDevice(tt.uid)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeviceService_UpdateStartCount(t *testing.T) {
|
|
db := utils.TestDB(t)
|
|
licenseService := NewLicenseService(db)
|
|
deviceService := NewDeviceService(db, licenseService)
|
|
|
|
// 创建测试设备型号
|
|
deviceModel := &model.DeviceModel{
|
|
ModelName: "test-model",
|
|
DeviceType: "software",
|
|
Company: "test-company",
|
|
CreatedBy: 1,
|
|
}
|
|
err := db.Create(deviceModel).Error
|
|
assert.NoError(t, err)
|
|
|
|
// 创建测试设备
|
|
device := &model.Device{
|
|
UID: "test-device-001",
|
|
DeviceType: deviceModel.DeviceType,
|
|
DeviceModel: deviceModel.ModelName,
|
|
Company: deviceModel.Company,
|
|
RegisterTime: time.Now(),
|
|
Status: "active",
|
|
LicenseType: "count",
|
|
MaxUses: 5,
|
|
StartCount: 0,
|
|
}
|
|
err = db.Create(device).Error
|
|
assert.NoError(t, err)
|
|
|
|
// 测试更新启动次数
|
|
for i := 1; i <= 5; i++ {
|
|
err = deviceService.UpdateStartCount(device.UID)
|
|
assert.NoError(t, err)
|
|
|
|
// 验证启动次数
|
|
var updatedDevice model.Device
|
|
err = db.First(&updatedDevice, "uid = ?", device.UID).Error
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, i, updatedDevice.StartCount)
|
|
|
|
// 检查最后一次是否将状态更新为过期
|
|
if i == 5 {
|
|
assert.Equal(t, "expired", updatedDevice.Status)
|
|
} else {
|
|
assert.Equal(t, "active", updatedDevice.Status)
|
|
}
|
|
}
|
|
|
|
// 测试超出使用次数
|
|
err = deviceService.UpdateStartCount(device.UID)
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestDeviceService_CreateDeviceModel(t *testing.T) {
|
|
db := utils.TestDB(t)
|
|
licenseService := NewLicenseService(db)
|
|
deviceService := NewDeviceService(db, licenseService)
|
|
|
|
tests := []struct {
|
|
name string
|
|
model *model.DeviceModel
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "创建有效设备型号",
|
|
model: &model.DeviceModel{
|
|
ModelName: "test-model",
|
|
DeviceType: "software",
|
|
Company: "test-company",
|
|
Status: "active",
|
|
CreatedBy: 1,
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "重复的设备型号",
|
|
model: &model.DeviceModel{
|
|
ModelName: "test-model",
|
|
DeviceType: "software",
|
|
Company: "test-company",
|
|
Status: "active",
|
|
CreatedBy: 1,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := deviceService.CreateDeviceModel(tt.model)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
|
|
// 验证创建的设备型号
|
|
var savedModel model.DeviceModel
|
|
err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType)
|
|
assert.Equal(t, tt.model.Company, savedModel.Company)
|
|
assert.Equal(t, tt.model.Status, savedModel.Status)
|
|
})
|
|
}
|
|
}
|