aboutsummaryrefslogblamecommitdiff
path: root/main.py
blob: 456da52a9f347d3a2233df3426790ca68ee1450a (plain) (tree)
1
2
3
4
5
6
7
8
                     





                 
          

























                                                                                                      
                                     
             
                                                                                                
                                            
                                                                                 

                         
                                                     
                  


                                                         

                        
                                                                         

                      
                                     


































































































                                                                                                                           

                             


































                                                                                          

                                                           





                                                                                                                                          
                                             
                          
                                


                                       
                               
                                     
 
                                                  
                                      
                                                 
        
                                                                     

                                   
                                                                               



                                                      

                                                                                                             




                                           

                                             


                                      
                                                                                                        
                                     

                                            

                                                         



                                                
                                                     

                                      
                                                                    

                                                                                                   






                                                                                           



                                                                                    













                                                      




































                                                                    
                                                                                         
 




                                                 
              






                                                
                                       
                                                         

                                                           
                          

                                   
                            

                                    
                                   
                                                     

                                                     

                         

              

                            
                               












                                                                                                                                                                           

                                      

                            



                                                  
#!/usr/bin/env python
import ipaddress
import os
import shutil
import subprocess
import sys
import yaml
import net

class SubnetManager:
    def __init__(self, config):
        self.base = config['base']
        self.zone_size = config['zone']
        self.local_size = config['local']
        if ipaddress.ip_address(self.base).version == 6:
            self.prefix = 128 - self.zone_size - self.local_size
        else:
            self.prefix = 32 - self.zone_size - self.local_size
        self.networks = ipaddress.ip_network((self.base, self.prefix)).subnets(self.zone_size)

        self.current_net = next(self.networks).hosts()

    def next_local(self):
        return next(self.current_net)

    def next_zone(self):
        self.current_net = next(self.networks).hosts()

    def __str__(self):
        return f'SubnetManager{{base: {self.base}, zone: {self.zone_size}, local: {self.local_size}}}'

class Latency:
    def __init__(self, latency, offset = None):
        if type(latency) is int:
            self.latency_us = latency
        else:
            for suffix, factor in [("us", 1),("ms", 1000), ("s", 1000000), ("m", 60 * 1000000)]:
                if latency.endswith(suffix):
                    self.latency_us = int(float(latency[:-len(suffix)]) * factor)
                    break
            else:
                self.latency_us = int(latency) * 1000
        if offset:
            self.latency_us -= Latency(offset).latency_us
        if self.latency_us < 0:
            self.latency_us = 0

    def __eq__(self, o):
        return isinstance(o, Latency) and o.latency_us == self.latency_us

    def __str__(self):
        return f'{self.latency_us}ms'

class Bandwidth:
    def __init__(self, bw):
        def convert(bw):
            factor = 1
            for suffix, f in [("bit", 1), ("bps",1), ("b", 1), ("byte", 8), ("Byte",8), ("B", 8)]:
                if bw.endswith(suffix):
                    bw = bw[:-len(suffix)]
                    factor = f
                    break

            for suffix, f in [("k", 1000), ("ki", 1024), ("m", 1000**2), ("mi", 1024**2), ("g", 1000**3), ("gi", 1024**3)]:
                if bw.lower().endswith(suffix):
                    return int(float(bw[:-len(suffix)]) * factor * f)
            else:
                return int(float(bw) * factor)
        if type(bw) is dict:
            self.down = convert(bw["down"])
            self.up = convert(bw["up"])
        else:
            self.down = convert(bw)
            self.up = convert(bw)

    def __str__(self):
        def convert(bw):
            for suffix, factor in [("g", 1000**3), ("m", 1000**2), ("k", 1000)]:
                if bw > 10 * factor:
                    return f'{bw/factor:.1f}{suffix}bps'
        return f'{convert(self.down)}/{convert(self.up)}'

    def __eq__(self, o):
        return (isinstance(o, Bandwidth) and
                o.down == self.down and
                o.up == self.up)
        

class LinkInfo:
    def __init__(self, bandwidth, latency, jitter = None, offset = None, **kwargs):
        self.bandwidth = Bandwidth(bandwidth)
        self.latency = Latency(latency, offset)
        self.jitter = Latency(jitter or 0)

    def __eq__(self, o):
        return (isinstance(o, LinkInfo) and 
            o.bandwidth == self.bandwidth and
            o.latency == self.latency and
            o.jitter == self.jitter)

    def __str__(self):
        return f'LinkInfo{{bw: {self.bandwidth}, latency: {self.latency}, jitter: {self.jitter}}}'

class Server:
    def __init__(self, name, link):
        self.ip = None
        self.name = name
        self.link = link
    
    def is_zone(self):
        return False

    def __str__(self):
        return f'Server{{name: {self.name}, ip: {self.ip}, link: {self.link}}}'

    def __repr__(self):
        return self.__str__()

class Zone:
    def __init__(self, name):
        self.name = name
        self.link = None
        self.servers = {}

    def is_zone(self):
        return True

    def add_server(self, server):
        if self.servers.get(server.name) is None:
            self.servers[server.name] = server
        else:
            raise Exception(f"Duplicate server '{server.name}' in zone '{self.name}'")

    def set_link(self, link):
        if self.link is None:
            self.link = link
        elif self.link != link:
            raise Exception(f"Uncoherent link configuration for zone '{self.name}'")

    def __str__(self):
        return f'Zone{{name: {self.name}, link: {self.link}, servers: {list(self.servers.values())}}}'

    def __repr__(self):
        return self.__str__()


class Network:
    def __init__(self):
        self.zones = {}
        self.subnet_manager = None
        self.latency_off = Latency(0)
        self.host_ip = None
        self.host_link = None

    def set_subnet_manager(self, subnet):
        self.subnet_manager = SubnetManager(subnet)

    def set_latency_offset(self, latency_off):
        self.latency_off = latency_off

    def add_server(self, server):
        name = server["name"]
        if zone_obj := server.get("zone"):
            zone_name = zone_obj["name"]
            if zone := self.zones.get(zone_name):
                if not zone.is_zone():
                    raise Exception("Duplicate zone: " + name)
            else:
                zone = Zone(zone_name)
                self.zones[zone_name] = zone
            if link:=zone_obj.get("external"):
                zone.set_link(LinkInfo(offset = self.latency_off, **link))
            zone.add_server(Server(name, LinkInfo(**zone_obj["internal"])))

        else:
            name = name
            if name in self.zones:
                raise Exception("Duplicate zone: " + name)
            self.zones[name] = Server(name, LinkInfo(offset = self.latency_off, **server))

    def assign_ips(self):
        for zone in self.zones.values():
            if zone.is_zone(): 
                for server in zone.servers.values():
                    server.ip = self.subnet_manager.next_local()
            else:
                zone.ip = self.subnet_manager.next_local()
            self.subnet_manager.next_zone()
        if not self.host_ip:
            self.host_ip = self.subnet_manager.next_local()

    def __str__(self):
        return f'Network{{subnet_manager: {self.subnet_manager}, zones: {list(self.zones.values())}, latency_offset: {self.latency_off}}}'

class NamespaceManager:
    def __init__(self):
        self.namespaces = set(["unconfined"])
        self.prefixlen = 0
        net.ns.name_unconfined()

    def make_namespace(self, name):
        if not name in self.namespaces:
            net.ns.create(name)
            self.namespaces.add(name)

    def make_bridge(self, name, namespace, ports):
        self.make_namespace(namespace)
        net.create_bridge(name, namespace, ports)
        
    def make_veth(self, name1, name2, space1, space2, ip, link=None):
        self.make_namespace(space1)
        self.make_namespace(space2)
        net.create_veth(name1, space1, name2, space2, ip, self.prefixlen, link)

    def build_network(self, network):
        self.prefixlen = network.subnet_manager.prefix
        netns = "testnet-core"
        self.make_veth("veth-testnet", "unconfined", "unconfined", netns, network.host_ip, network.host_link)
        ports = ["unconfined"]
        for zone in network.zones.values():
            if zone.is_zone():
                self.build_zone(zone)
            else:
                self.build_server(zone)
            ports.append('veth-' + zone.name)
        self.make_bridge("br0", netns, ports)

    def build_zone(self, zone):
        netns = "testnet-" + zone.name
        self.make_veth("veth-" + zone.name, "veth-" + zone.name, netns, "testnet-core", None, zone.link)
        ports = ['veth-' + zone.name]
        for server in zone.servers.values():
            self.build_server(server, zone)
            ports.append('veth-' + server.name)
        self.make_bridge("br-" + zone.name, netns, ports)

    def build_server(self, server, zone = None):
        if zone:
            zone_name = "testnet-" + zone.name
            namespace = zone_name + "-" + server.name
        else:
            zone_name = "testnet-core"
            namespace = "testnet-" + server.name + "-" + server.name
        self.make_veth("veth", "veth-" + server.name, namespace, zone_name, server.ip, server.link)

def parse(yaml):
    server_list = yaml["servers"]
    global_conf = yaml.get("global", {})
    subnet = global_conf.get("subnet", {'base': 'fc00:9a7a:9e::', 'local': 64, 'zone': 16})
    latency_offset = global_conf.get("latency-offset", 0)

    network = Network()
    if upstream := global_conf.get("upstream"):
        network.host_ip = upstream.get("ip")
        if host_link:= upstream.get("conn"):
            network.host_link = LinkInfo(latency_offset=latency_offset, **host_link)
    network.set_subnet_manager(subnet)
    network.set_latency_offset(latency_offset)
    for server in server_list:
        network.add_server(server)
    network.assign_ips()
    return network

def create(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
        shutil.copy(config_path, ".current_state.yml")
    network = parse(config)
    nsm = NamespaceManager()
    nsm.build_network(network)

def run(netns, cmd):
    if ":" in netns:
        zone_name,host = netns.split(":", 1)
    else:
        zone_name,host = None, netns
    with open(".current_state.yml", "r") as file:
        config = yaml.safe_load(file)
    zones = parse(config).zones
    server = None
    zone = None
    if zone_name:
        if (zone := zones.get(zone_name)) and zone.is_zone():
            server = zone.servers.get(host)
    elif (zone := zones.get(host)) and not zone.is_zone():
        server = zone
    else:
        for z in zones.values():
            if not z.is_zone(): continue
            if (s := z.servers.get(host)):
                if server:
                    raise Exception("Multiple matching host found.")
                server = s
                zone = z

    if not server:
        raise Exception("No matching host was found")

    env = os.environ.copy()
    env["HOST"] = server.name
    if zone.is_zone():
        env["ZONE"] = zone.name
    env["IP"] = str(server.ip)
    name = f'testnet-{zone.name}-{server.name}'

    if len(cmd) == 0:
        cmd = [os.getenv("SHELL") or "/bin/sh"]
    os.execve("/usr/bin/env", ["/usr/bin/env", "ip", "netns" , "exec", name ] + cmd, env)

def runall(cmd):
    with open(".current_state.yml", "r") as file:
        config = yaml.safe_load(file)
    zones = parse(config).zones

    number = 1
    for zone in zones.values():
        if zone.is_zone():
            for server in zone.servers.values():
                env = os.environ.copy()
                env["ZONE"] = zone.name
                env["HOST"] = server.name
                env["IP"] = str(server.ip)
                env["ID"] = str(number)
                env["SIZE"] = str(len(config['servers']))
                name = f'testnet-{zone.name}-{server.name}'
                net.ns.run(name, cmd, env)
                number +=1
        else:
            env = os.environ.copy()
            env["ZONE"] = ""
            env["HOST"] = zone.name
            env["IP"] = str(zone.ip)
            env["ID"] = str(number)
            env["SIZE"] = str(len(config['servers']))
            name = f'testnet-{zone.name}-{zone.name}'
            net.ns.run(name, cmd, env)
            first = False
            number +=1

def destroy():
    for ns in net.ns.list():
        net.ns.kill(ns)
    net.ns.forget("unconfined")
    os.remove(".current_state.yml")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("""Usage:
    mk-testnet create [config_path] # create a new network. config_path defailt to config.yml
    mk-testnet run-all <cmd> [args...] # run a command as each host. set the IP, NAME and ZONE environment variables
    mk-testnet run <name> [cmd [args...]] # run command in host named <name>. Use zonename:name if multiple zones hosts server with same name. If cmd is empty, run a shell
    mk-testnet destroy # destroy the current environment""")
        exit()
    cmd = sys.argv[1]
    if cmd == "create":
        create(sys.argv[2] if len(sys.argv) > 2 else "config.yml")
    elif cmd == "run":
        run(sys.argv[2], sys.argv[3:])
    elif cmd == "run-all":
        runall(sys.argv[2:])
    elif cmd == "destroy":
        destroy()    
    else:
        raise Exception(f"Unknown command: {cmd}")