aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/data/merklesearchtree.ex28
-rw-r--r--test/mst_test.exs13
2 files changed, 40 insertions, 1 deletions
diff --git a/lib/data/merklesearchtree.ex b/lib/data/merklesearchtree.ex
index d23e8a7..29c1dab 100644
--- a/lib/data/merklesearchtree.ex
+++ b/lib/data/merklesearchtree.ex
@@ -114,7 +114,7 @@ defmodule SData.MerkleSearchTree do
defp aux_insert_after_first(store, lst, key, value) do
case lst do
- [ {k1, _, r1} | rst ] when k1 == key ->
+ [ {k1, _old_value, r1} | rst ] when k1 == key ->
{ [ {k1, value, r1} | rst ], store }
[ {k1, v1, r1} ] ->
{r1a, r1b, new_store} = split(store, r1, key)
@@ -144,6 +144,32 @@ defmodule SData.MerkleSearchTree do
end
end
+ def get(state, key) do
+ get(state.store, state.root, key)
+ end
+
+ defp get(store, root, key) do
+ case root do
+ nil -> nil
+ _ ->
+ { _, low, lst } = store[root]
+ get_aux(store, low, lst, key)
+ end
+ end
+
+ defp get_aux(store, low, lst, key) do
+ case lst do
+ [] ->
+ get(store, low, key)
+ [ {k, v, _} | _ ] when key == k ->
+ v
+ [ {k, _, _} | _ ] when key < k ->
+ get(store, low, key)
+ [ {k, _, low2} | rst ] when key > k ->
+ get_aux(store, low2, rst, key)
+ end
+ end
+
def dump(state) do
dump(state.store, state.root, "")
end
diff --git a/test/mst_test.exs b/test/mst_test.exs
index 68886c7..6615b6e 100644
--- a/test/mst_test.exs
+++ b/test/mst_test.exs
@@ -11,6 +11,14 @@ defmodule ShardTest.MST do
z = Enum.reduce(Enum.shuffle(0..1000), MST.new(),
fn i, acc -> MST.insert(acc, i, i) end)
+ for i <- 0..1000 do
+ assert MST.get(y, i) == i
+ assert MST.get(z, i) == i
+ end
+ assert MST.get(y, 9999) == nil
+ assert MST.get(z, -1001) == nil
+ assert MST.get(z, 1.01) == nil
+
IO.puts "y.root: #{y.root|>Base.encode16}"
IO.puts "z.root: #{z.root|>Base.encode16}"
assert y.root == z.root
@@ -28,6 +36,11 @@ defmodule ShardTest.MST do
z = Enum.reduce(Enum.shuffle(items), MST.new(),
fn {k, v}, acc -> MST.insert(acc, k, v) end)
+ for {k, v} <- items do
+ assert MST.get(y, k) == v
+ assert MST.get(z, k) == v
+ end
+
IO.puts "y.root: #{y.root|>Base.encode16}"
IO.puts "z.root: #{z.root|>Base.encode16}"
assert y.root == z.root