update evaluate

Esse commit está contido em:
Charlie Hewitt
2018-01-03 16:49:29 +00:00
commit 8a1cd872c9
+18 -6
Ver Arquivo
@@ -161,18 +161,21 @@ def eval_from_file(path):
valence_p = []
arousal_t = []
arousal_p = []
ratings = [[], [], [], [], [], [], [], []]
with open(path, 'r') as csvfile:
reader = csv.reader(csvfile, delimiter=',')
for row in reader:
if row[0] == '' or row[0] == 'id':
continue
true_l.append(int(row[2]))
pred_l.append(int(row[5]))
probs = []
for i in range(6, 13):
probs.append(float(row[i]))
pred_r.append(probs)
if int(row[1]) != 1 and int(row[1]) != 3:
true_l.append(int(row[2]))
pred_l.append(int(row[5]))
probs = []
for i in range(6, 13):
probs.append(float(row[i]))
pred_r.append(probs)
if row[15] == ' 0.00.0' or row[15] == '':
valence_t.append(0)
@@ -185,6 +188,8 @@ def eval_from_file(path):
else:
arousal_t.append(float(row[16]))
arousal_p.append(float(row[4]))
ratings[int(row[2])].append(int(float(row[14])))
true_r = to_categorical(true_l, num_classes=8)
print('ACC'.ljust(20), ACC(true_l, pred_l))
@@ -193,12 +198,19 @@ def eval_from_file(path):
print('ALPHA'.ljust(20), ALPHA(true_l, pred_l))
print('AUCPR'.ljust(20), AUCPR(true_r, pred_r))
print('AUC'.ljust(20), AUC(true_r, pred_r))
print(confusion_matrix(true_l, pred_l))
print('')
print(''.ljust(20), 'VALENCE'.ljust(20), 'AROUSAL')
print('RMSE'.ljust(20), str(RMSE(valence_t, valence_p)).ljust(20), RMSE(arousal_t, arousal_p))
print('CORR'.ljust(20), str(CORR(valence_t, valence_p)).ljust(20), CORR(arousal_t, arousal_p))
print('SAGR'.ljust(20), str(SAGR(valence_t, valence_p)).ljust(20), SAGR(arousal_t, arousal_p))
print('CCC'.ljust(20), str(CCC(valence_t, valence_p)).ljust(20), CCC(arousal_t, arousal_p))
print('')
t = (0, 0)
for i, r in enumerate(ratings):
print(str(i).ljust(5), np.mean(r))
t = (t[0] + np.sum(r), t[1] + len(r))
print('Total', t[0] / t[1])
if __name__ == '__main__':