Skip to content

Commit 9e46651

Browse files
committed
inProgressInvite lifecycle, stable ID generation, minor log field reshuffling
1 parent a8d8d8b commit 9e46651

File tree

3 files changed

+59
-27
lines changed

3 files changed

+59
-27
lines changed

pkg/sip/inbound.go

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import (
2828
"sync/atomic"
2929
"time"
3030

31+
uuid "github.com/satori/go.uuid"
32+
3133
"github.com/frostbyte73/core"
3234
"github.com/icholy/digest"
3335
"github.com/pkg/errors"
@@ -64,6 +66,8 @@ const (
6466
inviteOKRetryAttempts = 5
6567
inviteOKRetryAttemptsNoACK = 2
6668
inviteOkAckLateTimeout = inviteOkRetryIntervalMax
69+
70+
inviteCredentialValidity = 60 * time.Minute // Allow reuse of credentials for 1h
6771
)
6872

6973
var errNoACK = errors.New("no ACK received for 200 OK")
@@ -137,20 +141,23 @@ func (s *Server) getCallInfo(id string) *inboundCallInfo {
137141
func (s *Server) getInvite(sipCallID string) *inProgressInvite {
138142
s.imu.Lock()
139143
defer s.imu.Unlock()
140-
for i := range s.inProgressInvites {
141-
if s.inProgressInvites[i].sipCallID == sipCallID {
142-
return s.inProgressInvites[i]
143-
}
144+
is, ok := s.inProgressInvites[sipCallID]
145+
if ok {
146+
return is
144147
}
145-
if len(s.inProgressInvites) >= digestLimit {
146-
s.inProgressInvites = s.inProgressInvites[1:]
147-
}
148-
is := &inProgressInvite{sipCallID: sipCallID}
149-
s.inProgressInvites = append(s.inProgressInvites, is)
148+
is = &inProgressInvite{sipCallID: sipCallID}
149+
s.inProgressInvites[sipCallID] = is
150+
151+
go func() {
152+
time.Sleep(inviteCredentialValidity)
153+
s.imu.Lock()
154+
defer s.imu.Unlock()
155+
delete(s.inProgressInvites, sipCallID)
156+
}()
150157
return is
151158
}
152159

153-
func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string) (ok bool) {
160+
func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string, inviteState *inProgressInvite) (ok bool) {
154161
log = log.WithValues(
155162
"username", username,
156163
"passwordHash", hashPassword(password),
@@ -176,8 +183,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
176183
if h := req.CallID(); h != nil {
177184
sipCallID = h.Value()
178185
}
179-
inviteState := s.getInvite(sipCallID)
180-
log = log.WithValues("inviteStateSipCallID", sipCallID)
181186

182187
h := req.GetHeader("Proxy-Authorization")
183188
if h == nil {
@@ -230,7 +235,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
230235
// Check if we have a valid challenge state
231236
if inviteState.challenge.Realm == "" {
232237
log.Warnw("No challenge state found for authentication attempt", errors.New("missing challenge state"),
233-
"sipCallID", sipCallID,
234238
"expectedRealm", UserAgent,
235239
)
236240
_ = tx.Respond(sip.NewResponseFromRequest(req, 401, "Bad credentials", nil))
@@ -305,20 +309,20 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
305309
s.log.Errorw("cannot parse source IP", err, "fromIP", src)
306310
return psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
307311
}
308-
callID := lksip.NewCallID()
312+
sipCallID := legCallIDFromReq(req)
309313
tid := traceid.FromGUID(callID)
310314
tr := callTransportFromReq(req)
311315
legTr := legTransportFromReq(req)
312316
log := s.log.WithValues(
313-
"callID", callID,
317+
"sipCallID", sipCallID,
314318
"traceID", tid.String(),
315319
"fromIP", src.Addr(),
316320
"toIP", req.Destination(),
317321
"transport", tr,
318322
)
319323

320324
var call *inboundCall
321-
cc := s.newInbound(log, LocalTag(callID), s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
325+
cc := s.newInbound(log, "unassigned", s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
322326
c := call
323327
if c == nil || len(c.attrsToHdr) == 0 {
324328
return headers
@@ -331,8 +335,6 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
331335
})
332336
log = LoggerWithParams(log, cc)
333337
log = LoggerWithHeaders(log, cc)
334-
cc.log = log
335-
log.Infow("processing invite")
336338

337339
if err := cc.ValidateInvite(); err != nil {
338340
if s.conf.HideInboundPort {
@@ -342,6 +344,28 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
342344
}
343345
return psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
344346
}
347+
348+
// Establish ID
349+
if _, ok := req.To().Params.Get("tag"); !ok {
350+
// No to-tag on the invite means we need to generate one per RFC 3261 section 12.
351+
if !inviteHasAuth(req) {
352+
// No auth = a 407 response and another INVITE+auth.
353+
// Generate a new to-tag early, to make sure both INVITES have the same ID.
354+
uuid, _ := uuid.NewV4() // Same as NewResponseFromRequest in sipgo
355+
req.To().Params.Add("tag", uuid.String())
356+
}
357+
}
358+
inviteProgress := s.getInvite(req.CallID().Value())
359+
callID := inviteProgress.lkCallID
360+
if callID == "" {
361+
callID = lksip.NewCallID()
362+
inviteProgress.lkCallID = callID
363+
}
364+
365+
log = log.WithValues("callID", callID)
366+
cc.log = log
367+
log.Infow("processing invite")
368+
345369
ctx, span := tracer.Start(ctx, "Server.onInvite")
346370
defer span.End()
347371

@@ -448,7 +472,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
448472
cc.Processing()
449473
}
450474
s.getCallInfo(cc.SIPCallID()).countInvite(log, req)
451-
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password) {
475+
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password, inviteProgress) {
452476
cmon.InviteErrorShort("unauthorized")
453477
// handleInviteAuth will generate the SIP Response as needed
454478
return psrpc.NewErrorf(psrpc.PermissionDenied, "invalid credentials were provided")

pkg/sip/protocol.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ func legTransportFromReq(req *sip.Request) Transport {
173173
return ""
174174
}
175175

176+
func legCallIDFromReq(req *sip.Request) string {
177+
if callID := req.CallID(); callID != nil {
178+
return callID.Value()
179+
}
180+
return ""
181+
}
182+
176183
func transportPort(c *config.Config, t Transport) int {
177184
if t == TransportTLS {
178185
if tc := c.TLS; tc != nil {

pkg/sip/server.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ import (
4444
)
4545

4646
const (
47-
UserAgent = "LiveKit"
48-
digestLimit = 500
47+
UserAgent = "LiveKit"
4948
)
5049

5150
const (
@@ -137,7 +136,7 @@ type Server struct {
137136
sipUnhandled RequestHandler
138137

139138
imu sync.Mutex
140-
inProgressInvites []*inProgressInvite
139+
inProgressInvites map[string]*inProgressInvite
141140

142141
closing core.Fuse
143142
cmu sync.RWMutex
@@ -160,18 +159,20 @@ type Server struct {
160159
type inProgressInvite struct {
161160
sipCallID string
162161
challenge digest.Challenge
162+
lkCallID string // SCL_* LiveKit call ID assigned to this dialog
163163
}
164164

165165
func NewServer(region string, conf *config.Config, log logger.Logger, mon *stats.Monitor, getIOClient GetIOInfoClient) *Server {
166166
if log == nil {
167167
log = logger.GetLogger()
168168
}
169169
s := &Server{
170-
log: log,
171-
conf: conf,
172-
region: region,
173-
mon: mon,
174-
getIOClient: getIOClient,
170+
log: log,
171+
conf: conf,
172+
region: region,
173+
mon: mon,
174+
getIOClient: getIOClient,
175+
inProgressInvites: make(map[string]*inProgressInvite),
175176
byRemoteTag: make(map[RemoteTag]*inboundCall),
176177
byLocalTag: make(map[LocalTag]*inboundCall),
177178
byCallID: make(map[string]*inboundCall),

0 commit comments

Comments
 (0)