From 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 23 Jul 2015 21:20:32 -0400 Subject: Use new tvt dataset with option --tvt --- data/__init__.py | 48 ++++++++++++++++++++++++++++++++++++------------ data/transformers.py | 17 +++++++++++------ 2 files changed, 47 insertions(+), 18 deletions(-) (limited to 'data') diff --git a/data/__init__.py b/data/__init__.py index 2121033..9d01d2a 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,5 @@ import os +import sys import h5py import numpy @@ -8,23 +9,46 @@ path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle') Polyline = h5py.special_dtype(vlen=numpy.float32) -# `wc -l test.csv` - 1 # Minus 1 to ignore the header -test_size = 320 - -# `wc -l train.csv` - 1 -train_size = 1710670 - # `wc -l metaData_taxistandsID_name_GPSlocation.csv` stands_size = 64 # include 0 ("no origin_stands") # `cut -d, -f 5 train.csv test.csv | sort -u | wc -l` - 1 taxi_id_size = 448 -# `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2 -origin_call_size = 57125 # include 0 ("no origin_call") - -# As printed by csv_to_hdf5.py -origin_call_train_size = 57106 - train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32) train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32)) + +tvt = '--tvt' in sys.argv + +if tvt: + test_size = 19770 + valid_size = 19427 + train_size = 1671473 + + origin_call_size = 57106 + origin_call_train_size = 57106 + + valid_set = 'valid' + valid_ds = 'tvt.hdf5' + traintest_ds = 'tvt.hdf5' + +else: + # `wc -l test.csv` - 1 # Minus 1 to ignore the header + test_size = 320 + + # `wc -l train.csv` - 1 + train_size = 1710670 + + # `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2 + origin_call_size = 57125 # include 0 ("no origin_call") + + # As printed by csv_to_hdf5.py + origin_call_train_size = 57106 + + if '--largevalid' in sys.argv: + valid_set = 'cuts/large_valid' + else: + valid_set = 'cuts/test_times_0' + + valid_ds = 'valid.hdf5' + traintest_ds = 'data.hdf5' diff --git a/data/transformers.py b/data/transformers.py index c2eb97e..b3a8486 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -70,7 +70,9 @@ class TaxiGenerateSplits(Transformer): def __init__(self, data_stream, max_splits=-1): super(TaxiGenerateSplits, self).__init__(data_stream) - self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'travel_time') + self.sources = data_stream.sources + if not data.tvt: + self.sources += ('destination_latitude', 'destination_longitude', 'travel_time') self.max_splits = max_splits self.data = None self.splits = [] @@ -100,12 +102,15 @@ class TaxiGenerateSplits(Transformer): r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) - dlat = numpy.float32(self.data[self.id_latitude][-1]) - dlon = numpy.float32(self.data[self.id_longitude][-1]) - ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - - return tuple(r + [dlat, dlon, ttime]) + r = tuple(r) + if data.tvt: + return r + else: + dlat = numpy.float32(self.data[self.id_latitude][-1]) + dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) + return r + (dlat, dlon, ttime) class _taxi_add_first_last_len_helper(object): def __init__(self, k, id_latitude, id_longitude): -- cgit v1.2.3