Paper link: https://arxiv.org/pdf/2502.10325
To set up the project, clone the repository and create a Conda environment:
cd agent_prm
conda env create -f environment.yml
conda activate agent_prm
pip install -e .Ensure you have a .env file with the requisite keys:
OPENAI_API_KEY=your_openai_api_key
OPENAI_ORGANIZATION=your_openai_organization_id
GEMINI_API_KEY=your_gemini_key
ANTHROPIC_API_KEY=your_anthropic_keyWe build on OpenInstruct for training, with some minor compatibility fixes so it needs to be installed locally.
# Clone and install Open-Instruct
git clone --branch fix_vllm https://github.com/sanjibanc/open-instruct.git
cd open-instruct
pip install -e .
cd ..We use SGLang for fast inference, with some minor compatibility fixes with LLama so it needs to be installed locally.
# Clone and install SGLang
git clone --branch new_llama_model https://github.com/sanjibanc/sglang.git
cd sglang
pip install -e .
cd ..To use slgang server, got to SGlang instructions
To set up external environments like AlfWorld, go to external environment instructions.
Agent PRM iterates over 3 stages:
- Rollout policy and compute PRM targets
- Train PRM
- Train policy via RL
Stage 2 and 3 are similar to standard RLHF operations, with stage 1 being the agent specific step.
We collect SFT training data from our prior work LEAP and train a policy via SFT
bash bash/train-sft-llama3.2-3B.shFor simplicity we provide the model here rl-llm-agent/Llama-3.2-3B-Instruct-sft-alfworld-iter0
We rollout in a batched fashion, and recommend using the SGLangServerAgent for fast inference. See sglang instructions to setup the SGLang server, then run the following script to collect rollouts
python scripts/dataproc/rollout_alfworld.py --config configs/rollout_alfworld.yamlOnce you have the rollouts, set the rollout path in configs/compute_prm_target and run
python scripts/dataproc/compute_prm_target.py --config configs/compute_prm_target.yamlThis should create a train and test file to train the PRM
To train the PRM, run the script that calls open instruct
bash bash/train-rm-llama3.2-3B.shUpload the best checkpoint to HF for convenience
python scripts/utils/upload_model_to_hf.py --input_model <path/to/checkpoint> --output_model <hugging face model path> --accelerateTo train the policy via OnlineDPO to optimize the PRM, run the following script
bash bash/online-dpo-llama3.2-3B.shUpload the best checkpoint to HF for convenience
Repeat stages 1 to 3.
Configure the agents you want to evaluate in configs/eval_alfworld.yaml and run the following script:
python scripts/eval/eval_alfworld.py --config configs/evaluate_alfworld.yamlIt will create a folder in data/eval/alfworld/ with the current datetime where logs and summary.csv will be saved.
For fast inference, use a SGLang server agent and host the policy in a SGLang server.
To evaluate a Best-of-N policy, host both the policy and the PRM in SGLang, and run the script with best_of_n agent.
Stage 1: Given expert demonstrations and policy rollouts, compute inverse PRM target
python scripts/dataproc/compute_inverse_prm_target.py --config configs/compute_inverse_prm_target.yamlStage 2: Train PRM
bash bash/train-inverse-prm-llama3.2-3B.shStage 3: Train generator as in agent prm
To train the PRM using a relative loss, change the target computation to be a preference dataset
python scripts/dataproc/compute_prm_preference_target.py --config configs/compute_prm_preference_target.yamlTo train the PRM using preference data, use the script
bash bash/train-rm-pref-llama3.2-3B.shTo train the policy using a steered exploration prompt prompts/alfworld/alfworld_exploration_template.j2, run the following script
python scripts/dataproc/compute_value_target.py --config configs/<path to value target.yaml>
bash bash/train-value-model-llama3.2-3B.shGiven a reference policy, collect rollouts, compute value targets and train a value estimate
bash bash/online-dpo-exploration-llama3.2-3B.shUse the value function to compute shaped PRM targets. This requires running the value function as a critic in a SGLang server
python scripts/dataproc/compute_shaped_prm_target.py --config configs/compute_shaped_prm_target.yamlTrain the shaped PRM
bash bash/train-shaped-rm-llama3.2-3B.shTrain the policy via online DPO
bash bash/online-dpo-shaped-prm-llama3.2-3B.shTo use SGLang for inference, grab a node from the same network as your inference scripts so they can communicate over the network.
SGLang has some compatibility issues with agent_prm conda environment, so we recommend using the sglang environment
conda env create -f sglang_environment.yml
conda activate sglangTo host a model, run
python -m sglang.launch_server --model-path <model_name> --port <port_number, e.g. 30000>When doing inference for Best-of-N with a PRM, you might want to grab two such nodes, one for the generator, and one for the verifier and assign them two different ports 3000 and 30010.
Clone AlfWorld from AlfWorld github repository. Follow the instructions in its README to get the game files.
Create an env_assets folder and copy over data to env_assets/alfworld. Set the following environment variable:
export ALFWORLD_DATA=</path/to/env_assets/alfworld>This project is is actively being developed. For any questions or issues, please contact us at sanjibanc@cornell.edu.