diff options
Diffstat (limited to 'main.py')
-rwxr-xr-x | main.py | 338 |
1 files changed, 338 insertions, 0 deletions
@@ -0,0 +1,338 @@ +#!/bin/env python +import ipaddress +import os +import shutil +import subprocess +import sys +import yaml +from pyroute2 import NDB, netns + +class SubnetManager: + def __init__(self, config): + self.base = config['base'] + self.zone_size = config['zone'] + self.local_size = config['local'] + if ipaddress.ip_address(self.base).version == 6: + self.prefix = 128 - self.zone_size - self.local_size + else: + self.prefix = 32 - self.zone_size - self.local_size + self.networks = ipaddress.ip_network((self.base, self.prefix)).subnets(self.zone_size) + + self.current_net = next(self.networks).hosts() + + def next_local(self): + return next(self.current_net) + + def next_zone(self): + self.current_net = next(self.networks).hosts() + + def __str__(self): + return f'SubnetManager{{base: {self.base}, zone: {self.zone_size}, local: {self.local_size}}}' + +class Latency: + def __init__(self, latency, offset = None): + if type(latency) is int: + self.latency_ms = latency + else: + for suffix, factor in [("ms", 1), ("s", 1000), ("m", 60 * 1000)]: + if latency.endswith(suffix): + self.latency_ms = int(float(latency[:-len(suffix)]) * factor) + break + else: + self.latency_ms = int(latency) + if offset: + self.latency_ms -= Latency(offset).latency_ms + if self.latency_ms < 0: + self.latency_ms = 0 + + def __eq__(self, o): + return isinstance(o, Latency) and o.latency_ms == self.latency_ms + + def __str__(self): + return f'{self.latency_ms}ms' + +class Bandwidth: + def __init__(self, bw): + def convert(bw): + factor = 1 + for suffix, f in [("bit", 1), ("bps",1), ("b", 1), ("byte", 8), ("Byte",8), ("B", 8)]: + if bw.endswith(suffix): + bw = bw[:-len(suffix)] + factor = f + break + + for suffix, f in [("k", 1000), ("ki", 1024), ("m", 1000**2), ("mi", 1024**2), ("g", 1000**3), ("gi", 1024**3)]: + if bw.lower().endswith(suffix): + return int(float(bw[:-len(suffix)]) * factor * f) + else: + return int(float(bw) * factor) + if type(bw) is dict: + self.down = convert(bw["down"]) + self.up = convert(bw["up"]) + else: + self.down = convert(bw) + self.up = convert(bw) + + def __str__(self): + def convert(bw): + for suffix, factor in [("g", 1000**3), ("m", 1000**2), ("k", 1000)]: + if bw > 10 * factor: + return f'{bw/factor:.1f}{suffix}bps' + return f'{convert(self.down)}/{convert(self.up)}' + + def __eq__(self, o): + return (isinstance(o, Bandwidth) and + o.down == self.down and + o.up == self.up) + + +class LinkInfo: + def __init__(self, bandwidth, latency, jitter = None, offset = None, **kwargs): + self.bandwidth = Bandwidth(bandwidth) + self.latency = Latency(latency, offset) + self.jitter = Latency(jitter or 0) + + def __eq__(self, o): + return (isinstance(o, LinkInfo) and + o.bandwidth == self.bandwidth and + o.latency == self.latency and + o.jitter == self.jitter) + + def __str__(self): + return f'LinkInfo{{bw: {self.bandwidth}, latency: {self.latency}, jitter: {self.jitter}}}' + +class Server: + def __init__(self, name, link): + self.ip = None + self.name = name + self.link = link + + def is_zone(self): + return False + + def __str__(self): + return f'Server{{name: {self.name}, ip: {self.ip}, link: {self.link}}}' + + def __repr__(self): + return self.__str__() + +class Zone: + def __init__(self, name): + self.name = name + self.link = None + self.servers = {} + + def is_zone(self): + return True + + def add_server(self, server): + if self.servers.get(server.name) is None: + self.servers[server.name] = server + else: + raise Exception(f"Duplicate server '{server.name}' in zone '{self.name}'") + + def set_link(self, link): + if self.link is None: + self.link = link + elif self.link != link: + raise Exception(f"Uncoherent link configuration for zone '{self.name}'") + + def __str__(self): + return f'Zone{{name: {self.name}, link: {self.link}, servers: {list(self.servers.values())}}}' + + def __repr__(self): + return self.__str__() + + +class Network: + def __init__(self): + self.zones = {} + self.subnet_manager = None + self.latency_off = Latency(0) + + def set_subnet_manager(self, subnet): + self.subnet_manager = SubnetManager(subnet) + + def set_latency_offset(self, latency_off): + self.latency_off = latency_off + + def add_server(self, server): + name = server["name"] + if zone_obj := server.get("zone"): + zone_name = zone_obj["name"] + if zone := self.zones.get(zone_name): + if not zone.is_zone(): + raise Exception("Duplicate zone: " + name) + else: + zone = Zone(zone_name) + self.zones[zone_name] = zone + if link:=zone_obj.get("external"): + zone.set_link(LinkInfo(offset = self.latency_off, **link)) + zone.add_server(Server(name, LinkInfo(**zone_obj["internal"]))) + + else: + name = name + if name in self.zones: + raise Exception("Duplicate zone: " + name) + self.zones[name] = Server(name, LinkInfo(offset = self.latency_off, **server)) + + def assign_ips(self): + for zone in self.zones.values(): + if zone.is_zone(): + for server in zone.servers.values(): + server.ip = self.subnet_manager.next_local() + else: + zone.ip = self.subnet_manager.next_local() + self.subnet_manager.next_zone() + + def __str__(self): + return f'Network{{subnet_manager: {self.subnet_manager}, zones: {list(self.zones.values())}, latency_offset: {self.latency_off}}}' + +class NamespaceManager: + def __init__(self): + self.ndb = NDB(log='debug') + self.namespaces = set() + + def make_namespace(self, name): + if not name in self.namespaces: + self.namespaces.add(name) + self.ndb.sources.add(netns=name) + + def make_bridge(self, name, namespace): + self.make_namespace(namespace) + return (self.ndb.interfaces.create(ifname=name, kind='bridge', target=namespace) + .set('state', 'up')) + + def make_veth(self, name1, name2, space1, space2, ip, link): + self.make_namespace(space1) + self.make_namespace(space2) + veth = (self.ndb.interfaces.create(ifname=name1, target=space1, kind='veth', + peer={'ifname': name2, 'net_ns_fd': space2}) + .set('state', 'up')) + if ip: + veth.add_ip(address=str(ip), prefixlen=self.prefixlen) + veth.commit() + #NSPopen(space1, ['tc', 'qdisc', 'add', 'dev', nam1, 'root', 'netem', 'delay', '1000ms', '0ms']) + #add the other way too + + def build_network(self, network): + self.prefixlen = network.subnet_manager.prefix + netns = "testnet-core" + bridge = self.make_bridge("br0", netns) + for zone in network.zones.values(): + if zone.is_zone(): + self.build_zone(zone) + else: + self.build_server(zone) + bridge.add_port({'ifname': 'veth-' + zone.name, 'target': netns}) + bridge.commit() + + def build_zone(self, zone): + netns = "testnet-" + zone.name + bridge = self.make_bridge("br-" + zone.name, netns) + self.make_veth("veth-" + zone.name, "veth-" + zone.name, netns, "testnet-core", None, zone.link) + bridge.add_port({'ifname': 'veth-' + zone.name, 'target': netns}) + for server in zone.servers.values(): + self.build_server(server, zone) + bridge.add_port({'ifname': 'veth-' + server.name, 'target': netns}) + bridge.commit() + + def build_server(self, server, zone = None): + if zone: + zone_name = "testnet-" + zone.name + else: + zone_name = "testnet-core" + namespace = zone_name + "-" + server.name + self.make_veth("veth", "veth-" + server.name, namespace, zone_name, server.ip, server.link) + + def close(self): + self.ndb.close() + +def parse(yaml): + server_list = yaml["servers"] + global_conf = yaml.get("global", {}) + subnet = global_conf.get("subnet", {'base': 'fc00:9a7a:9e::', 'local': 64, 'zone': 16}) + latency_offset = global_conf.get("latency-offset", 0) + + network = Network() + network.set_subnet_manager(subnet) + network.set_latency_offset(latency_offset) + for server in server_list: + network.add_server(server) + network.assign_ips() + return network + +def create(config_path): + with open(config_path, "r") as file: + config = yaml.safe_load(file) + shutil.copy(config_path, ".current_state.yml") + network = parse(config) + nsm = NamespaceManager() + nsm.build_network(network) + nsm.close() + +def run(netns, cmd): + if ":" in netns: + zone_name,host = netns.split(":", 1) + else: + zone_name,host = None, netns + with open(".current_state.yml", "r") as file: + config = yaml.safe_load(file) + zones = parse(config).zones + server = None + zone = None + if zone_name: + if (zone := zones.get(zone_name)) and zone.is_zone(): + server = zone.servers.get(host) + elif (zone := zones.get(host)) and not zone.is_zone(): + server = zone + else: + for z in zones.values(): + if not z.is_zone(): continue + if (s := z.servers.get(host)): + if server: + raise Exception("Multiple matching host found.") + server = s + zone = z + + if not server: + raise Exception("No matching host was found") + + env = os.environ.copy() + env["HOST"] = server.name + if zone.is_zone(): + env["ZONE"] = zone.name + env["IP"] = str(server.ip) + name = f'testnet-{zone.name}-{server.name}' + + if len(cmd) == 0: + cmd = [os.getenv("SHELL") or "/bin/sh"] + os.execve("/bin/env", ["/bin/env", "ip", "netns" , "exec", name ] + cmd, env) + + +def destroy(): + for ns in netns.listnetns(): + if ns.startswith("testnet-"): + subprocess.run(f"ip netns pids {ns} | xargs -r kill", check=True, shell=True) + netns.remove(ns) + os.remove(".current_state.yml") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("""Usage: + mk-testnet create [config_path] # create a new network. config_path defailt to config.yml + mk-testnet run-all <cmd> [args...] # run a command as each host. set the IP, NAME and ZONE environment variables + mk-testnet run <name> [cmd [args...]] # run command in host named <name>. Use zonename:name if multiple zones hosts server with same name. If cmd is empty, run a shell + mk-testnet destroy # destroy the current environment""") + exit() + cmd = sys.argv[1] + if cmd == "create": + create(sys.argv[2] if len(sys.argv) > 2 else "config.yml") + elif cmd == "run-all": + runall(sys.argv[2:]) + elif cmd == "run": + run(sys.argv[2], sys.argv[3:]) + elif cmd == "destroy": + destroy() + else: + raise Exception(f"Unknown command: {cmd}") |