From b44f7113c64568a20c9f93ca17577f17d7695dcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 11 Jun 2015 16:26:29 -0400 Subject: Use Mapping instead of extending Transformer --- data/transformers.py | 129 +++++++++++++++++++++++++++++++++------------------ model/mlp.py | 16 +++---- 2 files changed, 93 insertions(+), 52 deletions(-) diff --git a/data/transformers.py b/data/transformers.py index 4814e9b..1bed887 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -3,7 +3,8 @@ import random import numpy import theano -from fuel.transformers import Transformer +from fuel.schemes import ConstantScheme +from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack import data @@ -30,6 +31,29 @@ class Select(Transformer): raise ValueError data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] + +class TaxiExcludeTrips(Transformer): + def __init__(self, stream, exclude_list): + super(TaxiExcludeTrips, self).__init__(stream) + self.id_trip_id = stream.sources.index('trip_id') + self.exclude = {v: True for v in exclude_list} + def get_data(self, request=None): + if request is not None: raise ValueError + while True: + data = next(self.child_epoch_iterator) + if not data[self.id_trip_id] in self.exclude: break + return data + +class TaxiExcludeEmptyTrips(Transformer): + def __init__(self, stream): + super(TaxiExcludeEmptyTrips, self).__init__(stream) + self.latitude = stream.sources.index('latitude') + def get_data(self, request=None): + if request is not None: raise ValueError + while True: + data = next(self.child_epoch_iterator) + if len(data[self.latitude])>0: break + return data class TaxiGenerateSplits(Transformer): def __init__(self, data_stream, max_splits=-1): @@ -68,18 +92,13 @@ class TaxiGenerateSplits(Transformer): return tuple(r + [dlat, dlon, ttime]) -class TaxiAddFirstLastLen(Transformer): - def __init__(self, k, stream): - super(TaxiAddFirstLastLen, self).__init__(stream) - self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude', - 'last_k_latitude', 'last_k_longitude', - 'input_time') - self.id_latitude = stream.sources.index('latitude') - self.id_longitude = stream.sources.index('longitude') + +class _taxi_add_first_last_len_helper(object): + def __init__(self, k, latitude, longitude): self.k = k - def get_data(self, request=None): - if request is not None: raise ValueError - data = next(self.child_epoch_iterator) + self.id_latitude = id_latitude + self.id_longitude = id_longitude + def __call__(self, data): first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k], dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k], @@ -89,43 +108,65 @@ class TaxiAddFirstLastLen(Transformer): numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),) - return data + first_k + last_k + input_time + return first_k + last_k + input_time -class TaxiAddDateTime(Transformer): - def __init__(self, stream): - super(TaxiAddDateTime, self).__init__(stream) - self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day') - self.id_timestamp = stream.sources.index('timestamp') - def get_data(self, request=None): - if request is not None: raise ValueError - data = next(self.child_epoch_iterator) - ts = data[self.id_timestamp] +def taxi_add_first_last_len(stream, k): + fun = _taxi_add_first_last_len_helper(k, stream.sources.index('latitude'), stream.sources.index('longitude')) + return Mapping(stream, fun, add_sources=('first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude', 'input_time')) + + +class _taxi_add_datetime_helper(object): + def __init__(self, key): + self.key = key + def __call__(self, data): + ts = data[self.key] date = datetime.datetime.utcfromtimestamp(ts) yearweek = date.isocalendar()[1] - 1 info = (numpy.int8(51 if yearweek == 52 else yearweek), numpy.int8(date.weekday()), numpy.int8(date.hour * 4 + date.minute / 15)) - return data + info + return info + +def taxi_add_datetime(stream): + fun = _taxi_add_datetime_helper(stream.sources.index('timestamp')) + return Mapping(stream, fun, add_sources=('week_of_year', 'day_of_week', 'qhour_of_day')) + + +class _balanced_batch_helper(object): + def __init__(self, key): + self.key = key + def __call__(self, data): + return len(data[self.key]) + +def balanced_batch(stream, key, batch_size, batch_sort_size): + stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size)) + comparison = _balanced_batch_helper(stream.sources.index(key)) + stream = Mapping(stream, SortMapping(comparison)) + stream = Unpack(stream) + return Batch(stream, iteration_scheme=ConstantScheme(batch_size)) + + +class _taxi_remove_test_only_clients_helper(object): + def __init__(self, key): + self.key = key + def __call__(self, x): + x = list(x) + if x[self.key] >= data.origin_call_train_size: + x[self.key] = numpy.int32(0) + return tuple(x) -class TaxiExcludeTrips(Transformer): - def __init__(self, exclude_list, stream): - super(TaxiExcludeTrips, self).__init__(stream) - self.id_trip_id = stream.sources.index('trip_id') - self.exclude = {v: True for v in exclude_list} - def get_data(self, request=None): - if request is not None: raise ValueError - while True: - data = next(self.child_epoch_iterator) - if not data[self.id_trip_id] in self.exclude: break - return data +def taxi_remove_test_only_clients(stream): + fun = _taxi_remove_test_only_clients_helper(stream.sources.index('origin_call')) + return Mapping(stream, fun) -class TaxiRemoveTestOnlyClients(Transformer): - def __init__(self, stream): - super(TaxiRemoveTestOnlyClients, self).__init__(stream) - self.id_origin_call = stream.sources.index('origin_call') - def get_data(self, request=None): - if request is not None: raise ValueError - x = list(next(self.child_epoch_iterator)) - if x[self.id_origin_call] >= data.origin_call_train_size: - x[self.id_origin_call] = numpy.int32(0) - return tuple(x) + +class _add_destination_helper(object): + def __init__(self, latitude, longitude): + self.latitude = latitude + self.longitude = longitude + def __call__(self, data): + return (data[self.latitude][-1], data[self.longitude][-1]) + +def add_destination(stream): + fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude')) + return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude')) diff --git a/model/mlp.py b/model/mlp.py index b1e9163..6abc86f 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -61,11 +61,11 @@ class Stream(object): else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) @@ -77,17 +77,17 @@ class Stream(object): def valid(self, req_vars): stream = TaxiStream(self.config.valid_set, 'valid.hdf5') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): stream = TaxiStream('test') - stream = transformers.TaxiAddDateTime(stream) - stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream) - stream = transformers.TaxiRemoveTestOnlyClients(stream) + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) + stream = transformers.taxi_remove_test_only_clients(stream) return Batch(stream, iteration_scheme=ConstantScheme(1)) -- cgit v1.2.3