aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:26 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-22 15:51:47 -0400
commit57fe795d14e70c06c9bdbe6fe903588b5f75474e (patch)
treed7b0de1569a67dfc55dc6481c35e976d22572ebb /model
parent448e848796757ad9f0a2f681886f868b8f22e81f (diff)
downloadtaxi-57fe795d14e70c06c9bdbe6fe903588b5f75474e.tar.gz
taxi-57fe795d14e70c06c9bdbe6fe903588b5f75474e.zip
Add parametrizability for how the training data is presented
Diffstat (limited to 'model')
-rw-r--r--model/mlp.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/model/mlp.py b/model/mlp.py
index 05898a5..fc86b7b 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -51,14 +51,18 @@ class Stream(object):
self.config = config
def train(self, req_vars):
- stream = TaxiDataset('train')
- stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
-
valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = TaxiDataset('train')
+
+ if hasattr(self.config, 'use_cuts_for_trainig') and self.config.use_cuts_for_training:
+ stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
+ else:
+ stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
+
stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
- stream = transformers.TaxiGenerateSplits(stream, max_splits=1)
+ stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)