diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-10 17:16:20 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-10 17:16:20 -0400 |
commit | 27a0e0949c6ca3f7bd18569a23ddd0e1b3e9a64e (patch) | |
tree | 1201f8abe7c664f15084b1d48ed8de92c72879d8 /model | |
parent | 793be7b049cecba43072858341dc7006fef352e7 (diff) | |
download | taxi-27a0e0949c6ca3f7bd18569a23ddd0e1b3e9a64e.tar.gz taxi-27a0e0949c6ca3f7bd18569a23ddd0e1b3e9a64e.zip |
Batch shuffling
Diffstat (limited to 'model')
-rw-r--r-- | model/mlp.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/model/mlp.py b/model/mlp.py index 1f53e8c..7d04c82 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -52,6 +52,12 @@ class FFMLP(Initializable): def predict_inputs(self): return self.inputs +class UniformGenerator(object): + def __init__(self): + self.rng = numpy.random.RandomState(123) + def __call__(self, *args): + return float(self.rng.uniform()) + class Stream(object): def __init__(self, config): self.config = config @@ -69,17 +75,15 @@ class Stream(object): stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - stream = transformers.add_destination(stream) - - stream = transformers.taxi_add_datetime(stream) - stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) - stream = transformers.Select(stream, tuple(req_vars)) if hasattr(self.config, 'shuffle_batch_size'): stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size)) - rng = numpy.random.RandomState(123) - stream = Mapping(stream, SortMapping(lambda x: float(rng.uniform()))) + stream = Mapping(stream, SortMapping(key=UniformGenerator())) stream = Unpack(stream) + + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) + stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) |