aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/model.py b/model.py
index a5c820e..0d47710 100644
--- a/model.py
+++ b/model.py
@@ -2,6 +2,9 @@ import logging
import os
from argparse import ArgumentParser
+import numpy
+
+import theano
from theano import tensor
from theano.ifelse import ifelse
@@ -69,6 +72,7 @@ def main():
# Calculate the cost
# cost = (outputs - y).norm(2, axis=1).mean()
+ # outputs = numpy.array([[ -8.621953, 41.162142]], dtype='float32') + 0 * outputs
cost = hdist.hdist(outputs, y).mean()
cost.name = 'cost'
@@ -95,8 +99,8 @@ def main():
valid = data.valid_data
valid = DataStream(valid)
valid = transformers.add_first_k(n_begin_end_pts, valid)
- valid = transformers.add_random_k(n_begin_end_pts, valid)
- valid = transformers.add_destination(valid)
+ valid = transformers.add_last_k(n_begin_end_pts, valid)
+ valid = transformers.concat_destination_xy(valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))