aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-04-24 16:37:49 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-04-24 16:37:49 -0400
commit0be3ebaa19f2cf8a630565434e785e5c24929a14 (patch)
treec461a2a5d534d9ac2c7383e34e68b7b4fac22de0 /data.py
parent5589a8af8967cfc73d3b6fda8f86acc0d08172b8 (diff)
downloadtaxi-0be3ebaa19f2cf8a630565434e785e5c24929a14.tar.gz
taxi-0be3ebaa19f2cf8a630565434e785e5c24929a14.zip
Make TaxiData accept multiple files
Diffstat (limited to 'data.py')
-rw-r--r--data.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/data.py b/data.py
index 0493de7..d03e10e 100644
--- a/data.py
+++ b/data.py
@@ -55,31 +55,57 @@ class DayType(Enum):
class TaxiData(Dataset):
provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
-
example_iteration_scheme=None
- def __init__(self, path):
- self.path=path
+ class State:
+ __slots__ = ('file', 'index', 'reader')
+
+ def __init__(self, pathes, has_header=False):
+ if not isinstance(pathes, list):
+ pathes=[pathes]
+ assert len(pathes)
+ self.pathes=pathes
+ self.has_header=has_header
super(TaxiData, self).__init__()
def open(self):
- file=open(self.path)
- reader=csv.reader(file)
- reader.next() # Skip header
- return (file, reader)
+ state=self.State()
+ state.file=open(self.pathes[0])
+ state.index=0
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return state
def close(self, state):
- state[0].close()
+ state.file.close()
def reset(self, state):
- state[0].seek(0)
- state[1]=csv.reader(state[0])
+ if state.index==0:
+ state.file.seek(0)
+ else:
+ state.index=0
+ state.file.close()
+ state.file=open(self.pathes[0])
+ state.reader=csv.reader(state[0])
return state
def get_data(self, state, request=None):
if request is not None:
raise ValueError
- line=state[1].next()
+ try:
+ line=state.reader.next()
+ except StopIteration:
+ state.file.close()
+ state.index+=1
+ if state.index>=len(self.pathes):
+ raise
+ state.file=open(self.pathes[state.index])
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ line=state.reader.next()
+
line[1]=CallType.from_data(line[1]) # call_type
line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand