diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 10 |
1 files changed, 6 insertions, 4 deletions
@@ -7,6 +7,8 @@ from fuel.iterator import DataIterator PREFIX="/data/lisatmp3/auvolat/taxikaggle" +client_ids = {int(x): y+1 for y, x in enumerate(open(PREFIX+"/client_ids.txt"))} + class CallType(Enum): CENTRAL = 0 STAND = 1 @@ -87,7 +89,7 @@ class TaxiData(Dataset): state.index=0 state.file.close() state.file=open(self.pathes[0]) - state.reader=csv.reader(state[0]) + state.reader=csv.reader(state.file) return state def get_data(self, state, request=None): @@ -95,7 +97,7 @@ class TaxiData(Dataset): raise ValueError try: line=state.reader.next() - except StopIteration: + except ValueError: state.file.close() state.index+=1 if state.index>=len(self.pathes): @@ -104,10 +106,10 @@ class TaxiData(Dataset): state.reader=csv.reader(state.file) if self.has_header: state.reader.next() - line=state.reader.next() + return get_data(self, state) line[1]=CallType.from_data(line[1]) # call_type - line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call + line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand line[4]=int(line[4]) # taxi_id line[5]=int(line[5]) # timestamp |