from deepspeed_ray import DeepSpeedPredictorCluster, DPCConfig
import pandas as pd
import asyncio 
from attacks import ZouAttack
import json

proxy_config = DPCConfig(
    # model_name="gpt2",
    model_name="lmsys/vicuna-7b-v1.3",
    # model_name="meta-llama/Llama-2-7b-chat-hf",
    num_worker_groups=1,
    num_workers_per_group=1,
    mode="fsdp",
    fsdp_compile=False,
)

eval_config = DPCConfig(
    # model_name="gpt2",
    model_name="lmsys/vicuna-7b-v1.3",
    # model_name="meta-llama/Llama-2-7b-chat-hf",
    num_worker_groups=4,
    num_workers_per_group=1,
    mode="inference",
    deepspeed_kwargs={},
)


async def main():
    proxy_dpc, eval_dpc = await asyncio.gather(
        DeepSpeedPredictorCluster.create(proxy_config),
        DeepSpeedPredictorCluster.create(eval_config),
    )
    # eval_dpc = await DeepSpeedPredictorCluster.create(eval_config)
    # proxy_dpc = await DeepSpeedPredictorCluster.create(proxy_config)
    

    za = ZouAttack(eval_config.model_name, proxy_config.model_name, eval_dpc=eval_dpc, proxy_dpc=proxy_dpc, device=10)
    out = open("results/zou2_vicuna_7to7.jsonl", "a")
    
    df = pd.read_csv("llm-attacks/data/advbench/harmful_strings.csv")
    for target in df.target:
        print(target)
        result = await za.do_attack("", target, max_iter=500, batch_size=512)
        out.write(json.dumps(result) + "\n")
        out.flush()
        print(result)

if __name__ == "__main__":
    asyncio.run(main())
