package service import ( "licserver/internal/model" "licserver/internal/utils" "testing" "time" "github.com/stretchr/testify/assert" ) func TestDeviceService_RegisterDevice(t *testing.T) { db := utils.TestDB(t) licenseService := NewLicenseService(db) deviceService := NewDeviceService(db, licenseService) // 创建测试设备型号 deviceModel := &model.DeviceModel{ ModelName: "test-model", DeviceType: "software", Company: "test-company", Status: "active", CreatedBy: 1, } err := db.Create(deviceModel).Error assert.NoError(t, err) // 创建测试授权码 licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{ LicenseType: "time", Duration: 30, Count: 1, Remark: "test", }, 1) assert.NoError(t, err) assert.Len(t, licenses, 1) tests := []struct { name string input *DeviceRegisterInput wantErr bool }{ { name: "正常注册设备", input: &DeviceRegisterInput{ UID: "test-device-001", DeviceModel: "test-model", LicenseCode: licenses[0].Code, }, wantErr: false, }, { name: "重复注册设备", input: &DeviceRegisterInput{ UID: "test-device-001", DeviceModel: "test-model", LicenseCode: licenses[0].Code, }, wantErr: true, }, { name: "使用无效授权码", input: &DeviceRegisterInput{ UID: "test-device-002", DeviceModel: "test-model", LicenseCode: "invalid-code", }, wantErr: true, }, { name: "使用不存在的设备型号", input: &DeviceRegisterInput{ UID: "test-device-003", DeviceModel: "non-existent-model", LicenseCode: licenses[0].Code, }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := deviceService.RegisterDevice(tt.input, "127.0.0.1") if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 验证设备是否正确注册 var device model.Device err = db.Where("uid = ?", tt.input.UID).First(&device).Error assert.NoError(t, err) assert.Equal(t, deviceModel.DeviceType, device.DeviceType) assert.Equal(t, "active", device.Status) assert.Equal(t, tt.input.LicenseCode, device.LicenseCode) }) } } func TestDeviceService_ValidateDevice(t *testing.T) { db := utils.TestDB(t) licenseService := NewLicenseService(db) deviceService := NewDeviceService(db, licenseService) // 创建测试设备型号 deviceModel := &model.DeviceModel{ ModelName: "test-model", DeviceType: "software", Company: "test-company", CreatedBy: 1, } err := db.Create(deviceModel).Error assert.NoError(t, err) // 创建测试授权码 licenses, err := licenseService.CreateLicenses(&LicenseCreateInput{ LicenseType: "time", Duration: 30, Count: 1, Remark: "test", }, 1) assert.NoError(t, err) // 注册测试设备 device := &model.Device{ UID: "test-device-001", DeviceType: deviceModel.DeviceType, DeviceModel: deviceModel.ModelName, Company: deviceModel.Company, RegisterTime: time.Now(), Status: "active", LicenseCode: licenses[0].Code, LicenseType: "time", Duration: 30, ExpireTime: time.Now().Add(30 * 24 * time.Hour), } err = db.Create(device).Error assert.NoError(t, err) tests := []struct { name string uid string wantErr bool }{ { name: "验证正常设备", uid: "test-device-001", wantErr: false, }, { name: "验证不存在的设备", uid: "non-existent-device", wantErr: true, }, { name: "验证过期设备", uid: "expired-device", wantErr: true, }, } // 创建过期设备 expiredDevice := &model.Device{ UID: "expired-device", DeviceType: deviceModel.DeviceType, DeviceModel: deviceModel.ModelName, Company: deviceModel.Company, RegisterTime: time.Now(), Status: "expired", LicenseCode: "expired-license", LicenseType: "time", Duration: 30, ExpireTime: time.Now().Add(-24 * time.Hour), } err = db.Create(expiredDevice).Error assert.NoError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := deviceService.ValidateDevice(tt.uid) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) }) } } func TestDeviceService_UpdateStartCount(t *testing.T) { db := utils.TestDB(t) licenseService := NewLicenseService(db) deviceService := NewDeviceService(db, licenseService) // 创建测试设备型号 deviceModel := &model.DeviceModel{ ModelName: "test-model", DeviceType: "software", Company: "test-company", CreatedBy: 1, } err := db.Create(deviceModel).Error assert.NoError(t, err) // 创建测试设备 device := &model.Device{ UID: "test-device-001", DeviceType: deviceModel.DeviceType, DeviceModel: deviceModel.ModelName, Company: deviceModel.Company, RegisterTime: time.Now(), Status: "active", LicenseType: "count", MaxUses: 5, StartCount: 0, } err = db.Create(device).Error assert.NoError(t, err) // 测试更新启动次数 for i := 1; i <= 5; i++ { err = deviceService.UpdateStartCount(device.UID) assert.NoError(t, err) // 验证启动次数 var updatedDevice model.Device err = db.First(&updatedDevice, "uid = ?", device.UID).Error assert.NoError(t, err) assert.Equal(t, i, updatedDevice.StartCount) // 检查最后一次是否将状态更新为过期 if i == 5 { assert.Equal(t, "expired", updatedDevice.Status) } else { assert.Equal(t, "active", updatedDevice.Status) } } // 测试超出使用次数 err = deviceService.UpdateStartCount(device.UID) assert.Error(t, err) } func TestDeviceService_CreateDeviceModel(t *testing.T) { db := utils.TestDB(t) licenseService := NewLicenseService(db) deviceService := NewDeviceService(db, licenseService) tests := []struct { name string model *model.DeviceModel wantErr bool }{ { name: "创建有效设备型号", model: &model.DeviceModel{ ModelName: "test-model", DeviceType: "software", Company: "test-company", Status: "active", CreatedBy: 1, }, wantErr: false, }, { name: "重复的设备型号", model: &model.DeviceModel{ ModelName: "test-model", DeviceType: "software", Company: "test-company", Status: "active", CreatedBy: 1, }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := deviceService.CreateDeviceModel(tt.model) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) // 验证创建的设备型号 var savedModel model.DeviceModel err = db.Where("model_name = ?", tt.model.ModelName).First(&savedModel).Error assert.NoError(t, err) assert.Equal(t, tt.model.DeviceType, savedModel.DeviceType) assert.Equal(t, tt.model.Company, savedModel.Company) assert.Equal(t, tt.model.Status, savedModel.Status) }) } }