master_port=43144

splits=(
    "forget01"
    "forget05"
    "forget10"
)
    
model_modes=(
    "remember_uniform"
)


data_modes=(
    "forget_more_retain_perturbed"
)

OUTPUTMODELDIR=trained_models2/hf-ours
EVAL_OUTDIRNAME="eval_outputs-ours/tofu"

# ! Non lora
COMMON="lightning.trainer.devices=2 data.batch_size=4 gradient_accumulation_steps=4 model_train.num_layer=0 model_train.Lora.r=0 lightning.trainer.strategy=deepspeed_stage_3 model_train.weight_decay=0.01 OUTPUTMODELDIR=$OUTPUTMODELDIR "

export CUDA_VISIBLE_DEVICES=1,0
num_devices=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)

lr=1e-5

for split in ${splits[@]}; do
    for i in "${!model_modes[@]}"; do
        model_mode=${model_modes[$i]}
        data_mode=${data_modes[$i]}

        CUDA_VISIBLE_DEVICES=1,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
        	scripts/hf_forget_train.py \
            data.split=${split}_perturbed \
            project="tofu-ours-hf" \
            model_train.learning_rate=$lr \
            model_train=$model_mode \
            data_mode=$data_mode \
            $COMMON

		export CUDA_VISIBLE_DEVICES=1,0
        rawcheckpoints=($(find $OUTPUTMODELDIR -type d | grep $model_mode | grep "checkpoint-[0-9]*"))
		checkpoints=($(printf "%s\n" "${rawcheckpoints[@]}" | awk -F'-' '{print $NF, $0}' | sort -rn | cut -d' ' -f2-))

		for ((i=0; i<${#checkpoints[@]}; i+=1)); do
			checkpoint="${checkpoints[$i]}"
			gpuid=$(((gpuid+1) % $num_devices))
			checkpoint="${checkpoint//=/\\=}"

			weights=(-0.8)
			python scripts/eval_tofu.py \
				data=eval_tofu \
				data.split=${split}_perturbed \
				data.eval.retain_result="data/${split}_llama_wd0.01/eval_results/ds_size300/eval_log_aggregated.json" \
				data.eval.batch_size=40 \
				model=${model} \
				remember=${remember} \
				remember.num_layer=${num_layer} \
				remember.weight=${weight} \
				remember.save_path="${checkpoint}" \
				remember.is_lora=True \
				OUTDIRNAME=${OUTDIRNAME}/${split} \
				remember.top_logit_filter=1e-2 \
				gpu=gpu$((gpuid)) &
			if (((i+1) % $num_devices == 0)); then
				wait
			fi
		done
		wait

        sleep 2
		unset CUDA_VISIBLE_DEVICES
		rm -rf $OUTPUTMODELDIR/*
    done
done