diff options
author | Michael Lutz <michi@icosahedron.de> | 2021-10-28 23:48:26 +0200 |
---|---|---|
committer | Michael Lutz <michi@icosahedron.de> | 2021-12-16 22:28:32 +0100 |
commit | a05fd7aa50ecbee425df2d6f9015ec3ea359232f (patch) | |
tree | 9b9208d39813670789635b638592a903a2aafaa4 /src | |
parent | b0990fcff7358e839468e5cf811ffddc8b9d73e2 (diff) | |
download | openttd-a05fd7aa50ecbee425df2d6f9015ec3ea359232f.tar.xz |
Change: [Network] Transfer command data as serialized byte stream without fixed structure.
The data will be transmitted as the length followed by the serialized data. This allows the command
data to be different for every command type in the future.
Diffstat (limited to 'src')
-rw-r--r-- | src/command.cpp | 32 | ||||
-rw-r--r-- | src/command_func.h | 7 | ||||
-rw-r--r-- | src/command_type.h | 8 | ||||
-rw-r--r-- | src/core/span_type.hpp | 2 | ||||
-rw-r--r-- | src/misc/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/misc/endian_buffer.hpp | 206 | ||||
-rw-r--r-- | src/network/core/packet.cpp | 28 | ||||
-rw-r--r-- | src/network/core/packet.h | 2 | ||||
-rw-r--r-- | src/network/network.cpp | 28 | ||||
-rw-r--r-- | src/network/network_admin.cpp | 5 | ||||
-rw-r--r-- | src/network/network_command.cpp | 129 | ||||
-rw-r--r-- | src/network/network_internal.h | 13 | ||||
-rw-r--r-- | src/network/network_server.cpp | 8 | ||||
-rw-r--r-- | src/string.cpp | 18 | ||||
-rw-r--r-- | src/string_func.h | 3 |
15 files changed, 455 insertions, 35 deletions
diff --git a/src/command.cpp b/src/command.cpp index 001f2ee7d..788c71989 100644 --- a/src/command.cpp +++ b/src/command.cpp @@ -55,6 +55,8 @@ #include "viewport_cmd.h" #include "water_cmd.h" #include "waypoint_cmd.h" +#include "misc/endian_buffer.hpp" +#include "string_func.h" #include <array> @@ -400,11 +402,37 @@ bool DoCommandP(Commands cmd, StringID err_message, CommandCallback *callback, T } /** + * Toplevel network safe docommand function for the current company. Must not be called recursively. + * The callback is called when the command succeeded or failed. The parameters + * \a tile, \a p1, and \a p2 are from the #CommandProc function. The parameter \a cmd is the command to execute. + * + * @param cmd The command to execute (a CMD_* value) + * @param err_message Message prefix to show on error + * @param callback A callback function to call after the command is finished + * @param my_cmd indicator if the command is from a company or server (to display error messages for a user) + * @param tile The tile to perform a command on (see #CommandProc) + * @param p1 Additional data for the command (see #CommandProc) + * @param p2 Additional data for the command (see #CommandProc) + * @param text The text to pass + * @return \c true if the command succeeded, else \c false. + */ +bool InjectNetworkCommand(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text) +{ + return DoCommandP(cmd, err_message, callback, my_cmd, true, tile, p1, p2, text); +} + +/** * Helper to deduplicate the code for returning. * @param cmd the command cost to return. */ #define return_dcpi(cmd) { _docommand_recursive = 0; return cmd; } +/** Helper to format command parameters into a hex string. */ +static std::string CommandParametersToHexString(TileIndex tile, uint32 p1, uint32 p2, const std::string &text) +{ + return FormatArrayAsHex(EndianBufferWriter<>::FromValue(std::make_tuple(tile, p1, p2, text))); +} + /*! * Helper function for the toplevel network safe docommand function for the current company. * @@ -482,7 +510,7 @@ CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallba if (!_networking || _generating_world || network_command) { /* Log the failed command as well. Just to be able to be find * causes of desyncs due to bad command test implementations. */ - Debug(desync, 1, "cmdf: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, tile, p1, p2, cmd, err_message, text, GetCommandName(cmd)); + Debug(desync, 1, "cmdf: {:08x}; {:02x}; {:02x}; {:08x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cmd, err_message, tile, CommandParametersToHexString(tile, p1, p2, text), GetCommandName(cmd)); } cur_company.Restore(); return_dcpi(res); @@ -502,7 +530,7 @@ CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallba * reset the storages as we've not executed the command. */ return_dcpi(CommandCost()); } - Debug(desync, 1, "cmd: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, tile, p1, p2, cmd, err_message, text, GetCommandName(cmd)); + Debug(desync, 1, "cmd: {:08x}; {:02x}; {:02x}; {:08x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cmd, err_message, tile, CommandParametersToHexString(tile, p1, p2, text), GetCommandName(cmd)); /* Actually try and execute the command. If no cost-type is given * use the construction one */ diff --git a/src/command_func.h b/src/command_func.h index 03bfc73f1..0d000755c 100644 --- a/src/command_func.h +++ b/src/command_func.h @@ -12,6 +12,7 @@ #include "command_type.h" #include "company_type.h" +#include <vector> /** * Define a default return value for a failed command. @@ -32,6 +33,9 @@ static const CommandCost CMD_ERROR = CommandCost(INVALID_STRING_ID); */ #define return_cmd_error(errcode) return CommandCost(errcode); +/** Storage buffer for serialized command data. */ +typedef std::vector<byte> CommandDataBuffer; + CommandCost DoCommand(DoCommandFlag flags, Commands cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text = {}); CommandCost DoCommand(const CommandContainer *container, DoCommandFlag flags); @@ -41,9 +45,12 @@ bool DoCommandP(Commands cmd, CommandCallback *callback, TileIndex tile, uint32 bool DoCommandP(Commands cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text = {}); bool DoCommandP(const CommandContainer *container, bool my_cmd = true, bool network_command = false); +bool InjectNetworkCommand(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); + CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, bool estimate_only, bool network_command, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); +void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex location, const CommandDataBuffer &cmd_data); extern Money _additional_cash_required; diff --git a/src/command_type.h b/src/command_type.h index fa381ce13..3a15087b9 100644 --- a/src/command_type.h +++ b/src/command_type.h @@ -424,11 +424,19 @@ enum CommandPauseLevel { typedef CommandCost CommandProc(DoCommandFlag flags, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); +template <typename T> struct CommandFunctionTraitHelper; +template <typename... Targs> +struct CommandFunctionTraitHelper<CommandCost(*)(DoCommandFlag, Targs...)> { + using Args = std::tuple<std::decay_t<Targs>...>; +}; + /** Defines the traits of a command. */ template <Commands Tcmd> struct CommandTraits; #define DEF_CMD_TRAIT(cmd_, proc_, flags_, type_) \ template<> struct CommandTraits<cmd_> { \ + using Args = typename CommandFunctionTraitHelper<decltype(&proc_)>::Args; \ + static constexpr Commands cmd = cmd_; \ static constexpr auto &proc = proc_; \ static constexpr CommandFlags flags = (CommandFlags)(flags_); \ static constexpr CommandType type = type_; \ diff --git a/src/core/span_type.hpp b/src/core/span_type.hpp index 614be8456..0df528816 100644 --- a/src/core/span_type.hpp +++ b/src/core/span_type.hpp @@ -92,6 +92,8 @@ public: constexpr const_iterator cbegin() const noexcept { return const_iterator(first); } constexpr const_iterator cend() const noexcept { return const_iterator(last); } + constexpr reference operator[](size_type idx) const { return first[idx]; } + private: pointer first; pointer last; diff --git a/src/misc/CMakeLists.txt b/src/misc/CMakeLists.txt index ee2ca6a41..24cde73e4 100644 --- a/src/misc/CMakeLists.txt +++ b/src/misc/CMakeLists.txt @@ -5,6 +5,7 @@ add_files( countedptr.hpp dbg_helpers.cpp dbg_helpers.h + endian_buffer.hpp fixedsizearray.hpp getoptdata.cpp getoptdata.h diff --git a/src/misc/endian_buffer.hpp b/src/misc/endian_buffer.hpp new file mode 100644 index 000000000..c20d9a8b9 --- /dev/null +++ b/src/misc/endian_buffer.hpp @@ -0,0 +1,206 @@ +/* + * This file is part of OpenTTD. + * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2. + * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>. + */ + +/** @file endian_buffer.hpp Endian-aware buffer. */ + +#ifndef ENDIAN_BUFFER_HPP +#define ENDIAN_BUFFER_HPP + +#include <iterator> +#include <string_view> +#include "../core/span_type.hpp" +#include "../core/bitmath_func.hpp" + +struct StrongTypedefBase; + +/** + * Endian-aware buffer adapter that always writes values in little endian order. + * @note This class uses operator overloading (<<, just like streams) for writing + * as this allows providing custom operator overloads for more complex types + * like e.g. structs without needing to modify this class. + */ +template <typename Tcont = typename std::vector<byte>, typename Titer = typename std::back_insert_iterator<Tcont>> +class EndianBufferWriter { + /** Output iterator for the destination buffer. */ + Titer buffer; + +public: + EndianBufferWriter(Titer buffer) : buffer(buffer) {} + EndianBufferWriter(typename Titer::container_type &container) : buffer(std::back_inserter(container)) {} + + EndianBufferWriter &operator <<(const std::string &data) { return *this << std::string_view{ data }; } + EndianBufferWriter &operator <<(const char *data) { return *this << std::string_view{ data }; } + EndianBufferWriter &operator <<(std::string_view data) { this->Write(data); return *this; } + EndianBufferWriter &operator <<(bool data) { return *this << static_cast<byte>(data ? 1 : 0); } + + template <typename... Targs> + EndianBufferWriter &operator <<(const std::tuple<Targs...> &data) + { + this->WriteTuple(data, std::index_sequence_for<Targs...>{}); + return *this; + } + + template <class T, std::enable_if_t<std::disjunction_v<std::negation<std::is_class<T>>, std::is_base_of<StrongTypedefBase, T>>, int> = 0> + EndianBufferWriter &operator <<(const T data) + { + if constexpr (std::is_enum_v<T>) { + this->Write(static_cast<std::underlying_type_t<const T>>(data)); + } else if constexpr (std::is_base_of_v<StrongTypedefBase, T>) { + this->Write(data.value); + } else { + this->Write(data); + } + return *this; + } + + template <typename Tvalue, typename Tbuf = std::vector<byte>> + static Tbuf FromValue(const Tvalue &data) + { + Tbuf buffer; + EndianBufferWriter writer{ buffer }; + writer << data; + return buffer; + } + +private: + /** Helper function to write a tuple to the buffer. */ + template<class Ttuple, size_t... Tindices> + void WriteTuple(const Ttuple &values, std::index_sequence<Tindices...>) { + ((*this << std::get<Tindices>(values)), ...); + } + + /** Write overload for string values. */ + void Write(std::string_view value) + { + for (auto c : value) { + this->buffer++ = c; + } + this->buffer++ = '\0'; + } + + /** Fundamental write function. */ + template <class T> + void Write(T value) + { + static_assert(sizeof(T) <= 8, "Value can't be larger than 8 bytes"); + + if constexpr (sizeof(T) > 1) { + this->buffer++ = GB(value, 0, 8); + this->buffer++ = GB(value, 8, 8); + if constexpr (sizeof(T) > 2) { + this->buffer++ = GB(value, 16, 8); + this->buffer++ = GB(value, 24, 8); + } + if constexpr (sizeof(T) > 4) { + this->buffer++ = GB(value, 32, 8); + this->buffer++ = GB(value, 40, 8); + this->buffer++ = GB(value, 48, 8); + this->buffer++ = GB(value, 56, 8); + } + } else { + this->buffer++ = value; + } + } +}; + +/** + * Endian-aware buffer adapter that always reads values in little endian order. + * @note This class uses operator overloading (>>, just like streams) for reading + * as this allows providing custom operator overloads for more complex types + * like e.g. structs without needing to modify this class. + */ +class EndianBufferReader { + /** Reference to storage buffer. */ + span<const byte> buffer; + /** Current read position. */ + size_t read_pos = 0; + +public: + EndianBufferReader(span<const byte> buffer) : buffer(buffer) {} + + void rewind() { this->read_pos = 0; } + + EndianBufferReader &operator >>(std::string &data) { data = this->ReadStr(); return *this; } + EndianBufferReader &operator >>(bool &data) { data = this->Read<byte>() != 0; return *this; } + + template <typename... Targs> + EndianBufferReader &operator >>(std::tuple<Targs...> &data) + { + this->ReadTuple(data, std::index_sequence_for<Targs...>{}); + return *this; + } + + template <class T, std::enable_if_t<std::disjunction_v<std::negation<std::is_class<T>>, std::is_base_of<StrongTypedefBase, T>>, int> = 0> + EndianBufferReader &operator >>(T &data) + { + if constexpr (std::is_enum_v<T>) { + data = static_cast<T>(this->Read<std::underlying_type_t<T>>()); + } else if constexpr (std::is_base_of_v<StrongTypedefBase, T>) { + data.value = this->Read<decltype(data.value)>(); + } else { + data = this->Read<T>(); + } + return *this; + } + + template <typename Tvalue> + static Tvalue ToValue(span<const byte> buffer) + { + Tvalue result{}; + EndianBufferReader reader{ buffer }; + reader >> result; + return result; + } + +private: + /** Helper function to read a tuple from the buffer. */ + template<class Ttuple, size_t... Tindices> + void ReadTuple(Ttuple &values, std::index_sequence<Tindices...>) { + ((*this >> std::get<Tindices>(values)), ...); + } + + /** Read overload for string data. */ + std::string ReadStr() + { + std::string str; + while (this->read_pos < this->buffer.size()) { + char ch = this->Read<char>(); + if (ch == '\0') break; + str.push_back(ch); + } + return str; + } + + /** Fundamental read function. */ + template <class T> + T Read() + { + static_assert(!std::is_const_v<T>, "Can't read into const variables"); + static_assert(sizeof(T) <= 8, "Value can't be larger than 8 bytes"); + + if (read_pos + sizeof(T) > this->buffer.size()) return {}; + + T value = static_cast<T>(this->buffer[this->read_pos++]); + if constexpr (sizeof(T) > 1) { + value += static_cast<T>(this->buffer[this->read_pos++]) << 8; + } + if constexpr (sizeof(T) > 2) { + value += static_cast<T>(this->buffer[this->read_pos++]) << 16; + value += static_cast<T>(this->buffer[this->read_pos++]) << 24; + } + if constexpr (sizeof(T) > 4) { + value += static_cast<T>(this->buffer[this->read_pos++]) << 32; + value += static_cast<T>(this->buffer[this->read_pos++]) << 40; + value += static_cast<T>(this->buffer[this->read_pos++]) << 48; + value += static_cast<T>(this->buffer[this->read_pos++]) << 56; + } + + return value; + } +}; + +#endif /* ENDIAN_BUFFER_HPP */ diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index e106d5787..ec0919757 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -186,6 +186,17 @@ void Packet::Send_string(const std::string_view data) } /** + * Copy a sized byte buffer into the packet. + * @param data The data to send. + */ +void Packet::Send_buffer(const std::vector<byte> &data) +{ + assert(this->CanWriteToPacket(sizeof(uint16) + data.size())); + this->Send_uint16((uint16)data.size()); + this->buffer.insert(this->buffer.end(), data.begin(), data.end()); +} + +/** * Send as many of the bytes as possible in the packet. This can mean * that it is possible that not all bytes are sent. To cope with this * the function returns the amount of bytes that were actually sent. @@ -367,6 +378,23 @@ uint64 Packet::Recv_uint64() } /** + * Extract a sized byte buffer from the packet. + * @return The extracted buffer. + */ +std::vector<byte> Packet::Recv_buffer() +{ + uint16 size = this->Recv_uint16(); + if (size == 0 || !this->CanReadFromPacket(size, true)) return {}; + + std::vector<byte> data; + while (size-- > 0) { + data.push_back(this->buffer[this->pos++]); + } + + return data; +} + +/** * Reads characters (bytes) from the packet until it finds a '\0', or reaches a * maximum of \c length characters. * When the '\0' has not been reached in the first \c length read characters, diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 277ff8bba..04a232e1c 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -72,6 +72,7 @@ public: void Send_uint32(uint32 data); void Send_uint64(uint64 data); void Send_string(const std::string_view data); + void Send_buffer(const std::vector<byte> &data); size_t Send_bytes (const byte *begin, const byte *end); /* Reading/receiving of packets */ @@ -87,6 +88,7 @@ public: uint16 Recv_uint16(); uint32 Recv_uint32(); uint64 Recv_uint64(); + std::vector<byte> Recv_buffer(); std::string Recv_string(size_t length, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); size_t RemainingBytesToTransfer() const; diff --git a/src/network/network.cpp b/src/network/network.cpp index 13f2fe52a..8194f34d0 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -35,6 +35,7 @@ #include "../core/pool_func.hpp" #include "../gfx_func.h" #include "../error.h" +#include "../misc_cmd.h" #include <charconv> #include <sstream> #include <iomanip> @@ -1064,8 +1065,8 @@ void NetworkGameLoop() while (f != nullptr && !feof(f)) { if (_date == next_date && _date_fract == next_date_fract) { if (cp != nullptr) { - NetworkSendCommand(cp->cmd, cp->err_msg, nullptr, cp->company, cp->tile, cp->p1, cp->p2, cp->text); - Debug(desync, 0, "Injecting: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, cp->tile, cp->p1, cp->p2, cp->cmd, cp->text, GetCommandName(cp->cmd)); + NetworkSendCommand(cp->cmd, cp->err_msg, nullptr, cp->company, cp->data); + Debug(desync, 0, "Injecting: {:08x}; {:02x}; {:02x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cp->cmd, cp->tile, FormatArrayAsHex(cp->data), GetCommandName(cp->cmd)); delete cp; cp = nullptr; } @@ -1104,15 +1105,21 @@ void NetworkGameLoop() cp = new CommandPacket(); int company; uint cmd; - char buffer[128]; - int ret = sscanf(p, "%x; %x; %x; %x; %x; %x; %x; %x; \"%127[^\"]\"", &next_date, &next_date_fract, &company, &cp->tile, &cp->p1, &cp->p2, &cmd, &cp->err_msg, buffer); - cp->text = buffer; - /* There are 8 pieces of data to read, however the last is a - * string that might or might not exist. Ignore it if that - * string misses because in 99% of the time it's not used. */ - assert(ret == 9 || ret == 8); + char buffer[256]; + int ret = sscanf(p, "%x; %x; %x; %x; %x; %x; %255s", &next_date, &next_date_fract, &company, &cmd, &cp->err_msg, &cp->tile, buffer); + assert(ret == 6); cp->company = (CompanyID)company; cp->cmd = (Commands)cmd; + + /* Parse command data. */ + std::vector<byte> args; + size_t arg_len = strlen(buffer); + for (size_t i = 0; i + 1 < arg_len; i += 2) { + byte e = 0; + std::from_chars(buffer + i, buffer + i + 1, e, 16); + args.emplace_back(e); + } + cp->data = args; } else if (strncmp(p, "join: ", 6) == 0) { /* Manually insert a pause when joining; this way the client can join at the exact right time. */ int ret = sscanf(p + 6, "%x; %x", &next_date, &next_date_fract); @@ -1121,8 +1128,7 @@ void NetworkGameLoop() cp = new CommandPacket(); cp->company = COMPANY_SPECTATOR; cp->cmd = CMD_PAUSE; - cp->p1 = PM_PAUSED_NORMAL; - cp->p2 = 1; + cp->data = EndianBufferWriter<>::FromValue(CommandTraits<CMD_PAUSE>::Args{ 0, PM_PAUSED_NORMAL, 1, "" }); _ddc_fastforward = false; } else if (strncmp(p, "sync: ", 6) == 0) { int ret = sscanf(p + 6, "%x; %x; %x; %x", &next_date, &next_date_fract, &sync_state[0], &sync_state[1]); diff --git a/src/network/network_admin.cpp b/src/network/network_admin.cpp index 99f803e24..4711cdf04 100644 --- a/src/network/network_admin.cpp +++ b/src/network/network_admin.cpp @@ -630,10 +630,7 @@ NetworkRecvStatus ServerNetworkAdminSocketHandler::SendCmdLogging(ClientID clien p->Send_uint32(client_id); p->Send_uint8 (cp->company); p->Send_uint16(cp->cmd); - p->Send_uint32(cp->p1); - p->Send_uint32(cp->p2); - p->Send_uint32(cp->tile); - p->Send_string(cp->text); + p->Send_buffer(cp->data); p->Send_uint32(cp->frame); this->SendPacket(p); diff --git a/src/network/network_command.cpp b/src/network/network_command.cpp index 0fae6bcbf..472d5e60e 100644 --- a/src/network/network_command.cpp +++ b/src/network/network_command.cpp @@ -15,18 +15,41 @@ #include "../company_func.h" #include "../settings_type.h" #include "../airport_cmd.h" +#include "../aircraft_cmd.h" +#include "../autoreplace_cmd.h" +#include "../company_cmd.h" #include "../depot_cmd.h" #include "../dock_cmd.h" +#include "../economy_cmd.h" +#include "../engine_cmd.h" +#include "../goal_cmd.h" #include "../group_cmd.h" #include "../industry_cmd.h" +#include "../landscape_cmd.h" +#include "../misc_cmd.h" +#include "../news_cmd.h" +#include "../object_cmd.h" +#include "../order_cmd.h" #include "../rail_cmd.h" #include "../road_cmd.h" +#include "../roadveh_cmd.h" +#include "../settings_cmd.h" +#include "../signs_cmd.h" +#include "../station_cmd.h" +#include "../story_cmd.h" +#include "../subsidy_cmd.h" #include "../terraform_cmd.h" +#include "../timetable_cmd.h" #include "../town_cmd.h" #include "../train_cmd.h" +#include "../tree_cmd.h" #include "../tunnelbridge_cmd.h" #include "../vehicle_cmd.h" +#include "../viewport_cmd.h" +#include "../water_cmd.h" +#include "../waypoint_cmd.h" #include "../script/script_cmd.h" +#include <array> #include "../safeguards.h" @@ -62,6 +85,23 @@ static CommandCallback * const _callback_table[] = { /* 0x1B */ CcAddVehicleNewGroup, }; +/* Helpers to generate the command dispatch table from the command traits. */ + +template <Commands Tcmd> static CommandDataBuffer SanitizeCmdStrings(const CommandDataBuffer &data); +template <Commands Tcmd> static void UnpackNetworkCommand(const CommandPacket *cp); +struct CommandDispatch { + CommandDataBuffer(*Sanitize)(const CommandDataBuffer &); + void (*Unpack)(const CommandPacket *); +}; + +template<typename T, T... i> +inline constexpr auto MakeDispatchTable(std::integer_sequence<T, i...>) noexcept +{ + return std::array<CommandDispatch, sizeof...(i)>{{ { &SanitizeCmdStrings<static_cast<Commands>(i)>, &UnpackNetworkCommand<static_cast<Commands>(i)> }... }}; +} +static constexpr auto _cmd_dispatch = MakeDispatchTable(std::make_integer_sequence<std::underlying_type_t<Commands>, CMD_END>{}); + + /** * Append a CommandPacket at the end of the queue. * @param p The packet to append to the queue. @@ -149,15 +189,28 @@ static CommandQueue _local_execution_queue; */ void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex tile, uint32 p1, uint32 p2, const std::string &text) { + auto data = EndianBufferWriter<CommandDataBuffer>::FromValue(std::make_tuple(tile, p1, p2, text)); + NetworkSendCommand(cmd, err_message, callback, company, tile, data); +} + +/** + * Prepare a DoCommand to be send over the network + * @param cmd The command to execute (a CMD_* value) + * @param err_message Message prefix to show on error + * @param callback A callback function to call after the command is finished + * @param company The company that wants to send the command + * @param location Location of the command (e.g. for error message position) + * @param cmd_data The command proc arguments. + */ +void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex location, const CommandDataBuffer &cmd_data) +{ CommandPacket c; c.company = company; - c.tile = tile; - c.p1 = p1; - c.p2 = p2; c.cmd = cmd; c.err_msg = err_message; c.callback = callback; - c.text = text; + c.tile = location; + c.data = cmd_data; if (_network_server) { /* If we are the server, we queue the command in our 'special' queue. @@ -220,7 +273,7 @@ void NetworkExecuteLocalCommandQueue() /* We can execute this command */ _current_company = cp->company; - DoCommandP(cp, cp->my_cmd, true); + _cmd_dispatch[cp->cmd].Unpack(cp); queue.Pop(); delete cp; @@ -311,11 +364,8 @@ const char *NetworkGameSocketHandler::ReceiveCommand(Packet *p, CommandPacket *c if (!IsValidCommand(cp->cmd)) return "invalid command"; if (GetCommandFlags(cp->cmd) & CMD_OFFLINE) return "single-player only command"; cp->err_msg = p->Recv_uint16(); - - cp->p1 = p->Recv_uint32(); - cp->p2 = p->Recv_uint32(); cp->tile = p->Recv_uint32(); - cp->text = p->Recv_string(NETWORK_COMPANY_NAME_LENGTH, (!_network_server && GetCommandFlags(cp->cmd) & CMD_STR_CTRL) != 0 ? SVS_ALLOW_CONTROL_CODE | SVS_REPLACE_WITH_QUESTION_MARK : SVS_REPLACE_WITH_QUESTION_MARK); + cp->data = _cmd_dispatch[cp->cmd].Sanitize(p->Recv_buffer()); byte callback = p->Recv_uint8(); if (callback >= lengthof(_callback_table)) return "invalid callback"; @@ -331,13 +381,11 @@ const char *NetworkGameSocketHandler::ReceiveCommand(Packet *p, CommandPacket *c */ void NetworkGameSocketHandler::SendCommand(Packet *p, const CommandPacket *cp) { - p->Send_uint8 (cp->company); + p->Send_uint8(cp->company); p->Send_uint16(cp->cmd); p->Send_uint16(cp->err_msg); - p->Send_uint32(cp->p1); - p->Send_uint32(cp->p2); p->Send_uint32(cp->tile); - p->Send_string(cp->text); + p->Send_buffer(cp->data); byte callback = 0; while (callback < lengthof(_callback_table) && _callback_table[callback] != cp->callback) { @@ -350,3 +398,58 @@ void NetworkGameSocketHandler::SendCommand(Packet *p, const CommandPacket *cp) } p->Send_uint8 (callback); } + +/** + * Insert a client ID into the command data in a command packet. + * @param cp Command packet to modify. + * @param client_id Client id to insert. + */ +void NetworkReplaceCommandClientId(CommandPacket &cp, ClientID client_id) +{ + /* Unpack command parameters. */ + auto params = EndianBufferReader::ToValue<std::tuple<TileIndex, uint32, uint32, std::string>>(cp.data); + + /* Insert client id. */ + std::get<2>(params) = client_id; + + /* Repack command parameters. */ + cp.data = EndianBufferWriter<CommandDataBuffer>::FromValue(params); +} + + +/** Validate a single string argument coming from network. */ +template <class T> +static inline void SanitizeSingleStringHelper([[maybe_unused]] CommandFlags cmd_flags, T &data) +{ + if constexpr (std::is_same_v<std::string, T>) { + data = StrMakeValid(data.substr(0, NETWORK_COMPANY_NAME_LENGTH), (!_network_server && cmd_flags & CMD_STR_CTRL) != 0 ? SVS_ALLOW_CONTROL_CODE | SVS_REPLACE_WITH_QUESTION_MARK : SVS_REPLACE_WITH_QUESTION_MARK); + } +} + +/** Helper function to perform validation on command data strings. */ +template<class Ttuple, size_t... Tindices> +static inline void SanitizeStringsHelper(CommandFlags cmd_flags, Ttuple &values, std::index_sequence<Tindices...>) +{ + ((SanitizeSingleStringHelper(cmd_flags, std::get<Tindices>(values))), ...); +} + +/** + * Validate and sanitize strings in command data. + * @tparam Tcmd Command this data belongs to. + * @param data Command data. + * @return Sanitized command data. + */ +template <Commands Tcmd> +CommandDataBuffer SanitizeCmdStrings(const CommandDataBuffer &data) +{ + auto args = EndianBufferReader::ToValue<typename CommandTraits<Tcmd>::Args>(data); + SanitizeStringsHelper(CommandTraits<Tcmd>::flags, args, std::make_index_sequence<std::tuple_size_v<typename CommandTraits<Tcmd>::Args>>{}); + return EndianBufferWriter<CommandDataBuffer>::FromValue(args); +} + +template <Commands Tcmd> +void UnpackNetworkCommand(const CommandPacket *cp) +{ + auto args = EndianBufferReader::ToValue<typename CommandTraits<Tcmd>::Args>(cp->data); + std::apply(&InjectNetworkCommand, std::tuple_cat(std::make_tuple(Tcmd, cp->err_msg, cp->callback, cp->my_cmd), args)); +} diff --git a/src/network/network_internal.h b/src/network/network_internal.h index 25240da5d..58c99867c 100644 --- a/src/network/network_internal.h +++ b/src/network/network_internal.h @@ -15,6 +15,8 @@ #include "core/tcp_game.h" #include "../command_type.h" +#include "../command_func.h" +#include "../misc/endian_buffer.hpp" #ifdef RANDOM_DEBUG /** @@ -104,19 +106,26 @@ void UpdateNetworkGameWindow(); /** * Everything we need to know about a command to be able to execute it. */ -struct CommandPacket : CommandContainer { +struct CommandPacket { /** Make sure the pointer is nullptr. */ - CommandPacket() : next(nullptr), company(INVALID_COMPANY), frame(0), my_cmd(false) {} + CommandPacket() : next(nullptr), company(INVALID_COMPANY), frame(0), my_cmd(false), tile(0) {} CommandPacket *next; ///< the next command packet (if in queue) CompanyID company; ///< company that is executing the command uint32 frame; ///< the frame in which this packet is executed bool my_cmd; ///< did the command originate from "me" + + Commands cmd; ///< command being executed. + StringID err_msg; ///< string ID of error message to use. + CommandCallback *callback; ///< any callback function executed upon successful completion of the command. + TileIndex tile; ///< location of the command (for e.g. error message or effect display). + CommandDataBuffer data; ///< command parameters. }; void NetworkDistributeCommands(); void NetworkExecuteLocalCommandQueue(); void NetworkFreeLocalCommandQueue(); void NetworkSyncCommandQueue(NetworkClientSocket *cs); +void NetworkReplaceCommandClientId(CommandPacket &cp, ClientID client_id); void ShowNetworkError(StringID error_string); void NetworkTextMessage(NetworkAction action, TextColour colour, bool self_send, const std::string &name, const std::string &str = "", int64 data = 0, const std::string &data_str = ""); diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 50b46dfe1..967ad40a8 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -24,6 +24,7 @@ #include "../genworld.h" #include "../company_func.h" #include "../company_gui.h" +#include "../company_cmd.h" #include "../roadveh.h" #include "../order_backup.h" #include "../core/pool_func.hpp" @@ -1048,14 +1049,15 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_COMMAND(Packet * to match the company in the packet. If it doesn't, the client has done * something pretty naughty (or a bug), and will be kicked */ - if (!(cp.cmd == CMD_COMPANY_CTRL && cp.p1 == 0 && ci->client_playas == COMPANY_NEW_COMPANY) && ci->client_playas != cp.company) { + uint32 company_p1 = cp.cmd == CMD_COMPANY_CTRL ? std::get<1>(EndianBufferReader::ToValue<CommandTraits<CMD_COMPANY_CTRL>::Args>(cp.data)) : 0; + if (!(cp.cmd == CMD_COMPANY_CTRL && company_p1 == 0 && ci->client_playas == COMPANY_NEW_COMPANY) && ci->client_playas != cp.company) { IConsolePrint(CC_WARNING, "Kicking client #{} (IP: {}) due to calling a command as another company {}.", ci->client_playas + 1, this->GetClientIP(), cp.company + 1); return this->SendError(NETWORK_ERROR_COMPANY_MISMATCH); } if (cp.cmd == CMD_COMPANY_CTRL) { - if (cp.p1 != 0 || cp.company != COMPANY_SPECTATOR) { + if (company_p1 != 0 || cp.company != COMPANY_SPECTATOR) { return this->SendError(NETWORK_ERROR_CHEATER); } @@ -1066,7 +1068,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_COMMAND(Packet } } - if (GetCommandFlags(cp.cmd) & CMD_CLIENT_ID) cp.p2 = this->client_id; + if (GetCommandFlags(cp.cmd) & CMD_CLIENT_ID) NetworkReplaceCommandClientId(cp, this->client_id); this->incoming_queue.Append(&cp); return NETWORK_RECV_STATUS_OKAY; diff --git a/src/string.cpp b/src/string.cpp index d027cb7bf..aeac4fe84 100644 --- a/src/string.cpp +++ b/src/string.cpp @@ -19,6 +19,7 @@ #include <stdarg.h> #include <ctype.h> /* required for tolower() */ #include <sstream> +#include <iomanip> #ifdef _MSC_VER #include <errno.h> // required by vsnprintf implementation for MSVC @@ -161,6 +162,23 @@ char *CDECL str_fmt(const char *str, ...) } /** + * Format a byte array into a continuous hex string. + * @param data Array to format + * @return Converted string. + */ +std::string FormatArrayAsHex(span<const byte> data) +{ + std::ostringstream ss; + ss << std::uppercase << std::setfill('0') << std::setw(2) << std::hex; + + for (auto b : data) { + ss << b; + } + + return ss.str(); +} + +/** * Scan the string for old values of SCC_ENCODED and fix it to * it's new, static value. * @param str the string to scan diff --git a/src/string_func.h b/src/string_func.h index 0cbf26d6b..a5d3499c7 100644 --- a/src/string_func.h +++ b/src/string_func.h @@ -28,6 +28,7 @@ #include <iosfwd> #include "core/bitmath_func.hpp" +#include "core/span_type.hpp" #include "string_type.h" char *strecat(char *dst, const char *src, const char *last) NOACCESS(3); @@ -39,6 +40,8 @@ int CDECL vseprintf(char *str, const char *last, const char *format, va_list ap) char *CDECL str_fmt(const char *str, ...) WARN_FORMAT(1, 2); +std::string FormatArrayAsHex(span<const byte> data); + void StrMakeValidInPlace(char *str, const char *last, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK) NOACCESS(2); [[nodiscard]] std::string StrMakeValid(const std::string &str, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); void StrMakeValidInPlace(char *str, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); |