package service import ( "errors" "fmt" "io" "mime/multipart" "os" "path/filepath" "sort" "strings" "time" "licserver/internal/model" "licserver/internal/utils" "github.com/google/uuid" "gorm.io/gorm" "crypto/sha256" "encoding/hex" ) type UploadService struct { db *gorm.DB config *utils.Config } func NewUploadService(db *gorm.DB, config *utils.Config) *UploadService { return &UploadService{ db: db, config: config, } } func (s *UploadService) UploadFile(file *multipart.FileHeader, userID uint, deviceUID, description string) (*model.FileUpload, error) { // 生成唯一文件名 ext := filepath.Ext(file.Filename) uniqueID := uuid.New().String() fileName := fmt.Sprintf("%s%s", uniqueID, ext) filePath := filepath.Join(s.config.Upload.Path, fileName) // 确保上传目录存在 if err := os.MkdirAll(s.config.Upload.Path, 0755); err != nil { return nil, err } // 保存文件 src, err := file.Open() if err != nil { return nil, err } defer src.Close() dst, err := os.Create(filePath) if err != nil { return nil, err } defer dst.Close() if _, err = io.Copy(dst, src); err != nil { return nil, err } // 如果提供了设备UID,获取设备型号 var deviceModel string if deviceUID != "" { var device model.Device if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err == nil { deviceModel = device.DeviceModel } } // 创建数据库记录 upload := &model.FileUpload{ FileName: file.Filename, FilePath: filePath, FileSize: file.Size, FileType: strings.ToLower(ext), UploadedBy: userID, DeviceUID: deviceUID, DeviceModel: deviceModel, Description: description, } if err := s.db.Create(upload).Error; err != nil { os.Remove(filePath) return nil, err } return upload, nil } func (s *UploadService) DownloadFile(id uint) (*model.FileUpload, error) { var file model.FileUpload if err := s.db.First(&file, id).Error; err != nil { return nil, err } return &file, nil } func (s *UploadService) DeleteFile(id uint, userID uint) error { var file model.FileUpload if err := s.db.First(&file, id).Error; err != nil { return err } if file.UploadedBy != userID { return errors.New("无权删除此文件") } if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) { return err } return s.db.Delete(&file).Error } func (s *UploadService) GetDeviceFiles(deviceUID string) ([]model.FileUpload, error) { var files []model.FileUpload err := s.db.Where("device_uid = ?", deviceUID).Find(&files).Error return files, err } func (s *UploadService) UploadChunk( file *multipart.FileHeader, fileHash string, chunkNumber int, totalChunks int, totalSize int64, filename string, userID uint, deviceUID string, ) error { // 创建分片存储目录 chunkDir := filepath.Join(s.config.Upload.Path, "chunks", fileHash) if err := os.MkdirAll(chunkDir, 0755); err != nil { return err } // 保存分片文件 chunkPath := filepath.Join(chunkDir, fmt.Sprintf("%d", chunkNumber)) src, err := file.Open() if err != nil { return err } defer src.Close() dst, err := os.Create(chunkPath) if err != nil { return err } defer dst.Close() if _, err = io.Copy(dst, src); err != nil { return err } // 记录分片信息 chunk := model.UploadChunk{ FileHash: fileHash, ChunkNumber: chunkNumber, ChunkSize: file.Size, ChunkPath: chunkPath, TotalChunks: totalChunks, TotalSize: totalSize, Filename: filename, FileType: strings.ToLower(filepath.Ext(filename)), UploadedBy: userID, DeviceUID: deviceUID, } return s.db.Create(&chunk).Error } func (s *UploadService) CheckUploadStatus(fileHash string) (bool, error) { var chunks []model.UploadChunk if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil { return false, err } if len(chunks) == 0 { return false, nil } totalChunks := chunks[0].TotalChunks return len(chunks) == totalChunks, nil } func (s *UploadService) MergeChunks(fileHash string) (*model.FileUpload, error) { var chunks []model.UploadChunk if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil { return nil, err } if len(chunks) == 0 { return nil, errors.New("未找到文件分片") } if len(chunks) != chunks[0].TotalChunks { return nil, errors.New("文件分片不完整") } // 按分片序号排序 sort.Slice(chunks, func(i, j int) bool { return chunks[i].ChunkNumber < chunks[j].ChunkNumber }) // 创建最终文件 finalPath := filepath.Join(s.config.Upload.Path, fmt.Sprintf("%s%s", uuid.New().String(), chunks[0].FileType)) finalFile, err := os.Create(finalPath) if err != nil { return nil, err } defer finalFile.Close() // 合并分片 hash := sha256.New() for _, chunk := range chunks { chunkFile, err := os.Open(chunk.ChunkPath) if err != nil { return nil, err } if _, err = io.Copy(finalFile, chunkFile); err != nil { chunkFile.Close() return nil, err } if _, err = io.Copy(hash, chunkFile); err != nil { chunkFile.Close() return nil, err } chunkFile.Close() os.Remove(chunk.ChunkPath) // 删除已合并的分片 } // 验证文件哈希 if hex.EncodeToString(hash.Sum(nil)) != fileHash { os.Remove(finalPath) return nil, errors.New("文件哈希验证失败") } // 创建文件记录 upload := &model.FileUpload{ FileName: chunks[0].Filename, FilePath: finalPath, FileSize: chunks[0].TotalSize, FileType: chunks[0].FileType, UploadedBy: chunks[0].UploadedBy, DeviceUID: chunks[0].DeviceUID, } if err := s.db.Create(upload).Error; err != nil { os.Remove(finalPath) return nil, err } // 清理分片记录 s.db.Where("file_hash = ?", fileHash).Delete(&model.UploadChunk{}) os.RemoveAll(filepath.Dir(chunks[0].ChunkPath)) return upload, nil } func (s *UploadService) CleanupExpiredChunks() error { expireTime := time.Now().Add(-24 * time.Hour) var expiredChunks []model.UploadChunk if err := s.db.Where("completed = ? AND created_at < ?", false, expireTime).Find(&expiredChunks).Error; err != nil { return err } for _, chunk := range expiredChunks { os.Remove(chunk.ChunkPath) if len(chunk.ChunkPath) > 0 { chunkDir := filepath.Dir(chunk.ChunkPath) os.RemoveAll(chunkDir) } } return s.db.Unscoped().Where("completed = ? AND created_at < ?", false, expireTime).Delete(&model.UploadChunk{}).Error }