aboutsummaryrefslogtreecommitdiff
path: root/lib/net/tcpconn.ex
blob: 64c85e963a51704c0568672b2e3fc6ee13ddaa1f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
defmodule SNet.TCPConn do
  use GenServer, restart: :temporary
  require Salty.Box.Curve25519xchacha20poly1305, as: Box
  require Salty.Sign.Ed25519, as: Sign
  require Logger

  def start_link(state) do
    GenServer.start_link(__MODULE__, state)
  end

  def init(state) do
    GenServer.cast(self(), :handshake)
  	{:ok, state}
  end

  def handle_call(:get_host_str, _from, state) do
    {:reply, "#{state.his_pkey|>Base.encode16|>String.downcase}@#{to_string(:inet_parse.ntoa(state.addr))}:#{state.port}", state}
  end

  def handle_cast(:handshake, state) do
    socket = state.socket

    {srv_pkey, srv_skey} = Shard.Identity.get_keypair
    {:ok, sess_pkey, sess_skey} = Box.keypair
    {:ok, challenge} = Salty.Random.buf 32

    # Exchange public keys and challenge
    hello = {srv_pkey, sess_pkey, challenge, state.my_port}
    :gen_tcp.send(socket, :erlang.term_to_binary hello)  
    {:ok, pkt} = :gen_tcp.recv(socket, 0)
    {cli_pkey, cli_sess_pkey, cli_challenge, his_port} = :erlang.binary_to_term(pkt, [:safe])

    # Do challenge and check their challenge
    {:ok, cli_challenge_sign} = Sign.sign_detached(cli_challenge, srv_skey)
    pkt = encode_pkt(cli_challenge_sign, cli_sess_pkey, sess_skey)
    :gen_tcp.send(socket, pkt)

    {:ok, pkt} = :gen_tcp.recv(socket, 0)
    challenge_sign = decode_pkt(pkt, cli_sess_pkey, sess_skey)
    :ok = Sign.verify_detached(challenge_sign, challenge, cli_pkey)

    # Connected
    :inet.setopts(socket, [active: true])

    {:ok, {addr, port}} = :inet.peername socket
    state =%{ socket: socket,
        my_pkey: srv_pkey,
        my_skey: srv_skey,
        his_pkey: cli_pkey,
        conn_my_pkey: sess_pkey,
        conn_my_skey: sess_skey,
        conn_his_pkey: cli_sess_pkey,
        addr: addr,
        port: port
      }
    GenServer.cast(SNet.Manager, {:peer_up, cli_pkey, self(), addr, his_port})
    Logger.info "New peer: #{print_id state} at #{inspect addr}:#{port}"
    GenServer.cast(self(), :init_pull)

    {:noreply, state}
  end

  def handle_cast({:send_msg, msg}, state) do
    send_msg(state, msg)
    {:noreply, state}
  end

  def handle_cast(:init_pull, state) do
    id_list = (for {id, _} <- GenServer.call(Shard.Manager, :list), do: id)
    send_msg(state, {:interested, id_list})
    {:noreply, state}
  end

  defp encode_pkt(pkt, pk, sk) do
    {:ok, n} = Salty.Random.buf Box.noncebytes
    {:ok, msg} = Box.easy(pkt, n, pk, sk)
    n <> msg
  end

  defp decode_pkt(pkt, pk, sk) do
    n = binary_part(pkt, 0, Box.noncebytes)
    enc = binary_part(pkt, Box.noncebytes, (byte_size pkt) - Box.noncebytes)
    {:ok, msg} = Box.open_easy(enc, n, pk, sk)
    msg
  end

  defp send_msg(state, msg) do
    msgbin = :erlang.term_to_binary msg
    enc = encode_pkt(msgbin, state.conn_his_pkey, state.conn_my_skey)
    :gen_tcp.send(state.socket, enc)
  end

  def handle_info({:tcp, _socket, raw_data}, state) do
    msg = decode_pkt(raw_data, state.conn_his_pkey, state.conn_my_skey)
    handle_packet(:erlang.binary_to_term(msg, [:safe]), state)
    {:noreply, state}
  end

  def handle_info({:tcp_closed, _socket}, state) do
    Logger.info "Disconnected: #{print_id state} at #{inspect state.addr}:#{state.port}"
    GenServer.cast(SNet.Manager, {:peer_down, state.his_pkey, state.addr, state.port})
    exit(:normal)
  end

  defp handle_packet({:interested, shards}, state) do
    GenServer.cast(Shard.Manager, {:interested, state.his_pkey, self(), shards})
  end

  defp handle_packet({shard, msg}, state) do
    GenServer.cast(Shard.Manager, {:dispatch, state.his_pkey, self(), shard, msg})
  end

  defp print_id(state) do
    state.his_pkey
    |> binary_part(0, 8)
    |> Base.encode16
    |> String.downcase
  end
end