diff options
Diffstat (limited to 'data')
-rwxr-xr-x | data/csv_to_hdf5.py | 16 | ||||
-rwxr-xr-x | data/init_valid.py | 14 |
2 files changed, 15 insertions, 15 deletions
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py index 97cf428..b011b52 100755 --- a/data/csv_to_hdf5.py +++ b/data/csv_to_hdf5.py @@ -51,12 +51,12 @@ def read_taxis(input_directory, h5file, dataset): print >> sys.stderr, 'read %s: begin' % dataset size=getattr(data, '%s_size'%dataset) trip_id = numpy.empty(shape=(size,), dtype='S19') - call_type = numpy.empty(shape=(size,), dtype=numpy.uint8) - origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32) - origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8) - taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16) - timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32) - day_type = numpy.empty(shape=(size,), dtype=numpy.uint8) + call_type = numpy.empty(shape=(size,), dtype=numpy.int8) + origin_call = numpy.empty(shape=(size,), dtype=numpy.int32) + origin_stand = numpy.empty(shape=(size,), dtype=numpy.int8) + taxi_id = numpy.empty(shape=(size,), dtype=numpy.int16) + timestamp = numpy.empty(shape=(size,), dtype=numpy.int32) + day_type = numpy.empty(shape=(size,), dtype=numpy.int8) missing_data = numpy.empty(shape=(size,), dtype=numpy.bool) latitude = numpy.empty(shape=(size,), dtype=data.Polyline) longitude = numpy.empty(shape=(size,), dtype=data.Polyline) @@ -87,12 +87,12 @@ def read_taxis(input_directory, h5file, dataset): return splits def unique(h5file): - unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.uint32) + unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.int32) assert len(taxi_id_dict) == data.taxi_id_size for k, v in taxi_id_dict.items(): unique_taxi_id[v] = k - unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.uint32) + unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.int32) assert len(origin_call_dict) == data.origin_call_size for k, v in origin_call_dict.items(): unique_origin_call[v] = k diff --git a/data/init_valid.py b/data/init_valid.py index eed0059..cecaca9 100755 --- a/data/init_valid.py +++ b/data/init_valid.py @@ -12,18 +12,18 @@ import data _fields = { 'trip_id': 'S19', - 'call_type': numpy.uint8, - 'origin_call': numpy.uint32, - 'origin_stand': numpy.uint8, - 'taxi_id': numpy.uint16, - 'timestamp': numpy.uint32, - 'day_type': numpy.uint8, + 'call_type': numpy.int8, + 'origin_call': numpy.int32, + 'origin_stand': numpy.int8, + 'taxi_id': numpy.int16, + 'timestamp': numpy.int32, + 'day_type': numpy.int8, 'missing_data': numpy.bool, 'latitude': data.Polyline, 'longitude': data.Polyline, 'destination_latitude': numpy.float32, 'destination_longitude': numpy.float32, - 'travel_time': numpy.uint32, + 'travel_time': numpy.int32, } |