diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
commit | 9a60f6c4e39c09187710608a9e225b6024b34364 (patch) | |
tree | 92e43401b6c6d3982081a35ec680b82856ec00c0 /model.py | |
parent | 107b3798cca35472e158d94f36a0bd08f3fe1fe8 (diff) | |
download | taxi-9a60f6c4e39c09187710608a9e225b6024b34364.tar.gz taxi-9a60f6c4e39c09187710608a9e225b6024b34364.zip |
Add validation set ; fix lat/lon
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 8 |
1 files changed, 6 insertions, 2 deletions
@@ -2,6 +2,9 @@ import logging import os from argparse import ArgumentParser +import numpy + +import theano from theano import tensor from theano.ifelse import ifelse @@ -69,6 +72,7 @@ def main(): # Calculate the cost # cost = (outputs - y).norm(2, axis=1).mean() + # outputs = numpy.array([[ -8.621953, 41.162142]], dtype='float32') + 0 * outputs cost = hdist.hdist(outputs, y).mean() cost.name = 'cost' @@ -95,8 +99,8 @@ def main(): valid = data.valid_data valid = DataStream(valid) valid = transformers.add_first_k(n_begin_end_pts, valid) - valid = transformers.add_random_k(n_begin_end_pts, valid) - valid = transformers.add_destination(valid) + valid = transformers.add_last_k(n_begin_end_pts, valid) + valid = transformers.concat_destination_xy(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) |