aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/mlp.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/model/mlp.py b/model/mlp.py
index 9c84ef9..576b45b 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -8,6 +8,7 @@ from blocks.bricks import application, MLP, Rectifier, Initializable
import data
from data import transformers
from data.hdf5 import TaxiDataset, TaxiStream
+from data.cut import TaxiTimeCutScheme
from model import ContextEmbedder
@@ -51,13 +52,13 @@ class Stream(object):
def train(self, req_vars):
stream = TaxiDataset('train')
- stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
+ stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
- stream = transformers.TaxiGenerateSplits(stream, max_splits=100)
+ stream = transformers.TaxiGenerateSplits(stream, max_splits=1)
stream = transformers.TaxiAddDateTime(stream)
stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)