diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
commit | a25d4fb6e92f203183de2d89e8c467a6b14e1730 (patch) | |
tree | 6d448760e647572d52242f5726224cdf20e832ee | |
parent | ccd1245db7f6799ab4e1f45a8cead85ed67f1c72 (diff) | |
download | taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.tar.gz taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.zip |
Implement HDist, transformer that selects k at a random position.
-rw-r--r-- | data.py | 8 | ||||
-rw-r--r-- | hdist.py | 25 | ||||
-rw-r--r-- | model.py | 18 | ||||
-rw-r--r-- | transformers.py | 19 |
4 files changed, 54 insertions, 16 deletions
@@ -69,7 +69,7 @@ class TaxiData(Dataset): def __init__(self, pathes, has_header=False): if not isinstance(pathes, list): pathes=[pathes] - assert len(pathes) + assert len(pathes)>0 self.pathes=pathes self.has_header=has_header super(TaxiData, self).__init__() @@ -101,12 +101,12 @@ class TaxiData(Dataset): raise ValueError try: line=state.reader.next() - except StopIteration: - print state.index + except (ValueError, StopIteration): + # print state.index state.file.close() state.index+=1 if state.index>=len(self.pathes): - raise + raise StopIteration state.file=open(self.pathes[state.index]) state.reader=csv.reader(state.file) if self.has_header: diff --git a/hdist.py b/hdist.py new file mode 100644 index 0000000..4944cb8 --- /dev/null +++ b/hdist.py @@ -0,0 +1,25 @@ +from theano import tensor +import numpy + + +def hdist(a, b): + rearth = numpy.float32(6371) + deg2rad = numpy.float32(3.14159265358979 / 180) + + lat1 = a[:, 0] * deg2rad + lon1 = a[:, 1] * deg2rad + lat2 = b[:, 0] * deg2rad + lon2 = b[:, 1] * deg2rad + + dlat = abs(lat1-lat2) + dlon = abs(lon1-lon2) + + al = tensor.sin(dlat/2)**2 + tensor.cos(lat1) * tensor.cos(lat2) * (tensor.sin(dlon/2)**2) + d = tensor.arctan2(tensor.sqrt(al), tensor.sqrt(numpy.float32(1)-al)) + + hd = 2 * rearth * d + + return tensor.switch(tensor.eq(hd, float('nan')), (a-b).norm(2, axis=1), hd) + + + @@ -24,6 +24,7 @@ from blocks.extensions.monitoring import DataStreamMonitoring import data import transformers +import hdist n_dow = 7 # number of division for dayofweek/dayofmonth/hourofday n_dom = 31 @@ -38,8 +39,8 @@ n_end_pts = 5 dim_embed = 50 dim_hidden = 200 -learning_rate = 0.1 -batch_size = 32 +learning_rate = 0.01 +batch_size = 64 def main(): # The input and the targets @@ -67,7 +68,8 @@ def main(): outputs = output_layer.apply(hidden) # Calculate the cost - cost = (outputs - y).norm(2, axis=1).mean() + # cost = (outputs - y).norm(2, axis=1).mean() + cost = hdist.hdist(outputs, y).mean() cost.name = 'cost' # Initialization @@ -84,14 +86,16 @@ def main(): # Load the training and test data train = data.train_data train = DataStream(train) - train = transformers.add_extremities(train, n_begin_end_pts) + train = transformers.add_first_k(n_begin_end_pts, train) + train = transformers.add_random_k(n_begin_end_pts, train) train = transformers.add_destination(train) train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) train_stream = Batch(train, iteration_scheme=ConstantScheme(batch_size)) valid = data.valid_data valid = DataStream(valid) - valid = transformers.add_extremities(valid, n_begin_end_pts) + valid = transformers.add_first_k(n_begin_end_pts, valid) + valid = transformers.add_random_k(n_begin_end_pts, valid) valid = transformers.add_destination(valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size)) @@ -107,8 +111,8 @@ def main(): extensions=[DataStreamMonitoring([cost], valid_stream, prefix='valid', - every_n_batches=100), - Printing(every_n_batches=100), + every_n_batches=1000), + Printing(every_n_batches=1000), # Dump('taxi_model', every_n_batches=100), # LoadFromDump('taxi_model'), ] diff --git a/transformers.py b/transformers.py index 29b8094..b6f5e14 100644 --- a/transformers.py +++ b/transformers.py @@ -1,6 +1,7 @@ from fuel.transformers import Transformer, Filter, Mapping import numpy import theano +import random class Select(Transformer): def __init__(self, data_stream, sources): @@ -14,13 +15,21 @@ class Select(Transformer): data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] -def add_extremities(stream, k): +def add_first_k(k, stream): id_polyline=stream.sources.index('polyline') - def extremities(x): - return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(), - numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten()) + def first_k(x): + return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),) stream = Filter(stream, lambda x: len(x[id_polyline])>=k) - stream = Mapping(stream, extremities, ('first_k', 'last_k')) + stream = Mapping(stream, first_k, ('first_k',)) + return stream + +def add_random_k(k, stream): + id_polyline=stream.sources.index('polyline') + def random_k(x): + loc = random.randrange(len(x[id_polyline])-k+1) + return (numpy.array(x[id_polyline][loc:loc+k], dtype=theano.config.floatX).flatten(),) + stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + stream = Mapping(stream, random_k, ('last_k',)) return stream def add_destination(stream): |