diff options
Diffstat (limited to 'model/mlp.py')
-rw-r--r-- | model/mlp.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/model/mlp.py b/model/mlp.py index 05898a5..fc86b7b 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -51,14 +51,18 @@ class Stream(object): self.config = config def train(self, req_vars): - stream = TaxiDataset('train') - stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = TaxiDataset('train') + + if hasattr(self.config, 'use_cuts_for_trainig') and self.config.use_cuts_for_training: + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) + else: + stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) - stream = transformers.TaxiGenerateSplits(stream, max_splits=1) + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) |