diff options
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 8 |
1 files changed, 6 insertions, 2 deletions
@@ -2,6 +2,9 @@ import logging import os from argparse import ArgumentParser +import numpy + +import theano from theano import tensor from theano.ifelse import ifelse @@ -69,6 +72,7 @@ def main(): # Calculate the cost # cost = (outputs - y).norm(2, axis=1).mean() + # outputs = numpy.array([[ -8.621953, 41.162142]], dtype='float32') + 0 * outputs cost = hdist.hdist(outputs, y).mean() cost.name = 'cost' @@ -95,8 +99,8 @@ def main(): valid = data.valid_data valid = DataStream(valid) valid = transformers.add_first_k(n_begin_end_pts, valid) - valid = transformers.add_random_k(n_begin_end_pts, valid) - valid = transformers.add_destination(valid) + valid = transformers.add_last_k(n_begin_end_pts, valid) + valid = transformers.concat_destination_xy(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) |