diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-21 17:05:07 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-21 17:05:07 -0400 |
commit | 222971123f2741ee3689092fb4396dac83a13338 (patch) | |
tree | 59acaae81aaf6f57474be05fb93de1eb00061689 /model | |
parent | f6d2c6fc47f93b158b70b5c0c9a45324041ca4d5 (diff) | |
download | taxi-222971123f2741ee3689092fb4396dac83a13338.tar.gz taxi-222971123f2741ee3689092fb4396dac83a13338.zip |
Implement cut-based iteration scheme (SLOW!!!)
Diffstat (limited to 'model')
-rw-r--r-- | model/mlp.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/model/mlp.py b/model/mlp.py index 9c84ef9..576b45b 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -8,6 +8,7 @@ from blocks.bricks import application, MLP, Rectifier, Initializable import data from data import transformers from data.hdf5 import TaxiDataset, TaxiStream +from data.cut import TaxiTimeCutScheme from model import ContextEmbedder @@ -51,13 +52,13 @@ class Stream(object): def train(self, req_vars): stream = TaxiDataset('train') - stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) - stream = transformers.TaxiGenerateSplits(stream, max_splits=100) + stream = transformers.TaxiGenerateSplits(stream, max_splits=1) stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) |