generate_dataset.py 2.67 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
import numpy as np
from scipy.special import comb
import matplotlib.pyplot as plt
import math

# Finde alle Vektoren x \in {0,1}^n mit genau 5 Einträgen = 1 (weiß) 
# Gibt eine Matrix der Größe (N,n) zurück, deren Zeilen die gesuchten Vektoren sind
def find_combinations(n,k):
    # Rekursionsanfang
    if k==0:
        return np.zeros(n)
    if n==1 & k==1:
        return np.array([1])
    if n==1 & k==0:
        return np.array([0])

    # Anzahl der möglichen Kombinationen: k aus n auswählen, Matrix anlegen
    N = int(comb(n,k))
    X = np.zeros((N,n)) 

    # Setze den ersten Eintrag auf 1 (weiß) und rufe das Subproblem auf
    number_of_combinations_problem_1 = int(comb(n-1,k-1))
    X[0:number_of_combinations_problem_1,0] = 1
    X[0:number_of_combinations_problem_1,1:n] = find_combinations(n-1,k-1)

    if number_of_combinations_problem_1 == N:
        return X

    # Belasse den ersten Eintrag bei 0 (schwarz) und rufe das Subproblem auf
    X[number_of_combinations_problem_1:,1:n] = find_combinations(n-1,k)

    return X

# (weiß gewinnt, schwarz gewinnt, niemand gewinnt)
def winner_one_line(x1,x2,x3):
    if x1 != x2 or x2 != x3:
        return np.array([0,0,1]).T
    if x1 == 1:
        return np.array([1,0,0]).T
    return np.array([0,1,0]).T
    

def one_tictactoe_label(x):
    
    strikes = np.zeros((3, 8))

    # Alle Möglichkeiten zu gewinnen
    strikes[:,0] = winner_one_line(x[0], x[4], x[8]) # Diagonale 
    strikes[:,1] = winner_one_line(x[2], x[4], x[6]) # Antidiagonale
    strikes[:,2] = winner_one_line(x[0], x[1], x[2]) # Horizontal 1
    strikes[:,3] = winner_one_line(x[3], x[4], x[5]) # Horizontal 2
    strikes[:,4] = winner_one_line(x[6], x[7], x[8]) # Horizontal 3
    strikes[:,5] = winner_one_line(x[0], x[3], x[6]) # Vertikal 1
    strikes[:,6] = winner_one_line(x[1], x[4], x[7]) # Vertikal 2
    strikes[:,7] = winner_one_line(x[2], x[5], x[8]) # Vertikal 3

    # Eine Farbe gewinnt, falls sie mindestens einen Strike hat und die andere Farbe keine Strikes hat
    strikes_white = np.sum(strikes[0,:])
    strikes_black = np.sum(strikes[1,:])

    # Weiß gewinnt
    if strikes_black == 0 and strikes_white > 0:
        return np.array([1,0,0])
    # Schwarz gewinnt
    if strikes_white == 0 and strikes_black > 0:
        return np.array([0,1,0])
    
    return np.array([0,0,1])


def tictactoe_labels(X):
    N,n = X.shape
    labels = np.zeros((N,3))

    for i in range(N):
        labels[i,:] = one_tictactoe_label(X[i,:])
    
    return labels.astype(float)


def generate_tictactoe():
    n = 9
    k = 5
    N = int(comb(n,k))
    X = np.zeros((N,n))
    X = find_combinations(n,k).astype(float)
    labels = tictactoe_labels(X)
    return X, labels