# A Contrastive Framework for Neural Text Generation

## In this file, we show how to reproduce the results presented in the paper.

### 1. Environment Setup (python version: 3.8):
    pip install -r requirements.txt

### 2. Open-ended Document Generation:
Here, we show how to reproduce our results on the Wikitext-103 benchmark.

#### 2.1. Download Wikitext-103:
To download the Wikitext-103 benchmark, please run the following command:
    ```yaml
    chmod +x ./download_wikitext103_dataset.sh
    ./download_wikitext103_dataset.sh
    ```

The downloaded directory looks like:
    .
    ├── ./wikitext103/                    
        ├── wikitext103_raw_v1_train.txt # Training set of Wikitext-103
        ├── wikitext103_raw_v1_validation.txt # Validation set of Wikitext-103
        └── wikitext103_raw_v1_test.txt # Test set of Wikitext-103

#### 2.2. Reproduce the Results:
To reproduce the result in our case study (i.e., Table 4) of the paper. Please run the following commands.

##### 2.2.1. Download Checkpoint:
Please download our trained SimCTG checkpoint as
    ```yaml
    chmod +x ./download_simctg_wikitext103.sh
    ./download_simctg_wikitext103.sh
    ```

##### 2.2.2. Reproduce Results and Compare Different Decoding Methods:
Then, our case study can be reproduced as below. In the meantime, you can also compare the results of contrastive search with other decoding methods (i.e., greedy search, beam search, and nucleus sampling).
    ```yaml
    python document_generation_demo.py
    ```

##### 2.2.3. SimCTG Predictions:
We also provide the results of SimCTG on the test set of Wikitext-103. You can find the prediction file at ./document_generation/simctg_contrastive.json

The predicted file is a list of dictionary, where the data format of each dictionary is:
  ```yaml
  {  
     "prefix_text": The human-written prefix.
     "reference_text": The reference document (prefix + reference text continuation).
     "reference_continuation_text": The reference text continuation.   
     "generated_result": {
         "0": {
             "full_text": The prefix + generated continuation.
             "continuation": The generated continuation.
              }
         }
  }
  ```

#### 2.3. Train the Model by Yourself:
If you would like to train the model on Wikitext-103 by yourself. Please first download the data as described in Section 2.1. In the following, we show how to train SimCTG and the MLE baseline. By training these two models, you can directly see how SimCTG helps to improve the model perplexity as compared with the MLE baseline.

##### 2.3.1. Train SimCTG:
To train SimCTG, please run the following command:
    ```yaml
    cd ./document_generation/
    chmod +x ./train_simctg.sh
    ./train_simctg.sh
    ```

The arguments are as follows:
* `--model_name`: The name of huggingface pre-trained gpt model (e.g., gpt2, gpt-large).
* `--train_path`: The file path of training set.
* `--dev_path`: The file path of validation set.
* `--test_path`: The file path of test set.
* `--margin`: The contrastive margin $\rho$.
* `--max_len`: The maximum length of training samples.
* `--number_of_gpu`: The number of available GPUs.
* `--batch_size_per_gpu`: The batch size for each GPU.
* `--gradient_accumulation_steps`: How many forward computations between two gradient updates.
* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu.
* `--total_steps`: The number of total gradient update steps.
* `--print_every`: Have many steps to show the intermediate results.
* `--save_every`: How many steps to save one checkpoint.
* `--learning_rate`: The learning rate.
* `--save_path_prefix`: Where to save the checkpoints.

##### 2.3.2. Train MLE Baseline:
To train the MLE baseline, please run the following command:
    ```yaml
    cd ./document_generation/
    chmod +x ./train_mle.sh
    ./train_mle.sh
    ```
    
Here, to train the MLE baseline, we can simply set the --margin argument (i.e., $\rho$) as 0. Because when $\rho$ equals to 0, the SimCTG degenerates to the vanilla MLE objective as described in Section 3.1 of our paper.

##### 2.3.3. Train Unlikelihood Baseline:
To train the Unlikelihood baseline, please refer to the their official code here: https://github.com/facebookresearch/unlikelihood_training.


### 3. Open-domain Dialogue Generation
To reproduce our results on LCCC (i.e., Table 7 in Appendix G of our paper), please run the following commands.

#### 3.1. Download Checkpoints:
    ```yaml
    chmod +x ./download_simctg_lccc.sh
    ./download_simctg_lccc.sh
    ```

#### 3.3. Reproduce Results and Compare Different Decoding Methods:
By running the following command, you can reproduce the results of contrastive search as well as comparing contrastive search with other decoding methods (i.e., greedy search, beam search, and nucleus sampling).

    ```yaml
    python dialogue_generation_demo.py
    ```
