aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config/model_0.py6
-rw-r--r--data.py35
-rw-r--r--model.py36
3 files changed, 43 insertions, 34 deletions
diff --git a/config/model_0.py b/config/model_0.py
index ba04a15..736ef30 100644
--- a/config/model_0.py
+++ b/config/model_0.py
@@ -2,16 +2,12 @@ n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
n_dom = 31
n_hour = 24
-n_clients = 57124 #57105
-n_stands = 63
-
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
n_end_pts = 5
-train_size = 1710670
n_valid = 1000
-dim_embed = 50
+dim_embed = 10
dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed
dim_hidden = [200]
dim_output = 2
diff --git a/data.py b/data.py
index d38df10..79b38c7 100644
--- a/data.py
+++ b/data.py
@@ -24,6 +24,11 @@ def get_client_id(n):
porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
+n_clients = 57124 #57105
+n_stands = 63
+
+dataset_size = 1710670
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -126,26 +131,28 @@ class TaxiData(Dataset):
return self.get_data(state)
values = []
- for idx, (_, constructor) in enumerate(self.columns):
- values.append(constructor(line[idx]))
+ for _, constructor in self.columns:
+ values.append(constructor(line))
return tuple(values)
taxi_columns = [
- ("trip_id", lambda x: x),
- ("call_type", CallType.from_data),
- ("origin_call", lambda x: 0 if x == '' or x == 'NA' else get_client_id(int(x))),
- ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
- ("taxi_id", int),
- ("timestamp", int),
- ("day_type", DayType.from_data),
- ("missing_data", lambda x: x[0] == 'T'),
- ("polyline", lambda x: map(tuple, ast.literal_eval(x))),
+ ("trip_id", lambda l: l[0]),
+ ("call_type", lambda l: CallType.from_data(l[1])),
+ ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
+ ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
+ ("taxi_id", lambda l: int(l[4])),
+ ("timestamp", lambda l: int(l[5])),
+ ("day_type", lambda l: DayType.from_data(l[6])),
+ ("missing_data", lambda l: l[7][0] == 'T'),
+ ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
+ ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
+ ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]
taxi_columns_valid = taxi_columns + [
- ("destination_x", float),
- ("destination_y", float),
- ("time", int),
+ ("destination_longitude", lambda l: float(l[9])),
+ ("destination_latitude", lambda l: float(l[10])),
+ ("time", lambda l: int(l[11])),
]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
diff --git a/model.py b/model.py
index 6a8be9d..065db7e 100644
--- a/model.py
+++ b/model.py
@@ -29,7 +29,7 @@ from fuel.schemes import ConstantScheme, SequentialExampleScheme
from blocks.algorithms import GradientDescent, Scale, AdaDelta, Momentum
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
-from blocks.extensions import Printing
+from blocks.extensions import Printing, FinishAfter
from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
from blocks.extensions.monitoring import DataStreamMonitoring
@@ -48,9 +48,9 @@ def setup_stream():
# Load the training and test data
train = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
which_set='train',
- subset=slice(0, config.train_size - config.n_valid),
+ subset=slice(0, data.dataset_size - config.n_valid),
load_in_memory=True)
- train = DataStream(train, iteration_scheme=SequentialExampleScheme(config.train_size - config.n_valid))
+ train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid))
train = transformers.add_first_k(config.n_begin_end_pts, train)
train = transformers.add_random_k(config.n_begin_end_pts, train)
train = transformers.add_destination(train)
@@ -61,7 +61,7 @@ def setup_stream():
valid = H5PYDataset('/data/lisatmp3/simonet/taxi/data.hdf5',
which_set='train',
- subset=slice(config.train_size - config.n_valid, config.train_size),
+ subset=slice(data.dataset_size - config.n_valid, data.dataset_size),
load_in_memory=True)
valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid))
valid = transformers.add_first_k(config.n_begin_end_pts, valid)
@@ -74,6 +74,18 @@ def setup_stream():
return (train_stream, valid_stream)
+def setup_test_stream():
+ test = data.test_data
+
+ test = DataStream(test)
+ test = transformers.add_first_k(config.n_begin_end_pts, test)
+ test = transformers.add_last_k(config.n_begin_end_pts, test)
+ test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k_latitude',
+ 'last_k_latitude', 'first_k_longitude', 'last_k_longitude'))
+ test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
+
+ return test_stream
+
def main():
# The input and the targets
x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0]
@@ -94,8 +106,8 @@ def main():
# x_lastk_longitude = theano.printing.Print("x_lastk_longitude")(x_lastk_longitude)
# Define the model
- client_embed_table = LookupTable(length=config.n_clients+1, dim=config.dim_embed, name='client_lookup')
- stand_embed_table = LookupTable(length=config.n_stands+1, dim=config.dim_embed, name='stand_lookup')
+ client_embed_table = LookupTable(length=data.n_clients+1, dim=config.dim_embed, name='client_lookup')
+ stand_embed_table = LookupTable(length=data.n_stands+1, dim=config.dim_embed, name='stand_lookup')
mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
dims=[config.dim_input] + config.dim_hidden + [config.dim_output])
@@ -152,6 +164,7 @@ def main():
# Checkpoint('model.pkl', every_n_batches=100),
Dump('taxi_model', every_n_batches=1000),
LoadFromDump('taxi_model'),
+ FinishAfter(after_epoch=1)
]
main_loop = MainLoop(
@@ -163,13 +176,7 @@ def main():
main_loop.profile.report()
# Produce an output on the test data
- '''
- test = data.test_data
- test = DataStream(test)
- test = transformers.add_first_k(conifg.n_begin_end_pts, test)
- test = transformers.add_last_k(config.n_begin_end_pts, test)
- test = transformers.Select(test, ('trip_id', 'origin_stand', 'origin_call', 'first_k', 'last_k'))
- test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
+ test_stream = setup_test_stream()
outfile = open("test-output.csv", "w")
outcsv = csv.writer(outfile)
@@ -177,9 +184,8 @@ def main():
for out in apply_model.Apply(outputs=outputs, stream=test_stream, return_vars=['trip_id', 'outputs']):
dest = out['outputs']
for i, trip in enumerate(out['trip_id']):
- outcsv.writerow([trip, repr(dest[i, 1]), repr(dest[i, 0])])
+ outcsv.writerow([trip, repr(dest[i, 0]), repr(dest[i, 1])])
outfile.close()
- '''
if __name__ == "__main__":