diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 17:32:57 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 17:32:57 -0400 |
commit | 0527e6e696fa1832d599473099429295dea31650 (patch) | |
tree | 38d3e4946236e01746eb4216f83e00a3ba5ea14f | |
parent | bd2826df73554207c88c5918d86fd9707d9e3753 (diff) | |
download | taxi-0527e6e696fa1832d599473099429295dea31650.tar.gz taxi-0527e6e696fa1832d599473099429295dea31650.zip |
It kind of works (at least it does something now)
-rw-r--r-- | data.py | 10 | ||||
-rw-r--r-- | model.py | 22 | ||||
-rw-r--r-- | transformers.py | 7 |
3 files changed, 25 insertions, 14 deletions
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator PREFIX="/data/lisatmp3/auvolat/taxikaggle" +client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))} + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -87,7 +89,7 @@ class TaxiData(Dataset): state.index=0 state.file.close() state.file=open(self.pathes[0]) - state.reader=csv.reader(state[0]) + state.reader=csv.reader(state.file) return state def get_data(self, state, request=None): @@ -95,7 +97,7 @@ class TaxiData(Dataset): raise ValueError try: line=state.reader.next() - except StopIteration: + except ValueError: state.file.close() state.index+=1 if state.index>=len(self.pathes): @@ -104,10 +106,10 @@ class TaxiData(Dataset): state.reader=csv.reader(state.file) if self.has_header: state.reader.next() - line=state.reader.next() + return get_data(self, state) line[1]=CallType.from_data(line[1]) # call_type - line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call + line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand line[4]=int(line[4]) # taxi_id line[5]=int(line[5]) # timestamp @@ -29,7 +29,7 @@ n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday n_dom = 31 n_hour = 24 -n_clients = 57106 +n_clients = 57105 n_stands = 63 n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory @@ -45,9 +45,9 @@ def main(): # The input and the targets x_firstk = tensor.matrix('first_k') x_lastk = tensor.matrix('last_k') - x_client = tensor.lmatrix('origin_call') - x_stand = tensor.lmatrix('origin_stand') - y = tensor.vector('destination') + x_client = tensor.lvector('origin_call') + x_stand = tensor.lvector('origin_stand') + y = tensor.matrix('destination') # Define the model client_embed_table = LookupTable(length=n_clients+1, dim=dim_embed, name='client_lookup') @@ -60,12 +60,15 @@ def main(): client_embed = client_embed_table.apply(x_client).flatten(ndim=2) stand_embed = stand_embed_table.apply(x_stand).flatten(ndim=2) - inputs = tensor.concatenate([x_firstk, x_lastk, client_embed, stand_embed], axis=1) + inputs = tensor.concatenate([x_firstk, x_lastk, + client_embed, stand_embed], + axis=1) hidden = hidden_layer.apply(inputs) outputs = output_layer.apply(hidden) # Calculate the cost cost = (outputs - y).norm(2, axis=1).mean() + cost.name = 'cost' # Initialization client_embed_table.weights_init = IsotropicGaussian(0.001) @@ -83,12 +86,14 @@ def main(): train = DataStream(train) train = transformers.add_extremities(train, n_begin_end_pts) train = transformers.add_destination(train) + train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size)) valid = data.valid_data valid = DataStream(valid) valid = transformers.add_extremities(valid, n_begin_end_pts) valid = transformers.add_destination(valid) + valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) @@ -103,9 +108,10 @@ def main(): extensions=[DataStreamMonitoring([cost], valid_stream, prefix='valid', every_n_batches=100), - Printing(every_n_batches=100), - Dump('ngram_blocks_model', every_n_batches=100), - LoadFromDump('ngram_blocks_model')] + Printing(every_n_batches=100), + # Dump('taxi_model', every_n_batches=100), + # LoadFromDump('taxi_model'), + ] main_loop = MainLoop( model=Model([cost]), diff --git a/transformers.py b/transformers.py index 92ee439..29b8094 100644 --- a/transformers.py +++ b/transformers.py @@ -1,4 +1,6 @@ from fuel.transformers import Transformer, Filter, Mapping +import numpy +import theano class Select(Transformer): def __init__(self, data_stream, sources): @@ -15,11 +17,12 @@ class Select(Transformer): def add_extremities(stream, k): id_polyline=stream.sources.index('polyline') def extremities(x): - return (x[id_polyline][:k], x[id_polyline][-k:]) + return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(), + numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten()) stream = Filter(stream, lambda x: len(x[id_polyline])>=k) stream = Mapping(stream, extremities, ('first_k', 'last_k')) return stream def add_destination(stream): id_polyline=stream.sources.index('polyline') - return Mapping(stream, lambda x: x[id_polyline][-1], ('destination',)) + return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',)) |