diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
commit | de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055 (patch) | |
tree | 09c09a12861f0f6826cd33e3b77eba9a07076c49 | |
parent | 43e106e6630030dd34813295fe1d07bb86025402 (diff) | |
download | taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.tar.gz taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.zip |
Add TaxiGenerateSplits
-rw-r--r-- | data.py | 10 | ||||
-rw-r--r-- | model.py | 28 | ||||
-rw-r--r-- | transformers.py | 51 |
3 files changed, 73 insertions, 16 deletions
@@ -164,21 +164,25 @@ taxi_columns = [ ] taxi_columns_valid = taxi_columns + [ - ("destination_longitude", lambda l: float(l[9])), - ("destination_latitude", lambda l: float(l[10])), + ("destination_longitude", lambda l: numpy.float32(float(l[9]))), + ("destination_latitude", lambda l: numpy.float32(float(l[10]))), ("time", lambda l: int(l[11])), ] train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] -valid_files=["%s/split/valid.csv" % (DATA_PATH,)] +valid_files=["%s/split/valid2-cut.csv" % (DATA_PATH,)] test_file="%s/test.csv" % (DATA_PATH,) train_data=TaxiData(train_files, taxi_columns) valid_data = TaxiData(valid_files, taxi_columns_valid) test_data = TaxiData(test_file, taxi_columns, has_header=True) +valid_trips = [l for l in open(DATA_PATH + "/split/valid2-cut-ids.txt")] + def train_it(): return DataIterator(DataStream(train_data)) def test_it(): return DataIterator(DataStream(valid_data)) + + @@ -44,35 +44,35 @@ if __name__ == "__main__": sys.exit(1) config = importlib.import_module(sys.argv[1]) -def setup_stream(): + +def setup_train_stream(): # Load the training and test data train = H5PYDataset(data.H5DATA_PATH, which_set='train', - subset=slice(0, data.dataset_size - config.n_valid), + subset=slice(0, data.dataset_size), load_in_memory=True) train = DataStream(train, iteration_scheme=SequentialExampleScheme(data.dataset_size - config.n_valid)) + train = transformers.filter_out_trips(data.valid_trips, train) + train = transformers.TaxiGenerateSplits(train) train = transformers.add_first_k(config.n_begin_end_pts, train) - train = transformers.add_random_k(config.n_begin_end_pts, train) - train = transformers.add_destination(train) + train = transformers.add_last_k(config.n_begin_end_pts, train) train = transformers.Select(train, ('origin_stand', 'origin_call', 'first_k_latitude', 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', 'destination_latitude', 'destination_longitude')) train_stream = Batch(train, iteration_scheme=ConstantScheme(config.batch_size)) - valid = H5PYDataset(data.H5DATA_PATH, - which_set='train', - subset=slice(data.dataset_size - config.n_valid, data.dataset_size), - load_in_memory=True) - valid = DataStream(valid, iteration_scheme=SequentialExampleScheme(config.n_valid)) + return train_stream + +def setup_valid_stream(): + valid = DataStream(data.valid_data) valid = transformers.add_first_k(config.n_begin_end_pts, valid) - valid = transformers.add_random_k(config.n_begin_end_pts, valid) - valid = transformers.add_destination(valid) + valid = transformers.add_last_k(config.n_begin_end_pts, valid) valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k_latitude', 'last_k_latitude', 'first_k_longitude', 'last_k_longitude', 'destination_latitude', 'destination_longitude')) valid_stream = Batch(valid, iteration_scheme=ConstantScheme(1000)) - return (train_stream, valid_stream) + return valid_stream def setup_test_stream(): test = data.test_data @@ -86,6 +86,7 @@ def setup_test_stream(): return test_stream + def main(): # The input and the targets x_firstk_latitude = (tensor.matrix('first_k_latitude') - data.porto_center[0]) / data.data_std[0] @@ -146,7 +147,8 @@ def main(): stand_embed_table.initialize() mlp.initialize() - (train_stream, valid_stream) = setup_stream() + train_stream = setup_train_stream() + valid_stream = setup_valid_stream() # Training cg = ComputationGraph(cost) diff --git a/transformers.py b/transformers.py index 5ad9a87..79e8327 100644 --- a/transformers.py +++ b/transformers.py @@ -27,6 +27,42 @@ class Select(Transformer): data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] +class TaxiGenerateSplits(Transformer): + def __init__(self, data_stream, max_splits=-1): + super(TaxiGenerateSplits, self).__init__(data_stream) + self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude') + self.max_splits = max_splits + self.data = None + self.splits = [] + self.isplit = 0 + self.id_latitude = data_stream.sources.index('latitude') + self.id_longitude = data_stream.sources.index('longitude') + + def get_data(self, request=None): + if request is not None: + raise ValueError + while self.isplit >= len(self.splits): + self.data = next(self.child_epoch_iterator) + self.splits = range(len(self.data[self.id_polyline])) + random.shuffle_array(self.splits) + if self.max_splits != -1 and len(self.splits) > self.max_splits: + self.splits = self.splits[:self.max_splits] + self.isplit = 0 + + i = self.isplit + self.isplit += 1 + n = self.splits[i]+1 + + r = list(self.data) + + r[self.id_latitude] = r[self.id_latitude][:n] + r[self.id_longitude] = r[self.id_longitude][:n] + + dlat = self.data[self.id_latitude][-1] + dlon = self.data[self.id_longitude][-1] + + return tuple(r + [dlat, dlon]) + class first_k(object): def __init__(self, k, id_latitude, id_longitude): @@ -87,3 +123,18 @@ def add_destination(stream): id_latitude = stream.sources.index('latitude') id_longitude = stream.sources.index('longitude') return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude')) + + +class trip_filter(object): + def __init__(self, id_trip_id, exclude): + self.id_trip_id = id_trip_id + self.exclude = exclude + def __call__(self, data): + if data[self.id_trip_id] in self.exclude: + return False + else: + return True +def filter_out_trips(exclude_trips, stream): + id_trip_id = stream.sources.index('trip_id') + return Filter(stream, trip_filter(id_trip_id, exclude_trips)) + |