aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
commit0ecac7973fd02f44af9c8bc5765f7c159c94b23a (patch)
treeec28ec6253d84b3c55f8fa88adce4e720d26bd07
parent1ffd1fc355f6fddcb6cd3d93c0df58513d064472 (diff)
downloadtaxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.tar.gz
taxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.zip
Fusion AddFirstK and AddLastK transformers, and add 'input_time' field.
-rw-r--r--config/time_simple_mlp_tgtcls_2_cswdtx.py2
-rw-r--r--data/transformers.py23
-rwxr-xr-xtrain.py10
3 files changed, 11 insertions, 24 deletions
diff --git a/config/time_simple_mlp_tgtcls_2_cswdtx.py b/config/time_simple_mlp_tgtcls_2_cswdtx.py
index 4579df3..eb69714 100644
--- a/config/time_simple_mlp_tgtcls_2_cswdtx.py
+++ b/config/time_simple_mlp_tgtcls_2_cswdtx.py
@@ -32,8 +32,6 @@ embed_weights_init = IsotropicGaussian(0.001)
mlp_weights_init = IsotropicGaussian(0.01)
mlp_biases_init = Constant(0.001)
-exp_base = 1.5
-
learning_rate = 0.0001
momentum = 0.99
batch_size = 32
diff --git a/data/transformers.py b/data/transformers.py
index e3ff7b1..57747fc 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -67,10 +67,12 @@ class TaxiGenerateSplits(Transformer):
return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)])
-class TaxiAddFirstK(Transformer):
+class TaxiAddFirstLastLen(Transformer):
def __init__(self, k, stream):
- super(TaxiAddFirstK, self).__init__(stream)
- self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
+ super(TaxiAddFirstLastLen, self).__init__(stream)
+ self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude',
+ 'last_k_latitude', 'last_k_longitude',
+ 'input_time')
self.id_latitude = stream.sources.index('latitude')
self.id_longitude = stream.sources.index('longitude')
self.k = k
@@ -81,23 +83,12 @@ class TaxiAddFirstK(Transformer):
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
dtype=theano.config.floatX))
- return data + first_k
-
-class TaxiAddLastK(Transformer):
- def __init__(self, k, stream):
- super(TaxiAddLastK, self).__init__(stream)
- self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
- self.id_latitude = stream.sources.index('latitude')
- self.id_longitude = stream.sources.index('longitude')
- self.k = k
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
dtype=theano.config.floatX))
- return data + last_k
+ input_time = (15 * (len(data[self.id_latitude]) - 1),)
+ return data + first_k + last_k + input_time
class TaxiAddDateTime(Transformer):
def __init__(self, stream):
diff --git a/train.py b/train.py
index 4e2b983..8d9f4ad 100755
--- a/train.py
+++ b/train.py
@@ -44,8 +44,7 @@ def setup_train_stream(req_vars, valid_trips_ids):
train = transformers.TaxiGenerateSplits(train, max_splits=100)
train = transformers.TaxiAddDateTime(train)
- train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
- train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -56,8 +55,7 @@ def setup_valid_stream(req_vars):
valid = TaxiStream(config.valid_set, 'valid.hdf5')
valid = transformers.TaxiAddDateTime(valid)
- valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
- valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -68,8 +66,8 @@ def setup_test_stream(req_vars):
test = TaxiStream('test')
test = transformers.TaxiAddDateTime(test)
- test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
- test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLast(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))