summaryrefslogtreecommitdiff
path: root/net/vmw_vsock/af_vsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r--net/vmw_vsock/af_vsock.c88
1 files changed, 60 insertions, 28 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 943d58b07a55..036bdcc9d5c5 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -113,12 +113,14 @@
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
static void vsock_sk_destruct(struct sock *sk);
static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
+static void vsock_close(struct sock *sk, long timeout);
/* Protocol family. */
static struct proto vsock_proto = {
.name = "AF_VSOCK",
.owner = THIS_MODULE,
.obj_size = sizeof(struct vsock_sock),
+ .close = vsock_close,
};
/* The default peer timeout indicates how long we will wait for a peer response
@@ -328,7 +330,10 @@ EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
void vsock_remove_sock(struct vsock_sock *vsk)
{
- vsock_remove_bound(vsk);
+ /* Transport reassignment must not remove the binding. */
+ if (sock_flag(sk_vsock(vsk), SOCK_DEAD))
+ vsock_remove_bound(vsk);
+
vsock_remove_connected(vsk);
}
EXPORT_SYMBOL_GPL(vsock_remove_sock);
@@ -800,39 +805,44 @@ static bool sock_type_connectible(u16 type)
static void __vsock_release(struct sock *sk, int level)
{
- if (sk) {
- struct sock *pending;
- struct vsock_sock *vsk;
+ struct vsock_sock *vsk;
+ struct sock *pending;
- vsk = vsock_sk(sk);
- pending = NULL; /* Compiler warning. */
+ vsk = vsock_sk(sk);
+ pending = NULL; /* Compiler warning. */
- /* When "level" is SINGLE_DEPTH_NESTING, use the nested
- * version to avoid the warning "possible recursive locking
- * detected". When "level" is 0, lock_sock_nested(sk, level)
- * is the same as lock_sock(sk).
- */
- lock_sock_nested(sk, level);
+ /* When "level" is SINGLE_DEPTH_NESTING, use the nested
+ * version to avoid the warning "possible recursive locking
+ * detected". When "level" is 0, lock_sock_nested(sk, level)
+ * is the same as lock_sock(sk).
+ */
+ lock_sock_nested(sk, level);
- if (vsk->transport)
- vsk->transport->release(vsk);
- else if (sock_type_connectible(sk->sk_type))
- vsock_remove_sock(vsk);
+ /* Indicate to vsock_remove_sock() that the socket is being released and
+ * can be removed from the bound_table. Unlike transport reassignment
+ * case, where the socket must remain bound despite vsock_remove_sock()
+ * being called from the transport release() callback.
+ */
+ sock_set_flag(sk, SOCK_DEAD);
- sock_orphan(sk);
- sk->sk_shutdown = SHUTDOWN_MASK;
+ if (vsk->transport)
+ vsk->transport->release(vsk);
+ else if (sock_type_connectible(sk->sk_type))
+ vsock_remove_sock(vsk);
- skb_queue_purge(&sk->sk_receive_queue);
+ sock_orphan(sk);
+ sk->sk_shutdown = SHUTDOWN_MASK;
- /* Clean up any sockets that never were accepted. */
- while ((pending = vsock_dequeue_accept(sk)) != NULL) {
- __vsock_release(pending, SINGLE_DEPTH_NESTING);
- sock_put(pending);
- }
+ skb_queue_purge(&sk->sk_receive_queue);
- release_sock(sk);
- sock_put(sk);
+ /* Clean up any sockets that never were accepted. */
+ while ((pending = vsock_dequeue_accept(sk)) != NULL) {
+ __vsock_release(pending, SINGLE_DEPTH_NESTING);
+ sock_put(pending);
}
+
+ release_sock(sk);
+ sock_put(sk);
}
static void vsock_sk_destruct(struct sock *sk)
@@ -899,9 +909,22 @@ s64 vsock_stream_has_space(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_stream_has_space);
+/* Dummy callback required by sockmap.
+ * See unconditional call of saved_close() in sock_map_close().
+ */
+static void vsock_close(struct sock *sk, long timeout)
+{
+}
+
static int vsock_release(struct socket *sock)
{
- __vsock_release(sock->sk, 0);
+ struct sock *sk = sock->sk;
+
+ if (!sk)
+ return 0;
+
+ sk->sk_prot->close(sk, 0);
+ __vsock_release(sk, 0);
sock->sk = NULL;
sock->state = SS_FREE;
@@ -1386,6 +1409,11 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
if (err < 0)
goto out;
+ /* sk_err might have been set as a result of an earlier
+ * (failed) connect attempt.
+ */
+ sk->sk_err = 0;
+
/* Mark sock as connecting and set the error code to in
* progress in case this is a non-blocking connect.
*/
@@ -1400,7 +1428,11 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
timeout = vsk->connect_timeout;
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
- while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) {
+ /* If the socket is already closing or it is in an error state, there
+ * is no point in waiting.
+ */
+ while (sk->sk_state != TCP_ESTABLISHED &&
+ sk->sk_state != TCP_CLOSING && sk->sk_err == 0) {
if (flags & O_NONBLOCK) {
/* If we're not going to block, we schedule a timeout
* function to generate a timeout on the connection