aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/model.py b/model.py
index 95b002f..a5c820e 100644
--- a/model.py
+++ b/model.py
@@ -24,6 +24,7 @@ from blocks.extensions.monitoring import DataStreamMonitoring
import data
import transformers
+import hdist
n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
n_dom = 31
@@ -38,8 +39,8 @@ n_end_pts = 5
dim_embed = 50
dim_hidden = 200
-learning_rate = 0.1
-batch_size = 32
+learning_rate = 0.01
+batch_size = 64
def main():
# The input and the targets
@@ -67,7 +68,8 @@ def main():
outputs = output_layer.apply(hidden)
# Calculate the cost
- cost = (outputs - y).norm(2, axis=1).mean()
+ # cost = (outputs - y).norm(2, axis=1).mean()
+ cost = hdist.hdist(outputs, y).mean()
cost.name = 'cost'
# Initialization
@@ -84,14 +86,16 @@ def main():
# Load the training and test data
train = data.train_data
train = DataStream(train)
- train = transformers.add_extremities(train, n_begin_end_pts)
+ train = transformers.add_first_k(n_begin_end_pts, train)
+ train = transformers.add_random_k(n_begin_end_pts, train)
train = transformers.add_destination(train)
train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size))
valid = data.valid_data
valid = DataStream(valid)
- valid = transformers.add_extremities(valid, n_begin_end_pts)
+ valid = transformers.add_first_k(n_begin_end_pts, valid)
+ valid = transformers.add_random_k(n_begin_end_pts, valid)
valid = transformers.add_destination(valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))
@@ -107,8 +111,8 @@ def main():
extensions=[DataStreamMonitoring([cost], valid_stream,
prefix='valid',
- every_n_batches=100),
- Printing(every_n_batches=100),
+ every_n_batches=1000),
+ Printing(every_n_batches=1000),
# Dump('taxi_model', every_n_batches=100),
# LoadFromDump('taxi_model'),
]