mirror of
https://github.com/MatrixSeven/file-transfer-go.git
synced 2026-02-27 18:24:42 +08:00
第一版本
This commit is contained in:
129
internal/handlers/handlers.go
Normal file
129
internal/handlers/handlers.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"chuan/internal/models"
|
||||
"chuan/internal/services"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
p2pService *services.P2PService
|
||||
templates map[string]*template.Template
|
||||
}
|
||||
|
||||
func NewHandler(p2pService *services.P2PService) *Handler {
|
||||
h := &Handler{
|
||||
p2pService: p2pService,
|
||||
templates: make(map[string]*template.Template),
|
||||
}
|
||||
|
||||
// 加载模板
|
||||
h.loadTemplates()
|
||||
return h
|
||||
}
|
||||
|
||||
// 加载模板
|
||||
func (h *Handler) loadTemplates() {
|
||||
templateDir := "web/templates"
|
||||
|
||||
// 加载基础模板
|
||||
baseTemplate := filepath.Join(templateDir, "base.html")
|
||||
|
||||
// 加载各个页面模板
|
||||
templates := []string{"index.html"}
|
||||
|
||||
for _, tmplName := range templates {
|
||||
tmplPath := filepath.Join(templateDir, tmplName)
|
||||
tmpl, err := template.ParseFiles(baseTemplate, tmplPath)
|
||||
if err != nil {
|
||||
panic("加载模板失败: " + err.Error())
|
||||
}
|
||||
h.templates[tmplName] = tmpl
|
||||
println("模板加载成功:", tmplName)
|
||||
}
|
||||
}
|
||||
|
||||
// IndexHandler 首页处理器
|
||||
func (h *Handler) IndexHandler(w http.ResponseWriter, r *http.Request) {
|
||||
tmpl, exists := h.templates["index.html"]
|
||||
if !exists {
|
||||
http.Error(w, "模板不存在", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Title": "P2P文件传输",
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(w, data); err != nil {
|
||||
http.Error(w, "渲染模板失败: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRoomHandler 创建房间API
|
||||
func (h *Handler) CreateRoomHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "方法不允许", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Files []models.FileTransferInfo `json:"files"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "解析请求失败", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建房间
|
||||
code := h.p2pService.CreateRoom(req.Files)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"code": code,
|
||||
"message": "房间创建成功",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// GetRoomInfoHandler 获取房间信息API
|
||||
func (h *Handler) GetRoomInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "缺少取件码", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
room, exists := h.p2pService.GetRoomByCode(code)
|
||||
if !exists {
|
||||
response := map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "取件码不存在或已过期",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"files": room.Files,
|
||||
"message": "房间信息获取成功",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// HandleP2PWebSocket 处理P2P WebSocket连接
|
||||
func (h *Handler) HandleP2PWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
h.p2pService.HandleWebSocket(w, r)
|
||||
}
|
||||
68
internal/models/models.go
Normal file
68
internal/models/models.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileInfo 文件信息结构
|
||||
type FileInfo struct {
|
||||
ID string `json:"id"`
|
||||
FileName string `json:"filename"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
ContentType string `json:"content_type"`
|
||||
Code string `json:"code"`
|
||||
UploadTime time.Time `json:"upload_time"`
|
||||
ExpiryTime time.Time `json:"expiry_time"`
|
||||
DownloadURL string `json:"download_url"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// UploadResponse 上传响应结构
|
||||
type UploadResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code,omitempty"`
|
||||
FileInfo FileInfo `json:"file_info,omitempty"`
|
||||
DownloadURL string `json:"download_url,omitempty"`
|
||||
}
|
||||
|
||||
// WebRTCOffer WebRTC offer 结构
|
||||
type WebRTCOffer struct {
|
||||
SDP string `json:"sdp"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// WebRTCAnswer WebRTC answer 结构
|
||||
type WebRTCAnswer struct {
|
||||
SDP string `json:"sdp"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// WebRTCICECandidate ICE candidate 结构
|
||||
type WebRTCICECandidate struct {
|
||||
Candidate string `json:"candidate"`
|
||||
SDPMLineIndex int `json:"sdpMLineIndex"`
|
||||
SDPMid string `json:"sdpMid"`
|
||||
}
|
||||
|
||||
// VideoMessage 视频消息结构
|
||||
type VideoMessage struct {
|
||||
Type string `json:"type"`
|
||||
Payload interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
// FileTransferInfo P2P文件传输信息
|
||||
type FileTransferInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Type string `json:"type"`
|
||||
LastModified int64 `json:"lastModified"`
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
168
internal/services/file_service.go
Normal file
168
internal/services/file_service.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chuan/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type FileService struct {
|
||||
uploadDir string
|
||||
}
|
||||
|
||||
func NewFileService() *FileService {
|
||||
return &FileService{
|
||||
uploadDir: "./uploads",
|
||||
}
|
||||
}
|
||||
|
||||
// SaveFile 保存上传的文件
|
||||
func (fs *FileService) SaveFile(file multipart.File, header *multipart.FileHeader) (*models.FileInfo, error) {
|
||||
// 生成唯一文件ID
|
||||
fileID := uuid.New().String()
|
||||
|
||||
// 生成取件码
|
||||
code := fs.generateCode()
|
||||
|
||||
// 创建文件路径
|
||||
fileExt := filepath.Ext(header.Filename)
|
||||
fileName := fmt.Sprintf("%s%s", fileID, fileExt)
|
||||
filePath := filepath.Join(fs.uploadDir, fileName)
|
||||
|
||||
// 确保上传目录存在
|
||||
if err := os.MkdirAll(fs.uploadDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建上传目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建目标文件
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// 复制文件内容
|
||||
size, err := io.Copy(dst, file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取文件内容类型
|
||||
contentType := header.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = fs.getContentType(header.Filename)
|
||||
}
|
||||
|
||||
fileInfo := &models.FileInfo{
|
||||
ID: fileID,
|
||||
FileName: header.Filename,
|
||||
FileSize: size,
|
||||
ContentType: contentType,
|
||||
Code: code,
|
||||
UploadTime: time.Now(),
|
||||
ExpiryTime: time.Now().Add(24 * time.Hour), // 24小时过期
|
||||
FilePath: filePath,
|
||||
DownloadURL: fmt.Sprintf("/download/%s", code),
|
||||
}
|
||||
|
||||
// 存储文件信息到内存(生产环境应使用Redis)
|
||||
store := GetStore()
|
||||
if err := store.StoreFileInfo(fileInfo); err != nil {
|
||||
return nil, fmt.Errorf("存储文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
return fileInfo, nil
|
||||
}
|
||||
|
||||
// generateCode 生成6位取件码
|
||||
func (fs *FileService) generateCode() string {
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, 6)
|
||||
rand.Read(b)
|
||||
for i := range b {
|
||||
b[i] = charset[b[i]%byte(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// getContentType 根据文件扩展名获取内容类型
|
||||
func (fs *FileService) getContentType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".epub":
|
||||
return "application/epub+zip"
|
||||
case ".mobi":
|
||||
return "application/x-mobipocket-ebook"
|
||||
case ".txt":
|
||||
return "text/plain"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".mp4":
|
||||
return "video/mp4"
|
||||
case ".avi":
|
||||
return "video/avi"
|
||||
case ".mov":
|
||||
return "video/quicktime"
|
||||
case ".zip":
|
||||
return "application/zip"
|
||||
case ".rar":
|
||||
return "application/x-rar-compressed"
|
||||
case ".7z":
|
||||
return "application/x-7z-compressed"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
// GetFileByCode 根据取件码获取文件信息
|
||||
func (fs *FileService) GetFileByCode(code string) (*models.FileInfo, error) {
|
||||
store := GetStore()
|
||||
return store.GetFileInfo(code)
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
func (fs *FileService) DeleteFile(code string) error {
|
||||
fileInfo, err := fs.GetFileByCode(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除物理文件
|
||||
if err := os.Remove(fileInfo.FilePath); err != nil {
|
||||
return fmt.Errorf("删除文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 从内存存储删除文件信息
|
||||
store := GetStore()
|
||||
store.DeleteFileInfo(code)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertEpubToMobi 将EPUB转换为MOBI格式
|
||||
func (fs *FileService) ConvertEpubToMobi(epubPath string) (string, error) {
|
||||
// TODO: 集成Calibre API进行格式转换
|
||||
// 这里暂时返回原文件路径
|
||||
return epubPath, fmt.Errorf("格式转换功能尚未实现")
|
||||
}
|
||||
|
||||
// CleanExpiredFiles 清理过期文件
|
||||
func (fs *FileService) CleanExpiredFiles() error {
|
||||
// TODO: 实现定期清理过期文件的逻辑
|
||||
return nil
|
||||
}
|
||||
61
internal/services/memory_store.go
Normal file
61
internal/services/memory_store.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"chuan/internal/models"
|
||||
)
|
||||
|
||||
// 内存存储(生产环境应使用Redis)
|
||||
type MemoryStore struct {
|
||||
files map[string]*models.FileInfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var globalStore = &MemoryStore{
|
||||
files: make(map[string]*models.FileInfo),
|
||||
}
|
||||
|
||||
// StoreFileInfo 存储文件信息
|
||||
func (ms *MemoryStore) StoreFileInfo(fileInfo *models.FileInfo) error {
|
||||
ms.mutex.Lock()
|
||||
defer ms.mutex.Unlock()
|
||||
|
||||
ms.files[fileInfo.Code] = fileInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFileInfo 获取文件信息
|
||||
func (ms *MemoryStore) GetFileInfo(code string) (*models.FileInfo, error) {
|
||||
ms.mutex.RLock()
|
||||
defer ms.mutex.RUnlock()
|
||||
|
||||
fileInfo, exists := ms.files[code]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("文件不存在或已过期")
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(fileInfo.ExpiryTime) {
|
||||
delete(ms.files, code)
|
||||
return nil, fmt.Errorf("文件已过期")
|
||||
}
|
||||
|
||||
return fileInfo, nil
|
||||
}
|
||||
|
||||
// DeleteFileInfo 删除文件信息
|
||||
func (ms *MemoryStore) DeleteFileInfo(code string) error {
|
||||
ms.mutex.Lock()
|
||||
defer ms.mutex.Unlock()
|
||||
|
||||
delete(ms.files, code)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStore 获取全局存储实例
|
||||
func GetStore() *MemoryStore {
|
||||
return globalStore
|
||||
}
|
||||
256
internal/services/p2p_service.go
Normal file
256
internal/services/p2p_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"chuan/internal/models"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type FileTransferRoom struct {
|
||||
ID string
|
||||
Code string // 取件码
|
||||
Files []models.FileTransferInfo // 待传输文件信息
|
||||
Sender *websocket.Conn // 发送方连接
|
||||
Receiver *websocket.Conn // 接收方连接
|
||||
CreatedAt time.Time // 创建时间
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type P2PService struct {
|
||||
rooms map[string]*FileTransferRoom // 使用取件码作为key
|
||||
roomsMux sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
func NewP2PService() *P2PService {
|
||||
service := &P2PService{
|
||||
rooms: make(map[string]*FileTransferRoom),
|
||||
roomsMux: sync.RWMutex{},
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境应当限制
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 启动房间清理任务
|
||||
go service.cleanupExpiredRooms()
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// CreateRoom 创建新房间并返回取件码
|
||||
func (p *P2PService) CreateRoom(files []models.FileTransferInfo) string {
|
||||
code := generatePickupCode()
|
||||
|
||||
p.roomsMux.Lock()
|
||||
defer p.roomsMux.Unlock()
|
||||
|
||||
room := &FileTransferRoom{
|
||||
ID: "room_" + code,
|
||||
Code: code,
|
||||
Files: files,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
p.rooms[code] = room
|
||||
log.Printf("创建房间,取件码: %s,文件数量: %d", code, len(files))
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
// GetRoomByCode 根据取件码获取房间
|
||||
func (p *P2PService) GetRoomByCode(code string) (*FileTransferRoom, bool) {
|
||||
p.roomsMux.RLock()
|
||||
defer p.roomsMux.RUnlock()
|
||||
|
||||
room, exists := p.rooms[code]
|
||||
return room, exists
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (p *P2PService) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := p.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket升级失败: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 获取取件码和角色
|
||||
code := r.URL.Query().Get("code")
|
||||
role := r.URL.Query().Get("role") // "sender" or "receiver"
|
||||
|
||||
if code == "" || (role != "sender" && role != "receiver") {
|
||||
log.Printf("缺少取件码或角色参数")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取房间
|
||||
room, exists := p.GetRoomByCode(code)
|
||||
if !exists {
|
||||
log.Printf("房间不存在: %s", code)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置连接
|
||||
room.mutex.Lock()
|
||||
if role == "sender" {
|
||||
room.Sender = conn
|
||||
log.Printf("发送方连接到房间: %s", code)
|
||||
} else {
|
||||
room.Receiver = conn
|
||||
log.Printf("接收方连接到房间: %s", code)
|
||||
|
||||
// 发送文件列表给接收方
|
||||
filesMsg := models.VideoMessage{
|
||||
Type: "file-list",
|
||||
Payload: map[string]interface{}{"files": room.Files},
|
||||
}
|
||||
if err := conn.WriteJSON(filesMsg); err != nil {
|
||||
log.Printf("发送文件列表失败: %v", err)
|
||||
}
|
||||
|
||||
// 通知发送方接收方已连接
|
||||
if room.Sender != nil {
|
||||
readyMsg := models.VideoMessage{
|
||||
Type: "receiver-ready",
|
||||
Payload: map[string]interface{}{},
|
||||
}
|
||||
if err := room.Sender.WriteJSON(readyMsg); err != nil {
|
||||
log.Printf("发送接收方就绪消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
room.mutex.Unlock() // 连接关闭时清理
|
||||
defer func() {
|
||||
room.mutex.Lock()
|
||||
if role == "sender" {
|
||||
room.Sender = nil
|
||||
} else {
|
||||
room.Receiver = nil
|
||||
}
|
||||
room.mutex.Unlock()
|
||||
|
||||
// 如果双方都断开连接,删除房间
|
||||
p.cleanupRoom(code)
|
||||
}()
|
||||
|
||||
// 处理消息
|
||||
for {
|
||||
var msg models.VideoMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
log.Printf("读取WebSocket消息失败: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
log.Printf("收到WebSocket消息: 类型=%s, 来自=%s, 房间=%s", msg.Type, role, code)
|
||||
|
||||
// 转发消息到对方
|
||||
p.forwardMessage(room, role, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// forwardMessage 转发消息到对方
|
||||
func (p *P2PService) forwardMessage(room *FileTransferRoom, senderRole string, msg models.VideoMessage) {
|
||||
room.mutex.RLock()
|
||||
defer room.mutex.RUnlock()
|
||||
|
||||
var targetConn *websocket.Conn
|
||||
var targetRole string
|
||||
if senderRole == "sender" && room.Receiver != nil {
|
||||
targetConn = room.Receiver
|
||||
targetRole = "receiver"
|
||||
} else if senderRole == "receiver" && room.Sender != nil {
|
||||
targetConn = room.Sender
|
||||
targetRole = "sender"
|
||||
}
|
||||
|
||||
if targetConn != nil {
|
||||
log.Printf("转发消息: 类型=%s, 从%s到%s", msg.Type, senderRole, targetRole)
|
||||
if err := targetConn.WriteJSON(msg); err != nil {
|
||||
log.Printf("转发消息失败: %v", err)
|
||||
} else {
|
||||
log.Printf("消息转发成功: 类型=%s", msg.Type)
|
||||
}
|
||||
} else {
|
||||
log.Printf("无法转发消息: 目标连接不存在, 发送方=%s", senderRole)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoom 清理房间
|
||||
func (p *P2PService) cleanupRoom(code string) {
|
||||
p.roomsMux.Lock()
|
||||
defer p.roomsMux.Unlock()
|
||||
|
||||
if room, exists := p.rooms[code]; exists {
|
||||
room.mutex.RLock()
|
||||
bothDisconnected := room.Sender == nil && room.Receiver == nil
|
||||
room.mutex.RUnlock()
|
||||
|
||||
if bothDisconnected {
|
||||
delete(p.rooms, code)
|
||||
log.Printf("清理房间: %s", code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredRooms 定期清理过期房间
|
||||
func (p *P2PService) cleanupExpiredRooms() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
p.roomsMux.Lock()
|
||||
now := time.Now()
|
||||
for code, room := range p.rooms {
|
||||
// 房间存在超过1小时则删除
|
||||
if now.Sub(room.CreatedAt) > time.Hour {
|
||||
delete(p.rooms, code)
|
||||
log.Printf("清理过期房间: %s", code)
|
||||
}
|
||||
}
|
||||
p.roomsMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// generatePickupCode 生成6位取件码
|
||||
func generatePickupCode() string {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
code := rand.Intn(900000) + 100000
|
||||
return strconv.Itoa(code)
|
||||
}
|
||||
|
||||
// GetRoomStats 获取房间统计信息
|
||||
func (p *P2PService) GetRoomStats() map[string]interface{} {
|
||||
p.roomsMux.RLock()
|
||||
defer p.roomsMux.RUnlock()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_rooms": len(p.rooms),
|
||||
"rooms": make([]map[string]interface{}, 0),
|
||||
}
|
||||
|
||||
for code, room := range p.rooms {
|
||||
room.mutex.RLock()
|
||||
roomInfo := map[string]interface{}{
|
||||
"code": code,
|
||||
"file_count": len(room.Files),
|
||||
"has_sender": room.Sender != nil,
|
||||
"has_receiver": room.Receiver != nil,
|
||||
"created_at": room.CreatedAt,
|
||||
}
|
||||
room.mutex.RUnlock()
|
||||
stats["rooms"] = append(stats["rooms"].([]map[string]interface{}), roomInfo)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
175
internal/services/webrtc_service.go
Normal file
175
internal/services/webrtc_service.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"chuan/internal/models"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WebRTCService struct {
|
||||
clients map[string]*websocket.Conn
|
||||
clientsMux sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
func NewWebRTCService() *WebRTCService {
|
||||
return &WebRTCService{
|
||||
clients: make(map[string]*websocket.Conn),
|
||||
clientsMux: sync.RWMutex{},
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境应当限制
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (ws *WebRTCService) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := ws.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket升级失败: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 为客户端生成唯一ID
|
||||
clientID := ws.generateClientID()
|
||||
|
||||
// 添加客户端到连接池
|
||||
ws.clientsMux.Lock()
|
||||
ws.clients[clientID] = conn
|
||||
ws.clientsMux.Unlock()
|
||||
|
||||
// 连接关闭时清理
|
||||
defer func() {
|
||||
ws.clientsMux.Lock()
|
||||
delete(ws.clients, clientID)
|
||||
ws.clientsMux.Unlock()
|
||||
}()
|
||||
|
||||
// 发送欢迎消息
|
||||
welcomeMsg := models.VideoMessage{
|
||||
Type: "welcome",
|
||||
Payload: map[string]string{"clientId": clientID},
|
||||
}
|
||||
ws.sendMessage(conn, welcomeMsg)
|
||||
|
||||
// 处理消息
|
||||
for {
|
||||
var msg models.VideoMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
log.Printf("读取WebSocket消息失败: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "offer":
|
||||
ws.handleOffer(clientID, msg)
|
||||
case "answer":
|
||||
ws.handleAnswer(clientID, msg)
|
||||
case "ice-candidate":
|
||||
ws.handleICECandidate(clientID, msg)
|
||||
case "join-room":
|
||||
ws.handleJoinRoom(clientID, msg)
|
||||
case "leave-room":
|
||||
ws.handleLeaveRoom(clientID, msg)
|
||||
default:
|
||||
log.Printf("未知消息类型: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleOffer 处理WebRTC Offer
|
||||
func (ws *WebRTCService) handleOffer(clientID string, msg models.VideoMessage) {
|
||||
// 广播offer到其他客户端
|
||||
ws.broadcastToOthers(clientID, msg)
|
||||
}
|
||||
|
||||
// handleAnswer 处理WebRTC Answer
|
||||
func (ws *WebRTCService) handleAnswer(clientID string, msg models.VideoMessage) {
|
||||
// 广播answer到其他客户端
|
||||
ws.broadcastToOthers(clientID, msg)
|
||||
}
|
||||
|
||||
// handleICECandidate 处理ICE candidate
|
||||
func (ws *WebRTCService) handleICECandidate(clientID string, msg models.VideoMessage) {
|
||||
// 广播ICE candidate到其他客户端
|
||||
ws.broadcastToOthers(clientID, msg)
|
||||
}
|
||||
|
||||
// handleJoinRoom 处理加入房间
|
||||
func (ws *WebRTCService) handleJoinRoom(clientID string, msg models.VideoMessage) {
|
||||
// TODO: 实现房间管理逻辑
|
||||
log.Printf("客户端 %s 加入房间", clientID)
|
||||
}
|
||||
|
||||
// handleLeaveRoom 处理离开房间
|
||||
func (ws *WebRTCService) handleLeaveRoom(clientID string, msg models.VideoMessage) {
|
||||
// TODO: 实现房间管理逻辑
|
||||
log.Printf("客户端 %s 离开房间", clientID)
|
||||
}
|
||||
|
||||
// broadcastToOthers 向其他客户端广播消息
|
||||
func (ws *WebRTCService) broadcastToOthers(senderID string, msg models.VideoMessage) {
|
||||
ws.clientsMux.RLock()
|
||||
defer ws.clientsMux.RUnlock()
|
||||
|
||||
for clientID, conn := range ws.clients {
|
||||
if clientID != senderID {
|
||||
ws.sendMessage(conn, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessage 发送消息到WebSocket连接
|
||||
func (ws *WebRTCService) sendMessage(conn *websocket.Conn, msg models.VideoMessage) {
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("发送WebSocket消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generateClientID 生成客户端ID
|
||||
func (ws *WebRTCService) generateClientID() string {
|
||||
// 简单的ID生成,生产环境应使用更安全的方法
|
||||
return "client_" + randomString(8)
|
||||
}
|
||||
|
||||
// CreateOffer 创建WebRTC Offer
|
||||
func (ws *WebRTCService) CreateOffer() (*models.WebRTCOffer, error) {
|
||||
// TODO: 实现WebRTC Offer创建
|
||||
return &models.WebRTCOffer{
|
||||
SDP: "v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\n...", // 示例SDP
|
||||
Type: "offer",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateAnswer 创建WebRTC Answer
|
||||
func (ws *WebRTCService) CreateAnswer(offer *models.WebRTCOffer) (*models.WebRTCAnswer, error) {
|
||||
// TODO: 实现WebRTC Answer创建
|
||||
return &models.WebRTCAnswer{
|
||||
SDP: "v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\n...", // 示例SDP
|
||||
Type: "answer",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddICECandidate 添加ICE候选
|
||||
func (ws *WebRTCService) AddICECandidate(candidate *models.WebRTCICECandidate) error {
|
||||
// TODO: 实现ICE候选处理
|
||||
return nil
|
||||
}
|
||||
|
||||
// randomString 生成随机字符串
|
||||
func randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[i%len(charset)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
Reference in New Issue
Block a user