aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdata/csv_to_hdf5.py16
-rwxr-xr-xdata/init_valid.py14
2 files changed, 15 insertions, 15 deletions
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py
index 97cf428..b011b52 100755
--- a/data/csv_to_hdf5.py
+++ b/data/csv_to_hdf5.py
@@ -51,12 +51,12 @@ def read_taxis(input_directory, h5file, dataset):
print >> sys.stderr, 'read %s: begin' % dataset
size=getattr(data, '%s_size'%dataset)
trip_id = numpy.empty(shape=(size,), dtype='S19')
- call_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
- origin_call = numpy.empty(shape=(size,), dtype=numpy.uint32)
- origin_stand = numpy.empty(shape=(size,), dtype=numpy.uint8)
- taxi_id = numpy.empty(shape=(size,), dtype=numpy.uint16)
- timestamp = numpy.empty(shape=(size,), dtype=numpy.uint32)
- day_type = numpy.empty(shape=(size,), dtype=numpy.uint8)
+ call_type = numpy.empty(shape=(size,), dtype=numpy.int8)
+ origin_call = numpy.empty(shape=(size,), dtype=numpy.int32)
+ origin_stand = numpy.empty(shape=(size,), dtype=numpy.int8)
+ taxi_id = numpy.empty(shape=(size,), dtype=numpy.int16)
+ timestamp = numpy.empty(shape=(size,), dtype=numpy.int32)
+ day_type = numpy.empty(shape=(size,), dtype=numpy.int8)
missing_data = numpy.empty(shape=(size,), dtype=numpy.bool)
latitude = numpy.empty(shape=(size,), dtype=data.Polyline)
longitude = numpy.empty(shape=(size,), dtype=data.Polyline)
@@ -87,12 +87,12 @@ def read_taxis(input_directory, h5file, dataset):
return splits
def unique(h5file):
- unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.uint32)
+ unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.int32)
assert len(taxi_id_dict) == data.taxi_id_size
for k, v in taxi_id_dict.items():
unique_taxi_id[v] = k
- unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.uint32)
+ unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.int32)
assert len(origin_call_dict) == data.origin_call_size
for k, v in origin_call_dict.items():
unique_origin_call[v] = k
diff --git a/data/init_valid.py b/data/init_valid.py
index eed0059..cecaca9 100755
--- a/data/init_valid.py
+++ b/data/init_valid.py
@@ -12,18 +12,18 @@ import data
_fields = {
'trip_id': 'S19',
- 'call_type': numpy.uint8,
- 'origin_call': numpy.uint32,
- 'origin_stand': numpy.uint8,
- 'taxi_id': numpy.uint16,
- 'timestamp': numpy.uint32,
- 'day_type': numpy.uint8,
+ 'call_type': numpy.int8,
+ 'origin_call': numpy.int32,
+ 'origin_stand': numpy.int8,
+ 'taxi_id': numpy.int16,
+ 'timestamp': numpy.int32,
+ 'day_type': numpy.int8,
'missing_data': numpy.bool,
'latitude': data.Polyline,
'longitude': data.Polyline,
'destination_latitude': numpy.float32,
'destination_longitude': numpy.float32,
- 'travel_time': numpy.uint32,
+ 'travel_time': numpy.int32,
}