summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/network/core/packet.cpp41
-rw-r--r--src/network/core/packet.h3
-rw-r--r--src/network/network_admin.cpp2
-rw-r--r--src/network/network_client.cpp2
-rw-r--r--src/network/network_udp.cpp12
5 files changed, 37 insertions, 23 deletions
diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp
index 4eb0e929e..54f5a79e1 100644
--- a/src/network/core/packet.cpp
+++ b/src/network/core/packet.cpp
@@ -68,6 +68,16 @@ void Packet::PrepareToSend()
this->pos = 0; // We start reading from here
}
+/**
+ * Is it safe to write to the packet, i.e. didn't we run over the buffer?
+ * @param bytes_to_write The amount of bytes we want to try to write.
+ * @return True iff the given amount of bytes can be written to the packet.
+ */
+bool Packet::CanWriteToPacket(size_t bytes_to_write)
+{
+ return this->size + bytes_to_write < SEND_MTU;
+}
+
/*
* The next couple of functions make sure we can send
* uint8, uint16, uint32 and uint64 endian-safe
@@ -95,7 +105,7 @@ void Packet::Send_bool(bool data)
*/
void Packet::Send_uint8(uint8 data)
{
- assert(this->size < SEND_MTU - sizeof(data));
+ assert(this->CanWriteToPacket(sizeof(data)));
this->buffer[this->size++] = data;
}
@@ -105,7 +115,7 @@ void Packet::Send_uint8(uint8 data)
*/
void Packet::Send_uint16(uint16 data)
{
- assert(this->size < SEND_MTU - sizeof(data));
+ assert(this->CanWriteToPacket(sizeof(data)));
this->buffer[this->size++] = GB(data, 0, 8);
this->buffer[this->size++] = GB(data, 8, 8);
}
@@ -116,7 +126,7 @@ void Packet::Send_uint16(uint16 data)
*/
void Packet::Send_uint32(uint32 data)
{
- assert(this->size < SEND_MTU - sizeof(data));
+ assert(this->CanWriteToPacket(sizeof(data)));
this->buffer[this->size++] = GB(data, 0, 8);
this->buffer[this->size++] = GB(data, 8, 8);
this->buffer[this->size++] = GB(data, 16, 8);
@@ -129,7 +139,7 @@ void Packet::Send_uint32(uint32 data)
*/
void Packet::Send_uint64(uint64 data)
{
- assert(this->size < SEND_MTU - sizeof(data));
+ assert(this->CanWriteToPacket(sizeof(data)));
this->buffer[this->size++] = GB(data, 0, 8);
this->buffer[this->size++] = GB(data, 8, 8);
this->buffer[this->size++] = GB(data, 16, 8);
@@ -148,8 +158,8 @@ void Packet::Send_uint64(uint64 data)
void Packet::Send_string(const char *data)
{
assert(data != nullptr);
- /* The <= *is* valid due to the fact that we are comparing sizes and not the index. */
- assert(this->size + strlen(data) + 1 <= SEND_MTU);
+ /* Length of the string + 1 for the '\0' termination. */
+ assert(this->CanWriteToPacket(strlen(data) + 1));
while ((this->buffer[this->size++] = *data++) != '\0') {}
}
@@ -162,18 +172,21 @@ void Packet::Send_string(const char *data)
/**
- * Is it safe to read from the packet, i.e. didn't we run over the buffer ?
- * @param bytes_to_read The amount of bytes we want to try to read.
+ * Is it safe to read from the packet, i.e. didn't we run over the buffer?
+ * In case \c close_connection is true, the connection will be closed when one would
+ * overrun the buffer. When it is false, the connection remains untouched.
+ * @param bytes_to_read The amount of bytes we want to try to read.
+ * @param close_connection Whether to close the connection if one cannot read that amount.
* @return True if that is safe, otherwise false.
*/
-bool Packet::CanReadFromPacket(uint bytes_to_read)
+bool Packet::CanReadFromPacket(size_t bytes_to_read, bool close_connection)
{
/* Don't allow reading from a quit client/client who send bad data */
if (this->cs->HasClientQuit()) return false;
/* Check if variable is within packet-size */
if (this->pos + bytes_to_read > this->size) {
- this->cs->NetworkSocketHandler::CloseConnection();
+ if (close_connection) this->cs->NetworkSocketHandler::CloseConnection();
return false;
}
@@ -235,7 +248,7 @@ uint8 Packet::Recv_uint8()
{
uint8 n;
- if (!this->CanReadFromPacket(sizeof(n))) return 0;
+ if (!this->CanReadFromPacket(sizeof(n), true)) return 0;
n = this->buffer[this->pos++];
return n;
@@ -249,7 +262,7 @@ uint16 Packet::Recv_uint16()
{
uint16 n;
- if (!this->CanReadFromPacket(sizeof(n))) return 0;
+ if (!this->CanReadFromPacket(sizeof(n), true)) return 0;
n = (uint16)this->buffer[this->pos++];
n += (uint16)this->buffer[this->pos++] << 8;
@@ -264,7 +277,7 @@ uint32 Packet::Recv_uint32()
{
uint32 n;
- if (!this->CanReadFromPacket(sizeof(n))) return 0;
+ if (!this->CanReadFromPacket(sizeof(n), true)) return 0;
n = (uint32)this->buffer[this->pos++];
n += (uint32)this->buffer[this->pos++] << 8;
@@ -281,7 +294,7 @@ uint64 Packet::Recv_uint64()
{
uint64 n;
- if (!this->CanReadFromPacket(sizeof(n))) return 0;
+ if (!this->CanReadFromPacket(sizeof(n), true)) return 0;
n = (uint64)this->buffer[this->pos++];
n += (uint64)this->buffer[this->pos++] << 8;
diff --git a/src/network/core/packet.h b/src/network/core/packet.h
index 6e5c5509c..901d3f593 100644
--- a/src/network/core/packet.h
+++ b/src/network/core/packet.h
@@ -63,6 +63,7 @@ public:
/* Sending/writing of packets */
void PrepareToSend();
+ bool CanWriteToPacket(size_t bytes_to_write);
void Send_bool (bool data);
void Send_uint8 (uint8 data);
void Send_uint16(uint16 data);
@@ -75,7 +76,7 @@ public:
bool ParsePacketSize();
void PrepareToRead();
- bool CanReadFromPacket (uint bytes_to_read);
+ bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false);
bool Recv_bool ();
uint8 Recv_uint8 ();
uint16 Recv_uint16();
diff --git a/src/network/network_admin.cpp b/src/network/network_admin.cpp
index fa97b7e57..057ad5988 100644
--- a/src/network/network_admin.cpp
+++ b/src/network/network_admin.cpp
@@ -613,7 +613,7 @@ NetworkRecvStatus ServerNetworkAdminSocketHandler::SendCmdNames()
/* Should SEND_MTU be exceeded, start a new packet
* (magic 5: 1 bool "more data" and one uint16 "command id", one
* byte for string '\0' termination and 1 bool "no more data" */
- if (p->size + strlen(cmdname) + 5 >= SEND_MTU) {
+ if (p->CanWriteToPacket(strlen(cmdname) + 5)) {
p->Send_bool(false);
this->SendPacket(p);
diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp
index 72f69f99f..10b4fd141 100644
--- a/src/network/network_client.cpp
+++ b/src/network/network_client.cpp
@@ -933,7 +933,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_FRAME(Packet *p
}
#endif
/* Receive the token. */
- if (p->pos != p->size) this->token = p->Recv_uint8();
+ if (p->CanReadFromPacket(sizeof(uint8))) this->token = p->Recv_uint8();
DEBUG(net, 5, "Received FRAME %d", _frame_counter_server);
diff --git a/src/network/network_udp.cpp b/src/network/network_udp.cpp
index 46a21fc87..aa34515bd 100644
--- a/src/network/network_udp.cpp
+++ b/src/network/network_udp.cpp
@@ -220,23 +220,23 @@ void ServerNetworkUDPSocketHandler::Receive_CLIENT_DETAIL_INFO(Packet *p, Networ
static const uint MIN_CI_SIZE = 54;
uint max_cname_length = NETWORK_COMPANY_NAME_LENGTH;
- if (Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH) >= (uint)SEND_MTU - packet.size) {
+ if (!packet.CanWriteToPacket(Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH))) {
/* Assume we can at least put the company information in the packets. */
- assert(Company::GetNumItems() * MIN_CI_SIZE < (uint)SEND_MTU - packet.size);
+ assert(packet.CanWriteToPacket(Company::GetNumItems() * MIN_CI_SIZE));
/* At this moment the company names might not fit in the
* packet. Check whether that is really the case. */
for (;;) {
- int free = SEND_MTU - packet.size;
+ size_t required = 0;
for (const Company *company : Company::Iterate()) {
char company_name[NETWORK_COMPANY_NAME_LENGTH];
SetDParam(0, company->index);
GetString(company_name, STR_COMPANY_NAME, company_name + max_cname_length - 1);
- free -= MIN_CI_SIZE;
- free -= (int)strlen(company_name);
+ required += MIN_CI_SIZE;
+ required += strlen(company_name);
}
- if (free >= 0) break;
+ if (packet.CanWriteToPacket(required)) break;
/* Try again, with slightly shorter strings. */
assert(max_cname_length > 0);