aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data.py34
1 files changed, 21 insertions, 13 deletions
diff --git a/data.py b/data.py
index 8015054..5ebdcde 100644
--- a/data.py
+++ b/data.py
@@ -60,16 +60,17 @@ class DayType(Enum):
return 'C'
class TaxiData(Dataset):
- provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
example_iteration_scheme=None
class State:
__slots__ = ('file', 'index', 'reader')
- def __init__(self, pathes, has_header=False):
+ def __init__(self, pathes, columns, has_header=False):
if not isinstance(pathes, list):
pathes=[pathes]
assert len(pathes)>0
+ self.columns=columns
+ self.provides_sources = tuple(map(lambda x: x[0], columns))
self.pathes=pathes
self.has_header=has_header
super(TaxiData, self).__init__()
@@ -113,20 +114,27 @@ class TaxiData(Dataset):
state.reader.next()
return self.get_data(state)
- line[1]=CallType.from_data(line[1]) # call_type
- line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
- line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
- line[4]=int(line[4]) # taxi_id
- line[5]=int(line[5]) # timestamp
- line[6]=DayType.from_data(line[6]) # day_type
- line[7]=line[7][0]=='T' # missing_data
- line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
- return tuple(line)
+ values = []
+ for idx, (_, constructor) in enumerate(self.columns):
+ values.append(constructor(line[idx]))
+ return tuple(values)
+
+taxi_columns = [
+ ("trip_id", lambda x: x),
+ ("call_type", CallType.from_data),
+ ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]),
+ ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
+ ("taxi_id", int),
+ ("timestamp", int),
+ ("day_type", DayType.from_data),
+ ("missing_data", lambda x: x[0] == 'T'),
+ ("polyline", lambda x: map(tuple, ast.literal_eval(x))),
+]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
-train_data=TaxiData(train_files)
-valid_data=TaxiData(valid_files)
+train_data=TaxiData(train_files, taxi_columns)
+valid_data=TaxiData(valid_files, taxi_columns)
def train_it():
return DataIterator(DataStream(train_data))