aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/cut.py6
-rwxr-xr-xdata/make_time_index.py2
-rw-r--r--model/dest_simple_mlp_tgtcls.py2
-rw-r--r--model/mlp.py7
4 files changed, 11 insertions, 6 deletions
diff --git a/data/cut.py b/data/cut.py
index 1253434..7853030 100644
--- a/data/cut.py
+++ b/data/cut.py
@@ -24,8 +24,10 @@ class TaxiTimeCutScheme(IterationScheme):
with sqlite3.connect(self.dbfile) as db:
c = db.cursor()
for cut in cuts:
- l = l + [i for (i,) in
- c.execute('SELECT trip FROM trip_times WHERE begin <= ? AND end >= ?', (cut, cut))]
+ part = [i for (i,) in
+ c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?',
+ (cut - 40000, cut, cut))]
+ l = l + part
return iter_(l)
diff --git a/data/make_time_index.py b/data/make_time_index.py
index c51d075..c2838e0 100755
--- a/data/make_time_index.py
+++ b/data/make_time_index.py
@@ -39,7 +39,7 @@ def make_valid(outpath):
c.executemany('INSERT INTO trip_times(trip, begin, end) VALUES(?, ?, ?)', times)
timedb.commit()
print "Creating index..."
- c.execute('''CREATE INDEX trip_time_index ON trip_times (begin, end)''')
+ c.execute('''CREATE INDEX trip_begin_index ON trip_times (begin)''')
if __name__ == '__main__':
diff --git a/model/dest_simple_mlp_tgtcls.py b/model/dest_simple_mlp_tgtcls.py
index 2d65097..46fca2b 100644
--- a/model/dest_simple_mlp_tgtcls.py
+++ b/model/dest_simple_mlp_tgtcls.py
@@ -9,7 +9,7 @@ from model.mlp import FFMLP, Stream
class Model(FFMLP):
def __init__(self, config, **kwargs):
- super(Model, self, output_layer=Softmax).__init__(config, **kwargs)
+ super(Model, self).__init__(config, output_layer=Softmax, **kwargs)
self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
@application(outputs=['destination'])
diff --git a/model/mlp.py b/model/mlp.py
index 576b45b..05898a5 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -1,6 +1,6 @@
from theano import tensor
-from fuel.transformers import Batch
+from fuel.transformers import Batch, MultiProcessing
from fuel.streams import DataStream
from fuel.schemes import ConstantScheme, ShuffledExampleScheme
from blocks.bricks import application, MLP, Rectifier, Initializable
@@ -63,7 +63,10 @@ class Stream(object):
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
stream = transformers.Select(stream, tuple(req_vars))
- return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+
+ return stream
def valid(self, req_vars):
stream = TaxiStream(self.config.valid_set, 'valid.hdf5')