l2tp: fix race in l2tp_recv_common()
authorGuillaume Nault <g.nault@alphalink.fr>
Thu, 2 Apr 2020 17:32:45 +0000 (18:32 +0100)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Mon, 13 Apr 2020 08:31:27 +0000 (10:31 +0200)
commit 61b9a047729bb230978178bca6729689d0c50ca2 upstream.

Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d186df2 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: Guillaume Nault <g.nault@alphalink.fr>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Will Deacon <will@kernel.org>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
net/l2tp/l2tp_core.c
net/l2tp/l2tp_core.h
net/l2tp/l2tp_ip.c
net/l2tp/l2tp_ip6.c

index 429dbb06424001df7150db8904ccb0e4860049c7..4d41fe40723d24604a894ed087afc6904052a391 100644 (file)
@@ -277,6 +277,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
 }
 EXPORT_SYMBOL_GPL(l2tp_session_find);
 
+/* Like l2tp_session_find() but takes a reference on the returned session.
+ * Optionally calls session->ref() too if do_ref is true.
+ */
+struct l2tp_session *l2tp_session_get(struct net *net,
+                                     struct l2tp_tunnel *tunnel,
+                                     u32 session_id, bool do_ref)
+{
+       struct hlist_head *session_list;
+       struct l2tp_session *session;
+
+       if (!tunnel) {
+               struct l2tp_net *pn = l2tp_pernet(net);
+
+               session_list = l2tp_session_id_hash_2(pn, session_id);
+
+               rcu_read_lock_bh();
+               hlist_for_each_entry_rcu(session, session_list, global_hlist) {
+                       if (session->session_id == session_id) {
+                               l2tp_session_inc_refcount(session);
+                               if (do_ref && session->ref)
+                                       session->ref(session);
+                               rcu_read_unlock_bh();
+
+                               return session;
+                       }
+               }
+               rcu_read_unlock_bh();
+
+               return NULL;
+       }
+
+       session_list = l2tp_session_id_hash(tunnel, session_id);
+       read_lock_bh(&tunnel->hlist_lock);
+       hlist_for_each_entry(session, session_list, hlist) {
+               if (session->session_id == session_id) {
+                       l2tp_session_inc_refcount(session);
+                       if (do_ref && session->ref)
+                               session->ref(session);
+                       read_unlock_bh(&tunnel->hlist_lock);
+
+                       return session;
+               }
+       }
+       read_unlock_bh(&tunnel->hlist_lock);
+
+       return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_session_get);
+
 struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
                                          bool do_ref)
 {
@@ -636,6 +685,9 @@ discard:
  * a data (not control) frame before coming here. Fields up to the
  * session-id have already been parsed and ptr points to the data
  * after the session-id.
+ *
+ * session->ref() must have been called prior to l2tp_recv_common().
+ * session->deref() will be called automatically after skb is processed.
  */
 void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
                      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -645,14 +697,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
        int offset;
        u32 ns, nr;
 
-       /* The ref count is increased since we now hold a pointer to
-        * the session. Take care to decrement the refcnt when exiting
-        * this function from now on...
-        */
-       l2tp_session_inc_refcount(session);
-       if (session->ref)
-               (*session->ref)(session);
-
        /* Parse and check optional cookie */
        if (session->peer_cookie_len > 0) {
                if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -803,8 +847,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
        /* Try to dequeue as many skbs from reorder_q as we can. */
        l2tp_recv_dequeue(session);
 
-       l2tp_session_dec_refcount(session);
-
        return;
 
 discard:
@@ -813,8 +855,6 @@ discard:
 
        if (session->deref)
                (*session->deref)(session);
-
-       l2tp_session_dec_refcount(session);
 }
 EXPORT_SYMBOL(l2tp_recv_common);
 
@@ -921,8 +961,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
        }
 
        /* Find the session context */
-       session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+       session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
        if (!session || !session->recv_skb) {
+               if (session) {
+                       if (session->deref)
+                               session->deref(session);
+                       l2tp_session_dec_refcount(session);
+               }
+
                /* Not found? Pass to userspace to deal with */
                l2tp_info(tunnel, L2TP_MSG_DATA,
                          "%s: no session found (%u/%u). Passing up.\n",
@@ -935,6 +981,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
                goto error;
 
        l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
+       l2tp_session_dec_refcount(session);
 
        return 0;
 
index fad47e9d74bcccc50faec73b71c81b3a2f315b89..705fbc63ddc248683c4e8c3f0a6be0ce0a327fd4 100644 (file)
@@ -243,6 +243,9 @@ out:
        return tunnel;
 }
 
+struct l2tp_session *l2tp_session_get(struct net *net,
+                                     struct l2tp_tunnel *tunnel,
+                                     u32 session_id, bool do_ref);
 struct l2tp_session *l2tp_session_find(struct net *net,
                                       struct l2tp_tunnel *tunnel,
                                       u32 session_id);
index 7efb3cadc152be7f8e2c1f72e2f55fe02923ef7e..58f87bdd12c756d0ce4e722d4a9f5705f0d412ab 100644 (file)
@@ -142,19 +142,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
        }
 
        /* Ok, this is a data packet. Lookup the session. */
-       session = l2tp_session_find(net, NULL, session_id);
-       if (session == NULL)
+       session = l2tp_session_get(net, NULL, session_id, true);
+       if (!session)
                goto discard;
 
        tunnel = session->tunnel;
-       if (tunnel == NULL)
-               goto discard;
+       if (!tunnel)
+               goto discard_sess;
 
        /* Trace packet contents, if enabled */
        if (tunnel->debug & L2TP_MSG_DATA) {
                length = min(32u, skb->len);
                if (!pskb_may_pull(skb, length))
-                       goto discard;
+                       goto discard_sess;
 
                /* Point to L2TP header */
                optr = ptr = skb->data;
@@ -167,6 +167,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
                goto discard;
 
        l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
+       l2tp_session_dec_refcount(session);
 
        return 0;
 
@@ -204,6 +205,12 @@ pass_up:
 
        return sk_receive_skb(sk, skb, 1);
 
+discard_sess:
+       if (session->deref)
+               session->deref(session);
+       l2tp_session_dec_refcount(session);
+       goto discard;
+
 discard_put:
        sock_put(sk);
 
index 391dd9d8144fc01cd26bfe608e3160bfd003dabd..af04a8a68269724b41b3d0a4f2dea47149547c9d 100644 (file)
@@ -154,19 +154,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
        }
 
        /* Ok, this is a data packet. Lookup the session. */
-       session = l2tp_session_find(net, NULL, session_id);
-       if (session == NULL)
+       session = l2tp_session_get(net, NULL, session_id, true);
+       if (!session)
                goto discard;
 
        tunnel = session->tunnel;
-       if (tunnel == NULL)
-               goto discard;
+       if (!tunnel)
+               goto discard_sess;
 
        /* Trace packet contents, if enabled */
        if (tunnel->debug & L2TP_MSG_DATA) {
                length = min(32u, skb->len);
                if (!pskb_may_pull(skb, length))
-                       goto discard;
+                       goto discard_sess;
 
                /* Point to L2TP header */
                optr = ptr = skb->data;
@@ -180,6 +180,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 
        l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
                         tunnel->recv_payload_hook);
+       l2tp_session_dec_refcount(session);
+
        return 0;
 
 pass_up:
@@ -217,6 +219,12 @@ pass_up:
 
        return sk_receive_skb(sk, skb, 1);
 
+discard_sess:
+       if (session->deref)
+               session->deref(session);
+       l2tp_session_dec_refcount(session);
+       goto discard;
+
 discard_put:
        sock_put(sk);