aboutsummaryrefslogtreecommitdiff
path: root/data/make_valid_cut.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-05 21:55:13 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-05 22:05:21 -0400
commit1f2ff96e6480a62089fcac35154a956c218ed678 (patch)
treed0bb7a2a6d7ba6ae512a2ce3729b1ccbdc21c822 /data/make_valid_cut.py
parent54613c1f9cf510ca7a71d6619418f2247515aec6 (diff)
downloadtaxi-1f2ff96e6480a62089fcac35154a956c218ed678.tar.gz
taxi-1f2ff96e6480a62089fcac35154a956c218ed678.zip
Clean data module and generalize use of hdf5.
Diffstat (limited to 'data/make_valid_cut.py')
-rwxr-xr-xdata/make_valid_cut.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/data/make_valid_cut.py b/data/make_valid_cut.py
new file mode 100755
index 0000000..d5be083
--- /dev/null
+++ b/data/make_valid_cut.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+# Make a valid dataset by cutting the training set at specified timestamps
+
+import os
+import sys
+import importlib
+
+import h5py
+import numpy
+
+import data
+from data.hdf5 import taxi_it
+
+
+_fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time']
+
+def make_valid(cutfile, outpath):
+ cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts
+
+ valid = []
+
+ for line in taxi_it('train'):
+ time = line['timestamp']
+ latitude = line['latitude']
+ longitude = line['longitude']
+
+ if len(latitude) == 0:
+ continue
+
+ for ts in cuts:
+ if time <= ts and time + 15 * (len(latitude) - 1) >= ts:
+ # keep it
+ n = (ts - time) / 15 + 1
+ line.update({
+ 'latitude': latitude[:n],
+ 'longitude': longitude[:n],
+ 'destination_latitude': latitude[-1],
+ 'destination_longitude': longitude[-1],
+ 'travel_time': 15 * (len(latitude)-1)
+ })
+ valid.append(line)
+
+ file = h5py.File(outpath, 'a')
+ clen = file['trip_id'].shape[0]
+ alen = len(valid)
+ for field in _fields:
+ dset = file[field]
+ dset.resize((clen + alen,))
+ for i in xrange(alen):
+ dset[clen + i] = valid[i][field]
+
+ splits = file.attrs['split']
+ slen = splits.shape[0]
+ splits = numpy.resize(splits, (slen+len(_fields),))
+ for (i, field) in enumerate(_fields):
+ splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8')
+ splits[slen+i]['source'] = field.encode('utf8')
+ splits[slen+i]['start'] = clen
+ splits[slen+i]['stop'] = alen
+ splits[slen+i]['available'] = True
+ splits[slen+i]['comment'] = '.'
+ file.attrs['split'] = splits
+
+ file.flush()
+ file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) < 2 or len(sys.argv) > 3:
+ print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0]
+ sys.exit(1)
+ outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2]
+ make_valid(sys.argv[1], outpath)