aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:41:36 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-29 19:41:36 -0400
commit43e106e6630030dd34813295fe1d07bb86025402 (patch)
treec3a604d8d023e35532522a18da06e8c25dc251c6
parent8b27690c8d77585f173412e5719787c48272674e (diff)
downloadtaxi-43e106e6630030dd34813295fe1d07bb86025402.tar.gz
taxi-43e106e6630030dd34813295fe1d07bb86025402.zip
Fix
-rw-r--r--config/model_0.py2
-rw-r--r--data.py5
-rw-r--r--model.py6
3 files changed, 7 insertions, 6 deletions
diff --git a/config/model_0.py b/config/model_0.py
index 736ef30..f26a99c 100644
--- a/config/model_0.py
+++ b/config/model_0.py
@@ -9,7 +9,7 @@ n_valid = 1000
dim_embed = 10
dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed
-dim_hidden = [200]
+dim_hidden = [100]
dim_output = 2
learning_rate = 0.002
diff --git a/data.py b/data.py
index 92efefc..92aa062 100644
--- a/data.py
+++ b/data.py
@@ -2,6 +2,7 @@ import ast, csv
import socket
import fuel
import numpy
+import h5py
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
@@ -28,7 +29,7 @@ dataset_size = 1710670
def make_client_ids():
f = h5py.File(H5DATA_PATH, "r")
- l = f['uniq_origin_call']
+ l = f['unique_origin_call']
r = {}
for i in range(l.shape[0]):
r[l[i]] = i
@@ -37,7 +38,7 @@ def make_client_ids():
client_ids = make_client_ids()
def get_client_id(n):
- if n in client_ids:
+ if n in client_ids and client_ids[n] <= n_train_clients:
return client_ids[n]
else:
return 0
diff --git a/model.py b/model.py
index a44a6cf..aff9fd7 100644
--- a/model.py
+++ b/model.py
@@ -46,7 +46,7 @@ if __name__ == "__main__":
def setup_stream():
# Load the training and test data
- train = H5PYDataset(H5DATA_PATH,
+ train = H5PYDataset(data.H5DATA_PATH,
which_set='train',
subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
@@ -59,7 +59,7 @@ def setup_stream():
'destination_latitude', 'destination_longitude'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
- valid = H5PYDataset(H5DATA_PATH,
+ valid = H5PYDataset(data.H5DATA_PATH,
which_set='train',
subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
@@ -164,7 +164,7 @@ def main():
# Checkpoint('model.pkl', every_n_batches=100),
Dump('taxi_model', every_n_batches=1000),
LoadFromDump('taxi_model'),
- FinishAfter(after_epoch=1)
+ FinishAfter(after_epoch=5)
]
main_loop = MainLoop(