import gymnasium as gym
import numpy as np
from stable_baselines3 import DQN
from stable_baselines3.dqn import CnnPolicy
game="ALE/Pong-v5"
env=gym.make(game,render_mode="rgb_array")
save_file="dqn_"+game
print(env.action_space)
print(env.unwrapped.get_action_meanings())
model=DQN(CnnPolicy,env,verbose=1,exploration_final_eps=0.01,exploration_fraction=0.1,gradient_steps=1,learning_rate=0.0001,buffer_size=10000)
model.set_env(env)
model.learn(total_timesteps=1000000,log_interval=10)
model.save(save_file)
observation,info=env.reset()
score=0
rewards_sum=0
while True:
action,_states=model.predict(observation,deterministic=True)
observation,reward,terminated,truncated,info=env.step(env)
score=score+1
rewards_sum+=reward
if reward > 0:
print("Win!!!!",reward)
if terminated or truncated:
print("Finished ",score)
print("Reward sum=",rewards_sum)
break
上述代码在运行的时候,需要安装几个依赖包,具体如下:
pip install gymnasium[atari]
pip install gymnasium[accept-rom-license]
pip install stable_baselines3
而且运行的时候建议使用GPU,否则Model重新训练的过程会很漫长。