Deep Q-Network (DQN)
- class rlforge.agents.semi_gradient.dqn.DQNAgent(learning_rate, discount, state_dim, num_actions, temperature=1, network_architecture=[2], target_network_update_steps=8, num_replay=0, experience_buffer_size=1024, mini_batch_size=8)
Deep Q-Network (DQN) Agent implemented with raw function approximation.
This agent uses a hand-rolled multilayer perceptron (MLP) for Q-value estimation, an experience replay buffer, and a target network for stabilizing training.
Parameters
- learning_ratefloat
Step size for weight updates in the MLP.
- discountfloat
Discount factor (gamma) for future rewards.
- state_dimint
Dimension of the environment’s state space.
- num_actionsint
Number of discrete actions available in the environment.
- temperaturefloat, optional (default=1)
Temperature parameter for softmax action selection.
- network_architecturelist of int, optional (default=[2])
Hidden layer sizes for the MLP.
- target_network_update_stepsint, optional (default=8)
Frequency (in training steps) to copy weights from the main network to the target network.
- num_replayint, optional (default=0)
Number of replay updates per environment step.
- experience_buffer_sizeint, optional (default=1024)
Maximum size of the replay buffer.
- mini_batch_sizeint, optional (default=8)
Number of samples per replay update.
Attributes
- main_networkMLP
The primary Q-value estimator.
- target_networkMLP
A periodically updated copy of the main network for stable targets.
- experience_bufferExperienceBuffer
Stores transitions for replay updates.
- elapsed_training_stepsint
Counter for steps since last target network update.
- end(reward)
Handle the terminal transition at the end of an episode.
Parameters
- rewardfloat
Final reward received before termination.
- get_td_error(experiences)
Compute temporal-difference (TD) error for a batch of experiences.
Parameters
- experienceslist of tuples
Each tuple contains (state, action, reward, terminal, next_state).
Returns
- td_error_matnp.ndarray
Matrix of TD errors aligned with actions.
- cachedict
Cached forward pass values for backpropagation.
- reset()
Reset the agent’s networks and replay buffer for a fresh run.
- select_action(q_values, temperature)
Select an action using softmax exploration.
Parameters
- q_valuesnp.ndarray
Q-values for the current state.
- temperaturefloat
Softmax temperature.
Returns
- actionint
Selected action index.