diff options
-rw-r--r-- | data/cut.py | 31 | ||||
-rwxr-xr-x | data/make_time_index.py | 50 | ||||
-rw-r--r-- | model/mlp.py | 5 |
3 files changed, 84 insertions, 2 deletions
diff --git a/data/cut.py b/data/cut.py new file mode 100644 index 0000000..1253434 --- /dev/null +++ b/data/cut.py @@ -0,0 +1,31 @@ +from fuel.schemes import IterationScheme +import sqlite3 +import random +import os +from picklable_itertools import iter_ + +import data + +first_time = 1372636853 +last_time = 1404172787 + + +class TaxiTimeCutScheme(IterationScheme): + def __init__(self, dbfile=None, use_cuts=None): + self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile + self.use_cuts = use_cuts + + def get_request_iterator(self): + cuts = self.use_cuts + if cuts == None: + cuts = [random.randrange(first_time, last_time) for _ in range(100)] + + l = [] + with sqlite3.connect(self.dbfile) as db: + c = db.cursor() + for cut in cuts: + l = l + [i for (i,) in + c.execute('SELECT trip FROM trip_times WHERE begin <= ? AND end >= ?', (cut, cut))] + + return iter_(l) + diff --git a/data/make_time_index.py b/data/make_time_index.py new file mode 100755 index 0000000..c51d075 --- /dev/null +++ b/data/make_time_index.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# Make a valid dataset by cutting the training set at specified timestamps + +import os +import sys +import importlib + +import h5py +import numpy + +import data +from data.hdf5 import taxi_it + +import sqlite3 + +def make_valid(outpath): + times = [] + for i, line in enumerate(taxi_it('train')): + time = line['timestamp'] + latitude = line['latitude'] + + if len(latitude) == 0: + continue + + duration = 15 * (len(latitude) - 1) + + times.append((i, int(time), int(time + duration))) + if i % 1000 == 0: + print times[-1] + + + with sqlite3.connect(outpath) as timedb: + c = timedb.cursor() + c.execute(''' + CREATE TABLE trip_times + (trip INTEGER, begin INTEGER, end INTEGER) + ''') + print "Adding data..." + c.executemany('INSERT INTO trip_times(trip, begin, end) VALUES(?, ?, ?)', times) + timedb.commit() + print "Creating index..." + c.execute('''CREATE INDEX trip_time_index ON trip_times (begin, end)''') + + +if __name__ == '__main__': + if len(sys.argv) < 1 or len(sys.argv) > 2: + print >> sys.stderr, 'Usage: %s [outfile]' % sys.argv[0] + sys.exit(1) + outpath = os.path.join(data.path, 'time_index.db') if len(sys.argv) < 2 else sys.argv[1] + make_valid(outpath) diff --git a/model/mlp.py b/model/mlp.py index 9c84ef9..576b45b 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -8,6 +8,7 @@ from blocks.bricks import application, MLP, Rectifier, Initializable import data from data import transformers from data.hdf5 import TaxiDataset, TaxiStream +from data.cut import TaxiTimeCutScheme from model import ContextEmbedder @@ -51,13 +52,13 @@ class Stream(object): def train(self, req_vars): stream = TaxiDataset('train') - stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) - stream = transformers.TaxiGenerateSplits(stream, max_splits=100) + stream = transformers.TaxiGenerateSplits(stream, max_splits=1) stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) |