RDMA/cm: Allow ib_send_cm_rej() to be done under lock

The first thing ib_send_cm_rej() does is obtain the lock, so use the usual
unlocked wrapper, locked actor pattern here.

This avoids a sketchy lock/unlock sequence (which could allow state to
change) during cm_destroy_id().

While here simplify some of the logic in the implementation.

Link: https://lore.kernel.org/r/20200310092545.251365-14-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
This commit is contained in:
Jason Gunthorpe 2020-03-10 11:25:43 +02:00
parent 87cabf3e09
commit 81ddb41f87

View File

@ -87,6 +87,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
const void *private_data, u8 private_data_len);
static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
void *private_data, u8 private_data_len);
static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
enum ib_cm_rej_reason reason, void *ari,
u8 ari_length, const void *private_data,
u8 private_data_len);
static struct ib_client cm_client = {
.name = "cm",
@ -1060,11 +1064,11 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
case IB_CM_REQ_SENT:
case IB_CM_MRA_REQ_RCVD:
ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
cm_send_rej_locked(cm_id_priv, IB_CM_REJ_TIMEOUT,
&cm_id_priv->id.device->node_guid,
sizeof(cm_id_priv->id.device->node_guid),
NULL, 0);
spin_unlock_irq(&cm_id_priv->lock);
ib_send_cm_rej(cm_id, IB_CM_REJ_TIMEOUT,
&cm_id_priv->id.device->node_guid,
sizeof cm_id_priv->id.device->node_guid,
NULL, 0);
break;
case IB_CM_REQ_RCVD:
if (err == -ENOMEM) {
@ -1072,9 +1076,10 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
cm_reset_to_idle(cm_id_priv);
spin_unlock_irq(&cm_id_priv->lock);
} else {
cm_send_rej_locked(cm_id_priv,
IB_CM_REJ_CONSUMER_DEFINED, NULL, 0,
NULL, 0);
spin_unlock_irq(&cm_id_priv->lock);
ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
NULL, 0, NULL, 0);
}
break;
case IB_CM_REP_SENT:
@ -1084,9 +1089,9 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
case IB_CM_MRA_REQ_SENT:
case IB_CM_REP_RCVD:
case IB_CM_MRA_REP_SENT:
cm_send_rej_locked(cm_id_priv, IB_CM_REJ_CONSUMER_DEFINED, NULL,
0, NULL, 0);
spin_unlock_irq(&cm_id_priv->lock);
ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
NULL, 0, NULL, 0);
break;
case IB_CM_ESTABLISHED:
if (cm_id_priv->qp_type == IB_QPT_XRC_TGT) {
@ -2899,65 +2904,72 @@ static int cm_drep_handler(struct cm_work *work)
return -EINVAL;
}
int ib_send_cm_rej(struct ib_cm_id *cm_id,
enum ib_cm_rej_reason reason,
void *ari,
u8 ari_length,
const void *private_data,
u8 private_data_len)
static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
enum ib_cm_rej_reason reason, void *ari,
u8 ari_length, const void *private_data,
u8 private_data_len)
{
struct cm_id_private *cm_id_priv;
struct ib_mad_send_buf *msg;
unsigned long flags;
int ret;
lockdep_assert_held(&cm_id_priv->lock);
if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) ||
(ari && ari_length > IB_CM_REJ_ARI_LENGTH))
return -EINVAL;
cm_id_priv = container_of(cm_id, struct cm_id_private, id);
spin_lock_irqsave(&cm_id_priv->lock, flags);
switch (cm_id->state) {
switch (cm_id_priv->id.state) {
case IB_CM_REQ_SENT:
case IB_CM_MRA_REQ_RCVD:
case IB_CM_REQ_RCVD:
case IB_CM_MRA_REQ_SENT:
case IB_CM_REP_RCVD:
case IB_CM_MRA_REP_SENT:
ret = cm_alloc_msg(cm_id_priv, &msg);
if (!ret)
cm_format_rej((struct cm_rej_msg *) msg->mad,
cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
cm_reset_to_idle(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len);
break;
case IB_CM_REP_SENT:
case IB_CM_MRA_REP_RCVD:
ret = cm_alloc_msg(cm_id_priv, &msg);
if (!ret)
cm_format_rej((struct cm_rej_msg *) msg->mad,
cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
cm_enter_timewait(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len);
break;
default:
pr_debug("%s: local_id %d, cm_id->state: %d\n", __func__,
be32_to_cpu(cm_id_priv->id.local_id), cm_id->state);
ret = -EINVAL;
goto out;
be32_to_cpu(cm_id_priv->id.local_id),
cm_id_priv->id.state);
return -EINVAL;
}
if (ret)
goto out;
ret = ib_post_send_mad(msg, NULL);
if (ret)
if (ret) {
cm_free_msg(msg);
return ret;
}
out: spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return 0;
}
int ib_send_cm_rej(struct ib_cm_id *cm_id, enum ib_cm_rej_reason reason,
void *ari, u8 ari_length, const void *private_data,
u8 private_data_len)
{
struct cm_id_private *cm_id_priv =
container_of(cm_id, struct cm_id_private, id);
unsigned long flags;
int ret;
spin_lock_irqsave(&cm_id_priv->lock, flags);
ret = cm_send_rej_locked(cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return ret;
}
EXPORT_SYMBOL(ib_send_cm_rej);