source activate lla
set -e pipefail

# DONT' CHNAGE THIS!!!
models_name=(
"cola"
"sst2"
"mrpc"
"stsb"
"qqp"
"mnli"
"qnli"
"rte"
)
models_to_merge=()
for d in "${models_name[@]}"; do
models_to_merge+=($xx_HOME/twin-merge/MergeLM/save_models/$d/roberta-base_lr1e-05)
done


function twin_merge(){

date_today=$(date '+%Y-%m-%d')
outdir=${outdir:="outs/twin_${date_today}"}
mkdir -p ${outdir}
# NOTICE: we only select prefix 
select_merge=${select_merge:="8"}
select_twin=${select_twin:="8"}

if [ $select_merge -eq 1 ]; then
    echo "please set \$select_merge > 1"
    exit 1 
elif [ $select_twin -eq 1 ]; then
    datapath="data_glue/new_dataset2.json"
    if [ -z $src_twin ];then
        echo "please set \$src_twin!"
        exit 1
    fi
else
    datapath="data_glue/new_dataset_data5500_indomain$select_twin.json"
    src_twin=("${models_name[@]:0:$select_twin}") 
    src_merge=("${models_name[@]:0:$select_merge}") 
fi

mask_strategy=${mask_strategy:="svd"}
mask_rate=${mask_rate:="0.9"}
echo ">>> use data_path $datapath"
echo ">>> use outdir $outdir"
echo ">>> merged from $select_merge tasks"
echo ">>> use twin vector from $select_twin tasks"
echo ">>> mask_rate $mask_rate; mask_strategy $mask_strategy"

python twin_merge.py \
--models-to-merge ${models_to_merge[@]} \
--models-name ${models_name[@]} \
--data-path $datapath \
--src-merge ${src_merge[@]} \
--src-twin ${src_twin[@]} \
--yaml-file config/twin_merge.yml \
--exclude-param ".*classifier.*" ".*bias.*" \
--mask-rate $mask_rate \
--mask-strategy $mask_strategy \
--outdir $outdir 

}

function run_1(){

for src in "cola" "sst2" "mrpc" "stsb" "qqp" "mnli" "qnli" "rte"; do
    src_twin=($src) select_twin=1 twin_merge
done

}

function run_mul(){

for i in 2 3 4 5 6 7; do
    select_twin=$i twin_merge
done

}

function run_ood(){

for i in 2 3 4 5 6 7; do
    select_twin=$i select_merge=$i twin_merge
done

}

function run_rate(){
rates=(
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
0.91
0.92
0.93
0.94
0.95
0.96
0.97
0.98
0.99
0.991
0.992
0.993
0.994
0.995
0.996
0.997
0.998
)

for mask_rate in ${rates[@]}; do
# select_twin=8 select_merge=8
    mask_rate=$mask_rate twin_merge
done

}

function run_rate3(){
rates=(
0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
0.91
0.991
0.992
0.993
0.994
0.995
0.996
0.997
0.998
)

for mask_rate in ${rates[@]}; do
    outdir="outs/ood-err-rank" mask_rate=$mask_rate run_ood
done

}

function run_rate2(){
rates=(
0.7
0.91
0.991
0.994
0.998
)

for mask_rate in ${rates[@]}; do
    outdir="outs/ind-err" mask_rate=$mask_rate run_mul
done

}


function run_ablation(){
    # 不用common vector: 
datapath="data_glue/new_dataset_data5500_indomain5.json"
src_twin=("cola" "sst2" "mrpc" "stsb" "qqp")
src_merge=("cola" "sst2" "mrpc" "stsb" "qqp")

mask_strategy=${mask_strategy:="svd"}
mask_rate=${mask_rate:="0.998"}

python twin_search.py \
--models-to-merge ${models_to_merge[@]} \
--models-name ${models_name[@]} \
--data-path $datapath \
--src-merge ${src_merge[@]} \
--src-twin ${src_twin[@]} \
--yaml-file config/task_arithmetic_search.yml \
--exclude-param ".*classifier.*" ".*bias.*" \
--mask-rate $mask_rate \
--mask-strategy $mask_strategy \
--outdir "outs/ablation" \
--scaling 0 0 

}



# CUDA_VISIBLE_DEVICES=1 run_rate3 
CUDA_VISIBLE_DEVICES=2 run_ablation 
# run_rate