netlink: Use rhashtable walk interface in diag dump

This patch converts the diag dumping code to use the rhashtable
walk code instead of going through rhashtable by hand.  The lock
nl_table_lock is now only taken while we process the multicast
list as it's not needed for the rhashtable walk.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Herbert Xu 2016-08-19 16:21:37 +08:00 committed by David S. Miller
parent 39ec406d12
commit ad20207432

View File

@ -63,43 +63,75 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
int protocol, int s_num)
{
struct rhashtable_iter *hti = (void *)cb->args[2];
struct netlink_table *tbl = &nl_table[protocol];
struct rhashtable *ht = &tbl->hash;
const struct bucket_table *htbl = rht_dereference_rcu(ht->tbl, ht);
struct net *net = sock_net(skb->sk);
struct netlink_diag_req *req;
struct netlink_sock *nlsk;
struct sock *sk;
int ret = 0, num = 0, i;
int num = 2;
int ret = 0;
req = nlmsg_data(cb->nlh);
for (i = 0; i < htbl->size; i++) {
struct rhash_head *pos;
if (s_num > 1)
goto mc_list;
rht_for_each_entry_rcu(nlsk, pos, htbl, i, node) {
sk = (struct sock *)nlsk;
num--;
if (!net_eq(sock_net(sk), net))
continue;
if (num < s_num) {
num++;
if (!hti) {
hti = kmalloc(sizeof(*hti), GFP_KERNEL);
if (!hti)
return -ENOMEM;
cb->args[2] = (long)hti;
}
if (!s_num)
rhashtable_walk_enter(&tbl->hash, hti);
ret = rhashtable_walk_start(hti);
if (ret == -EAGAIN)
ret = 0;
if (ret)
goto stop;
while ((nlsk = rhashtable_walk_next(hti))) {
if (IS_ERR(nlsk)) {
ret = PTR_ERR(nlsk);
if (ret == -EAGAIN) {
ret = 0;
continue;
}
break;
}
if (sk_diag_fill(sk, skb, req,
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq,
NLM_F_MULTI,
sock_i_ino(sk)) < 0) {
ret = 1;
goto done;
}
sk = (struct sock *)nlsk;
num++;
if (!net_eq(sock_net(sk), net))
continue;
if (sk_diag_fill(sk, skb, req,
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq,
NLM_F_MULTI,
sock_i_ino(sk)) < 0) {
ret = 1;
break;
}
}
stop:
rhashtable_walk_stop(hti);
if (ret)
goto done;
rhashtable_walk_exit(hti);
cb->args[2] = 0;
num++;
mc_list:
read_lock(&nl_table_lock);
sk_for_each_bound(sk, &tbl->mc_list) {
if (sk_hashed(sk))
continue;
@ -116,13 +148,14 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
NLM_F_MULTI,
sock_i_ino(sk)) < 0) {
ret = 1;
goto done;
break;
}
num++;
}
read_unlock(&nl_table_lock);
done:
cb->args[0] = num;
cb->args[1] = protocol;
return ret;
}
@ -131,20 +164,20 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
{
struct netlink_diag_req *req;
int s_num = cb->args[0];
int err = 0;
req = nlmsg_data(cb->nlh);
rcu_read_lock();
read_lock(&nl_table_lock);
if (req->sdiag_protocol == NDIAG_PROTO_ALL) {
int i;
for (i = cb->args[1]; i < MAX_LINKS; i++) {
if (__netlink_diag_dump(skb, cb, i, s_num))
err = __netlink_diag_dump(skb, cb, i, s_num);
if (err)
break;
s_num = 0;
}
cb->args[1] = i;
} else {
if (req->sdiag_protocol >= MAX_LINKS) {
read_unlock(&nl_table_lock);
@ -152,13 +185,22 @@ static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
return -ENOENT;
}
__netlink_diag_dump(skb, cb, req->sdiag_protocol, s_num);
err = __netlink_diag_dump(skb, cb, req->sdiag_protocol, s_num);
}
read_unlock(&nl_table_lock);
rcu_read_unlock();
return err < 0 ? err : skb->len;
}
return skb->len;
static int netlink_diag_dump_done(struct netlink_callback *cb)
{
struct rhashtable_iter *hti = (void *)cb->args[2];
if (cb->args[0] == 1)
rhashtable_walk_exit(hti);
kfree(hti);
return 0;
}
static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
@ -172,6 +214,7 @@ static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
if (h->nlmsg_flags & NLM_F_DUMP) {
struct netlink_dump_control c = {
.dump = netlink_diag_dump,
.done = netlink_diag_dump_done,
};
return netlink_dump_start(net->diag_nlsk, skb, h, &c);
} else