1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
import ast
import csv
import numpy
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator
import data
from data.hdf5 import origin_call_normalize, taxi_id_normalize
class TaxiData(Dataset):
example_iteration_scheme=None
class State:
__slots__ = ('file', 'index', 'reader')
def __init__(self, pathes, columns, has_header=False):
if not isinstance(pathes, list):
pathes=[pathes]
assert len(pathes)>0
self.columns=columns
self.provides_sources = tuple(map(lambda x: x[0], columns))
self.pathes=pathes
self.has_header=has_header
super(TaxiData, self).__init__()
def open(self):
state=self.State()
state.file=open(self.pathes[0])
state.index=0
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
return state
def close(self, state):
state.file.close()
def reset(self, state):
if state.index==0:
state.file.seek(0)
else:
state.index=0
state.file.close()
state.file=open(self.pathes[0])
state.reader=csv.reader(state.file)
return state
def get_data(self, state, request=None):
if request is not None:
raise ValueError
try:
line=state.reader.next()
except (ValueError, StopIteration):
# print state.index
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
raise StopIteration
state.file=open(self.pathes[state.index])
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
return self.get_data(state)
values = []
for _, constructor in self.columns:
values.append(constructor(line))
return tuple(values)
taxi_columns = [
("trip_id", lambda l: l[0]),
("call_type", lambda l: ord(l[1])-ord('A')),
("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))),
("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
("taxi_id", lambda l: taxi_id_normalize(int(l[4]))),
("timestamp", lambda l: int(l[5])),
("day_type", lambda l: ord(l[6])-ord('A')),
("missing_data", lambda l: l[7][0] == 'T'),
("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]
taxi_columns_valid = taxi_columns + [
("destination_longitude", lambda l: numpy.float32(float(l[9]))),
("destination_latitude", lambda l: numpy.float32(float(l[10]))),
("time", lambda l: int(l[11])),
]
train_file="%s/train.csv" % data.path
valid_file="%s/valid2-cut.csv" % data.path
test_file="%s/test.csv" % data.path
train_data=TaxiData(train_file, taxi_columns, has_header=True)
valid_data = TaxiData(valid_file, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)
valid_trips = [l for l in open("%s/valid2-cut-ids.txt" % data.path)]
def train_it():
return DataIterator(DataStream(train_data))
def test_it():
return DataIterator(DataStream(valid_data))
|