diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-05 13:10:45 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-05 13:10:45 -0400 |
commit | 66159d9fce0129116e82e74cf3eb1d9e048b253d (patch) | |
tree | ec2d80b94ed77d66d393cf546ddb702891f1ee9f | |
parent | ab1076e00d6a92120e46d4a0085911b4425a0d60 (diff) | |
download | taxi-66159d9fce0129116e82e74cf3eb1d9e048b253d.tar.gz taxi-66159d9fce0129116e82e74cf3eb1d9e048b253d.zip |
Remove enums from hdf5
-rwxr-xr-x | convert_data.py | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/convert_data.py b/convert_data.py index f069580..ca66786 100755 --- a/convert_data.py +++ b/convert_data.py @@ -9,8 +9,6 @@ stands_size = 63 # `wc -l metaData_taxistandsID_name_GPSlocation.csv` - 1 taxi_id_size = 448 # `cut -d, -f 5 train.csv test.csv | sort -u | wc -l` - 1 origin_call_size = 57124 # `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 3 # Minus 3 to ignore "NA", "" and the header -Call_type = h5py.special_dtype(enum=(numpy.int8, {'CENTRAL': 0, 'STAND': 1, 'STREET': 2})) -Day_type = h5py.special_dtype(enum=(numpy.int8, {'NORMAL': 0, 'HOLYDAY': 1, 'HOLYDAY_EVE': 2})) Polyline = h5py.special_dtype(vlen=theano.config.floatX) taxi_id_dict = {} @@ -52,12 +50,12 @@ def read_taxis(input_directory, h5file, dataset): print >> sys.stderr, 'read %s: begin' % dataset size=globals()['%s_size'%dataset] trip_id = numpy.empty(shape=(size,), dtype='S19') - call_type = numpy.empty(shape=(size,), dtype=Call_type) - origin_call = numpy.empty(shape=(size,), dtype=numpy.int32) - origin_stand = numpy.empty(shape=(size,), dtype=numpy.int8) - taxi_id = numpy.empty(shape=(size,), dtype=numpy.int16) - timestamp = numpy.empty(shape=(size,), dtype=numpy.int32) - day_type = numpy.empty(shape=(size,), dtype=Day_type) + call_type = numpy.empty(shape=(size,), dtype=numpy.uint8) + origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32) + origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8) + taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16) + timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32) + day_type = numpy.empty(shape=(size,), dtype=numpy.uint8) missing_data = numpy.empty(shape=(size,), dtype=numpy.bool) latitude = numpy.empty(shape=(size,), dtype=Polyline) longitude = numpy.empty(shape=(size,), dtype=Polyline) @@ -88,12 +86,12 @@ def read_taxis(input_directory, h5file, dataset): return splits def unique(h5file): - unique_taxi_id = numpy.empty(shape=(taxi_id_size,), dtype=numpy.int32) + unique_taxi_id = numpy.empty(shape=(taxi_id_size,), dtype=numpy.uint32) assert len(taxi_id_dict) == taxi_id_size for k, v in taxi_id_dict.items(): unique_taxi_id[v] = k - unique_origin_call = numpy.empty(shape=(origin_call_size+1,), dtype=numpy.int32) + unique_origin_call = numpy.empty(shape=(origin_call_size+1,), dtype=numpy.uint32) assert len(origin_call_dict) == origin_call_size+1 for k, v in origin_call_dict.items(): unique_origin_call[v] = k @@ -112,12 +110,10 @@ def convert(input_directory, save_path): fill_hdf5_file(h5file, split) - for name in ['stands_name', 'stands_latitude', 'stands_longitude']: + for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']: h5file[name].dims[0].label = 'index' for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: h5file[name].dims[0].label = 'batch' - h5file['unique_taxi_id'].dims[0].label = 'unormalized taxi_id' - h5file['unique_origin_call'].dims[0].label = 'unormalized origin_call' h5file.flush() h5file.close() |