Training a Neural Netwok to Play Tic-Tac-Toe in C/C++

Materials available at: https://forejune.co/cuda/

Reviews

  • https://forejune.co/cuda/ai/mlp/mlp.html (Perceptron)

  • https://forejune.co/cuda/ai/minimax/minimax.html (Minimax)
  • Multi-Layer Perceptron (MLP) for Tic-Tac-Toe

    Many ways to do it

    
    	
    Softmax ~ Activation function
    Binary Inputs and Outputs

    
    Inputs:
        x0 = 1 	(bias)
        x1  x2  x3  x4 x5  x6  x7  x8  x9	  	Player X
        x10 x11 x12 x13 x14 x15 x16 x17 x18 		Player O
    Outputs:
        y0 y1  y2  y3  y4 y5  y6  y7  y8  	  	
    
    Example:
    
    
       
    Inputs:
    	1
    	1 0 0 0 1 0 0 0 0	X
    	0 1 1 0 0 0 0 0 0	O
    Outputs:
    	0 0 0 0 0 0 0 0 1
      
      // mlp.h
      #ifndef __MLP_H__
      #define __MLP_H__
      #include <vector>
      #include <array>
      
      using namespace std;
      
      const int n1 = 19;  //number of inputs and bias
      const int m1 = 27;  //number of hidden nodes and bias
      const int K =  9;   //number of outputs
      struct database {
        vector<array<double, n1>> inputs;
        vector<array<double, K>> labels;
      };
      
      class MLP
      {
      private:
          double w1[n1][m1];
          double w[n1][m1];
          double wo[m1][K];
          double a[m1];  //linear sum of products of inputs and weights
          double h[m1];  //hidden nodes h[j] = g(a[j])
          double y[K];   //predicted output
          double z[K];   //linear sum of products of weights and hidden nodes
          double eta; //learning rate
      public:
          MLP(double learning_rate);
          double g(double x);
          double gd(double x);  //Derivative of sigmoid function
          double* forward(double x[]);  //Forward propagation
          void backward(double yd[], double x[]);  //Backward prop, yd is desired output
          void train(const database &db, int epochs);  //Train the Perceptron
          ~MLP();
          int  saveWeights(char fname[]);
          int  readWeights(char fname[]);
      };
      #endif
      	
    Building a Database for Training
    Depth First Search (DFS)
    
    Go as deep as possible, Backtrack
    
    Recursion programming
    
    void buildDatabase(int board[9], Player currentPlayer, database &db) { if ( a_player_has_won( board ) || draw( board ) ) return; if (currentPlayer == ai) { int move = minimax( board ); if ( move >= 0 ) { double inputs[n1], labels[K] = {0}; board2inputs(board, inputs); labels[move] = 1; addDatabase(db, inputs, labels); } } // scan other empty squares for (int i = 0; i < 9; i++){ if (board[i] == EMPTY) { board[i] = currentPlayer; buildDatabase(board, otherPlayer, db); // DFS board[i] = EMPTY; // restore board (backtrack) } } }
    	
    C/C++ Implementation

    
    // ticmlp.cpp -- training an MLP to play tic-tac-toe
    // https://forejune.co/cuda/
    
    #include "minimax.h"
    #include "mlp.h"
    #include <iostream>
    
    using namespace std;
    
    void addDatabase(database &db, double in[], double targets[])
    {
      array<double, n1>tempi;
      array<double, K>tempt;
    
      for (int i = 0; i < n1; i++)	//n1 = 19, K = 9
        tempi[i] = in[i];
      db.inputs.push_back( tempi );
    
      for (int i = 0; i < K; i++)
        tempt[i] = targets[i];
    
      db.labels.push_back( tempt );
    }
    
    void board2inputs(MiniMaxT &tic, double inputs[])
    {
       inputs[0] = 1;
       for (int i = 0; i < 9; i++) {
    	inputs[i+1] = inputs[i+10] = 0;
    	if (tic.board[i] == X)
    	  inputs[i+1] = 1;
    	else if (tic.board[i] == O)
    	  inputs[i+10] = 1; 
        }
    }
    
    void buildDatabase(MiniMaxT & tic, Player currentPlayer, database &db)
    {
      if ( tic.isWinning( X ) || tic.isWinning( O ) )
        return;
      if ( !tic.emptyCell() )
        return;		//no more empty space on board
    
      
      if (currentPlayer == X) {
        int move = tic.minimax(true, 0).move;	
        if ( move >= 0 ) {
          double inputs[n1], labels[K] = {0};
          board2inputs(tic, inputs);
          labels[move] = 1;
          addDatabase(db, inputs, labels);
        }
      }   
    
      // scan other empty squares
      for (int i = 0; i < 9; i++){
          if (tic.board[i] == EMPTY) {
    	tic.board[i] = currentPlayer;
    	Player nextPlayer = (currentPlayer == X) ? O : X;
            buildDatabase(tic, nextPlayer, db);  //~ DFS
    	tic.board[i] = EMPTY;	// restore board (backtracking)
          }
      }
    }
    
    int out2move(double out[K])
    {
      double max = -1, move = -1;
      for (int i = 0; i < K; i++) {
        if (out[i] > max) {
          max = out[i];
          move = i;
        }
      }
    
      return move;
    }
    
    
    int main()
    {
    
       database db;
       Player ai = X;
       Player opp = O;
       Player turn;
    
       MiniMaxT tic(ai, opp);
       buildDatabase(tic, ai, db);
    
       cout << "Training the MLP ..." << endl;
       MLP mlp(0.05);
       int iterations =  30;		//number of epochs for training
       
       mlp.train(db, iterations);
    
       mlp.saveWeights( (char *) "ticWeights.txt" );
         
       double *outs;
       double ins[n1];
      
       tic.resetBoard(); 
       cout << "After training" << endl;
    
       turn = ai;
       while ( true ) {
        if ( tic.isWinning( ai ) ) {
            cout << "AI wins!\n";
            break;
        } else if ( tic.isWinning( opp ) ) {
            cout << "You win!\n";
            break;
        } else if ( !tic.emptyCell() ) {
            cout << "Draw!\n";
            break;
        }
        int move;	int neuralMove;
        if (turn == ai) {
          move =  tic.minimax(true, 0).move;	//move from minimax
          board2inputs(tic, ins);
          outs = mlp.forward( ins );
          neuralMove = out2move( outs );		//move from MLP
          tic.board[neuralMove] = ai;
          cout << "AI plays " << neuralMove << "(" << move << ")" <<  endl;
          turn = opp;
        } else {            // opponent's turn
          for ( ; ; ){
            cout << "Your move (0-8): ";
            cin >> move;
            if ( move < 0 || move > 8 || tic.board[move] != EMPTY) {
              cout << "Invalid move; try again: \n";
              continue;
            } else
              break;
          }
          tic.board[move] = opp;
          turn = ai;
        }  //else
        tic.printBoard();
      }  //while true
    
      return 0;
    }
      

    	
    Using a GUI

    /* https://forejune.co/cuda
     * testTicmlpg.cpp
     *
     */
    
    #include <GL/gl.h>
    #include <GL/glu.h>
    #include <GL/glut.h>
    #include <iostream>
    #include "mlp.h"
    
    using namespace std;
    
    enum Player {EMPTY=0, X, O};
    
    Player ai = X;
    Player opp = O;
    Player board[9];
    bool oppMove = false;
    bool playing = false;
    const float d = 1;	//drawing distance between lines = 2*d
    int	whoWon = -1;
    MLP mlp (0.05);
    
    void resetBoard()
    {
      for (int i = 0; i < 9; i++)
        board[i] = EMPTY;
    
      oppMove = false;
      playing = false;
      whoWon = -1;
    }
    
    void displayMessage(float x, float y, void* font, const string& str) 
    {
        glRasterPos2f(x, y);
    
        // Loop through each character of the string and draw it
        for (char const& c : str) 
            glutBitmapCharacter(font, c);
    }
    
    void printBoard()
    {
       float di = -d/2.0;		//align image center at center of square
       // positions to display X or O
       float cx[9] = {-2*d, 0, 2*d, -2*d, 0, 2*d, -2*d, 0, 2*d};
       float cy[9] = {2*d, 2*d, 2*d, 0, 0, 0, -2*d, -2*d, -2*d};
       for (int i = 0; i < 9; i++ ) {
           if (board[i] != EMPTY) {
    	  if (board[i] == X) {
    	     glColor3f(1, 0, 0);	//red color
                 displayMessage(cx[i]+di, cy[i]+di, GLUT_BITMAP_TIMES_ROMAN_24, "X");
    	  }else {
    	     glColor3f(0, 1, 0);	//use green color
                 displayMessage(cx[i]+di, cy[i]+di, GLUT_BITMAP_TIMES_ROMAN_24, "O");
    	  }
           } 
       } // for
     
       if ( whoWon >= 0 ) {
         string str[] = {"AI has won!", "You have won!", "It was a draw!"};
         displayMessage(-3*d, -5*d, GLUT_BITMAP_TIMES_ROMAN_24, str[whoWon]);
       } 
       glFlush();
    }
    
    int window;
    int screenWidth = 500, screenHeight = 500;
    
    void init(void)
    {   
      glClearColor(1, 1, 1, 0);     //clear color buffer with white color
      glClear(GL_COLOR_BUFFER_BIT); //clear color buffer
      //define coordinate system
      glMatrixMode(GL_PROJECTION);
      glLoadIdentity();
      gluOrtho2D(-10, 10, -10, 10);
      glPointSize( 3 );
      glColor3f(0.0, 0.0, 0.0);             //draw with black color
    
      glMatrixMode(GL_MODELVIEW);
      glLoadIdentity();
      glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
      mlp.readWeights((char*) "ticWeights.txt");
    }
    
    void line (float x0, float y0, float x1, float y1)
    {
      glBegin(GL_LINES);
        glVertex2f(x0, y0);
        glVertex2f(x1, y1);
      glEnd();
    }
    void display(void)
    {
       glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
       glLineWidth( 3 );
       glColor3f(0, 0, 0);
       line(-3*d, d, 3*d, d);		//upper horizontal line 
       line(-3*d, -d, 3*d, -d);		//lower horizontal line 
       line(-d, 3*d, -d, -3*d);		//left vertical line
       line(d, 3*d, d, -3*d);		//right vertical line
    
       glFlush();
    }
      
    bool isWinning(Player player)
    {
        const int w[8][3] = {               //winning positions
            {0,1,2},{3,4,5},{6,7,8},        //rows
            {0,3,6},{1,4,7},{2,5,8},        //columns
            {0,4,8},{2,4,6}                 //diagonals
        };
        for (int i = 0; i < 8; i++){
          if (board[w[i][0]]==player && board[w[i][1]]==player
                          && board[w[i][2]]==player)
            return true;
         }
    
        return false;
    }
    
    bool emptyCell()
    {
        //chekc if more empty space
        for (int i = 0; i < 9; i++)
           if (board[i] == EMPTY)
             return true;
    
        return false;   //no more empty space
    }
    
    
    void board2inputs(double inputs[])
    {
       inputs[0] = 1;
       for (int i = 0; i < 9; i++) {
            inputs[i+1] = inputs[i+10] = 0;
            if (board[i] == X)
              inputs[i+1] = 1;
            else if (board[i] == O)
              inputs[i+10] = 1;
        }
    }
    
    int out2move(double out[K])
    {
      double max = -1, move = -1;
      for (int i = 0; i < K; i++) {
        if (out[i] > max) {
          max = out[i];
          move = i;
        }
      }
    
      return move;
    }
    
    int checkStatus()
    {
       if ( isWinning( ai ) ) 
            return 0;	
       else if ( isWinning( opp ) ) 
            return 1;	
        else if ( !emptyCell() ) 
            return 2;	
    
        return -1;
    }
    
    void aiMove()
    {
      whoWon = checkStatus();
      if (whoWon > 0) return;	//game over
      double ins[n1];
      double *outs;
      board2inputs(ins);
      outs = mlp.forward( ins );
    
      int move = out2move( outs );
      
      board[move] = ai;
    }
    
    void play()
    {
      if ( playing )
        return;
      
      //ai goes first
      int move = 0;
      board[move] = ai;
      printBoard();
      oppMove = true;	    //else opp moves first
    }
    
    void keyboard(unsigned char key, int x, int y)
    {
      switch(key) {
        case 27: /* escape */
            glutDestroyWindow(window);
            exit(0);
        case 'p':	//play game
    	if ( !playing ) {
    	  play();
    	  playing = true;
    	}
    	break;
        case 'r':	//reset board
    	resetBoard();
    	glutPostRedisplay();
    	break;
      }
    }
    
    int getLocation(float wx, float wy)
    {
      if (wx < -d) {
        if (wy > d)
     	return 0;
        else if (wy > -d)
    	return 3;
        else
    	return 6;
       }else if (wx < d) {
        if (wy > d)
     	return 1;
        else if (wy > -d)
    	return 4;
        else
    	return 7;
       }else {
        if (wy > d)
     	return 2;
        else if (wy > -d)
    	return 5;
        else
    	return 8;
       }
    }
    
    void mouse(int button, int state, int mx, int my)
    {
       if ( !oppMove )
         return;
       int x = mx, y = screenHeight - my;
       float wx = (float) x * 20.0/screenWidth - 10;	//world coordinates
       float wy = (float) y * 20.0/screenHeight - 10;	
       if ( button == GLUT_LEFT_BUTTON && state == GLUT_DOWN ){
          if (whoWon < 0 ) {	//game not done
            int loc = getLocation(wx, wy);
            if (board[loc] == EMPTY) {
    	  board[loc] = opp;
    	  oppMove = false;
    	  printBoard();
    	  aiMove();
    	  whoWon = checkStatus();
    	  printBoard();
    	  oppMove = true;
            }
          }
       }
    }
    
    int graphics(int argc, char** argv)
    {
       glutInit(&argc, argv);
       glutInitDisplayMode(GLUT_SINGLE | GLUT_RGB | GLUT_DEPTH);
       glutInitWindowSize(screenWidth, screenHeight);
       glutInitWindowPosition(100, 100);
       window = glutCreateWindow(argv[0]);
       init();
       glutDisplayFunc(display);
       glutKeyboardFunc(keyboard);
       glutMouseFunc( mouse );
       glutMainLoop();
       return 0; 
    }
    
    
    int main(int argc, char** argv)
    {
      graphics(argc, argv);
    
      return 0;
    }
    

     May try other configurations, e.g. m1 = 20