plots.py 2.86 KB
Newer Older
1 2 3 4
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
import math
5
from generate_dataset import *
6 7 8 9 10 11 12 13

def imshow_zero_center(image, n):
    lim = tf.reduce_max(abs(image))
    plt.imshow(image, vmin=-lim, vmax=lim, cmap='seismic')
    plt.title("Hesse-Matrix für n = " +  str(n))
    plt.colorbar()
    plt.show()

Eva Lina Fesefeldt's avatar
Eva Lina Fesefeldt committed
14 15 16 17 18 19 20 21
def vecshow_zero_center(image, title):
    image = image.T
    lim = tf.reduce_max(abs(image))
    plt.imshow(image, vmin=-lim, vmax=lim, cmap='seismic')
    plt.title(title)
    plt.colorbar()
    plt.show()   

22 23 24 25 26
def show_eigenvalues_semilogy(A, n):
    plt.semilogy(tf.math.abs(A), '.b')
    plt.xlabel("Eigenwertindex")
    plt.ylabel("Betrag des Eigenwertes")
    plt.title("Eigenwerte der Hesse-Matrix für n = " + str(n))
27
    plt.show()
28 29 30 31 32 33

def show_eigenvalues_semilogx_complex_plane(A, n):
    plt.semilogx(tf.math.real(A), tf.math.imag(A), '.r', alpha=0.3)
    plt.xlabel("Realteil")
    plt.ylabel("Imaginärteil")
    plt.title("Eigenwerte der Hesse-Matrix für n = " + str(n))
34

35 36 37 38

# Visualisieren des Feldes
def show_fields(X, labels, n):
    #plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]})
Eva Lina Fesefeldt's avatar
Eva Lina Fesefeldt committed
39 40
    mpl.rcParams.update({'font.size': 8})
    plt.figure(figsize=(7, 7))
41 42 43 44 45 46 47 48 49 50 51 52
    ncols = 3
    nrows = math.ceil(n / ncols)
    for i in range(n):
        ax = plt.subplot(nrows, ncols, i + 1)
        plt.imshow(X[i].reshape(3, 3))
        plt.gray()
        if labels[i,0] == 1:
            title = 'Weiß gewinnt'
        if labels[i,1] == 1:
            title = 'Schwarz gewinnt'
        if labels[i,2] == 1:
            title = 'Unentschieden'
Eva Lina Fesefeldt's avatar
Eva Lina Fesefeldt committed
53
        ax.set_title(title)
54 55
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
Eva Lina Fesefeldt's avatar
Eva Lina Fesefeldt committed
56 57

    plt.show()
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

def driver_show_fields():
    set, labels = generate_tictactoe()
    index_schwarz = labels[:,1] == 1
    schwarz_gewinnt = set[index_schwarz,:]
    label_schwarz = labels[index_schwarz,:]
    index_weiß = labels[:,0] == 1
    weiß_gewinnt = set[index_weiß,:]
    label_weiß = labels[index_weiß,:]
    index_unentschieden = labels[:,2] == 1
    unentschieden = set[index_unentschieden,:]
    label_unentschieden = labels[index_unentschieden,:]

    schwarz_gewinnt_auswahl = schwarz_gewinnt[:3,:]
    label_schwarz_auswahl = label_schwarz[:3,:]
    weiß_gewinnt_auswahl = weiß_gewinnt[:3,:]
    label_weiß_auswahl = label_weiß[:3,:]
    unentschieden_auswahl = unentschieden[:3,:]
    label_unenschieden_auswahl = label_unentschieden[:3,:]

    auswahl = np.concatenate((schwarz_gewinnt_auswahl, weiß_gewinnt_auswahl, unentschieden_auswahl))
    label_auswahl = np.concatenate((label_schwarz_auswahl, label_weiß_auswahl, label_unenschieden_auswahl))

    pi = np.random.permutation(9)
    auswahl = auswahl[pi]
    label_auswahl = label_auswahl[pi]

    show_fields(auswahl, label_auswahl, 9)
    

88

89 90
if __name__ == "__main__":
    driver_show_fields()