From 9a60f6c4e39c09187710608a9e225b6024b34364 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 27 Apr 2015 17:27:43 -0400 Subject: Add validation set ; fix lat/lon --- model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'model.py') diff --git a/model.py b/model.py index a5c820e..0d47710 100644 --- a/model.py +++ b/model.py @@ -2,6 +2,9 @@ import logging import os from argparse import ArgumentParser +import numpy + +import theano from theano import tensor from theano.ifelse import ifelse @@ -69,6 +72,7 @@ def main(): # Calculate the cost # cost = (outputs - y).norm(2, axis=1).mean() + # outputs = numpy.array([[ -8.621953, 41.162142]], dtype='float32') + 0 * outputs cost = hdist.hdist(outputs, y).mean() cost.name = 'cost' @@ -95,8 +99,8 @@ def main(): valid = data.valid_data valid = DataStream(valid) valid = transformers.add_first_k(n_begin_end_pts, valid) - valid = transformers.add_random_k(n_begin_end_pts, valid) - valid = transformers.add_destination(valid) + valid = transformers.add_last_k(n_begin_end_pts, valid) + valid = transformers.concat_destination_xy(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) -- cgit v1.2.3