aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config/simple_mlp_2_cs.py4
-rw-r--r--config/simple_mlp_2_noembed.py4
-rw-r--r--config/simple_mlp_tgtcls_0_cs.py4
-rw-r--r--config/simple_mlp_tgtcls_1_cs.py4
-rw-r--r--config/simple_mlp_tgtcls_1_cswdt.py28
-rw-r--r--train.py18
-rw-r--r--transformers.py127
7 files changed, 95 insertions, 94 deletions
diff --git a/config/simple_mlp_2_cs.py b/config/simple_mlp_2_cs.py
index 692d325..fa2f4c1 100644
--- a/config/simple_mlp_2_cs.py
+++ b/config/simple_mlp_2_cs.py
@@ -2,10 +2,6 @@ import model.simple_mlp as model
import data
-n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
-n_dom = 31
-n_hour = 24
-
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
n_end_pts = 5
diff --git a/config/simple_mlp_2_noembed.py b/config/simple_mlp_2_noembed.py
index bc300e7..2f45f63 100644
--- a/config/simple_mlp_2_noembed.py
+++ b/config/simple_mlp_2_noembed.py
@@ -2,10 +2,6 @@ import model.simple_mlp as model
import data
-n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
-n_dom = 31
-n_hour = 24
-
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
n_end_pts = 5
diff --git a/config/simple_mlp_tgtcls_0_cs.py b/config/simple_mlp_tgtcls_0_cs.py
index b174517..96faca0 100644
--- a/config/simple_mlp_tgtcls_0_cs.py
+++ b/config/simple_mlp_tgtcls_0_cs.py
@@ -4,10 +4,6 @@ import data
import model.simple_mlp_tgtcls as model
-n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
-n_dom = 31
-n_hour = 24
-
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
n_end_pts = 5
diff --git a/config/simple_mlp_tgtcls_1_cs.py b/config/simple_mlp_tgtcls_1_cs.py
index 6bf82e1..293a0ab 100644
--- a/config/simple_mlp_tgtcls_1_cs.py
+++ b/config/simple_mlp_tgtcls_1_cs.py
@@ -4,10 +4,6 @@ import data
import model.simple_mlp_tgtcls as model
-n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
-n_dom = 31
-n_hour = 24
-
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
n_end_pts = 5
diff --git a/config/simple_mlp_tgtcls_1_cswdt.py b/config/simple_mlp_tgtcls_1_cswdt.py
new file mode 100644
index 0000000..9261635
--- /dev/null
+++ b/config/simple_mlp_tgtcls_1_cswdt.py
@@ -0,0 +1,28 @@
+import cPickle
+
+import data
+
+import model.simple_mlp_tgtcls as model
+
+n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
+n_end_pts = 5
+
+n_valid = 1000
+
+with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f)
+
+dim_embeddings = [
+ ('origin_call', data.n_train_clients+1, 10),
+ ('origin_stand', data.n_stands+1, 10),
+ ('week_of_year', 53, 10),
+ ('day_of_week', 7, 10),
+ ('qhour_of_day', 24 * 4, 10)
+]
+
+dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
+dim_hidden = [500]
+dim_output = tgtcls.shape[0]
+
+learning_rate = 0.0001
+momentum = 0.99
+batch_size = 32
diff --git a/train.py b/train.py
index 238803a..2c9522e 100644
--- a/train.py
+++ b/train.py
@@ -49,11 +49,13 @@ def setup_train_stream(req_vars):
subset=slice(0, data.dataset_size),
load_in_memory=True)
train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
- train = transformers.filter_out_trips(data.valid_trips, train)
+
+ train = transformers.TaxiExcludeTrips(data.valid_trips, train)
train = transformers.TaxiGenerateSplits(train, max_splits=100)
- train = transformers.add_first_k(config.n_begin_end_pts, train)
- train = transformers.add_last_k(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddDateTime(train)
+ train = transformers.TaxiAddFirstK(config.n_begin_end_pts, train)
+ train = transformers.TaxiAddLastK(config.n_begin_end_pts, train)
train = transformers.Select(train, tuple(req_vars))
train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size))
@@ -63,8 +65,9 @@ def setup_train_stream(req_vars):
def setup_valid_stream(req_vars):
valid = DataStream(data.valid_data)
- valid = transformers.add_first_k(config.n_begin_end_pts, valid)
- valid = transformers.add_last_k(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddDateTime(valid)
+ valid = transformers.TaxiAddFirstK(config.n_begin_end_pts, valid)
+ valid = transformers.TaxiAddLastK(config.n_begin_end_pts, valid)
valid = transformers.Select(valid, tuple(req_vars))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000))
@@ -74,8 +77,9 @@ def setup_valid_stream(req_vars):
def setup_test_stream(req_vars):
test = DataStream(data.test_data)
- test = transformers.add_first_k(config.n_begin_end_pts, test)
- test = transformers.add_last_k(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddDateTime(test)
+ test = transformers.TaxiAddFirstK(config.n_begin_end_pts, test)
+ test = transformers.TaxiAddLastK(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
diff --git a/transformers.py b/transformers.py
index 876cee2..d6ed611 100644
--- a/transformers.py
+++ b/transformers.py
@@ -4,6 +4,8 @@ import theano
import random
import data
+import datetime
+
def at_least_k(k, v, pad_at_begin, is_longitude):
if len(v) == 0:
v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX)
@@ -63,78 +65,61 @@ class TaxiGenerateSplits(Transformer):
return tuple(r + [dlat, dlon])
-
-class first_k(object):
- def __init__(self, k, id_latitude, id_longitude):
- self.k = k
- self.id_latitude = id_latitude
- self.id_longitude = id_longitude
- def __call__(self, data):
- return (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
- dtype=theano.config.floatX),
- numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
- dtype=theano.config.floatX))
-def add_first_k(k, stream):
- id_latitude = stream.sources.index('latitude')
- id_longitude = stream.sources.index('longitude')
- return Mapping(stream, first_k(k, id_latitude, id_longitude), ('first_k_latitude', 'first_k_longitude'))
-
-class random_k(object):
- def __init__(self, k, id_latitude, id_longitude):
+class TaxiAddFirstK(Transformer):
+ def __init__(self, k, stream):
+ super(TaxiAddFirstK, self).__init__(stream)
+ self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
+ self.id_latitude = stream.sources.index('latitude')
+ self.id_longitude = stream.sources.index('longitude')
self.k = k
- self.id_latitude = id_latitude
- self.id_longitude = id_longitude
- def __call__(self, x):
- lat = at_least_k(self.k, x[self.id_latitude], True, False)
- lon = at_least_k(self.k, x[self.id_longitude], True, True)
- loc = random.randrange(len(lat)-self.k+1)
- return (numpy.array(lat[loc:loc+self.k], dtype=theano.config.floatX),
- numpy.array(lon[loc:loc+self.k], dtype=theano.config.floatX))
-def add_random_k(k, stream):
- id_latitude = stream.sources.index('latitude')
- id_longitude = stream.sources.index('longitude')
- return Mapping(stream, random_k(k, id_latitude, id_longitude), ('last_k_latitude', 'last_k_longitude'))
-
-class last_k(object):
- def __init__(self, k, id_latitude, id_longitude):
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
+ dtype=theano.config.floatX),
+ numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
+ dtype=theano.config.floatX))
+ return data + first_k
+
+class TaxiAddLastK(Transformer):
+ def __init__(self, k, stream):
+ super(TaxiAddLastK, self).__init__(stream)
+ self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
+ self.id_latitude = stream.sources.index('latitude')
+ self.id_longitude = stream.sources.index('longitude')
self.k = k
- self.id_latitude = id_latitude
- self.id_longitude = id_longitude
- def __call__(self, data):
- return (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
- dtype=theano.config.floatX),
- numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
- dtype=theano.config.floatX))
-def add_last_k(k, stream):
- id_latitude = stream.sources.index('latitude')
- id_longitude = stream.sources.index('longitude')
- return Mapping(stream, last_k(k, id_latitude, id_longitude), ('last_k_latitude', 'last_k_longitude'))
-
-class destination(object):
- def __init__(self, id_latitude, id_longitude):
- self.id_latitude = id_latitude
- self.id_longitude = id_longitude
- def __call__(self, data):
- return (numpy.array(at_least_k(1, data[self.id_latitude], True, False)[-1],
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
dtype=theano.config.floatX),
- numpy.array(at_least_k(1, data[self.id_longitude], True, True)[-1],
- dtype=theano.config.floatX))
-def add_destination(stream):
- id_latitude = stream.sources.index('latitude')
- id_longitude = stream.sources.index('longitude')
- return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude'))
-
-
-class trip_filter(object):
- def __init__(self, id_trip_id, exclude):
- self.id_trip_id = id_trip_id
- self.exclude = exclude
- def __call__(self, data):
- if data[self.id_trip_id] in self.exclude:
- return False
- else:
- return True
-def filter_out_trips(exclude_trips, stream):
- id_trip_id = stream.sources.index('trip_id')
- return Filter(stream, trip_filter(id_trip_id, exclude_trips))
+ numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
+ dtype=theano.config.floatX))
+ return data + last_k
+
+class TaxiAddDateTime(Transformer):
+ def __init__(self, stream):
+ super(TaxiAddDateTime, self).__init__(stream)
+ self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day')
+ self.id_timestamp = stream.sources.index('timestamp')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ data = next(self.child_epoch_iterator)
+ ts = data[self.id_timestamp]
+ date = datetime.datetime.utcfromtimestamp(ts)
+ info = (date.isocalendar()[1] - 1, date.weekday(), date.hour * 4 + date.minute / 15)
+ return data + info
+
+class TaxiExcludeTrips(Transformer):
+ def __init__(self, exclude_list, stream):
+ super(TaxiExcludeTrips, self).__init__(stream)
+ self.id_trip_id = stream.sources.index('trip_id')
+ self.exclude = {v: True for v in exclude_list}
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ while True:
+ data = next(self.child_epoch_iterator)
+ if not data[self.id_trip_id] in self.exclude: break
+ return data
+