Skip to content
116 changes: 81 additions & 35 deletions pkg/sip/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/livekit/protocol/rpc"
lksip "github.com/livekit/protocol/sip"
"github.com/livekit/protocol/tracer"
"github.com/livekit/protocol/utils"
"github.com/livekit/protocol/utils/traceid"
"github.com/livekit/psrpc"
lksdk "github.com/livekit/server-sdk-go/v2"
Expand All @@ -64,6 +65,8 @@ const (
inviteOKRetryAttempts = 5
inviteOKRetryAttemptsNoACK = 2
inviteOkAckLateTimeout = inviteOkRetryIntervalMax

inviteCredentialValidity = 60 * time.Minute // Allow reuse of credentials for 1h
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is smart, might want to extend this to max_call_duration or something.

Also, keep in mind that this is per Call-ID for now, so new calls would still need re-auth.

)

var errNoACK = errors.New("no ACK received for 200 OK")
Expand Down Expand Up @@ -134,23 +137,50 @@ func (s *Server) getCallInfo(id string) *inboundCallInfo {
return c
}

func (s *Server) getInvite(sipCallID string) *inProgressInvite {
s.imu.Lock()
defer s.imu.Unlock()
for i := range s.inProgressInvites {
if s.inProgressInvites[i].sipCallID == sipCallID {
return s.inProgressInvites[i]
func (s *Server) cleanupInvites() {
ticker := time.NewTicker(5 * time.Minute) // Periodic cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-s.closing.Watch():
return
case <-ticker.C:
s.imu.Lock()
for it := s.inviteTimeoutQueue.IterateRemoveAfter(inviteCredentialValidity); it.Next(); {
key := it.Item().Value
delete(s.inProgressInvites, key)
}
s.imu.Unlock()
}
}
if len(s.inProgressInvites) >= digestLimit {
s.inProgressInvites = s.inProgressInvites[1:]
}

func (s *Server) getInvite(sipCallID, toTag, fromTag string) *inProgressInvite {
key := dialogKey{
sipCallID: sipCallID,
toTag: toTag,
fromTag: fromTag,
}

s.imu.RLock()
is, exists := s.inProgressInvites[key]
s.imu.RUnlock()
if !exists {
s.imu.Lock()
is, exists = s.inProgressInvites[key]
if !exists {
is = &inProgressInvite{sipCallID: sipCallID, timeoutLink: utils.TimeoutQueueItem[dialogKey]{Value: key}}
s.inProgressInvites[key] = is
}
s.imu.Unlock()
}
is := &inProgressInvite{sipCallID: sipCallID}
s.inProgressInvites = append(s.inProgressInvites, is)

// Always reset the timeout link, whether just created or not
s.inviteTimeoutQueue.Reset(&is.timeoutLink)
return is
}

func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string) (ok bool) {
func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string, inviteState *inProgressInvite) (ok bool) {
log = log.WithValues(
"username", username,
"passwordHash", hashPassword(password),
Expand All @@ -171,14 +201,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
_ = tx.Respond(sip.NewResponseFromRequest(req, 100, "Processing", nil))
}

// Extract SIP Call ID for tracking in-progress invites
sipCallID := ""
if h := req.CallID(); h != nil {
sipCallID = h.Value()
}
inviteState := s.getInvite(sipCallID)
log = log.WithValues("inviteStateSipCallID", sipCallID)

h := req.GetHeader("Proxy-Authorization")
if h == nil {
inviteState.challenge = digest.Challenge{
Expand Down Expand Up @@ -230,7 +252,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re
// Check if we have a valid challenge state
if inviteState.challenge.Realm == "" {
log.Warnw("No challenge state found for authentication attempt", errors.New("missing challenge state"),
"sipCallID", sipCallID,
"expectedRealm", UserAgent,
)
_ = tx.Respond(sip.NewResponseFromRequest(req, 401, "Bad credentials", nil))
Expand Down Expand Up @@ -305,20 +326,18 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
s.log.Errorw("cannot parse source IP", err, "fromIP", src)
return psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP"))
}
callID := lksip.NewCallID()
tid := traceid.FromGUID(callID)
sipCallID := legCallIDFromReq(req)
tr := callTransportFromReq(req)
legTr := legTransportFromReq(req)
log := s.log.WithValues(
"callID", callID,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one log line that will not have callID now is "Bad request", when validation fails.

"traceID", tid.String(),
"sipCallID", sipCallID,
"fromIP", src.Addr(),
"toIP", req.Destination(),
"transport", tr,
)

var call *inboundCall
cc := s.newInbound(log, LocalTag(callID), s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
cc := s.newInbound(log, s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string {
c := call
if c == nil || len(c.attrsToHdr) == 0 {
return headers
Expand All @@ -331,25 +350,53 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
})
log = LoggerWithParams(log, cc)
log = LoggerWithHeaders(log, cc)
cc.log = log
log.Infow("processing invite")

if err := cc.ValidateInvite(); err != nil {
log.Errorw("invalid invite", err)
if s.conf.HideInboundPort {
cc.Drop()
} else {
cc.RespondAndDrop(sip.StatusBadRequest, "Bad request")
}
return psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed"))
}

// Establish ID
fromTag, _ := req.From().Params.Get("tag") // always exists, via ValidateInvite() check
toParams := req.To().Params // To() always exists, via ValidateInvite() check
if toParams == nil {
toParams = sip.NewParams()
req.To().Params = toParams
}
toTag, ok := toParams.Get("tag")
if !ok {
// No to-tag on the invite means we need to generate one per RFC 3261 section 12.
// Generate a new to-tag early, to make sure both INVITES have the same ID.
toTag = utils.NewGuid("")
toParams.Add("tag", toTag)
}
inviteProgress := s.getInvite(sipCallID, toTag, fromTag)
callID := inviteProgress.lkCallID
if callID == "" {
callID = lksip.NewCallID()
inviteProgress.lkCallID = callID
}
cc.id = LocalTag(callID)
tid := traceid.FromGUID(sipCallID)

log = log.WithValues("callID", callID)
log = log.WithValues("traceID", tid.String())
cc.log = log
log.Infow("processing invite")

ctx, span := tracer.Start(ctx, "Server.onInvite")
defer span.End()

s.cmu.RLock()
existing := s.byCallID[cc.SIPCallID()]
existing := s.byCallID[sipCallID]
s.cmu.RUnlock()
if existing != nil && existing.cc.InviteCSeq() < cc.InviteCSeq() {
log.Infow("accepting reinvite", "sipCallID", existing.cc.ID(), "content-type", req.ContentType(), "content-length", req.ContentLength())
log.Infow("accepting reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength())
existing.log().Infow("reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength(), "cseq", cc.InviteCSeq())
cc.AcceptAsKeepAlive(existing.cc.OwnSDP())
return nil
Expand All @@ -376,7 +423,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE

callInfo := &rpc.SIPCall{
LkCallId: callID,
SipCallId: cc.SIPCallID(),
SipCallId: sipCallID,
SourceIp: src.Addr().String(),
Address: ToSIPUri("", cc.Address()),
From: ToSIPUri("", from),
Expand Down Expand Up @@ -447,15 +494,15 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE
// We will send password request anyway, so might as well signal that the progress is made.
cc.Processing()
}
s.getCallInfo(cc.SIPCallID()).countInvite(log, req)
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password) {
s.getCallInfo(sipCallID).countInvite(log, req)
if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password, inviteProgress) {
cmon.InviteErrorShort("unauthorized")
// handleInviteAuth will generate the SIP Response as needed
return psrpc.NewErrorf(psrpc.PermissionDenied, "invalid credentials were provided")
}
// ok
case AuthAccept:
s.getCallInfo(cc.SIPCallID()).countInvite(log, req)
s.getCallInfo(sipCallID).countInvite(log, req)
// ok
}

Expand Down Expand Up @@ -1366,11 +1413,10 @@ func (c *inboundCall) transferCall(ctx context.Context, transferTo string, heade

}

func (s *Server) newInbound(log logger.Logger, id LocalTag, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound {
func (s *Server) newInbound(log logger.Logger, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound {
c := &sipInbound{
log: log,
s: s,
id: id,
invite: invite,
inviteTx: inviteTx,
legTr: legTransportFromReq(invite),
Expand Down
2 changes: 1 addition & 1 deletion pkg/sip/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ authLoop:
if err != nil {
return nil, fmt.Errorf("invalid challenge %q: %w", challengeStr, err)
}
toHeader := resp.To()
Copy link
Contributor Author

@alexlivekit alexlivekit Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're doing the right validation on the inbound side, the outbound side E2E test caught this error!
But this also means out clients might run into the same issue, in case some of them are not spec-compliant.

toHeader = resp.To()
if toHeader == nil {
return nil, errors.New("no 'To' header on Response")
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/sip/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ func legTransportFromReq(req *sip.Request) Transport {
return ""
}

func legCallIDFromReq(req *sip.Request) string {
if callID := req.CallID(); callID != nil {
return callID.Value()
}
return ""
}

func transportPort(c *config.Config, t Transport) int {
if t == TransportTLS {
if tc := c.TLS; tc != nil {
Expand Down
61 changes: 37 additions & 24 deletions pkg/sip/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
"github.com/livekit/protocol/utils/traceid"
"github.com/livekit/sipgo"
"github.com/livekit/sipgo/sip"
Expand All @@ -44,8 +45,7 @@ import (
)

const (
UserAgent = "LiveKit"
digestLimit = 500
UserAgent = "LiveKit"
)

const (
Expand Down Expand Up @@ -127,18 +127,25 @@ type Handler interface {
OnSessionEnd(ctx context.Context, callIdentifier *CallIdentifier, callInfo *livekit.SIPCallInfo, reason string)
}

type dialogKey struct {
sipCallID string
toTag string
fromTag string
}

type Server struct {
log logger.Logger
mon *stats.Monitor
region string
sipSrv *sipgo.Server
getIOClient GetIOInfoClient
getRoom GetRoomFunc
sipListeners []io.Closer
sipUnhandled RequestHandler

imu sync.Mutex
inProgressInvites []*inProgressInvite
log logger.Logger
mon *stats.Monitor
region string
sipSrv *sipgo.Server
getIOClient GetIOInfoClient
getRoom GetRoomFunc
sipListeners []io.Closer
sipUnhandled RequestHandler
inviteTimeoutQueue utils.TimeoutQueue[dialogKey]

imu sync.RWMutex
inProgressInvites map[dialogKey]*inProgressInvite

closing core.Fuse
cmu sync.RWMutex
Expand All @@ -159,8 +166,10 @@ type Server struct {
}

type inProgressInvite struct {
sipCallID string
challenge digest.Challenge
sipCallID string
challenge digest.Challenge
lkCallID string // SCL_* LiveKit call ID assigned to this dialog
timeoutLink utils.TimeoutQueueItem[dialogKey]
}

type ServerOption func(s *Server)
Expand All @@ -178,15 +187,16 @@ func NewServer(region string, conf *config.Config, log logger.Logger, mon *stats
log = logger.GetLogger()
}
s := &Server{
log: log,
conf: conf,
region: region,
mon: mon,
getIOClient: getIOClient,
getRoom: DefaultGetRoomFunc,
byRemoteTag: make(map[RemoteTag]*inboundCall),
byLocalTag: make(map[LocalTag]*inboundCall),
byCallID: make(map[string]*inboundCall),
log: log,
conf: conf,
region: region,
mon: mon,
getIOClient: getIOClient,
getRoom: DefaultGetRoomFunc,
inProgressInvites: make(map[dialogKey]*inProgressInvite),
byRemoteTag: make(map[RemoteTag]*inboundCall),
byLocalTag: make(map[LocalTag]*inboundCall),
byCallID: make(map[string]*inboundCall),
}
for _, option := range options {
option(s)
Expand Down Expand Up @@ -330,6 +340,9 @@ func (s *Server) Start(agent *sipgo.UserAgent, sc *ServiceConfig, tlsConf *tls.C
}
}

// Start the cleanup task
go s.cleanupInvites()

return nil
}

Expand Down
Loading
Loading