Skip to content

Commit

Permalink
✨ feat: Implement WebRTC logic check for screen sharing.
Browse files Browse the repository at this point in the history
- Added WebRTC logic to handle screen sharing functionality.
- Implemented mechanisms for upgrading WebSocket connections and managing peer connections.
- Set up handlers for dispatching keyframes, handling incoming tracks, and processing WebSocket messages.
- Introduced screen sharing start and stop handlers to manage screen sharing sessions.

Related issue: YJU-OKURA#80
yuminn-k committed May 4, 2024
1 parent f7d8e5e commit dddde02
Showing 1 changed file with 366 additions and 18 deletions.
384 changes: 366 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
@@ -2,12 +2,19 @@ package main

import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"log"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"

@@ -29,7 +36,187 @@ import (
"gorm.io/gorm"
)

var redisClient *redis.Client
var (
redisClient *redis.Client
addr = flag.String("addr", ":8080", "http service address")
upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return origin == "http://localhost:3000" // 적절한 출처를 설정해야 합니다.
},
}

// lock for peerConnections and trackLocals
listLock sync.RWMutex
peerConnections = make(map[string]*peerConnectionState)
trackLocals = make(map[string]*webrtc.TrackLocalStaticRTP)
stopChan = make(chan bool)
)

type websocketMessage struct {
Event string `json:"event"`
Data string `json:"data"`
}

type peerConnectionState struct {
id string
peerConnection *webrtc.PeerConnection // peerConnection *webrtc.PeerConnection
websocket *threadSafeWriter // websocket *threadSafeWriter
}

// Helper to make Gorilla Websockets thread safe
type threadSafeWriter struct {
*websocket.Conn
sync.Mutex
}

func (t *threadSafeWriter) WriteJSON(v interface{}) error {
t.Lock()
defer t.Unlock()

return t.Conn.WriteJSON(v)
}

// Adding a new peer connection with a unique ID
func addPeerConnection(id string, pc *peerConnectionState) {
listLock.Lock()
defer listLock.Unlock()
peerConnections[id] = pc
}

// Removing a peer connection by ID
func removePeerConnection(id string) {
listLock.Lock()
defer listLock.Unlock()
if pc, ok := peerConnections[id]; ok {
pc.peerConnection.Close()
pc.websocket.Close()
delete(peerConnections, id)
}
}

func dispatchKeyFramesEvery(interval time.Duration, stop chan bool) {
ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
dispatchKeyFrame()
case <-stop:
return
}
}
}

func websocketHandler(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket upgrade error: %v", err)
return
}
defer c.Close()

pc, err := webrtc.NewPeerConnection(webrtc.Configuration{})
if err != nil {
log.Printf("PeerConnection creation failed: %v", err)
return
}
defer pc.Close()

id := uuid.New().String()
peerState := &peerConnectionState{id: id, peerConnection: pc, websocket: &threadSafeWriter{Conn: c}}
addPeerConnection(id, peerState)

pc.OnICECandidate(func(i *webrtc.ICECandidate) {
if i != nil {
sendCandidateToClient(i, peerState.websocket)
}
})

pc.OnConnectionStateChange(func(p webrtc.PeerConnectionState) {
switch p {
case webrtc.PeerConnectionStateClosed, webrtc.PeerConnectionStateFailed:
stopDispatchKeyFrames()
removePeerConnection(id)
case webrtc.PeerConnectionStateConnected:
stopChan = make(chan bool)
go dispatchKeyFramesEvery(3*time.Second, stopChan)
}
})

pc.OnTrack(func(t *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
handleIncomingTrack(t, peerState)
})

handleWebSocketMessages(peerState)

go dispatchKeyFramesEvery(3*time.Second, stopChan)
}

func stopDispatchKeyFrames() {
select {
case stopChan <- true:
default:
}
}

func sendCandidateToClient(candidate *webrtc.ICECandidate, client *threadSafeWriter) {
candidateData, _ := json.Marshal(candidate.ToJSON())
client.WriteJSON(&websocketMessage{Event: "candidate", Data: string(candidateData)})
}

func handleIncomingTrack(track *webrtc.TrackRemote, state *peerConnectionState) {
localTrack := addTrack(track)
defer removeTrack(localTrack)

buffer := make([]byte, 1500)
for {
i, _, readErr := track.Read(buffer)
if readErr != nil {
break
}
if _, writeErr := localTrack.Write(buffer[:i]); writeErr != nil {
break
}
}
}

func handleWebSocketMessages(state *peerConnectionState) {
for {
_, message, err := state.websocket.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
var msg websocketMessage
if err := json.Unmarshal(message, &msg); err != nil {
log.Println("unmarshal:", err)
continue
}

switch msg.Event {
case "candidate":
handleCandidateMessage(msg.Data, state.peerConnection)
case "answer":
handleAnswerMessage(msg.Data, state.peerConnection)
default:
log.Println("Unknown message event:", msg.Event)
}
}
}

func handleCandidateMessage(data string, pc *webrtc.PeerConnection) {
var candidate webrtc.ICECandidateInit
json.Unmarshal([]byte(data), &candidate)
pc.AddICECandidate(candidate)
}

func handleAnswerMessage(data string, pc *webrtc.PeerConnection) {
var answer webrtc.SessionDescription
json.Unmarshal([]byte(data), &answer)
pc.SetRemoteDescription(answer)
}

func main() {
configureGinMode()
@@ -46,21 +233,181 @@ func main() {
router := setupRouter(db, jwtService)
startServer(router)

// Parse the flags passed to program
flag.Parse()

// Init other state
log.SetFlags(0)
trackLocals = map[string]*webrtc.TrackLocalStaticRTP{}

// Start dispatching key frames with a control channel
go dispatchKeyFramesEvery(time.Second*3, stopChan) // request a keyframe every 3 seconds

// start HTTP server
log.Fatal(http.ListenAndServe(*addr, nil))
}

// Add to list of tracks and fire renegotation for all PeerConnections
func addTrack(t *webrtc.TrackRemote) *webrtc.TrackLocalStaticRTP {
listLock.Lock()
defer func() {
listLock.Unlock()
signalPeerConnections()
}()

// Create a new TrackLocal with the same codec as our incoming
trackLocal, err := webrtc.NewTrackLocalStaticRTP(t.Codec().RTPCodecCapability, t.ID(), t.StreamID())
if err != nil {
panic(err)
}

trackLocals[t.ID()] = trackLocal
return trackLocal
}

// Remove from list of tracks and fire renegotation for all PeerConnections
func removeTrack(t *webrtc.TrackLocalStaticRTP) {
listLock.Lock()
defer func() {
listLock.Unlock()
signalPeerConnections()
}()

delete(trackLocals, t.ID())
}

// signalPeerConnections updates each PeerConnection so that it is getting all the expected media tracks
func signalPeerConnections() {
listLock.Lock()
defer listLock.Unlock()

attemptSync := func() (tryAgain bool) {
for id, conn := range peerConnections {
if conn.peerConnection.ConnectionState() == webrtc.PeerConnectionStateClosed {
removePeerConnection(id)
return true
}

existingSenders := map[string]bool{}
for _, sender := range conn.peerConnection.GetSenders() {
if sender.Track() == nil {
continue
}
existingSenders[sender.Track().ID()] = true
if _, ok := trackLocals[sender.Track().ID()]; !ok {
if err := conn.peerConnection.RemoveTrack(sender); err != nil {
return true
}
}
}

for trackID, track := range trackLocals {
if _, ok := existingSenders[trackID]; !ok {
if _, err := conn.peerConnection.AddTrack(track); err != nil {
return true
}
}
}

offer, err := conn.peerConnection.CreateOffer(nil)
if err != nil {
return true
}

if err = conn.peerConnection.SetLocalDescription(offer); err != nil {
return true
}

offerString, err := json.Marshal(offer)
if err != nil {
return true
}

if err = conn.websocket.WriteJSON(&websocketMessage{
Event: "offer",
Data: string(offerString),
}); err != nil {
return true
}
}
return false
}

for syncAttempt := 0; ; syncAttempt++ {
if syncAttempt == 25 {
time.Sleep(time.Second * 3)
go signalPeerConnections()
return
}
if !attemptSync() {
break
}
}
}

// dispatchKeyFrame sends a keyframe to all PeerConnections, used everytime a new user joins the call
func dispatchKeyFrame() {
listLock.Lock()
defer listLock.Unlock()

for i := range peerConnections {
for _, receiver := range peerConnections[i].peerConnection.GetReceivers() {
if receiver.Track() == nil {
continue
}

_ = peerConnections[i].peerConnection.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{
MediaSSRC: uint32(receiver.Track().SSRC()),
},
})
}
}
}

func startScreenSharingHandler(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket upgrade error: %v", err)
return
}
defer c.Close()

pc, err := webrtc.NewPeerConnection(webrtc.Configuration{})
if err != nil {
log.Printf("PeerConnection creation failed: %v", err)
return
}
defer pc.Close()

id := uuid.New().String()
peerState := &peerConnectionState{id: id, peerConnection: pc, websocket: &threadSafeWriter{Conn: c}}
addPeerConnection(id, peerState)

track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: "video/vp8"}, "screen", "share")
if err != nil {
log.Printf("Failed to create track: %v", err)
return
}

_, err = pc.AddTrack(track)
if err != nil {
log.Printf("Failed to add track: %v", err)
return
}

handleWebSocketMessages(peerState)
}

func stopScreenSharingHandler(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
removePeerConnection(id)
}

// Initialize gRPC server
//func startGRPCServer() {
// lis, err := net.Listen("tcp", ":50051")
// if err != nil {
// log.Fatalf("failed to listen: %v", err)
// }
// grpcServer := grpc.NewServer()
// // Register gRPC services here
// log.Printf("server listening at %v", lis.Addr())
// if err := grpcServer.Serve(lis); err != nil {
// log.Fatalf("failed to serve: %v", err)
// }
//}
// Unique ID generator, could be a simple counter or a UUID generator
func generateUniqueID() string {
return uuid.New().String() // Assuming github.com/google/uuid is imported
}

// configureGinMode Ginのモードを設定する
func configureGinMode() {
@@ -258,7 +605,7 @@ func setupRoutes(router *gin.Engine, userController *controllers.UserController,
setupGoogleAuthRoutes(router, googleAuthController)
setupCreateClassRoutes(router, createClassController, jwtService)
setupChatRoutes(router, chatController, jwtService)
setupLiveClassRoutes(router, liveClassController)
setupLiveClassRoutes(router)
}

// @securityDefinitions.apikey Bearer
@@ -294,6 +641,7 @@ func setupClassBoardRoutes(router *gin.Engine, controller *controllers.ClassBoar
cb.DELETE(":id", controller.DeleteClassBoard)

cb.GET("subscribe", controller.SubscribeClassBoardUpdates)
cb.GET("search", controller.SearchClassBoards)

// TODO: フロントエンド側の実装が完了したら、コメントアウトを外す
//protected := cb.Group("/:uid/:cid")
@@ -483,10 +831,10 @@ func manageChatRooms(db *gorm.DB, chatManager *services.Manager) {
}
}

func setupLiveClassRoutes(router *gin.Engine, liveClassController *controllers.LiveClassController) {
func setupLiveClassRoutes(router *gin.Engine) {
live := router.Group("/api/gin/live")
{
live.GET("screen_share/:uid/:cid", liveClassController.GetScreenShareInfo)
live.POST("screen_share/start/:cid", liveClassController.StartScreenShare)
live.GET("screen-share/start", gin.WrapF(startScreenSharingHandler))
live.GET("screen-share/stop", gin.WrapF(stopScreenSharingHandler))
}
}

0 comments on commit dddde02

Please sign in to comment.