summaryrefslogtreecommitdiff
path: root/model/lstm.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/lstm.py')
-rw-r--r--model/lstm.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/model/lstm.py b/model/lstm.py
index d928c88..10b090c 100644
--- a/model/lstm.py
+++ b/model/lstm.py
@@ -57,14 +57,16 @@ class Model():
states.append((init_cell, new_cells[-1, :, :]))
if 'xreg' in p and p['xreg'] is not None:
- n, s, w1, w2, w3 = p['xreg']
+ n, s, w1, w2, w3, w4 = p['xreg']
cost_x1 = w1 * ((new_hidden.mean(axis=2) - s)**2).mean()
cost_x2 = w2 * ((new_hidden.mean(axis=(0,1)) - s)**2).mean()
cost_x3 = -w3 * abs(new_hidden - s).mean()
+ cost_x4 = w4 * abs(new_hidden[:-1,:,:]-new_hidden[1:,:,:]).mean()
cost_x1.name = 'cost_x1_%d'%i
cost_x2.name = 'cost_x2_%d'%i
cost_x3.name = 'cost_x3_%d'%i
- costs_xreg += [cost_x1, cost_x2, cost_x3]
+ cost_x4.name = 'cost_x4_%d'%i
+ costs_xreg += [cost_x1, cost_x2, cost_x3, cost_x4]
dims.append(p['dim'])
hidden.append(new_hidden)