diff options
-rw-r--r-- | data.py | 34 |
1 files changed, 21 insertions, 13 deletions
@@ -60,16 +60,17 @@ class DayType(Enum): return 'C' class TaxiData(Dataset): - provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline") example_iteration_scheme=None class State: __slots__ = ('file', 'index', 'reader') - def __init__(self, pathes, has_header=False): + def __init__(self, pathes, columns, has_header=False): if not isinstance(pathes, list): pathes=[pathes] assert len(pathes)>0 + self.columns=columns + self.provides_sources = tuple(map(lambda x: x[0], columns)) self.pathes=pathes self.has_header=has_header super(TaxiData, self).__init__() @@ -113,20 +114,27 @@ class TaxiData(Dataset): state.reader.next() return self.get_data(state) - line[1]=CallType.from_data(line[1]) # call_type - line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call - line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand - line[4]=int(line[4]) # taxi_id - line[5]=int(line[5]) # timestamp - line[6]=DayType.from_data(line[6]) # day_type - line[7]=line[7][0]=='T' # missing_data - line[8]=map(tuple, ast.literal_eval(line[8])) # polyline - return tuple(line) + values = [] + for idx, (_, constructor) in enumerate(self.columns): + values.append(constructor(line[idx])) + return tuple(values) + +taxi_columns = [ + ("trip_id", lambda x: x), + ("call_type", CallType.from_data), + ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]), + ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)), + ("taxi_id", int), + ("timestamp", int), + ("day_type", DayType.from_data), + ("missing_data", lambda x: x[0] == 'T'), + ("polyline", lambda x: map(tuple, ast.literal_eval(x))), +] train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] -train_data=TaxiData(train_files) -valid_data=TaxiData(valid_files) +train_data=TaxiData(train_files, taxi_columns) +valid_data=TaxiData(valid_files, taxi_columns) def train_it(): return DataIterator(DataStream(train_data)) |