【Unity】ML-Agentsでセルフプレイの強化学習を行う【大玉押し合いバトル】

Unity
マイケル
マイケル
みなさんこんにちは!
マイケルです!
エレキベア
エレキベア
こんにちクマ〜〜〜
マイケル
マイケル
今日は引き続きML-Agentsを触っていきます!
前回は簡単なサンプル作成まで行いましたが、今回はセルフプレイを使った学習について見ていこうと思います。
エレキベア
エレキベア
セルフプレイって何クマ?
マイケル
マイケル
簡単に言うと、エージェント同士を競い合わせて学習させる手法のことだよ!
テニスやサッカー等、1対1やチーム同士で競う場面で使用することができるね。
エレキベア
エレキベア
エージェント同士で実際にゲームをプレイしながら学ぶのクマね
マイケル
マイケル
イメージとして、ポン風のゲームを学習させてみたのはこんな感じ!
03 pingpong↑セルフプレイで学習させてみた結果
エレキベア
エレキベア
おお〜〜〜
ちゃんといい感じにバトルしてるクマね
マイケル
マイケル
ちょっとガクガクしてるけどそれっぽいよね!
こちらは下記の書籍に載っていたサンプルを使わせていただいています。

Unity ML-Agents 実践ゲームプログラミング v1.1対応版

マイケル
マイケル
サンプルをそのまま載せても面白くないので、
今回はこれを少し改造した大玉転がしのセルフプレイ学習に挑戦してみようと思います!
エレキベア
エレキベア
大玉転がし楽しそうクマね
スポンサーリンク

大玉転がしのセルフプレイ学習

セルフプレイとは

マイケル
マイケル
まずはセルフプレイとは何かについて!
これは先ほど説明した通り、エージェント同士を競い合わせて学習させる手法のことです!
基本的な学習方法は前回と同じだけど、下記に注意して学習させる必要があります!


・最終的な報酬を-1、0、1に設定する。
・敵対するエージェントの単位でチームIDを振り分ける。

マイケル
マイケル
この最終的な報酬とチームIDから、ELOという指標を使って
学習の進行状態を調べることになります。
エレキベア
エレキベア
お互いに強くなっていくのを確認できるクマね

プロジェクト構成

マイケル
マイケル
オブジェクト構成は下記のように設定しました。
ボール(Ball)、エージェント(PlayerAgent)の他に、
ステージの端にゴールとなるエリア(ScoreArea)を追加しています。
ScreenShot 2022 07 14 0 04 49↑オブジェクト構成
マイケル
マイケル
大玉転がしのため、ルールは
ボールを押し合って相手のエリアに入れたら勝ちというだけです!
エレキベア
エレキベア
大玉転がしのルールってそれぞれのチームで転がして競争する競技じゃなかったクマ?

参考:大玉送り – Wikiedia

マイケル
マイケル
・・・・・・。
マイケル
マイケル
「大玉押し合いバトル」のセルフプレイ学習を行います!!
エレキベア
エレキベア
(無理矢理変えたクマ・・・)
マイケル
マイケル
よく考えたら大玉転がし人生で一度もやったことないよ!!
エレキベア
エレキベア
マジかよクマ・・・
エージェント学習スクリプトの作成
マイケル
マイケル
それでは気を取り直して・・・
メインとなるエージェントの学習スクリプトを見ていきます!
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

namespace RollerBigBall
{
    public class PlayerAgent : Agent
    {
        [SerializeField] private int agentId;
        [SerializeField] private GameObject rootGameObject;
        [SerializeField] private GameObject enemy;
        [SerializeField] private GameObject ball;
        private Rigidbody m_BallRigidbody;

        public override void Initialize()
        {
            m_BallRigidbody = ball.GetComponent<Rigidbody>();
        }

        public override void OnEpisodeBegin()
        {
            SetReward(0.0f);
        }

        public override void CollectObservations(VectorSensor sensor)
        {
            // 対戦相手と対面になっているため反転させる
            var dir = (agentId == 0) ? 1.0f : -1.0f;

            // 観察値
            // 自分の位置
            var playerLocalPosition = GetOffsetLocalPosition(transform.position);
            sensor.AddObservation(playerLocalPosition.x * dir);
            sensor.AddObservation(playerLocalPosition.y);
            sensor.AddObservation(playerLocalPosition.z * dir);
            // 敵の位置
            var enemyLocalPosition = GetOffsetLocalPosition(enemy.transform.position);
            sensor.AddObservation(enemyLocalPosition.x * dir);
            sensor.AddObservation(enemyLocalPosition.y);
            sensor.AddObservation(enemyLocalPosition.z * dir);
            // ボールの位置、速度
            var ballLocalPosition = GetOffsetLocalPosition(ball.transform.position);
            sensor.AddObservation(ballLocalPosition.x * dir);
            sensor.AddObservation(ballLocalPosition.y);
            sensor.AddObservation(ballLocalPosition.z * dir);
            sensor.AddObservation(m_BallRigidbody.velocity.x * dir);
            sensor.AddObservation(m_BallRigidbody.velocity.y);
            sensor.AddObservation(m_BallRigidbody.velocity.z * dir);
        }

        private Vector3 GetOffsetLocalPosition(Vector3 worldPosition)
        {
            // 複数エージェントで学習できるよう、ルートとなるゲームオブジェクトの座標分引く
            return worldPosition - rootGameObject.transform.position;
        }

        private enum AgentAction
        {
            None = 0,
            LeftMove,
            RightMove,
            ForwardMove,
            BackMove,
        }

        public override void OnActionReceived(ActionBuffers actions)
        {
            // 対戦相手と対面になっているため反転させる
            var dir = (agentId == 0) ? 1.0f : -1.0f;

            // 移動アクションを受け取ったら動かす
            var discreteActions = actions.DiscreteActions;
            var action = discreteActions[0];
            var pos = transform.localPosition;
            switch (action)
            {
                case (int) AgentAction.LeftMove:
                    pos.x -= 0.2f * dir;
                    break;
                case (int) AgentAction.RightMove:
                    pos.x += 0.2f * dir;
                    break;
                case (int) AgentAction.ForwardMove:
                    pos.z += 0.2f * dir;
                    break;
                case (int) AgentAction.BackMove:
                    pos.z -= 0.2f * dir;
                    break;
            }

            // 範囲外には行かないように制御
            if (pos.x < -4.0f) pos.x = -4.0f;
            if (pos.x > 4.0f) pos.x = 4.0f;
            if (pos.z > 7.0f) pos.z = 7.0f;
            if (pos.z < -7.0f) pos.z = -7.0f;

            // 位置を設定
            transform.localPosition = pos;
        }

        public override void Heuristic(in ActionBuffers actionsOut)
        {
            var discreteActions = actionsOut.DiscreteActions;
            discreteActions[0] = (int) AgentAction.None;
            if (Input.GetKey(KeyCode.LeftArrow)) discreteActions[0] = (int) AgentAction.LeftMove;
            if (Input.GetKey(KeyCode.RightArrow)) discreteActions[0] = (int) AgentAction.RightMove;
            if (Input.GetKey(KeyCode.UpArrow)) discreteActions[0] = (int) AgentAction.ForwardMove;
            if (Input.GetKey(KeyCode.DownArrow)) discreteActions[0] = (int) AgentAction.BackMove;
        }

        private void OnCollisionEnter(Collision collision)
        {
            // ボールに衝突したら報酬を与える
            if (collision.gameObject.CompareTag("Ball"))
            {
                AddReward(0.01f);
            }
        }
    }
}
マイケル
マイケル
スクリプト全体はこんな感じです!
観察値としては下記の12個の値を設定しています。
        public override void CollectObservations(VectorSensor sensor)
        {
            // 対戦相手と対面になっているため反転させる
            var dir = (agentId == 0) ? 1.0f : -1.0f;

            // 観察値
            // 自分の位置
            var playerLocalPosition = GetOffsetLocalPosition(transform.position);
            sensor.AddObservation(playerLocalPosition.x * dir);
            sensor.AddObservation(playerLocalPosition.y);
            sensor.AddObservation(playerLocalPosition.z * dir);
            // 敵の位置
            var enemyLocalPosition = GetOffsetLocalPosition(enemy.transform.position);
            sensor.AddObservation(enemyLocalPosition.x * dir);
            sensor.AddObservation(enemyLocalPosition.y);
            sensor.AddObservation(enemyLocalPosition.z * dir);
            // ボールの位置、速度
            var ballLocalPosition = GetOffsetLocalPosition(ball.transform.position);
            sensor.AddObservation(ballLocalPosition.x * dir);
            sensor.AddObservation(ballLocalPosition.y);
            sensor.AddObservation(ballLocalPosition.z * dir);
            sensor.AddObservation(m_BallRigidbody.velocity.x * dir);
            sensor.AddObservation(m_BallRigidbody.velocity.y);
            sensor.AddObservation(m_BallRigidbody.velocity.z * dir);
        }

        private Vector3 GetOffsetLocalPosition(Vector3 worldPosition)
        {
            // 複数エージェントで学習できるよう、ルートとなるゲームオブジェクトの座標分引く
            return worldPosition - rootGameObject.transform.position;
        }
↑観察値の設定
マイケル
マイケル
設定値自体は各オブジェクトの位置、速度になっていますが、
一点、注意点としては相手と対面しているため、片方のエージェントは向きを考慮して反転させなければならないことです。
エレキベア
エレキベア
なるほどクマ
反対側から見たら逆になるクマね
マイケル
マイケル
それとGetOffsetLocalPositionの部分については、複数の学習環境を並べて学習させたいために変換しています。
localPositionだと回転等考慮しないといけなかったので・・・。
エレキベア
エレキベア
学習の効率化クマね
マイケル
マイケル
行動については離散値を使用していて、
下記のように移動+何もしないの5種類を設定しています。
こちらも向きを反転させないといけない点には注意しましょう。
        private enum AgentAction
        {
            None = 0,
            LeftMove,
            RightMove,
            ForwardMove,
            BackMove,
        }

        public override void OnActionReceived(ActionBuffers actions)
        {
            // 対戦相手と対面になっているため反転させる
            var dir = (agentId == 0) ? 1.0f : -1.0f;

            // 移動アクションを受け取ったら動かす
            var discreteActions = actions.DiscreteActions;
            var action = discreteActions[0];
            var pos = transform.localPosition;
            switch (action)
            {
                case (int) AgentAction.LeftMove:
                    pos.x -= 0.2f * dir;
                    break;
                case (int) AgentAction.RightMove:
                    pos.x += 0.2f * dir;
                    break;
                case (int) AgentAction.ForwardMove:
                    pos.z += 0.2f * dir;
                    break;
                case (int) AgentAction.BackMove:
                    pos.z -= 0.2f * dir;
                    break;
            }

            // 範囲外には行かないように制御
            if (pos.x < -4.0f) pos.x = -4.0f;
            if (pos.x > 4.0f) pos.x = 4.0f;
            if (pos.z > 7.0f) pos.z = 7.0f;
            if (pos.z < -7.0f) pos.z = -7.0f;

            // 位置を設定
            transform.localPosition = pos;
        }
↑エージェントが取る行動
マイケル
マイケル
報酬については基本的に外側から設定しますが、
ボールに向かっていくようエージェント内でも少し報酬を与えています。
最終的に1、-1に近づけないといけないため、あまり大きな報酬を与えないようにしましょう。
        private void OnCollisionEnter(Collision collision)
        {
            // ボールに衝突したら報酬を与える
            if (collision.gameObject.CompareTag("Ball"))
            {
                AddReward(0.01f);
            }
        }
↑ボールにあたると報酬を与える
エレキベア
エレキベア
ボールを押すと褒めてあげるクマね
マイケル
マイケル
あとはこのスクリプトとBehavior Parameters、Decision Requesterコンポーネント
各エージェントにアタッチして、下記のように設定します。
ScreenShot 2022 07 14 0 02 25↑PlayerAgent0の設定
ScreenShot 2022 07 14 0 02 40↑PlayerAgent1の設定
マイケル
マイケル
AgentIdとTeamIdを振り分ける点には注意しましょう!
以上でエージェントの設定は完了です!
エレキベア
エレキベア
こんな感じでグループ分けするクマね
ゲーム管理スクリプトの作成
マイケル
マイケル
そして次はゲーム全体を管理するスクリプトを作成します。
こちらもこれまでとは異なる点ですが、ゲームが終了した時点で勝敗を判断し、
各エージェントに報酬を与えて
います。
using Unity.MLAgents;
using UnityEngine;
using Random = UnityEngine.Random;

namespace RollerBigBall
{
    public class GameManager : MonoBehaviour
    {
        // エージェント
        [SerializeField] private Agent[] agents;
        private Vector3[] m_AgentsInitPositions;

        // ボール
        [SerializeField] private GameObject ball;
        private Rigidbody m_BallRigidbody;
        private Vector3 m_BallInitPosition;
        private readonly float BallSpeed = 3.0f;

        private void Start()
        {
            m_AgentsInitPositions = new Vector3[agents.Length];
            m_AgentsInitPositions[0] = agents[0].transform.position;
            m_AgentsInitPositions[1] = agents[1].transform.position;
            m_BallRigidbody = ball.GetComponent<Rigidbody>();
            m_BallInitPosition = ball.transform.position;
            Reset();
        }

        private void Update()
        {
            // ボールが落ちたらリセット
            if (ball.transform.position.y <= -0.5f)
            {
                Reset();
            }
        }

        private void Reset()
        {
            // エージェント位置のリセット
            agents[0].gameObject.transform.position = m_AgentsInitPositions[0];
            agents[1].gameObject.transform.position = m_AgentsInitPositions[1];

            // ボールの位置のリセット
            ball.transform.position = m_BallInitPosition;

            // ランダムな方向へ打ち出す
            var force = new Vector3(
                Random.Range(-1.0f, 1.0f) * BallSpeed,
                0.0f,
                Random.Range(-1.0f, 1.0f) * BallSpeed);
            if (Random.value < 0.5f) force.z *= -1.0f;
            m_BallRigidbody.velocity = force;
        }

        public void EndGame(int winAgentId)
        {
            // 勝った方に報酬を与える
            if (winAgentId == 0)
            {
                agents[0].AddReward(1.0f);
                agents[1].AddReward(-1.0f);
            }
            else
            {
                agents[0].AddReward(-1.0f);
                agents[1].AddReward(1.0f);
            }
            // エピソード終了してリセット
            agents[0].EndEpisode();
            agents[1].EndEpisode();
            Reset();
        }
    }
}
↑ゲーム管理オブジェクト
エレキベア
エレキベア
エピソード終了処理も管理クラス側から呼び出すクマね
スコアエリアスクリプトの作成
マイケル
マイケル
そして最後に各ゴールエリアに下記のスクリプトをアタッチしましょう!
winAgentIdには、そのエリアにボールを入れた際に勝利するエージェントのIDを指定します。
using UnityEngine;

namespace RollerBigBall
{
    public class ScoreArea : MonoBehaviour
    {
        [SerializeField] private GameManager gameManager;
        [SerializeField] private int winAgentId;

        private void OnTriggerEnter(Collider other)
        {
            gameManager.EndGame(winAgentId);
        }
    }
}
エレキベア
エレキベア
ゴールに入れたらゲーム終了処理を呼び出すクマね
マイケル
マイケル
以上でプロジェクトの設定は完了です!

訓練ファイルの作成

マイケル
マイケル
次に学習させるための訓練ファイルを作成しましょう!
この時セルフプレイ学習特有のパラメータ(self_playの部分)を設定しなければなりません。
詳細は公式のドキュメントをご参照ください!

GitHub – Unity-Technologies
/
ml-agents

– Self-Play

behaviors:
  RollerBigBall:
    # トレーナー種別
    trainer_type: ppo
    
    # ハイパーパラメータ
    hyperparameters:
        
      # PPO、SAC共通
      batch_size: 1024   # 大きめにとる
      buffer_size: 10240 # 大きめにとる
      learning_rate: 3.0e-4
      learning_rate_schedule: constant # 常に相手が変わるため一定にする
      
      # PPO固有
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
    
    # ニューラルネットワーク
    network_settings:
      normalize: true
      hidden_units: 128
      num_layers: 2
      vis_encode_type: simple
    
    # 報酬シグナル
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    
    # 基本設定
    max_steps: 50000000
    time_horizon: 1000
    summary_freq: 10000
    
    # セルフプレイ
    self_play:
        save_steps: 50000                    # 何ステップ毎にポリシーを保存するか
        team_change: 100000                  # 何ステップ毎に学習チームを切り替えるか
        swap_steps: 50000                    # 何ステップ毎に対戦相手のポリシーを交換するか
        play_against_latest_model_ratio: 0.5 # 過去のポリシーと対戦しない確率
        window: 10                           # 対戦相手として保持するポリシー数(多様性)
        initial_elo: 1200.0                  # ELOの初期値
↑訓練ファイルのパラメータ設定
マイケル
マイケル
今回はこのような値に設定しました。
セルフプレイ学習は相手が変わっていくため、learning_rate_scheduleは上げていかずに常に一定(constant)にした方がよさそうです。
エレキベア
エレキベア
パラメータの内容も段々分かってきたクマね

強化学習実行

マイケル
マイケル
準備が整ったところで学習実行します!
作成したyamlファイルを config/sample 配下に置いて、下記コマンドを実行します。
# 学習実行
cd 【mlagentsフォルダ】
mlagents-learn ./config/sample/RollerBigBall_selfplay.yaml --run-id=firstRunRollerBigBall
01 big ball learn↑学習の様子
マイケル
マイケル
熱戦を繰り広げていますね
エレキベア
エレキベア
シュールクマ・・・
マイケル
マイケル
学習の進行具合をtensorboardで確認します。
# tensorboardの確認
cd 【mlagentsフォルダ】
tensorboard --logdir results/firstRunRollerBigBall --port 6006
ScreenShot 2022 07 13 23 32 10↑ELOの確認
マイケル
マイケル
Self-playのELOの値が上がっていくのを確認できれば、
キリのいいところで学習終了させましょう!
マイケル
マイケル
ELOの値が思わしくない場合は、teamIdを振り分け忘れている等の可能性があるため、
設定を一度見直してみるといいかもしれません!
(自分は最初やらかしました)
エレキベア
エレキベア
やっちまったのクマね・・・

作成したモデルでの推論

マイケル
マイケル
さてそれでは学習したAI同士で戦わせてみましょう!
エレキベア
エレキベア
どんな結果になったか楽しみクマ〜〜〜
02 big ball battle↑AI VS AI
マイケル
マイケル
どうでしょうか!
割といい感じにバトルしているように見えます!
エレキベア
エレキベア
少し挙動不審な感じもするクマが
お互い探り合ってる感じが面白いクマね
マイケル
マイケル
ちなみに僕も操作して戦ってみましたが、
割といい勝負でした・・・。
嬉しいような悲しいような・・・。
03 big ball vs michael↑マイケル(赤) vs AI(緑)
エレキベア
エレキベア
衝突判定が甘いが故に難しくなってるクマね
AI相手に必死なのが面白いクマ〜〜
マイケル
マイケル
余裕だろと思ったら中々勝てなくてびっくりしたよ・・・。

おわりに

マイケル
マイケル
それでは今回はここまで!
セルフプレイでの学習を行ってみたけどどうだったかな?
エレキベア
エレキベア
やっぱり生きてるみたいな動きしてくれると
見てて面白いクマね〜〜〜
マイケル
マイケル
強化学習はこういうところが面白いね!
今回参考にさせていただいた書籍にはこれ以外の学習方法もたくさん載っているから、興味を持った方は是非読んでみてください!

Unity ML-Agents 実践ゲームプログラミング v1.1対応版

マイケル
マイケル
それから公式のサンプルも幅広く用意されているので、
こちらも触るとより理解が深まると思います!
エレキベア
エレキベア
多関節の学習も気になるクマ〜〜〜
マイケル
マイケル
そんなこんなでML-Agentsの記事を2回続けて書いたけど、
今度はこの経験を踏まえてリバーシのAIも作ってみます!!
お楽しみに〜〜〜!!
エレキベア
エレキベア
やったるクマ〜〜〜〜

【Unity】ML-Agentsでセルフプレイの強化学習を行う【大玉押し合いバトル】〜完〜

コメント