/*
 * Copyright (c) Atmosphère-NX
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
#include 
#include "htcs_session.hpp"
#include "htcs_virtual_socket_collection.hpp"
namespace ams::htcs::client {
    namespace {
        constexpr inline s32 InvalidSocket = -1;
        constexpr inline s32 InvalidPrimitive = -1;
    }
    /* Declare client functions. */
    sf::SharedPointer socket(s32 &last_error);
    s32 close(sf::SharedPointer socket, s32 &last_error);
    s32 bind(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error);
    s32 listen(sf::SharedPointer socket, s32 backlog_count, s32 &last_error);
    sf::SharedPointer accept(sf::SharedPointer socket, htcs::SockAddrHtcs *address, s32 &last_error);
    s32 fcntl(sf::SharedPointer socket, s32 command, s32 value, s32 &last_error);
    s32 shutdown(sf::SharedPointer socket, s32 how, s32 &last_error);
    ssize_t recv(sf::SharedPointer socket, void *buffer, size_t buffer_size, s32 flags, s32 &last_error);
    ssize_t send(sf::SharedPointer socket, const void *buffer, size_t buffer_size, s32 flags, s32 &last_error);
    s32 connect(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error);
    s32 select(s32 * const read, s32 &num_read, s32 * const write, s32 &num_write, s32 * const except, s32 &num_except, htcs::TimeVal *timeout, s32 &last_error);
    struct VirtualSocket {
        s32 m_id;
        s32 m_primitive;
        sf::SharedPointer m_socket;
        bool m_do_bind;
        htcs::SockAddrHtcs m_address;
        s32 m_listen_backlog_count;
        s32 m_fcntl_command;
        s32 m_fcntl_value;
        bool m_blocking;
        VirtualSocket() {
            /* Initialize. */
            this->Init();
        }
        ~VirtualSocket() { /* ... */ }
        void Init() {
            /* Setup fields. */
            m_id        = InvalidSocket;
            m_primitive = InvalidPrimitive;
            m_socket    = nullptr;
            m_blocking  = true;
            m_do_bind   = false;
            std::memset(std::addressof(m_address), 0, sizeof(m_address));
            m_listen_backlog_count = -1;
            m_fcntl_command = -1;
            m_fcntl_value   = 0;
        }
        s32 Bind(const htcs::SockAddrHtcs *address, s32 &last_error) {
            /* Mark the bind. */
            m_do_bind = true;
            /* Set our address. */
            std::memcpy(std::addressof(m_address), address, sizeof(m_address));
            /* Clear the error. */
            last_error = 0;
            return 0;
        }
        s32 Listen(s32 backlog_count, s32 &last_error) {
            s32 res = -1;
            if (m_do_bind) {
                /* Set backlog count. */
                m_listen_backlog_count = std::max(backlog_count, 1);
                /* Clear error. */
                last_error = 0;
                res        = 0;
            } else {
                last_error = HTCS_EINVAL;
            }
            return res;
        }
        s32 Fcntl(s32 command, s32 value, s32 &last_error) {
            /* Clear error. */
            s32 res    = 0;
            last_error = 0;
            if (command == HTCS_F_SETFL) {
                m_fcntl_command = command;
                m_fcntl_value   = value;
                m_blocking = (value & HTCS_O_NONBLOCK) == 0;
            } else if (command == HTCS_F_GETFL) {
                res = m_fcntl_value;
            } else {
                last_error = HTCS_EINVAL;
                res        = -1;
            }
            return res;
        }
        s32 SetSocket(sf::SharedPointer socket, s32 &last_error) {
            s32 res = 0;
            if (m_socket == nullptr && socket != nullptr) {
                /* Set our socket. */
                m_socket = socket;
                /* Bind, fcntl, and listen, since those may have been deferred. */
                if (m_do_bind) {
                    res = bind(m_socket, std::addressof(m_address), last_error);
                }
                if (res == 0 && m_fcntl_command != -1) {
                    res = fcntl(m_socket, m_fcntl_command, m_fcntl_value, last_error);
                }
                if (res == 0 && m_listen_backlog_count > 0) {
                    res = listen(m_socket, m_listen_backlog_count, last_error);
                }
            }
            return res;
        }
    };
    VirtualSocketCollection::VirtualSocketCollection()
        : m_socket_list(nullptr),
          m_list_count(0),
          m_list_size(0),
          m_next_id(1),
          m_mutex()
    {
        /* ... */
    }
    VirtualSocketCollection::~VirtualSocketCollection() {
        /* Clear ourselves. */
        this->Clear();
        /* Destroy all sockets in our list. */
        for (auto i = 0; i < m_list_size; ++i) {
            std::destroy_at(m_socket_list + i);
        }
        /* Clear the backing memory for our socket list. */
        std::memset(m_buffer, 0, sizeof(VirtualSocket) * m_list_size);
    }
    size_t VirtualSocketCollection::GetWorkingMemorySize(int num_sockets) {
        AMS_ASSERT(num_sockets < htcs::SocketCountMax);
        return num_sockets * sizeof(VirtualSocket);
    }
    void VirtualSocketCollection::Init(void *buffer, size_t buffer_size) {
        /* Set our buffer. */
        m_buffer      = buffer;
        m_buffer_size = buffer_size;
        /* Configure our list. */
        m_list_size = static_cast(m_buffer_size / sizeof(VirtualSocket));
        m_list_size = std::min(m_list_size, htcs::SocketCountMax);
        /* Initialize our list. */
        m_socket_list = static_cast(m_buffer);
        for (auto i = 0; i < m_list_size; ++i) {
            std::construct_at(m_socket_list + i);
        }
    }
    void VirtualSocketCollection::Clear() {
        /* Lock ourselves. */
        std::scoped_lock lk(m_mutex);
        /* Clear our list. */
        m_list_count = 0;
    }
    s32 VirtualSocketCollection::Socket(s32 &error_code) {
        /* Clear error code. */
        error_code = 0;
        /* Create the socket. */
        return this->CreateSocket(sf::SharedPointer{nullptr}, error_code);
    }
    s32 VirtualSocketCollection::Close(s32 id, s32 &error_code) {
        /* Clear error code. */
        error_code = 0;
        /* Prepare to find the socket. */
        s32 res = 0;
        sf::SharedPointer socket = nullptr;
        /* Find the socket. */
        s32 index;
        {
            std::scoped_lock lk(m_mutex);
            if (index = this->Find(id, std::addressof(error_code)); index >= 0) {
                /* Get the socket's object. */
                VirtualSocket *virt_socket = m_socket_list + index;
                socket = virt_socket->m_socket;
                /* Move the list. */
                for (/* ... */; index < m_list_count - 1; ++index) {
                    m_socket_list[index] = m_socket_list[index + 1];
                }
                /* Clear the now unused last list entry. */
                m_socket_list[index].Init();
                /* Decrement our list count. */
                --m_list_count;
            }
        }
        /* If we found the socket, close it. */
        if (socket != nullptr) {
            close(socket, error_code);
            /* Clear the error code. */
            res        = 0;
            error_code = 0;
        }
        return index >= 0 ? res : -1;
    }
    s32 VirtualSocketCollection::Bind(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Get the socket. */
        sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code));
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            res = bind(socket, address, error_code);
        } else if (error_code != HTCS_EBADF) {
            /* Lock ourselves. */
            std::scoped_lock lk(m_mutex);
            /* Check if the socket is already bound. */
            if (const auto index = this->Find(id); index >= 0) {
                const auto exists = this->HasAddr(address);
                if (m_socket_list[index].m_do_bind) {
                    error_code = HTCS_EINVAL;
                    res        = -1;
                } else if (exists) {
                    error_code = HTCS_EADDRINUSE;
                    res        = -1;
                } else {
                    res = m_socket_list[index].Bind(address, error_code);
                }
            } else {
                error_code = HTCS_EBADF;
            }
        }
        return res;
    }
    s32 VirtualSocketCollection::Listen(s32 id, s32 backlog_count, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Get the socket. */
        sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code));
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            res = listen(socket, backlog_count, error_code);
        } else if (error_code != HTCS_EBADF) {
            /* Lock ourselves. */
            std::scoped_lock lk(m_mutex);
            /* Try to listen on the virtual socket. */
            if (const auto index = this->Find(id); index >= 0) {
                res = m_socket_list[index].Listen(backlog_count, error_code);
            }
        }
        return res;
    }
    s32 VirtualSocketCollection::Accept(s32 id, htcs::SockAddrHtcs *address, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Declare socket that we're creating. */
        sf::SharedPointer new_socket = nullptr;
        /* Get the socket. */
        sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code));
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            if (error_code != HTCS_ENONE) {
                return -1;
            }
            new_socket = this->DoAccept(socket, id, address, error_code);
        } else if (error_code != HTCS_EBADF) {
            /* Fetch the socket. */
            socket = this->FetchSocket(id, error_code);
            /* Wait for the socket. */
            while (socket == nullptr) {
                /* Determine whether we should block/listen. */
                bool block_until_done = true;
                bool listened         = false;
                s32 index;
                {
                    std::scoped_lock lk(m_mutex);
                    if (index = this->Find(id, std::addressof(error_code)); index >= 0) {
                        block_until_done = m_socket_list[index].m_blocking;
                        listened         = m_socket_list[index].m_listen_backlog_count > 0;
                    }
                }
                /* Check that the socket exists. */
                if (index < 0) {
                    error_code = HTCS_EINTR;
                    return -1;
                }
                /* Check that the socket has been listened. */
                if (!listened) {
                    error_code = HTCS_EINVAL;
                    return -1;
                }
                /* Check that we should block. */
                if (!block_until_done) {
                    error_code = HTCS_EWOULDBLOCK;
                    return -1;
                }
                /* Wait before trying again. */
                os::SleepThread(TimeSpan::FromMilliSeconds(500));
                /* Fetch the potentially updated socket. */
                socket = this->FetchSocket(id, error_code);
            }
            /* Check that we haven't errored. */
            if (error_code != HTCS_ENONE) {
                return -1;
            }
            /* Do the accept. */
            new_socket = this->DoAccept(socket, id, address, error_code);
        }
        /* If we have a new socket, register it. */
        if (new_socket != 0) {
            res = this->CreateSocket(new_socket, error_code);
            if (res < 0) {
                s32 tmp_error_code;
                close(new_socket, tmp_error_code);
            }
        }
        return res;
    }
    s32 VirtualSocketCollection::Fcntl(s32 id, s32 command, s32 value, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Get the socket. */
        sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code));
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            res = fcntl(socket, command, value, error_code);
        } else if (error_code != HTCS_EBADF) {
            /* Lock ourselves. */
            std::scoped_lock lk(m_mutex);
            /* Try to listen on the virtual socket. */
            if (const auto index = this->Find(id); index >= 0) {
                res = m_socket_list[index].Fcntl(command, value, error_code);
            }
        }
        return res;
    }
    s32 VirtualSocketCollection::Shutdown(s32 id, s32 how, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Get the socket. */
        sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code));
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            res = shutdown(socket, how, error_code);
        } else if (error_code != HTCS_EBADF) {
            error_code = HTCS_ENOTCONN;
        }
        return res;
    }
    ssize_t VirtualSocketCollection::Recv(s32 id, void *buffer, size_t buffer_size, s32 flags, s32 &error_code) {
        /* Setup result/error code. */
        ssize_t res = -1;
        error_code  = 0;
        /* Fetch the socket. */
        sf::SharedPointer socket = this->FetchSocket(id, error_code);
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            if (error_code != HTCS_ENONE) {
                return -1;
            }
            res = recv(socket, buffer, buffer_size, flags, error_code);
        } else if (error_code != HTCS_EBADF) {
            error_code = HTCS_ENOTCONN;
        }
        return res;
    }
    ssize_t VirtualSocketCollection::Send(s32 id, const void *buffer, size_t buffer_size, s32 flags, s32 &error_code) {
        /* Setup result/error code. */
        ssize_t res = -1;
        error_code  = 0;
        /* Fetch the socket. */
        sf::SharedPointer socket = this->FetchSocket(id, error_code);
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            if (error_code != HTCS_ENONE) {
                return -1;
            }
            res = send(socket, buffer, buffer_size, flags, error_code);
        } else if (error_code != HTCS_EBADF) {
            error_code = HTCS_ENOTCONN;
        }
        return res;
    }
    s32 VirtualSocketCollection::Connect(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code) {
        /* Setup result/error code. */
        s32 res    = -1;
        error_code = 0;
        /* Fetch the socket. */
        sf::SharedPointer socket = this->FetchSocket(id, error_code);
        /* If we found the socket, bind. */
        if (socket != nullptr) {
            if (error_code != HTCS_ENONE) {
                return -1;
            }
            res = connect(socket, address, error_code);
        } else if (error_code != HTCS_EBADF) {
            error_code = HTCS_EADDRNOTAVAIL;
        }
        return res;
    }
    s32 VirtualSocketCollection::Select(htcs::FdSet *read, htcs::FdSet *write, htcs::FdSet *except, htcs::TimeVal *timeout, s32 &error_code) {
        /* Setup result/error code. */
        s32 res            = -1;
        s32 tmp_error_code = 0;
        /* Declare buffers. */
        s32 read_primitives[SocketCountMax];
        s32 write_primitives[SocketCountMax];
        s32 except_primitives[SocketCountMax];
        /* Get reads. */
        s32 num_read = this->GetSockets(read_primitives, read, tmp_error_code);
        if (tmp_error_code != HTCS_ENONE) {
            error_code = tmp_error_code;
            return res;
        }
        /* Get writes. */
        s32 num_write = this->GetSockets(write_primitives, write, tmp_error_code);
        if (tmp_error_code != HTCS_ENONE) {
            error_code = tmp_error_code;
            return res;
        }
        /* Get excepts. */
        s32 num_except = this->GetSockets(except_primitives, except, tmp_error_code);
        if (tmp_error_code != HTCS_ENONE) {
            error_code = tmp_error_code;
            return res;
        }
        /* Perform the select. */
        if (num_read + num_write + num_except > 0) {
            res = select(read_primitives, num_read, write_primitives, num_write, except_primitives, num_except, timeout, error_code);
            /* Set the socket primitives. */
            this->SetSockets(read, read_primitives, num_read);
            this->SetSockets(write, write_primitives, num_write);
            this->SetSockets(except, except_primitives, num_except);
        } else {
            error_code = HTCS_EINVAL;
        }
        return res;
    }
    s32 VirtualSocketCollection::CreateId() {
        /* Lock ourselves. */
        std::scoped_lock lk(m_mutex);
        /* Get a free id. */
        s32 res = 0;
        do {
            res = m_next_id++;
            if (m_next_id <= 0) {
                m_next_id = 1;
            }
        } while (this->Find(res) >= 0);
        return res;
    }
    s32 VirtualSocketCollection::Add(sf::SharedPointer socket) {
        /* Check that the socket isn't null. */
        if (socket == nullptr) {
            return -1;
        }
        /* Create the socket. */
        s32 error_code;
        return this->CreateSocket(socket, error_code);
    }
    void VirtualSocketCollection::Insert(s32 id, sf::SharedPointer socket) {
        /* Lock ourselves. */
        std::scoped_lock lk(m_mutex);
        /* Add the socket to the list. */
        if (m_list_count != 0) {
            /* Ensure the list remains in sorted order. */
            s32 index;
            for (index = m_list_count - 1; index >= 0; --index) {
                if (m_socket_list[index].m_id < id) {
                    break;
                }
                m_socket_list[index + 1] = m_socket_list[index];
            }
            /* Set the socket in the list. */
            m_socket_list[index + 1].m_id     = id;
            m_socket_list[index + 1].m_socket = socket;
        } else {
            /* Set the socket in the list. */
            m_socket_list[0].m_id     = id;
            m_socket_list[0].m_socket = socket;
        }
        /* Increment our count. */
        ++m_list_count;
    }
    void VirtualSocketCollection::SetSize(s32 size) {
        AMS_UNUSED(size);
    }
    s32 VirtualSocketCollection::Find(s32 id, s32 *error_code) {
        /* Perform a binary search to find the socket. */
        if (m_list_count > 0) {
            s32 left  = 0;
            s32 right = m_list_count - 1;
            while (left <= right) {
                const s32 mid = (left + right) / 2;
                if (m_socket_list[mid].m_id == id) {
                    return mid;
                } else if (m_socket_list[mid].m_id > id) {
                    right = mid - 1;
                } else /* if (m_socket_list[mid].m_id < id) */ {
                    left  = mid + 1;
                }
            }
        }
        /* We failed to find the socket. */
        if (error_code != nullptr) {
            *error_code = HTCS_EBADF;
        }
        return InvalidSocket;
    }
    s32 VirtualSocketCollection::FindByPrimitive(s32 primitive) {
        /* Find a socket with the desired primitive. */
        for (auto i = 0; i < m_list_size; ++i) {
            if (m_socket_list[i].m_primitive == primitive) {
                return i;
            }
        }
        return InvalidPrimitive;
    }
    bool VirtualSocketCollection::HasAddr(const htcs::SockAddrHtcs *address) {
        /* Try to find a matching socket. */
        for (auto i = 0; i < m_list_count; ++i) {
            if (m_socket_list[i].m_address.family == address->family &&
                std::strcmp(m_socket_list[i].m_address.peer_name.name, address->peer_name.name) == 0 &&
                std::strcmp(m_socket_list[i].m_address.port_name.name, address->port_name.name) == 0)
            {
                return true;
            }
        }
        return false;
    }
    sf::SharedPointer VirtualSocketCollection::GetSocket(s32 id, s32 *error_code) {
        sf::SharedPointer res = nullptr;
        /* Get the socket. */
        {
            std::scoped_lock lk(m_mutex);
            if (const auto index = this->Find(id, error_code); index >= 0) {
                res = m_socket_list[index].m_socket;
            }
        }
        return res;
    }
    sf::SharedPointer VirtualSocketCollection::FetchSocket(s32 id, s32 &error_code) {
        /* Clear the error code. */
        error_code = 0;
        /* Get the socket. */
        auto socket = this->GetSocket(id, std::addressof(error_code));
        if (socket == nullptr && error_code == HTCS_ENONE) {
            socket = this->RealizeSocket(id);
        }
        return socket;
    }
    sf::SharedPointer VirtualSocketCollection::RealizeSocket(s32 id) {
        /* Clear the error code. */
        s32 error_code = 0;
        /* Get socket. */
        sf::SharedPointer res = socket(error_code);
        if (res != nullptr) {
            /* Assign the new socket. */
            s32 index;
            {
                std::scoped_lock lk(m_mutex);
                index = this->Find(id, std::addressof(error_code));
                if (index >= 0) {
                    m_socket_list[index].SetSocket(res, error_code);
                }
            }
            /* If the socket was deleted, close it. */
            if (index < 0) {
                s32 temp_error = 0;
                close(res, temp_error);
                res = nullptr;
            }
        }
        return res;
    }
    sf::SharedPointer VirtualSocketCollection::DoAccept(sf::SharedPointer socket, s32 id, htcs::SockAddrHtcs *address, s32 &error_code) {
        /* Clear the error code. */
        error_code = 0;
        /* Try to accept. */
        sf::SharedPointer new_socket = accept(socket, address, error_code);
        if (error_code == HTCS_ENETDOWN) {
            new_socket = accept(socket, address, error_code);
            std::scoped_lock lk(m_mutex);
            if (const auto index = this->Find(id, std::addressof(error_code)); index >= 0) {
                m_socket_list[index].m_socket = nullptr;
            }
        }
        return new_socket;
    }
    s32 VirtualSocketCollection::GetSockets(s32 * const out_primitives, htcs::FdSet *set, s32 &error_code) {
        /* Clear the error code. */
        error_code = 0;
        s32 count  = 0;
        /* Walk the fdset. */
        if (set != nullptr) {
            for (auto i = 0; i < FdSetSize; ++i) {
                /* If the set no longer has fds, we're done. */
                if (set->fds[i] == 0) {
                    break;
                }
                /* Find the fd's primitive. */
                s32 primitive = InvalidPrimitive;
                s32 index;
                {
                    std::scoped_lock lk(m_mutex);
                    if (index = this->Find(set->fds[i], std::addressof(error_code)); index >= 0) {
                        /* Get the primitive, if necessary. */
                        if (m_socket_list[index].m_primitive == InvalidPrimitive && m_socket_list[index].m_socket != nullptr) {
                            m_socket_list[index].m_socket->GetPrimitive(std::addressof(m_socket_list[index].m_primitive));
                        }
                        primitive = m_socket_list[index].m_primitive;
                    }
                }
                /* Check that an error didn't occur. */
                if (error_code != HTCS_ENONE) {
                    return 0;
                }
                /* If the primitive is invalid, try to realize the socket. */
                if (primitive == InvalidPrimitive) {
                    if (this->RealizeSocket(set->fds[i]) != nullptr) {
                        std::scoped_lock lk(m_mutex);
                        /* Get the primitive. */
                        if (index = this->Find(set->fds[i], std::addressof(error_code)); index >= 0) {
                            m_socket_list[index].m_socket->GetPrimitive(std::addressof(m_socket_list[index].m_primitive));
                            primitive = m_socket_list[index].m_primitive;
                        }
                    }
                    /* Check that an error didn't occur. */
                    if (error_code != HTCS_ENONE) {
                        return 0;
                    }
                }
                /* Set the output primitive. */
                if (primitive != InvalidPrimitive) {
                    out_primitives[count++] = primitive;
                }
            }
        }
        return count;
    }
    void VirtualSocketCollection::SetSockets(htcs::FdSet *set, s32 * const primitives, s32 count) {
        if (set != nullptr) {
            /* Clear the set. */
            FdSetZero(set);
            /* Copy the fds. */
            for (auto i = 0; i < count; ++i) {
                std::scoped_lock lk(m_mutex);
                if (const auto index = this->FindByPrimitive(primitives[i]); index >= 0) {
                    set->fds[i] = m_socket_list[index].m_id;
                }
            }
        }
    }
    s32 VirtualSocketCollection::CreateSocket(sf::SharedPointer socket, s32 &error_code) {
        /* Clear the error code. */
        error_code = 0;
        s32 id     = InvalidSocket;
        /* Check that we can add to the list. */
        if (m_list_count < m_list_size) {
            /* Create a new id. */
            id = this->CreateId();
            /* Insert the socket into the list. */
            this->Insert(id, socket);
        } else {
            if (socket != nullptr) {
                s32 tmp_error_code;
                close(socket, tmp_error_code);
            }
            error_code = HTCS_EMFILE;
        }
        return id;
    }
}