From d4f027c03bfc5f632b7d63c125dfa57a7ba89926 Mon Sep 17 00:00:00 2001 From: Rubidium Date: Sun, 18 Apr 2021 10:23:41 +0200 Subject: Codechange: encapsulate writing data from Packets into sockets/files/buffers to prevent packet state modifications outside of the Packet --- src/network/core/packet.h | 52 +++++++++++++++++++++++++++++++++++++++++ src/network/core/tcp.cpp | 6 ++--- src/network/core/tcp_listen.h | 4 ++-- src/network/core/udp.cpp | 2 +- src/network/network_client.cpp | 36 ++++++++++++++-------------- src/network/network_content.cpp | 16 +++++++++++-- 6 files changed, 90 insertions(+), 26 deletions(-) diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 3eee0522f..b091d8a7e 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -88,6 +88,58 @@ public: size_t RemainingBytesToTransfer() const; + /** + * Transfer data from the packet to the given function. It starts reading at the + * position the last transfer stopped. + * See Packet::TransferIn for more information about transferring data to functions. + * @param transfer_function The function to pass the buffer as second parameter and the + * amount to write as third parameter. It returns the amount that + * was written or -1 upon errors. + * @param limit The maximum amount of bytes to transfer. + * @param destination The first parameter of the transfer function. + * @param args The fourth and further parameters to the transfer function, if any. + * @return The return value of the transfer_function. + */ + template < + typename A = size_t, ///< The type for the amount to be passed, so it can be cast to the right type. + typename F, ///< The type of the function. + typename D, ///< The type of the destination. + typename ... Args> ///< The types of the remaining arguments to the function. + ssize_t TransferOutWithLimit(F transfer_function, size_t limit, D destination, Args&& ... args) + { + size_t amount = std::min(this->RemainingBytesToTransfer(), limit); + if (amount == 0) return 0; + + assert(this->pos < this->buffer.size()); + assert(this->pos + amount <= this->buffer.size()); + /* Making buffer a char means casting a lot in the Recv/Send functions. */ + const char *output_buffer = reinterpret_cast(this->buffer + this->pos); + ssize_t bytes = transfer_function(destination, output_buffer, static_cast(amount), std::forward(args)...); + if (bytes > 0) this->pos += bytes; + return bytes; + } + + /** + * Transfer data from the packet to the given function. It starts reading at the + * position the last transfer stopped. + * See Packet::TransferIn for more information about transferring data to functions. + * @param transfer_function The function to pass the buffer as second parameter and the + * amount to write as third parameter. It returns the amount that + * was written or -1 upon errors. + * @param destination The first parameter of the transfer function. + * @param args The fourth and further parameters to the transfer function, if any. + * @tparam A The type for the amount to be passed, so it can be cast to the right type. + * @tparam F The type of the transfer_function. + * @tparam D The type of the destination. + * @tparam Args The types of the remaining arguments to the function. + * @return The return value of the transfer_function. + */ + template + ssize_t TransferOut(F transfer_function, D destination, Args&& ... args) + { + return TransferOutWithLimit(transfer_function, std::numeric_limits::max(), destination, std::forward(args)...); + } + /** * Transfer data from the given function into the packet. It starts writing at the * position the last transfer stopped. diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp index aa1e1cbed..ab18f47a8 100644 --- a/src/network/core/tcp.cpp +++ b/src/network/core/tcp.cpp @@ -103,7 +103,7 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) p = this->packet_queue; while (p != nullptr) { - res = send(this->sock, (const char*)p->buffer + p->pos, p->size - p->pos, 0); + res = p->TransferOut(send, this->sock, 0); if (res == -1) { int err = GET_LAST_ERROR(); if (err != EWOULDBLOCK) { @@ -122,10 +122,8 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) return SPS_CLOSED; } - p->pos += res; - /* Is this packet sent? */ - if (p->pos == p->size) { + if (p->RemainingBytesToTransfer() == 0) { /* Go to the next packet */ this->packet_queue = p->next; delete p; diff --git a/src/network/core/tcp_listen.h b/src/network/core/tcp_listen.h index 1f073aa73..53a3d57cc 100644 --- a/src/network/core/tcp_listen.h +++ b/src/network/core/tcp_listen.h @@ -63,7 +63,7 @@ public: DEBUG(net, 1, "[%s] Banned ip tried to join (%s), refused", Tsocket::GetName(), entry.c_str()); - if (send(s, (const char*)p.buffer, p.size, 0) < 0) { + if (p.TransferOut(send, s, 0) < 0) { DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR()); } closesocket(s); @@ -80,7 +80,7 @@ public: Packet p(Tfull_packet); p.PrepareToSend(); - if (send(s, (const char*)p.buffer, p.size, 0) < 0) { + if (p.TransferOut(send, s, 0) < 0) { DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR()); } closesocket(s); diff --git a/src/network/core/udp.cpp b/src/network/core/udp.cpp index 398f53142..8e476f4e2 100644 --- a/src/network/core/udp.cpp +++ b/src/network/core/udp.cpp @@ -99,7 +99,7 @@ void NetworkUDPSocketHandler::SendPacket(Packet *p, NetworkAddress *recv, bool a } /* Send the buffer */ - int res = sendto(s.second, (const char*)p->buffer, p->size, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength()); + ssize_t res = p->TransferOut(sendto, s.second, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength()); DEBUG(net, 7, "[udp] sendto(%s)", send.GetAddressAsString().c_str()); /* Check for any errors, but ignore it otherwise */ diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index 10b4fd141..6156dc486 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -35,7 +35,6 @@ /* This file handles all the client-commands */ - /** Read some packets, and when do use that data as initial load filter. */ struct PacketReader : LoadFilter { static const size_t CHUNK = 32 * 1024; ///< 32 KiB chunks of memory. @@ -59,35 +58,38 @@ struct PacketReader : LoadFilter { } } + /** + * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. + * @param destination The reader to add the data to. + * @param source The buffer to read data from. + * @param amount The number of bytes to copy. + * @return The number of bytes that were copied. + */ + static inline ssize_t TransferOutMemCopy(PacketReader *destination, const char *source, size_t amount) + { + memcpy(destination->buf, source, amount); + destination->buf += amount; + destination->written_bytes += amount; + return amount; + } + /** * Add a packet to this buffer. * @param p The packet to add. */ - void AddPacket(const Packet *p) + void AddPacket(Packet *p) { assert(this->read_bytes == 0); - - size_t in_packet = p->size - p->pos; - size_t to_write = std::min(this->bufe - this->buf, in_packet); - const byte *pbuf = p->buffer + p->pos; - - this->written_bytes += in_packet; - if (to_write != 0) { - memcpy(this->buf, pbuf, to_write); - this->buf += to_write; - } + p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this); /* Did everything fit in the current chunk, then we're done. */ - if (to_write == in_packet) return; + if (p->RemainingBytesToTransfer() == 0) return; /* Allocate a new chunk and add the remaining data. */ - pbuf += to_write; - to_write = in_packet - to_write; this->blocks.push_back(this->buf = CallocT(CHUNK)); this->bufe = this->buf + CHUNK; - memcpy(this->buf, pbuf, to_write); - this->buf += to_write; + p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this); } size_t Read(byte *rbuf, size_t size) override diff --git a/src/network/network_content.cpp b/src/network/network_content.cpp index 0220f890b..529225235 100644 --- a/src/network/network_content.cpp +++ b/src/network/network_content.cpp @@ -459,6 +459,18 @@ static bool GunzipFile(const ContentInfo *ci) #endif /* defined(WITH_ZLIB) */ } +/** + * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. + * @param file The file to write data to. + * @param buffer The buffer to write to the file. + * @param amount The number of bytes to write. + * @return The number of bytes that were written. + */ +static inline ssize_t TransferOutFWrite(FILE *file, const char *buffer, size_t amount) +{ + return fwrite(buffer, 1, amount, file); +} + bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet *p) { if (this->curFile == nullptr) { @@ -476,8 +488,8 @@ bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet *p) } } else { /* We have a file opened, thus are downloading internal content */ - size_t toRead = (size_t)(p->size - p->pos); - if (fwrite(p->buffer + p->pos, 1, toRead, this->curFile) != toRead) { + size_t toRead = p->RemainingBytesToTransfer(); + if (toRead != 0 && (size_t)p->TransferOut(TransferOutFWrite, this->curFile) != toRead) { DeleteWindowById(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_CONTENT_DOWNLOAD); ShowErrorMessage(STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD, STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD_FILE_NOT_WRITABLE, WL_ERROR); this->Close(); -- cgit v1.2.3-70-g09d2