diff --git a/fs/9p/v9fs.c b/fs/9p/v9fs.c index 0429c8ee58f1..89bac3d2f05b 100644 --- a/fs/9p/v9fs.c +++ b/fs/9p/v9fs.c @@ -343,7 +343,7 @@ static int v9fs_parse_options(struct v9fs_session_info *v9ses, char *opts) v9ses->uid = make_kuid(current_user_ns(), uid); if (!uid_valid(v9ses->uid)) { ret = -EINVAL; - pr_info("Uknown uid %s\n", s); + pr_info("Unknown uid %s\n", s); } } diff --git a/fs/9p/vfs_file.c b/fs/9p/vfs_file.c index 03c9e325bfbc..5f2e48d41d72 100644 --- a/fs/9p/vfs_file.c +++ b/fs/9p/vfs_file.c @@ -533,7 +533,7 @@ v9fs_mmap_file_mmap(struct file *filp, struct vm_area_struct *vma) return retval; } -static int +static vm_fault_t v9fs_vm_page_mkwrite(struct vm_fault *vmf) { struct v9fs_inode *v9inode; diff --git a/fs/9p/xattr.c b/fs/9p/xattr.c index f329eee6dc93..352abc39e891 100644 --- a/fs/9p/xattr.c +++ b/fs/9p/xattr.c @@ -105,7 +105,7 @@ int v9fs_fid_xattr_set(struct p9_fid *fid, const char *name, { struct kvec kvec = {.iov_base = (void *)value, .iov_len = value_len}; struct iov_iter from; - int retval; + int retval, err; iov_iter_kvec(&from, WRITE | ITER_KVEC, &kvec, 1, value_len); @@ -126,7 +126,9 @@ int v9fs_fid_xattr_set(struct p9_fid *fid, const char *name, retval); else p9_client_write(fid, 0, &from, &retval); - p9_client_clunk(fid); + err = p9_client_clunk(fid); + if (!retval && err) + retval = err; return retval; } diff --git a/include/net/9p/client.h b/include/net/9p/client.h index 7af9d769b97d..0fa0fbab33b0 100644 --- a/include/net/9p/client.h +++ b/include/net/9p/client.h @@ -27,6 +27,7 @@ #define NET_9P_CLIENT_H #include +#include /* Number of requests per row */ #define P9_ROW_MAXTAG 255 @@ -112,7 +113,7 @@ enum p9_req_status_t { struct p9_req_t { int status; int t_err; - wait_queue_head_t *wq; + wait_queue_head_t wq; struct p9_fcall *tc; struct p9_fcall *rc; void *aux; @@ -128,8 +129,7 @@ struct p9_req_t { * @proto_version: 9P protocol version to use * @trans_mod: module API instantiated with this client * @trans: tranport instance state and API - * @fidpool: fid handle accounting for session - * @fidlist: List of active fid handles + * @fids: All active FID handles * @tagpool - transaction id accounting for session * @reqs - 2D array of requests * @max_tag - current maximum tag id allocated @@ -169,8 +169,7 @@ struct p9_client { } tcp; } trans_opts; - struct p9_idpool *fidpool; - struct list_head fidlist; + struct idr fids; struct p9_idpool *tagpool; struct p9_req_t *reqs[P9_ROW_MAXTAG]; @@ -188,7 +187,6 @@ struct p9_client { * @iounit: the server reported maximum transaction size for this file * @uid: the numeric uid of the local user who owns this handle * @rdir: readdir accounting structure (allocated on demand) - * @flist: per-client-instance fid tracking * @dlist: per-dentry fid tracking * * TODO: This needs lots of explanation. @@ -204,7 +202,6 @@ struct p9_fid { void *rdir; - struct list_head flist; struct hlist_node dlist; /* list of all fids attached to a dentry */ }; diff --git a/net/9p/client.c b/net/9p/client.c index 5c1343195292..deae53a7dffc 100644 --- a/net/9p/client.c +++ b/net/9p/client.c @@ -283,8 +283,9 @@ p9_tag_alloc(struct p9_client *c, u16 tag, unsigned int max_size) return ERR_PTR(-ENOMEM); } for (col = 0; col < P9_ROW_MAXTAG; col++) { - c->reqs[row][col].status = REQ_STATUS_IDLE; - c->reqs[row][col].tc = NULL; + req = &c->reqs[row][col]; + req->status = REQ_STATUS_IDLE; + init_waitqueue_head(&req->wq); } c->max_tag += P9_ROW_MAXTAG; } @@ -294,13 +295,6 @@ p9_tag_alloc(struct p9_client *c, u16 tag, unsigned int max_size) col = tag % P9_ROW_MAXTAG; req = &c->reqs[row][col]; - if (!req->wq) { - req->wq = kmalloc(sizeof(wait_queue_head_t), GFP_NOFS); - if (!req->wq) - goto grow_failed; - init_waitqueue_head(req->wq); - } - if (!req->tc) req->tc = p9_fcall_alloc(alloc_msize); if (!req->rc) @@ -320,9 +314,7 @@ p9_tag_alloc(struct p9_client *c, u16 tag, unsigned int max_size) pr_err("Couldn't grow tag array\n"); kfree(req->tc); kfree(req->rc); - kfree(req->wq); req->tc = req->rc = NULL; - req->wq = NULL; return ERR_PTR(-ENOMEM); } @@ -341,7 +333,7 @@ struct p9_req_t *p9_tag_lookup(struct p9_client *c, u16 tag) * buffer to read the data into */ tag++; - if(tag >= c->max_tag) + if (tag >= c->max_tag) return NULL; row = tag / P9_ROW_MAXTAG; @@ -410,7 +402,6 @@ static void p9_tag_cleanup(struct p9_client *c) /* free requests associated with tags */ for (row = 0; row < (c->max_tag/P9_ROW_MAXTAG); row++) { for (col = 0; col < P9_ROW_MAXTAG; col++) { - kfree(c->reqs[row][col].wq); kfree(c->reqs[row][col].tc); kfree(c->reqs[row][col].rc); } @@ -448,12 +439,12 @@ void p9_client_cb(struct p9_client *c, struct p9_req_t *req, int status) /* * This barrier is needed to make sure any change made to req before - * the other thread wakes up will indeed be seen by the waiting side. + * the status change is visible to another thread */ smp_wmb(); req->status = status; - wake_up(req->wq); + wake_up(&req->wq); p9_debug(P9_DEBUG_MUX, "wakeup: %d\n", req->tc->tag); } EXPORT_SYMBOL(p9_client_cb); @@ -478,20 +469,11 @@ p9_parse_header(struct p9_fcall *pdu, int32_t *size, int8_t *type, int16_t *tag, int err; pdu->offset = 0; - if (pdu->size == 0) - pdu->size = 7; err = p9pdu_readf(pdu, 0, "dbw", &r_size, &r_type, &r_tag); if (err) goto rewind_and_exit; - pdu->size = r_size; - pdu->id = r_type; - pdu->tag = r_tag; - - p9_debug(P9_DEBUG_9P, "<<< size=%d type: %d tag: %d\n", - pdu->size, pdu->id, pdu->tag); - if (type) *type = r_type; if (tag) @@ -499,6 +481,16 @@ p9_parse_header(struct p9_fcall *pdu, int32_t *size, int8_t *type, int16_t *tag, if (size) *size = r_size; + if (pdu->size != r_size || r_size < 7) { + err = -EINVAL; + goto rewind_and_exit; + } + + pdu->id = r_type; + pdu->tag = r_tag; + + p9_debug(P9_DEBUG_9P, "<<< size=%d type: %d tag: %d\n", + pdu->size, pdu->id, pdu->tag); rewind_and_exit: if (rewind) @@ -525,6 +517,12 @@ static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) int ecode; err = p9_parse_header(req->rc, NULL, &type, NULL, 0); + if (req->rc->size >= c->msize) { + p9_debug(P9_DEBUG_ERROR, + "requested packet size too big: %d\n", + req->rc->size); + return -EIO; + } /* * dump the response from server * This should be after check errors which poplulate pdu_fcall. @@ -774,7 +772,7 @@ p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...) } again: /* Wait for the response */ - err = wait_event_killable(*req->wq, req->status >= REQ_STATUS_RCVD); + err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD); /* * Make sure our req is coherent with regard to updates in other @@ -909,34 +907,31 @@ static struct p9_fid *p9_fid_create(struct p9_client *clnt) { int ret; struct p9_fid *fid; - unsigned long flags; p9_debug(P9_DEBUG_FID, "clnt %p\n", clnt); fid = kmalloc(sizeof(struct p9_fid), GFP_KERNEL); if (!fid) - return ERR_PTR(-ENOMEM); - - ret = p9_idpool_get(clnt->fidpool); - if (ret < 0) { - ret = -ENOSPC; - goto error; - } - fid->fid = ret; + return NULL; memset(&fid->qid, 0, sizeof(struct p9_qid)); fid->mode = -1; fid->uid = current_fsuid(); fid->clnt = clnt; fid->rdir = NULL; - spin_lock_irqsave(&clnt->lock, flags); - list_add(&fid->flist, &clnt->fidlist); - spin_unlock_irqrestore(&clnt->lock, flags); + fid->fid = 0; - return fid; + idr_preload(GFP_KERNEL); + spin_lock_irq(&clnt->lock); + ret = idr_alloc_u32(&clnt->fids, fid, &fid->fid, P9_NOFID - 1, + GFP_NOWAIT); + spin_unlock_irq(&clnt->lock); + idr_preload_end(); + + if (!ret) + return fid; -error: kfree(fid); - return ERR_PTR(ret); + return NULL; } static void p9_fid_destroy(struct p9_fid *fid) @@ -946,9 +941,8 @@ static void p9_fid_destroy(struct p9_fid *fid) p9_debug(P9_DEBUG_FID, "fid %d\n", fid->fid); clnt = fid->clnt; - p9_idpool_put(fid->fid, clnt->fidpool); spin_lock_irqsave(&clnt->lock, flags); - list_del(&fid->flist); + idr_remove(&clnt->fids, fid->fid); spin_unlock_irqrestore(&clnt->lock, flags); kfree(fid->rdir); kfree(fid); @@ -958,7 +952,7 @@ static int p9_client_version(struct p9_client *c) { int err = 0; struct p9_req_t *req; - char *version; + char *version = NULL; int msize; p9_debug(P9_DEBUG_9P, ">>> TVERSION msize %d protocol %d\n", @@ -1031,7 +1025,7 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) memcpy(clnt->name, client_id, strlen(client_id) + 1); spin_lock_init(&clnt->lock); - INIT_LIST_HEAD(&clnt->fidlist); + idr_init(&clnt->fids); err = p9_tag_init(clnt); if (err < 0) @@ -1051,18 +1045,12 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) goto destroy_tagpool; } - clnt->fidpool = p9_idpool_create(); - if (IS_ERR(clnt->fidpool)) { - err = PTR_ERR(clnt->fidpool); - goto put_trans; - } - p9_debug(P9_DEBUG_MUX, "clnt %p trans %p msize %d protocol %d\n", clnt, clnt->trans_mod, clnt->msize, clnt->proto_version); err = clnt->trans_mod->create(clnt, dev_name, options); if (err) - goto destroy_fidpool; + goto put_trans; if (clnt->msize > clnt->trans_mod->maxsize) clnt->msize = clnt->trans_mod->maxsize; @@ -1075,8 +1063,6 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) close_trans: clnt->trans_mod->close(clnt); -destroy_fidpool: - p9_idpool_destroy(clnt->fidpool); put_trans: v9fs_put_trans(clnt->trans_mod); destroy_tagpool: @@ -1089,7 +1075,8 @@ EXPORT_SYMBOL(p9_client_create); void p9_client_destroy(struct p9_client *clnt) { - struct p9_fid *fid, *fidptr; + struct p9_fid *fid; + int id; p9_debug(P9_DEBUG_MUX, "clnt %p\n", clnt); @@ -1098,14 +1085,11 @@ void p9_client_destroy(struct p9_client *clnt) v9fs_put_trans(clnt->trans_mod); - list_for_each_entry_safe(fid, fidptr, &clnt->fidlist, flist) { + idr_for_each_entry(&clnt->fids, fid, id) { pr_info("Found fid %d not clunked\n", fid->fid); p9_fid_destroy(fid); } - if (clnt->fidpool) - p9_idpool_destroy(clnt->fidpool); - p9_tag_cleanup(clnt); kfree(clnt); @@ -1138,9 +1122,8 @@ struct p9_fid *p9_client_attach(struct p9_client *clnt, struct p9_fid *afid, p9_debug(P9_DEBUG_9P, ">>> TATTACH afid %d uname %s aname %s\n", afid ? afid->fid : -1, uname, aname); fid = p9_fid_create(clnt); - if (IS_ERR(fid)) { - err = PTR_ERR(fid); - fid = NULL; + if (!fid) { + err = -ENOMEM; goto error; } fid->uid = n_uname; @@ -1189,9 +1172,8 @@ struct p9_fid *p9_client_walk(struct p9_fid *oldfid, uint16_t nwname, clnt = oldfid->clnt; if (clone) { fid = p9_fid_create(clnt); - if (IS_ERR(fid)) { - err = PTR_ERR(fid); - fid = NULL; + if (!fid) { + err = -ENOMEM; goto error; } @@ -1576,7 +1558,7 @@ p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) int count = iov_iter_count(to); int rsize, non_zc = 0; char *dataptr; - + rsize = fid->iounit; if (!rsize || rsize > clnt->msize-P9_IOHDRSZ) rsize = clnt->msize - P9_IOHDRSZ; @@ -1790,7 +1772,7 @@ struct p9_stat_dotl *p9_client_getattr_dotl(struct p9_fid *fid, "<<< st_mtime_sec=%lld st_mtime_nsec=%lld\n" "<<< st_ctime_sec=%lld st_ctime_nsec=%lld\n" "<<< st_btime_sec=%lld st_btime_nsec=%lld\n" - "<<< st_gen=%lld st_data_version=%lld", + "<<< st_gen=%lld st_data_version=%lld\n", ret->st_result_mask, ret->qid.type, ret->qid.path, ret->qid.version, ret->st_mode, ret->st_nlink, from_kuid(&init_user_ns, ret->st_uid), @@ -2019,9 +2001,8 @@ struct p9_fid *p9_client_xattrwalk(struct p9_fid *file_fid, err = 0; clnt = file_fid->clnt; attr_fid = p9_fid_create(clnt); - if (IS_ERR(attr_fid)) { - err = PTR_ERR(attr_fid); - attr_fid = NULL; + if (!attr_fid) { + err = -ENOMEM; goto error; } p9_debug(P9_DEBUG_9P, diff --git a/net/9p/protocol.c b/net/9p/protocol.c index 931ea00c4fed..4a1e1dd30b52 100644 --- a/net/9p/protocol.c +++ b/net/9p/protocol.c @@ -156,7 +156,7 @@ p9pdu_vreadf(struct p9_fcall *pdu, int proto_version, const char *fmt, *sptr = kmalloc(len + 1, GFP_NOFS); if (*sptr == NULL) { - errcode = -EFAULT; + errcode = -ENOMEM; break; } if (pdu_read(pdu, *sptr, len)) { diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c index 588bf88c3305..e2ef3c782c53 100644 --- a/net/9p/trans_fd.c +++ b/net/9p/trans_fd.c @@ -185,6 +185,8 @@ static void p9_mux_poll_stop(struct p9_conn *m) spin_lock_irqsave(&p9_poll_lock, flags); list_del_init(&m->poll_pending_link); spin_unlock_irqrestore(&p9_poll_lock, flags); + + flush_work(&p9_poll_work); } /** @@ -197,15 +199,14 @@ static void p9_mux_poll_stop(struct p9_conn *m) static void p9_conn_cancel(struct p9_conn *m, int err) { struct p9_req_t *req, *rtmp; - unsigned long flags; LIST_HEAD(cancel_list); p9_debug(P9_DEBUG_ERROR, "mux %p err %d\n", m, err); - spin_lock_irqsave(&m->client->lock, flags); + spin_lock(&m->client->lock); if (m->err) { - spin_unlock_irqrestore(&m->client->lock, flags); + spin_unlock(&m->client->lock); return; } @@ -217,7 +218,6 @@ static void p9_conn_cancel(struct p9_conn *m, int err) list_for_each_entry_safe(req, rtmp, &m->unsent_req_list, req_list) { list_move(&req->req_list, &cancel_list); } - spin_unlock_irqrestore(&m->client->lock, flags); list_for_each_entry_safe(req, rtmp, &cancel_list, req_list) { p9_debug(P9_DEBUG_ERROR, "call back req %p\n", req); @@ -226,6 +226,7 @@ static void p9_conn_cancel(struct p9_conn *m, int err) req->t_err = err; p9_client_cb(m->client, req, REQ_STATUS_ERROR); } + spin_unlock(&m->client->lock); } static __poll_t @@ -324,7 +325,9 @@ static void p9_read_work(struct work_struct *work) if ((!m->req) && (m->rc.offset == m->rc.capacity)) { p9_debug(P9_DEBUG_TRANS, "got new header\n"); - err = p9_parse_header(&m->rc, NULL, NULL, NULL, 0); + /* Header size */ + m->rc.size = 7; + err = p9_parse_header(&m->rc, &m->rc.size, NULL, NULL, 0); if (err) { p9_debug(P9_DEBUG_ERROR, "error parsing header: %d\n", err); @@ -369,12 +372,14 @@ static void p9_read_work(struct work_struct *work) */ if ((m->req) && (m->rc.offset == m->rc.capacity)) { p9_debug(P9_DEBUG_TRANS, "got new packet\n"); + m->req->rc->size = m->rc.offset; spin_lock(&m->client->lock); if (m->req->status != REQ_STATUS_ERROR) status = REQ_STATUS_RCVD; list_del(&m->req->req_list); - spin_unlock(&m->client->lock); + /* update req->status while holding client->lock */ p9_client_cb(m->client, m->req, status); + spin_unlock(&m->client->lock); m->rc.sdata = NULL; m->rc.offset = 0; m->rc.capacity = 0; @@ -940,7 +945,7 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) if (err < 0) return err; - if (valid_ipaddr4(addr) < 0) + if (addr == NULL || valid_ipaddr4(addr) < 0) return -EINVAL; csocket = NULL; @@ -990,6 +995,9 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args) csocket = NULL; + if (addr == NULL) + return -EINVAL; + if (strlen(addr) >= UNIX_PATH_MAX) { pr_err("%s (%d): address too long: %s\n", __func__, task_pid_nr(current), addr); diff --git a/net/9p/trans_rdma.c b/net/9p/trans_rdma.c index b06286f253cb..b513cffeeb3c 100644 --- a/net/9p/trans_rdma.c +++ b/net/9p/trans_rdma.c @@ -320,6 +320,7 @@ recv_done(struct ib_cq *cq, struct ib_wc *wc) if (wc->status != IB_WC_SUCCESS) goto err_out; + c->rc->size = wc->byte_len; err = p9_parse_header(c->rc, NULL, NULL, &tag, 1); if (err) goto err_out; @@ -644,6 +645,9 @@ rdma_create_trans(struct p9_client *client, const char *addr, char *args) struct rdma_conn_param conn_param; struct ib_qp_init_attr qp_attr; + if (addr == NULL) + return -EINVAL; + /* Parse the transport specific mount options */ err = parse_opts(args, &opts); if (err < 0) diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c index 05006cbb3361..7728b0acde09 100644 --- a/net/9p/trans_virtio.c +++ b/net/9p/trans_virtio.c @@ -89,10 +89,8 @@ struct virtio_chan { unsigned long p9_max_pages; /* Scatterlist: can be too big for stack. */ struct scatterlist sg[VIRTQUEUE_NUM]; - - int tag_len; /* - * tag name to identify a mount Non-null terminated + * tag name to identify a mount null terminated */ char *tag; @@ -144,24 +142,27 @@ static void req_done(struct virtqueue *vq) struct virtio_chan *chan = vq->vdev->priv; unsigned int len; struct p9_req_t *req; + bool need_wakeup = false; unsigned long flags; p9_debug(P9_DEBUG_TRANS, ": request done\n"); - while (1) { - spin_lock_irqsave(&chan->lock, flags); - req = virtqueue_get_buf(chan->vq, &len); - if (req == NULL) { - spin_unlock_irqrestore(&chan->lock, flags); - break; + spin_lock_irqsave(&chan->lock, flags); + while ((req = virtqueue_get_buf(chan->vq, &len)) != NULL) { + if (!chan->ring_bufs_avail) { + chan->ring_bufs_avail = 1; + need_wakeup = true; } - chan->ring_bufs_avail = 1; - spin_unlock_irqrestore(&chan->lock, flags); - /* Wakeup if anyone waiting for VirtIO ring space. */ - wake_up(chan->vc_wq); - if (len) + + if (len) { + req->rc->size = len; p9_client_cb(chan->client, req, REQ_STATUS_RCVD); + } } + spin_unlock_irqrestore(&chan->lock, flags); + /* Wakeup if anyone waiting for VirtIO ring space. */ + if (need_wakeup) + wake_up(chan->vc_wq); } /** @@ -188,7 +189,7 @@ static int pack_sg_list(struct scatterlist *sg, int start, s = rest_of_page(data); if (s > count) s = count; - BUG_ON(index > limit); + BUG_ON(index >= limit); /* Make sure we don't terminate early. */ sg_unmark_end(&sg[index]); sg_set_buf(&sg[index++], data, s); @@ -233,6 +234,7 @@ pack_sg_list_p(struct scatterlist *sg, int start, int limit, s = PAGE_SIZE - data_off; if (s > count) s = count; + BUG_ON(index >= limit); /* Make sure we don't terminate early. */ sg_unmark_end(&sg[index]); sg_set_page(&sg[index++], pdata[i++], s, data_off); @@ -382,8 +384,8 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, * p9_virtio_zc_request - issue a zero copy request * @client: client instance issuing the request * @req: request to be issued - * @uidata: user bffer that should be ued for zero copy read - * @uodata: user buffer that shoud be user for zero copy write + * @uidata: user buffer that should be used for zero copy read + * @uodata: user buffer that should be used for zero copy write * @inlen: read buffer size * @outlen: write buffer size * @in_hdr_len: reader header size, This is the size of response protocol data @@ -406,6 +408,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, p9_debug(P9_DEBUG_TRANS, "virtio request\n"); if (uodata) { + __le32 sz; int n = p9_get_mapped_pages(chan, &out_pages, uodata, outlen, &offs, &need_drop); if (n < 0) @@ -416,6 +419,12 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); outlen = n; } + /* The size field of the message must include the length of the + * header and the length of the data. We didn't actually know + * the length of the data until this point so add it in now. + */ + sz = cpu_to_le32(req->tc->size + outlen); + memcpy(&req->tc->sdata[0], &sz, sizeof(sz)); } else if (uidata) { int n = p9_get_mapped_pages(chan, &in_pages, uidata, inlen, &offs, &need_drop); @@ -446,7 +455,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, out += pack_sg_list_p(chan->sg, out, VIRTQUEUE_NUM, out_pages, out_nr_pages, offs, outlen); } - + /* * Take care of in data * For example TREAD have 11. @@ -490,7 +499,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, virtqueue_kick(chan->vq); spin_unlock_irqrestore(&chan->lock, flags); p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n"); - err = wait_event_killable(*req->wq, req->status >= REQ_STATUS_RCVD); + err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD); /* * Non kernel buffers are pinned, unpin them */ @@ -517,14 +526,15 @@ static ssize_t p9_mount_tag_show(struct device *dev, { struct virtio_chan *chan; struct virtio_device *vdev; + int tag_len; vdev = dev_to_virtio(dev); chan = vdev->priv; + tag_len = strlen(chan->tag); - memcpy(buf, chan->tag, chan->tag_len); - buf[chan->tag_len] = 0; + memcpy(buf, chan->tag, tag_len + 1); - return chan->tag_len + 1; + return tag_len + 1; } static DEVICE_ATTR(mount_tag, 0444, p9_mount_tag_show, NULL); @@ -563,7 +573,7 @@ static int p9_virtio_probe(struct virtio_device *vdev) chan->vq = virtio_find_single_vq(vdev, req_done, "requests"); if (IS_ERR(chan->vq)) { err = PTR_ERR(chan->vq); - goto out_free_vq; + goto out_free_chan; } chan->vq->vdev->priv = chan; spin_lock_init(&chan->lock); @@ -577,7 +587,7 @@ static int p9_virtio_probe(struct virtio_device *vdev) err = -EINVAL; goto out_free_vq; } - tag = kmalloc(tag_len, GFP_KERNEL); + tag = kzalloc(tag_len + 1, GFP_KERNEL); if (!tag) { err = -ENOMEM; goto out_free_vq; @@ -586,7 +596,6 @@ static int p9_virtio_probe(struct virtio_device *vdev) virtio_cread_bytes(vdev, offsetof(struct virtio_9p_config, tag), tag, tag_len); chan->tag = tag; - chan->tag_len = tag_len; err = sysfs_create_file(&(vdev->dev.kobj), &dev_attr_mount_tag.attr); if (err) { goto out_free_tag; @@ -616,6 +625,7 @@ static int p9_virtio_probe(struct virtio_device *vdev) kfree(tag); out_free_vq: vdev->config->del_vqs(vdev); +out_free_chan: kfree(chan); fail: return err; @@ -643,10 +653,12 @@ p9_virtio_create(struct p9_client *client, const char *devname, char *args) int ret = -ENOENT; int found = 0; + if (devname == NULL) + return -EINVAL; + mutex_lock(&virtio_9p_lock); list_for_each_entry(chan, &virtio_chan_list, chan_list) { - if (!strncmp(devname, chan->tag, chan->tag_len) && - strlen(devname) == chan->tag_len) { + if (!strcmp(devname, chan->tag)) { if (!chan->inuse) { chan->inuse = true; found = 1; diff --git a/net/9p/trans_xen.c b/net/9p/trans_xen.c index 2e2b8bca54f3..c2d54ac76bfd 100644 --- a/net/9p/trans_xen.c +++ b/net/9p/trans_xen.c @@ -94,6 +94,9 @@ static int p9_xen_create(struct p9_client *client, const char *addr, char *args) { struct xen_9pfs_front_priv *priv; + if (addr == NULL) + return -EINVAL; + read_lock(&xen_9pfs_lock); list_for_each_entry(priv, &xen_9pfs_devs, list) { if (!strcmp(priv->tag, addr)) { diff --git a/net/9p/util.c b/net/9p/util.c index 59f278e64f58..55ad98277e85 100644 --- a/net/9p/util.c +++ b/net/9p/util.c @@ -138,4 +138,3 @@ int p9_idpool_check(int id, struct p9_idpool *p) return idr_find(&p->pool, id) != NULL; } EXPORT_SYMBOL(p9_idpool_check); -