diff -uNr a/blatta/README.txt b/blatta/README.txt
--- a/blatta/README.txt bc8aaa1f75e7830f0359f5a35958fc19b8bbdd316e7a97cf66bba38e0a616bd673751249ce6e9fa2aad7f05e32af3ae9b351e87cac8a679fb8d381510514d813
+++ b/blatta/README.txt 7ee0913ae7addd7e419ccde7a0f4e7c2348029ad00e1d6994da98430875358b91032ded251ee34c7917c81308be84e9042da9d54780b440aec23065a53c498ff
@@ -30,3 +30,8 @@
3. Use genkey to generate a key.
4. Add the key to the peer using the key command.
5. Add an address for the peer using the address command.
+
+NOTES:
+
+To run the unit tests, you'll need to run:
+pip install mock
diff -uNr a/blatta/blatta b/blatta/blatta
--- a/blatta/blatta 50acef42c77e18fb23fa8157201ddba47717bb95ac85fe9d23393aba847a39665373cc6c98fdbb08096f321081980f421c17acd4490de8cb56d38bef637c0565
+++ b/blatta/blatta 87fb4a6177c042b76e8c0e34dc153bc2f574e78b90521124a44cf41e02f57aaf13ba15be758f3a67056f53064c8928f3da47dbc0cffb88019ae5456b6f06a9bf
@@ -8,6 +8,7 @@
import sys
import tempfile
import time
+import logging
from lib.server import VERSION
from lib.server import Server
from lib.peer import Peer
@@ -90,8 +91,11 @@
(options, args) = op.parse_args(argv[1:])
if options.channel_name is None:
options.channel_name = "#pest"
+ log_format = "%(levelname)s %(asctime)s: %(message)s"
if options.debug:
- options.verbose = True
+ logging.basicConfig(level=logging.DEBUG, format=log_format, stream=sys.stdout)
+ else:
+ logging.basicConfig(level=logging.INFO, format=log_format, stream=sys.stdout)
if options.irc_ports is None:
options.irc_ports = "6697"
if options.udp_port is None:
@@ -139,7 +143,7 @@
try:
server.start()
except KeyboardInterrupt:
- server.print_error("Interrupted.")
+ logging.error("Interrupted.")
main(sys.argv)
diff -uNr a/blatta/lib/client.py b/blatta/lib/client.py
--- a/blatta/lib/client.py a35f64ee21532cc117fb36691ece37e2f523cb01855fb4eb8268ac745519f27fd36addb324d56d4893fb7a84fe70302630e53becf3571e250b3b578ca46abce3
+++ b/blatta/lib/client.py 98e1e99d7ea8fe523728ca8f5661b244f3769a09a44cfa8c40236fe5cb32bb2ee267c7ee2e2028d1e455b5fc3bae9efb53e91a457033133c4bb25162ebcf8e5e
@@ -6,6 +6,8 @@
import os
import base64
import traceback
+import logging
+from lib.state import State
from lib.message import Message
from lib.server import VERSION
from funcs import *
@@ -22,6 +24,7 @@
def __init__(self, server, socket):
self.server = server
+ self.state = State.get_instance()
self.socket = socket
self.channels = {} # irc_lower(Channel name) --> Channel
self.nickname = None
@@ -37,15 +40,11 @@
else:
self.__handle_command = self.__registration_handler
- def is_addressed_to_me(self, message):
- command = self.__parse_udp_message(message)
- if command[0] == 'PRIVMSG':
- if command[1][0][0] == '#' or command[1][0] == self.nickname:
- return True
- else:
- return False
- else:
- return True
+ def message_from_station(self, msg):
+ targetname = self.server.channel_name if msg.command == BROADCAST else self.nickname
+ pest_prefix = msg.prefix if msg.prefix else msg.speaker
+ formatted_message = ":%s PRIVMSG %s :%s" % (pest_prefix, targetname, msg.body)
+ self.__writebuffer += formatted_message + "\r\n"
def get_prefix(self):
return "%s" % (self.nickname)
@@ -68,30 +67,6 @@
def write_queue_size(self):
return len(self.__writebuffer)
- def __parse_udp_message(self, message):
- data = " ".join(message.split()[1:]) + "\r\n"
- lines = self.__linesep_regexp.split(data)
- lines = lines[:-1]
- commands = []
- for line in lines:
- if not line:
- # Empty line. Ignore.
- continue
- x = line.split(" ", 1)
- command = x[0].upper()
- if len(x) == 1:
- arguments = []
- else:
- if len(x[1]) > 0 and x[1][0] == ":":
- arguments = [x[1][1:]]
- else:
- y = string.split(x[1], " :", 1)
- arguments = string.split(y[0])
- if len(y) == 2:
- arguments.append(y[1])
- commands.append([command, arguments])
- return commands[0]
-
def __parse_read_buffer(self):
lines = self.__linesep_regexp.split(self.__readbuffer)
self.__readbuffer = lines[-1]
@@ -159,7 +134,6 @@
% self.nickname)
self.reply("004 %s :%s blatta-%s o o"
% (self.nickname, server.name, VERSION))
- self.send_lusers()
self.send_motd()
self.__handle_command = self.__command_handler
@@ -182,32 +156,20 @@
if arguments[0] == "0":
for (channelname, channel) in self.channels.items():
self.message_channel(channel, "PART", channelname, True)
- self.channel_log(channel, "left", meta=True)
server.remove_member_from_channel(self, channelname)
self.channels = {}
return
channelnames = arguments[0].split(",")
- if len(arguments) > 1:
- keys = arguments[1].split(",")
- else:
- keys = []
- keys.extend((len(channelnames) - len(keys)) * [None])
- for (i, channelname) in enumerate(channelnames):
+ for channelname in channelnames:
if irc_lower(channelname) in self.channels:
continue
if not valid_channel_re.match(channelname):
self.reply_403(channelname)
continue
channel = server.get_channel(channelname)
- if channel.key is not None and channel.key != keys[i]:
- self.reply(
- "475 %s %s :Cannot join channel (+k) - bad key"
- % (self.nickname, channelname))
- continue
channel.add_member(self)
self.channels[irc_lower(channelname)] = channel
self.message_channel(channel, "JOIN", channelname, True)
- self.channel_log(channel, "joined", meta=True)
if channel.topic:
self.reply("332 %s %s :%s"
% (self.nickname, channel.name, channel.topic))
@@ -218,7 +180,7 @@
% (self.nickname,
channelname,
" ".join(sorted(x
- for x in self.server.state.get_peer_handles()))))
+ for x in self.state.get_peer_handles()))))
self.reply("366 %s %s :End of NAMES list"
% (self.nickname, channelname))
@@ -238,7 +200,7 @@
self.reply("323 %s :End of LIST" % self.nickname)
def lusers_handler():
- self.send_lusers()
+ pass
def mode_handler():
if len(arguments) < 1:
@@ -268,8 +230,6 @@
self.message_channel(
channel, "MODE", "%s +k %s" % (channel.name, key),
True)
- self.channel_log(
- channel, "set channel key to %s" % key, meta=True)
else:
self.reply("442 %s :You're not on that channel"
% targetname)
@@ -279,8 +239,6 @@
self.message_channel(
channel, "MODE", "%s -k" % channel.name,
True)
- self.channel_log(
- channel, "removed channel key", meta=True)
else:
self.reply("442 %s :You're not on that channel"
% targetname)
@@ -313,9 +271,6 @@
self.reply("432 %s %s :Erroneous Nickname"
% (self.nickname, newnick))
else:
- for x in self.channels.values():
- self.channel_log(
- x, "changed nickname to %s" % newnick, meta=True)
oldnickname = self.nickname
self.nickname = newnick
server.client_changed_nickname(self, oldnickname)
@@ -340,16 +295,23 @@
channel = server.get_channel(targetname)
self.message_channel(
channel, command, "%s :%s" % (channel.name, message))
- self.channel_log(channel, message)
+ # send the channel message to peers as well
+ self.server.station.infosec.message(
+ Message(
+ {
+ "speaker": self.nickname,
+ "command": BROADCAST,
+ "bounces": 0,
+ "body": message
+ }))
else:
- formatted_message = ":%s %s %s :%s" % (self.prefix, command, targetname, message)
- self.server.peer_message(Message({
+ self.server.station.infosec.message(Message({
"speaker": self.nickname,
"handle": targetname,
- "body": formatted_message,
+ "body": message,
"bounces": 0,
"command": DIRECT
- }, self.server))
+ }))
if(client):
client.message(formatted_message)
@@ -372,7 +334,6 @@
self.message_channel(
channel, "PART", "%s :%s" % (channelname, partmsg),
True)
- self.channel_log(channel, "left (%s)" % partmsg, meta=True)
del self.channels[irc_lower(channelname)]
server.remove_member_from_channel(self, channelname)
@@ -405,8 +366,6 @@
self.message_channel(
channel, "TOPIC", "%s :%s" % (channelname, newtopic),
True)
- self.channel_log(
- channel, "set topic to %r" % newtopic, meta=True)
else:
if channel.topic:
self.reply("332 %s %s :%s"
@@ -464,17 +423,21 @@
def wot_handler():
if len(arguments) < 1:
# Display the current WOT
- peers = self.server.state.get_peers()
+ peers = self.state.get_peers()
if len(peers) > 0:
for peer in peers:
- self.pest_reply("%s %s:%s" % (string.join(peer.handles, ","), peer.address, peer.port))
+ if peer.address and peer.port:
+ address = "%s:%s" % (peer.address, peer.port)
+ else:
+ address = "
"
+ self.pest_reply("%s %s" % (string.join(peer.handles, ","), address))
else:
self.pest_reply("WOT is empty")
elif len(arguments) == 1:
# Display all WOT data concerning the peer identified by HANDLE,
# including all known keys, starting with the most recently used, for that peer.
handle = arguments[0]
- peer = self.server.state.get_peer_by_handle(handle)
+ peer = self.state.get_peer_by_handle(handle)
if peer:
self.pest_reply("keys:")
for key in peer.keys:
@@ -488,7 +451,7 @@
def peer_handler():
if len(arguments) == 1:
try:
- self.server.state.add_peer(arguments[0])
+ self.state.add_peer(arguments[0])
self.pest_reply("added new peer %s" % arguments[0])
self.message(":%s JOIN %s" % (arguments[0], self.server.channel_name))
except:
@@ -499,11 +462,11 @@
def unpeer_handler():
if len(arguments) == 1:
try:
- self.server.state.remove_peer(arguments[0])
+ self.state.remove_peer(arguments[0])
self.pest_reply("removed peer %s" % arguments[0])
self.message(":%s PART %s" % (arguments[0], self.server.channel_name))
except Exception, e:
- self.server.print_debug(e)
+ logging.debug(e)
self.pest_reply("Error attempting to remove peer")
else:
self.pest_reply("Usage: UNPEER ")
@@ -518,7 +481,7 @@
handle = arguments[0]
key = arguments[1]
try:
- self.server.state.add_key(handle, key)
+ self.state.add_key(handle, key)
self.pest_reply("added key: %s" % key)
except:
self.pest_reply("Error attempting to add key")
@@ -528,23 +491,23 @@
self.pest_reply("Usage: UNKEY ")
else:
try:
- self.server.state.remove_key(arguments[0])
+ self.state.remove_key(arguments[0])
self.pest_reply("removed key: %s" % arguments[0])
except Exception, e:
self.pest_reply("Error attempting to remove key")
- self.server.print_debug(e)
+ logging.debug(e)
def at_handler():
if len(arguments) == 0:
- at = self.server.state.get_at()
+ at = self.state.get_at()
elif len(arguments) == 1:
handle = arguments[0]
- at = self.server.state.get_at(handle)
+ at = self.state.get_at(handle)
elif len(arguments) == 2:
try:
handle, address = arguments
address_ip, port = string.split(address, ":")
- self.server.state.update_address_table({"handle": handle,
+ self.state.update_at({"handle": handle,
"address": address_ip,
"port": port},
False)
@@ -552,7 +515,7 @@
except Exception as ex:
self.pest_reply("Error attempting to update address table")
stack = traceback.format_exc()
- print(stack)
+ logger.debug(stack)
return
elif len(arguments) > 2:
self.pest_reply("Usage: AT [] []")
@@ -599,12 +562,12 @@
except KeyError:
self.reply("421 %s %s :Unknown command" % (self.nickname, command))
stack = traceback.format_exc()
- print(stack)
+ logger.debug(stack)
def socket_readable_notification(self):
try:
data = self.socket.recv(2 ** 10)
- self.server.print_debug(
+ logging.debug(
"[%s:%d] -> %r" % (self.host, self.port, data))
quitmsg = "EOT"
except socket.error as x:
@@ -621,7 +584,7 @@
def socket_writable_notification(self):
try:
sent = self.socket.send(self.__writebuffer)
- self.server.print_debug(
+ logging.debug(
"[%s:%d] <- %r" % (
self.host, self.port, self.__writebuffer[:sent]))
self.__writebuffer = self.__writebuffer[sent:]
@@ -630,7 +593,7 @@
def disconnect(self, quitmsg):
self.message("ERROR :%s" % quitmsg)
- self.server.print_info(
+ logging.info(
"Disconnected connection from %s:%s (%s)." % (
self.host, self.port, quitmsg))
self.socket.close()
@@ -654,31 +617,8 @@
def message_channel(self, channel, command, message, include_self=False):
line = ":%s %s %s" % (self.prefix, command, message)
- for client in channel.members:
- if client != self or include_self:
- client.message(line)
- # send the channel message to peers as well
- self.server.peer_message(
- Message(
- {
- "speaker": self.nickname,
- "command": BROADCAST,
- "bounces": 0,
- "body": line
- }, self.server))
-
- def channel_log(self, channel, message, meta=False):
- if not self.server.logdir:
- return
- if meta:
- format = "[%s] * %s %s\n"
- else:
- format = "[%s] <%s> %s\n"
- timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
- logname = channel.name.replace("_", "__").replace("/", "_")
- fp = open("%s/%s.log" % (self.server.logdir, logname), "a")
- fp.write(format % (timestamp, self.nickname, message))
- fp.close()
+ if include_self:
+ self.message(line)
def message_related(self, msg, include_self=False):
clients = set()
@@ -691,10 +631,6 @@
for client in clients:
client.message(msg)
- def send_lusers(self):
- self.reply("251 %s :There are %d users and 0 services on 1 server"
- % (self.nickname, len(self.server.clients)))
-
def send_motd(self):
server = self.server
motdlines = server.get_motd_lines()
diff -uNr a/blatta/lib/infosec.py b/blatta/lib/infosec.py
--- a/blatta/lib/infosec.py 2bebeb6ee55f0941c567e114ee01352f18f96d02ac4e2f037b9ced7e71c219cc2ed51ac615d28e22244c5e07f136ec7ce1fb565c7217c62de4708aee4d8b05d5
+++ b/blatta/lib/infosec.py 2f8e9df6cf92a779900080f585a1b9873218d949be48299950c2b05a64fd5a3abd8ab927f5c5d4c5d51e7c3def2bab13a6ffe532907e681883730d60bd75b0e5
@@ -16,6 +16,7 @@
import random
import os
import pprint
+import logging
pp = pprint.PrettyPrinter(indent=4)
PACKET_SIZE = 496
@@ -34,32 +35,66 @@
IGNORED = 4
class Infosec(object):
- def __init__(self, server=None):
- self.server = server
+ def __init__(self, state=None):
+ self.state = state
- def get_message_bytes(self, message, peer=None):
- try:
- timestamp = message.timestamp
- except:
- timestamp = None
- command = message.command
- speaker = self._pad(message.speaker, MAX_SPEAKER_SIZE)
+ def message(self, message):
+ # if we are not rebroadcasting we need to set the timestamp
- # if we are rebroadcasting we need to use the original timestamp
+ if message.timestamp == None:
+ message.original = True
+ message.timestamp = int(time.time())
+ else:
+ message.original = False
+
+ target_peer = (self.state.get_peer_by_handle(message.handle)
+ if message.command == DIRECT
+ else None)
+
+ if target_peer and not target_peer.get_key():
+ logging.debug("No key for peer associated with %s" % message.handle)
+ return
+
+ if message.command == DIRECT and target_peer == None:
+ logging.debug("Aborting message: unknown handle: %s" % message.handle)
+ return
+
+ message_bytes = self.get_message_bytes(message, target_peer)
+ if message.command != IGNORE:
+ message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest())
+ logging.debug("generated message_hash: %s" % message_hash)
+ self.state.add_to_dedup_queue(message_hash)
+ self.state.log(message.speaker, message_bytes, target_peer)
+
+ if message.command == DIRECT:
+ signed_packet_bytes = self.pack(target_peer, message, message_bytes)
+ target_peer.send(signed_packet_bytes)
+ elif message.command == BROADCAST or message.command == IGNORE:
+ for peer in self.state.get_keyed_peers():
+
+ # we don't want to send a broadcast back to the originator
- if(timestamp == None):
- int_ts = int(time.time())
+ if message.peer and (peer.peer_id == message.peer.peer_id):
+ next
+
+ signed_packet_bytes = self.pack(peer, message, message_bytes)
+ peer.send(signed_packet_bytes)
else:
- int_ts = timestamp
+ pass
+
+ def get_message_bytes(self, message, peer=None):
+ timestamp = message.timestamp
+ command = message.command
+ speaker = self._pad(message.speaker, MAX_SPEAKER_SIZE)
# let's generate the self_chain value from the last message or set it to zero if
# there this is the first message
if message.original:
if command == DIRECT:
- self_chain = self.server.state.get_last_message_hash(message.speaker, peer.peer_id)
+ self_chain = self.state.get_last_message_hash(message.speaker, peer.peer_id)
elif command == BROADCAST:
- self_chain = self.server.state.get_last_message_hash(message.speaker)
+ self_chain = self.state.get_last_message_hash(message.speaker)
elif command == IGNORE:
self_chain = "\x00" * 32
net_chain = "\x00" * 32
@@ -69,16 +104,19 @@
# pack message bytes
- message_bytes = struct.pack(MESSAGE_PACKET_FORMAT, int_ts, self_chain, net_chain, speaker, message.body)
+ if message.command != IGNORE:
+ logging.debug("packing message bytes: %s" % message.body)
+ else:
+ logging.debug("packing rubbish message bytes: %s" % binascii.hexlify(message.body))
+
+ message_bytes = struct.pack(MESSAGE_PACKET_FORMAT, message.timestamp, self_chain, net_chain, speaker, message.body)
return message_bytes
- def pack(self, peer, message):
+ def pack(self, peer, message, message_bytes):
key_bytes = base64.b64decode(peer.get_key())
signing_key = key_bytes[:32]
cipher_key = key_bytes[32:]
- message_bytes = self.get_message_bytes(message, peer)
-
# pack packet bytes
nonce = self._generate_nonce(16)
@@ -111,15 +149,15 @@
try:
black_packet_bytes, signature_bytes = struct.unpack(BLACK_PACKET_FORMAT, black_packet)
except:
- self.server.print_error("Discarding malformed black packet from %s" % peer.get_key())
- return Message({ "error_code": MALFORMED_PACKET }, self.server)
+ logging.error("Discarding malformed black packet from %s" % peer.get_key())
+ return Message({ "error_code": MALFORMED_PACKET })
# check signature
signature_check_bytes = hmac.new(signing_key, black_packet_bytes, hashlib.sha384).digest()
if(signature_check_bytes != signature_bytes):
- return Message({ "error_code": INVALID_SIGNATURE }, self.server)
+ return Message({ "error_code": INVALID_SIGNATURE })
# try to decrypt black packet
@@ -130,10 +168,27 @@
nonce, bounces, version, command, message_bytes = struct.unpack(RED_PACKET_FORMAT, red_packet_bytes)
+ # compute message_hash
+
+ message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest())
+
# unpack message
- int_ts, self_chain, net_chain, speaker, message = struct.unpack(MESSAGE_PACKET_FORMAT, message_bytes)
- speaker = speaker.strip()
+ int_ts, self_chain, net_chain, speaker, body = struct.unpack(MESSAGE_PACKET_FORMAT, message_bytes)
+
+ # remove padding from speaker
+
+ for index, byte in enumerate(speaker):
+ if byte == '\x00':
+ speaker = speaker[0:index]
+ break
+
+ # remove padding from body
+
+ for index, byte in enumerate(body):
+ if byte == '\x00':
+ body = body[0:index]
+ break
# nothing to be done for an IGNORE command
@@ -143,39 +198,26 @@
# check timestamp
if(int_ts not in self._ts_range()):
- return Message({ "error_code": STALE_PACKET }, self.server)
-
- # check for duplicates
-
- message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest())
- if(self.server.state.is_duplicate_message(message_hash)):
- return Message({ "error_code": DUPLICATE_PACKET }, self.server)
- else:
- self.server.state.add_to_dedup_queue(message_hash)
+ return Message({ "error_code": STALE_PACKET })
# check self_chain
if command == DIRECT:
- self_chain_check = self.server.state.get_last_message_hash(speaker, peer.peer_id)
+ self_chain_check = self.state.get_last_message_hash(speaker, peer.peer_id)
elif command == BROADCAST:
- self_chain_check = self.server.state.get_last_message_hash(speaker)
+ self_chain_check = self.state.get_last_message_hash(speaker)
self_chain_valid = (self_chain_check == self_chain)
# log this message for use in the self_chain check
- self.server.state.log(speaker, message_bytes, peer.peer_id if (command == DIRECT) else None)
+ self.state.log(speaker, message_bytes, peer if (command == DIRECT) else None)
- # remove padding from message bytes
-
- for index, byte in enumerate(message):
- if binascii.hexlify(byte) == "00":
- unpadded_message = message[0:index]
- break
+ # build message object
- return Message({
+ message = Message({
"peer": peer,
- "body": unpadded_message.rstrip(),
+ "body": body.rstrip(),
"timestamp": int_ts,
"command": command,
"speaker": speaker,
@@ -183,12 +225,19 @@
"self_chain": self_chain,
"net_chain": net_chain,
"self_chain_valid": self_chain_valid,
- "error_code": None
- },
- self.server)
+ "message_hash": message_hash
+ })
+
+ # check for duplicates
+
+ if(self.state.is_duplicate_message(message_hash)):
+ message.error_code = DUPLICATE_PACKET
+ return message
+
+ return message
def _pad(self, text, size):
- return text.ljust(size)
+ return text.ljust(size, "\x00")
def _ts_range(self):
current_ts = int(time.time())
diff -uNr a/blatta/lib/message.py b/blatta/lib/message.py
--- a/blatta/lib/message.py 0096d80e9d0c52787f1ad8c43d6b392c5e5434dfd04f62f41361365d858ea8f92b92f901ce4165e2b4fb4079b7aa9e857cf4296dee3eb48759cf63120f3975c5
+++ b/blatta/lib/message.py 0cf5cd14c7e157cf47cf2e4f0f9fd076e89a817a71b389aa8748a8472437e62a2ac35fa5f3b3ae9be39a2f15a926e5920571267bdf39b9ed6b4e0e4cbf9c2970
@@ -1,7 +1,7 @@
class Message(object):
- def __init__(self, message, server=None):
+ def __init__(self, message):
self.original = True
- self.server = server
+ self.prefix = None
self.handle = message.get("handle")
self.peer = message.get("peer")
self.body = message.get("body")
@@ -13,5 +13,4 @@
self.net_chain = message.get("net_chain")
self.self_chain_valid = message.get("self_chain_valid")
self.error_code = message.get("error_code")
- if server:
- self.state = server.state
+ self.message_hash = message.get("message_hash")
diff -uNr a/blatta/lib/peer.py b/blatta/lib/peer.py
--- a/blatta/lib/peer.py e763bb836eba69aedebd4d4adfdd8820e1a173f16e1fd493ddd16bf6d41718155c20ee38073570d2aeb0144ea281b6fbcc5ddde48d5ca60cb01a0b08103dd1f5
+++ b/blatta/lib/peer.py c96da174ae6ceb0489ed2b50872c02e481fa7d14d09a747b387bcf7ec45f1ec1dac5326a3370951f20c8995c85fcc7a8b91bc333cedaf78b11345a3897c8f95d
@@ -1,20 +1,21 @@
import socket
-from infosec import Infosec
from commands import IGNORE
+from commands import DIRECT
+from commands import BROADCAST
+
import sys
import binascii
import traceback
+import logging
class Peer(object):
- def __init__(self, server, peer_entry):
+ def __init__(self, socket, peer_entry):
self.handles = peer_entry["handles"]
self.keys = peer_entry["keys"]
self.peer_id = peer_entry["peer_id"]
- self.server = server
self.address = peer_entry["address"]
self.port = peer_entry["port"]
- self.socket = self.server.udp_server_socket
- self.infosec = Infosec(server)
+ self.socket = socket
def get_key(self):
if len(self.keys) > 0:
@@ -22,16 +23,16 @@
else:
return None
- def send(self, msg):
- try:
- if msg.command != IGNORE:
- self.server.print_debug("packing message: %s" % msg.body)
- signed_packet_bytes = self.infosec.pack(self, msg)
- self.socket.sendto(signed_packet_bytes, (self.address, self.port))
- self.server.print_debug("[%s:%d] <- %s" % (self.address,
- self.port,
- binascii.hexlify(signed_packet_bytes)[0:16]))
+ def send(self, signed_packet_bytes):
+ if self.get_key() != None:
+ try:
+ self.socket.sendto(signed_packet_bytes, (self.address, self.port))
+ logging.debug("[%s:%d] <- %s" % (self.address,
+ self.port,
+ binascii.hexlify(signed_packet_bytes)[0:16]))
- except Exception as ex:
- stack = traceback.format_exc()
- print(stack)
+ except Exception as ex:
+ stack = traceback.format_exc()
+ logging.debug(stack)
+ else:
+ logging.debug("Discarding message to unknown handle or handle with no key: %s" % message.handle)
diff -uNr a/blatta/lib/server.py b/blatta/lib/server.py
--- a/blatta/lib/server.py 16e7971b6eab7483a4060d5cae5111dec2f61618a2022620343ef7aa3fcedee87cc6499c9f9978215c315fde958e70fa7810f50967e97dd299cd98842118c12d
+++ b/blatta/lib/server.py 7f7198c51eb6b00321c1754f1675d907263bf600b8ef67b79641e4a763357fe73935b9f1534234d226f9564e2341aeaafb50bfcff5128bfe14404651b8a36ef2
@@ -1,4 +1,4 @@
-VERSION = "9988"
+VERSION = "9987"
import os
import select
@@ -8,29 +8,18 @@
import tempfile
import time
import string
-import binascii
-import hashlib
import datetime
+import sqlite3
from datetime import datetime
+from funcs import *
from lib.client import Client
-from lib.state import State
from lib.channel import Channel
-from lib.infosec import PACKET_SIZE
-from lib.infosec import MAX_BOUNCES
-from lib.infosec import STALE_PACKET
-from lib.infosec import DUPLICATE_PACKET
-from lib.infosec import MALFORMED_PACKET
-from lib.infosec import INVALID_SIGNATURE
-from lib.infosec import IGNORED
-from lib.infosec import Infosec
-from lib.peer import Peer
+from lib.station import Station
from lib.message import Message
-from funcs import *
-from commands import BROADCAST
-from commands import DIRECT
-from commands import IGNORE
+from lib.infosec import PACKET_SIZE
import imp
import pprint
+import logging
class Server(object):
def __init__(self, options):
@@ -40,18 +29,14 @@
self.password = options.password
self.motdfile = options.motd
self.verbose = options.verbose
- self.debug = options.debug
self.logdir = options.logdir
self.chroot = options.chroot
self.setuid = options.setuid
self.statedir = options.statedir
- self.infosec = Infosec(self)
self.config_file_path = options.config_file_path
- self.state = State(self, options.db_path)
self.pp = pprint.PrettyPrinter(indent=4)
-
- if options.address_table_path != None:
- self.state.import_at_and_wot(options.address_table_path)
+ self.db_path = options.db_path
+ self.address_table_path = options.address_table_path
if options.listen:
self.address = socket.gethostbyname(options.listen)
@@ -61,8 +46,9 @@
self.name = socket.getfqdn(self.address)[:server_name_limit]
self.channels = {} # irc_lower(Channel name) --> Channel instance.
- self.clients = {} # Socket --> Client instance..peers = ""
+ self.client = None
self.nicknames = {} # irc_lower(Nickname) --> Client instance.
+
if self.logdir:
create_directory(self.logdir)
if self.statedir:
@@ -79,7 +65,7 @@
try:
pid = os.fork()
if pid > 0:
- self.print_info("PID: %d" % pid)
+ logging.info("PID: %d" % pid)
sys.exit(0)
except OSError:
sys.exit(1)
@@ -113,19 +99,6 @@
else:
return []
- def print_info(self, msg):
- if self.verbose:
- print(msg)
- sys.stdout.flush()
-
- def print_debug(self, msg):
- if self.debug:
- print("%s %s" % (datetime.now(), msg))
- sys.stdout.flush()
-
- def print_error(self, msg):
- sys.stderr.write("%s\n" % msg)
-
def client_changed_nickname(self, client, oldnickname):
if oldnickname:
del self.nicknames[irc_lower(oldnickname)]
@@ -139,118 +112,26 @@
def remove_client(self, client, quitmsg):
client.message_related(":%s QUIT :%s" % (client.prefix, quitmsg))
for x in client.channels.values():
- client.channel_log(x, "quit (%s)" % quitmsg, meta=True)
x.remove_client(client)
if client.nickname \
and irc_lower(client.nickname) in self.nicknames:
del self.nicknames[irc_lower(client.nickname)]
- del self.clients[client.socket]
+ self.client = None
def remove_channel(self, channel):
del self.channels[irc_lower(channel.name)]
- def handle_udp_data(self, bytes_address_pair):
- data = bytes_address_pair[0]
- address = bytes_address_pair[1]
- packet_info = (address[0],
- address[1],
- binascii.hexlify(data)[0:16])
- self.print_debug("[%s:%d] -> %s" % packet_info)
- for peer in self.state.get_peers():
- if peer.get_key() != None:
- message = self.infosec.unpack(peer, data)
- error_code = message.error_code
- if(error_code == None):
- self.print_debug("[%s] -> %s" % (peer.handles[0], message.body))
-
- self.conditionally_update_address_table(peer, message, address)
- # send the message to all clients
- for c in self.clients:
- if (self.clients[c].is_addressed_to_me(message.body)):
- self.clients[c].message(message.body)
- # send the message to all other peers if it should be propagated
- if(message.command == BROADCAST) and message.bounces < MAX_BOUNCES:
- self.rebroadcast(peer, message)
- return
- elif error_code == STALE_PACKET:
- self.print_debug("[%s:%d] -> stale packet: %s" % packet_info)
- return
- elif error_code == DUPLICATE_PACKET:
- self.print_debug("[%s:%d] -> duplicate packet: %s" % packet_info)
- return
- elif error_code == MALFORMED_PACKET:
- self.print_debug("[%s:%d] -> malformed packet: %s" % packet_info)
- return
- elif error_code == IGNORED:
- self.conditionally_update_address_table(peer, message, address)
- self.print_debug("[%s:%d] -> ignoring packet: %s" % packet_info)
- return
- elif error_code == INVALID_SIGNATURE:
- pass
- self.print_debug("[%s:%d] -> martian packet: %s" % packet_info)
-
- # we only update the address table if the speaker is same as peer
-
- def conditionally_update_address_table(self, peer, message, address):
- try:
- idx = peer.handles.index(message.speaker)
- except:
- idx = None
-
- if idx != None:
- self.state.update_address_table({"handle": message.speaker,
- "address": address[0],
- "port": address[1]
- })
- def peer_message(self, message):
- message.original = True
- if message.command == DIRECT:
- peer = self.state.get_peer_by_handle(message.handle)
- message_bytes = self.infosec.get_message_bytes(message, peer)
- message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest())
- self.state.add_to_dedup_queue(message_hash)
-
- self.state.log(message.speaker, message_bytes, peer.peer_id)
- if peer and (peer.get_key() != None):
- peer.send(message)
- else:
- self.print_debug("Discarding message to unknown handle or handle with no key: %s" % message.handle)
- else:
- message.timestamp = int(time.time())
- message_bytes = self.infosec.get_message_bytes(message)
- if message.command != IGNORE:
- self.state.log(message.speaker, message_bytes)
- message_hash = binascii.hexlify(hashlib.sha256(message_bytes).digest())
- self.state.add_to_dedup_queue(message_hash)
- for peer in self.state.get_peers():
- if peer.get_key() != None:
- peer.send(message)
- else:
- self.print_debug("Discarding message to handle with no key: %s" % message.handle)
-
- def rebroadcast(self, source_peer, message):
- message.original = False
- for peer in self.state.get_peers():
- if(peer.peer_id != source_peer.peer_id):
- message.command = BROADCAST
- message.bounces = message.bounces + 1
- peer.send(message)
-
-
- def sendrubbish(self):
- for socket in self.clients:
- self.peer_message(Message({
- "speaker": self.clients[socket].nickname,
- "command": IGNORE,
- "bounces": 0,
- "body": self.infosec.gen_rubbish_body()
- }, self))
-
def start(self):
# Setup UDP first
self.udp_server_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
self.udp_server_socket.bind((self.address, self.udp_port))
- self.print_info("Listening for Pest packets on udp port %d." % self.udp_port)
+ self.station = Station({ "socket": self.udp_server_socket,
+ "db_path": self.db_path,
+ "address_table_path": self.address_table_path
+ })
+ self.station.start_embargo_queue_checking()
+ self.station.start_rubbish()
+ logging.info("Listening for Pest packets on udp port %d." % self.udp_port)
serversockets = []
for port in self.irc_ports:
@@ -259,51 +140,55 @@
try:
s.bind((self.address, port))
except socket.error as e:
- self.print_error("Could not bind port %s: %s." % (port, e))
+ logging.error("Could not bind port %s: %s." % (port, e))
sys.exit(1)
s.listen(5)
serversockets.append(s)
del s
- self.print_info("Listening for IRC connections on port %d." % port)
+ logging.info("Listening for IRC connections on port %d." % port)
if self.chroot:
os.chdir(self.chroot)
os.chroot(self.chroot)
- self.print_info("Changed root directory to %s" % self.chroot)
+ logging.info("Changed root directory to %s" % self.chroot)
if self.setuid:
os.setgid(self.setuid[1])
os.setuid(self.setuid[0])
- self.print_info("Setting uid:gid to %s:%s"
+ logging.info("Setting uid:gid to %s:%s"
% (self.setuid[0], self.setuid[1]))
last_aliveness_check = time.time()
while True:
(inputready,outputready,exceptready) = select.select([self.udp_server_socket],[],[],0)
(iwtd, owtd, ewtd) = select.select(
- serversockets + [x.socket for x in self.clients.values()],
- [x.socket for x in self.clients.values()
- if x.write_queue_size() > 0],
+ serversockets + ([self.client.socket] if self.client else []),
+ [self.client.socket] if self.client and self.client.write_queue_size() > 0 else [],
[],
.2)
for x in inputready:
- if x == self.udp_server_socket:
- bytes_address_pair = self.udp_server_socket.recvfrom(PACKET_SIZE)
- self.handle_udp_data(bytes_address_pair)
+ if x == self.udp_server_socket:
+ bytes_address_pair = self.udp_server_socket.recvfrom(PACKET_SIZE)
+ self.station.embargo_queue_lock.acquire()
+ try:
+ self.station.handle_udp_data(bytes_address_pair)
+ except sqlite3.ProgrammingError as ex:
+ logging.error("sqlite3 concurrency problem")
+ self.station.embargo_queue_lock.release()
for x in iwtd:
- if x in self.clients:
- self.clients[x].socket_readable_notification()
+ if self.client != None:
+ self.client.socket_readable_notification()
else:
(conn, addr) = x.accept()
- self.clients[conn] = Client(self, conn)
- self.print_info("Accepted connection from %s:%s." % (
+ self.client = Client(self, conn)
+ self.station.client = self.client
+ logging.info("Accepted connection from %s:%s." % (
addr[0], addr[1]))
for x in owtd:
- if x in self.clients: # client may have been disconnected
- self.clients[x].socket_writable_notification()
+ if self.client and x == self.client.socket: # client may have been disconnected
+ self.client.socket_writable_notification()
now = time.time()
if last_aliveness_check + 10 < now:
- for client in self.clients.values():
- client.check_aliveness()
- last_aliveness_check = now
- self.sendrubbish() # Kludge to keep ephemeral port open when NATed
+ if self.client:
+ self.client.check_aliveness()
+ last_aliveness_check = now
def create_directory(path):
if not os.path.isdir(path):
diff -uNr a/blatta/lib/state.py b/blatta/lib/state.py
--- a/blatta/lib/state.py acd5eaffdba356d5b2b2e0ce494e3be8aed35ccf0b96f9605bfd73fd3f758286f1908d043274a5480ac02c2d270550e1b061b32c0856e521a4eaba2f9f6b29f3
+++ b/blatta/lib/state.py 4f78202d4744a3284c00c4aac9c055f4abae95eea1c51c4acd519a9723a990d4d1fc336254140f75b7d75995a781935a2a6250dd19c7e7610b6643365a47938f
@@ -2,42 +2,53 @@
import sqlite3
import imp
import hashlib
+import logging
from itertools import chain
class State(object):
-
- def __init__(self, server, db_path):
- self.server = server
- self.conn = sqlite3.connect(db_path)
- self.cursor = self.conn.cursor()
- self.cursor.execute("create table if not exists at(handle_id integer,\
- address text not null,\
- port integer not null,\
- active_at datetime default null,\
- updated_at datetime default current_timestamp,\
- unique(handle_id, address, port))")
-
- self.cursor.execute("create table if not exists wot(peer_id integer primary key)")
-
- self.cursor.execute("create table if not exists handles(handle_id integer primary key,\
- peer_id integer,\
- handle text,\
- unique(handle))")
-
- self.cursor.execute("create table if not exists keys(peer_id intenger,\
- key text,\
- used_at datetime default current_timestamp,\
- unique(key))")
-
- self.cursor.execute("create table if not exists logs(\
- handle text not null,\
- peer_id integer,\
- message_bytes blob not null,\
- created_at datetime default current_timestamp)")
-
- self.cursor.execute("create table if not exists dedup_queue(\
- hash text not null,\
- created_at datetime default current_timestamp)")
+ __instance = None
+ @staticmethod
+ def get_instance(socket=None, db_path=None):
+ if State.__instance == None:
+ State(socket, db_path)
+ return State.__instance
+
+ def __init__(self, socket, db_path):
+ if State.__instance != None:
+ raise Exception("This class is a singleton")
+ else:
+ self.socket = socket
+ self.conn = sqlite3.connect(db_path, check_same_thread=False)
+ self.cursor = self.conn.cursor()
+ self.cursor.execute("create table if not exists at(handle_id integer,\
+ address text not null,\
+ port integer not null,\
+ active_at datetime default null,\
+ updated_at datetime default current_timestamp,\
+ unique(handle_id, address, port))")
+
+ self.cursor.execute("create table if not exists wot(peer_id integer primary key)")
+
+ self.cursor.execute("create table if not exists handles(handle_id integer primary key,\
+ peer_id integer,\
+ handle text,\
+ unique(handle))")
+
+ self.cursor.execute("create table if not exists keys(peer_id intenger,\
+ key text,\
+ used_at datetime default current_timestamp,\
+ unique(key))")
+
+ self.cursor.execute("create table if not exists logs(\
+ handle text not null,\
+ peer_id integer,\
+ message_bytes blob not null,\
+ created_at datetime default current_timestamp)")
+
+ self.cursor.execute("create table if not exists dedup_queue(\
+ hash text not null,\
+ created_at datetime default current_timestamp)")
+ State.__instance = self
def get_at(self, handle=None):
at = []
@@ -60,7 +71,7 @@
(handle_id,)).fetchone()[0]
at.append({"handle": h,
"address": "%s:%s" % (address, port),
- "active_at": updated_at})
+ "active_at": updated_at if updated_at else "no packets received from this address"})
return at
@@ -69,6 +80,7 @@
self.conn.commit()
result = self.cursor.execute("select hash from dedup_queue where hash=?",
(message_hash,)).fetchone()
+ logging.debug("checking if %s is dupe" % message_hash)
if(result != None):
return True
else:
@@ -78,6 +90,7 @@
self.cursor.execute("insert into dedup_queue(hash)\
values(?)",
(message_hash,))
+ logging.debug("added %s to dedup" % message_hash)
self.conn.commit()
def get_last_message_hash(self, handle, peer_id=None):
@@ -96,9 +109,14 @@
if message_bytes:
return hashlib.sha256(message_bytes[0][:]).digest()
else:
- return "0" * 32
+ return "\x00" * 32
+
+ def log(self, handle, message_bytes, peer=None):
+ if peer != None:
+ peer_id = peer.peer_id
+ else:
+ peer_id = None
- def log(self, handle, message_bytes, peer_id=None):
self.cursor.execute("insert into logs(handle, peer_id, message_bytes)\
values(?, ?, ?)",
(handle, peer_id, buffer(message_bytes)))
@@ -124,7 +142,7 @@
self.conn.commit()
- def update_address_table(self, peer, set_active_at=True):
+ def update_at(self, peer, set_active_at=True):
row = self.cursor.execute("select handle_id from handles where handle=?",
(peer["handle"],)).fetchone()
if row != None:
@@ -196,7 +214,7 @@
(peer_id,)).fetchall()))
def get_peer_handles(self):
- handles = list(chain.from_iterable(self.cursor.execute("select handle from handles").fetchall()))
+ handles = self.listify(self.cursor.execute("select handle from handles").fetchall())
return handles
def get_peers(self):
@@ -209,6 +227,20 @@
peers.append(peer)
return peers
+ def listify(self, results):
+ return list(chain.from_iterable(results))
+
+ def get_keyed_peers(self):
+ peer_ids = self.listify(self.cursor.execute("select peer_id from keys").fetchall())
+ peers = []
+ for peer_id in peer_ids:
+ handle = self.cursor.execute("select handle from handles where peer_id=?", (peer_id,)).fetchone()[0]
+ peer = self.get_peer_by_handle(handle)
+ if not (self.is_duplicate(peers, peer)):
+ peers.append(peer)
+ return peers
+
+
def get_peer_by_handle(self, handle):
handle_info = self.cursor.execute("select handle_id, peer_id from handles where handle=?",
(handle,)).fetchone()
@@ -219,18 +251,19 @@
address = self.cursor.execute("select address, port from at where handle_id=?\
order by updated_at desc limit 1",
(handle_info[0],)).fetchone()
- handles = list(chain.from_iterable(self.cursor.execute("select handle from handles where peer_id=?",
- (handle_info[1],)).fetchall()))
- keys = list(chain.from_iterable(self.cursor.execute("select key from keys where peer_id=?\
+ handles = self.listify(self.cursor.execute("select handle from handles where peer_id=?",
+ (handle_info[1],)).fetchall())
+ keys = self.listify(self.cursor.execute("select key from keys where peer_id=?\
order by used_at desc",
- (handle_info[1],)).fetchall()))
- return Peer(self.server, {
+ (handle_info[1],)).fetchall())
+ return Peer(self.socket, {
"handles": handles,
"peer_id": handle_info[1],
"address": address[0] if address else "",
"port": address[1] if address else "",
"keys": keys
})
+
def is_duplicate(self, peers, peer):
for existing_peer in peers:
if existing_peer.address == peer.address and existing_peer.port == peer.port:
diff -uNr a/blatta/lib/station.py b/blatta/lib/station.py
--- a/blatta/lib/station.py false
+++ b/blatta/lib/station.py 9e41fdd532e857cec8e4d3407560d8570b8e6b7b713739e6d81622b0a6abcbe5a74a9e1ce70f192be2ce5f5c2d0b2374e433bcb7a1d83e322c24460adf92723a
@@ -0,0 +1,196 @@
+import time
+import threading
+import binascii
+import logging
+import os
+from lib.state import State
+from lib.infosec import MAX_BOUNCES
+from lib.infosec import STALE_PACKET
+from lib.infosec import DUPLICATE_PACKET
+from lib.infosec import MALFORMED_PACKET
+from lib.infosec import INVALID_SIGNATURE
+from lib.infosec import IGNORED
+from lib.infosec import Infosec
+from commands import IGNORE
+from lib.message import Message
+from commands import BROADCAST
+from commands import DIRECT
+from lib.peer import Peer
+
+RUBBISH_INTERVAL = 10
+
+class Station(object):
+ def __init__(self, options):
+ self.client = None
+ self.state = State.get_instance(options["socket"], options["db_path"])
+ if options.get("address_table_path") != None:
+ self.state.import_at_and_wot(options.get("address_table_path"))
+ self.infosec = Infosec(self.state)
+ self.embargo_queue = {}
+ self.embargo_queue_lock = threading.Lock()
+
+ def start_embargo_queue_checking(self):
+ threading.Thread(target=self.check_embargo_queue).start()
+
+ def start_rubbish(self):
+ pass
+ threading.Thread(target=self.send_rubbish).start()
+
+ def handle_udp_data(self, bytes_address_pair):
+ data = bytes_address_pair[0]
+ address = bytes_address_pair[1]
+ packet_info = (address[0],
+ address[1],
+ binascii.hexlify(data)[0:16])
+ logging.debug("[%s:%d] -> %s" % packet_info)
+ for peer in self.state.get_keyed_peers():
+ message = self.infosec.unpack(peer, data)
+ error_code = message.error_code
+ if(error_code == None):
+ logging.debug("%s(%s) -> %s bounces: %d" % (message.speaker, peer.handles[0], message.body, message.bounces))
+ self.conditionally_update_at(peer, message, address)
+
+ # if this is a direct message, just deliver it and return
+ if message.command == DIRECT:
+ self.deliver(message)
+ return
+
+ # if the speaker is in our wot, we need to check if the message is hearsay
+ if message.speaker in self.state.get_peer_handles():
+ self.embargo(message)
+ return
+
+ else:
+ # skip the embargo and deliver this message with appropriate simple hearsay labeling
+ message.prefix = "%s[%s]" % (message.speaker, peer.handles[0])
+ self.deliver(message)
+ return
+ elif error_code == STALE_PACKET:
+ logging.debug("[%s:%d] -> stale packet: %s" % packet_info)
+ return
+ elif error_code == DUPLICATE_PACKET:
+ logging.debug("[%s:%d] -> duplicate packet: %s" % packet_info)
+ return
+ elif error_code == MALFORMED_PACKET:
+ logging.debug("[%s:%d] -> malformed packet: %s" % packet_info)
+ return
+ elif error_code == IGNORED:
+ self.conditionally_update_at(peer, message, address)
+ logging.debug("[%s:%d] -> ignoring packet: %s" % packet_info)
+ return
+ elif error_code == INVALID_SIGNATURE:
+ pass
+ logging.debug("[%s:%d] -> martian packet: %s" % packet_info)
+
+ def deliver(self, message):
+ # add to duplicate queue
+ self.state.add_to_dedup_queue(message.message_hash)
+
+ # send to the irc client
+ if self.client:
+ self.client.message_from_station(message)
+
+ def embargo(self, message):
+ # initialize the key/value to empty array if not in the hash
+ # append message to array
+ if not message.message_hash in self.embargo_queue.keys():
+ self.embargo_queue[message.message_hash] = []
+ self.embargo_queue[message.message_hash].append(message)
+
+ def check_embargo_queue(self):
+ # get a lock so other threads can't mess with the db or the queue
+ self.embargo_queue_lock.acquire()
+ self.check_for_immediate_messages()
+ self.flush_hearsay_messages()
+
+ # release the lock
+ self.embargo_queue_lock.release()
+
+ # continue the thread loop after interval
+ time.sleep(1)
+ threading.Thread(target=self.check_embargo_queue).start()
+
+ def check_for_immediate_messages(self):
+ for key in dict(self.embargo_queue).keys():
+ messages = self.embargo_queue[key]
+
+ for message in messages:
+
+ # if this is an immediate copy of the message
+
+ if message.speaker in message.peer.handles:
+
+ # clear the queue and deliver
+
+ self.embargo_queue.pop(key, None)
+ self.deliver(message)
+ self.rebroadcast(message)
+ break
+
+
+ def flush_hearsay_messages(self):
+ # if we made it this far either we haven't found any immediate messages
+ # or we sent them all so we must deliver the remaining hearsay messages
+ # with the appropriate labeling
+ for key in dict(self.embargo_queue).keys():
+
+ # collect the source handles
+ handles = []
+ messages = self.embargo_queue[key]
+ for message in messages:
+ handles.append(message.peer.handles[0])
+
+ # select the message with the lowest bounce count
+ message = sorted(messages, key=lambda m: m.bounces)[0]
+
+ # clear the queue
+ self.embargo_queue.pop(key, None)
+
+ # compute prefix
+ if len(messages) < 4:
+ message.prefix = "%s[%s]" % (message.speaker, "|".join(handles))
+ else:
+ message.prefix = "%s[%d]" % (message.speaker, len(messages))
+
+ # deliver
+ self.deliver(message)
+
+ # send the message to all other peers if it should be propagated
+ self.rebroadcast(message)
+
+
+ # we only update the address table if the speaker is same as peer
+
+ def conditionally_update_at(self, peer, message, address):
+ if message.speaker in peer.handles:
+ self.state.update_at({
+ "handle": message.speaker,
+ "address": address[0],
+ "port": address[1]
+ })
+
+ def rebroadcast(self, message):
+ if message.bounces < MAX_BOUNCES:
+ message.command = BROADCAST
+ message.bounces = message.bounces + 1
+ self.infosec.message(message)
+ else:
+ logging.debug("[%s:%d] -> packet TTL expired: %s" % packet_info)
+
+
+ def send_rubbish(self):
+ logging.debug("sending rubbish...")
+ self.embargo_queue_lock.acquire()
+ try:
+ if self.client:
+ self.infosec.message(Message({
+ "speaker": self.client.nickname,
+ "command": IGNORE,
+ "bounces": 0,
+ "body": self.infosec.gen_rubbish_body()
+ }))
+ except:
+ logging.error("Something went wrong attempting to send rubbish")
+ self.embargo_queue_lock.release()
+ time.sleep(RUBBISH_INTERVAL)
+ threading.Thread(target=self.send_rubbish).start()
diff -uNr a/blatta/start_test_net.sh b/blatta/start_test_net.sh
--- a/blatta/start_test_net.sh 10233fa2a74d0f92f3215b417140a9481f1263ceb7ca4486cca97d48e9c112a36a9b66cb4f2c99a553626dea431d6d8ae6d22735bd2535b8bde7ea964a1f0b21
+++ b/blatta/start_test_net.sh 24a5c19318989da9f79790107499e2ebda16bc5389b739e4e3ae686c3ff024317517203b9c5c3324ae1a391a63f94939e22c8de730e758ecbc6afee4f54e108d
@@ -1,6 +1,6 @@
#!/bin/bash
# start 3 servers on different ports
-./blatta --debug --channel-name \#aleth --irc-port 6668 --udp-port 7778 --db-path a.db --address-table-path test_net_configs/a.py > logs/a &
+./blatta --debug --channel-name \#aleth --irc-port 9968 --udp-port 7778 --db-path a.db --address-table-path test_net_configs/a.py > logs/a &
./blatta --debug --channel-name \#aleth --irc-port 6669 --udp-port 7779 --db-path b.db --address-table-path test_net_configs/b.py > logs/b &
./blatta --debug --channel-name \#aleth --irc-port 6670 --udp-port 7780 --db-path c.db --address-table-path test_net_configs/c.py > logs/c &
diff -uNr a/blatta/test_net_configs/a.py b/blatta/test_net_configs/a.py
--- a/blatta/test_net_configs/a.py 3276661a7529957fb3d7aac616f26be8de21d436ae5092c40662bc2fca472a6be7e460a2d8c76286a8d84bac1e8a8bf94b98e086c9df5a8a9389ff8c9efec8b9
+++ b/blatta/test_net_configs/a.py 27bfacb1a2f3d5c0c9947045e0dbf61d2822c0da84be7b9589b261d79fd3b9b2a845d354fc0310aae2841d5dc0b1d8ff22960d33d779dfbc6e0680bd33424d27
@@ -4,9 +4,9 @@
'name': 'awt_b',
'port': 7779
},
- { 'address': 'localhost',
- 'key': 'lT8/fYe/rQdReyavsTrVqInnLFCaU38o2ZAn5+r8uoFSSWgJelafikFELR9t6SJHMpFQvLmlAbF14nL2PfOAyA==',
- 'name': 'awt_c',
- 'port': 7780
- }
+# { 'address': 'localhost',
+# 'key': 'lT8/fYe/rQdReyavsTrVqInnLFCaU38o2ZAn5+r8uoFSSWgJelafikFELR9t6SJHMpFQvLmlAbF14nL2PfOAyA==',
+# 'name': 'awt_c',
+# 'port': 7780
+# }
]
diff -uNr a/blatta/tests/__init__.py b/blatta/tests/__init__.py
--- a/blatta/tests/__init__.py false
+++ b/blatta/tests/__init__.py 85df4eea67226c8976c9484f97e06ee93506c5e43982babfe99ae2a075f4e1f43f99442cc80897e3ad2d8a409ae59320410347ca74c4317674b815082de8b240
@@ -0,0 +1 @@
+# This file can't be empty otherwise diff won't see it.
diff -uNr a/blatta/tests/test_station.py b/blatta/tests/test_station.py
--- a/blatta/tests/test_station.py false
+++ b/blatta/tests/test_station.py 991e1ff9817a01d4320abdf30e09c890b5454c726e313a488dd6bedc9e8e663019a63caf9fc16979251f653beda05ececa4a806a7e3c299e6847a8b6fe11a6e4
@@ -0,0 +1,194 @@
+# https://stackoverflow.com/questions/1896918/running-unittest-with-typical-test-directory-structure
+import unittest
+import logging
+from mock import Mock
+from mock import patch
+
+from lib.station import Station
+
+class TestStation(unittest.TestCase):
+ def setUp(self):
+ logging.basicConfig(level=logging.DEBUG)
+ options = {
+ "clients": {"clientsocket": Mock()},
+ "db_path": "tests/test.db",
+ "socket": Mock()
+ }
+ self.station = Station(options)
+ self.station.deliver = Mock()
+ self.station.rebroadcast = Mock()
+ self.station.rebroadcast.return_value = "foobar"
+
+ def tearDown(self):
+ pass
+
+ def test_embargo_bounce_ordering(self):
+ peer1 = Mock()
+ peer1.handles = ["a", "b"]
+ peer2 = Mock()
+ peer2.handles = ["c", "d"]
+ low_bounce_message = Mock()
+ low_bounce_message.peer = peer1
+ low_bounce_message.bounces = 1
+ low_bounce_message.message_hash = "messagehash"
+ high_bounce_message = Mock()
+ high_bounce_message.peer = peer2
+ high_bounce_message.bounces = 2
+ high_bounce_message.message_hash = "messagehash"
+ self.station.embargo_queue = {
+ "messagehash": [
+ low_bounce_message,
+ high_bounce_message
+ ],
+ }
+ self.station.flush_hearsay_messages()
+ self.station.deliver.assert_called_once_with(low_bounce_message)
+ self.station.rebroadcast.assert_called_once_with(low_bounce_message)
+
+ def test_immediate_message_delivered(self):
+ peer = Mock()
+ peer.handles = ["a", "b"]
+ message = Mock()
+ message.speaker = "a"
+ message.peer = peer
+ self.station.embargo_queue = {
+ "messagehash": [
+ message
+ ],
+ }
+ self.station.check_for_immediate_messages()
+ self.station.deliver.assert_called_once_with(message)
+ self.station.rebroadcast.assert_called_once_with(message)
+
+ def test_hearsay_message_not_delivered(self):
+ peer = Mock()
+ peer.handles = ["a", "b"]
+ message = Mock()
+ message.speaker = "c"
+ message.peer = peer
+ self.station.embargo_queue = {
+ "messagehash": [
+ message
+ ],
+ }
+ self.station.check_for_immediate_messages()
+ self.station.deliver.assert_not_called()
+
+ def test_embargo_queue_cleared(self):
+ peer = Mock()
+ peer.handles = ["a", "b"]
+ message = Mock()
+ message.speaker = "c"
+ message.peer = peer
+ self.station.embargo_queue = {
+ "messagehash": [
+ message
+ ],
+ }
+ self.assertEqual(len(self.station.embargo_queue), 1)
+ self.station.flush_hearsay_messages()
+ self.assertEqual(len(self.station.embargo_queue), 0)
+
+ def test_immediate_prefix(self):
+ peer = Mock()
+ peer.handles = ["a", "b"]
+ message = Mock()
+ message.speaker = "a"
+ message.prefix = None
+ message.peer = peer
+ self.station.embargo_queue = {
+ "messagehash": [
+ message
+ ],
+ }
+ self.station.check_for_immediate_messages()
+ self.assertEqual(message.prefix, None)
+
+ def test_simple_hearsay_prefix(self):
+ peer = Mock()
+ peer.handles = ["a", "b"]
+ message = Mock()
+ message.speaker = "c"
+ message.prefix = None
+ message.peer = peer
+ self.station.embargo_queue = {
+ "messagehash": [
+ message
+ ],
+ }
+ self.station.flush_hearsay_messages()
+ self.assertEqual(message.prefix, "c[a]")
+
+ def test_in_wot_hearsay_prefix_under_four(self):
+ peer1 = Mock()
+ peer1.handles = ["a", "b"]
+ peer2 = Mock()
+ peer2.handles = ["d", "e"]
+ peer3 = Mock()
+ peer3.handles = ["f", "g"]
+ message_via_peer1 = Mock()
+ message_via_peer1.speaker = "c"
+ message_via_peer1.prefix = None
+ message_via_peer1.peer = peer1
+ message_via_peer1.bounces = 1
+ message_via_peer2 = Mock()
+ message_via_peer2.speaker = "c"
+ message_via_peer2.prefix = None
+ message_via_peer2.peer = peer2
+ message_via_peer2.bounces = 2
+ message_via_peer3 = Mock()
+ message_via_peer3.speaker = "c"
+ message_via_peer3.prefix = None
+ message_via_peer3.peer = peer3
+ message_via_peer3.bounces = 1
+ self.station.embargo_queue = {
+ "messagehash": [
+ message_via_peer1,
+ message_via_peer2,
+ message_via_peer3
+ ],
+ }
+ self.station.flush_hearsay_messages()
+ self.station.deliver.assert_called_once_with(message_via_peer1)
+ self.assertEqual(message_via_peer1.prefix, "c[a|d|f]")
+
+ def test_in_wot_hearsay_prefix_more_than_three(self):
+ peer1 = Mock()
+ peer1.handles = ["a", "b"]
+ peer2 = Mock()
+ peer2.handles = ["d", "e"]
+ peer3 = Mock()
+ peer3.handles = ["f", "g"]
+ peer4 = Mock()
+ peer4.handles = ["f", "g"]
+ message_via_peer1 = Mock()
+ message_via_peer1.speaker = "c"
+ message_via_peer1.prefix = None
+ message_via_peer1.peer = peer1
+ message_via_peer1.bounces = 1
+ message_via_peer2 = Mock()
+ message_via_peer2.speaker = "c"
+ message_via_peer2.prefix = None
+ message_via_peer2.peer = peer2
+ message_via_peer2.bounces = 2
+ message_via_peer3 = Mock()
+ message_via_peer3.speaker = "c"
+ message_via_peer3.prefix = None
+ message_via_peer3.peer = peer3
+ message_via_peer3.bounces = 1
+ message_via_peer4 = Mock()
+ message_via_peer4.speaker = "c"
+ message_via_peer4.prefix = None
+ message_via_peer4.peer = peer4
+ message_via_peer4.bounces = 1
+ self.station.embargo_queue = {
+ "messagehash": [
+ message_via_peer1,
+ message_via_peer2,
+ message_via_peer3,
+ message_via_peer4
+ ],
+ }
+ self.station.flush_hearsay_messages()
+ self.station.deliver.assert_called_once_with(message_via_peer1)
+ self.assertEqual(message_via_peer1.prefix, "c[4]")