make_dataset_plot.py 1.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from generate_dataset import *
from plots import show_fields

set, labels = generate_tictactoe()

index_schwarz = labels[:,1] == 1
schwarz_gewinnt = set[index_schwarz,:]
label_schwarz = labels[index_schwarz,:]
index_weiß = labels[:,0] == 1
weiß_gewinnt = set[index_weiß,:]
label_weiß = labels[index_weiß,:]
index_unentschieden = labels[:,2] == 1
unentschieden = set[index_unentschieden,:]
label_unentschieden = labels[index_unentschieden,:]

schwarz_gewinnt_auswahl = schwarz_gewinnt[:3,:]
label_schwarz_auswahl = label_schwarz[:3,:]
weiß_gewinnt_auswahl = weiß_gewinnt[:3,:]
label_weiß_auswahl = label_weiß[:3,:]
unentschieden_auswahl = unentschieden[:3,:]
label_unenschieden_auswahl = label_unentschieden[:3,:]

auswahl = np.concatenate((schwarz_gewinnt_auswahl, weiß_gewinnt_auswahl, unentschieden_auswahl))
label_auswahl = np.concatenate((label_schwarz_auswahl, label_weiß_auswahl, label_unenschieden_auswahl))

pi = np.random.permutation(9)
auswahl = auswahl[pi]
label_auswahl = label_auswahl[pi]

show_fields(auswahl, label_auswahl, 9)
print()