diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
commit | a25d4fb6e92f203183de2d89e8c467a6b14e1730 (patch) | |
tree | 6d448760e647572d52242f5726224cdf20e832ee /model.py | |
parent | ccd1245db7f6799ab4e1f45a8cead85ed67f1c72 (diff) | |
download | taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.tar.gz taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.zip |
Implement HDist, transformer that selects k at a random position.
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -24,6 +24,7 @@ from blocks.extensions.monitoring import DataStreamMonitoring import data import transformers +import hdist n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday n_dom = 31 @@ -38,8 +39,8 @@ n_end_pts = 5 dim_embed = 50 dim_hidden = 200 -learning_rate = 0.1 -batch_size = 32 +learning_rate = 0.01 +batch_size = 64 def main(): # The input and the targets @@ -67,7 +68,8 @@ def main(): outputs = output_layer.apply(hidden) # Calculate the cost - cost = (outputs - y).norm(2, axis=1).mean() + # cost = (outputs - y).norm(2, axis=1).mean() + cost = hdist.hdist(outputs, y).mean() cost.name = 'cost' # Initialization @@ -84,14 +86,16 @@ def main(): # Load the training and test data train = data.train_data train = DataStream(train) - train = transformers.add_extremities(train, n_begin_end_pts) + train = transformers.add_first_k(n_begin_end_pts, train) + train = transformers.add_random_k(n_begin_end_pts, train) train = transformers.add_destination(train) train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size)) valid = data.valid_data valid = DataStream(valid) - valid = transformers.add_extremities(valid, n_begin_end_pts) + valid = transformers.add_first_k(n_begin_end_pts, valid) + valid = transformers.add_random_k(n_begin_end_pts, valid) valid = transformers.add_destination(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) @@ -107,8 +111,8 @@ def main(): extensions=[DataStreamMonitoring([cost], valid_stream, prefix='valid', - every_n_batches=100), - Printing(every_n_batches=100), + every_n_batches=1000), + Printing(every_n_batches=1000), # Dump('taxi_model', every_n_batches=100), # LoadFromDump('taxi_model'), ] |