aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 13:08:56 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 13:08:56 -0400
commita25d4fb6e92f203183de2d89e8c467a6b14e1730 (patch)
tree6d448760e647572d52242f5726224cdf20e832ee /model.py
parentccd1245db7f6799ab4e1f45a8cead85ed67f1c72 (diff)
downloadtaxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.tar.gz
taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.zip
Implement HDist, transformer that selects k at a random position.
Diffstat (limited to 'model.py')
-rw-r--r--model.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/model.py b/model.py
index 95b002f..a5c820e 100644
--- a/model.py
+++ b/model.py
@@ -24,6 +24,7 @@ from blocks.extensions.monitoring import DataStreamMonitoring
import data
import transformers
+import hdist
n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
n_dom = 31
@@ -38,8 +39,8 @@ n_end_pts = 5
dim_embed = 50
dim_hidden = 200
-learning_rate = 0.1
-batch_size = 32
+learning_rate = 0.01
+batch_size = 64
def main():
# The input and the targets
@@ -67,7 +68,8 @@ def main():
outputs = output_layer.apply(hidden)
# Calculate the cost
- cost = (outputs - y).norm(2, axis=1).mean()
+ # cost = (outputs - y).norm(2, axis=1).mean()
+ cost = hdist.hdist(outputs, y).mean()
cost.name = 'cost'
# Initialization
@@ -84,14 +86,16 @@ def main():
# Load the training and test data
train = data.train_data
train = DataStream(train)
- train = transformers.add_extremities(train, n_begin_end_pts)
+ train = transformers.add_first_k(n_begin_end_pts, train)
+ train = transformers.add_random_k(n_begin_end_pts, train)
train = transformers.add_destination(train)
train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size))
valid = data.valid_data
valid = DataStream(valid)
- valid = transformers.add_extremities(valid, n_begin_end_pts)
+ valid = transformers.add_first_k(n_begin_end_pts, valid)
+ valid = transformers.add_random_k(n_begin_end_pts, valid)
valid = transformers.add_destination(valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))
@@ -107,8 +111,8 @@ def main():
extensions=[DataStreamMonitoring([cost], valid_stream,
prefix='valid',
- every_n_batches=100),
- Printing(every_n_batches=100),
+ every_n_batches=1000),
+ Printing(every_n_batches=1000),
# Dump('taxi_model', every_n_batches=100),
# LoadFromDump('taxi_model'),
]