aboutsummaryrefslogblamecommitdiff
path: root/data.py
blob: 5e3b40984a6aba5876dec98662bcc60e31e11262 (plain) (tree)
1
2
3
4
5
6
7
8
               
             





                                      



                                              
 
                                                                                   
 















































                                                                                                                                     

                                 








                                                 


                                        






                                           

                           
                          

                           





                                           
                                           




                                            

                                    

                             







                                                     
                                       
 
                                                       
                                                                                             







                                                                                  

                                                                            

                                




                                               
                                               
import ast, csv
import socket
import fuel
from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator

if socket.gethostname() == "adeb.laptop":
    DATA_PATH = "/Users/adeb/data/taxi"
else:
    PREFIX="/data/lisatmp3/auvolat/taxikaggle"

client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}

class CallType(Enum):
    CENTRAL = 0
    STAND = 1
    STREET = 2

    @classmethod
    def from_data(cls, val):
        if val=='A':
            return cls.CENTRAL
        elif val=='B':
            return cls.STAND
        elif val=='C':
            return cls.STREET

    @classmethod
    def to_data(cls, val):
        if val==cls.CENTRAL:
            return 'A'
        elif val==cls.STAND:
            return 'B'
        elif val==cls.STREET:
            return 'C'

class DayType(Enum):
    NORMAL = 0
    HOLIDAY = 1
    HOLIDAY_EVE = 2

    @classmethod
    def from_data(cls, val):
        if val=='A':
            return cls.NORMAL
        elif val=='B':
            return cls.HOLIDAY
        elif val=='C':
            return cls.HOLIDAY_EVE

    @classmethod
    def to_data(cls, val):
        if val==cls.NORMAL:
            return 'A'
        elif val==cls.HOLIDAY:
            return 'B'
        elif val==cls.HOLIDAY_EVE:
            return 'C'

class TaxiData(Dataset):
    provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
    example_iteration_scheme=None

    class State:
        __slots__ = ('file', 'index', 'reader')

    def __init__(self, pathes, has_header=False):
        if not isinstance(pathes, list):
            pathes=[pathes]
        assert len(pathes)
        self.pathes=pathes
        self.has_header=has_header
        super(TaxiData, self).__init__()

    def open(self):
        state=self.State()
        state.file=open(self.pathes[0])
        state.index=0
        state.reader=csv.reader(state.file)
        if self.has_header:
            state.reader.next()
        return state

    def close(self, state):
        state.file.close()

    def reset(self, state):
        if state.index==0:
            state.file.seek(0)
        else:
            state.index=0
            state.file.close()
            state.file=open(self.pathes[0])
        state.reader=csv.reader(state.file)
        return state

    def get_data(self, state, request=None):
        if request is not None:
            raise ValueError
        try:
            line=state.reader.next()
        except StopIteration:
            print state.index
            state.file.close()
            state.index+=1
            if state.index>=len(self.pathes):
                raise
            state.file=open(self.pathes[state.index])
            state.reader=csv.reader(state.file)
            if self.has_header:
                state.reader.next()
            return self.get_data(state)

        line[1]=CallType.from_data(line[1]) # call_type
        line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
        line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
        line[4]=int(line[4]) # taxi_id
        line[5]=int(line[5]) # timestamp
        line[6]=DayType.from_data(line[6]) # day_type
        line[7]=line[7][0]=='T' # missing_data
        line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
        return tuple(line)

train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
train_data=TaxiData(train_files)
valid_data=TaxiData(valid_files)

def train_it():
    return DataIterator(DataStream(train_data))

def test_it():
    return DataIterator(DataStream(valid_data))