aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/data.py b/data.py
index 92efefc..92aa062 100644
--- a/data.py
+++ b/data.py
@@ -2,6 +2,7 @@ import ast, csv
import socket
import fuel
import numpy
+import h5py
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
@@ -28,7 +29,7 @@ dataset_size = 1710670
def make_client_ids():
f = h5py.File(H5DATA_PATH, "r")
- l = f['uniq_origin_call']
+ l = f['unique_origin_call']
r = {}
for i in range(l.shape[0]):
r[l[i]] = i
@@ -37,7 +38,7 @@ def make_client_ids():
client_ids = make_client_ids()
def get_client_id(n):
- if n in client_ids:
+ if n in client_ids and client_ids[n] <= n_train_clients:
return client_ids[n]
else:
return 0