def fedAvg(models, train_gs):  #return weighted avg_paramter of models
    states = dict()
    gloabl_state = dict()
    for user_index in train_gs:
        states[user_index] = models[user_index].state_dict() 

    for key in models[user_index].state_dict():
        for user_index in train_gs:
            if not key in gloabl_state.keys():
                count_D = train_gs[user_index].number_of_nodes()
                gloabl_state[key] = count_D * states[user_index][key]
            else:
                count_D += train_gs[user_index].number_of_nodes()
                gloabl_state[key] += train_gs[user_index].number_of_nodes() * states[user_index][key]

        gloabl_state[key] /= count_D

    return gloabl_state


def fedGate(models, models_delta, train_gs, tau, lr):
    states = dict()
    states_delta = dict()
    gloabl_state = dict()
    for user_index in train_gs:
        states[user_index] = models[user_index].state_dict() 
        states_delta[user_index] = models_delta[user_index].state_dict() 
    for key in models[user_index].state_dict():
        for user_index in train_gs:
            if not key in gloabl_state.keys():
                count_D = train_gs[user_index].number_of_nodes()
                gloabl_state[key] = count_D * states[user_index][key]
            else:
                count_D += train_gs[user_index].number_of_nodes()
                gloabl_state[key] += train_gs[user_index].number_of_nodes() * states[user_index][key]

        gloabl_state[key] /= count_D
            
            
    for user_index in train_gs:
        for key in models[user_index].state_dict():
            states_delta[user_index][key] += (gloabl_state[key] - states[user_index][key]) /(tau*lr)
        models_delta[user_index].load_state_dict(states_delta[user_index])
    return gloabl_state, models_delta
