From 43e106e6630030dd34813295fe1d07bb86025402 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 29 Apr 2015 19:41:36 -0400 Subject: Fix --- data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'data.py') diff --git a/data.py b/data.py index 92efefc..92aa062 100644 --- a/data.py +++ b/data.py @@ -2,6 +2,7 @@ import ast, csv import socket import fuel import numpy +import h5py from enum import Enum from fuel.datasets import Dataset from fuel.streams import DataStream @@ -28,7 +29,7 @@ dataset_size = 1710670 def make_client_ids(): f = h5py.File(H5DATA_PATH, "r") - l = f['uniq_origin_call'] + l = f['unique_origin_call'] r = {} for i in range(l.shape[0]): r[l[i]] = i @@ -37,7 +38,7 @@ def make_client_ids(): client_ids = make_client_ids() def get_client_id(n): - if n in client_ids: + if n in client_ids and client_ids[n] <= n_train_clients: return client_ids[n] else: return 0 -- cgit v1.2.3