Skip to content

Commit f54f156

Browse files
committed
inProgressInvite lifecycle, stable ID generation, minor log field reshuffling
1 parent ec3f57b commit f54f156

File tree

3 files changed

+60
-35
lines changed

3 files changed

+60
-35
lines changed

pkg/sip/inbound.go

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030

3131
msdk "github.com/livekit/media-sdk"
3232
"github.com/livekit/protocol/rpc"
33+
uuid "github.com/satori/go.uuid"
3334

3435
"github.com/frostbyte73/core"
3536
"github.com/icholy/digest"
@@ -64,6 +65,8 @@ const (
6465
inviteOKRetryAttempts = 5
6566
inviteOKRetryAttemptsNoACK = 2
6667
inviteOkAckLateTimeout = inviteOkRetryIntervalMax
68+
69+
inviteCredentialValidity = 60 * time.Minute // Allow reuse of credentials for 1h
6770
)
6871

6972
var errNoACK = errors.New("no ACK received for 200 OK")
@@ -137,20 +140,23 @@ func (s *Server) getCallInfo(id string) *inboundCallInfo {
137140
func (s *Server) getInvite(sipCallID string) *inProgressInvite {
138141
s.imu.Lock()
139142
defer s.imu.Unlock()
140-
for i := range s.inProgressInvites {
141-
if s.inProgressInvites[i].sipCallID == sipCallID {
142-
return s.inProgressInvites[i]
143-
}
144-
}
145-
if len(s.inProgressInvites) >= digestLimit {
146-
s.inProgressInvites = s.inProgressInvites[1:]
143+
is, ok := s.inProgressInvites[sipCallID]
144+
if ok {
145+
return is
147146
}
148-
is := &inProgressInvite{sipCallID: sipCallID}
149-
s.inProgressInvites = append(s.inProgressInvites, is)
147+
is = &inProgressInvite{sipCallID: sipCallID}
148+
s.inProgressInvites[sipCallID] = is
149+
150+
go func() {
151+
time.Sleep(inviteCredentialValidity)
152+
s.imu.Lock()
153+
defer s.imu.Unlock()
154+
delete(s.inProgressInvites, sipCallID)
155+
}()
150156
return is
151157
}
152158

153-
func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string) (ok bool) {
159+
func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string, inviteState *inProgressInvite) (ok bool) {
154160
log = log.WithValues(
155161
"username", username,
156162
"passwordHash", hashPassword(password),
@@ -178,8 +184,6 @@ func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.Se
178184
}
179185
ci := s.getCallInfo(sipCallID)
180186
ci.countInvite(log, req)
181-
inviteState := s.getInvite(sipCallID)
182-
log = log.WithValues("inviteStateSipCallID", sipCallID)
183187

184188
h := req.GetHeader("Proxy-Authorization")
185189
if h == nil {
@@ -220,7 +224,6 @@ func (s *Server) handleInviteAuth(log logger.Logger, req *sip.Request, tx sip.Se
220224
// Check if we have a valid challenge state
221225
if inviteState.challenge.Realm == "" {
222226
log.Warnw("No challenge state found for authentication attempt", errors.New("missing challenge state"),
223-
"sipCallID", sipCallID,
224227
"expectedRealm", UserAgent,
225228
)
226229
_ = tx.Respond(sip.NewResponseFromRequest(req, 401, "Bad credentials", nil))
@@ -295,18 +298,18 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
295298
s.log.Errorw("cannot parse source IP", err, "fromIP", src)
296299
return psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
297300
}
298-
callID := lksip.NewCallID()
301+
sipCallID := legCallIDFromReq(req)
299302
tr := callTransportFromReq(req)
300303
legTr := legTransportFromReq(req)
301304
log := s.log.WithValues(
302-
"callID", callID,
305+
"sipCallID", sipCallID,
303306
"fromIP", src.Addr(),
304307
"toIP", req.Destination(),
305308
"transport", tr,
306309
)
307310

308311
var call *inboundCall
309-
cc := s.newInbound(log, LocalTag(callID), s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
312+
cc := s.newInbound(log, "unassigned", s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
310313
c := call
311314
if c == nil || len(c.attrsToHdr) == 0 {
312315
return headers
@@ -319,8 +322,6 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
319322
})
320323
log = LoggerWithParams(log, cc)
321324
log = LoggerWithHeaders(log, cc)
322-
cc.log = log
323-
log.Infow("processing invite")
324325

325326
if err := cc.ValidateInvite(); err != nil {
326327
if s.conf.HideInboundPort {
@@ -330,6 +331,28 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
330331
}
331332
return psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
332333
}
334+
335+
// Establish ID
336+
if _, ok := req.To().Params.Get("tag"); !ok {
337+
// No to-tag on the invite means we need to generate one per RFC 3261 section 12.
338+
if !inviteHasAuth(req) {
339+
// No auth = a 407 response and another INVITE+auth.
340+
// Generate a new to-tag early, to make sure both INVITES have the same ID.
341+
uuid, _ := uuid.NewV4() // Same as NewResponseFromRequest in sipgo
342+
req.To().Params.Add("tag", uuid.String())
343+
}
344+
}
345+
inviteProgress := s.getInvite(req.CallID().Value())
346+
callID := inviteProgress.lkCallID
347+
if callID == "" {
348+
callID = lksip.NewCallID()
349+
inviteProgress.lkCallID = callID
350+
}
351+
352+
log = log.WithValues("callID", callID)
353+
cc.log = log
354+
log.Infow("processing invite")
355+
333356
ctx, span := tracer.Start(ctx, "Server.onInvite")
334357
defer span.End()
335358

@@ -352,12 +375,6 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
352375
cc.Processing()
353376
}
354377

355-
// Extract SIP Call ID directly from the request
356-
sipCallID := ""
357-
if h := req.CallID(); h != nil {
358-
sipCallID = h.Value()
359-
}
360-
361378
callInfo := &rpc.SIPCall{
362379
LkCallId: callID,
363380
SipCallId: sipCallID,
@@ -421,7 +438,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
421438
// We will send password request anyway, so might as well signal that the progress is made.
422439
cc.Processing()
423440
}
424-
if !s.handleInviteAuth(log, req, tx, from.User, r.Username, r.Password) {
441+
if !s.handleInviteAuth(log, req, tx, from.User, r.Username, r.Password, inviteProgress) {
425442
cmon.InviteErrorShort("unauthorized")
426443
// handleInviteAuth will generate the SIP Response as needed
427444
return psrpc.NewErrorf(psrpc.PermissionDenied, "invalid credentials were provided")

pkg/sip/protocol.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ func legTransportFromReq(req *sip.Request) Transport {
173173
return ""
174174
}
175175

176+
func legCallIDFromReq(req *sip.Request) string {
177+
if callID := req.CallID(); callID != nil {
178+
return callID.Value()
179+
}
180+
return ""
181+
}
182+
176183
func transportPort(c *config.Config, t Transport) int {
177184
if t == TransportTLS {
178185
if tc := c.TLS; tc != nil {

pkg/sip/server.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ import (
4343
)
4444

4545
const (
46-
UserAgent = "LiveKit"
47-
digestLimit = 500
46+
UserAgent = "LiveKit"
4847
)
4948

5049
const (
@@ -133,7 +132,7 @@ type Server struct {
133132
sipUnhandled RequestHandler
134133

135134
imu sync.Mutex
136-
inProgressInvites []*inProgressInvite
135+
inProgressInvites map[string]*inProgressInvite
137136

138137
closing core.Fuse
139138
cmu sync.RWMutex
@@ -155,20 +154,22 @@ type Server struct {
155154
type inProgressInvite struct {
156155
sipCallID string
157156
challenge digest.Challenge
157+
lkCallID string // SCL_* LiveKit call ID assigned to this dialog
158158
}
159159

160160
func NewServer(region string, conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Server {
161161
if log == nil {
162162
log = logger.GetLogger()
163163
}
164164
s := &Server{
165-
log: log,
166-
conf: conf,
167-
region: region,
168-
mon: mon,
169-
getIOClient: getIOClient,
170-
activeCalls: make(map[RemoteTag]*inboundCall),
171-
byLocal: make(map[LocalTag]*inboundCall),
165+
log: log,
166+
conf: conf,
167+
region: region,
168+
mon: mon,
169+
getIOClient: getIOClient,
170+
inProgressInvites: make(map[string]*inProgressInvite),
171+
activeCalls: make(map[RemoteTag]*inboundCall),
172+
byLocal: make(map[LocalTag]*inboundCall),
172173
}
173174
s.infos.byCallID = expirable.NewLRU[string, *inboundCallInfo](maxCallCache, nil, callCacheTTL)
174175
s.initMediaRes()

0 commit comments

Comments
 (0)