summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/network/core/tcp.h3
-rw-r--r--src/network/core/tcp_connect.cpp17
2 files changed, 16 insertions, 4 deletions
diff --git a/src/network/core/tcp.h b/src/network/core/tcp.h
index bbd0bc2a9..b4b4398de 100644
--- a/src/network/core/tcp.h
+++ b/src/network/core/tcp.h
@@ -98,6 +98,7 @@ private:
std::chrono::steady_clock::time_point last_attempt; ///< Time we last tried to connect.
std::string connection_string; ///< Current address we are connecting to (before resolving).
+ NetworkAddress bind_address; ///< Address we're binding to, if any.
void Resolve();
void OnResolved(addrinfo *ai);
@@ -113,7 +114,7 @@ private:
public:
TCPConnecter() {};
- TCPConnecter(const std::string &connection_string, uint16 default_port);
+ TCPConnecter(const std::string &connection_string, uint16 default_port, NetworkAddress bind_address = {});
virtual ~TCPConnecter();
/**
diff --git a/src/network/core/tcp_connect.cpp b/src/network/core/tcp_connect.cpp
index 0e8e2e125..6db2500f3 100644
--- a/src/network/core/tcp_connect.cpp
+++ b/src/network/core/tcp_connect.cpp
@@ -24,10 +24,13 @@
static std::vector<TCPConnecter *> _tcp_connecters;
/**
- * Create a new connecter for the given address
- * @param connection_string the address to connect to
+ * Create a new connecter for the given address.
+ * @param connection_string The address to connect to.
+ * @param default_port If not indicated in connection_string, what port to use.
+ * @param bind_address The local bind address to use. Defaults to letting the OS find one.
*/
-TCPConnecter::TCPConnecter(const std::string &connection_string, uint16 default_port)
+TCPConnecter::TCPConnecter(const std::string &connection_string, uint16 default_port, NetworkAddress bind_address) :
+ bind_address(bind_address)
{
this->connection_string = NormalizeConnectionString(connection_string, default_port);
@@ -96,6 +99,14 @@ void TCPConnecter::Connect(addrinfo *address)
return;
}
+ if (this->bind_address.GetPort() > 0) {
+ if (bind(sock, (const sockaddr *)this->bind_address.GetAddress(), this->bind_address.GetAddressLength()) != 0) {
+ Debug(net, 1, "Could not bind socket on {}: {}", this->bind_address.GetAddressAsString(), NetworkError::GetLast().AsString());
+ closesocket(sock);
+ return;
+ }
+ }
+
if (!SetNoDelay(sock)) {
Debug(net, 1, "Setting TCP_NODELAY failed: {}", NetworkError::GetLast().AsString());
}