diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 48 |
1 files changed, 37 insertions, 11 deletions
@@ -55,31 +55,57 @@ class DayType(Enum): class TaxiData(Dataset): provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline") - example_iteration_scheme=None - def __init__(self, path): - self.path=path + class State: + __slots__ = ('file', 'index', 'reader') + + def __init__(self, pathes, has_header=False): + if not isinstance(pathes, list): + pathes=[pathes] + assert len(pathes) + self.pathes=pathes + self.has_header=has_header super(TaxiData, self).__init__() def open(self): - file=open(self.path) - reader=csv.reader(file) - reader.next() # Skip header - return (file, reader) + state=self.State() + state.file=open(self.pathes[0]) + state.index=0 + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + return state def close(self, state): - state[0].close() + state.file.close() def reset(self, state): - state[0].seek(0) - state[1]=csv.reader(state[0]) + if state.index==0: + state.file.seek(0) + else: + state.index=0 + state.file.close() + state.file=open(self.pathes[0]) + state.reader=csv.reader(state[0]) return state def get_data(self, state, request=None): if request is not None: raise ValueError - line=state[1].next() + try: + line=state.reader.next() + except StopIteration: + state.file.close() + state.index+=1 + if state.index>=len(self.pathes): + raise + state.file=open(self.pathes[state.index]) + state.reader=csv.reader(state.file) + if self.has_header: + state.reader.next() + line=state.reader.next() + line[1]=CallType.from_data(line[1]) # call_type line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand |