Skip to content

Commit 90c08e6

Browse files
authored
Revert "Reuse the same ID for both auth-less and auth-ful INVITEs" (#533)
1 parent 44e2418 commit 90c08e6

File tree

5 files changed

+60
-286
lines changed

5 files changed

+60
-286
lines changed

pkg/sip/inbound.go

Lines changed: 35 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ import (
4242
"github.com/livekit/protocol/rpc"
4343
lksip "github.com/livekit/protocol/sip"
4444
"github.com/livekit/protocol/tracer"
45-
"github.com/livekit/protocol/utils"
4645
"github.com/livekit/protocol/utils/traceid"
4746
"github.com/livekit/psrpc"
4847
lksdk "github.com/livekit/server-sdk-go/v2"
@@ -65,8 +64,6 @@ const (
6564
inviteOKRetryAttempts = 5
6665
inviteOKRetryAttemptsNoACK = 2
6766
inviteOkAckLateTimeout = inviteOkRetryIntervalMax
68-
69-
inviteCredentialValidity = 60 * time.Minute // Allow reuse of credentials for 1h
7067
)
7168

7269
var errNoACK = errors.New("no ACK received for 200 OK")
@@ -137,50 +134,23 @@ func (s *Server) getCallInfo(id string) *inboundCallInfo {
137134
return c
138135
}
139136

140-
func (s *Server) cleanupInvites() {
141-
ticker := time.NewTicker(5 * time.Minute) // Periodic cleanup every 5 minutes
142-
defer ticker.Stop()
143-
for {
144-
select {
145-
case <-s.closing.Watch():
146-
return
147-
case <-ticker.C:
148-
s.imu.Lock()
149-
for it := s.inviteTimeoutQueue.IterateRemoveAfter(inviteCredentialValidity); it.Next(); {
150-
key := it.Item().Value
151-
delete(s.inProgressInvites, key)
152-
}
153-
s.imu.Unlock()
137+
func (s *Server) getInvite(sipCallID string) *inProgressInvite {
138+
s.imu.Lock()
139+
defer s.imu.Unlock()
140+
for i := range s.inProgressInvites {
141+
if s.inProgressInvites[i].sipCallID == sipCallID {
142+
return s.inProgressInvites[i]
154143
}
155144
}
156-
}
157-
158-
func (s *Server) getInvite(sipCallID, toTag, fromTag string) *inProgressInvite {
159-
key := dialogKey{
160-
sipCallID: sipCallID,
161-
toTag: toTag,
162-
fromTag: fromTag,
163-
}
164-
165-
s.imu.RLock()
166-
is, exists := s.inProgressInvites[key]
167-
s.imu.RUnlock()
168-
if !exists {
169-
s.imu.Lock()
170-
is, exists = s.inProgressInvites[key]
171-
if !exists {
172-
is = &inProgressInvite{sipCallID: sipCallID, timeoutLink: utils.TimeoutQueueItem[dialogKey]{Value: key}}
173-
s.inProgressInvites[key] = is
174-
}
175-
s.imu.Unlock()
145+
if len(s.inProgressInvites) >= digestLimit {
146+
s.inProgressInvites = s.inProgressInvites[1:]
176147
}
177-
178-
// Always reset the timeout link, whether just created or not
179-
s.inviteTimeoutQueue.Reset(&is.timeoutLink)
148+
is := &inProgressInvite{sipCallID: sipCallID}
149+
s.inProgressInvites = append(s.inProgressInvites, is)
180150
return is
181151
}
182152

183-
func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string, inviteState *inProgressInvite) (ok bool) {
153+
func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string) (ok bool) {
184154
log = log.WithValues(
185155
"username", username,
186156
"passwordHash", hashPassword(password),
@@ -201,6 +171,14 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
201171
_ = tx.Respond(sip.NewResponseFromRequest(req, 100, "Processing", nil))
202172
}
203173

174+
// Extract SIP Call ID for tracking in-progress invites
175+
sipCallID := ""
176+
if h := req.CallID(); h != nil {
177+
sipCallID = h.Value()
178+
}
179+
inviteState := s.getInvite(sipCallID)
180+
log = log.WithValues("inviteStateSipCallID", sipCallID)
181+
204182
h := req.GetHeader("Proxy-Authorization")
205183
if h == nil {
206184
inviteState.challenge = digest.Challenge{
@@ -252,6 +230,7 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
252230
// Check if we have a valid challenge state
253231
if inviteState.challenge.Realm == "" {
254232
log.Warnw("No challenge state found for authentication attempt", errors.New("missing challenge state"),
233+
"sipCallID", sipCallID,
255234
"expectedRealm", UserAgent,
256235
)
257236
_ = tx.Respond(sip.NewResponseFromRequest(req, 401, "Bad credentials", nil))
@@ -326,18 +305,20 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
326305
s.log.Errorw("cannot parse source IP", err, "fromIP", src)
327306
return psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
328307
}
329-
sipCallID := legCallIDFromReq(req)
308+
callID := lksip.NewCallID()
309+
tid := traceid.FromGUID(callID)
330310
tr := callTransportFromReq(req)
331311
legTr := legTransportFromReq(req)
332312
log := s.log.WithValues(
333-
"sipCallID", sipCallID,
313+
"callID", callID,
314+
"traceID", tid.String(),
334315
"fromIP", src.Addr(),
335316
"toIP", req.Destination(),
336317
"transport", tr,
337318
)
338319

339320
var call *inboundCall
340-
cc := s.newInbound(log, s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
321+
cc := s.newInbound(log, LocalTag(callID), s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
341322
c := call
342323
if c == nil || len(c.attrsToHdr) == 0 {
343324
return headers
@@ -350,53 +331,25 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
350331
})
351332
log = LoggerWithParams(log, cc)
352333
log = LoggerWithHeaders(log, cc)
334+
cc.log = log
335+
log.Infow("processing invite")
353336

354337
if err := cc.ValidateInvite(); err != nil {
355-
log.Errorw("invalid invite", err)
356338
if s.conf.HideInboundPort {
357339
cc.Drop()
358340
} else {
359341
cc.RespondAndDrop(sip.StatusBadRequest, "Bad request")
360342
}
361343
return psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
362344
}
363-
364-
// Establish ID
365-
fromTag, _ := req.From().Params.Get("tag") // always exists, via ValidateInvite() check
366-
toParams := req.To().Params // To() always exists, via ValidateInvite() check
367-
if toParams == nil {
368-
toParams = sip.NewParams()
369-
req.To().Params = toParams
370-
}
371-
toTag, ok := toParams.Get("tag")
372-
if !ok {
373-
// No to-tag on the invite means we need to generate one per RFC 3261 section 12.
374-
// Generate a new to-tag early, to make sure both INVITES have the same ID.
375-
toTag = utils.NewGuid("")
376-
toParams.Add("tag", toTag)
377-
}
378-
inviteProgress := s.getInvite(sipCallID, toTag, fromTag)
379-
callID := inviteProgress.lkCallID
380-
if callID == "" {
381-
callID = lksip.NewCallID()
382-
inviteProgress.lkCallID = callID
383-
}
384-
cc.id = LocalTag(callID)
385-
tid := traceid.FromGUID(sipCallID)
386-
387-
log = log.WithValues("callID", callID)
388-
log = log.WithValues("traceID", tid.String())
389-
cc.log = log
390-
log.Infow("processing invite")
391-
392345
ctx, span := tracer.Start(ctx, "Server.onInvite")
393346
defer span.End()
394347

395348
s.cmu.RLock()
396-
existing := s.byCallID[sipCallID]
349+
existing := s.byCallID[cc.SIPCallID()]
397350
s.cmu.RUnlock()
398351
if existing != nil && existing.cc.InviteCSeq() < cc.InviteCSeq() {
399-
log.Infow("accepting reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength())
352+
log.Infow("accepting reinvite", "sipCallID", existing.cc.ID(), "content-type", req.ContentType(), "content-length", req.ContentLength())
400353
existing.log().Infow("reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength(), "cseq", cc.InviteCSeq())
401354
cc.AcceptAsKeepAlive(existing.cc.OwnSDP())
402355
return nil
@@ -423,7 +376,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
423376

424377
callInfo := &rpc.SIPCall{
425378
LkCallId: callID,
426-
SipCallId: sipCallID,
379+
SipCallId: cc.SIPCallID(),
427380
SourceIp: src.Addr().String(),
428381
Address: ToSIPUri("", cc.Address()),
429382
From: ToSIPUri("", from),
@@ -494,15 +447,15 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
494447
// We will send password request anyway, so might as well signal that the progress is made.
495448
cc.Processing()
496449
}
497-
s.getCallInfo(sipCallID).countInvite(log, req)
498-
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password, inviteProgress) {
450+
s.getCallInfo(cc.SIPCallID()).countInvite(log, req)
451+
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password) {
499452
cmon.InviteErrorShort("unauthorized")
500453
// handleInviteAuth will generate the SIP Response as needed
501454
return psrpc.NewErrorf(psrpc.PermissionDenied, "invalid credentials were provided")
502455
}
503456
// ok
504457
case AuthAccept:
505-
s.getCallInfo(sipCallID).countInvite(log, req)
458+
s.getCallInfo(cc.SIPCallID()).countInvite(log, req)
506459
// ok
507460
}
508461

@@ -1422,10 +1375,11 @@ func (c *inboundCall) transferCall(ctx context.Context, transferTo string, heade
14221375

14231376
}
14241377

1425-
func (s *Server) newInbound(log logger.Logger, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound {
1378+
func (s *Server) newInbound(log logger.Logger, id LocalTag, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound {
14261379
c := &sipInbound{
14271380
log: log,
14281381
s: s,
1382+
id: id,
14291383
invite: invite,
14301384
inviteTx: inviteTx,
14311385
legTr: legTransportFromReq(invite),

pkg/sip/outbound.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ authLoop:
852852
if err != nil {
853853
return nil, fmt.Errorf("invalid challenge %q: %w", challengeStr, err)
854854
}
855-
toHeader = resp.To()
855+
toHeader := resp.To()
856856
if toHeader == nil {
857857
return nil, errors.New("no 'To' header on Response")
858858
}

pkg/sip/protocol.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,6 @@ func legTransportFromReq(req *sip.Request) Transport {
173173
return ""
174174
}
175175

176-
func legCallIDFromReq(req *sip.Request) string {
177-
if callID := req.CallID(); callID != nil {
178-
return callID.Value()
179-
}
180-
return ""
181-
}
182-
183176
func transportPort(c *config.Config, t Transport) int {
184177
if t == TransportTLS {
185178
if tc := c.TLS; tc != nil {

pkg/sip/server.go

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import (
3535
"github.com/livekit/protocol/livekit"
3636
"github.com/livekit/protocol/logger"
3737
"github.com/livekit/protocol/rpc"
38-
"github.com/livekit/protocol/utils"
3938
"github.com/livekit/protocol/utils/traceid"
4039
"github.com/livekit/sipgo"
4140
"github.com/livekit/sipgo/sip"
@@ -45,7 +44,8 @@ import (
4544
)
4645

4746
const (
48-
UserAgent = "LiveKit"
47+
UserAgent = "LiveKit"
48+
digestLimit = 500
4949
)
5050

5151
const (
@@ -127,25 +127,18 @@ type Handler interface {
127127
OnSessionEnd(ctx context.Context, callIdentifier *CallIdentifier, callInfo *livekit.SIPCallInfo, reason string)
128128
}
129129

130-
type dialogKey struct {
131-
sipCallID string
132-
toTag string
133-
fromTag string
134-
}
135-
136130
type Server struct {
137-
log logger.Logger
138-
mon *stats.Monitor
139-
region string
140-
sipSrv *sipgo.Server
141-
getIOClient GetIOInfoClient
142-
getRoom GetRoomFunc
143-
sipListeners []io.Closer
144-
sipUnhandled RequestHandler
145-
inviteTimeoutQueue utils.TimeoutQueue[dialogKey]
146-
147-
imu sync.RWMutex
148-
inProgressInvites map[dialogKey]*inProgressInvite
131+
log logger.Logger
132+
mon *stats.Monitor
133+
region string
134+
sipSrv *sipgo.Server
135+
getIOClient GetIOInfoClient
136+
getRoom GetRoomFunc
137+
sipListeners []io.Closer
138+
sipUnhandled RequestHandler
139+
140+
imu sync.Mutex
141+
inProgressInvites []*inProgressInvite
149142

150143
closing core.Fuse
151144
cmu sync.RWMutex
@@ -166,10 +159,8 @@ type Server struct {
166159
}
167160

168161
type inProgressInvite struct {
169-
sipCallID string
170-
challenge digest.Challenge
171-
lkCallID string // SCL_* LiveKit call ID assigned to this dialog
172-
timeoutLink utils.TimeoutQueueItem[dialogKey]
162+
sipCallID string
163+
challenge digest.Challenge
173164
}
174165

175166
type ServerOption func(s *Server)
@@ -187,16 +178,15 @@ func NewServer(region string, conf *config.Config, log logger.Logger, mon *stats
187178
log = logger.GetLogger()
188179
}
189180
s := &Server{
190-
log: log,
191-
conf: conf,
192-
region: region,
193-
mon: mon,
194-
getIOClient: getIOClient,
195-
getRoom: DefaultGetRoomFunc,
196-
inProgressInvites: make(map[dialogKey]*inProgressInvite),
197-
byRemoteTag: make(map[RemoteTag]*inboundCall),
198-
byLocalTag: make(map[LocalTag]*inboundCall),
199-
byCallID: make(map[string]*inboundCall),
181+
log: log,
182+
conf: conf,
183+
region: region,
184+
mon: mon,
185+
getIOClient: getIOClient,
186+
getRoom: DefaultGetRoomFunc,
187+
byRemoteTag: make(map[RemoteTag]*inboundCall),
188+
byLocalTag: make(map[LocalTag]*inboundCall),
189+
byCallID: make(map[string]*inboundCall),
200190
}
201191
for _, option := range options {
202192
option(s)
@@ -340,9 +330,6 @@ func (s *Server) Start(agent *sipgo.UserAgent, sc *ServiceConfig, tlsConf *tls.C
340330
}
341331
}
342332

343-
// Start the cleanup task
344-
go s.cleanupInvites()
345-
346333
return nil
347334
}
348335

0 commit comments

Comments
 (0)