diff --git a/inmem.go b/inmem.go index 14c46ca..76cf7ef 100644 --- a/inmem.go +++ b/inmem.go @@ -1,6 +1,7 @@ package libchan import ( + "encoding" "encoding/binary" "errors" "io" @@ -312,6 +313,13 @@ func (w *pipeSender) copyValue(v interface{}) (interface{}, error) { return w.copyChannelMessage(val) case map[interface{}]interface{}: return w.copyChannelInterfaceMessage(val) + case encoding.BinaryMarshaler: + p, err := val.MarshalBinary() + if err != nil { + return nil, err + } + + return w.copyValue(p) default: if rv := reflect.ValueOf(v); rv.Kind() == reflect.Ptr { if rv.Elem().Kind() == reflect.Struct { @@ -419,10 +427,8 @@ func (w *pipeSender) copyChannelInterfaceMessage(m map[interface{}]interface{}) return mCopy, nil } func (w *pipeSender) copyStructure(m interface{}) (interface{}, error) { - v := reflect.ValueOf(m) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } + v := reflect.Indirect(reflect.ValueOf(m)) + if v.Kind() != reflect.Struct { return nil, errors.New("invalid non struct type") } diff --git a/spdy/copy.go b/spdy/copy.go index cb0aa7e..a946827 100644 --- a/spdy/copy.go +++ b/spdy/copy.go @@ -1,6 +1,7 @@ package spdy import ( + "encoding" "errors" "io" "net" @@ -55,6 +56,13 @@ func (c *channel) copyValue(v interface{}) (interface{}, error) { return c.copyChannelMessage(val) case map[interface{}]interface{}: return c.copyChannelInterfaceMessage(val) + case encoding.BinaryMarshaler: + p, err := val.MarshalBinary() + if err != nil { + return nil, err + } + + return c.copyValue(p) default: if rv := reflect.ValueOf(v); rv.Kind() == reflect.Ptr { if rv.Elem().Kind() == reflect.Struct { @@ -163,10 +171,8 @@ func (c *channel) copyChannelInterfaceMessage(m map[interface{}]interface{}) (in } func (c *channel) copyStructure(m interface{}) (interface{}, error) { - v := reflect.ValueOf(m) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } + v := reflect.Indirect(reflect.ValueOf(m)) + if v.Kind() != reflect.Struct { return nil, errors.New("invalid non struct type") } @@ -177,6 +183,9 @@ func (c *channel) copyStructValue(v reflect.Value) (interface{}, error) { mCopy := make(map[string]interface{}) t := v.Type() for i := 0; i < v.NumField(); i++ { + // TODO(stevvooe): Calling Interface without checking if a type can be + // interfaced may lead to panics. This value copier may need to be + // refactored to handle arbitrary types. vCopy, vErr := c.copyValue(v.Field(i).Interface()) if vErr != nil { return nil, vErr diff --git a/spdy/copy_test.go b/spdy/copy_test.go new file mode 100644 index 0000000..82f9935 --- /dev/null +++ b/spdy/copy_test.go @@ -0,0 +1,16 @@ +package spdy + +import ( + "testing" + + "github.com/docker/libchan/testutil" +) + +func TestTypeTransmission(t *testing.T) { + sender, receiver, err := Pipe() + if err != nil { + t.Fatalf("error creating pipe: %v", err) + } + + testutil.CheckTypeTransmission(t, receiver, sender) +} diff --git a/testutil/check.go b/testutil/check.go new file mode 100644 index 0000000..efff948 --- /dev/null +++ b/testutil/check.go @@ -0,0 +1,62 @@ +// Package testutil contains checks that implementations of libchan transports +// can use to check compliance. For now, this will only work with Go, but +// cross-language tests could be added here, as well. +package testutil + +import ( + "io" + "io/ioutil" + "reflect" + "strings" + "testing" + "time" + + "github.com/docker/libchan" +) + +func CheckTypeTransmission(t *testing.T, receiver libchan.Receiver, sender libchan.Sender) { + // Add types that should be transmitted by value to this struct. Their + // equality will be tested with reflect.DeepEquals. + type ValueTypes struct { + I int + T time.Time + } + + // Add other types, that may include readers or stateful items. + type A struct { + // TODO(stevvooe): Ideally, this would be embedded but libchan doesn't + // seem to transmit embedded structs correctly. + V ValueTypes + Reader io.ReadCloser // TODO(stevvooe): Only io.ReadCloser is support for now. + } + + readerContent := "asdf" + expected := A{ + V: ValueTypes{ + I: 1234, + T: time.Now(), + }, + Reader: ioutil.NopCloser(strings.NewReader(readerContent)), + } + + go func() { + if err := sender.Send(expected); err != nil { + t.Fatalf("unexpected error sending: %v", err) + } + }() + + var received A + if err := receiver.Receive(&received); err != nil { + t.Fatalf("unexpected error receiving: %v", err) + } + + if !reflect.DeepEqual(received.V, expected.V) { + t.Fatalf("expected structs to be equal: %#v != %#v", received.V, expected.V) + } + + receivedContent, _ := ioutil.ReadAll(received.Reader) + if string(receivedContent) != readerContent { + t.Fatalf("reader transmitted different content %q != %q", receivedContent, readerContent) + } + +} diff --git a/testutil/check_test.go b/testutil/check_test.go new file mode 100644 index 0000000..085cc85 --- /dev/null +++ b/testutil/check_test.go @@ -0,0 +1,14 @@ +package testutil + +import ( + "github.com/docker/libchan" + + "testing" +) + +// TestTypeTransmission tests the main package (to avoid import cycles) and +// provides and example of how this test should be used in other packages. +func TestTypeTransmission(t *testing.T) { + receiver, sender := libchan.Pipe() + CheckTypeTransmission(t, receiver, sender) +}