aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/data.py b/data.py
index 0493de7..d03e10e 100644
--- a/data.py
+++ b/data.py
@@ -55,31 +55,57 @@ class DayType(Enum):
class TaxiData(Dataset):
provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
-
example_iteration_scheme=None
- def __init__(self, path):
- self.path=path
+ class State:
+ __slots__ = ('file', 'index', 'reader')
+
+ def __init__(self, pathes, has_header=False):
+ if not isinstance(pathes, list):
+ pathes=[pathes]
+ assert len(pathes)
+ self.pathes=pathes
+ self.has_header=has_header
super(TaxiData, self).__init__()
def open(self):
- file=open(self.path)
- reader=csv.reader(file)
- reader.next() # Skip header
- return (file, reader)
+ state=self.State()
+ state.file=open(self.pathes[0])
+ state.index=0
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return state
def close(self, state):
- state[0].close()
+ state.file.close()
def reset(self, state):
- state[0].seek(0)
- state[1]=csv.reader(state[0])
+ if state.index==0:
+ state.file.seek(0)
+ else:
+ state.index=0
+ state.file.close()
+ state.file=open(self.pathes[0])
+ state.reader=csv.reader(state[0])
return state
def get_data(self, state, request=None):
if request is not None:
raise ValueError
- line=state[1].next()
+ try:
+ line=state.reader.next()
+ except StopIteration:
+ state.file.close()
+ state.index+=1
+ if state.index>=len(self.pathes):
+ raise
+ state.file=open(self.pathes[state.index])
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ line=state.reader.next()
+
line[1]=CallType.from_data(line[1]) # call_type
line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand