aboutsummaryrefslogblamecommitdiff
path: root/data.py
blob: 92efefc07857b9d3bd139094ecee33bf4c81c054 (plain) (tree)
1
2
3
4
5
6
7
8
9
               
             
           
            



                                      
             
 


                                         
                                                 
 
                                                     
 

                                                                                        
 

                       



                      

















                                                     














































                                  

                                 


                                               
                                                          

                                        
                            

                                                                   

                                  


                                        






                                           

                           
                          

                           





                                           
                                           




                                            

                                    

                                           


                                             
                                   



                                                     
                                       
 
                   

                                            


                            










                                                                                             
 
 
                                     


                                                     

 

                                                                            
                                      
 
                                              
                                                      
                                                              




                                               
                                               
import ast, csv
import socket
import fuel
import numpy
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator
import theano

if socket.gethostname() == "adeb.laptop":
    DATA_PATH = "/Users/adeb/data/taxi"
else:
    DATA_PATH="/data/lisatmp3/auvolat/taxikaggle"

H5DATA_PATH = '/data/lisatmp3/simonet/taxi/data.hdf5'

porto_center = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
data_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))

n_clients = 57124
n_train_clients = 57105
n_stands = 63

dataset_size = 1710670

# ---- Read client IDs and create reverse dictionnary

def make_client_ids():
    f = h5py.File(H5DATA_PATH, "r")
    l = f['uniq_origin_call']
    r = {}
    for i in range(l.shape[0]):
        r[l[i]] = i
    return r

client_ids = make_client_ids()

def get_client_id(n):
    if n in client_ids:
        return client_ids[n]
    else:
        return 0

class CallType(Enum):
    CENTRAL = 0
    STAND = 1
    STREET = 2

    @classmethod
    def from_data(cls, val):
        if val=='A':
            return cls.CENTRAL
        elif val=='B':
            return cls.STAND
        elif val=='C':
            return cls.STREET

    @classmethod
    def to_data(cls, val):
        if val==cls.CENTRAL:
            return 'A'
        elif val==cls.STAND:
            return 'B'
        elif val==cls.STREET:
            return 'C'

class DayType(Enum):
    NORMAL = 0
    HOLIDAY = 1
    HOLIDAY_EVE = 2

    @classmethod
    def from_data(cls, val):
        if val=='A':
            return cls.NORMAL
        elif val=='B':
            return cls.HOLIDAY
        elif val=='C':
            return cls.HOLIDAY_EVE

    @classmethod
    def to_data(cls, val):
        if val==cls.NORMAL:
            return 'A'
        elif val==cls.HOLIDAY:
            return 'B'
        elif val==cls.HOLIDAY_EVE:
            return 'C'

class TaxiData(Dataset):
    example_iteration_scheme=None

    class State:
        __slots__ = ('file', 'index', 'reader')

    def __init__(self, pathes, columns, has_header=False):
        if not isinstance(pathes, list):
            pathes=[pathes]
        assert len(pathes)>0
        self.columns=columns
        self.provides_sources = tuple(map(lambda x: x[0], columns))
        self.pathes=pathes
        self.has_header=has_header
        super(TaxiData, self).__init__()

    def open(self):
        state=self.State()
        state.file=open(self.pathes[0])
        state.index=0
        state.reader=csv.reader(state.file)
        if self.has_header:
            state.reader.next()
        return state

    def close(self, state):
        state.file.close()

    def reset(self, state):
        if state.index==0:
            state.file.seek(0)
        else:
            state.index=0
            state.file.close()
            state.file=open(self.pathes[0])
        state.reader=csv.reader(state.file)
        return state

    def get_data(self, state, request=None):
        if request is not None:
            raise ValueError
        try:
            line=state.reader.next()
        except (ValueError, StopIteration):
            # print state.index
            state.file.close()
            state.index+=1
            if state.index>=len(self.pathes):
                raise StopIteration
            state.file=open(self.pathes[state.index])
            state.reader=csv.reader(state.file)
            if self.has_header:
                state.reader.next()
            return self.get_data(state)

        values = []
        for _, constructor in self.columns:
            values.append(constructor(line))
        return tuple(values)

taxi_columns = [
    ("trip_id", lambda l: l[0]),
    ("call_type", lambda l: CallType.from_data(l[1])),
    ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
    ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
    ("taxi_id", lambda l: int(l[4])),
    ("timestamp", lambda l: int(l[5])),
    ("day_type", lambda l: DayType.from_data(l[6])),
    ("missing_data", lambda l: l[7][0] == 'T'),
    ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
    ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
    ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))),
]

taxi_columns_valid = taxi_columns + [
    ("destination_longitude", lambda l: float(l[9])),
    ("destination_latitude", lambda l: float(l[10])),
    ("time", lambda l: int(l[11])),
]

train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
test_file="%s/test.csv" % (DATA_PATH,)

train_data=TaxiData(train_files, taxi_columns)
valid_data = TaxiData(valid_files, taxi_columns_valid)
test_data = TaxiData(test_file, taxi_columns, has_header=True)

def train_it():
    return DataIterator(DataStream(train_data))

def test_it():
    return DataIterator(DataStream(valid_data))