import ast, csv
import socket
import fuel
import numpy
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator
import theano
if socket.gethostname() == "adeb.laptop":
DATA_PATH = "/Users/adeb/data/taxi"
else:
DATA_PATH="/data/lisatmp3/auvolat/taxikaggle"
H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5'
porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
n_clients = 57124
n_train_clients = 57105
n_stands = 63
dataset_size = 1710670
# ---- Read client IDs and create reverse dictionnary
def make_client_ids():
f = h5py.File(H5DATA_PATH, "r")
l = f['uniq_origin_call']
r = {}
for i in range(l.shape[0]):
r[l[i]] = i
return r
client_ids = make_client_ids()
def get_client_id(n):
if n in client_ids:
return client_ids[n]
else:
return 0
class CallType(Enum):
CENTRAL = 0
STAND = 1
STREET = 2
@classmethod
def from_data(cls, val):
if val=='A':
return cls.CENTRAL
elif val=='B':
return cls.STAND
elif val=='C':
return cls.STREET
@classmethod
def to_data(cls, val):
if val==cls.CENTRAL:
return 'A'
elif val==cls.STAND:
return 'B'
elif val==cls.STREET:
return 'C'
class DayType(Enum):
NORMAL = 0
HOLIDAY = 1
HOLIDAY_EVE = 2
@classmethod
def from_data(cls, val):
if val=='A':
return cls.NORMAL
elif val=='B':
return cls.HOLIDAY
elif val=='C':
return cls.HOLIDAY_EVE
@classmethod
def to_data(cls, val):
if val==cls.NORMAL:
return 'A'
elif val==cls.HOLIDAY:
return 'B'
elif val==cls.HOLIDAY_EVE:
return 'C'
class TaxiData(Dataset):
example_iteration_scheme=None
class State:
__slots__ = ('file', 'index', 'reader')
def __init__(self, pathes, columns, has_header=False):
if not isinstance(pathes, list):
pathes=[pathes]
assert len(pathes)>0
self.columns=columns
self.provides_sources = tuple(map(lambda x: x[0], columns))
self.pathes=pathes
self.has_header=has_header
super(TaxiData, self).__init__()
def open(self):
state=self.State()
state.file=open(self.pathes[0])
state.index=0
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
return state
def close(self, state):
state.file.close()
def reset(self, state):
if state.index==0:
state.file.seek(0)
else:
state.index=0
state.file.close()
state.file=open(self.pathes[0])
state.reader=csv.reader(state.file)
return state
def get_data(self, state, request=None):
if request is not None:
raise ValueError
try:
line=state.reader.next()
except (ValueError, StopIteration):
# print state.index
state.file.close()
state.index+=1
if state.index>=len(self.pathes):
raise StopIteration
state.file=open(self.pathes[state.index])
state.reader=csv.reader(state.file)
if self.has_header:
state.reader.next()
return self.get_data(state)
values = []
for _, constructor in self.columns:
values.append(constructor(line))
return tuple(values)
taxi_columns = [
("trip_id", lambda l: l[0]),
("call_type", lambda l: CallType.from_data(l[1])),
("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
("taxi_id", lambda l: int(l[4])),
("timestamp", lambda l: int(l[5])),
("day_type", lambda l: DayType.from_data(l[6])),
("missing_data", lambda l: l[7][0] == 'T'),
("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]
taxi_columns_valid = taxi_columns + [
("destination_longitude", lambda l: float(l[9])),
("destination_latitude", lambda l: float(l[10])),
("time", lambda l: int(l[11])),
]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
test_file="%s/test.csv" % (DATA_PATH,)
train_data=TaxiData(train_files, taxi_columns)
valid_data = TaxiData(valid_files, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)
def train_it():
return DataIterator(DataStream(train_data))
def test_it():
return DataIterator(DataStream(valid_data))