aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/model.py b/model.py
index 360237e..95b002f 100644
--- a/model.py
+++ b/model.py
@@ -29,7 +29,7 @@ n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
n_dom = 31
n_hour = 24
-n_clients = 57106
+n_clients = 57105
n_stands = 63
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
@@ -45,9 +45,9 @@ def main():
# The input and the targets
x_firstk = tensor.matrix('first_k')
x_lastk = tensor.matrix('last_k')
- x_client = tensor.lmatrix('origin_call')
- x_stand = tensor.lmatrix('origin_stand')
- y = tensor.vector('destination')
+ x_client = tensor.lvector('origin_call')
+ x_stand = tensor.lvector('origin_stand')
+ y = tensor.matrix('destination')
# Define the model
client_embed_table = LookupTable(length=n_clients+1, dim=dim_embed, name='client_lookup')
@@ -60,12 +60,15 @@ def main():
client_embed = client_embed_table.apply(x_client).flatten(ndim=2)
stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2)
- inputs = tensor.concatenate([x_firstk, x_lastk, client_embed, stand_embed], axis=1)
+ inputs = tensor.concatenate([x_firstk, x_lastk,
+ client_embed, stand_embed],
+ axis=1)
hidden = hidden_layer.apply(inputs)
outputs = output_layer.apply(hidden)
# Calculate the cost
cost = (outputs - y).norm(2, axis=1).mean()
+ cost.name = 'cost'
# Initialization
client_embed_table.weights_init = IsotropicGaussian(0.001)
@@ -83,12 +86,14 @@ def main():
train = DataStream(train)
train = transformers.add_extremities(train, n_begin_end_pts)
train = transformers.add_destination(train)
+ train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size))
valid = data.valid_data
valid = DataStream(valid)
valid = transformers.add_extremities(valid, n_begin_end_pts)
valid = transformers.add_destination(valid)
+ valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))
@@ -103,9 +108,10 @@ def main():
extensions=[DataStreamMonitoring([cost], valid_stream,
prefix='valid',
every_n_batches=100),
- Printing(every_n_batches=100),
- Dump('ngram_blocks_model', every_n_batches=100),
- LoadFromDump('ngram_blocks_model')]
+ Printing(every_n_batches=100),
+ # Dump('taxi_model', every_n_batches=100),
+ # LoadFromDump('taxi_model'),
+ ]
main_loop = MainLoop(
model=Model([cost]),