aboutsummaryrefslogtreecommitdiff
path: root/data/csv_to_hdf5.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-05 21:55:13 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-05 22:05:21 -0400
commit1f2ff96e6480a62089fcac35154a956c218ed678 (patch)
treed0bb7a2a6d7ba6ae512a2ce3729b1ccbdc21c822 /data/csv_to_hdf5.py
parent54613c1f9cf510ca7a71d6619418f2247515aec6 (diff)
downloadtaxi-1f2ff96e6480a62089fcac35154a956c218ed678.tar.gz
taxi-1f2ff96e6480a62089fcac35154a956c218ed678.zip
Clean data module and generalize use of hdf5.
Diffstat (limited to 'data/csv_to_hdf5.py')
-rwxr-xr-xdata/csv_to_hdf5.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py
new file mode 100755
index 0000000..17217f3
--- /dev/null
+++ b/data/csv_to_hdf5.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python
+
+import ast
+import csv
+import os
+import sys
+
+import h5py
+import numpy
+import theano
+from fuel.converters.base import fill_hdf5_file
+
+import data
+
+
+taxi_id_dict = {}
+origin_call_dict = {0: 0}
+
+def get_unique_taxi_id(val):
+ if val in taxi_id_dict:
+ return taxi_id_dict[val]
+ else:
+ taxi_id_dict[val] = len(taxi_id_dict)
+ return len(taxi_id_dict) - 1
+
+def get_unique_origin_call(val):
+ if val in origin_call_dict:
+ return origin_call_dict[val]
+ else:
+ origin_call_dict[val] = len(origin_call_dict)
+ return len(origin_call_dict) - 1
+
+def read_stands(input_directory, h5file):
+ stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24))
+ stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
+ stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
+ stands_name[0] = 'None'
+ stands_latitude[0] = stands_longitude[0] = 0
+ with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f:
+ reader = csv.reader(f)
+ reader.next() # header
+ for line in reader:
+ id = int(line[0])
+ stands_name[id] = line[1]
+ stands_latitude[id] = float(line[2])
+ stands_longitude[id] = float(line[3])
+ return (('stands', 'stands_name', stands_name),
+ ('stands', 'stands_latitude', stands_latitude),
+ ('stands', 'stands_longitude', stands_longitude))
+
+def read_taxis(input_directory, h5file, dataset):
+ print >> sys.stderr, 'read %s: begin' % dataset
+ size=getattr(data, '%s_size'%dataset)
+ trip_id = numpy.empty(shape=(size,), dtype='S19')
+ call_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32)
+ origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16)
+ timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32)
+ day_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ missing_data = numpy.empty(shape=(size,), dtype=numpy.bool)
+ latitude = numpy.empty(shape=(size,), dtype=data.Polyline)
+ longitude = numpy.empty(shape=(size,), dtype=data.Polyline)
+ with open(os.path.join(input_directory, '%s.csv'%dataset), 'r') as f:
+ reader = csv.reader(f)
+ reader.next() # header
+ id=0
+ for line in reader:
+ if id%10000==0 and id!=0:
+ print >> sys.stderr, 'read %s: %d done' % (dataset, id)
+ trip_id[id] = line[0]
+ call_type[id] = ord(line[1][0]) - ord('A')
+ origin_call[id] = 0 if line[2]=='NA' or line[2]=='' else get_unique_origin_call(int(line[2]))
+ origin_stand[id] = 0 if line[3]=='NA' or line[3]=='' else int(line[3])
+ taxi_id[id] = get_unique_taxi_id(int(line[4]))
+ timestamp[id] = int(line[5])
+ day_type[id] = ord(line[6][0]) - ord('A')
+ missing_data[id] = line[7][0] == 'T'
+ polyline = ast.literal_eval(line[8])
+ latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX)
+ longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX)
+ id+=1
+ splits = ()
+ print >> sys.stderr, 'read %s: writing' % dataset
+ for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
+ splits += ((dataset, name, locals()[name]),)
+ print >> sys.stderr, 'read %s: end' % dataset
+ return splits
+
+def unique(h5file):
+ unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.uint32)
+ assert len(taxi_id_dict) == data.taxi_id_size
+ for k, v in taxi_id_dict.items():
+ unique_taxi_id[v] = k
+
+ unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.uint32)
+ assert len(origin_call_dict) == data.origin_call_size
+ for k, v in origin_call_dict.items():
+ unique_origin_call[v] = k
+
+ return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id),
+ ('unique_origin_call', 'unique_origin_call', unique_origin_call))
+
+def convert(input_directory, save_path):
+ h5file = h5py.File(save_path, 'w')
+ split = ()
+ split += read_stands(input_directory, h5file)
+ split += read_taxis(input_directory, h5file, 'train')
+ print 'First origin_call not present in training set: ', len(origin_call_dict)
+ split += read_taxis(input_directory, h5file, 'test')
+ split += unique(h5file)
+
+ fill_hdf5_file(h5file, split)
+
+ for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']:
+ h5file[name].dims[0].label = 'index'
+ for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']:
+ h5file[name].dims[0].label = 'batch'
+
+ h5file.flush()
+ h5file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) != 3:
+ print >> sys.stderr, 'Usage: %s download_dir output_file' % sys.argv[0]
+ sys.exit(1)
+ convert(sys.argv[1], sys.argv[2])