diff --git a/learn/parse_http.go b/learn/parse_http.go index 7b37c0ef..e3511548 100644 --- a/learn/parse_http.go +++ b/learn/parse_http.go @@ -101,14 +101,14 @@ func ParseHTTP(elem akinet.ParsedNetworkContent) (*PartialWitness, error) { var streamID uuid.UUID var seq int + xForwardedFor := optionals.None[string]() switch t := elem.(type) { case akinet.HTTPRequest: streamID = t.StreamID seq = t.Seq - isRequest = true - methodMeta, datas = parseRequest(&t) + methodMeta, datas, xForwardedFor = parseRequest(&t) rawBody = t.Body bodyDecompressed = t.BodyDecompressed headers = t.Header @@ -187,8 +187,9 @@ func ParseHTTP(elem akinet.ParsedNetworkContent) (*PartialWitness, error) { } return &PartialWitness{ - Witness: &pb.Witness{Method: method}, - PairKey: toWitnessID(streamID, seq), + Witness: &pb.Witness{Method: method}, + PairKey: toWitnessID(streamID, seq), + XForwardedFor: xForwardedFor, }, nil } @@ -487,14 +488,14 @@ func parseMultipartBody(multipartType string, boundary string, bodyStream io.Rea }, nil } -func parseRequest(req *akinet.HTTPRequest) (*pb.MethodMeta, []*pb.Data) { +func parseRequest(req *akinet.HTTPRequest) (*pb.MethodMeta, []*pb.Data, optionals.Optional[string]) { datas := []*pb.Data{} noStatusCode := optionals.None[int]() datas = append(datas, parseQuery(req.URL)...) datas = append(datas, parseHeader(req.Header, noStatusCode)...) datas = append(datas, parseCookies(req.Cookies, noStatusCode)...) - return parseMethodMeta(req), datas + return parseMethodMeta(req), datas, parseLoadBalancer(req.Header) } func parseResponse(resp *akinet.HTTPResponse) []*pb.Data { @@ -523,6 +524,19 @@ func parseCookies(cs []*http.Cookie, responseCodeOpt optionals.Optional[int]) [] return datas } +// Extract the X-Forwarded-For header, if present +func parseLoadBalancer(header http.Header) optionals.Optional[string] { + for k, vs := range header { + switch strings.ToLower(k) { + case "x-forwarded-for": + if len(vs) > 0 { + return optionals.Some(vs[0]) + } + } + } + return optionals.None[string]() +} + // Translate headers to data objects. optionals.None indicates that the header // is in a request. func parseHeader(header http.Header, responseCodeOpt optionals.Optional[int]) []*pb.Data { @@ -551,11 +565,14 @@ func parseHeader(header http.Header, responseCodeOpt optionals.Optional[int]) [] switch strings.ToLower(k) { case "cookie", "set-cookie": - // Cookies are parsed by parseHeader. + // Cookies are parsed by parseCookies. continue case "content-type": // Handled by parseBody. continue + case "x-forwarded-for": + // Handled by parseLoadBalancer + continue case "authorization": // If the authorization header is in the request, create an // HTTPAuth object. Treat authorization headers in the response diff --git a/learn/types.go b/learn/types.go index 0853bf15..ef482f05 100644 --- a/learn/types.go +++ b/learn/types.go @@ -5,8 +5,9 @@ import ( "github.com/google/uuid" - "github.com/akitasoftware/akita-libs/akid" pb "github.com/akitasoftware/akita-ir/go/api_spec" + "github.com/akitasoftware/akita-libs/akid" + "github.com/akitasoftware/go-utils/optionals" ) var ( @@ -22,6 +23,9 @@ type PartialWitness struct { // Key used to pair this PartialWitness up with its counterpart. PairKey akid.WitnessID + + // Request header's X-forwarded-for header, the real client address + XForwardedFor optionals.Optional[string] } // Generates a v5 UUID as witness ID based on stream ID and seq. diff --git a/trace/backend_collector.go b/trace/backend_collector.go index 3174cc5c..11012a3e 100644 --- a/trace/backend_collector.go +++ b/trace/backend_collector.go @@ -3,6 +3,7 @@ package trace import ( "encoding/base64" "net" + "strings" "sync" "time" @@ -51,7 +52,8 @@ type witnessWithInfo struct { requestEnd time.Time responseStart time.Time - witness *pb.Witness + witness *pb.Witness + xForwardedFor optionals.Optional[string] } func (r witnessWithInfo) toReport() (*kgxapi.WitnessReport, error) { @@ -64,6 +66,13 @@ func (r witnessWithInfo) toReport() (*kgxapi.WitnessReport, error) { return nil, errors.Wrap(err, "failed to marshal witness proto") } + if header, ok := r.xForwardedFor.Get(); ok { + printer.Infof("Real source IP %v, X-Forwarded-For header %v\n", r.srcIP, header) + addrs := strings.Split(header, ",") + // Replace with the leftmost IP, which is the closest to the original (???) + r.srcIP = net.ParseIP(addrs[0]) + } + return &kgxapi.WitnessReport{ Direction: kgxapi.Inbound, OriginAddr: r.srcIP, @@ -207,6 +216,8 @@ func (c *BackendCollector) Process(t akinet.ParsedNetworkTraffic) error { if isRequest { pair.srcIP, pair.dstIP = pair.dstIP, pair.srcIP pair.srcPort, pair.dstPort = pair.dstPort, pair.srcPort + + pair.xForwardedFor = partial.XForwardedFor } c.queueUpload(pair) @@ -225,6 +236,7 @@ func (c *BackendCollector) Process(t akinet.ParsedNetworkTraffic) error { witness: partial.Witness, observationTime: t.ObservationTime, id: partial.PairKey, + xForwardedFor: partial.XForwardedFor, } // Store whichever timestamp brackets the processing interval. w.recordTimestamp(isRequest, t)