Deep Q-Network (DQN)

class rlforge.agents.semi_gradient.dqn.DQNAgent(learning_rate, discount, state_dim, num_actions, temperature=1, network_architecture=[2], target_network_update_steps=8, num_replay=0, experience_buffer_size=1024, mini_batch_size=8)

Deep Q-Network (DQN) Agent implemented with raw function approximation.

This agent uses a hand-rolled multilayer perceptron (MLP) for Q-value estimation, an experience replay buffer, and a target network for stabilizing training.

Parameters

learning_ratefloat

Step size for weight updates in the MLP.

discountfloat

Discount factor (gamma) for future rewards.

state_dimint

Dimension of the environment’s state space.

num_actionsint

Number of discrete actions available in the environment.

temperaturefloat, optional (default=1)

Temperature parameter for softmax action selection.

network_architecturelist of int, optional (default=[2])

Hidden layer sizes for the MLP.

target_network_update_stepsint, optional (default=8)

Frequency (in training steps) to copy weights from the main network to the target network.

num_replayint, optional (default=0)

Number of replay updates per environment step.

experience_buffer_sizeint, optional (default=1024)

Maximum size of the replay buffer.

mini_batch_sizeint, optional (default=8)

Number of samples per replay update.

Attributes

main_networkMLP

The primary Q-value estimator.

target_networkMLP

A periodically updated copy of the main network for stable targets.

experience_bufferExperienceBuffer

Stores transitions for replay updates.

elapsed_training_stepsint

Counter for steps since last target network update.

end(reward)

Handle the terminal transition at the end of an episode.

Parameters

rewardfloat

Final reward received before termination.

get_td_error(experiences)

Compute temporal-difference (TD) error for a batch of experiences.

Parameters

experienceslist of tuples

Each tuple contains (state, action, reward, terminal, next_state).

Returns

td_error_matnp.ndarray

Matrix of TD errors aligned with actions.

cachedict

Cached forward pass values for backpropagation.

reset()

Reset the agent’s networks and replay buffer for a fresh run.

select_action(q_values, temperature)

Select an action using softmax exploration.

Parameters

q_valuesnp.ndarray

Q-values for the current state.

temperaturefloat

Softmax temperature.

Returns

actionint

Selected action index.

start(new_state)

Begin an episode by selecting an action from the initial state.

Parameters

new_statenp.ndarray

The initial environment state.

Returns

actionint

The selected action.

step(reward, new_state)

Take a step in the environment, update replay buffer, and train.

Parameters

rewardfloat

Reward received from the previous action.

new_statenp.ndarray

The new environment state.

Returns

actionint

The next action chosen by the agent.