From 7181d861ddb4f5cf20616af7a10f436d62616d31 Mon Sep 17 00:00:00 2001 From: Yuta Imazu Date: Fri, 4 Oct 2024 22:50:20 +0900 Subject: [PATCH] kernel: do not rely on struct layout when downcasting --- common/extra.h | 11 +++++++++++ kernel/console/tty.c | 16 ++++++++++------ kernel/fs/fifo.c | 22 +++++++++++++++------- kernel/fs/fs.c | 3 ++- kernel/fs/proc/pid.c | 6 +++--- kernel/fs/proc/root.c | 4 ++-- kernel/socket.h | 8 ++++++++ kernel/syscall/socket.c | 9 ++++----- kernel/unix_socket.c | 35 ++++++++++++++++------------------- 9 files changed, 71 insertions(+), 43 deletions(-) diff --git a/common/extra.h b/common/extra.h index e9417447..5aa75416 100644 --- a/common/extra.h +++ b/common/extra.h @@ -15,6 +15,17 @@ #define SIZEOF_FIELD(t, f) sizeof(((t*)0)->f) #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) +// clang-format off +#define CONTAINER_OF(ptr, type, member) \ + ({ \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wgnu-statement-expression-from-macro-expansion\"") \ + const __typeof__(((type*)0)->member)* __mptr = (ptr); \ + (type*)((char*)__mptr - offsetof(type, member)); \ + _Pragma("GCC diagnostic pop") \ + }) +// clang-format on + #define ROUND_UP(x, align) (((x) + ((align) - 1)) & ~((align) - 1)) #define ROUND_DOWN(x, align) ((x) & ~((align) - 1)) #define DIV_CEIL(lhs, rhs) (((lhs) + (rhs) - 1) / (rhs)) diff --git a/kernel/console/tty.c b/kernel/console/tty.c index 4f31519b..50374f82 100644 --- a/kernel/console/tty.c +++ b/kernel/console/tty.c @@ -1,4 +1,5 @@ #include "private.h" +#include #include #include #include @@ -9,20 +10,23 @@ #include #include +static struct tty* tty_from_file(struct file* file) { + return CONTAINER_OF(file->inode, struct tty, inode); +} + static bool can_read(struct tty* tty) { return !ring_buf_is_empty(&tty->input_buf); } static bool unblock_read(struct file* file) { - struct tty* tty = (struct tty*)file->inode; - return can_read(tty); + return can_read(tty_from_file(file)); } static ssize_t tty_pread(struct file* file, void* buf, size_t count, uint64_t offset) { (void)offset; - struct tty* tty = (struct tty*)file->inode; + struct tty* tty = tty_from_file(file); for (;;) { int rc = file_block(file, unblock_read, 0); @@ -81,7 +85,7 @@ static void processed_echo(struct tty* tty, const char* buf, size_t count) { static ssize_t tty_pwrite(struct file* file, const void* buf, size_t count, uint64_t offset) { (void)offset; - struct tty* tty = (struct tty*)file->inode; + struct tty* tty = tty_from_file(file); spinlock_lock(&tty->lock); processed_echo(tty, buf, count); spinlock_unlock(&tty->lock); @@ -89,7 +93,7 @@ static ssize_t tty_pwrite(struct file* file, const void* buf, size_t count, } static int tty_ioctl(struct file* file, int request, void* user_argp) { - struct tty* tty = (struct tty*)file->inode; + struct tty* tty = tty_from_file(file); struct termios* termios = &tty->termios; int ret = 0; spinlock_lock(&tty->lock); @@ -157,9 +161,9 @@ static int tty_ioctl(struct file* file, int request, void* user_argp) { } static short tty_poll(struct file* file, short events) { - struct tty* tty = (struct tty*)file->inode; short revents = 0; if (events & POLLIN) { + struct tty* tty = tty_from_file(file); spinlock_lock(&tty->lock); if (can_read(tty)) revents |= POLLIN; diff --git a/kernel/fs/fifo.c b/kernel/fs/fifo.c index 335652f5..3a1d222f 100644 --- a/kernel/fs/fifo.c +++ b/kernel/fs/fifo.c @@ -15,8 +15,16 @@ struct fifo { atomic_size_t num_writers; }; +static struct fifo* fifo_from_inode(struct inode* inode) { + return CONTAINER_OF(inode, struct fifo, inode); +} + +static struct fifo* fifo_from_file(struct file* file) { + return fifo_from_inode(file->inode); +} + static void fifo_destroy_inode(struct inode* inode) { - struct fifo* fifo = (struct fifo*)inode; + struct fifo* fifo = fifo_from_inode(inode); ring_buf_destroy(&fifo->buf); kfree(fifo); } @@ -36,7 +44,7 @@ static bool unblock_open(struct file* file) { static int fifo_open(struct file* file, mode_t mode) { (void)mode; - struct fifo* fifo = (struct fifo*)file->inode; + struct fifo* fifo = fifo_from_file(file); switch (file->flags & O_ACCMODE) { case O_RDONLY: ++fifo->num_readers; @@ -61,7 +69,7 @@ static int fifo_open(struct file* file, mode_t mode) { } static int fifo_close(struct file* file) { - struct fifo* fifo = (struct fifo*)file->inode; + struct fifo* fifo = fifo_from_file(file); switch (file->flags & O_ACCMODE) { case O_RDONLY: --fifo->num_readers; @@ -76,7 +84,7 @@ static int fifo_close(struct file* file) { } static bool unblock_read(struct file* file) { - const struct fifo* fifo = (const struct fifo*)file->inode; + const struct fifo* fifo = fifo_from_file(file); return fifo->num_writers == 0 || !ring_buf_is_empty(&fifo->buf); } @@ -84,7 +92,7 @@ static ssize_t fifo_pread(struct file* file, void* buffer, size_t count, uint64_t offset) { (void)offset; - struct fifo* fifo = (struct fifo*)file->inode; + struct fifo* fifo = fifo_from_file(file); struct ring_buf* buf = &fifo->buf; for (;;) { @@ -115,7 +123,7 @@ static ssize_t fifo_pwrite(struct file* file, const void* buffer, size_t count, uint64_t offset) { (void)offset; - struct fifo* fifo = (struct fifo*)file->inode; + struct fifo* fifo = fifo_from_file(file); struct ring_buf* buf = &fifo->buf; for (;;) { @@ -190,5 +198,5 @@ struct inode* fifo_create(void) { inode->mode = S_IFIFO; inode->ref_count = 1; - return (struct inode*)fifo; + return inode; } diff --git a/kernel/fs/fs.c b/kernel/fs/fs.c index 609e3677..dd156c3c 100644 --- a/kernel/fs/fs.c +++ b/kernel/fs/fs.c @@ -9,6 +9,7 @@ #include #include #include +#include void inode_ref(struct inode* inode) { ASSERT(inode); @@ -27,7 +28,7 @@ void inode_destroy(struct inode* inode) { ASSERT(inode->ref_count == 0 && inode->num_links == 0); ASSERT(inode->fops->destroy_inode); inode_unref(inode->fifo); - inode_unref((struct inode*)inode->bound_socket); + inode_unref(&inode->bound_socket->inode); inode->fops->destroy_inode(inode); } diff --git a/kernel/fs/proc/pid.c b/kernel/fs/proc/pid.c index 7b0f8796..1a125c68 100644 --- a/kernel/fs/proc/pid.c +++ b/kernel/fs/proc/pid.c @@ -159,7 +159,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def, pid_t pid) { proc_pid_item_inode* node = kmalloc(sizeof(proc_pid_item_inode)); if (!node) { - inode_unref((struct inode*)parent); + inode_unref(&parent->inode); return -ENOMEM; } *node = (proc_pid_item_inode){0}; @@ -174,7 +174,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def, inode->ref_count = 1; int rc = dentry_append(&parent->children, item_def->name, inode); - inode_unref((struct inode*)parent); + inode_unref(&parent->inode); return rc; } @@ -215,6 +215,6 @@ struct inode* proc_pid_dir_inode_create(proc_dir_inode* parent, pid_t pid) { return ERR_PTR(rc); } - inode_unref((struct inode*)parent); + inode_unref(&parent->inode); return inode; } diff --git a/kernel/fs/proc/root.c b/kernel/fs/proc/root.c index 148c72f9..874fe43e 100644 --- a/kernel/fs/proc/root.c +++ b/kernel/fs/proc/root.c @@ -218,7 +218,7 @@ static int proc_root_getdents(struct file* file, getdents_callback_fn callback, static int add_item(proc_dir_inode* parent, const proc_item_def* item_def) { proc_item_inode* node = kmalloc(sizeof(proc_item_inode)); if (!node) { - inode_unref((struct inode*)parent); + inode_unref(&parent->inode); return -ENOMEM; } *node = (proc_item_inode){0}; @@ -232,7 +232,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def) { inode->ref_count = 1; int rc = dentry_append(&parent->children, item_def->name, inode); - inode_unref((struct inode*)parent); + inode_unref(&parent->inode); return rc; } diff --git a/kernel/socket.h b/kernel/socket.h index b0a662be..89642216 100644 --- a/kernel/socket.h +++ b/kernel/socket.h @@ -35,3 +35,11 @@ NODISCARD int unix_socket_listen(struct unix_socket*, int backlog); NODISCARD struct unix_socket* unix_socket_accept(struct file*); NODISCARD int unix_socket_connect(struct file*, struct inode* addr_inode); NODISCARD int unix_socket_shutdown(struct file*, int how); + +static inline struct unix_socket* unix_socket_from_inode(struct inode* inode) { + return CONTAINER_OF(inode, struct unix_socket, inode); +} + +static inline struct unix_socket* unix_socket_from_file(struct file* file) { + return unix_socket_from_inode(file->inode); +} diff --git a/kernel/syscall/socket.c b/kernel/syscall/socket.c index aeb798bb..f8da6a13 100644 --- a/kernel/syscall/socket.c +++ b/kernel/syscall/socket.c @@ -16,7 +16,7 @@ int sys_socket(int domain, int type, int protocol) { struct unix_socket* socket = unix_socket_create(); if (IS_ERR(socket)) return PTR_ERR(socket); - struct file* file = inode_open((struct inode*)socket, O_RDWR, 0); + struct file* file = inode_open(&socket->inode, O_RDWR, 0); if (IS_ERR(file)) return PTR_ERR(file); int fd = task_alloc_file_descriptor(-1, file); @@ -31,7 +31,7 @@ int sys_bind(int sockfd, const struct sockaddr* user_addr, socklen_t addrlen) { return PTR_ERR(file); if (!S_ISSOCK(file->inode->mode)) return -ENOTSOCK; - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); if (addrlen <= sizeof(sa_family_t) || sizeof(struct sockaddr_un) < addrlen) return -EINVAL; @@ -69,7 +69,7 @@ int sys_listen(int sockfd, int backlog) { if (!S_ISSOCK(file->inode->mode)) return -ENOTSOCK; - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); return unix_socket_listen(socket, backlog); } @@ -101,8 +101,7 @@ int sys_accept4(int sockfd, struct sockaddr* user_addr, socklen_t* user_addrlen, struct unix_socket* connector = unix_socket_accept(file); if (IS_ERR(connector)) return PTR_ERR(connector); - struct file* connector_file = - inode_open((struct inode*)connector, O_RDWR, 0); + struct file* connector_file = inode_open(&connector->inode, O_RDWR, 0); if (IS_ERR(connector_file)) return PTR_ERR(connector_file); diff --git a/kernel/unix_socket.c b/kernel/unix_socket.c index 1c8f1e93..ccfbd30e 100644 --- a/kernel/unix_socket.c +++ b/kernel/unix_socket.c @@ -7,38 +7,37 @@ #include "task.h" static void unix_socket_destroy_inode(struct inode* inode) { - struct unix_socket* socket = (struct unix_socket*)inode; + struct unix_socket* socket = unix_socket_from_inode(inode); ring_buf_destroy(&socket->to_connector_buf); ring_buf_destroy(&socket->to_acceptor_buf); kfree(socket); } static int unix_socket_close(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); socket->is_open_for_writing_to_connector = false; socket->is_open_for_writing_to_acceptor = false; return 0; } static bool is_connector(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; - return socket->connector_file == file; + return unix_socket_from_file(file)->connector_file == file; } static bool is_open_for_reading(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); return is_connector(file) ? socket->is_open_for_writing_to_connector : socket->is_open_for_writing_to_acceptor; } static struct ring_buf* buf_to_read(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); return is_connector(file) ? &socket->to_connector_buf : &socket->to_acceptor_buf; } static struct ring_buf* buf_to_write(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); return is_connector(file) ? &socket->to_acceptor_buf : &socket->to_connector_buf; } @@ -54,7 +53,7 @@ static ssize_t unix_socket_pread(struct file* file, void* buffer, size_t count, uint64_t offset) { (void)offset; - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); if (!socket->is_connected) return -EINVAL; @@ -78,7 +77,7 @@ static ssize_t unix_socket_pread(struct file* file, void* buffer, size_t count, } static bool is_writable(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); if (is_connector(file)) { if (!socket->is_open_for_writing_to_acceptor) return false; @@ -93,7 +92,7 @@ static ssize_t unix_socket_pwrite(struct file* file, const void* buffer, size_t count, uint64_t offset) { (void)offset; - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); if (!socket->is_connected) return -ENOTCONN; @@ -121,7 +120,7 @@ static ssize_t unix_socket_pwrite(struct file* file, const void* buffer, } static short unix_socket_poll(struct file* file, short events) { - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); short revents = 0; if (events & POLLIN) { bool can_read = @@ -211,15 +210,14 @@ int unix_socket_listen(struct unix_socket* socket, int backlog) { } static bool is_acceptable(struct file* file) { - struct unix_socket* socket = (struct unix_socket*)file->inode; - return socket->num_pending > 0; + return unix_socket_from_file(file)->num_pending > 0; } struct unix_socket* unix_socket_accept(struct file* file) { if (!S_ISSOCK(file->inode->mode)) return ERR_PTR(-ENOTSOCK); - struct unix_socket* listener = (struct unix_socket*)file->inode; + struct unix_socket* listener = unix_socket_from_file(file); mutex_lock(&listener->lock); bool is_listening = listener->state == SOCKET_STATE_LISTENING; @@ -255,8 +253,7 @@ struct unix_socket* unix_socket_accept(struct file* file) { } static bool is_connectable(struct file* file) { - struct unix_socket* connector = (struct unix_socket*)file->inode; - return connector->is_connected; + return unix_socket_from_file(file)->is_connected; } int unix_socket_connect(struct file* file, struct inode* addr_inode) { @@ -267,7 +264,7 @@ int unix_socket_connect(struct file* file, struct inode* addr_inode) { if (!listener) return -ECONNREFUSED; - struct unix_socket* connector = (struct unix_socket*)file->inode; + struct unix_socket* connector = unix_socket_from_file(file); mutex_lock(&connector->lock); switch (connector->state) { @@ -297,7 +294,7 @@ int unix_socket_connect(struct file* file, struct inode* addr_inode) { connector->state = SOCKET_STATE_PENDING; connector->next = NULL; - inode_ref((struct inode*)connector); + inode_ref(&connector->inode); if (listener->next) { struct unix_socket* it = listener->next; @@ -330,7 +327,7 @@ int unix_socket_shutdown(struct file* file, int how) { bool shut_read = how == SHUT_RD || how == SHUT_RDWR; bool shut_write = how == SHUT_WR || how == SHUT_RDWR; bool conn = is_connector(file); - struct unix_socket* socket = (struct unix_socket*)file->inode; + struct unix_socket* socket = unix_socket_from_file(file); if ((conn && shut_read) || (!conn && shut_write)) socket->is_open_for_writing_to_connector = false; if ((conn && shut_write) || (!conn && shut_read))