From dddde022d7d25a4d3468388f019e6a8a7847be86 Mon Sep 17 00:00:00 2001 From: devYuMinKim Date: Sun, 5 May 2024 04:08:09 +0900 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Implement=20WebRTC=20logic?= =?UTF-8?q?=20check=20for=20screen=20sharing.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added WebRTC logic to handle screen sharing functionality. - Implemented mechanisms for upgrading WebSocket connections and managing peer connections. - Set up handlers for dispatching keyframes, handling incoming tracks, and processing WebSocket messages. - Introduced screen sharing start and stop handlers to manage screen sharing sessions. Related issue: #80 --- main.go | 384 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 366 insertions(+), 18 deletions(-) diff --git a/main.go b/main.go index 06bab65..f7ce419 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,19 @@ package main import ( "context" + "encoding/json" "errors" + "flag" "fmt" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" "log" "net/http" "os" "os/signal" + "sync" "syscall" "time" @@ -29,7 +36,187 @@ import ( "gorm.io/gorm" ) -var redisClient *redis.Client +var ( + redisClient *redis.Client + addr = flag.String("addr", ":8080", "http service address") + upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + return origin == "http://localhost:3000" // 적절한 출처를 설정해야 합니다. + }, + } + + // lock for peerConnections and trackLocals + listLock sync.RWMutex + peerConnections = make(map[string]*peerConnectionState) + trackLocals = make(map[string]*webrtc.TrackLocalStaticRTP) + stopChan = make(chan bool) +) + +type websocketMessage struct { + Event string `json:"event"` + Data string `json:"data"` +} + +type peerConnectionState struct { + id string + peerConnection *webrtc.PeerConnection // peerConnection *webrtc.PeerConnection + websocket *threadSafeWriter // websocket *threadSafeWriter +} + +// Helper to make Gorilla Websockets thread safe +type threadSafeWriter struct { + *websocket.Conn + sync.Mutex +} + +func (t *threadSafeWriter) WriteJSON(v interface{}) error { + t.Lock() + defer t.Unlock() + + return t.Conn.WriteJSON(v) +} + +// Adding a new peer connection with a unique ID +func addPeerConnection(id string, pc *peerConnectionState) { + listLock.Lock() + defer listLock.Unlock() + peerConnections[id] = pc +} + +// Removing a peer connection by ID +func removePeerConnection(id string) { + listLock.Lock() + defer listLock.Unlock() + if pc, ok := peerConnections[id]; ok { + pc.peerConnection.Close() + pc.websocket.Close() + delete(peerConnections, id) + } +} + +func dispatchKeyFramesEvery(interval time.Duration, stop chan bool) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + dispatchKeyFrame() + case <-stop: + return + } + } +} + +func websocketHandler(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade error: %v", err) + return + } + defer c.Close() + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + log.Printf("PeerConnection creation failed: %v", err) + return + } + defer pc.Close() + + id := uuid.New().String() + peerState := &peerConnectionState{id: id, peerConnection: pc, websocket: &threadSafeWriter{Conn: c}} + addPeerConnection(id, peerState) + + pc.OnICECandidate(func(i *webrtc.ICECandidate) { + if i != nil { + sendCandidateToClient(i, peerState.websocket) + } + }) + + pc.OnConnectionStateChange(func(p webrtc.PeerConnectionState) { + switch p { + case webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateFailed: + stopDispatchKeyFrames() + removePeerConnection(id) + case webrtc.PeerConnectionStateConnected: + stopChan = make(chan bool) + go dispatchKeyFramesEvery(3*time.Second, stopChan) + } + }) + + pc.OnTrack(func(t *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + handleIncomingTrack(t, peerState) + }) + + handleWebSocketMessages(peerState) + + go dispatchKeyFramesEvery(3*time.Second, stopChan) +} + +func stopDispatchKeyFrames() { + select { + case stopChan <- true: + default: + } +} + +func sendCandidateToClient(candidate *webrtc.ICECandidate, client *threadSafeWriter) { + candidateData, _ := json.Marshal(candidate.ToJSON()) + client.WriteJSON(&websocketMessage{Event: "candidate", Data: string(candidateData)}) +} + +func handleIncomingTrack(track *webrtc.TrackRemote, state *peerConnectionState) { + localTrack := addTrack(track) + defer removeTrack(localTrack) + + buffer := make([]byte, 1500) + for { + i, _, readErr := track.Read(buffer) + if readErr != nil { + break + } + if _, writeErr := localTrack.Write(buffer[:i]); writeErr != nil { + break + } + } +} + +func handleWebSocketMessages(state *peerConnectionState) { + for { + _, message, err := state.websocket.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + var msg websocketMessage + if err := json.Unmarshal(message, &msg); err != nil { + log.Println("unmarshal:", err) + continue + } + + switch msg.Event { + case "candidate": + handleCandidateMessage(msg.Data, state.peerConnection) + case "answer": + handleAnswerMessage(msg.Data, state.peerConnection) + default: + log.Println("Unknown message event:", msg.Event) + } + } +} + +func handleCandidateMessage(data string, pc *webrtc.PeerConnection) { + var candidate webrtc.ICECandidateInit + json.Unmarshal([]byte(data), &candidate) + pc.AddICECandidate(candidate) +} + +func handleAnswerMessage(data string, pc *webrtc.PeerConnection) { + var answer webrtc.SessionDescription + json.Unmarshal([]byte(data), &answer) + pc.SetRemoteDescription(answer) +} func main() { configureGinMode() @@ -46,21 +233,181 @@ func main() { router := setupRouter(db, jwtService) startServer(router) + // Parse the flags passed to program + flag.Parse() + + // Init other state + log.SetFlags(0) + trackLocals = map[string]*webrtc.TrackLocalStaticRTP{} + + // Start dispatching key frames with a control channel + go dispatchKeyFramesEvery(time.Second*3, stopChan) // request a keyframe every 3 seconds + + // start HTTP server + log.Fatal(http.ListenAndServe(*addr, nil)) +} + +// Add to list of tracks and fire renegotation for all PeerConnections +func addTrack(t *webrtc.TrackRemote) *webrtc.TrackLocalStaticRTP { + listLock.Lock() + defer func() { + listLock.Unlock() + signalPeerConnections() + }() + + // Create a new TrackLocal with the same codec as our incoming + trackLocal, err := webrtc.NewTrackLocalStaticRTP(t.Codec().RTPCodecCapability, t.ID(), t.StreamID()) + if err != nil { + panic(err) + } + + trackLocals[t.ID()] = trackLocal + return trackLocal +} + +// Remove from list of tracks and fire renegotation for all PeerConnections +func removeTrack(t *webrtc.TrackLocalStaticRTP) { + listLock.Lock() + defer func() { + listLock.Unlock() + signalPeerConnections() + }() + + delete(trackLocals, t.ID()) +} + +// signalPeerConnections updates each PeerConnection so that it is getting all the expected media tracks +func signalPeerConnections() { + listLock.Lock() + defer listLock.Unlock() + + attemptSync := func() (tryAgain bool) { + for id, conn := range peerConnections { + if conn.peerConnection.ConnectionState() == webrtc.PeerConnectionStateClosed { + removePeerConnection(id) + return true + } + + existingSenders := map[string]bool{} + for _, sender := range conn.peerConnection.GetSenders() { + if sender.Track() == nil { + continue + } + existingSenders[sender.Track().ID()] = true + if _, ok := trackLocals[sender.Track().ID()]; !ok { + if err := conn.peerConnection.RemoveTrack(sender); err != nil { + return true + } + } + } + + for trackID, track := range trackLocals { + if _, ok := existingSenders[trackID]; !ok { + if _, err := conn.peerConnection.AddTrack(track); err != nil { + return true + } + } + } + + offer, err := conn.peerConnection.CreateOffer(nil) + if err != nil { + return true + } + + if err = conn.peerConnection.SetLocalDescription(offer); err != nil { + return true + } + + offerString, err := json.Marshal(offer) + if err != nil { + return true + } + + if err = conn.websocket.WriteJSON(&websocketMessage{ + Event: "offer", + Data: string(offerString), + }); err != nil { + return true + } + } + return false + } + + for syncAttempt := 0; ; syncAttempt++ { + if syncAttempt == 25 { + time.Sleep(time.Second * 3) + go signalPeerConnections() + return + } + if !attemptSync() { + break + } + } +} + +// dispatchKeyFrame sends a keyframe to all PeerConnections, used everytime a new user joins the call +func dispatchKeyFrame() { + listLock.Lock() + defer listLock.Unlock() + + for i := range peerConnections { + for _, receiver := range peerConnections[i].peerConnection.GetReceivers() { + if receiver.Track() == nil { + continue + } + + _ = peerConnections[i].peerConnection.WriteRTCP([]rtcp.Packet{ + &rtcp.PictureLossIndication{ + MediaSSRC: uint32(receiver.Track().SSRC()), + }, + }) + } + } +} + +func startScreenSharingHandler(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade error: %v", err) + return + } + defer c.Close() + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + log.Printf("PeerConnection creation failed: %v", err) + return + } + defer pc.Close() + + id := uuid.New().String() + peerState := &peerConnectionState{id: id, peerConnection: pc, websocket: &threadSafeWriter{Conn: c}} + addPeerConnection(id, peerState) + + track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: "video/vp8"}, "screen", "share") + if err != nil { + log.Printf("Failed to create track: %v", err) + return + } + + _, err = pc.AddTrack(track) + if err != nil { + log.Printf("Failed to add track: %v", err) + return + } + + handleWebSocketMessages(peerState) +} + +func stopScreenSharingHandler(w http.ResponseWriter, r *http.Request) { + id := r.URL.Query().Get("id") + removePeerConnection(id) } -// Initialize gRPC server -//func startGRPCServer() { -// lis, err := net.Listen("tcp", ":50051") -// if err != nil { -// log.Fatalf("failed to listen: %v", err) -// } -// grpcServer := grpc.NewServer() -// // Register gRPC services here -// log.Printf("server listening at %v", lis.Addr()) -// if err := grpcServer.Serve(lis); err != nil { -// log.Fatalf("failed to serve: %v", err) -// } -//} +// Unique ID generator, could be a simple counter or a UUID generator +func generateUniqueID() string { + return uuid.New().String() // Assuming github.com/google/uuid is imported +} // configureGinMode Ginのモードを設定する func configureGinMode() { @@ -258,7 +605,7 @@ func setupRoutes(router *gin.Engine, userController *controllers.UserController, setupGoogleAuthRoutes(router, googleAuthController) setupCreateClassRoutes(router, createClassController, jwtService) setupChatRoutes(router, chatController, jwtService) - setupLiveClassRoutes(router, liveClassController) + setupLiveClassRoutes(router) } // @securityDefinitions.apikey Bearer @@ -294,6 +641,7 @@ func setupClassBoardRoutes(router *gin.Engine, controller *controllers.ClassBoar cb.DELETE(":id", controller.DeleteClassBoard) cb.GET("subscribe", controller.SubscribeClassBoardUpdates) + cb.GET("search", controller.SearchClassBoards) // TODO: フロントエンド側の実装が完了したら、コメントアウトを外す //protected := cb.Group("/:uid/:cid") @@ -483,10 +831,10 @@ func manageChatRooms(db *gorm.DB, chatManager *services.Manager) { } } -func setupLiveClassRoutes(router *gin.Engine, liveClassController *controllers.LiveClassController) { +func setupLiveClassRoutes(router *gin.Engine) { live := router.Group("/api/gin/live") { - live.GET("screen_share/:uid/:cid", liveClassController.GetScreenShareInfo) - live.POST("screen_share/start/:cid", liveClassController.StartScreenShare) + live.GET("screen-share/start", gin.WrapF(startScreenSharingHandler)) + live.GET("screen-share/stop", gin.WrapF(stopScreenSharingHandler)) } }