diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 8bfe961b434d0..58d8806585a23 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -477,20 +477,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) goto err; } - if (vsk->transport) { - if (vsk->transport == new_transport) { - ret = 0; - goto err; - } - - /* transport->release() must be called with sock lock acquired. - * This path can only be taken during vsock_stream_connect(), - * where we have already held the sock lock. - * In the other cases, this function is called on a new socket - * which is not assigned to any transport. - */ - vsk->transport->release(vsk); - vsock_deassign_transport(vsk); + if (vsk->transport && vsk->transport == new_transport) { + ret = 0; + goto err; } /* We increase the module refcnt to prevent the transport unloading @@ -507,6 +496,26 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) */ mutex_unlock(&vsock_register_mutex); + if (vsk->transport) { + /* transport->release() must be called with sock lock acquired. + * This path can only be taken during vsock_stream_connect(), + * where we have already held the sock lock. + * In the other cases, this function is called on a new socket + * which is not assigned to any transport. + */ + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + + /* transport's release() and destruct() can touch some socket + * state, since we are reassigning the socket to a new transport + * during vsock_connect(), let's reset these fields to have a + * clean state. + */ + sock_reset_flag(sk, SOCK_DONE); + sk->sk_state = TCP_CLOSE; + vsk->peer_shutdown = 0; + } + ret = new_transport->init(vsk, psk); if (ret) { module_put(new_transport->module);