diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-05 21:55:13 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-05 22:05:21 -0400 |
commit | 1f2ff96e6480a62089fcac35154a956c218ed678 (patch) | |
tree | d0bb7a2a6d7ba6ae512a2ce3729b1ccbdc21c822 /data/csv_to_hdf5.py | |
parent | 54613c1f9cf510ca7a71d6619418f2247515aec6 (diff) | |
download | taxi-1f2ff96e6480a62089fcac35154a956c218ed678.tar.gz taxi-1f2ff96e6480a62089fcac35154a956c218ed678.zip |
Clean data module and generalize use of hdf5.
Diffstat (limited to 'data/csv_to_hdf5.py')
-rwxr-xr-x | data/csv_to_hdf5.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py new file mode 100755 index 0000000..17217f3 --- /dev/null +++ b/data/csv_to_hdf5.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python + +import ast +import csv +import os +import sys + +import h5py +import numpy +import theano +from fuel.converters.base import fill_hdf5_file + +import data + + +taxi_id_dict = {} +origin_call_dict = {0: 0} + +def get_unique_taxi_id(val): + if val in taxi_id_dict: + return taxi_id_dict[val] + else: + taxi_id_dict[val] = len(taxi_id_dict) + return len(taxi_id_dict) - 1 + +def get_unique_origin_call(val): + if val in origin_call_dict: + return origin_call_dict[val] + else: + origin_call_dict[val] = len(origin_call_dict) + return len(origin_call_dict) - 1 + +def read_stands(input_directory, h5file): + stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24)) + stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX) + stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX) + stands_name[0] = 'None' + stands_latitude[0] = stands_longitude[0] = 0 + with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f: + reader = csv.reader(f) + reader.next() # header + for line in reader: + id = int(line[0]) + stands_name[id] = line[1] + stands_latitude[id] = float(line[2]) + stands_longitude[id] = float(line[3]) + return (('stands', 'stands_name', stands_name), + ('stands', 'stands_latitude', stands_latitude), + ('stands', 'stands_longitude', stands_longitude)) + +def read_taxis(input_directory, h5file, dataset): + print >> sys.stderr, 'read %s: begin' % dataset + size=getattr(data, '%s_size'%dataset) + trip_id = numpy.empty(shape=(size,), dtype='S19') + call_type = numpy.empty(shape=(size,), dtype=numpy.uint8) + origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32) + origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8) + taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16) + timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32) + day_type = numpy.empty(shape=(size,), dtype=numpy.uint8) + missing_data = numpy.empty(shape=(size,), dtype=numpy.bool) + latitude = numpy.empty(shape=(size,), dtype=data.Polyline) + longitude = numpy.empty(shape=(size,), dtype=data.Polyline) + with open(os.path.join(input_directory, '%s.csv'%dataset), 'r') as f: + reader = csv.reader(f) + reader.next() # header + id=0 + for line in reader: + if id%10000==0 and id!=0: + print >> sys.stderr, 'read %s: %d done' % (dataset, id) + trip_id[id] = line[0] + call_type[id] = ord(line[1][0]) - ord('A') + origin_call[id] = 0 if line[2]=='NA' or line[2]=='' else get_unique_origin_call(int(line[2])) + origin_stand[id] = 0 if line[3]=='NA' or line[3]=='' else int(line[3]) + taxi_id[id] = get_unique_taxi_id(int(line[4])) + timestamp[id] = int(line[5]) + day_type[id] = ord(line[6][0]) - ord('A') + missing_data[id] = line[7][0] == 'T' + polyline = ast.literal_eval(line[8]) + latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX) + longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX) + id+=1 + splits = () + print >> sys.stderr, 'read %s: writing' % dataset + for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: + splits += ((dataset, name, locals()[name]),) + print >> sys.stderr, 'read %s: end' % dataset + return splits + +def unique(h5file): + unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.uint32) + assert len(taxi_id_dict) == data.taxi_id_size + for k, v in taxi_id_dict.items(): + unique_taxi_id[v] = k + + unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.uint32) + assert len(origin_call_dict) == data.origin_call_size + for k, v in origin_call_dict.items(): + unique_origin_call[v] = k + + return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id), + ('unique_origin_call', 'unique_origin_call', unique_origin_call)) + +def convert(input_directory, save_path): + h5file = h5py.File(save_path, 'w') + split = () + split += read_stands(input_directory, h5file) + split += read_taxis(input_directory, h5file, 'train') + print 'First origin_call not present in training set: ', len(origin_call_dict) + split += read_taxis(input_directory, h5file, 'test') + split += unique(h5file) + + fill_hdf5_file(h5file, split) + + for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']: + h5file[name].dims[0].label = 'index' + for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: + h5file[name].dims[0].label = 'batch' + + h5file.flush() + h5file.close() + +if __name__ == '__main__': + if len(sys.argv) != 3: + print >> sys.stderr, 'Usage: %s download_dir output_file' % sys.argv[0] + sys.exit(1) + convert(sys.argv[1], sys.argv[2]) |