1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
#!/usr/bin/env python
# Make a valid dataset by cutting the training set at specified timestamps
import os
import sys
import importlib
import h5py
import numpy
import data
from data.hdf5 import taxi_it
_fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time']
def make_valid(cutfile, outpath):
cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts
valid = []
for line in taxi_it('train'):
time = line['timestamp']
latitude = line['latitude']
longitude = line['longitude']
if len(latitude) == 0:
continue
for ts in cuts:
if time <= ts and time + 15 * (len(latitude) - 1) >= ts:
# keep it
n = (ts - time) / 15 + 1
line.update({
'latitude': latitude[:n],
'longitude': longitude[:n],
'destination_latitude': latitude[-1],
'destination_longitude': longitude[-1],
'travel_time': 15 * (len(latitude)-1)
})
valid.append(line)
file = h5py.File(outpath, 'a')
clen = file['trip_id'].shape[0]
alen = len(valid)
for field in _fields:
dset = file[field]
dset.resize((clen + alen,))
for i in xrange(alen):
dset[clen + i] = valid[i][field]
splits = file.attrs['split']
slen = splits.shape[0]
splits = numpy.resize(splits, (slen+len(_fields),))
for (i, field) in enumerate(_fields):
splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8')
splits[slen+i]['source'] = field.encode('utf8')
splits[slen+i]['start'] = clen
splits[slen+i]['stop'] = alen
splits[slen+i]['available'] = True
splits[slen+i]['comment'] = '.'
file.attrs['split'] = splits
file.flush()
file.close()
if __name__ == '__main__':
if len(sys.argv) < 2 or len(sys.argv) > 3:
print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0]
sys.exit(1)
outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2]
make_valid(sys.argv[1], outpath)
|