diff --git a/encoding/msgpack/codec.go b/encoding/msgpack/codec.go new file mode 100644 index 0000000..1605436 --- /dev/null +++ b/encoding/msgpack/codec.go @@ -0,0 +1,264 @@ +package msgpack + +import ( + "encoding/binary" + "errors" + "io" + "reflect" + "time" + + "github.com/dmcgowan/msgpack" + "github.com/docker/libchan" + "github.com/docker/libchan/encoding" +) + +const ( + duplexStreamCode = 1 + inboundStreamCode = 2 + outboundStreamCode = 3 + inboundChannelCode = 4 + outboundChannelCode = 5 + timeCode = 6 +) + +type cproducer struct { + encoding.ChannelFactory +} + +type creceiver struct { + encoding.ChannelReceiver +} + +func decodeReferenceID(b []byte) (referenceID uint64, err error) { + if len(b) == 8 { + referenceID = binary.BigEndian.Uint64(b) + } else if len(b) == 4 { + referenceID = uint64(binary.BigEndian.Uint32(b)) + } else { + err = errors.New("bad reference id") + } + return +} + +func encodeReferenceID(referenceID uint64) []byte { + if referenceID > 0xffffffff { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, referenceID) + return buf + } + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(referenceID)) + return buf +} + +func (p *cproducer) copySendChannel(send libchan.Sender) (uint64, error) { + recv, copyID, err := p.CreateReceiver() + if err != nil { + return 0, err + } + // Start copying into sender + go func() { + libchan.Copy(send, recv) + send.Close() + }() + return copyID, nil +} + +func (p *cproducer) copyReceiveChannel(recv libchan.Receiver) (uint64, error) { + send, copyID, err := p.CreateSender() + if err != nil { + return 0, err + } + // Start copying from receiver + go func() { + libchan.Copy(send, recv) + send.Close() + }() + return copyID, nil +} + +func (r *creceiver) decodeStream(b []byte) (io.ReadWriteCloser, error) { + referenceID, err := decodeReferenceID(b) + if err != nil { + return nil, err + } + + return r.GetStream(referenceID) +} + +func (r *creceiver) decodeReceiver(v reflect.Value, b []byte) error { + referenceID, err := decodeReferenceID(b) + if err != nil { + return err + } + + recv, err := r.GetReceiver(referenceID) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(recv)) + + return nil +} + +func (r *creceiver) decodeSender(v reflect.Value, b []byte) error { + referenceID, err := decodeReferenceID(b) + if err != nil { + return err + } + + send, err := r.GetSender(referenceID) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(send)) + + return nil +} + +func (r *creceiver) decodeWStream(v reflect.Value, b []byte) error { + bs, err := r.decodeStream(b) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(bs)) + + return nil +} + +func (r *creceiver) decodeRStream(v reflect.Value, b []byte) error { + bs, err := r.decodeStream(b) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(bs)) + + return bs.Close() +} + +func entimeCode(t *time.Time) ([]byte, error) { + var b [12]byte + binary.BigEndian.PutUint64(b[0:8], uint64(t.Unix())) + binary.BigEndian.PutUint32(b[8:12], uint32(t.Nanosecond())) + return b[:], nil +} + +func detimeCode(v reflect.Value, b []byte) error { + if len(b) != 12 { + return errors.New("Invalid length") + } + t := time.Unix(int64(binary.BigEndian.Uint64(b[0:8])), int64(binary.BigEndian.Uint32(b[8:12]))) + + if v.Kind() == reflect.Ptr { + v.Set(reflect.ValueOf(&t)) + } else { + v.Set(reflect.ValueOf(t)) + } + + return nil +} + +func (p *cproducer) encodeExtended(iv reflect.Value) (i int, b []byte, e error) { + switch v := iv.Interface().(type) { + case libchan.Sender: + copyCh, err := p.copySendChannel(v) + if err != nil { + return 0, nil, err + } + return inboundChannelCode, encodeReferenceID(copyCh), nil + case libchan.Receiver: + copyCh, err := p.copyReceiveChannel(v) + if err != nil { + return 0, nil, err + } + return outboundChannelCode, encodeReferenceID(copyCh), nil + + case io.Reader: + // Either ReadWriteCloser, ReadWriter, or ReadCloser + streamCopy, copyID, err := p.CreateStream() + if err != nil { + return 0, nil, err + } + go func() { + io.Copy(streamCopy, v) + streamCopy.Close() + }() + code := outboundStreamCode + if wc, ok := v.(io.WriteCloser); ok { + go func() { + io.Copy(wc, streamCopy) + wc.Close() + }() + code = duplexStreamCode + } else if w, ok := v.(io.Writer); ok { + go func() { + io.Copy(w, streamCopy) + }() + code = duplexStreamCode + } + return code, encodeReferenceID(copyID), nil + case io.Writer: + streamCopy, copyID, err := p.CreateStream() + if err != nil { + return 0, nil, err + } + if wc, ok := v.(io.WriteCloser); ok { + go func() { + io.Copy(wc, streamCopy) + wc.Close() + }() + } else { + go func() { + io.Copy(v, streamCopy) + }() + } + return inboundStreamCode, encodeReferenceID(copyID), nil + case *time.Time: + b, err := entimeCode(v) + return timeCode, b, err + } + return 0, nil, nil +} + +// Codec implements the libchan encoding using msgpack5. +type Codec struct{} + +// NewEncoder returns a libchan encoder which encodes given objects +// to msgpack5 on the given datastream using the given encoding +// channel producer. +func (codec *Codec) NewEncoder(w io.Writer, p encoding.ChannelFactory) encoding.Encoder { + prd := &cproducer{p} + encoder := msgpack.NewEncoder(w) + exts := msgpack.NewExtensions() + exts.SetEncoder(prd.encodeExtended) + encoder.AddExtensions(exts) + return encoder +} + +// NewDecoder returns a libchan decoder which decodes objects from +// the given data stream from msgpack5 into provided object using +// the provided types for libchan interfaces. +func (codec *Codec) NewDecoder(r io.Reader, recv encoding.ChannelReceiver, streamT, recvT, sendT reflect.Type) encoding.Decoder { + rec := &creceiver{recv} + decoder := msgpack.NewDecoder(r) + exts := msgpack.NewExtensions() + exts.AddDecoder(duplexStreamCode, streamT, rec.decodeWStream) + exts.AddDecoder(inboundStreamCode, streamT, rec.decodeWStream) + exts.AddDecoder(outboundStreamCode, streamT, rec.decodeRStream) + exts.AddDecoder(inboundChannelCode, sendT, rec.decodeSender) + exts.AddDecoder(outboundChannelCode, recvT, rec.decodeReceiver) + exts.AddDecoder(timeCode, reflect.TypeOf(&time.Time{}), detimeCode) + decoder.AddExtensions(exts) + return decoder +} + +// NewRawMessage returns a transit object which will copy a +// msgpack5 datastream and allow decoding that object +// using a Decoder from the codec object. +func (codec *Codec) NewRawMessage() encoding.Decoder { + return new(msgpack.RawMessage) +} diff --git a/encoding/types.go b/encoding/types.go new file mode 100644 index 0000000..b6584ea --- /dev/null +++ b/encoding/types.go @@ -0,0 +1,78 @@ +package encoding + +import ( + "io" + "reflect" + + "github.com/docker/libchan" +) + +// ChannelFactory represents an object which is able to create new +// channels and streams. This interface is used by an encoder +// create a channel or stream, copy the encoded type, and +// encode the identifier. +type ChannelFactory interface { + // CreateSender creates a new send channel and returns + // the identifier associated with the sender. This + // identifier can be used to get the Receiver on + // the receiving side by calling GetReceiver. + CreateSender() (libchan.Sender, uint64, error) + + // CreateReceiver creates a new receive channel and + // returns the identifier associated with the receiver. + // This identifier can be used to get the Sender on + // the receiving side by calling GetSender. + CreateReceiver() (libchan.Receiver, uint64, error) + + // CreateStream createsa new byte stream and returns + // the identifier associate with the stream. This + // identifier can be used to get the byte stream + // by calling GetStream on the receiving side. + CreateStream() (io.ReadWriteCloser, uint64, error) +} + +// ChannelReceiver represents an object which is able to receive +// new channels and streams and retrieve by an integer identifer. +type ChannelReceiver interface { + // GetSender gets a remotely created sender referenced + // by the given identifier. + GetSender(uint64) (libchan.Sender, error) + + // GetReceiver gets a remotely created receiver referenced + // by the given identifier. + GetReceiver(uint64) (libchan.Receiver, error) + + // GetStream gets a remotely created byte stream + // referenced by the given identifier. + GetStream(uint64) (io.ReadWriteCloser, error) +} + +// Encoder represents an object which can encode an interface +// into data stream to be decoded. This Encoder must be able +// to encode interfaces by converting to libchan channels and +// streams and encoding the identifier. +type Encoder interface { + Encode(v ...interface{}) error +} + +// Decoder represents an object which can decode from a data +// stream into an interface. The decoder must have support +// for decoding stream and channel identifiers into a libchan +// Sender or Receiver as well as io Readers and Writers. +type Decoder interface { + Decode(v ...interface{}) error +} + +// ChannelCodec represents a libchan codec capable of encoding +// Go interfaces into data streams supporting libchan types as +// well as decode into libchan supported interfaces. In addition +// to encoding and decoding, the codec must provide a transit +// type which is capable of copying a data stream in order to +// delay decoding into an object until finally received. +// The RawMessage must return an object similar to json.RawMessage +// with the capability of decoding itself into an object. +type ChannelCodec interface { + NewEncoder(io.Writer, ChannelFactory) Encoder + NewDecoder(io.Reader, ChannelReceiver, reflect.Type, reflect.Type, reflect.Type) Decoder + NewRawMessage() Decoder +} diff --git a/examples/rexec/client.go b/examples/rexec/client.go index bd3a084..ab90eb6 100644 --- a/examples/rexec/client.go +++ b/examples/rexec/client.go @@ -8,6 +8,7 @@ import ( "os" "github.com/docker/libchan" + "github.com/docker/libchan/encoding/msgpack" "github.com/docker/libchan/spdy" ) @@ -46,7 +47,7 @@ func main() { if err != nil { log.Fatal(err) } - transport := spdy.NewTransport(p) + transport := spdy.NewTransport(p, &msgpack.Codec{}) sender, err := transport.NewSendChannel() if err != nil { log.Fatal(err) diff --git a/examples/rexec/rexec_server/server.go b/examples/rexec/rexec_server/server.go index ac0a71a..d5dc726 100644 --- a/examples/rexec/rexec_server/server.go +++ b/examples/rexec/rexec_server/server.go @@ -10,6 +10,7 @@ import ( "syscall" "github.com/docker/libchan" + "github.com/docker/libchan/encoding/msgpack" "github.com/docker/libchan/spdy" ) @@ -67,7 +68,7 @@ func main() { log.Print(err) break } - t := spdy.NewTransport(p) + t := spdy.NewTransport(p, &msgpack.Codec{}) go func() { for { diff --git a/spdy/bench_test.go b/spdy/bench_test.go index 9307046..99e47c5 100644 --- a/spdy/bench_test.go +++ b/spdy/bench_test.go @@ -10,10 +10,6 @@ import ( "github.com/docker/libchan" ) -var ( - testPipe = Pipe -) - type SimpleStruct struct { Value int } diff --git a/spdy/encode.go b/spdy/encode.go deleted file mode 100644 index 6d6d952..0000000 --- a/spdy/encode.go +++ /dev/null @@ -1,293 +0,0 @@ -package spdy - -import ( - "encoding/binary" - "errors" - "io" - "reflect" - "time" - - "github.com/dmcgowan/msgpack" - "github.com/docker/libchan" -) - -const ( - duplexStreamCode = 1 - inboundStreamCode = 2 - outboundStreamCode = 3 - inboundChannelCode = 4 - outboundChannelCode = 5 - timeCode = 6 -) - -func decodeReferenceID(b []byte) (referenceID uint64, err error) { - if len(b) == 8 { - referenceID = binary.BigEndian.Uint64(b) - } else if len(b) == 4 { - referenceID = uint64(binary.BigEndian.Uint32(b)) - } else { - err = errors.New("bad reference id") - } - return -} - -func encodeReferenceID(b []byte, referenceID uint64) (n int) { - if referenceID > 0xffffffff { - binary.BigEndian.PutUint64(b, referenceID) - n = 8 - } else { - binary.BigEndian.PutUint32(b, uint32(referenceID)) - n = 4 - } - return -} - -func (s *stream) channelBytes() ([]byte, error) { - buf := make([]byte, 8) - written := encodeReferenceID(buf, s.referenceID) - return buf[:written], nil -} - -func (s *stream) copySendChannel(send libchan.Sender) (*nopSender, error) { - recv, sendCopy, err := s.CreateNestedReceiver() - if err != nil { - return nil, err - } - // Start copying into sender - go func() { - libchan.Copy(send, recv) - send.Close() - }() - return sendCopy.(*nopSender), nil -} - -func (s *stream) copyReceiveChannel(recv libchan.Receiver) (*nopReceiver, error) { - send, recvCopy, err := s.CreateNestedSender() - if err != nil { - return nil, err - } - // Start copying from receiver - go func() { - libchan.Copy(send, recv) - send.Close() - }() - return recvCopy.(*nopReceiver), nil -} - -func (s *stream) decodeStream(b []byte) (*stream, error) { - referenceID, err := decodeReferenceID(b) - if err != nil { - return nil, err - } - - gs := s.session.getStream(referenceID) - if gs == nil { - return nil, errors.New("stream does not exist") - } - - return gs, nil -} - -func (s *stream) decodeReceiver(v reflect.Value, b []byte) error { - bs, err := s.decodeStream(b) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(&receiver{stream: bs})) - - return nil -} - -func (s *stream) decodeSender(v reflect.Value, b []byte) error { - bs, err := s.decodeStream(b) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(&sender{stream: bs})) - - return nil -} - -func (s *stream) streamBytes() ([]byte, error) { - var buf [8]byte - written := encodeReferenceID(buf[:], s.referenceID) - - return buf[:written], nil -} - -func (s *stream) decodeWStream(v reflect.Value, b []byte) error { - bs, err := s.decodeStream(b) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(bs)) - - return nil -} - -func (s *stream) decodeRStream(v reflect.Value, b []byte) error { - bs, err := s.decodeStream(b) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(bs)) - - return nil -} - -func encodeTime(t *time.Time) ([]byte, error) { - var b [12]byte - binary.BigEndian.PutUint64(b[0:8], uint64(t.Unix())) - binary.BigEndian.PutUint32(b[8:12], uint32(t.Nanosecond())) - return b[:], nil -} - -func decodeTime(v reflect.Value, b []byte) error { - if len(b) != 12 { - return errors.New("Invalid length") - } - t := time.Unix(int64(binary.BigEndian.Uint64(b[0:8])), int64(binary.BigEndian.Uint32(b[8:12]))) - - if v.Kind() == reflect.Ptr { - v.Set(reflect.ValueOf(&t)) - } else { - v.Set(reflect.ValueOf(t)) - } - - return nil -} - -func (s *stream) encodeExtended(iv reflect.Value) (i int, b []byte, e error) { - switch v := iv.Interface().(type) { - case *nopSender: - if v.stream == nil { - return 0, nil, errors.New("bad type") - } - if v.stream.session != s.session { - rc, err := s.copySendChannel(v) - if err != nil { - return 0, nil, err - } - b, err := rc.stream.channelBytes() - return inboundChannelCode, b, err - } - - b, err := v.stream.channelBytes() - return inboundChannelCode, b, err - case *nopReceiver: - if v.stream == nil { - return 0, nil, errors.New("bad type") - } - if v.stream.session != s.session { - rc, err := s.copyReceiveChannel(v) - if err != nil { - return 0, nil, err - } - b, err := rc.stream.channelBytes() - return outboundChannelCode, b, err - } - - b, err := v.stream.channelBytes() - return outboundChannelCode, b, err - case *stream: - if v.referenceID == 0 { - return 0, nil, errors.New("bad type") - } - if v.session != s.session { - streamCopy, err := s.createByteStream() - if err != nil { - return 0, nil, err - } - go func(r io.Reader) { - io.Copy(streamCopy, r) - streamCopy.Close() - }(v) - go func(w io.WriteCloser) { - io.Copy(w, streamCopy) - w.Close() - }(v) - v = streamCopy - - } - b, err := v.channelBytes() - return duplexStreamCode, b, err - case libchan.Sender: - copyCh, err := s.copySendChannel(v) - if err != nil { - return 0, nil, err - } - b, err := copyCh.stream.channelBytes() - return inboundChannelCode, b, err - case libchan.Receiver: - copyCh, err := s.copyReceiveChannel(v) - if err != nil { - return 0, nil, err - } - b, err := copyCh.stream.channelBytes() - return outboundChannelCode, b, err - - case io.Reader: - // Either ReadWriteCloser, ReadWriter, or ReadCloser - streamCopy, err := s.createByteStream() - if err != nil { - return 0, nil, err - } - go func() { - io.Copy(streamCopy, v) - streamCopy.Close() - }() - code := outboundStreamCode - if wc, ok := v.(io.WriteCloser); ok { - go func() { - io.Copy(wc, streamCopy) - wc.Close() - }() - code = duplexStreamCode - } else if w, ok := v.(io.Writer); ok { - go func() { - io.Copy(w, streamCopy) - }() - code = duplexStreamCode - } - b, err := streamCopy.streamBytes() - return code, b, err - case io.Writer: - streamCopy, err := s.createByteStream() - if err != nil { - return 0, nil, err - } - if wc, ok := v.(io.WriteCloser); ok { - go func() { - io.Copy(wc, streamCopy) - wc.Close() - }() - } else { - go func() { - io.Copy(v, streamCopy) - }() - } - - b, err := streamCopy.streamBytes() - return inboundStreamCode, b, err - case *time.Time: - b, err := encodeTime(v) - return timeCode, b, err - } - return 0, nil, nil -} - -func (s *stream) initializeExtensions() *msgpack.Extensions { - exts := msgpack.NewExtensions() - exts.SetEncoder(s.encodeExtended) - exts.AddDecoder(duplexStreamCode, reflect.TypeOf(&stream{}), s.decodeWStream) - exts.AddDecoder(inboundStreamCode, reflect.TypeOf(&stream{}), s.decodeWStream) - exts.AddDecoder(outboundStreamCode, reflect.TypeOf(&stream{}), s.decodeRStream) - exts.AddDecoder(inboundChannelCode, reflect.TypeOf(&sender{}), s.decodeSender) - exts.AddDecoder(outboundChannelCode, reflect.TypeOf(&receiver{}), s.decodeReceiver) - exts.AddDecoder(timeCode, reflect.TypeOf(&time.Time{}), decodeTime) - return exts -} diff --git a/spdy/pipe.go b/spdy/pipe.go deleted file mode 100644 index 5414595..0000000 --- a/spdy/pipe.go +++ /dev/null @@ -1,82 +0,0 @@ -package spdy - -import ( - "io" - "net" - - "github.com/docker/libchan" -) - -type pipeSender struct { - session libchan.Transport - sender *sender -} - -type pipeReceiver struct { - session libchan.Transport - receiver *receiver -} - -// Pipe creates a top-level channel pipe using an in memory transport. -func Pipe() (libchan.Receiver, libchan.Sender, error) { - c1, c2 := net.Pipe() - - s1, err := NewSpdyStreamProvider(c1, false) - if err != nil { - return nil, nil, err - } - t1 := NewTransport(s1) - - s2, err := NewSpdyStreamProvider(c2, true) - if err != nil { - return nil, nil, err - } - t2 := NewTransport(s2) - - var recv libchan.Receiver - waitError := make(chan error) - - go func() { - var err error - recv, err = t2.WaitReceiveChannel() - waitError <- err - }() - - send, senderErr := t1.NewSendChannel() - if senderErr != nil { - c1.Close() - c2.Close() - return nil, nil, senderErr - } - - receiveErr := <-waitError - if receiveErr != nil { - c1.Close() - c2.Close() - return nil, nil, receiveErr - } - return &pipeReceiver{t2, recv.(*receiver)}, &pipeSender{t1, send.(*sender)}, nil -} - -func (p *pipeSender) Send(message interface{}) error { - return p.sender.Send(message) -} - -func (p *pipeSender) Close() error { - err := p.sender.Close() - if err != nil { - return err - } - if closer, ok := p.session.(io.Closer); ok { - return closer.Close() - } - return nil -} - -func (p *pipeReceiver) Receive(message interface{}) error { - return p.receiver.Receive(message) -} - -func (p *pipeReceiver) SendTo(dst libchan.Sender) (int, error) { - return p.receiver.SendTo(dst) -} diff --git a/spdy/pipe_test.go b/spdy/pipe_test.go index 2d36450..5cdee92 100644 --- a/spdy/pipe_test.go +++ b/spdy/pipe_test.go @@ -9,8 +9,51 @@ import ( "time" "github.com/docker/libchan" + "github.com/docker/libchan/encoding/msgpack" ) +// testPipe creates a top-level channel pipe using an in memory +// transport using spdy and msgpack +func testPipe() (libchan.Receiver, libchan.Sender, error) { + c1, c2 := net.Pipe() + + s1, err := NewSpdyStreamProvider(c1, false) + if err != nil { + return nil, nil, err + } + t1 := NewTransport(s1, &msgpack.Codec{}) + + s2, err := NewSpdyStreamProvider(c2, true) + if err != nil { + return nil, nil, err + } + t2 := NewTransport(s2, &msgpack.Codec{}) + + var recv libchan.Receiver + waitError := make(chan error) + + go func() { + var err error + recv, err = t2.WaitReceiveChannel() + waitError <- err + }() + + send, senderErr := t1.NewSendChannel() + if senderErr != nil { + c1.Close() + c2.Close() + return nil, nil, senderErr + } + + receiveErr := <-waitError + if receiveErr != nil { + c1.Close() + c2.Close() + return nil, nil, receiveErr + } + return recv, send, nil +} + type PipeMessage struct { Message string Stream io.ReadWriteCloser @@ -107,7 +150,7 @@ func SpawnPipeTest(t *testing.T, client PipeSenderRoutine, server PipeReceiverRo endClient := make(chan bool) endServer := make(chan bool) - receiver, sender, err := Pipe() + receiver, sender, err := testPipe() if err != nil { t.Fatalf("Error creating pipe: %s", err) } diff --git a/spdy/proxy_test.go b/spdy/proxy_test.go index 4360f0f..ed494c1 100644 --- a/spdy/proxy_test.go +++ b/spdy/proxy_test.go @@ -136,8 +136,8 @@ func SpawnProxyTest(t *testing.T, client PipeSenderRoutine, server PipeReceiverR endServer := make(chan bool) endProxy := make(chan bool) - receiver1, sender1, err := Pipe() - receiver2, sender2, err := Pipe() + receiver1, sender1, err := testPipe() + receiver2, sender2, err := testPipe() if err != nil { t.Fatalf("Error creating pipe: %s", err) diff --git a/spdy/session.go b/spdy/session.go index da76b9b..8867d9f 100644 --- a/spdy/session.go +++ b/spdy/session.go @@ -5,12 +5,13 @@ import ( "errors" "io" "net/http" + "reflect" "strconv" "sync" "sync/atomic" - "github.com/dmcgowan/msgpack" "github.com/docker/libchan" + "github.com/docker/libchan/encoding" ) var ( @@ -27,6 +28,7 @@ type Transport struct { receiverChan chan *receiver streamC *sync.Cond streams map[uint64]*stream + codec encoding.ChannelCodec } type stream struct { @@ -39,33 +41,26 @@ type stream struct { type sender struct { stream *stream encodeLock sync.Mutex - encoder *msgpack.Encoder + encoder encoding.Encoder buffer *bufio.Writer } type receiver struct { stream *stream decodeLock sync.Mutex - decoder *msgpack.Decoder -} - -type nopReceiver struct { - stream *stream -} - -type nopSender struct { - stream *stream + decoder encoding.Decoder } // NewTransport returns an object implementing the // libchan Transport interface using a stream provider. -func NewTransport(provider StreamProvider) libchan.Transport { +func NewTransport(provider StreamProvider, codec encoding.ChannelCodec) libchan.Transport { session := &Transport{ provider: provider, referenceCounter: 1, receiverChan: make(chan *receiver), streamC: sync.NewCond(new(sync.Mutex)), streams: make(map[uint64]*stream), + codec: codec, } go session.handleStreams() @@ -167,8 +162,12 @@ func (s *Transport) createSubStream(parentID uint64) (*stream, error) { return newStream, nil } -func (s *stream) createByteStream() (*stream, error) { - return s.session.createSubStream(s.referenceID) +func (s *stream) CreateStream() (io.ReadWriteCloser, uint64, error) { + strm, err := s.session.createSubStream(s.referenceID) + if err != nil { + return nil, 0, err + } + return strm, strm.referenceID, nil } // NewSendChannel creates and returns a new send channel. The receive @@ -195,28 +194,50 @@ func (s *Transport) WaitReceiveChannel() (libchan.Receiver, error) { return r, nil } -// CreateNestedReceiver creates a new channel returning the local -// receiver and the remote sender. The remote sender needs to be -// sent across the channel before being utilized. -func (s *stream) CreateNestedReceiver() (libchan.Receiver, libchan.Sender, error) { +// CreateReceiver creates a new channel returning the local +// receiver and the remote sender identifier. +func (s *stream) CreateReceiver() (libchan.Receiver, uint64, error) { stream, err := s.session.createSubStream(s.referenceID) if err != nil { - return nil, nil, err + return nil, 0, err } - return &receiver{stream: stream}, &nopSender{stream: stream}, err + return &receiver{stream: stream}, uint64(stream.referenceID), err } -// CreateNestedReceiver creates a new channel returning the local -// sender and the remote receiver. The remote receiver needs to be -// sent across the channel before being utilized. -func (s *stream) CreateNestedSender() (libchan.Sender, libchan.Receiver, error) { +// CreateSender creates a new channel returning the local +// sender and the remote receiver identifier. +func (s *stream) CreateSender() (libchan.Sender, uint64, error) { stream, err := s.session.createSubStream(s.referenceID) if err != nil { - return nil, nil, err + return nil, 0, err + } + + return &sender{stream: stream}, uint64(stream.referenceID), err +} + +func (s *stream) GetSender(sid uint64) (libchan.Sender, error) { + strm := s.session.getStream(sid) + if strm == nil { + return nil, errors.New("sender does not exist") + } + return &sender{stream: strm}, nil +} + +func (s *stream) GetReceiver(sid uint64) (libchan.Receiver, error) { + strm := s.session.getStream(sid) + if strm == nil { + return nil, errors.New("sender does not exist") } + return &receiver{stream: strm}, nil +} - return &sender{stream: stream}, &nopReceiver{stream: stream}, err +func (s *stream) GetStream(sid uint64) (io.ReadWriteCloser, error) { + strm := s.session.getStream(sid) + if strm == nil { + return nil, errors.New("sender does not exist") + } + return strm, nil } // Send sends a message across the channel to a receiver on the @@ -226,8 +247,7 @@ func (s *sender) Send(message interface{}) error { defer s.encodeLock.Unlock() if s.encoder == nil { s.buffer = bufio.NewWriter(s.stream) - s.encoder = msgpack.NewEncoder(s.buffer) - s.encoder.AddExtensions(s.stream.initializeExtensions()) + s.encoder = s.stream.session.codec.NewEncoder(s.buffer, s.stream) } if err := s.encoder.Encode(message); err != nil { @@ -243,14 +263,19 @@ func (s *sender) Close() error { return s.stream.stream.Close() } +var ( + streamT = reflect.TypeOf(&stream{}) + recvT = reflect.TypeOf(&receiver{}) + sendT = reflect.TypeOf(&sender{}) +) + // Receive receives a message sent across the channel from // a sender on the other side of the transport. func (r *receiver) Receive(message interface{}) error { r.decodeLock.Lock() defer r.decodeLock.Unlock() if r.decoder == nil { - r.decoder = msgpack.NewDecoder(r.stream) - r.decoder.AddExtensions(r.stream.initializeExtensions()) + r.decoder = r.stream.session.codec.NewDecoder(r.stream, r.stream, streamT, recvT, sendT) } decodeErr := r.decoder.Decode(message) @@ -264,14 +289,14 @@ func (r *receiver) Receive(message interface{}) error { func (r *receiver) SendTo(dst libchan.Sender) (int, error) { var n int for { - var rm msgpack.RawMessage - if err := r.Receive(&rm); err == io.EOF { + rm := r.stream.session.codec.NewRawMessage() + if err := r.Receive(rm); err == io.EOF { break } else if err != nil { return n, err } - if err := dst.Send(&rm); err != nil { + if err := dst.Send(rm); err != nil { return n, err } n++ @@ -279,22 +304,6 @@ func (r *receiver) SendTo(dst libchan.Sender) (int, error) { return n, nil } -func (*nopSender) Send(interface{}) error { - return ErrOperationNotAllowed -} - -func (*nopSender) Close() error { - return ErrOperationNotAllowed -} - -func (*nopReceiver) Receive(interface{}) error { - return ErrOperationNotAllowed -} - -func (*nopReceiver) SendTo(libchan.Sender) (int, error) { - return 0, ErrOperationNotAllowed -} - func (s *stream) Read(b []byte) (int, error) { return s.stream.Read(b) } diff --git a/spdy/session_test.go b/spdy/session_test.go index 3232641..c3c6421 100644 --- a/spdy/session_test.go +++ b/spdy/session_test.go @@ -1,6 +1,7 @@ package spdy import ( + "bufio" "io" "net" "os" @@ -9,6 +10,7 @@ import ( "time" "github.com/docker/libchan" + "github.com/docker/libchan/encoding/msgpack" ) type InOutMessage struct { @@ -143,17 +145,15 @@ type MessageWithByteStream struct { func TestByteStream(t *testing.T) { client := func(t *testing.T, sndr libchan.Sender, s libchan.Transport) { - bs, bsErr := sndr.(*sender).stream.createByteStream() - if bsErr != nil { - t.Fatalf("Error creating byte stream: %s", bsErr) - } + bs, remote := net.Pipe() + w := bufio.NewWriter(bs) m1 := &MessageWithByteStream{ Message: "with a byte stream", - Stream: bs, + Stream: remote, } - _, writeErr := bs.Write([]byte("Hello there server!")) + _, writeErr := w.Write([]byte("Hello there server!")) if writeErr != nil { t.Fatalf("Error writing to byte stream: %s", writeErr) } @@ -162,6 +162,9 @@ func TestByteStream(t *testing.T) { if sendErr != nil { t.Fatalf("Error sending channel: %s", sendErr) } + if flushErr := w.Flush(); flushErr != nil { + t.Fatalf("Error flushing: %s", flushErr) + } readBytes := make([]byte, 30) n, readErr := bs.Read(readBytes) @@ -203,7 +206,7 @@ func TestByteStream(t *testing.T) { t.Fatalf("Unexpected read value:\n\tExpected: %q\n\tActual: %q", expected, string(readBytes[:n])) } - _, writeErr := bs.Write([]byte("G'day client ☺")) + _, writeErr := m1.Stream.Write([]byte("G'day client ☺")) if writeErr != nil { t.Fatalf("Error writing to byte stream: %s", writeErr) } @@ -446,7 +449,7 @@ func ClientSendWrapper(f func(t *testing.T, c libchan.Sender, s libchan.Transpor if sessionErr != nil { t.Fatalf("Error creating session: %s", sessionErr) } - session := NewTransport(provider) + session := NewTransport(provider, &msgpack.Codec{}) sender, senderErr := session.NewSendChannel() if senderErr != nil { @@ -478,7 +481,7 @@ func ServerReceiveWrapper(f func(t *testing.T, c libchan.Receiver, s libchan.Tra if sessionErr != nil { t.Fatalf("Error creating session: %s", sessionErr) } - session := NewTransport(provider) + session := NewTransport(provider, &msgpack.Codec{}) receiver, receiverErr := session.WaitReceiveChannel() if receiverErr != nil {