From 0ecac7973fd02f44af9c8bc5765f7c159c94b23a Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 8 May 2015 13:59:44 -0400 Subject: Fusion AddFirstK and AddLastK transformers, and add 'input_time' field. --- config/time_simple_mlp_tgtcls_2_cswdtx.py | 2 -- data/transformers.py | 23 +++++++---------------- train.py | 10 ++++------ 3 files changed, 11 insertions(+), 24 deletions(-) diff --git a/config/time_simple_mlp_tgtcls_2_cswdtx.py b/config/time_simple_mlp_tgtcls_2_cswdtx.py index 4579df3..eb69714 100644 --- a/config/time_simple_mlp_tgtcls_2_cswdtx.py +++ b/config/time_simple_mlp_tgtcls_2_cswdtx.py @@ -32,8 +32,6 @@ embed_weights_init = IsotropicGaussian(0.001) mlp_weights_init = IsotropicGaussian(0.01) mlp_biases_init = Constant(0.001) -exp_base = 1.5 - learning_rate = 0.0001 momentum = 0.99 batch_size = 32 diff --git a/data/transformers.py b/data/transformers.py index e3ff7b1..57747fc 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -67,10 +67,12 @@ class TaxiGenerateSplits(Transformer): return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) -class TaxiAddFirstK(Transformer): +class TaxiAddFirstLastLen(Transformer): def __init__(self, k, stream): - super(TaxiAddFirstK, self).__init__(stream) - self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude') + super(TaxiAddFirstLastLen, self).__init__(stream) + self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude', + 'last_k_latitude', 'last_k_longitude', + 'input_time') self.id_latitude = stream.sources.index('latitude') self.id_longitude = stream.sources.index('longitude') self.k = k @@ -81,23 +83,12 @@ class TaxiAddFirstK(Transformer): dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k], dtype=theano.config.floatX)) - return data + first_k - -class TaxiAddLastK(Transformer): - def __init__(self, k, stream): - super(TaxiAddLastK, self).__init__(stream) - self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude') - self.id_latitude = stream.sources.index('latitude') - self.id_longitude = stream.sources.index('longitude') - self.k = k - def get_data(self, request=None): - if request is not None: raise ValueError - data = next(self.child_epoch_iterator) last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:], dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) - return data + last_k + input_time = (15 * (len(data[self.id_latitude]) - 1),) + return data + first_k + last_k + input_time class TaxiAddDateTime(Transformer): def __init__(self, stream): diff --git a/train.py b/train.py index 4e2b983..8d9f4ad 100755 --- a/train.py +++ b/train.py @@ -44,8 +44,7 @@ def setup_train_stream(req_vars, valid_trips_ids): train = transformers.TaxiGenerateSplits(train, max_splits=100) train = transformers.TaxiAddDateTime(train) - train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train) - train = transformers.TaxiAddLastK(config.n_begin_end_pts, train) + train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train) train = transformers.Select(train, tuple(req_vars)) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) @@ -56,8 +55,7 @@ def setup_valid_stream(req_vars): valid = TaxiStream(config.valid_set, 'valid.hdf5') valid = transformers.TaxiAddDateTime(valid) - valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid) - valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid) + valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid) valid = transformers.Select(valid, tuple(req_vars)) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) @@ -68,8 +66,8 @@ def setup_test_stream(req_vars): test = TaxiStream('test') test = transformers.TaxiAddDateTime(test) - test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test) - test = transformers.TaxiAddLastK(config.n_begin_end_pts, test) + test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test) + test = transformers.TaxiAddLast(config.n_begin_end_pts, test) test = transformers.Select(test, tuple(req_vars)) test_stream = Batch(test, iteration_scheme=ConstantScheme(1000)) -- cgit v1.2.3