LicenseManger/internal/service/upload.go

464 lines
6.4 KiB
Go
Raw Permalink Normal View History

2024-11-14 14:55:43 +00:00
package service
import (
"errors"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"sort"
"strings"
"time"
"licserver/internal/model"
"licserver/internal/utils"
"github.com/google/uuid"
"gorm.io/gorm"
"crypto/sha256"
"encoding/hex"
)
type UploadService struct {
db *gorm.DB
config *utils.Config
}
func NewUploadService(db *gorm.DB, config *utils.Config) *UploadService {
return &UploadService{
db: db,
config: config,
}
}
func (s *UploadService) UploadFile(file *multipart.FileHeader, userID uint, deviceUID, description string) (*model.FileUpload, error) {
// 生成唯一文件名
ext := filepath.Ext(file.Filename)
uniqueID := uuid.New().String()
fileName := fmt.Sprintf("%s%s", uniqueID, ext)
filePath := filepath.Join(s.config.Upload.Path, fileName)
// 确保上传目录存在
if err := os.MkdirAll(s.config.Upload.Path, 0755); err != nil {
return nil, err
}
// 保存文件
src, err := file.Open()
if err != nil {
return nil, err
}
defer src.Close()
dst, err := os.Create(filePath)
if err != nil {
return nil, err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return nil, err
}
// 如果提供了设备UID获取设备型号
var deviceModel string
if deviceUID != "" {
var device model.Device
if err := s.db.Where("uid = ?", deviceUID).First(&device).Error; err == nil {
deviceModel = device.DeviceModel
}
}
// 创建数据库记录
upload := &model.FileUpload{
FileName: file.Filename,
FilePath: filePath,
FileSize: file.Size,
FileType: strings.ToLower(ext),
UploadedBy: userID,
DeviceUID: deviceUID,
DeviceModel: deviceModel,
Description: description,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(filePath)
return nil, err
}
return upload, nil
}
func (s *UploadService) DownloadFile(id uint) (*model.FileUpload, error) {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return nil, err
}
return &file, nil
}
func (s *UploadService) DeleteFile(id uint, userID uint) error {
var file model.FileUpload
if err := s.db.First(&file, id).Error; err != nil {
return err
}
if file.UploadedBy != userID {
return errors.New("无权删除此文件")
}
if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&file).Error
}
func (s *UploadService) GetDeviceFiles(deviceUID string) ([]model.FileUpload, error) {
var files []model.FileUpload
err := s.db.Where("device_uid = ?", deviceUID).Find(&files).Error
return files, err
}
func (s *UploadService) UploadChunk(
file *multipart.FileHeader,
fileHash string,
chunkNumber int,
totalChunks int,
totalSize int64,
filename string,
userID uint,
deviceUID string,
) error {
// 创建分片存储目录
chunkDir := filepath.Join(s.config.Upload.Path, "chunks", fileHash)
if err := os.MkdirAll(chunkDir, 0755); err != nil {
return err
}
// 保存分片文件
chunkPath := filepath.Join(chunkDir, fmt.Sprintf("%d", chunkNumber))
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
dst, err := os.Create(chunkPath)
if err != nil {
return err
}
defer dst.Close()
if _, err = io.Copy(dst, src); err != nil {
return err
}
// 记录分片信息
chunk := model.UploadChunk{
FileHash: fileHash,
ChunkNumber: chunkNumber,
ChunkSize: file.Size,
ChunkPath: chunkPath,
TotalChunks: totalChunks,
TotalSize: totalSize,
Filename: filename,
FileType: strings.ToLower(filepath.Ext(filename)),
UploadedBy: userID,
DeviceUID: deviceUID,
}
return s.db.Create(&chunk).Error
}
func (s *UploadService) CheckUploadStatus(fileHash string) (bool, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return false, err
}
if len(chunks) == 0 {
return false, nil
}
totalChunks := chunks[0].TotalChunks
return len(chunks) == totalChunks, nil
}
func (s *UploadService) MergeChunks(fileHash string) (*model.FileUpload, error) {
var chunks []model.UploadChunk
if err := s.db.Where("file_hash = ?", fileHash).Find(&chunks).Error; err != nil {
return nil, err
}
if len(chunks) == 0 {
return nil, errors.New("未找到文件分片")
}
if len(chunks) != chunks[0].TotalChunks {
return nil, errors.New("文件分片不完整")
}
// 按分片序号排序
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkNumber < chunks[j].ChunkNumber
})
// 创建最终文件
finalPath := filepath.Join(s.config.Upload.Path, fmt.Sprintf("%s%s", uuid.New().String(), chunks[0].FileType))
finalFile, err := os.Create(finalPath)
if err != nil {
return nil, err
}
defer finalFile.Close()
// 合并分片
hash := sha256.New()
for _, chunk := range chunks {
chunkFile, err := os.Open(chunk.ChunkPath)
if err != nil {
return nil, err
}
if _, err = io.Copy(finalFile, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
if _, err = io.Copy(hash, chunkFile); err != nil {
chunkFile.Close()
return nil, err
}
chunkFile.Close()
os.Remove(chunk.ChunkPath) // 删除已合并的分片
}
// 验证文件哈希
if hex.EncodeToString(hash.Sum(nil)) != fileHash {
os.Remove(finalPath)
return nil, errors.New("文件哈希验证失败")
}
// 创建文件记录
upload := &model.FileUpload{
FileName: chunks[0].Filename,
FilePath: finalPath,
FileSize: chunks[0].TotalSize,
FileType: chunks[0].FileType,
UploadedBy: chunks[0].UploadedBy,
DeviceUID: chunks[0].DeviceUID,
}
if err := s.db.Create(upload).Error; err != nil {
os.Remove(finalPath)
return nil, err
}
// 清理分片记录
s.db.Where("file_hash = ?", fileHash).Delete(&model.UploadChunk{})
os.RemoveAll(filepath.Dir(chunks[0].ChunkPath))
return upload, nil
}
func (s *UploadService) CleanupExpiredChunks() error {
expireTime := time.Now().Add(-24 * time.Hour)
var expiredChunks []model.UploadChunk
if err := s.db.Where("completed = ? AND created_at < ?", false, expireTime).Find(&expiredChunks).Error; err != nil {
return err
}
for _, chunk := range expiredChunks {
os.Remove(chunk.ChunkPath)
if len(chunk.ChunkPath) > 0 {
chunkDir := filepath.Dir(chunk.ChunkPath)
os.RemoveAll(chunkDir)
}
}
return s.db.Unscoped().Where("completed = ? AND created_at < ?", false, expireTime).Delete(&model.UploadChunk{}).Error
}