diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-04 15:33:34 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-04 15:33:34 -0400 |
commit | f9a31bd246e3c4736d3f532b566b7437eba6b4de (patch) | |
tree | ad5c5da41acf52fb6642d6c7db1fbeb149c61c25 | |
parent | 929eaf8dd0233f8423b24b93b78c99fc9df65343 (diff) | |
download | taxi-f9a31bd246e3c4736d3f532b566b7437eba6b4de.tar.gz taxi-f9a31bd246e3c4736d3f532b566b7437eba6b4de.zip |
Fix hdf5 converter
-rwxr-xr-x | convert_data.py | 52 |
1 files changed, 31 insertions, 21 deletions
diff --git a/convert_data.py b/convert_data.py index 9684fa9..f069580 100755 --- a/convert_data.py +++ b/convert_data.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import os, h5py, csv, sys, numpy, theano, ast -from fuel.datasets.hdf5 import H5PYDataset +from fuel.converters.base import fill_hdf5_file test_size = 320 # `wc -l test.csv` - 1 # Minus 1 to ignore the header train_size = 1710670 # `wc -l train.csv` - 1 @@ -31,9 +31,9 @@ def get_unique_origin_call(val): return len(origin_call_dict) - 1 def read_stands(input_directory, h5file): - stands_name = h5file.create_dataset('stands_name', shape=(stands_size+1,), dtype=('a', 24)) - stands_latitude = h5file.create_dataset('stands_latitude', shape=(stands_size+1,), dtype=theano.config.floatX) - stands_longitude = h5file.create_dataset('stands_longitude', shape=(stands_size+1,), dtype=theano.config.floatX) + stands_name = numpy.empty(shape=(stands_size+1,), dtype=('a', 24)) + stands_latitude = numpy.empty(shape=(stands_size+1,), dtype=theano.config.floatX) + stands_longitude = numpy.empty(shape=(stands_size+1,), dtype=theano.config.floatX) stands_name[0] = 'None' stands_latitude[0] = stands_longitude[0] = 0 with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f: @@ -44,9 +44,11 @@ def read_stands(input_directory, h5file): stands_name[id] = line[1] stands_latitude[id] = float(line[2]) stands_longitude[id] = float(line[3]) - return {'stands': {array: (0, stands_size+1) for array in ['stands_name', 'stands_latitude', 'stands_longitude' ]}} + return (('stands', 'stands_name', stands_name), + ('stands', 'stands_latitude', stands_latitude), + ('stands', 'stands_longitude', stands_longitude)) -def read_taxis(input_directory, h5file, dataset, prefix): +def read_taxis(input_directory, h5file, dataset): print >> sys.stderr, 'read %s: begin' % dataset size=globals()['%s_size'%dataset] trip_id = numpy.empty(shape=(size,), dtype='S19') @@ -78,37 +80,45 @@ def read_taxis(input_directory, h5file, dataset, prefix): latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX) longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX) id+=1 - splits = {} + splits = () print >> sys.stderr, 'read %s: writing' % dataset - for array in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: - name = '%s%s' % (prefix, array) - h5file.create_dataset(name, data=locals()[array]) - splits[name] = (0, size) + for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: + splits += ((dataset, name, locals()[name]),) print >> sys.stderr, 'read %s: end' % dataset - return {dataset: splits} + return splits def unique(h5file): - unique_taxi_id = h5file.create_dataset('unique_taxi_id', shape=(taxi_id_size,), dtype=numpy.int32) + unique_taxi_id = numpy.empty(shape=(taxi_id_size,), dtype=numpy.int32) assert len(taxi_id_dict) == taxi_id_size for k, v in taxi_id_dict.items(): unique_taxi_id[v] = k - unique_origin_call = h5file.create_dataset('unique_origin_call', shape=(origin_call_size+1,), dtype=numpy.int32) + unique_origin_call = numpy.empty(shape=(origin_call_size+1,), dtype=numpy.int32) assert len(origin_call_dict) == origin_call_size+1 for k, v in origin_call_dict.items(): unique_origin_call[v] = k - return {'unique': {'unique_taxi_id': (0, taxi_id_size), 'unique_origin_call': (0, origin_call_size+1)}} + return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id), + ('unique_origin_call', 'unique_origin_call', unique_origin_call)) def convert(input_directory, save_path): h5file = h5py.File(save_path, 'w') - split = {} - split.update(read_stands(input_directory, h5file)) - split.update(read_taxis(input_directory, h5file, 'train', '')) + split = () + split += read_stands(input_directory, h5file) + split += read_taxis(input_directory, h5file, 'train') print 'First origin_call not present in training set: ', len(origin_call_dict) - split.update(read_taxis(input_directory, h5file, 'test', 'test_')) - split.update(unique(h5file)) - h5file.attrs['split'] = H5PYDataset.create_split_array(split) + split += read_taxis(input_directory, h5file, 'test') + split += unique(h5file) + + fill_hdf5_file(h5file, split) + + for name in ['stands_name', 'stands_latitude', 'stands_longitude']: + h5file[name].dims[0].label = 'index' + for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: + h5file[name].dims[0].label = 'batch' + h5file['unique_taxi_id'].dims[0].label = 'unormalized taxi_id' + h5file['unique_origin_call'].dims[0].label = 'unormalized origin_call' + h5file.flush() h5file.close() |