aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 17:32:57 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-24 17:32:57 -0400
commit0527e6e696fa1832d599473099429295dea31650 (patch)
tree38d3e4946236e01746eb4216f83e00a3ba5ea14f
parentbd2826df73554207c88c5918d86fd9707d9e3753 (diff)
downloadtaxi-0527e6e696fa1832d599473099429295dea31650.tar.gz
taxi-0527e6e696fa1832d599473099429295dea31650.zip
It kind of works (at least it does something now)
-rw-r--r--data.py10
-rw-r--r--model.py22
-rw-r--r--transformers.py7
3 files changed, 25 insertions, 14 deletions
diff --git a/data.py b/data.py
index 4590a7b..e6b7cbf 100644
--- a/data.py
+++ b/data.py
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator
PREFIX="/data/lisatmp3/auvolat/taxikaggle"
+client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))}
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -87,7 +89,7 @@ class TaxiData(Dataset):
state.index=0
state.file.close()
state.file=open(self.pathes[0])
- state.reader=csv.reader(state[0])
+ state.reader=csv.reader(state.file)
return state
def get_data(self, state, request=None):
@@ -95,7 +97,7 @@ class TaxiData(Dataset):
raise ValueError
try:
line=state.reader.next()
- except StopIteration:
+ except ValueError:
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
@@ -104,10 +106,10 @@ class TaxiData(Dataset):
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
- line=state.reader.next()
+ return get_data(self, state)
line[1]=CallType.from_data(line[1]) # call_type
- line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
+ line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
line[4]=int(line[4]) # taxi_id
line[5]=int(line[5]) # timestamp
diff --git a/model.py b/model.py
index 360237e..95b002f 100644
--- a/model.py
+++ b/model.py
@@ -29,7 +29,7 @@ n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday
n_dom = 31
n_hour = 24
-n_clients = 57106
+n_clients = 57105
n_stands = 63
n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
@@ -45,9 +45,9 @@ def main():
# The input and the targets
x_firstk = tensor.matrix('first_k')
x_lastk = tensor.matrix('last_k')
- x_client = tensor.lmatrix('origin_call')
- x_stand = tensor.lmatrix('origin_stand')
- y = tensor.vector('destination')
+ x_client = tensor.lvector('origin_call')
+ x_stand = tensor.lvector('origin_stand')
+ y = tensor.matrix('destination')
# Define the model
client_embed_table = LookupTable(length=n_clients+1, dim=dim_embed, name='client_lookup')
@@ -60,12 +60,15 @@ def main():
client_embed = client_embed_table.apply(x_client).flatten(ndim=2)
stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2)
- inputs = tensor.concatenate([x_firstk, x_lastk, client_embed, stand_embed], axis=1)
+ inputs = tensor.concatenate([x_firstk, x_lastk,
+ client_embed, stand_embed],
+ axis=1)
hidden = hidden_layer.apply(inputs)
outputs = output_layer.apply(hidden)
# Calculate the cost
cost = (outputs - y).norm(2, axis=1).mean()
+ cost.name = 'cost'
# Initialization
client_embed_table.weights_init = IsotropicGaussian(0.001)
@@ -83,12 +86,14 @@ def main():
train = DataStream(train)
train = transformers.add_extremities(train, n_begin_end_pts)
train = transformers.add_destination(train)
+ train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size))
valid = data.valid_data
valid = DataStream(valid)
valid = transformers.add_extremities(valid, n_begin_end_pts)
valid = transformers.add_destination(valid)
+ valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))
@@ -103,9 +108,10 @@ def main():
extensions=[DataStreamMonitoring([cost], valid_stream,
prefix='valid',
every_n_batches=100),
- Printing(every_n_batches=100),
- Dump('ngram_blocks_model', every_n_batches=100),
- LoadFromDump('ngram_blocks_model')]
+ Printing(every_n_batches=100),
+ # Dump('taxi_model', every_n_batches=100),
+ # LoadFromDump('taxi_model'),
+ ]
main_loop = MainLoop(
model=Model([cost]),
diff --git a/transformers.py b/transformers.py
index 92ee439..29b8094 100644
--- a/transformers.py
+++ b/transformers.py
@@ -1,4 +1,6 @@
from fuel.transformers import Transformer, Filter, Mapping
+import numpy
+import theano
class Select(Transformer):
def __init__(self, data_stream, sources):
@@ -15,11 +17,12 @@ class Select(Transformer):
def add_extremities(stream, k):
id_polyline=stream.sources.index('polyline')
def extremities(x):
- return (x[id_polyline][:k], x[id_polyline][-k:])
+ return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),
+ numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten())
stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
stream = Mapping(stream, extremities, ('first_k', 'last_k'))
return stream
def add_destination(stream):
id_polyline=stream.sources.index('polyline')
- return Mapping(stream, lambda x: x[id_polyline][-1], ('destination',))
+ return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',))