diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 17:32:57 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-24 17:32:57 -0400 |
commit | 0527e6e696fa1832d599473099429295dea31650 (patch) | |
tree | 38d3e4946236e01746eb4216f83e00a3ba5ea14f /data.py | |
parent | bd2826df73554207c88c5918d86fd9707d9e3753 (diff) | |
download | taxi-0527e6e696fa1832d599473099429295dea31650.tar.gz taxi-0527e6e696fa1832d599473099429295dea31650.zip |
It kind of works (at least it does something now)
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 10 |
1 files changed, 6 insertions, 4 deletions
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator PREFIX="/data/lisatmp3/auvolat/taxikaggle" +client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))} + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -87,7 +89,7 @@ class TaxiData(Dataset): state.index=0 state.file.close() state.file=open(self.pathes[0]) - state.reader=csv.reader(state[0]) + state.reader=csv.reader(state.file) return state def get_data(self, state, request=None): @@ -95,7 +97,7 @@ class TaxiData(Dataset): raise ValueError try: line=state.reader.next() - except StopIteration: + except ValueError: state.file.close() state.index+=1 if state.index>=len(self.pathes): @@ -104,10 +106,10 @@ class TaxiData(Dataset): state.reader=csv.reader(state.file) if self.has_header: state.reader.next() - line=state.reader.next() + return get_data(self, state) line[1]=CallType.from_data(line[1]) # call_type - line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call + line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand line[4]=int(line[4]) # taxi_id line[5]=int(line[5]) # timestamp |