diff options
Diffstat (limited to 'model/lstm.py')
-rw-r--r-- | model/lstm.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/model/lstm.py b/model/lstm.py index d928c88..10b090c 100644 --- a/model/lstm.py +++ b/model/lstm.py @@ -57,14 +57,16 @@ class Model(): states.append((init_cell, new_cells[-1, :, :])) if 'xreg' in p and p['xreg'] is not None: - n, s, w1, w2, w3 = p['xreg'] + n, s, w1, w2, w3, w4 = p['xreg'] cost_x1 = w1 * ((new_hidden.mean(axis=2) - s)**2).mean() cost_x2 = w2 * ((new_hidden.mean(axis=(0,1)) - s)**2).mean() cost_x3 = -w3 * abs(new_hidden - s).mean() + cost_x4 = w4 * abs(new_hidden[:-1,:,:]-new_hidden[1:,:,:]).mean() cost_x1.name = 'cost_x1_%d'%i cost_x2.name = 'cost_x2_%d'%i cost_x3.name = 'cost_x3_%d'%i - costs_xreg += [cost_x1, cost_x2, cost_x3] + cost_x4.name = 'cost_x4_%d'%i + costs_xreg += [cost_x1, cost_x2, cost_x3, cost_x4] dims.append(p['dim']) hidden.append(new_hidden) |