diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index bfb463466..24af938f0 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -1,12 +1,10 @@ name: Unit and integration tests -# WARNING: As this workflow supports the pull_request_target event, please exercise extra care when editing it. on: workflow_dispatch: - pull_request_target: + pull_request: branches: [ main ] - types: [ opened, synchronize, reopened, labeled ] push: branches: [ main ] @@ -15,69 +13,15 @@ concurrency: cancel-in-progress: true jobs: - authorize: - if: (github.event.action == 'labeled' && github.event.label.name == 'run-test') || github.event_name != 'pull_request_target' || (! github.event.pull_request.head.repo.fork) - runs-on: ubuntu-latest - steps: - - run: true - - start-runner: - name: Start self-hosted EC2 runner - needs: authorize - runs-on: ubuntu-22.04 - env: - AWS_REGION: us-east-1 - EC2_AMI_ID: ami-04fe9856174d852b8 - EC2_INSTANCE_TYPE: dl1.24xlarge - EC2_SUBNET_ID: subnet-b7533b96 - EC2_SECURITY_GROUP: sg-08af7938042271373 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 runner - id: start-ec2-runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: start - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - aws-resource-tags: > # optional, requires additional permissions - [ - {"Key": "Name", "Value": "optimum-habana-ci-fast-tests"}, - {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} - ] transformers: name: Run tests for optimum.habana.transformers - needs: - - authorize - - start-runner # required to start the main job when the runner is ready - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-east-1 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - - name: Checkout on branch - if: github.event_name != 'pull_request_target' + - name: Checkout uses: actions/checkout@v2 - with: - ref: ${{ github.ref }} - - name: Checkout on PR merge commit - if: github.event_name == 'pull_request_target' - uses: actions/checkout@v2 - with: - ref: ${{ github.event.pull_request.merge_commit_sha }} - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -89,31 +33,19 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/fast_tests.sh diffusers: name: Run tests for optimum.habana.diffusers needs: - - authorize - - start-runner # required to get output from the start-runner job - transformers # required to wait for the previous tests to finish - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-east-1 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - - name: Checkout on branch - if: github.event_name != 'pull_request_target' - uses: actions/checkout@v2 - with: - ref: ${{ github.ref }} - - name: Checkout on PR merge commit - if: github.event_name == 'pull_request_target' + - name: Checkout uses: actions/checkout@v2 - with: - ref: ${{ github.event.pull_request.merge_commit_sha }} - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -125,30 +57,5 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/fast_tests_diffusers.sh - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - authorize - - start-runner # required to get output from the start-runner job - - transformers # required to wait for the tests to be finished - - diffusers # required to wait for the tests to be finished - runs-on: ubuntu-22.04 - env: - AWS_REGION: us-east-1 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/slow_tests.yml b/.github/workflows/slow_tests.yml index ab2e46b24..82914019e 100644 --- a/.github/workflows/slow_tests.yml +++ b/.github/workflows/slow_tests.yml @@ -10,53 +10,15 @@ concurrency: group: ${{ github.workflow }} jobs: - start-runner: - name: Start self-hosted EC2 runner - runs-on: ubuntu-22.04 - env: - AWS_REGION: us-west-2 - EC2_AMI_ID: ami-03549026a9aa06f99 - EC2_INSTANCE_TYPE: dl1.24xlarge - EC2_SUBNET_ID: subnet-452c913d - EC2_SECURITY_GROUP: sg-0894f4f70dd6bd778 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 runner - id: start-ec2-runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: start - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - aws-resource-tags: > # optional, requires additional permissions - [ - {"Key": "Name", "Value": "optimum-habana-ci-slow-tests"}, - {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} - ] example-diff: name: Test examples differences - needs: - - start-runner # required to start the main job when the runner is ready - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -68,23 +30,20 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/example_diff_tests.sh stable-diffusion: name: Test Stable Diffusion if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff # run the job when the previous test job is done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -96,24 +55,21 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_diffusers.sh deepspeed: name: Test DeepSpeed models if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - stable-diffusion # run the job when the previous test job is done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -125,24 +81,21 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_deepspeed.sh multi-card: name: Test multi-card models if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed # run the job when the previous test job is done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -154,25 +107,22 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_8x.sh single-card: name: Test single-card models if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed - multi-card # run the job when the previous test jobs are done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -184,20 +134,17 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_1x.sh albert-xxl-single-card: name: Test single-card ALBERT XXL if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed - multi-card - single-card # run the job when the previous test jobs are done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout if: github.event.schedule == '0 21 * * 6' @@ -205,7 +152,7 @@ jobs: - name: Pull image if: github.event.schedule == '0 21 * * 6' run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run test if: github.event.schedule == '0 21 * * 6' run: | @@ -218,7 +165,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/albert_xxl_1x.sh - name: Warning if: github.event.schedule != '0 21 * * 6' @@ -227,21 +174,18 @@ jobs: name: Test text-generation example if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed - multi-card - single-card - albert-xxl-single-card # run the job when the previous test jobs are done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -253,28 +197,25 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ make slow_tests_text_generation_example TOKEN=${{ secrets.TEXT_GENERATION_CI_HUB_TOKEN }} trl: name: Test TRL integration if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed - multi-card - single-card - albert-xxl-single-card - text-generation # run the job when the previous test jobs are done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -286,13 +227,12 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_trl.sh sentence-transformers: name: Test Sentence Transformers integration if: ${{ !cancelled() && (success() || failure()) }} needs: - - start-runner - example-diff - deepspeed - multi-card @@ -300,9 +240,7 @@ jobs: - albert-xxl-single-card - text-generation - trl # run the job when the previous test jobs are done - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - AWS_REGION: us-west-2 + runs-on: [self-hosted, linux, x64, gaudi-habana] # run the job on the newly created runner steps: - name: Checkout Optimum Habana uses: actions/checkout@v2 @@ -316,7 +254,7 @@ jobs: path: sentence-transformers - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -328,35 +266,5 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash optimum-habana/tests/ci/sentence_transformers.sh - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - start-runner # required to get output from the start-runner job - - example-diff - - deepspeed - - multi-card - - single-card - - albert-xxl-single-card - - text-generation - - trl - - sentence-transformers - runs-on: ubuntu-22.04 - env: - AWS_REGION: us-west-2 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/slow_tests_gaudi2.yml b/.github/workflows/slow_tests_gaudi2.yml index f5edc569b..623b62f32 100644 --- a/.github/workflows/slow_tests_gaudi2.yml +++ b/.github/workflows/slow_tests_gaudi2.yml @@ -17,7 +17,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -30,7 +30,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_diffusers.sh deepspeed: name: Test DeepSpeed models @@ -43,7 +43,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -56,7 +56,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_deepspeed.sh fsdp: name: Test FSDP models @@ -69,7 +69,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -82,7 +82,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ make slow_tests_fsdp TOKEN=${{ secrets.TEXT_GENERATION_CI_HUB_TOKEN }} multi-card: name: Test multi-card models @@ -95,7 +95,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -108,7 +108,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_8x.sh single-card: name: Test single-card models @@ -122,7 +122,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest:latest - name: Run tests run: | docker run \ @@ -136,7 +136,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_1x.sh text-generation: name: Test text-generation example @@ -151,7 +151,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -164,7 +164,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ make slow_tests_text_generation_example TOKEN=${{ secrets.TEXT_GENERATION_CI_HUB_TOKEN }} trl: name: Test TRL integration @@ -177,7 +177,7 @@ jobs: uses: actions/checkout@v2 - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -190,7 +190,7 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash tests/ci/slow_tests_trl.sh sentence-transformers: name: Test Sentence Transformers integration @@ -211,7 +211,7 @@ jobs: path: sentence-transformers - name: Pull image run: | - docker pull vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - name: Run tests run: | docker run \ @@ -224,5 +224,5 @@ jobs: --cap-add=sys_nice \ --net=host \ --ipc=host \ - vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest \ + vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest \ /bin/bash optimum-habana/tests/ci/sentence_transformers.sh diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 000000000..9cbbf6803 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,15 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/Makefile b/Makefile index efa5d625a..eb6036a6b 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,7 @@ slow_tests_1x: test_installs python -m pytest tests/test_examples.py -v -s -k "single_card" python -m pip install peft==0.10.0 python -m pytest tests/test_peft_inference.py + python -m pytest tests/test_pipeline.py # Run multi-card non-regression tests slow_tests_8x: test_installs @@ -53,7 +54,7 @@ slow_tests_8x: test_installs # Run DeepSpeed non-regression tests slow_tests_deepspeed: test_installs - python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 + python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 python -m pytest tests/test_examples.py -v -s -k "deepspeed" slow_tests_diffusers: test_installs @@ -65,7 +66,7 @@ slow_tests_diffusers: test_installs # Run text-generation non-regression tests slow_tests_text_generation_example: test_installs - python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 + python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 python -m pytest tests/test_text_generation_example.py tests/test_encoder_decoder.py -v -s --token $(TOKEN) # Run image-to-text non-regression tests @@ -76,7 +77,7 @@ slow_tests_fsdp: test_installs python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN) slow_tests_trl: test_installs - python -m pip install trl==0.7.8 + python -m pip install trl==0.8.6 python -m pip install peft==0.7.0 python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss" diff --git a/README.md b/README.md index 65a5ba5df..fabff9e26 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,15 @@ HPUs offer fast model training and inference as well as a great price-performanc Check out [this blog post about BLOOM inference](https://huggingface.co/blog/habana-gaudi-2-bloom) and [this post benchmarking Intel Gaudi 2 and NVIDIA A100 GPUs for BridgeTower training](https://huggingface.co/blog/bridgetower) for concrete examples. +## Gaudi Setup + +Please refer to the Intel Gaudi AI Accelerator official [installation guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html). + +> Tests should be run in a Docker container based on Intel Gaudi Docker images. +> +> The current version has been validated for SynapseAI 1.16. + + ## Install the library and get example scripts ### Option 1: Use the latest stable release @@ -50,9 +59,9 @@ The `--upgrade-strategy eager` option is needed to ensure `optimum-habana` is up To use the example associated with the latest stable release, run: > ``` > git clone https://github.com/huggingface/optimum-habana -> cd optimum-habana && git checkout v1.11.1 +> cd optimum-habana && git checkout v1.12.0 > ``` -> with `v1.11.1` the version number of this release. +> with `v1.12.0` the version number of this release. ### Option 2: Use the latest main branch under development @@ -67,7 +76,7 @@ git clone https://github.com/huggingface/optimum-habana To use DeepSpeed on HPUs, you also need to run the following command: >```bash ->pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +>pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 >``` To install the requirements for every example: @@ -176,7 +185,7 @@ The following model architectures, tasks and device distributions have been vali | DistilBERT |:heavy_check_mark: | :heavy_check_mark: |
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • | | GPT2 | :heavy_check_mark: | :heavy_check_mark: |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | BLOOM(Z) | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| StarCoder | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| StarCoder / StarCoder2 | :heavy_check_mark: |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | GPT-J |
  • DeepSpeed
  • |
  • Single card
  • DeepSpeed
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | GPT-NeoX |
  • DeepSpeed
  • |
  • DeepSpeed
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | OPT | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | @@ -204,6 +213,7 @@ The following model architectures, tasks and device distributions have been vali | Blip | |
  • Single card
  • |
  • [visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | +| Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | @@ -237,15 +247,6 @@ If you find any issues while using those, please open an issue or a pull request After training your model, feel free to submit it to the Intel [leaderboard](https://huggingface.co/spaces/Intel/powered_by_intel_llm_leaderboard) which is designed to evaluate, score, and rank open-source LLMs that have been pre-trained or fine-tuned on Intel Hardwares. Models submitted to the leaderboard will be evaluated on the Intel Developer Cloud. The evaluation platform consists of Gaudi Accelerators and Xeon CPUs running benchmarks from the Eleuther AI Language Model Evaluation Harness. -## Gaudi Setup - -Please refer to the Intel Gaudi AI Accelerator official [installation guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html). - -> Tests should be run in a Docker container based on Intel Gaudi Docker images. -> -> The current version has been validated for SynapseAI 1.15. - - ## Development Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions. \ No newline at end of file diff --git a/docs/Dockerfile b/docs/Dockerfile index 3d253fd36..a31904c95 100644 --- a/docs/Dockerfile +++ b/docs/Dockerfile @@ -1,4 +1,4 @@ -FROM vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest +FROM vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest ARG commit_sha ARG clone_url diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 5e84eafe3..b33cfd062 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -43,7 +43,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | DistilBERT | ✅ | ✅ |
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering)
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • | | GPT2 | ✅ | ✅ |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | BLOOM(Z) | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| StarCoder | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| StarCoder / StarCoder2 | ✅ |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | GPT-J |
  • DeepSpeed
  • |
  • Single card
  • DeepSpeed
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | GPT-NeoX |
  • DeepSpeed
  • |
  • DeepSpeed
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | OPT | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | @@ -71,6 +71,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Blip | |
  • Single card
  • |
  • [visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | +| Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | - Diffusers diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 1dc06fd02..3d657260f 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -23,6 +23,6 @@ python -m pip install --upgrade-strategy eager optimum[habana] To use DeepSpeed on HPUs, you also need to run the following command: ```bash -python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 ``` diff --git a/docs/source/usage_guides/deepspeed.mdx b/docs/source/usage_guides/deepspeed.mdx index 51734bb42..dfd68b278 100644 --- a/docs/source/usage_guides/deepspeed.mdx +++ b/docs/source/usage_guides/deepspeed.mdx @@ -31,7 +31,7 @@ You can find more information about DeepSpeed Gaudi integration [here](https://d To use DeepSpeed on Gaudi, you need to install Optimum Habana and [Habana's DeepSpeed fork](https://github.com/HabanaAI/DeepSpeed) with: ```bash pip install optimum[habana] -pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 ``` @@ -78,7 +78,7 @@ It is strongly advised to read [this section](https://huggingface.co/docs/transf -Other examples of configurations for HPUs are proposed [here](https://github.com/HabanaAI/Model-References/tree/1.15.0/PyTorch/nlp/DeepSpeedExamples/deepspeed-bert/scripts) by Habana. +Other examples of configurations for HPUs are proposed [here](https://github.com/HabanaAI/Model-References/tree/1.16.0/PyTorch/nlp/DeepSpeedExamples/deepspeed-bert/scripts) by Habana. The [Transformers documentation](https://huggingface.co/docs/transformers/main_classes/deepspeed#configuration) explains how to write a configuration from scratch very well. A more complete description of all configuration possibilities is available [here](https://www.deepspeed.ai/docs/config-json/). diff --git a/examples/audio-classification/README.md b/examples/audio-classification/README.md index 0bd77d1aa..7e91e46ea 100644 --- a/examples/audio-classification/README.md +++ b/examples/audio-classification/README.md @@ -20,6 +20,12 @@ The following examples showcase how to fine-tune `Wav2Vec2` for audio classifica Speech recognition models that have been pretrained in an unsupervised fashion on audio data alone, *e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html), have shown to require only very little annotated data to yield good performance on speech classification datasets. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` ## Single-HPU @@ -102,7 +108,7 @@ On 8 HPUs, this script should run in ~12 minutes and yield an accuracy of **80.4 > You need to install DeepSpeed with: > ```bash -> pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +> pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 > ``` DeepSpeed can be used with almost the same command as for a multi-card run: diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index 20baedf0a..b8c1e146c 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -46,8 +46,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/contrastive-image-text/README.md b/examples/contrastive-image-text/README.md index 636913a72..d19ddcaad 100644 --- a/examples/contrastive-image-text/README.md +++ b/examples/contrastive-image-text/README.md @@ -23,6 +23,13 @@ This folder contains two examples: Such models can be used for natural language image search and potentially zero-shot image classification. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Download COCO dataset (2017) This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the COCO dataset before training. diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index fe582fbba..dd4b7e3fb 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -56,8 +56,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index 868c006e4..8d3b3a28a 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -61,8 +61,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/gaudi_spawn.py b/examples/gaudi_spawn.py index b7833c417..8896e0a14 100644 --- a/examples/gaudi_spawn.py +++ b/examples/gaudi_spawn.py @@ -84,7 +84,7 @@ def main(): if not is_deepspeed_available(): raise ImportError( "--use_deepspeed requires deepspeed: `pip install" - " git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0`." + " git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`." ) # Patch sys.argv diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md index f7a9cdd18..642cf427d 100644 --- a/examples/image-classification/README.md +++ b/examples/image-classification/README.md @@ -19,6 +19,13 @@ limitations under the License. This directory contains a script that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit) or [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)) on HPUs. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data). +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Single-HPU training ### Using datasets from Hub diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 5e600e766..9f25269b4 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -63,8 +63,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 67a9857bc..7b41f870e 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -30,4 +30,29 @@ python3 run_pipeline.py \ Models that have been validated: - [nlpconnect/vit-gpt2-image-captioning](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning) - [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) - - [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) \ No newline at end of file + - [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) + +### Running with FP8 + +Llava-1.5-7b and Llava-1.5-13b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. + +More information on enabling fp8 in SynapseAI is available here: +https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html + +Here is an example to measure the tensor quantization statistics on Llava-1.5-7b: +```bash +QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ +--model_name_or_path llava-hf/llava-1.5-7b-hf \ +--image_path "https://llava-vl.github.io/static/images/view.jpg" \ +--use_hpu_graphs \ +--bf16 +``` + +Here is an example to quantize the model based on previous measurements for Llava-1.5-7b: +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ +--model_name_or_path llava-hf/llava-1.5-7b-hf \ +--image_path "https://llava-vl.github.io/static/images/view.jpg" \ +--use_hpu_graphs \ +--bf16 +``` diff --git a/examples/image-to-text/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json b/examples/image-to-text/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json new file mode 100644 index 000000000..602a147ba --- /dev/null +++ b/examples/image-to-text/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/image-to-text/quantization_config/maxabs_measure.json b/examples/image-to-text/quantization_config/maxabs_measure.json new file mode 100644 index 000000000..3645fe743 --- /dev/null +++ b/examples/image-to-text/quantization_config/maxabs_measure.json @@ -0,0 +1,9 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/image-to-text/quantization_config/maxabs_measure_include_outputs.json b/examples/image-to-text/quantization_config/maxabs_measure_include_outputs.json new file mode 100644 index 000000000..6de845a54 --- /dev/null +++ b/examples/image-to-text/quantization_config/maxabs_measure_include_outputs.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "measure_exclude": "NONE", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/image-to-text/quantization_config/maxabs_quant.json b/examples/image-to-text/quantization_config/maxabs_quant.json new file mode 100644 index 000000000..02314a728 --- /dev/null +++ b/examples/image-to-text/quantization_config/maxabs_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/image-to-text/quantization_config/unit_scale_quant.json b/examples/image-to-text/quantization_config/unit_scale_quant.json new file mode 100644 index 000000000..caad4bb2a --- /dev/null +++ b/examples/image-to-text/quantization_config/unit_scale_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "unit_scale", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 39d26a6b2..52df29f52 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -16,6 +16,7 @@ import argparse import json import logging +import os import time from pathlib import Path @@ -85,15 +86,31 @@ def main(): parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.") parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.") parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.") + parser.add_argument( + "--ignore_eos", + action="store_true", + help="Whether to ignore eos, set False to disable it.", + ) args = parser.parse_args() + # set args.quant_config with env variable if it is set + args.quant_config = os.getenv("QUANT_CONFIG", "") + adapt_transformers_to_gaudi() model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type if args.image_path is None and model_type == "llava": args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] + elif args.image_path is None and model_type == "llava_next": + args.image_path = [ + "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + ] if args.prompt is None and model_type == "llava": args.prompt = "\nUSER: What's the content of the image?\nASSISTANT:" + elif args.prompt is None and model_type == "llava_next": + args.prompt = "[INST] \nWhat is shown in this image? [/INST]" + if args.model_name_or_path == "llava-hf/llava-v1.6-vicuna-13b-hf": + args.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:" image_paths = args.image_path image_paths_len = len(image_paths) @@ -116,6 +133,11 @@ def main(): else: model_dtype = torch.float32 + if args.quant_config: + import habana_frameworks.torch.core as htcore + + htcore.hpu_set_env() + generator = pipeline( "image-to-text", model=args.model_name_or_path, @@ -126,24 +148,42 @@ def main(): "lazy_mode": True, "hpu_graphs": args.use_hpu_graphs, "max_new_tokens": args.max_new_tokens, - "ignore_eos": False, + "ignore_eos": args.ignore_eos, } if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph generator.model = wrap_in_hpu_graph(generator.model) + if args.quant_config: + import habana_quantization_toolkit + + habana_quantization_toolkit.prep_model(generator.model) + + htcore.hpu_initialize(generator.model) + # warm up for i in range(args.warmup): generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs) + torch.hpu.synchronize() + if args.quant_config: + habana_quantization_toolkit.finish_measurements(generator.model) + start = time.perf_counter() for i in range(args.n_iterations): result = generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs) end = time.perf_counter() duration = end - start - total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens + # Let's calculate the number of generated tokens + n_input_tokens = len(generator.tokenizer(args.prompt).input_ids) if args.prompt is not None else 0 + n_output_tokens = 0 + for sequence in result: + # We have to subtract the number of input tokens as they are part of the returned sequence + n_output_tokens += len(generator.tokenizer(sequence[0]["generated_text"]).input_ids) - n_input_tokens + + total_new_tokens_generated = args.n_iterations * n_output_tokens throughput = total_new_tokens_generated / duration logger.info( f"result = {result}, time = {(end-start) * 1000 / args.n_iterations }ms, Throughput (including tokenization) = {throughput} tokens/second" diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 4b6d26693..8c77f0e81 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -22,6 +22,13 @@ GPT-2 is trained or fine-tuned using a causal language modeling (CLM) loss while The following examples will run on datasets hosted on our [hub](https://huggingface.co/datasets) or with your own text files for training and validation. We give examples of both below. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## GPT2/GPT-J/GPT-NeoX and causal language modeling The following examples fine-tune GPT-2, GPT-J-6B and GPT-NeoX-20B on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before the tokenization). The loss here is the one of causal language modeling. @@ -230,6 +237,12 @@ python ../gaudi_spawn.py \ ``` +### Training in torch.compile mode +RoBERTa-Large model training in [torch.compile](pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) mode is enabled by applying the following changes to your command, +a) Set the following environment variables `PT_HPU_LAZY_MODE=0` and `PT_ENABLE_INT64_SUPPORT=1`. +b) Run the above commands with `--model_name_or_path roberta-large`, `--use_lazy_mode False` and add `--torch_compile`, `--torch_compile_backend hpu_backend` and remove `--use_hpu_graphs_for_inference` flags. + + ## Pretraining You can easily train a model from scratch by replacing `--model_name_or_path my_model_name` by `--config_name my_model_name --tokenizer_name my_model_name`. @@ -386,36 +399,6 @@ python3 run_lora_clm.py \ --validation_split_percentage 4 \ --adam_epsilon 1e-08 ``` -- Single-card finetuning of Mistral-7B-Instruct-v0.2 with fp8: -```bash -python3 run_lora_clm.py \ - --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2\ - --dataset_name tatsu-lab/alpaca \ - --fp8 True \ - --output_dir ./model_lora_mistral \ - --num_train_epochs 3 \ - --per_device_train_batch_size 8 \ - --evaluation_strategy "no" \ - --save_strategy "no" \ - --learning_rate 4e-4 \ - --warmup_ratio 0.03 \ - --lr_scheduler_type "constant" \ - --max_grad_norm 0.3 \ - --logging_steps 1 \ - --do_train \ - --use_habana \ - --use_lazy_mode \ - --throughput_warmup_steps 5 \ - --lora_rank=8 \ - --lora_target_modules "v_proj" "q_proj" \ - --lora_alpha=16 \ - --lora_dropout=0.05 \ - --dataset_concatenation \ - --max_seq_length 512 \ - --low_cpu_mem_usage True \ - --validation_split_percentage 4 \ - --adam_epsilon 1e-08 -``` - Single-card finetuning of Falcon-40B: ```bash LOWER_LIST=ops_bf16.txt python3 run_lora_clm.py \ @@ -628,7 +611,9 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \ --use_fused_rope False \ --torch_compile_backend hpu_backend \ --torch_compile \ - --gradient_accumulation_steps 2 + --gradient_accumulation_steps 2 \ + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Falcon-180B: @@ -731,7 +716,7 @@ python3 ../text-generation/run_generation.py \ ## Streaming -To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` with `--max_steps` specified in the command line. This is currently supported by `run_mlm.py` and `run_clm.py`. +To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` with `--max_steps` specified in the command line. This is supported by `run_mlm.py` and `run_clm.py`. For example: ```bash diff --git a/examples/language-modeling/requirements.txt b/examples/language-modeling/requirements.txt index 13fabb1bb..3a09ba253 100644 --- a/examples/language-modeling/requirements.txt +++ b/examples/language-modeling/requirements.txt @@ -1,5 +1,5 @@ torch >= 1.3 -datasets >= 2.4.0 +datasets >= 2.14.0 sentencepiece != 0.1.92 protobuf evaluate diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 060a48bb7..fb00e93fb 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -62,8 +62,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 4e1990636..96f1df011 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -259,6 +259,27 @@ class DataArguments: save_last_ckpt: bool = field( default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} ) + instruction_column_name: Optional[str] = field( + default=None, + metadata={ + "help": "Name of the column in the dataset that describes the task that the model should perform. By " + "default, the 'instruction' column is used for non-SQL prompts and the 'question' column is used for SQL prompts." + }, + ) + input_column_name: Optional[str] = field( + default=None, + metadata={ + "help": "Name of the column in the dataset that optionally provides context or input for the task. By " + "default, the 'input' column is used for non-SQL prompts and the 'context' column is used for SQL prompts." + }, + ) + output_column_name: Optional[str] = field( + default=None, + metadata={ + "help": "Name of the column in the dataset with the answer to the instruction. By default, the " + "'output' column is used for non-SQL prompts and the 'answer' column is used for SQL prompts." + }, + ) @dataclass @@ -365,7 +386,7 @@ def create_prompts(examples): prompts["target"] = [] for example in examples: prompt_template = ( - PROMPT_DICT["prompt_with_input"] if example["input"] != "" else PROMPT_DICT["prompt_without_input"] + PROMPT_DICT["prompt_with_input"] if example.get("input", "") != "" else PROMPT_DICT["prompt_without_input"] ) source = prompt_template.format_map(example) prompts["source"].append(source) @@ -539,19 +560,7 @@ def main(): **dataset_args, ) - if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt: - # Preprocessing the datasets. - for key in raw_datasets: - prompts = ( - create_prompts(raw_datasets[key]) - if not data_args.sql_prompt - else create_sql_prompts(raw_datasets[key]) - ) - columns_to_be_removed = list(raw_datasets[key].features.keys()) - raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"]) - raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"]) - raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) - elif ( + if ( data_args.dataset_name == "timdettmers/openassistant-guanaco" ): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621 raw_datasets = raw_datasets.map( @@ -565,7 +574,33 @@ def main(): [col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]] ) else: - raise ValueError("Unsupported dataset") + # Preprocessing the datasets. + for key in raw_datasets: + if data_args.instruction_column_name: + raw_datasets[key] = raw_datasets[key].rename_column( + data_args.instruction_column_name, "question" if data_args.sql_prompt else "instruction" + ) + + if data_args.input_column_name: + raw_datasets[key] = raw_datasets[key].rename_column( + data_args.input_column_name, "context" if data_args.sql_prompt else "input" + ) + + if data_args.output_column_name: + raw_datasets[key] = raw_datasets[key].rename_column( + data_args.output_column_name, "answer" if data_args.sql_prompt else "output" + ) + + prompts = ( + create_prompts(raw_datasets[key]) + if not data_args.sql_prompt + else create_sql_prompts(raw_datasets[key]) + ) + columns_to_be_removed = list(raw_datasets[key].features.keys()) + raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"]) + raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"]) + raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) + # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None @@ -669,18 +704,16 @@ def concatenate_data(dataset, max_seq_length): concatenated_dataset[column] = reshaped_data return datasets.Dataset.from_dict(concatenated_dataset) - if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt: + if data_args.dataset_name == "timdettmers/openassistant-guanaco": + tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"]) + if training_args.do_eval: + tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"]) + else: tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"]) if training_args.do_eval: tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns( ["prompt_sources", "prompt_targets"] ) - elif data_args.dataset_name == "timdettmers/openassistant-guanaco": - tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"]) - if training_args.do_eval: - tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"]) - else: - raise ValueError("Unsupported dataset") tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length) if training_args.do_eval: tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length) diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 44ad61a48..17c243276 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -61,8 +61,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/multi-node-training/EFA/Dockerfile b/examples/multi-node-training/EFA/Dockerfile index 2b97d0e54..919c015be 100644 --- a/examples/multi-node-training/EFA/Dockerfile +++ b/examples/multi-node-training/EFA/Dockerfile @@ -1,4 +1,4 @@ -FROM vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest +FROM vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest # Installs pdsh and upgrade pip RUN apt-get update && apt-get install -y pdsh && \ @@ -18,7 +18,7 @@ RUN sed -i 's/#Port 22/Port 3022/g' /etc/ssh/sshd_config && \ # Installs Optimum Habana and Habana's fork of DeepSpeed RUN pip install optimum[habana] && \ - pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 + pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 CMD ssh-keygen -t rsa -b 4096 -N '' -f ~/.ssh/id_rsa && \ chmod 600 ~/.ssh/id_rsa && \ diff --git a/examples/multi-node-training/GaudiNIC/Dockerfile b/examples/multi-node-training/GaudiNIC/Dockerfile index a35013ea4..c8d9f5ad6 100644 --- a/examples/multi-node-training/GaudiNIC/Dockerfile +++ b/examples/multi-node-training/GaudiNIC/Dockerfile @@ -1,4 +1,4 @@ -FROM vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest +FROM vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest # Installs pdsh and upgrade pip RUN apt-get update && apt-get install -y pdsh && \ @@ -12,7 +12,7 @@ RUN sed -i 's/#Port 22/Port 3022/g' /etc/ssh/sshd_config && \ # Installs Optimum Habana and Habana's fork of DeepSpeed RUN pip install optimum[habana] && \ - pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 + pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 CMD ssh-keygen -t rsa -b 4096 -N '' -f ~/.ssh/id_rsa && \ chmod 600 ~/.ssh/id_rsa && \ diff --git a/examples/multi-node-training/README.md b/examples/multi-node-training/README.md index 204d0076e..0e40e616f 100644 --- a/examples/multi-node-training/README.md +++ b/examples/multi-node-training/README.md @@ -67,6 +67,28 @@ Finally, on each system, add all hosts (including itself) to `known_hosts`. The ssh-keyscan -p 3022 -H 10.10.100.104 >> ~/.ssh/known_hosts ``` +You can check if ssh port is working with the following command: + +1. Run `lsof -i` inside docker of each node to make sure sshd is up. It should be something like below. +```bash +COMMAND PID USER FD TYPE DEVICE SIZE/OFF NODE NAME +sshd 35 root 3u IPv4 23262521 0t0 TCP *:3022 (LISTEN) +sshd 35 root 4u IPv6 23262523 0t0 TCP *:3022 (LISTEN) +``` +If no sshd, then do the following to restart sshd. +```bash +sed -i 's/#Port 22/Port 3022/g' /etc/ssh/sshd_config +sed -i 's/# Port 22/ Port 3022/g' /etc/ssh/ssh_config +sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config +service ssh restart +``` +2. Test ssh with command `ssh -p 3022 IP-address` to each other to make sure the nodes can communicate with each other. + +3. Try gaudi_spawn.py training command with world_size 8 for few steps to make sure the command works for 8 ranks on each node. + +4. Start gaudi_spawn.py with multi-nodes run on main node docker. (the node with the 1st ip address in the hostfile) + + ## Hostfile DeepSpeed requires a [hostfile](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) to know the addresses of and the number of devices to use on each node. You can specify its path with `--hostfile`. This file should look like this: diff --git a/examples/protein-folding/README.md b/examples/protein-folding/README.md index 5d2abe9a0..d5003e1e4 100644 --- a/examples/protein-folding/README.md +++ b/examples/protein-folding/README.md @@ -34,6 +34,13 @@ The predicted protein structure will be stored in save-hpu.pdb file. We can use # Mila-Intel protST example +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Single-HPU inference for zero shot evaluation Here we show how to run zero shot evaluation of protein ST model on HPU: diff --git a/examples/protein-folding/run_esmfold.py b/examples/protein-folding/run_esmfold.py index d1fbfea82..096f18055 100644 --- a/examples/protein-folding/run_esmfold.py +++ b/examples/protein-folding/run_esmfold.py @@ -40,7 +40,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.10.0") +check_optimum_habana_min_version("1.11.0") def convert_outputs_to_pdb(outputs): diff --git a/examples/question-answering/README.md b/examples/question-answering/README.md old mode 100644 new mode 100755 index 32d4917a8..d531bd9fc --- a/examples/question-answering/README.md +++ b/examples/question-answering/README.md @@ -26,6 +26,13 @@ uses special features of those tokenizers. You can check if your favorite model Note that if your dataset contains samples with no possible answers (like SQUAD version 2), you need to pass along the flag `--version_2_with_negative`. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Fine-tuning BERT on SQuAD1.1 For the following cases, an example of a Gaudi configuration file is given @@ -37,7 +44,7 @@ For the following cases, an example of a Gaudi configuration file is given This example code fine-tunes BERT on the SQuAD1.1 dataset. ```bash -python run_qa.py \ +PT_HPU_LAZY_MODE=0 python run_qa.py \ --model_name_or_path bert-large-uncased-whole-word-masking \ --gaudi_config_name Habana/bert-large-uncased-whole-word-masking \ --dataset_name squad \ @@ -51,8 +58,9 @@ python run_qa.py \ --doc_stride 128 \ --output_dir /tmp/squad/ \ --use_habana \ - --use_lazy_mode \ - --use_hpu_graphs_for_inference \ + --torch_compile_backend hpu_backend \ + --torch_compile \ + --use_lazy_mode false \ --throughput_warmup_steps 3 \ --bf16 ``` @@ -63,7 +71,7 @@ python run_qa.py \ Here is how you would fine-tune the BERT large model (with whole word masking) on the SQuAD dataset using the `run_qa` script, with 8 HPUs: ```bash -python ../gaudi_spawn.py \ +PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \ --world_size 8 --use_mpi run_qa.py \ --model_name_or_path bert-large-uncased-whole-word-masking \ --gaudi_config_name Habana/bert-large-uncased-whole-word-masking \ @@ -78,8 +86,9 @@ python ../gaudi_spawn.py \ --doc_stride 128 \ --output_dir /tmp/squad_output/ \ --use_habana \ - --use_lazy_mode \ - --use_hpu_graphs_for_inference \ + --torch_compile_backend hpu_backend \ + --torch_compile \ + --use_lazy_mode false \ --throughput_warmup_steps 3 \ --bf16 ``` @@ -133,6 +142,13 @@ Here is a DeepSpeed configuration you can use to train your models on Gaudi: ``` +### Training in torch.compile mode + +Albert XXL model training in [torch.compile](pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) mode is enabled by applying the following changes to your command, \ +a) Set the following environment variables `PT_HPU_LAZY_MODE=0` and `PT_ENABLE_INT64_SUPPORT=1`. \ +b) Run the above commands with `--model_name_or_path albert-xxlarge-v1`, `--use_lazy_mode False` and add `--torch_compile`, `--torch_compile_backend hpu_backend` and remove `--use_hpu_graphs_for_inference` flags. + + ## Fine-tuning Llama on SQuAD1.1 > [!NOTE] @@ -195,8 +211,8 @@ python run_qa.py \ | RoBERTa large | 3e-5 | 2 | 12 | 8 | | ALBERT large (single-card) | 5e-5 | 2 | 32 | 4 | | ALBERT large (multi-card) | 6e-5 | 2 | 32 | 4 | -| ALBERT XXL (single-card) | 5e-6 | 2 | 12 | 2 | -| ALBERT XXL (multi-card) | 5e-5 | 2 | 12 | 2 | +| ALBERT XXL (single-card) | 5e-6 | 2 | 16 | 2 | +| ALBERT XXL (multi-card) | 5e-5 | 2 | 16 | 2 | | DistilBERT | 5e-5 | 3 | 8 | 8 | | meta-llama/Llama-2-13b-chat-hf (multi-card) | 3e-5 | 2 | 8 | 8 | | FlagAlpha/Llama2-Chinese-13b-Chat (multi-card) | 3e-5 | 2 | 8 | 8 | diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 99cfaf67b..e58f7f42a 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -60,8 +60,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py index bf89175f7..50880a1f7 100644 --- a/examples/question-answering/run_seq2seq_qa.py +++ b/examples/question-answering/run_seq2seq_qa.py @@ -56,8 +56,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 79ded7679..92e576be3 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -27,6 +27,13 @@ limitations under the License. - [Inference](#single-hpu-seq2seq-inference) +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Connectionist Temporal Classification The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/optimum-habana/tree/main/examples/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCTC) for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset. @@ -134,7 +141,7 @@ On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of ** > You need to install DeepSpeed with: > ```bash -> pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +> pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 > ``` DeepSpeed can be used with almost the same command as for a multi-card run: diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py index 31cce49fd..048da1dd5 100644 --- a/examples/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/speech-recognition/run_speech_recognition_ctc.py @@ -59,8 +59,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 9d30c3ccb..06733f8e7 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -55,8 +55,8 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index bffee7548..5f33c6fb7 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -202,6 +202,9 @@ python text_to_image_generation.py \ > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. +> Please note: there is a regression with "--guidance_scale 0.0" for the latest release. + + ### ControlNet ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. @@ -276,3 +279,47 @@ python text_to_image_generation.py \ --use_hpu_graphs \ --gaudi_config Habana/stable-diffusion-2 ``` + +# Stable Video Diffusion Examples + +Stable Video Diffusion (SVD) was unveiled in [Stable Video Diffusion Announcement](https://stability.ai/news/stable-video-diffusion-open-ai-video-model) +by the Stability AI team. Stable Video Diffusion XT version (SVD-XT) is tuned to generate 25 frames of video from a single image. + +## Image-to-video Generation + +Script `image_to_video_generation.py` showcases how to perform image-to-video generation using Stable Video Diffusion on Intel Gaudi. + +### Single Image Prompt + +Here is how to generate video with one image prompt: +```bash +python image_to_video_generation.py \ + --model_name_or_path "stabilityai/stable-video-diffusion-img2vid-xt" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png" \ + --num_videos_per_prompt 1 \ + --video_save_dir /tmp/stable_video_diffusion_xt \ + --save_frames_as_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +### Multiple Image Prompts + +Here is how to generate videos with several image prompts: +```bash +python image_to_video_generation.py \ + --model_name_or_path "stabilityai/stable-video-diffusion-img2vid-xt" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png" \ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" \ + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" \ + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" \ + --num_videos_per_prompt 1 \ + --video_save_dir /tmp/stable_video_diffusion_xt \ + --save_frames_as_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` diff --git a/examples/stable-diffusion/image_to_video_generation.py b/examples/stable-diffusion/image_to_video_generation.py new file mode 100755 index 000000000..7beb73a1a --- /dev/null +++ b/examples/stable-diffusion/image_to_video_generation.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import sys +from pathlib import Path + +import torch +from diffusers.utils import export_to_video, load_image + +from optimum.habana.diffusers import GaudiEulerDiscreteScheduler +from optimum.habana.utils import set_seed + + +try: + from optimum.habana.utils import check_optimum_habana_min_version +except ImportError: + + def check_optimum_habana_min_version(*a, **b): + return () + + +# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. +check_optimum_habana_min_version("1.8.1") + + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name_or_path", + default="stabilityai/stable-video-diffusion-img2vid-xt", + type=str, + help="Path to pre-trained model", + ) + + # Pipeline arguments + parser.add_argument( + "--image_path", + type=str, + default=None, + nargs="*", + help="Path to input image(s) to guide video generation", + ) + parser.add_argument( + "--num_videos_per_prompt", type=int, default=1, help="The number of videos to generate per prompt image." + ) + parser.add_argument("--batch_size", type=int, default=1, help="The number of videos in a batch.") + parser.add_argument("--height", type=int, default=576, help="The height in pixels of the generated video.") + parser.add_argument("--width", type=int, default=1024, help="The width in pixels of the generated video.") + parser.add_argument( + "--num_inference_steps", + type=int, + default=25, + help=( + "The number of denoising steps. More denoising steps usually lead to a higher quality images at the expense" + " of slower inference." + ), + ) + parser.add_argument( + "--min_guidance_scale", + type=float, + default=1.0, + help="The minimum guidance scale. Used for the classifier free guidance with first frame.", + ) + parser.add_argument( + "--max_guidance_scale", + type=float, + default=3.0, + help="The maximum guidance scale. Used for the classifier free guidance with last frame.", + ) + parser.add_argument( + "--fps", + type=int, + default=7, + help=( + "Frames per second. The rate at which the generated images shall be exported to a video after generation." + " Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training." + ), + ) + parser.add_argument( + "--motion_bucket_id", + type=int, + default=127, + help=( + "The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion" + " will be in the video." + ), + ) + parser.add_argument( + "--noise_aug_strength", + type=float, + default=0.02, + help=( + "The amount of noise added to the init image, the higher it is the less the video will look like the" + " init image. Increase it for more motion." + ), + ) + parser.add_argument( + "--decode_chunk_size", + type=int, + default=None, + help=( + "The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency" + " between frames, but also the higher the memory consumption. By default, the decoder will decode all" + " frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage." + ), + ) + parser.add_argument( + "--output_type", + type=str, + choices=["pil", "np"], + default="pil", + help="Whether to return PIL images or Numpy arrays.", + ) + parser.add_argument( + "--pipeline_save_dir", + type=str, + default=None, + help="The directory where the generation pipeline will be saved.", + ) + parser.add_argument( + "--video_save_dir", + type=str, + default="./stable-video-diffusion-generated-frames", + help="The directory where frames will be saved.", + ) + parser.add_argument( + "--save_frames_as_images", + action="store_true", + help="Save output frames as images", + ) + + parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.") + + # HPU-specific arguments + parser.add_argument("--use_habana", action="store_true", help="Use HPU.") + parser.add_argument( + "--use_hpu_graphs", action="store_true", help="Use HPU graphs on HPU. This should lead to faster generations." + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default="Habana/stable-diffusion", + help=( + "Name or path of the Gaudi configuration. In particular, it enables to specify how to apply Habana Mixed" + " Precision." + ), + ) + parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.") + + args = parser.parse_args() + + from optimum.habana.diffusers import GaudiStableVideoDiffusionPipeline + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO) + + # Initialize the scheduler and the generation pipeline + scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler") + kwargs = { + "scheduler": scheduler, + "use_habana": args.use_habana, + "use_hpu_graphs": args.use_hpu_graphs, + "gaudi_config": args.gaudi_config_name, + } + if args.bf16: + kwargs["torch_dtype"] = torch.bfloat16 + + pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained( + args.model_name_or_path, + **kwargs, + ) + + # Set seed before running the model + set_seed(args.seed) + + # Load input image(s) + input = [] + logger.info("Input image(s):") + if isinstance(args.image_path, str): + args.image_path = [args.image_path] + for image_path in args.image_path: + image = load_image(image_path) + image = image.resize((args.height, args.width)) + input.append(image) + logger.info(image_path) + + # Generate images + outputs = pipeline( + image=input, + num_videos_per_prompt=args.num_videos_per_prompt, + batch_size=args.batch_size, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + min_guidance_scale=args.min_guidance_scale, + max_guidance_scale=args.max_guidance_scale, + fps=args.fps, + motion_bucket_id=args.motion_bucket_id, + noise_aug_strength=args.noise_aug_strength, + decode_chunk_size=args.decode_chunk_size, + output_type=args.output_type, + ) + + # Save the pipeline in the specified directory if not None + if args.pipeline_save_dir is not None: + pipeline.save_pretrained(args.pipeline_save_dir) + + # Save images in the specified directory if not None and if they are in PIL format + if args.video_save_dir is not None: + if args.output_type == "pil": + video_save_dir = Path(args.video_save_dir) + video_save_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving video frames in {video_save_dir.resolve()}...") + for i, frames in enumerate(outputs.frames): + export_to_video(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=7) + if args.save_frames_as_images: + for j, frame in enumerate(frames): + frame.save( + args.video_save_dir + + "/gen_video_" + + str(i).zfill(2) + + "_frame_" + + str(j).zfill(2) + + ".png" + ) + else: + logger.warning("--output_type should be equal to 'pil' to save frames in --video_save_dir.") + + +if __name__ == "__main__": + main() diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index bb4cdfb9a..e52fc5fcb 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -38,7 +38,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.10.0") +check_optimum_habana_min_version("1.11.0") logger = logging.getLogger(__name__) @@ -220,6 +220,24 @@ def main(): default=0, help="Number of steps to capture for profiling.", ) + parser.add_argument( + "--unet_adapter_name_or_path", + default=None, + type=str, + help="Path to pre-trained model", + ) + parser.add_argument( + "--text_encoder_adapter_name_or_path", + default=None, + type=str, + help="Path to pre-trained model", + ) + parser.add_argument( + "--lora_id", + default=None, + type=str, + help="Path to lora id", + ) args = parser.parse_args() # Set image resolution @@ -311,6 +329,8 @@ def main(): controlnet=controlnet, **kwargs, ) + if args.lora_id: + pipeline.load_lora_weights(args.lora_id) # Set seed before running the model set_seed(args.seed) @@ -334,6 +354,8 @@ def main(): args.model_name_or_path, **kwargs, ) + if args.lora_id: + pipeline.load_lora_weights(args.lora_id) # Set seed before running the model set_seed(args.seed) @@ -358,8 +380,18 @@ def main(): args.model_name_or_path, **kwargs, ) - - # Set seed before running the model + if args.unet_adapter_name_or_path is not None: + from peft import PeftModel + + pipeline.unet = PeftModel.from_pretrained(pipeline.unet, args.unet_adapter_name_or_path) + pipeline.unet = pipeline.unet.merge_and_unload() + if args.text_encoder_adapter_name_or_path is not None: + from peft import PeftModel + + pipeline.text_encoder = PeftModel.from_pretrained( + pipeline.text_encoder, args.text_encoder_adapter_name_or_path + ) + pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() set_seed(args.seed) outputs = pipeline( diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md index a81e4b0ec..d686b30f4 100644 --- a/examples/stable-diffusion/training/README.md +++ b/examples/stable-diffusion/training/README.md @@ -168,8 +168,8 @@ pip install -r requirements.txt ```bash python train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ - --pretrained_vae_model_name_or_path stabilityai/sdxl-vae \ - --dataset_name lambdalabs/pokemon-blip-captions \ + --pretrained_vae_model_name_or_path madebyollin/sdxl-vae-fp16-fix \ + --dataset_name lambdalabs/naruto-blip-captions \ --resolution 512 \ --crop_resolution 512 \ --center_crop \ @@ -181,14 +181,14 @@ python train_text_to_image_sdxl.py \ --max_grad_norm 1 \ --lr_scheduler constant \ --lr_warmup_steps 0 \ - --output_dir sdxl-pokemon-model \ + --output_dir sdxl_model_output \ --gaudi_config_name Habana/stable-diffusion \ --throughput_warmup_steps 3 \ --dataloader_num_workers 8 \ --bf16 \ --use_hpu_graphs_for_training \ --use_hpu_graphs_for_inference \ - --validation_prompt="a robotic cat with wings" \ + --validation_prompt="a cute naruto creature" \ --validation_epochs 48 \ --checkpointing_steps 2500 \ --logging_step 10 \ @@ -201,8 +201,8 @@ python train_text_to_image_sdxl.py \ PT_HPU_RECIPE_CACHE_CONFIG=/tmp/stdxl_recipe_cache,True,1024 \ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ - --pretrained_vae_model_name_or_path stabilityai/sdxl-vae \ - --dataset_name lambdalabs/pokemon-blip-captions \ + --pretrained_vae_model_name_or_path madebyollin/sdxl-vae-fp16-fix \ + --dataset_name lambdalabs/naruto-blip-captions \ --resolution 512 \ --crop_resolution 512 \ --center_crop \ @@ -214,27 +214,27 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py --max_grad_norm 1 \ --lr_scheduler constant \ --lr_warmup_steps 0 \ - --output_dir sdxl-pokemon-model \ + --output_dir sdxl_model_output \ --gaudi_config_name Habana/stable-diffusion \ --throughput_warmup_steps 3 \ --dataloader_num_workers 8 \ --bf16 \ --use_hpu_graphs_for_training \ --use_hpu_graphs_for_inference \ - --validation_prompt="a robotic cat with wings" \ + --validation_prompt="a cute naruto creature" \ --validation_epochs 48 \ --checkpointing_steps 336 \ - --mediapipe dataset_sdxl_pokemon \ + --mediapipe dataset_sdxl_mediapipe \ --adjust_throughput ``` ### Single-card Training on Gaudi1 ```bash -PT_HPU_MAX_COMPOUND_OP_SIZE=5 python train_text_to_image_sdxl.py \ +python train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ - --pretrained_vae_model_name_or_path stabilityai/sdxl-vae \ - --dataset_name lambdalabs/pokemon-blip-captions \ - --resolution 512 \ + --pretrained_vae_model_name_or_path madebyollin/sdxl-vae-fp16-fix \ + --dataset_name lambdalabs/naruto-blip-captions \ + --resolution 256 \ --center_crop \ --random_flip \ --proportion_empty_prompts=0.2 \ @@ -245,11 +245,12 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=5 python train_text_to_image_sdxl.py \ --max_grad_norm 1 \ --lr_scheduler constant \ --lr_warmup_steps 0 \ - --output_dir sdxl-pokemon-model \ + --output_dir sdxl_model_output \ --gaudi_config_name Habana/stable-diffusion \ --throughput_warmup_steps 3 \ --use_hpu_graphs_for_training \ --use_hpu_graphs_for_inference \ + --checkpointing_steps 3000 \ --bf16 ``` @@ -258,3 +259,172 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=5 python train_text_to_image_sdxl.py \ > [!NOTE] > `--mediapipe` only works on Gaudi2. + + +## DreamBooth +DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +### Full model finetune +And launch the multi-card training using: +```bash + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="dog" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="out" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=$CLASS_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_class_images=200 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --mixed_precision=bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference \ + --gaudi_config_name Habana/stable-diffusion \ + full + +``` +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. + +### PEFT model finetune +We provide example for dreambooth to use lora/lokr/loha/oft to finetune unet or text encoder. + +**___Note: When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we +use *1e-4* instead of the usual *5e-6*.___** + +Launch the multi-card training using: +```bash + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="dog" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="out" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=$CLASS_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_class_images=200 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --mixed_precision=bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference \ + --gaudi_config_name Habana/stable-diffusion \ + lora --unet_r 8 --unet_alpha 8 + +``` +Similar command could be applied to loha, lokr, oft. +You could check each adapter specific args by "--help", like you could use following command to check oft specific args. + +```bash +python3 train_dreambooth.py oft --help + +``` + +**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu. +there's error like "cpu fallback is not supported during hpu graph capturing"___** + + +You could use text_to_image_generation.py to generate picture using the peft adapter like + +```bash +python ../text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --prompts "a sks dog" \ + --num_images_per_prompt 5 \ + --batch_size 1 \ + --image_save_dir /tmp/stable_diffusion_images \ + --use_habana \ + --use_hpu_graphs \ + --unet_adapter_name_or_path out/unet \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +### DreamBooth training example for Stable Diffusion XL +You could use the dog images as example as well. +You can launch training using: +```bash +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="lora-trained-xl" +export VAE_PATH="madebyollin/sdxl-vae-fp16-fix" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --pretrained_vae_model_name_or_path=$VAE_PATH \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed=0 \ + --use_hpu_graphs_for_inference \ + --use_hpu_graphs_for_training \ + --gaudi_config_name Habana/stable-diffusion + +``` + +You could use text_to_image_generation.py to generate picture using the peft adapter like + +```bash +python ../text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "A picture of a sks dog in a bucket" \ + --num_images_per_prompt 5 \ + --batch_size 1 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --use_habana \ + --use_hpu_graphs \ + --lora_id lora-trained-xl \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` diff --git a/examples/stable-diffusion/training/requirements.txt b/examples/stable-diffusion/training/requirements.txt index acdd70a4a..7fb174867 100644 --- a/examples/stable-diffusion/training/requirements.txt +++ b/examples/stable-diffusion/training/requirements.txt @@ -1 +1,2 @@ imagesize +peft == 0.10.0 diff --git a/examples/stable-diffusion/training/train_dreambooth.py b/examples/stable-diffusion/training/train_dreambooth.py new file mode 100644 index 000000000..b34f3c12c --- /dev/null +++ b/examples/stable-diffusion/training/train_dreambooth.py @@ -0,0 +1,1357 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +""" +Training script for DreamBooth to Text-to-Image Diffusion Models +Adapted from the following source: +https://github.com/huggingface/peft/blob/608a90ded9985ee1c5912d738082bb1fd618902b/examples/stable_diffusion/train_dreambooth.py +""" + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import threading +import warnings +from pathlib import Path +from typing import Union + +import datasets +import diffusers +import habana_frameworks.torch.core as htcore +import numpy as np +import psutil +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from habana_frameworks.torch.hpu import memory_stats +from huggingface_hub import HfApi +from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType +from optimum.habana.diffusers import GaudiStableDiffusionPipeline +from optimum.habana.transformers.trainer import _is_peft_model +from optimum.habana.utils import set_seed + + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = [ + "to_q", + "to_k", + "to_v", + "proj", + "proj_in", + "proj_out", + "conv", + "conv1", + "conv2", + "conv_shortcut", + "to_out.0", + "time_emb_proj", + "ff.net.2", +] + +TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: + if args.adapter == "full": + raise ValueError("Cannot create unet adapter config for full parameter") + + if args.adapter == "lora": + config = LoraConfig( + r=args.unet_r, + lora_alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=args.unet_dropout, + bias=args.unet_bias, + init_lora_weights=True, + ) + elif args.adapter == "loha": + config = LoHaConfig( + r=args.unet_r, + alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + rank_dropout=args.unet_rank_dropout, + module_dropout=args.unet_module_dropout, + use_effective_conv2d=args.unet_use_effective_conv2d, + init_weights=True, + ) + elif args.adapter == "lokr": + config = LoKrConfig( + r=args.unet_r, + alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + rank_dropout=args.unet_rank_dropout, + module_dropout=args.unet_module_dropout, + use_effective_conv2d=args.unet_use_effective_conv2d, + decompose_both=args.unet_decompose_both, + decompose_factor=args.unet_decompose_factor, + init_weights=True, + ) + elif args.adapter == "oft": + config = OFTConfig( + r=args.unet_r, + target_modules=UNET_TARGET_MODULES, + module_dropout=args.unet_dropout, + init_weights=True, + coft=args.unet_use_coft, + eps=args.unet_eps, + ) + else: + raise ValueError(f"Unknown adapter type {args.adapter}") + + return config + + +def create_text_encoder_adapter_config( + args: argparse.Namespace, +) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: + if args.adapter == "full": + raise ValueError("Cannot create text_encoder adapter config for full parameter") + + if args.adapter == "lora": + config = LoraConfig( + r=args.te_r, + lora_alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=args.te_dropout, + bias=args.te_bias, + init_lora_weights=True, + ) + elif args.adapter == "loha": + config = LoHaConfig( + r=args.te_r, + alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + rank_dropout=args.te_rank_dropout, + module_dropout=args.te_module_dropout, + init_weights=True, + ) + elif args.adapter == "lokr": + config = LoKrConfig( + r=args.te_r, + alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + rank_dropout=args.te_rank_dropout, + module_dropout=args.te_module_dropout, + decompose_both=args.te_decompose_both, + decompose_factor=args.te_decompose_factor, + init_weights=True, + ) + elif args.adapter == "oft": + config = OFTConfig( + r=args.te_r, + target_modules=TEXT_ENCODER_TARGET_MODULES, + module_dropout=args.te_dropout, + init_weights=True, + coft=args.te_use_coft, + eps=args.te_eps, + ) + else: + raise ValueError(f"Unknown adapter type {args.adapter}") + + return config + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "bf16"], + help=( + "Whether to use mixed precision. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "bf16"], + help=("Choose prior generation precision between fp32 and bf16 (bfloat16)."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--throughput_warmup_steps", + type=int, + default=0, + help=( + "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the" + " first N steps will not be considered in the calculation of the throughput. This is especially useful in" + " lazy mode." + ), + ) + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.", + ) + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.", + ) + + # Adapter arguments + subparsers = parser.add_subparsers(dest="adapter") + + # Dummy subparser to train whole model + subparsers.add_parser("full", help="Train full model without adapters") + + # LoRA adapter + lora = subparsers.add_parser("lora", help="Use LoRA adapter") + lora.add_argument("--unet_r", type=int, default=8, help="LoRA rank for unet") + lora.add_argument("--unet_alpha", type=int, default=8, help="LoRA alpha for unet") + lora.add_argument("--unet_dropout", type=float, default=0.0, help="LoRA dropout probability for unet") + lora.add_argument( + "--unet_bias", + type=str, + default="none", + help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only'", + ) + lora.add_argument( + "--te_r", type=int, default=8, help="LoRA rank for text_encoder, only used if `train_text_encoder` is True" + ) + lora.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoRA alpha for text_encoder, only used if `train_text_encoder` is True", + ) + lora.add_argument( + "--te_dropout", + type=float, + default=0.0, + help="LoRA dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lora.add_argument( + "--te_bias", + type=str, + default="none", + help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only', only used if `train_text_encoder` is True", + ) + + # LoHa adapter + loha = subparsers.add_parser("loha", help="Use LoHa adapter") + loha.add_argument("--unet_r", type=int, default=8, help="LoHa rank for unet") + loha.add_argument("--unet_alpha", type=int, default=8, help="LoHa alpha for unet") + loha.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoHa rank_dropout probability for unet") + loha.add_argument( + "--unet_module_dropout", type=float, default=0.0, help="LoHa module_dropout probability for unet" + ) + loha.add_argument( + "--unet_use_effective_conv2d", + action="store_true", + help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1", + ) + loha.add_argument( + "--te_r", type=int, default=8, help="LoHa rank for text_encoder, only used if `train_text_encoder` is True" + ) + loha.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoHa alpha for text_encoder, only used if `train_text_encoder` is True", + ) + loha.add_argument( + "--te_rank_dropout", + type=float, + default=0.0, + help="LoHa rank_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + loha.add_argument( + "--te_module_dropout", + type=float, + default=0.0, + help="LoHa module_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + + # LoKr adapter + lokr = subparsers.add_parser("lokr", help="Use LoKr adapter") + lokr.add_argument("--unet_r", type=int, default=8, help="LoKr rank for unet") + lokr.add_argument("--unet_alpha", type=int, default=8, help="LoKr alpha for unet") + lokr.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoKr rank_dropout probability for unet") + lokr.add_argument( + "--unet_module_dropout", type=float, default=0.0, help="LoKr module_dropout probability for unet" + ) + lokr.add_argument( + "--unet_use_effective_conv2d", + action="store_true", + help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1", + ) + lokr.add_argument( + "--unet_decompose_both", action="store_true", help="Decompose left matrix in kronecker product for unet" + ) + lokr.add_argument( + "--unet_decompose_factor", type=int, default=-1, help="Decompose factor in kronecker product for unet" + ) + lokr.add_argument( + "--te_r", type=int, default=8, help="LoKr rank for text_encoder, only used if `train_text_encoder` is True" + ) + lokr.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoKr alpha for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_rank_dropout", + type=float, + default=0.0, + help="LoKr rank_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_module_dropout", + type=float, + default=0.0, + help="LoKr module_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_decompose_both", + action="store_true", + help="Decompose left matrix in kronecker product for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_decompose_factor", + type=int, + default=-1, + help="Decompose factor in kronecker product for text_encoder, only used if `train_text_encoder` is True", + ) + # oft adapter + oft = subparsers.add_parser("oft", help="Use Oft adapter") + oft.add_argument("--unet_r", type=int, default=8, help="Oft rank for unet") + oft.add_argument("--unet_dropout", type=float, default=0.0, help="Oft dropout probability for unet") + oft.add_argument("--unet_use_coft", action="store_true", help="Using constrained OFT in unet") + oft.add_argument("--unet_eps", type=float, default=0.0, help="The control strength of COFT for unet") + oft.add_argument( + "--te_r", type=int, default=8, help="Oft rank for text_encoder, only used if `train_text_encoder` is True" + ) + oft.add_argument( + "--te_dropout", + type=float, + default=0.0, + help="Oft dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + oft.add_argument( + "--te_use_coft", + action="store_true", + help="Using constrained OFT in text_encoder, only used if `train_text_encoder` is True", + ) + oft.add_argument( + "--te_eps", + type=float, + default=0.0, + help="The control strength of COFT for text_encoder, only used if `train_text_encoder` is True", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + mem_stats = memory_stats() + + self.begin = mem_stats["InUse"] + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + mem_stats = memory_stats() + + self.end = mem_stats["InUse"] + self.peak = mem_stats["MaxInUse"] + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16" + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + force_autocast=gaudi_config.use_torch_autocast, + ) + if args.report_to == "wandb": + import wandb + + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project_name) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.bfloat16 if accelerator.device.type == "hpu" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + api = HfApi(token=args.hub_token) + # Create repo (repo_name from args or inferred) + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + repo_id = api.create_repo(repo_name, exist_ok=True).repo_id + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + ) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + if args.adapter != "full": + config = create_unet_adapter_config(args) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + unet.to(accelerator.device) + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + elif args.train_text_encoder and args.adapter != "full": + config = create_text_encoder_adapter_config(args) + text_encoder = get_peft_model(text_encoder, config) + text_encoder.print_trainable_parameters() + text_encoder.to(accelerator.device) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder and not args.adapter != "full": + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + if gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_class = FusedAdamW + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=1, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + def unwrap_model(model, training=False): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + if not training: + return model + else: + if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU: + kwargs = {} + kwargs["gradient_as_bucket_view"] = True + accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + if args.use_hpu_graphs_for_training: + if _is_peft_model(model): + base_model = model.get_base_model() + htcore.hpu.ModuleCacher()(model=base_model, inplace=True) + else: + htcore.hpu.ModuleCacher()(model=model, inplace=True) + return model + + unwrap_model(model=unet, training=True) + if args.train_text_encoder: + unwrap_model(model=text_encoder, training=True) + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + with TorchTracemalloc() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + htcore.mark_step() + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + # if global_step % args.checkpointing_steps == 0: + # if accelerator.is_main_process: + # save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + # accelerator.save_state(save_path) + # logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # Set evaliation mode + pipeline.unet.eval() + pipeline.text_encoder.eval() + + # run inference + if args.seed is not None: + if accelerator.device == torch.device("hpu"): + # torch.Generator() is unsupported on HPU + generator = set_seed(args.seed) + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + # Set evaliation mode + pipeline.unet.train() + if args.train_text_encoder: + pipeline.text_encoder.train() + + del pipeline + + if global_step >= args.max_train_steps: + break + # Printing the HPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print(f"HPU Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"HPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"HPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"HPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + + accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") + accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") + accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") + accelerator.print( + f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" + ) + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.adapter != "full": + unwarpped_unet = unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, "text_encoder"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + else: + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unwrap_model(unet), + text_encoder=unwrap_model(text_encoder), + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + api.upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + run_as_future=True, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py new file mode 100644 index 000000000..ea34c5077 --- /dev/null +++ b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py @@ -0,0 +1,1768 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +""" +Training script for LORA DreamBooth to Text-to-Image Diffusion Models +Adapted from the following source: +https://github.com/huggingface/diffusers/blob/v0.26.3/examples/dreambooth/train_dreambooth_lora_sdxl.py +""" + +import argparse +import gc +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import diffusers +import habana_frameworks.torch.core as htcore +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from packaging import version +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType +from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline +from optimum.habana.transformers.trainer import _is_peft_model +from optimum.habana.utils import set_seed + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.26.0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, +): + img_str = "widget:\n" if images else "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + output: + url: + "image_{i}.png" + """ + + yaml = f""" +--- +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- lora +- template:sd-lora +{img_str} +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ +--- + """ + + model_card = f""" +# SDXL LoRA DreamBooth - {repo_id} + + + +## Model description + +These are {repo_id} LoRA adaption weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/). + +LoRA for the text encoder was enabled: {train_text_encoder}. + +Special VAE used for training: {vae_path}. + +## Trigger words + +You should use {instance_prompt} to trigger the image generation. + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "bf16"], + help=( + "Whether to use mixed precision. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "bf16"], + help=("Choose prior generation precision between fp32 and bf16 (bfloat16)."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.", + ) + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images[index % self.num_instance_images] + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-1][-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + logging_dir = Path(args.output_dir, args.logging_dir) + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16" + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + force_autocast=gaudi_config.use_torch_autocast, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.bfloat16 if accelerator.device.type == "hpu" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + + # The VAE is always in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) + + def unwrap_model(model, training=False): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + if not training: + return model + else: + if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU: + kwargs = {} + kwargs["gradient_as_bucket_view"] = True + accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + if args.use_hpu_graphs_for_training: + if _is_peft_model(model): + base_model = model.get_base_model() + htcore.hpu.ModuleCacher()(model=base_model, inplace=True) + else: + htcore.hpu.ModuleCacher()(model=model, inplace=True) + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + GaudiStableDiffusionXLPipeline.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_one_ = None + text_encoder_two_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] + else: + params_to_optimize = [unet_lora_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + elif gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_class = FusedAdamW + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Computes additional embeddings/ids required by the SDXL UNet. + # regular text embeddings (when `train_text_encoder` is not True) + # pooled text embeddings + # time ids + + def compute_time_ids(): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds + + # Handle instance prompt. + instance_time_ids = compute_time_ids() + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + class_time_ids = compute_time_ids() + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + gc.collect() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + add_time_ids = instance_time_ids + if args.with_prior_preservation: + add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) + + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) + + unwrap_model(model=unet, training=True) + if args.train_text_encoder: + unwrap_model(model=text_encoder_one, training=True) + unwrap_model(model=text_encoder_two, training=True) + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() + + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts) + tokens_two = tokenize_prompt(tokenizer_two, prompts) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + else: + elems_to_repeat_text_embeds = 1 + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + + # Predict the noise residual + if not args.train_text_encoder: + unet_added_conditions = { + "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), + } + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] + else: + unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + unet_added_conditions.update( + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + htcore.mark_step() + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.text_encoder.eval() + pipeline.text_encoder_2.eval() + pipeline.unet.eval() + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + if accelerator.device == torch.device("hpu"): + # torch.Generator() is unsupported on HPU + generator = set_seed(args.seed) + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + pipeline_args = {"prompt": args.validation_prompt} + + images = [ + pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + pipeline.unet.train() + if args.train_text_encoder: + pipeline.text_encoder.train() + pipeline.text_encoder_2.train() + del pipeline + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + ) + text_encoder_two = unwrap_model(text_encoder_two) + text_encoder_2_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + ) + else: + text_encoder_lora_layers = None + text_encoder_2_lora_layers = None + + GaudiStableDiffusionXLPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_2_lora_layers=text_encoder_2_lora_layers, + ) + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py index 9e7503867..9a7193db2 100644 --- a/examples/stable-diffusion/training/train_text_to_image_sdxl.py +++ b/examples/stable-diffusion/training/train_text_to_image_sdxl.py @@ -934,21 +934,16 @@ def preprocess_train(examples): for image in images: original_sizes.append((image.height, image.width)) image = train_resize(image) - if args.crop_resolution < args.resolution: - if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) - image = crop(image, y1, x1, h, w) - else: - x1 = 0 - y1 = 0 if args.random_flip and random.random() < 0.5: # flip - x1 = image.width - x1 image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) crop_top_left = (y1, x1) crop_top_lefts.append(crop_top_left) image = train_transforms(image) diff --git a/examples/summarization/README.md b/examples/summarization/README.md index 7b63d96c6..e1fc98c4d 100644 --- a/examples/summarization/README.md +++ b/examples/summarization/README.md @@ -23,6 +23,13 @@ This directory contains examples for finetuning and evaluating transformers on s For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files. You will also find examples of these below. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Single-card Training Here is an example of a summarization task with T5: diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py old mode 100644 new mode 100755 index 0574a0f0b..db7a4913c --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -65,8 +65,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") @@ -764,6 +764,9 @@ def compute_metrics(eval_preds): else: training_args.generation_config.max_length = data_args.val_max_target_length if data_args.num_beams is not None: + if data_args.num_beams == 1: + training_args.generation_config.length_penalty = None + training_args.generation_config.early_stopping = False training_args.generation_config.num_beams = data_args.num_beams elif training_args.generation_num_beams is not None: training_args.generation_config.num_beams = training_args.generation_num_beams diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index 3a354ecd6..f5af6bc7d 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -27,6 +27,12 @@ and can also be used for a dataset hosted on our [hub](https://huggingface.co/da GLUE is made up of a total of 9 different tasks where the task name can be cola, sst2, mrpc, stsb, qqp, mnli, qnli, rte or wnli. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` ## Fine-tuning BERT on MRPC diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index abe8b33e8..155d1dd65 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -57,8 +57,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md old mode 100644 new mode 100755 index f1c8f0bdd..e020e72a7 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -28,7 +28,7 @@ pip install -r requirements.txt Then, if you plan to use [DeepSpeed-inference](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Inference_Using_DeepSpeed.html) (e.g. to use BLOOM/BLOOMZ), you should install DeepSpeed as follows: ```bash -pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 ``` @@ -91,6 +91,22 @@ python run_generation.py \ > The batch size should be larger than or equal to the number of prompts. Otherwise, only the first N prompts are kept with N being equal to the batch size. +### Run Speculative Sampling on Gaudi + +If you want to generate a sequence of text from a prompt of your choice using assisted decoding, you can use the following command as an example: + +``` +python run_generation.py \ +--model_name_or_path gpt2 \ +--assistant_model distilgpt2 \ +--batch_size 1 \ +--max_new_tokens 100 \ +--use_hpu_graphs \ +--use_kv_cache \ +--num_return_sequences 1 \ +--temperature 0 \ +--prompt "Alice and Bob" +``` ### Benchmark @@ -107,7 +123,6 @@ Here are a few settings you may be interested in: - `--prompt` to benchmark the model on one or several prompts of your choice - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it -- `--fp8` Enable Quantization to fp8 For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash @@ -155,7 +170,9 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --use_hpu_graphs \ --use_kv_cache \ --batch_size 1 \ ---do_sample +--do_sample \ +--use_flash_attention \ +--flash_attention_causal_mask ``` > To be able to run gated models like [StarCoder](https://huggingface.co/bigcode/starcoder), you should: @@ -175,7 +192,6 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ > --bf16 > ``` - ### Use any dataset from the Hugging Face Hub You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`. @@ -266,7 +282,10 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python ../gaudi_spawn.py --use_hpu_graphs \ --trim_logits \ --use_kv_cache \ ---reuse_cache \ +--bucket_size=128 \ +--bucket_internal \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ --batch_size 1 ``` @@ -281,10 +300,12 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --use_hpu_graphs \ --trim_logits \ --use_kv_cache \ ---reuse_cache \ +--bucket_size=128 \ +--bucket_internal \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ ---batch_size 1 \ ---fp8 +--batch_size 1 ``` Alternatively, here is another example to quantize the model based on previous measurements for LLama2-70b: @@ -297,12 +318,13 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --trim_logits \ --use_kv_cache \ --reuse_cache \ +--use_flash_attention \ +--flash_attention_recompute \ --bf16 \ ---batch_size 277 \ +--batch_size 350 \ --max_new_tokens 2048 \ --max_input_tokens 2048 \ ---limit_hpu_graphs \ ---fp8 +--limit_hpu_graphs ``` Here is an example to measure the tensor quantization statistics on Mixtral-8x7B with 1 card: @@ -328,8 +350,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati --bucket_size 128 \ --max_new_tokens 2048 \ --batch_size 16 \ ---bf16 \ ---fp8 +--bf16 ``` Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: @@ -344,7 +365,10 @@ QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python .. --trim_logits \ --batch_size 1 \ --bf16 \ ---reuse_cache +--reuse_cache \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask ``` Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: @@ -361,7 +385,9 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --bf16 \ --reuse_cache \ --trim_logits \ ---fp8 +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask ``` Here is an example to measure the tensor quantization statistics on phi-2 with 1 card: @@ -389,12 +415,56 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_phi.json python run_generation.p --batch_size 1 \ --bf16 \ --trim_logits \ ---reuse_cache \ ---fp8 +--reuse_cache ``` -`--fp8` is required to enable quantization in fp8. +### Running FP8 models on single device + +Some bf16 models don't fit on one card due to hpu memory limitation, but in fp8 precision they do fit. +As measurement is being calculated in bf16 precision, to be able to run fp8 model on single card you should use `unify_measurements` script. +Here are the steps: +1. Measure the model on a number of cards that are enough for the model to fit in BF16. +2. Quantize the model on the same amount of cards for scales to be saved. +3. Run unify_measurements.py script using the measurement files created after running steps 1 and 2. A unified measurement is then calculated. +```bash +python quantization_tools/unify_measurements.py -g 01234567 -m *path_to_8x_measurements* -o *path_to_output_1x_measurement* +``` +In the above example, the measurements of cards 0-7 will be unified to a single measurement. For example, if you specify `-g 0123 4567`, +cards 0-3 and cards 4-7 will be unified in two different measurement files. All different group combinations are supported. +4. Run quantization using the unified measurement file/s. + +More information on usage of the unifier script can be found in fp8 Habana docs: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html + + + +### CPU memory reduction on single card + +Some models can fit on HPU DRAM but can't fit on the CPU RAM. +When we run a model on single card and don't use deepspeed, the `--disk_offload` flag allows to offload weights to disk during model quantization in HQT. When this flag is mentioned, during the quantization process, each weight first is loaded from disk to CPU RAM, when brought to HPU DRAM and quantized there. This way not all the model is on the CPU RAM but only one weight each time. +To enable this weights offload mechanism, add `--disk_offload` flag to the topology command line. +Here is an example of using disk_offload in quantize command. +Please follow the "Running FP8 models on single device" section first before running the cmd below. + +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant.json TQDM_DISABLE=1 \ +python run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--attn_softmax_bf16 \ +--use_hpu_graphs \ +--trim_logits \ +--use_kv_cache \ +--limit_hpu_graphs \ +--bucket_size=128 \ +--bucket_internal \ +--max_new_tokens 2048 \ +--max_input_tokens 2048 \ +--bf16 \ +--batch_size 1 \ +--disk_offload \ +--use_flash_attention \ +--flash_attention_recompute +``` ### Using Habana Flash Attention @@ -406,13 +476,16 @@ Below example uses `flash_attention_recompute` mode in order to reduce memory co python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --model_name_or_path meta-llama/Llama-2-70b-hf \ --use_hpu_graphs \ +--limit_hpu_graphs \ --use_kv_cache \ ---reuse_cache \ +--bf16 \ --trim_logits \ --attn_softmax_bf16 \ ---max_input_tokens 31744 \ ---max_new_tokens 1024 \ ---batch_size=12 \ +--bucket_size=128 \ +--bucket_internal \ +--batch_size 10 \ +--max_input_tokens 40960 \ +--max_new_tokens 5120 \ --use_flash_attention \ --flash_attention_recompute \ --flash_attention_causal_mask \ @@ -429,7 +502,7 @@ The evaluation of LLMs can be done using the `lm_eval.py` script. It utilizes th For a more detailed description of parameters, please see the help message: ``` -./run_lm_eval.py -h +python run_lm_eval.py --help ``` diff --git a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json new file mode 100644 index 000000000..602a147ba --- /dev/null +++ b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_tools/unify_measurements.py b/examples/text-generation/quantization_tools/unify_measurements.py index 75ae329a4..0efc06c8d 100644 --- a/examples/text-generation/quantization_tools/unify_measurements.py +++ b/examples/text-generation/quantization_tools/unify_measurements.py @@ -20,13 +20,15 @@ def find_measurement_path(measurement, measurements_dir_path, scales, group_size return os.path.join(measurements_dir_path, measurment_file) -def unify_measurements(measurement_group, measurements_dir_path, output_path, scales=False): +def unify_measurements( + measurement_group, measurements_dir_path, output_path, groups_size, groups_num, group_index, scales=False +): measurements_paths = [] group_name = "" # save all the jsons paths in the given measurement group for measurement in measurement_group: - measurement_path = find_measurement_path(measurement, measurements_dir_path, scales, len(measurement_group)) + measurement_path = find_measurement_path(measurement, measurements_dir_path, scales, groups_size) measurements_paths.append(measurement_path) group_name += measurement @@ -36,13 +38,22 @@ def unify_measurements(measurement_group, measurements_dir_path, output_path, sc with open(measurement_path, "r") as f: js = json.load(f) measurements_jsons.append(js["Nodes"]) - # create a name for the unified json that will be created for this measurement group - unified_json_name = ( - find_measurement_path(measurement_group[0], measurements_dir_path, scales, len(measurement_group)) - .split("/")[-1] - .replace("_" + measurement_group[0] + "_" + str(len(measurement_group)), "") - ) + + if groups_num == 1: + unified_json_name = ( + find_measurement_path(measurement_group[0], measurements_dir_path, scales, groups_size) + .split("/")[-1] + .replace("_" + measurement_group[0] + "_" + str(groups_size), "") + ) + else: + unified_json_name = ( + find_measurement_path(measurement_group[0], measurements_dir_path, scales, groups_size) + .split("/")[-1] + .replace( + "_" + measurement_group[0] + "_" + str(groups_size), "_" + str(group_index) + "_" + str(groups_num) + ) + ) unified_json_path = os.path.join(output_path, unified_json_name) # open a unified json file @@ -50,6 +61,7 @@ def unify_measurements(measurement_group, measurements_dir_path, output_path, sc copy.write(origin.read()) with open(unified_json_path, "r") as json_file: unified_json = json.load(json_file) + unified_json["LocalRank"] = group_index if groups_num != 1 else -1 # iterate all unified json nodes for node_name, node_values in unified_json["Nodes"].items(): @@ -64,7 +76,8 @@ def unify_measurements(measurement_group, measurements_dir_path, output_path, sc # iterate over all the measurment group and take the maximum for each tensor and its channel if scales: for measurement_json in measurements_jsons: - max_inputs[0] = max(measurement_json[node_name]["inputs"][0], max_inputs[0]) + for i in range(0, len(max_inputs)): + max_inputs[i] = max(measurement_json[node_name]["inputs"][i], max_inputs[i]) if max_outputs is not None: max_outputs = max(measurement_json[node_name]["outputs"], max_outputs) if max_weight is not None: @@ -83,7 +96,8 @@ def unify_measurements(measurement_group, measurements_dir_path, output_path, sc # update the maximum in the unified json if scales: - unified_json["Nodes"][node_name]["inputs"][0] = max_inputs[0] + for i in range(0, len(max_inputs)): + unified_json["Nodes"][node_name]["inputs"][i] = max_inputs[i] if max_outputs is not None: unified_json["Nodes"][node_name]["outputs"] = max_outputs if max_weight is not None: @@ -99,7 +113,7 @@ def unify_measurements(measurement_group, measurements_dir_path, output_path, sc for i in range(0, len(max_weight)): unified_json["Nodes"][node_name]["params"]["weight"][i][0] = max_weight[i][0] global_rank = None - local_rank = None + local_rank = group_index if groups_num != 1 else -1 mode = "" layers = {} with open(unified_json_path, "w") as json_file: @@ -134,7 +148,8 @@ def parse_args(args): "--groups", type=list, nargs="+", - help="the groups of cards that are going to be unified- e.g. 01 23 45 67", + help="groups of cards we want to unify, each group should be seperated by whitespace \ + - e.g. 01 23 45 67, card 0 measurement will be unified with card 1 measurement and so on", ) parser.add_argument( "-o", @@ -154,15 +169,27 @@ def main(args): measurements_path = args.measurements groups = args.groups - num_jsons = 0 + num_jsons_drange = 0 + num_jsons_scales = 0 for path in os.listdir(measurements_path): if path.endswith(".json"): - num_jsons += 1 - assert os.path.isdir(measurements_path) and (num_jsons % len(groups)) == 0 + if "MAXABS" in path: + num_jsons_scales += 1 + elif "mod_list" not in path: + num_jsons_drange += 1 + assert ( + os.path.isdir(measurements_path) + and (num_jsons_drange % len(groups)) == 0 + and (num_jsons_scales % len(groups)) == 0 + ) - for group in groups: - unify_measurements(group, measurements_path, output_path, scales=False) - unify_measurements(group, measurements_path, output_path, scales=True) + for group_index, group in enumerate(groups): + unify_measurements( + group, measurements_path, output_path, num_jsons_drange, len(groups), group_index, scales=False + ) + unify_measurements( + group, measurements_path, output_path, num_jsons_scales, len(groups), group_index, scales=True + ) print("finished measurement unifier script") diff --git a/examples/text-generation/requirements_lm_eval.txt b/examples/text-generation/requirements_lm_eval.txt index 4001184d7..4d1824722 100644 --- a/examples/text-generation/requirements_lm_eval.txt +++ b/examples/text-generation/requirements_lm_eval.txt @@ -1 +1 @@ -git+https://github.com/polisettyvarma/lm-evaluation-harness.git@lm_harness_fixes \ No newline at end of file +https://github.com/polisettyvarma/lm-evaluation-harness/archive/3cdc8daadad9f4559ae6cdfae96f1d83d6b3c1f4.zip \ No newline at end of file diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index a9225264f..b30c2e444 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -125,6 +125,12 @@ def setup_parser(parser): type=int, help="Number of steps to capture for profiling.", ) + parser.add_argument( + "--profiling_record_shapes", + default=False, + type=bool, + help="Record shapes when enabling profiling.", + ) parser.add_argument( "--prompt", default=None, @@ -146,6 +152,12 @@ def setup_parser(parser): nargs="+", help="Optional argument list of words that must be generated.", ) + parser.add_argument( + "--assistant_model", + default=None, + type=str, + help="Optional argument to give a path to a draft/assistant model for assisted decoding.", + ) parser.add_argument( "--peft_model", default=None, @@ -221,7 +233,6 @@ def setup_parser(parser): help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)", ) - parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") parser.add_argument( "--use_flash_attention", action="store_true", @@ -237,6 +248,11 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention in causal mode on first token generation.", ) + parser.add_argument( + "--flash_attention_fast_softmax", + action="store_true", + help="Whether to enable Habana Flash Attention in fast softmax mode.", + ) parser.add_argument( "--book_source", action="store_true", @@ -290,7 +306,7 @@ def setup_parser(parser): def main(): parser = argparse.ArgumentParser() args = setup_parser(parser) - model, tokenizer, generation_config = initialize_model(args, logger) + model, assistant_model, tokenizer, generation_config = initialize_model(args, logger) use_lazy_mode = True if args.torch_compile and model.config.model_type == "llama": @@ -388,12 +404,14 @@ def generate(size=None, reduce_recompile=False): outputs = model.generate( **input_tokens, generation_config=generation_config, + assistant_model=assistant_model, lazy_mode=use_lazy_mode, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ignore_eos=args.ignore_eos, iteration_times=iteration_times, + profiling_record_shapes=args.profiling_record_shapes, ).cpu() first_token_time = iteration_times[0] + encode_duration logger.info(f"Time to first token = {first_token_time*1000}ms") @@ -577,6 +595,7 @@ def generate_dataset(batch): profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ignore_eos=args.ignore_eos, + profiling_record_shapes=args.profiling_record_shapes, ).cpu() return prompt, outputs diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 43a1c46ef..4d9b69bde 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,16 +75,24 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral"]: + if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, } ) - if self.model.config.model_type in ["llama", "mistral"]: + if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon"]: + if self.model.config.model_type != "falcon": + self.model_inputs.update( + { + "attn_softmax_bf16": self.options.attn_softmax_bf16, + } + ) self.model_inputs.update( { - "attn_softmax_bf16": self.options.attn_softmax_bf16, + "use_flash_attention": self.options.use_flash_attention, + "flash_attention_recompute": self.options.flash_attention_recompute, + "flash_attention_causal_mask": self.options.flash_attention_causal_mask, } ) if args.warmup: @@ -149,7 +157,7 @@ def _model_call(self, inps): def main(): args = setup_lm_eval_parser() - model, tokenizer, generation_config = initialize_model(args, logger) + model, _, tokenizer, generation_config = initialize_model(args, logger) lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) with torch.no_grad(): diff --git a/examples/text-generation/text-generation-pipeline/README.md b/examples/text-generation/text-generation-pipeline/README.md index 10c651e37..2e6e5b84f 100644 --- a/examples/text-generation/text-generation-pipeline/README.md +++ b/examples/text-generation/text-generation-pipeline/README.md @@ -20,15 +20,9 @@ The text-generation pipeline can be used to perform text-generation by providing ## Requirements -Update `PYTHONPATH` as follows. -```bash -export OPTIMUM_HABANA_PATH=/path/to/optimum-habana -export PYTHONPATH=${PYTHONPATH}:${OPTIMUM_HABANA_PATH}/examples/text-generation -``` - If you plan to use [DeepSpeed-inference](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Inference_Using_DeepSpeed.html), you should install DeepSpeed as follows: ```bash -pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 +pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 ``` If you would like to use the pipeline with LangChain classes, you can install LangChain as follows: @@ -78,6 +72,7 @@ python run_pipeline.py \ --use_kv_cache \ --max_new_tokens 100 \ --do_sample \ +--batch_size 2 \ --prompt "Hello world" "How are you?" ``` @@ -101,6 +96,7 @@ python run_pipeline.py \ --do_sample \ --temperature 0.5 \ --top_p 0.95 \ +--batch_size 2 \ --prompt "Hello world" "How are you?" ``` @@ -114,6 +110,7 @@ python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ --bf16 \ --use_hpu_graphs \ --use_kv_cache \ +--batch_size 4 \ --prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time" ``` @@ -128,6 +125,7 @@ python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ --do_sample \ --temperature 0.5 \ --top_p 0.95 \ +--batch_size 4 \ --prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time" ``` @@ -143,7 +141,7 @@ python run_pipeline_langchain.py \ --batch_size 32 \ --max_new_tokens 1024 \ --do_sample \ - --device=hpu + --device=hpu ``` > The pipeline class has been validated for LangChain version 0.1.16 and may not work with other versions of the package. diff --git a/examples/text-generation/text-generation-pipeline/pipeline.py b/examples/text-generation/text-generation-pipeline/pipeline.py index 6105d52a6..15cb96a3d 100644 --- a/examples/text-generation/text-generation-pipeline/pipeline.py +++ b/examples/text-generation/text-generation-pipeline/pipeline.py @@ -1,11 +1,19 @@ +import os +import sys + import torch from transformers import TextGenerationPipeline -from utils import initialize_model + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(SCRIPT_DIR)) class GaudiTextGenerationPipeline(TextGenerationPipeline): def __init__(self, args, logger, use_with_langchain=False, warmup_on_init=True): - self.model, self.tokenizer, self.generation_config = initialize_model(args, logger) + from utils import initialize_model + + self.model, _, self.tokenizer, self.generation_config = initialize_model(args, logger) self.task = "text-generation" self.device = args.device @@ -18,6 +26,7 @@ def __init__(self, args, logger, use_with_langchain=False, warmup_on_init=True): self.use_hpu_graphs = args.use_hpu_graphs self.profiling_steps = args.profiling_steps self.profiling_warmup_steps = args.profiling_warmup_steps + self.profiling_record_shapes = args.profiling_record_shapes self.use_with_langchain = use_with_langchain if self.use_with_langchain: @@ -56,6 +65,7 @@ def __call__(self, prompt): hpu_graphs=self.use_hpu_graphs, profiling_steps=self.profiling_steps, profiling_warmup_steps=self.profiling_warmup_steps, + profiling_record_shapes=self.profiling_record_shapes, ).cpu() if use_batch: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 09cdd8e8f..4b46ba175 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -36,7 +36,12 @@ model_on_meta, write_checkpoints_json, ) -from optimum.habana.utils import check_habana_frameworks_version, check_optimum_habana_min_version, set_seed +from optimum.habana.utils import ( + check_habana_frameworks_version, + check_optimum_habana_min_version, + get_habana_frameworks_version, + set_seed, +) def adjust_batch(batch, size): @@ -96,6 +101,22 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) +def setup_inference(args, model): + import habana_frameworks.torch.core as htcore + + habana_version = get_habana_frameworks_version() + + print("Initializing inference mode") + # Keeping the if-else here for back compat. TODO remove later + if habana_version.major >= 1 and habana_version.minor >= 16: + htcore.hpu_initialize(model, mark_only_scales_as_const=True) + else: + const_marking = os.getenv("ENABLE_CONST_MARKING", "True") + if const_marking == "True": + htcore.hpu_initialize(model) + return model + + def setup_const_serialization(const_serialization_path): import uuid @@ -122,6 +143,11 @@ def setup_env(args): os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") + if args.use_hpu_graphs and args.limit_hpu_graphs and not args.reuse_cache and args.bucket_internal: + # Based upon above conditions and below env variable, + # we can call HPU graphs clear_inputs(). + os.environ.setdefault("PT_HPUGRAPH_DISABLE_TENSOR_CACHE", "1") + # Tweak generation so that it runs faster on Gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -132,7 +158,7 @@ def setup_device(args): if args.device == "hpu": import habana_frameworks.torch.core as htcore - if args.fp8: + if args.quant_config: htcore.hpu_set_env() return torch.device(args.device) @@ -151,13 +177,16 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): - model.model = torch.compile(model.model, backend="hpu_backend") + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) return model def setup_model(args, model_dtype, model_kwargs, logger): logger.info("Single-device run.") - + if args.assistant_model is None: + assistant_model = None + else: + logger.info(f"Using asssitant model {args.assistant_model}.") if args.disk_offload: from accelerate import infer_auto_device_map, init_empty_weights @@ -175,6 +204,10 @@ def setup_model(args, model_dtype, model_kwargs, logger): **model_kwargs, ) else: + if args.assistant_model is not None: + assistant_model = AutoModelForCausalLM.from_pretrained( + args.assistant_model, torch_dtype=model_dtype, **model_kwargs + ) if args.peft_model is not None: model = peft_model(args, model_dtype, logger, **model_kwargs) else: @@ -185,7 +218,12 @@ def setup_model(args, model_dtype, model_kwargs, logger): import habana_quantization_toolkit habana_quantization_toolkit.prep_model(model) + if args.assistant_model is not None: + habana_quantization_toolkit.quantize_model(assistant_model) + model = model.eval().to(args.device) + if args.assistant_model is not None: + assistant_model = assistant_model.eval().to(args.device) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -196,6 +234,8 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = wrap_in_hpu_graph(model, hash_with_views=False) else: model = wrap_in_hpu_graph(model) + if args.assistant_model is not None: + assistant_model = wrap_in_hpu_graph(assistant_model) if _is_peft_model(model): model.base_model = wrap_in_hpu_graph(model.base_model) if model.peft_type == "ADAPTION_PROMPT": @@ -203,8 +243,9 @@ def setup_model(args, model_dtype, model_kwargs, logger): if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) - - return model + # if args.assistant_model is not None: + # assistant_model = get_torch_compiled_model(assistant_model) + return model, assistant_model def setup_distributed_model(args, model_dtype, model_kwargs, logger): @@ -215,6 +256,11 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) load_to_meta = model_on_meta(config) + if args.assistant_model is None: + assistant_model = None + else: + logger.info(f"Using asssitant model {args.assistant_model}.") + if load_to_meta: # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load with deepspeed.OnDevice(dtype=model_dtype, device="meta"): @@ -249,6 +295,11 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): ) model.eval() + if args.assistant_model is not None: + assistant_model = AutoModelForCausalLM.from_pretrained( + args.assistant_model, torch_dtype=model_dtype, **model_kwargs + ).eval() + # Initialize the model ds_inference_kwargs = {"dtype": model_dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size} @@ -259,18 +310,21 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type in ["llama", "falcon"]: + if model.config.model_type in ["llama", "falcon", "qwen2"]: patch_scoped_linear_all_reduce(model) if args.quant_config: import habana_quantization_toolkit habana_quantization_toolkit.prep_model(model) + if args.assistant_model is not None: + habana_quantization_toolkit.prep_model(assistant_model) if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) - - return model + # if args.assistant_model is not None: + # assistant_model = get_torch_compiled_model(assistant_model) + return model, assistant_model def peft_model(args, model_dtype, logger, **model_kwargs): @@ -338,7 +392,7 @@ def peft_model(args, model_dtype, logger, **model_kwargs): return model -def setup_tokenizer(args, model): +def setup_tokenizer(args, model, assistant_model): tokenizer_kwargs = { "revision": args.model_revision, "token": args.token, @@ -355,6 +409,10 @@ def setup_tokenizer(args, model): model.generation_config.pad_token_id = 0 model.generation_config.bos_token_id = 1 model.generation_config.eos_token_id = 2 + if assistant_model is not None: + assistant_model.generation_config.pad_token_id = 0 + assistant_model.generation_config.bos_token_id = 1 + assistant_model.generation_config.eos_token_id = 2 tokenizer.bos_token_id = model.generation_config.bos_token_id tokenizer.eos_token_id = model.generation_config.eos_token_id tokenizer.pad_token_id = model.generation_config.pad_token_id @@ -363,6 +421,8 @@ def setup_tokenizer(args, model): tokenizer.bos_token = tokenizer.decode(tokenizer.bos_token_id) if model.config.model_type == "persimmon": model.generation_config.pad_token_id = model.generation_config.eos_token_id + if assistant_model is not None: + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id tokenizer.bos_token_id = model.generation_config.bos_token_id tokenizer.eos_token_id = model.generation_config.eos_token_id tokenizer.pad_token_id = model.generation_config.pad_token_id @@ -374,11 +434,13 @@ def setup_tokenizer(args, model): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.generation_config.pad_token_id = model.generation_config.eos_token_id + if assistant_model is not None: + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id - return tokenizer, model + return tokenizer, model, assistant_model -def setup_generation_config(args, model, tokenizer): +def setup_generation_config(args, model, assistant_model, tokenizer): bad_words_ids = None force_words_ids = None if args.bad_words is not None: @@ -387,11 +449,12 @@ def setup_generation_config(args, model, tokenizer): force_words_ids = [tokenizer.encode(force_word, add_special_tokens=False) for force_word in args.force_words] is_optimized = model_is_optimized(model.config) + # Generation configuration generation_config = copy.deepcopy(model.generation_config) generation_config.max_new_tokens = args.max_new_tokens generation_config.use_cache = args.use_kv_cache - generation_config.static_shapes = is_optimized + generation_config.static_shapes = is_optimized and assistant_model is None generation_config.bucket_size = args.bucket_size if is_optimized else -1 generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample @@ -409,7 +472,9 @@ def setup_generation_config(args, model, tokenizer): generation_config.use_flash_attention = args.use_flash_attention generation_config.flash_attention_recompute = args.flash_attention_recompute generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask + generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + return generation_config @@ -421,8 +486,10 @@ def initialize_model(args, logger): setup_device(args) set_seed(args.seed) get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) + if args.assistant_model is not None: + get_repo_root(args.assistant_model, local_rank=args.local_rank, token=args.token) use_deepspeed = args.world_size > 0 - if use_deepspeed or args.bf16 or args.fp8: + if use_deepspeed or args.bf16: model_dtype = torch.bfloat16 else: model_dtype = torch.float @@ -436,25 +503,20 @@ def initialize_model(args, logger): if args.trust_remote_code: logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail") - model = ( + model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed else setup_distributed_model(args, model_dtype, model_kwargs, logger) ) - tokenizer, model = setup_tokenizer(args, model) - generation_config = setup_generation_config(args, model, tokenizer) + tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) + generation_config = setup_generation_config(args, model, assistant_model, tokenizer) if args.const_serialization_path: setup_const_serialization(args.const_serialization_path) - if args.fp8: - import habana_frameworks.torch.core as htcore - - print("Initializing inference mode") - const_marking = os.getenv("ENABLE_CONST_MARKING", "True") - if const_marking == "True": - htcore.hpu_initialize(model) + if args.quant_config: + model = setup_inference(args, model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") logger.info(f"Model initialization took {(init_end - init_start):.3f}s") - return model, tokenizer, generation_config + return model, assistant_model, tokenizer, generation_config diff --git a/examples/text-to-speech/README.md b/examples/text-to-speech/README.md index cbe09ebfc..a1e089f55 100644 --- a/examples/text-to-speech/README.md +++ b/examples/text-to-speech/README.md @@ -18,13 +18,23 @@ limitations under the License. This directory contains a script that showcases how to use the Transformers pipeline API to run text to speech task on HPUs. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` + ## Single-HPU inference ```bash python3 run_pipeline.py \ --model_name_or_path microsoft/speecht5_tts \ --text "Hello, my dog is cooler than you!" \ + --use_hpu_graphs \ --bf16 ``` Models that have been validated: - [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts) + - [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium) + - [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) diff --git a/examples/text-to-speech/run_pipeline.py b/examples/text-to-speech/run_pipeline.py index 0183db034..1d9b53de7 100644 --- a/examples/text-to-speech/run_pipeline.py +++ b/examples/text-to-speech/run_pipeline.py @@ -58,11 +58,18 @@ def main(): parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.") parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.") parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.") + parser.add_argument("--seed", type=int, default=555, help="make speech generation deterministic") + parser.add_argument( + "--use_hpu_graphs", + action="store_true", + help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.", + ) args = parser.parse_args() adapt_transformers_to_gaudi() text = args.text text_bs = len(text) + set_seed(args.seed) if args.batch_size > text_bs: # Dynamically extends to support larger batch sizes @@ -84,18 +91,32 @@ def main(): device="hpu", ) - embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") - speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to("hpu") + if args.use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + generator.model = wrap_in_hpu_graph(generator.model) + + forward_params = None + if generator.model.config.model_type == "speecht5": + embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") + speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to("hpu") + forward_params = {"speaker_embeddings": speaker_embedding} + if generator.model.config.model_type == "seamless_m4t": + forward_params = {"tgt_lang": "eng"} + + generate_kwargs = None + if generator.model.can_generate(): + generate_kwargs = {"lazy_mode": True, "ignore_eos": False, "hpu_graphs": args.use_hpu_graphs} - with torch.autocast("hpu", torch.bfloat16, enabled=args.bf16), torch.no_grad(), torch.inference_mode(): + with torch.autocast("hpu", torch.bfloat16, enabled=args.bf16), torch.inference_mode(): # warm up for i in range(args.warmup): if generator.model.config.model_type == "speecht5": # SpeechT5 forces a dropout with training=True, which may zero out some elements randomly. # A random dropout may need different lengths of spectrograms to fit probability thresholds, # which violates the HPU static shape, so we have to fix the seed here. - set_seed(555) - generator(text, batch_size=args.batch_size, forward_params={"speaker_embeddings": speaker_embedding}) + set_seed(args.seed) + generator(text, batch_size=args.batch_size, forward_params=forward_params, generate_kwargs=generate_kwargs) start = time.time() for i in range(args.n_iterations): @@ -103,13 +124,13 @@ def main(): # SpeechT5 forces a dropout with training=True, which may zero out some elements randomly. # A random dropout may need different lengths of spectrograms to fit probability thresholds, # which violates the HPU static shape, so we have to fix the seed here. - set_seed(555) + set_seed(args.seed) speech = generator( - text, batch_size=args.batch_size, forward_params={"speaker_embeddings": speaker_embedding} + text, batch_size=args.batch_size, forward_params=forward_params, generate_kwargs=generate_kwargs ) end = time.time() logger.info(f"speech = {speech} time = {(end-start) * 1000 / args.n_iterations }ms") - sf.write("speech.wav", speech[0]["audio"], samplerate=speech[0]["sampling_rate"]) + sf.write("speech.wav", speech[0]["audio"].squeeze(), samplerate=speech[0]["sampling_rate"]) if __name__ == "__main__": diff --git a/examples/translation/README.md b/examples/translation/README.md index 6a24ad151..1d705d23f 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -21,6 +21,12 @@ limitations under the License. For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files. You will also find examples of these below. +## Requirements + +First, you should install the requirements: +```bash +pip install -r requirements.txt +``` ## Single-card Training diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py index 33ca6eaf1..db40ef8f2 100644 --- a/examples/translation/run_translation.py +++ b/examples/translation/run_translation.py @@ -62,8 +62,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.38.0") -check_optimum_habana_min_version("1.10.0") +check_min_version("4.40.0") +check_optimum_habana_min_version("1.11.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/examples/trl/README.md b/examples/trl/README.md index 3a0ff92e5..e6a4f0b00 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -1,10 +1,9 @@ # Examples -## Prerequisites - -Install all the dependencies in the `requirements.txt`: +## Requirements +First, you should install the requirements: ``` $ pip install -U -r requirements.txt ``` @@ -266,4 +265,4 @@ results = pipeline(prompts) for prompt, image in zip(prompts, results.images): image.save(f"{prompt}.png") -``` \ No newline at end of file +``` diff --git a/examples/trl/dpo.py b/examples/trl/dpo.py index f69081123..5779296bb 100644 --- a/examples/trl/dpo.py +++ b/examples/trl/dpo.py @@ -185,6 +185,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: torch_dtype=torch.bfloat16, ) model.config.use_cache = False + model.config.use_fused_rope = False if script_args.ignore_bias_buffers: # torch distributed hack diff --git a/examples/trl/ppo.py b/examples/trl/ppo.py index 22ea73ab0..d4ad12764 100644 --- a/examples/trl/ppo.py +++ b/examples/trl/ppo.py @@ -201,7 +201,8 @@ def collator(data): torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) - +model.config.use_fused_rope = False +model.config.use_fused_rms_norm = False optimizer = None model = model.to(torch.bfloat16) @@ -241,7 +242,6 @@ def collator(data): reward_model_name, num_labels=1, low_cpu_mem_usage=True, - torch_dtype=torch.bfloat16, ) if config.use_habana: diff --git a/examples/trl/requirements.txt b/examples/trl/requirements.txt index c4bc5c6bb..01f0e51a8 100644 --- a/examples/trl/requirements.txt +++ b/examples/trl/requirements.txt @@ -1,7 +1,6 @@ -trl == 0.7.8 +trl == 0.8.6 peft == 0.6.2 -datasets -wandb +datasets == 2.19.2 tyro evaluate scikit-learn diff --git a/examples/trl/reward_modeling.py b/examples/trl/reward_modeling.py index 35dd0b939..f67d65794 100644 --- a/examples/trl/reward_modeling.py +++ b/examples/trl/reward_modeling.py @@ -181,6 +181,7 @@ class ScriptArguments: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.eos_token_id model.config.use_cache = not script_args.gradient_checkpointing +model.config.use_fused_rope = False num_proc = 24 # Can adjust to be higher if you have more processors. original_columns = train_dataset.column_names diff --git a/examples/trl/sft.py b/examples/trl/sft.py index 581c96f90..6edaa4330 100644 --- a/examples/trl/sft.py +++ b/examples/trl/sft.py @@ -151,6 +151,7 @@ def create_datasets(tokenizer, args, seed=None): token=script_args.token, ) base_model.config.use_cache = False +base_model.config.use_fused_rope = False tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token @@ -167,7 +168,6 @@ def create_datasets(tokenizer, args, seed=None): gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True gaudi_config.use_fused_clip_norm = True - trainer = GaudiSFTTrainer( model=base_model, gaudi_config=gaudi_config, diff --git a/notebooks/AI_HW_Summit_2022.ipynb b/notebooks/AI_HW_Summit_2022.ipynb index cf6c8bdea..4db1b7358 100644 --- a/notebooks/AI_HW_Summit_2022.ipynb +++ b/notebooks/AI_HW_Summit_2022.ipynb @@ -261,7 +261,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0" + "!pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0" ] }, { diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index 006abfe67..fcfd47c31 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -146,7 +146,7 @@ def __init__( if deepspeed_plugin: if not is_deepspeed_available(): raise ImportError( - "DeepSpeed is not installed => run `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0`." + "DeepSpeed is not installed => run `pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`." ) mixed_precision = ( diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index 93c4a563a..eda6ed4b0 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -55,7 +55,7 @@ def __init__(self, cpu: bool = False, **kwargs): if not is_deepspeed_available(): raise ImportError( "DeepSpeed is not available, install it with: `pip install" - " git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0`." + " git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0`." ) self.distributed_type = GaudiDistributedType.DEEPSPEED import deepspeed diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 2cf9025c7..895b72151 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -94,7 +94,7 @@ def model_on_meta(config): """ Checks if load the model to meta. """ - return config.model_type in ["bloom", "llama", "falcon"] + return config.model_type in ["bloom", "llama", "falcon", "mixtral"] def get_optimized_model_name(config): diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index d4381e01f..26d5d2d35 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -4,4 +4,5 @@ from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import GaudiStableDiffusionUpscalePipeline from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline +from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index eba03ddd7..c2ffdb9cc 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -19,7 +19,7 @@ import inspect import os import sys -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch from diffusers.pipelines import DiffusionPipeline @@ -28,6 +28,7 @@ from diffusers.utils.torch_utils import is_compiled_module from huggingface_hub import create_repo +from optimum.habana.utils import to_device_dtype from optimum.utils import logging from ...transformers.gaudi_configuration import GaudiConfig @@ -157,7 +158,9 @@ def __init__( if bf16_full_eval or self.gaudi_config.use_torch_autocast: import diffusers - from ..models import gaudi_unet_2d_condition_model_forward + from ..models import ( + gaudi_unet_2d_condition_model_forward, + ) diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.forward = ( gaudi_unet_2d_condition_model_forward @@ -357,3 +360,33 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P pretrained_model_name_or_path, **kwargs, ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + # Move the state dict from HPU to CPU before saving + if unet_lora_layers: + unet_lora_layers = to_device_dtype(unet_lora_layers, target_device=torch.device("cpu")) + if text_encoder_lora_layers: + text_encoder_lora_layers = to_device_dtype(text_encoder_lora_layers, target_device=torch.device("cpu")) + if text_encoder_2_lora_layers: + text_encoder_2_lora_layers = to_device_dtype(text_encoder_2_lora_layers, target_device=torch.device("cpu")) + return super().save_lora_weights( + save_directory, + unet_lora_layers, + text_encoder_lora_layers, + text_encoder_2_lora_layers, + is_main_process, + weight_name, + save_function, + safe_serialization, + ) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 488286920..1096dec9a 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -512,7 +512,7 @@ def __call__( text_embeddings_batch = text_embeddings_batches[0] text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) - for i in range(num_inference_steps): + for i in range(len(timesteps)): if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf @@ -626,10 +626,18 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - if output_type == "pil": + if output_type == "pil" and isinstance(image, list): outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) else: - outputs["images"] += [*image] + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) if has_nsfw_concept is not None: outputs["has_nsfw_concept"] += has_nsfw_concept diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 394807b14..b60b6d89f 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -357,7 +357,7 @@ def __call__( text_embeddings_batch = text_embeddings_batches[0] text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) - for i in range(num_inference_steps): + for i in range(len(timesteps)): if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 401871aae..477871eb4 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -458,7 +458,7 @@ def __call__( noise_level_batch = noise_level_batches[0] noise_level_batches = torch.roll(noise_level_batches, shifts=-1, dims=0) - for i in range(num_inference_steps): + for i in range(len(timesteps)): if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf @@ -567,13 +567,21 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - if output_type == "pil": + if output_type == "pil" and isinstance(image, list): # Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image) outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) else: - outputs["images"] += [*image] + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) if has_nsfw_concept is not None: outputs["has_nsfw_concept"] += has_nsfw_concept diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b7bec0d6a..e9e5596d7 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -840,10 +840,18 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type) - if output_type == "pil": + if output_type == "pil" and isinstance(image, list): outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) else: - outputs["images"] += [*image] + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) # Offload all models self.maybe_free_model_hooks() diff --git a/optimum/habana/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/optimum/habana/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py new file mode 100644 index 000000000..cee268140 --- /dev/null +++ b/optimum/habana/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -0,0 +1,759 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import time +from dataclasses import dataclass +from math import ceil +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from diffusers.pipelines.stable_video_diffusion import StableVideoDiffusionPipeline +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( + _append_dims, + _resize_with_antialiasing, + tensor2vid, +) +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ....transformers.gaudi_configuration import GaudiConfig +from ....utils import speed_metrics +from ..pipeline_utils import GaudiDiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class GaudiStableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + frames (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + throughput (float): + Measured samples per second + """ + + frames: Union[List[PIL.Image.Image], np.ndarray] + throughput: float + + +class GaudiStableVideoDiffusionPipeline(GaudiDiffusionPipeline, StableVideoDiffusionPipeline): + r""" + Adapted from: https://github.com/huggingface/diffusers/blob/v0.24.0/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py#L72 + - Added generation by batches functionality + - Added support for HPU graphs + + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + + StableVideoDiffusionPipeline.__init__( + self, + vae, + image_encoder, + unet, + scheduler, + feature_extractor, + ) + + self.to(self._device) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + device, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype).to(device) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype).to(device) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @classmethod + def _pad_batches(cls, input_batches, num_dummy_samples): + sequence_to_stack = (input_batches[-1],) + tuple( + torch.zeros_like(input_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + input_batches[-1] = torch.vstack(sequence_to_stack) + return input_batches + + @classmethod + def _split_input_into_batches( + cls, + cond_input, + batch_size, + num_dummy_samples, + uncond_input=None, + ): + input_batches = list(torch.split(cond_input, batch_size)) + uncond_input_batches = None + if uncond_input is not None: + uncond_input_batches = list(torch.split(uncond_input, batch_size)) + + if num_dummy_samples > 0: # Pad inputs + input_batches = cls._pad_batches(input_batches, num_dummy_samples) + if uncond_input_batches is not None: + uncond_input_batches = cls._pad_batches(uncond_input_batches, num_dummy_samples) + + if uncond_input_batches is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional inputs into a single batch + # to avoid doing two forward passes + for i, (uncond_input_batch, input_batch) in enumerate(zip(uncond_input_batches, input_batches[:])): + input_batches[i] = torch.cat([uncond_input_batch, input_batch]) + input_batches = torch.stack(input_batches) + return input_batches + + @classmethod + def _split_image_latents_into_batches( + cls, + image_latents, + batch_size, + num_dummy_samples, + num_images, + do_classifier_free_guidance, + ): + if do_classifier_free_guidance: + # Tiling of unconditional and conditional image latents differs from image embeddings + # For image latents, first concatenate the unconditional and conditional image latents + # Next, repeat for number of videos per prompt + negative_image_latents = torch.zeros_like(image_latents) + image_latents_batches = list(torch.split(image_latents, batch_size)) + negative_image_latents_batches = list(torch.split(negative_image_latents, batch_size)) + if num_dummy_samples > 0: # Pad inputs + image_latents_batches = cls._pad_batches(image_latents_batches, num_dummy_samples) + negative_image_latents_batches = cls._pad_batches(negative_image_latents_batches, num_dummy_samples) + for i, (negative_image_latents_batch, image_latents_batch) in enumerate( + zip(negative_image_latents_batches, image_latents_batches[:]) + ): + uncond_splits = list(torch.split(negative_image_latents_batch, num_images)) + cond_splits = list(torch.split(image_latents_batch, num_images)) + input_batch = [torch.cat([uncond, cond]) for (uncond, cond) in zip(uncond_splits, cond_splits)] + image_latents_batches[i] = torch.vstack(input_batch) + image_latents_batches = torch.stack(image_latents_batches) + else: + image_latents_batches = cls._split_input_into_batches(image_latents, batch_size, num_dummy_samples) + + return image_latents_batches + + @classmethod + def _split_inputs_into_batches( + cls, + batch_size, + latents, + image_latents, + image_embeddings, + added_time_ids, + num_images, + do_classifier_free_guidance, + ): + if do_classifier_free_guidance: + negative_image_embeddings, image_embeddings = image_embeddings.chunk(2) + negative_added_time_ids, added_time_ids = added_time_ids.chunk(2) + else: + negative_image_embeddings = None + negative_added_time_ids = None + + # If the last batch has less samples than batch_size, compute number of dummy samples to pad + last_samples = latents.shape[0] % batch_size + num_dummy_samples = batch_size - last_samples if last_samples > 0 else 0 + + # Generate num_batches batches of size batch_size + latents_batches = cls._split_input_into_batches(latents, batch_size, num_dummy_samples) + image_latents_batches = cls._split_image_latents_into_batches( + image_latents, batch_size, num_dummy_samples, num_images, do_classifier_free_guidance + ) + image_embeddings_batches = cls._split_input_into_batches( + image_embeddings, batch_size, num_dummy_samples, negative_image_embeddings + ) + added_time_ids_batches = cls._split_input_into_batches( + added_time_ids, batch_size, num_dummy_samples, negative_added_time_ids + ) + + return ( + latents_batches, + image_latents_batches, + image_embeddings_batches, + added_time_ids_batches, + num_dummy_samples, + ) + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + batch_size: int = 1, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + batch_size (`int`, *optional*, defaults to 1): + The number of images in a batch. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + num_images = 1 + elif isinstance(image, list): + num_images = len(image) + else: + num_images = image.shape[0] + num_batches = ceil((num_videos_per_prompt * num_images) / batch_size) + logger.info( + f"{num_images} image(s) received, {num_videos_per_prompt} video(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image( + image, device, num_videos_per_prompt, self.do_classifier_free_guidance + ) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width) + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + noise = randn_tensor(image.shape, generator=generator, device=rand_device, dtype=image.dtype).to(device) + # image = self.image_processor.preprocess(image, height=height, width=width).to(device) + # noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + + image = image + noise_aug_strength * noise + + needs_upcasting = ( + self.vae.dtype == torch.float16 or self.vae.dtype == torch.bfloat16 + ) and self.vae.config.force_upcast + + if needs_upcasting: + cast_dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + + # Only encode the conditional image latents and generate unconditional image latents during batch split + # The tiling of conditional and unconditional image latents requires special handling + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=False, # Override to return only conditional latents + ) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16/bf16 if needed + if needs_upcasting: + self.vae.to(dtype=cast_dtype) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + num_images, + num_videos_per_prompt, + self.do_classifier_free_guidance, + device, + ) + added_time_ids = added_time_ids.to(device) + + # 6 Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self.scheduler.reset_timestep_dependent_params() + + # 7 Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + num_images * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 8. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 9. Split into batches (HPU-specific step) + ( + latents_batches, + image_latents_batches, + image_embeddings_batches, + added_time_ids_batches, + num_dummy_samples, + ) = self._split_inputs_into_batches( + batch_size, + latents, + image_latents, + image_embeddings, + added_time_ids, + num_images, + self.do_classifier_free_guidance, + ) + + outputs = { + "frames": [], + } + t0 = time.time() + t1 = t0 + + # 10. Denoising loop + throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + self._num_timesteps = len(timesteps) + for j in self.progress_bar(range(num_batches)): + # The throughput is calculated from the 3rd iteration + # because compilation occurs in the first two iterations + if j == throughput_warmup_steps: + t1 = time.time() + + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + image_latents_batch = image_latents_batches[0] + image_latents_batches = torch.roll(image_latents_batches, shifts=-1, dims=0) + image_embeddings_batch = image_embeddings_batches[0] + image_embeddings_batches = torch.roll(image_embeddings_batches, shifts=-1, dims=0) + added_time_ids_batch = added_time_ids_batches[0] + added_time_ids_batches = torch.roll(added_time_ids_batches, shifts=-1, dims=0) + + for i in self.progress_bar(range(num_inference_steps)): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents_batch], dim=2) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=image_embeddings_batch, + added_time_ids=added_time_ids_batch, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + + latents_batch = callback_outputs.pop("latents", latents_batch) + + if not output_type == "latent": + # cast back to fp16/bf16 if needed + if needs_upcasting: + self.vae.to(dtype=cast_dtype) + + frames = self.decode_latents(latents_batch, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents_batch + + outputs["frames"].append(frames) + + speed_metrics_prefix = "generation" + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=num_batches * batch_size + if t1 == t0 + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=num_batches, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["frames"][-1] = outputs["frames"][-1][:-num_dummy_samples] + + # Process generated images + for i, frames in enumerate(outputs["frames"][:]): + if i == 0: + outputs["frames"].clear() + + if output_type == "pil": + outputs["frames"] += frames + else: + outputs["frames"] += [*frames] + + self.maybe_free_model_hooks() + + if not return_dict: + return outputs["frames"] + + return GaudiStableVideoDiffusionPipelineOutput( + frames=outputs["frames"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) diff --git a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py index bd4cbda92..8fe62c34b 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py +++ b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py @@ -94,9 +94,14 @@ def __init__( prediction_type, interpolation_type, use_karras_sigmas, + sigma_min, + sigma_max, timestep_spacing, + timestep_type, steps_offset, + rescale_betas_zero_snr, ) + self._initial_timestep = None self.reset_timestep_dependent_params() diff --git a/optimum/habana/transformers/generation/__init__.py b/optimum/habana/transformers/generation/__init__.py index 15f567b0b..6b43ee2ae 100644 --- a/optimum/habana/transformers/generation/__init__.py +++ b/optimum/habana/transformers/generation/__init__.py @@ -1,6 +1,10 @@ +from .candidate_generator import GaudiAssistedCandidateGenerator from .configuration_utils import GaudiGenerationConfig from .stopping_criteria import ( + gaudi_EosTokenCriteria_call, gaudi_MaxLengthCriteria_call, gaudi_MaxNewTokensCriteria_call, + gaudi_MaxTimeCriteria_call, + gaudi_StoppingCriteriaList_call, ) from .utils import MODELS_OPTIMIZED_WITH_STATIC_SHAPES, GaudiGenerationMixin diff --git a/optimum/habana/transformers/generation/candidate_generator.py b/optimum/habana/transformers/generation/candidate_generator.py new file mode 100644 index 000000000..633e39e9d --- /dev/null +++ b/optimum/habana/transformers/generation/candidate_generator.py @@ -0,0 +1,45 @@ +import inspect +from typing import TYPE_CHECKING, Dict, Optional + +import torch +from transformers.generation.candidate_generator import ( + AssistedCandidateGenerator, +) + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transfromers.generation.logits_process import LogitsProcessorList + + from .configuration_utils import GaudiGenerationConfig + + +class GaudiAssistedCandidateGenerator(AssistedCandidateGenerator): + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + generation_config: "GaudiGenerationConfig", + logits_processor: "LogitsProcessorList", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + ): + super().__init__( + input_ids, + assistant_model, + generation_config, + logits_processor, + model_kwargs, + inputs_tensor, + ) + + # Remove model kwargs that are specific to optimized models + # E.g. token_idx, use_flash_attention, etc... + # Otherwise it will trigger an error in GenerationMixin._validate_model_kwargs + # TODO: may need to complete this for encoder-decoders: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/generation/utils.py#L1133 + model_args = set(inspect.signature(assistant_model.prepare_inputs_for_generation).parameters) + if "kwargs" in model_args or "model_kwargs" in model_args: + model_args |= set(inspect.signature(assistant_model.forward).parameters) + for key, value in list(self.assistant_kwargs.items()): + if value is not None and key not in model_args: + del self.assistant_kwargs[key] diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 61585b559..ce38a07ed 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -35,6 +35,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable recompute if use Habana flash attention. flash_attention_causal_mask (`bool`, *optional*): Whether to enable causal_mask if use Habana flash attention. + flash_attention_fast_softmax_mode (`bool`, *optional*): + Whether to use fast softmax with reduced precision if use Habana flash attention. """ def __init__(self, **kwargs): @@ -51,4 +53,5 @@ def __init__(self, **kwargs): self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) + self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index 4c6eedae6..dac7aadd9 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time +from typing import Union + import torch from optimum.utils import logging @@ -21,10 +24,21 @@ logger = logging.get_logger(__name__) +# Instead of returning a tensor describing status of completeness of each sentence +# we only return a single boolean describing the state of the batch +# only when needs_tensor_output says so, we return array of booleans + + +def create_return_const_tensor(input_ids, is_done): + return torch.full((input_ids.shape[0],), 1 if is_done else 0, device=input_ids.device, dtype=torch.uint8) -def gaudi_MaxLengthCriteria_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + +def gaudi_MaxLengthCriteria_call( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs +) -> Union[torch.BoolTensor, bool]: token_idx = kwargs.get("token_idx", None) if token_idx is not None: + assert not kwargs["needs_tensor_output"] return token_idx >= self.max_length else: cur_len = input_ids.shape[-1] @@ -35,12 +49,66 @@ def gaudi_MaxLengthCriteria_call(self, input_ids: torch.LongTensor, scores: torc f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " "exceptions, performance degradation, or nothing at all." ) - return is_done + return create_return_const_tensor(input_ids, is_done) -def gaudi_MaxNewTokensCriteria_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: +def gaudi_MaxNewTokensCriteria_call( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs +) -> Union[torch.BoolTensor, bool]: token_idx = kwargs.get("token_idx", None) if token_idx is not None: + assert not kwargs["needs_tensor_output"] return token_idx >= self.max_length else: - return input_ids.shape[-1] >= self.max_length + is_done = input_ids.shape[-1] >= self.max_length + return create_return_const_tensor(input_ids, is_done) + + +def gaudi_MaxTimeCriteria_call( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs +) -> Union[torch.BoolTensor, bool]: + is_done = time.time() - self.initial_timestamp > self.max_time + if kwargs["needs_tensor_output"]: + return create_return_const_tensor(input_ids, is_done) + else: + return is_done + + +def gaudi_EosTokenCriteria_call( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs +) -> Union[torch.BoolTensor, bool]: + self.eos_token_id = self.eos_token_id.to(input_ids.device) + token_idx = kwargs.get("token_idx", None) + if token_idx is not None: + assert not kwargs["needs_tensor_output"] + is_done = torch.isin(input_ids[:, token_idx - 1], self.eos_token_id) + else: + is_done = torch.isin(input_ids[:, -1], self.eos_token_id) + if kwargs["needs_tensor_output"]: + return is_done.byte() + else: + return torch.all(is_done).item() + + +def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool: + if token_idx is None: + return not ignore_eos and eos_token_id is not None + else: + # token_idx is present, so we have static shapes, so using single boolean + return False + + +def gaudi_StoppingCriteriaList_call( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs +) -> Union[torch.BoolTensor, bool]: + kwargs["needs_tensor_output"] = needs_tensor_output( + kwargs.get("token_idx", None), kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None) + ) + is_done = ( + torch.full((input_ids.shape[0],), 0, device=input_ids.device, dtype=torch.int8) + if kwargs["needs_tensor_output"] + else False + ) + for criteria in self: + is_done = is_done | criteria(input_ids, scores, **kwargs) + return is_done diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 33f2141fb..c858ef1e2 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -24,9 +24,18 @@ import torch.distributed as dist from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from transformers.generation.candidate_generator import CandidateGenerator +from transformers.generation.candidate_generator import ( + CandidateGenerator, + PromptLookupCandidateGenerator, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, +) from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, + MaxTimeCriteria, StoppingCriteriaList, validate_stopping_criteria, ) @@ -42,22 +51,25 @@ GenerationMixin, GenerationMode, _split_model_inputs, + _split_model_outputs, stack_model_outputs, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from transformers.utils import ModelOutput +from transformers.utils import ModelOutput, is_torchdynamo_compiling from optimum.utils import logging from ...utils import HabanaGenerationtime, HabanaProfile from ..integrations.deepspeed import unwrap_deepspeed_model +from .candidate_generator import GaudiAssistedCandidateGenerator from .configuration_utils import GaudiGenerationConfig if TYPE_CHECKING: from transformers import PreTrainedModel + from transformers.streamers import BaseStreamer - from .streamers import BaseStreamer + from .candidate_generator import GaudiCandidateGenerator MODELS_OPTIMIZED_WITH_STATIC_SHAPES = [ @@ -78,9 +90,12 @@ "mixtral", "gemma", "blip_text_model", + "seamless_m4t", + "starcoder2", "persimmon", "qwen2", "llava", + "llava_next", "stablelm", ] @@ -113,6 +128,15 @@ def incrementor(bucket_size, prompt_len): } +def get_final_stopping_criteria(x): + if isinstance(x, bool): + return x + elif torch.is_tensor(x): + return all(x) + else: + raise TypeError(f"The stopping criteria should be either a boolean or a torch.tensor but got {type(x)}.") + + class GaudiGenerationMixin(GenerationMixin): """ This class enables to perform fast generation in lazy mode and with HPU graphs. @@ -121,43 +145,6 @@ class GaudiGenerationMixin(GenerationMixin): sizes allows to make the most of lazy mode and HPU graphs. """ - @staticmethod - def _expand_inputs_for_generation( - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, - **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: - """ - Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]. - - Copied from Transformers: https://github.com/huggingface/transformers/blob/527ab894e59b6582578008e3b47648a65063f73d/src/transformers/generation/utils.py#L704 - The tensor `token_idx` is not expanded. - """ - - def _expand_dict_for_generation(dict_to_expand): - for key in dict_to_expand: - if ( - dict_to_expand[key] is not None - and key != "token_idx" - and key != "decoder_input_ids" - and isinstance(dict_to_expand[key], torch.Tensor) - ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) - return dict_to_expand - - if input_ids is not None: - input_ids = input_ids.repeat_interleave(expand_size, dim=0) - - model_kwargs = _expand_dict_for_generation(model_kwargs) - - if is_encoder_decoder: - if model_kwargs.get("encoder_outputs") is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) - - return input_ids, model_kwargs - def _get_hpu_graphs_kwargs(self, model_kwargs): hpu_graphs_kwargs = {} if model_kwargs["limit_hpu_graphs"]: @@ -187,6 +174,7 @@ def _prepare_decoder_input_ids_for_generation( bos_token_id: int = None, device: torch.device = None, max_new_tokens: int = None, + pad_token_id: int = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, @@ -221,7 +209,10 @@ def _prepare_decoder_input_ids_for_generation( # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length decoder_input_ids_start = ( - torch.ones((batch_size, max_length), dtype=torch.long, device=device) * decoder_start_token_id + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) + decoder_input_ids_start = torch.nn.functional.pad( + decoder_input_ids_start, (0, max_length - 1), value=pad_token_id ) # no user input -> use decoder_start_token_id as decoder_input_ids @@ -241,7 +232,18 @@ def _prepare_decoder_input_ids_for_generation( isinstance(decoder_start_token_id, torch.Tensor) and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() ): - decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if token_idx is None: + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + else: + max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length + if max_length != decoder_input_ids_start.shape[-1]: + decoder_input_ids_start = torch.nn.functional.pad( + decoder_input_ids_start, + (0, max_length - decoder_input_ids_start.shape[-1]), + value=pad_token_id, + ) + decoder_input_ids = decoder_input_ids_start.index_copy(1, token_idx, decoder_input_ids) + token_idx.add_(1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( @@ -249,8 +251,90 @@ def _prepare_decoder_input_ids_for_generation( dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask + else: + if token_idx is not None: + decoder_input_ids_len = decoder_input_ids.shape[-1] + max_length = ( + max_new_tokens + decoder_input_ids_len + if max_new_tokens is not None + else self.generation_config.max_length + ) + decoder_input_ids = torch.nn.functional.pad( + decoder_input_ids, (0, max_length - decoder_input_ids_len), value=pad_token_id + ) + token_idx.copy_(decoder_input_ids_len) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + pad_len = max_length - decoder_attention_mask.shape[-1] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :pad_len], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + return decoder_input_ids, model_kwargs + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """ + Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]. + + Copied from Transformers: https://github.com/huggingface/transformers/blob/527ab894e59b6582578008e3b47648a65063f73d/src/transformers/generation/utils.py#L704 + The tensor `token_idx` is not expanded. + """ + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "token_idx" + and key != "decoder_input_ids" + and key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + def _pad_past_key_values(self, model_kwargs): + pad_amount = model_kwargs.get("kv_cache_pad_len", 0) + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad( + model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount) + ) + if model_kwargs.get("lazy_mode", False): + self.htcore_generation.mark_step() + + def _remove_past_key_values(self, model_kwargs): + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + t = model_kwargs["past_key_values"][i][j] + del t + model_kwargs["past_key_values"][i][j] = None + del model_kwargs["past_key_values"] + model_kwargs["past_key_values"] = None + def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -265,10 +349,11 @@ def _update_model_kwargs_for_generation( """ # mark to identify starting from second token model_kwargs["first_token"] = False - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) + if not model_kwargs.get("pad_done", False): + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -311,6 +396,9 @@ def _update_model_kwargs_for_generation( if "token_idx_cpu" in model_kwargs: model_kwargs["token_idx_cpu"] += 1 + if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + return model_kwargs @torch.no_grad() @@ -345,7 +433,7 @@ def create_pad_arg(pad_amount, i, j): else: assert False elif model_kwargs["past_key_values"][0][0].dim() == 4: - return (0, 0, 0, pad_amount) # llama, falcon + return (0, 0, 0, pad_amount) # llama, falcon, qwen2 else: assert False, "Unknown case, please handle, or dont use bucketing" @@ -384,6 +472,180 @@ def create_pad_arg(pad_amount, i, j): model_kwargs["token_idx"] = torch.tensor(params["token_idx"], device=self.device) return input_ids, model_kwargs + def _get_candidate_generator( + self, + generation_config: GaudiGenerationConfig, + input_ids: torch.LongTensor, + inputs_tensor: torch.Tensor, + assistant_model: "PreTrainedModel", + logits_processor: LogitsProcessorList, + model_kwargs: Dict, + ) -> CandidateGenerator: + if generation_config.prompt_lookup_num_tokens is not None: + candidate_generator = PromptLookupCandidateGenerator( + num_output_tokens=generation_config.prompt_lookup_num_tokens, + max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_length=generation_config.max_length, + ) + else: + candidate_generator = GaudiAssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + ) + return candidate_generator + + def _get_stopping_criteria( + self, + generation_config: GaudiGenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], + ignore_eos: bool = False, + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if generation_config.max_length is not None: + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + criteria.append( + MaxLengthCriteria( + max_length=generation_config.max_length, + max_position_embeddings=max_position_embeddings, + ) + ) + if generation_config.max_time is not None: + criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if not ignore_eos and generation_config.eos_token_id is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, + has_token_idx, + ): + """Prepared max and min length in generaion configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + if has_token_idx: + generation_config.max_length = input_ids_length + else: + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length + # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.max_length -= inputs_tensor.shape[1] + + # same for min length + if generation_config.min_new_tokens is not None: + if not has_default_min_length: + logger.warning( + f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" + f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + if has_token_idx: + generation_config.min_length = input_ids_length + else: + generation_config.min_length = generation_config.min_new_tokens + input_ids_length + + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + + return generation_config + + def _prepare_generation_config( + self, generation_config: GaudiGenerationConfig, **kwargs: Dict + ) -> Tuple[GaudiGenerationConfig, Dict]: + """ + Copied from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/generation/utils.py#L1230 + Differences: + - add management of `static_shapes` and `ignore_eos` in the generation config + - workaround for `token_type_ids` for Falcon + """ + # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) + # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with + # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # three conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) the user must have set generation parameters in the model config. + # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. + if ( + not is_torchdynamo_compiling() + and self.generation_config._from_model_config + and self.generation_config._original_object_hash == hash(self.generation_config) + and self.config._has_non_default_generation_parameters() + ): + new_generation_config = GaudiGenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. + if is_torchdynamo_compiling(): + model_kwargs = kwargs + generate_attributes_in_kwargs = [ + key for key, value in kwargs.items() if getattr(generation_config, key, None) != value + ] + if len(generate_attributes_in_kwargs) > 0: + raise ValueError( + "`torch.compile` exception: all generation configuration attributes must be passed within a " + f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." + ) + else: + generation_config = copy.deepcopy(generation_config) + if generation_config.static_shapes is None: + generation_config.static_shapes = self.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES + if self.config.model_type == "vision-encoder-decoder": + generation_config.static_shapes = ( + self.config.decoder.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES + ) + self.generation_config.static_shapes = generation_config.static_shapes + if generation_config.ignore_eos is None: + generation_config.ignore_eos = kwargs.get("ignore_eos", kwargs.get("lazy_mode", None)) + self.generation_config.ignore_eos = generation_config.ignore_eos + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): + for key in ["token_type_ids"]: + model_kwargs.pop(key, None) + + return generation_config, model_kwargs + @torch.no_grad() def generate( self, @@ -402,6 +664,7 @@ def generate( profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, iteration_times: Optional[List[float]] = None, + profiling_record_shapes: Optional[bool] = False, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -424,12 +687,12 @@ def generate( inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`transformers.generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading + `generation_config` is not provided, the default will be used, which has the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. @@ -438,7 +701,7 @@ def generate( generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a + Custom stopping criteria that complements the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is @@ -475,6 +738,8 @@ def generate( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -482,7 +747,7 @@ def generate( Return: [`transformers.utils.ModelOutput`] or `torch.LongTensor`: A [`transformers.generationutils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`transformers.generationutils.ModelOutput`] types are: - [`transformers.generation.GenerateDecoderOnlyOutput`], @@ -492,57 +757,33 @@ def generate( - [`transformers.generation.GenerateEncoderDecoderOutput`], - [`transformers.generation.GenerateBeamEncoderDecoderOutput`] """ + if iteration_times is not None: + hb_gen_time = HabanaGenerationtime(iteration_times=iteration_times) + hb_gen_time.start() + else: + hb_gen_time = None if synced_gpus is None: if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: synced_gpus = True else: synced_gpus = False + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() if hpu_graphs and not lazy_mode: raise ValueError( "`hpu_graphs` is True but `lazy_mode` is False. HPU graphs require `lazy_mode` to be set to True." ) - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # three conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. - if ( - self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() - ): - new_generation_config = GaudiGenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - if generation_config.static_shapes is None: - generation_config.static_shapes = self.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES - if self.config.model_type == "vision-encoder-decoder": - generation_config.static_shapes = self.config.decoder.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES - self.generation_config.static_shapes = generation_config.static_shapes - if generation_config.ignore_eos is None: - generation_config.ignore_eos = kwargs.get("ignore_eos", lazy_mode) num_virtual_tokens = kwargs.pop("num_virtual_tokens", 0) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): - for key in ["token_type_ids"]: - model_kwargs.pop(key, None) + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: + synced_gpus = True + else: + synced_gpus = False logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -579,7 +820,9 @@ def generate( model_kwargs["use_cache"] = True else: model_kwargs["use_cache"] = generation_config.use_cache + self.generation_config.max_length = generation_config.max_length + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs @@ -592,8 +835,8 @@ def generate( not generation_config.bucket_internal and generation_config.bucket_size > 0 and ( - self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH - or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + generation_config.get_generation_mode(assistant_model) == GenerationMode.GREEDY_SEARCH + or generation_config.get_generation_mode(assistant_model) == GenerationMode.BEAM_SEARCH ) ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 @@ -603,6 +846,9 @@ def generate( ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size + # Below condition checked explicitly since llama supports bucket_internal even without reuse_cache + if generation_config.bucket_internal: + assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" if generation_config.reuse_cache: assert self.config.model_type in [ "llama", @@ -610,7 +856,8 @@ def generate( "falcon", "mixtral", "phi", - ], "reuse_cache only supported by llama, mistral, falcon, mixtral and phi at the moment" + "qwen2", + ], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi and qwen2 at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 @@ -688,6 +935,7 @@ def generate( bos_token_id=generation_config.bos_token_id, device=inputs_tensor.device, max_new_tokens=generation_config.max_new_tokens, + pad_token_id=generation_config.pad_token_id, ) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") @@ -698,37 +946,30 @@ def generate( # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if generation_config.max_new_tokens is not None: - if not has_default_max_length and generation_config.max_length is not None: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - if "token_idx" in model_kwargs: - generation_config.max_length = input_ids_length - else: - generation_config.max_length = generation_config.max_new_tokens + input_ids_length - # otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length` - elif ( - model_input_name == "inputs_embeds" - and inputs_tensor.shape[:-1] != input_ids.shape - and not self.config.is_encoder_decoder - ): - generation_config.max_length -= inputs_tensor.shape[1] + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + has_token_idx="token_idx" in model_kwargs, + ) - # if we don't pass `past_key_values` and a cache_implementation is specified - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get( - "past_key_values", False - ): - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static": + if model_kwargs.get("past_key_values", False) is not False: + raise ValueError( + "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." + ) + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) self._validate_generated_length( generation_config, @@ -743,6 +984,7 @@ def generate( model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16 # determine whether limit_hpu_graphs needs to be used + model_kwargs["use_hpu_graphs"] = hpu_graphs model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs # prepare for allocate kv cache @@ -752,6 +994,9 @@ def generate( model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False + model_kwargs["flash_attention_fast_softmax"] = ( + True if generation_config.flash_attention_fast_softmax else False + ) model_kwargs["num_virtual_tokens"] = num_virtual_tokens if not self.config.is_encoder_decoder: @@ -764,14 +1009,17 @@ def generate( unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens ) - model_kwargs["kv_cache_len"] = calculated_max_length + if generation_config.use_cache: + model_kwargs["kv_cache_len"] = calculated_max_length + model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens - if self.config.model_type in ["llama", "falcon", "mistral"]: + if self.config.model_type in ["llama", "falcon", "mistral", "qwen2"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) # 7. determine generation mode - generation_mode = self._get_generation_mode(generation_config, assistant_model) + generation_mode = generation_config.get_generation_mode(assistant_model) + if generation_config.bucket_size > 0: assert generation_config.static_shapes, "bucket_size > 0 can be set only when static_shapes is set" # if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search) @@ -815,7 +1063,9 @@ def generate( # 9. prepare stopping criteria self.generation_config.generation_mode = generation_mode prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria + generation_config=generation_config, + stopping_criteria=stopping_criteria, + ignore_eos=self.generation_config.ignore_eos, ) # In lazy mode, import Habana torch to be able to add mark_step() @@ -847,7 +1097,7 @@ def generate( ) # 12. run assisted generate - return self.assisted_decoding( + result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, @@ -855,22 +1105,25 @@ def generate( logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, + lazy_mode=lazy_mode, + ignore_eos=generation_config.ignore_eos, + profiling_warmup_steps=profiling_warmup_steps, + profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, **model_kwargs, ) if generation_mode == GenerationMode.GREEDY_SEARCH: # 11. run greedy search - return self.greedy_search( + result = self._greedy_search( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -880,7 +1133,8 @@ def generate( ignore_eos=generation_config.ignore_eos, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, - iteration_times=iteration_times, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -888,14 +1142,13 @@ def generate( if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") - return self.contrastive_search( + result = self._contrastive_search( input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -904,6 +1157,8 @@ def generate( sequential=generation_config.low_memory, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -920,13 +1175,12 @@ def generate( ) # 13. run sample - return self.sample( + result = self._sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -936,7 +1190,8 @@ def generate( ignore_eos=generation_config.ignore_eos, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, - iteration_times=iteration_times, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -959,13 +1214,12 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.beam_search( + result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -974,6 +1228,8 @@ def generate( lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -1001,14 +1257,13 @@ def generate( ) # 14. run beam sample - return self.beam_sample( + result = self._beam_sample( input_ids, beam_scorer, logits_processor=prepared_logits_processor, logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1016,6 +1271,8 @@ def generate( lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -1039,13 +1296,12 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.group_beam_search( + result = self._group_beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1053,6 +1309,8 @@ def generate( lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) @@ -1116,13 +1374,12 @@ def typeerror(): **model_kwargs, ) # 13. run beam search - return self.constrained_beam_search( + result = self._constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1130,11 +1387,23 @@ def typeerror(): lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, + hb_gen_time=hb_gen_time, + profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not callable(getattr(self, "_reset_cache", None)): + raise ValueError( + "A `static_cache` was used to generate but there was a failure when trying to release the cache. " + " Make sure this model implements a `_reset_cache` function." + ) + self._reset_cache() + + return result + @torch.no_grad() - def contrastive_search( + def _contrastive_search( self, input_ids: torch.LongTensor, top_k: Optional[int] = 1, @@ -1155,6 +1424,8 @@ def contrastive_search( lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -1163,7 +1434,7 @@ def contrastive_search( - In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -1214,6 +1485,8 @@ def contrastive_search( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1242,7 +1515,7 @@ def contrastive_search( >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) - >>> outputs = model.contrastive_search( + >>> outputs = model._contrastive_search( ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -1251,7 +1524,7 @@ def contrastive_search( raise NotImplementedError("Contrastive search is not supported by optimum-habana yet.") - def greedy_search( + def _greedy_search( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, @@ -1270,7 +1543,8 @@ def greedy_search( ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, - iteration_times: Optional[List[float]] = None, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -1279,7 +1553,7 @@ def greedy_search( - In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() + In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -1328,6 +1602,8 @@ def greedy_search( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1368,7 +1644,7 @@ def greedy_search( ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.greedy_search( + >>> outputs = model._greedy_search( ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ... ) @@ -1388,10 +1664,30 @@ def greedy_search( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if not self.generation_config.ignore_eos: + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -1420,45 +1716,40 @@ def greedy_search( ) # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + this_peer_finished = False if not ignore_eos: - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) + hb_profer = HabanaProfile( + warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes + ) hb_profer.start() - this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = -1 # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs.get("bucket_internal", None) reduce_recompile = model_kwargs.get("reduce_recompile", False) - prompt_len = input_ids.shape[-1] if not bucket_internal: if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) + inc = iter(incrementor(bucket_size, cur_len)) if bucket_size > 0: assert "position_ids" not in model_kwargs, "Untested path" - cur_len = prompt_len token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes cur_len = token_idx.item() - if iteration_times is not None: - hb_gen_time = HabanaGenerationtime(iteration_times=iteration_times) - hb_gen_time.start() - while True: + + time_to_first_token_done = False + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - if bucket_size > 0 and not bucket_internal: # it will not have been padded if bucket_size > 0 params = next(inc) @@ -1467,7 +1758,6 @@ def greedy_search( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1548,7 +1838,9 @@ def greedy_search( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) if bucket_size > 0 and bucket_internal: # Calculate slice idx for kv cache during the decode phase. @@ -1562,24 +1854,44 @@ def greedy_search( model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] cur_len = cur_len + 1 - # if eos_token was found in one sentence, set sentence to finished - if not ignore_eos and eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + if ignore_eos: + this_peer_finished = stopping_criteria( + input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + ) + else: + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id ) - # stop when each sentence is finished - if not ignore_eos and unfinished_sequences.max() == 0: - this_peer_finished = True - - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores, token_idx=cur_len): - this_peer_finished = True + this_peer_finished = unfinished_sequences.max() == 0 + if ( + not model_kwargs.get("pad_done", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True hb_profer.step() - if iteration_times is not None: + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() hb_gen_time.step() - if this_peer_finished and not synced_gpus: - break + + if ( + model_kwargs.get("use_hpu_graphs", False) + and model_kwargs.get("limit_hpu_graphs", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) hb_profer.stop() if streamer is not None: @@ -1610,7 +1922,7 @@ def greedy_search( else: return input_ids - def sample( + def _sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, @@ -1630,7 +1942,8 @@ def sample( ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, - iteration_times: Optional[List[float]] = None, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -1639,7 +1952,7 @@ def sample( - In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -1691,6 +2004,8 @@ def sample( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1744,7 +2059,7 @@ def sample( >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.sample( + >>> outputs = model._sample( ... input_ids, ... logits_processor=logits_processor, ... logits_warper=logits_warper, @@ -1769,10 +2084,30 @@ def sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if not self.generation_config.ignore_eos: + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -1803,46 +2138,42 @@ def sample( # keep track of which sequences are already finished # TODO: no ignore_eos check here since there is a compilation error, will add ignore_eos here if fixed - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) - hb_profer.start() - this_peer_finished = False # used by synced_gpus only - cur_len = input_ids.shape[-1] - token_idx = model_kwargs.get("token_idx", None) - if token_idx is not None: - # Update cur_len in case of static shapes - cur_len = token_idx.item() - if iteration_times is not None: - hb_gen_time = HabanaGenerationtime(iteration_times=iteration_times) - hb_gen_time.start() + batch_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = -1 # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs.get("bucket_internal", None) reduce_recompile = model_kwargs.get("reduce_recompile", False) - prompt_len = input_ids.shape[-1] + hb_profer = HabanaProfile( + warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes + ) + hb_profer.start() + if not bucket_internal: if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) + inc = iter(incrementor(bucket_size, cur_len)) if bucket_size > 0: assert "position_ids" not in model_kwargs, "Untested path" + token_idx = model_kwargs.get("token_idx", None) + if token_idx is not None: + # Update cur_len in case of static shapes + cur_len = token_idx.item() + # auto-regressive generation - while True: + time_to_first_token_done = False + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - if bucket_size > 0 and not bucket_internal: # it will not have been padded if bucket_size > 0 params = next(inc) @@ -1851,7 +2182,6 @@ def sample( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1927,7 +2257,9 @@ def sample( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) cur_len = cur_len + 1 if bucket_size > 0 and bucket_internal: @@ -1941,25 +2273,45 @@ def sample( else: model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] - # if eos_token was found in one sentence, set sentence to finished - if not ignore_eos and eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + if ignore_eos: + this_peer_finished = stopping_criteria( + input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id ) + else: + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + ) + this_peer_finished = unfinished_sequences.max() == 0 - # stop when each sentence is finished - if not ignore_eos and unfinished_sequences.max() == 0: - this_peer_finished = True - - # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores, token_idx=cur_len): - this_peer_finished = True + if ( + not model_kwargs.get("pad_done", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True hb_profer.step() - if iteration_times is not None: + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() hb_gen_time.step() - if this_peer_finished and not synced_gpus: - break + + if ( + model_kwargs.get("use_hpu_graphs", False) + and model_kwargs.get("limit_hpu_graphs", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) hb_profer.stop() if streamer is not None: @@ -1990,7 +2342,7 @@ def sample( else: return input_ids - def beam_search( + def _beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, @@ -2009,6 +2361,8 @@ def beam_search( lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -2017,7 +2371,7 @@ def beam_search( - In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() + In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -2067,6 +2421,8 @@ def beam_search( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2123,7 +2479,7 @@ def beam_search( ... ] ... ) - >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] @@ -2144,7 +2500,28 @@ def beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if not self.generation_config.ignore_eos: + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -2165,10 +2542,13 @@ def beam_search( num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes cur_len = token_idx.item() + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -2285,9 +2665,11 @@ def expand_if_needed(tensor, new_size, value, dim=-1): input_ids = torch.stack(result) return input_ids - hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) + hb_profer = HabanaProfile( + warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes + ) hb_profer.start() - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False bucket_size = model_kwargs.get("bucket_size", -1) reduce_recompile = model_kwargs.get("reduce_recompile", False) @@ -2298,18 +2680,11 @@ def expand_if_needed(tensor, new_size, value, dim=-1): assert "position_ids" not in model_kwargs, "Untested path" if self.generation_config.static_shapes: initial_ids = input_ids[::num_beams, 0:cur_len] - while True: + + time_to_first_token_done = False + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break if bucket_size > 0: # it will not have been padded if bucket_size > 0 @@ -2334,6 +2709,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): "transo_xl", "xlnet", "cpm", + "jamba", ] ): raise RuntimeError( @@ -2482,9 +2858,11 @@ def expand_if_needed(tensor, new_size, value, dim=-1): input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values", None) is not None: if model_kwargs["reuse_cache"]: model_kwargs["past_key_values"] = unwrap_deepspeed_model(self).reorder_kv_cache(beam_idx) else: @@ -2509,13 +2887,20 @@ def expand_if_needed(tensor, new_size, value, dim=-1): and num_eos_tokens >= num_beams_tensor ): break - elif stopping_criteria(input_ids, scores, token_idx=cur_len): - break - elif stopping_criteria(input_ids, scores) or (beam_scorer.is_done and not lazy_mode): - if not synced_gpus: + elif get_final_stopping_criteria(stopping_criteria(input_ids, scores, token_idx=cur_len)): break - else: - this_peer_finished = True + elif get_final_stopping_criteria(stopping_criteria(input_ids, scores)) or ( + beam_scorer.is_done and not lazy_mode + ): + this_peer_finished = True + + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + hb_gen_time.step() hb_profer.stop() if self.generation_config.static_shapes: @@ -2586,7 +2971,7 @@ def move(obj, device): else: return sequence_outputs["sequences"] - def beam_sample( + def _beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, @@ -2605,6 +2990,8 @@ def beam_sample( lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -2613,7 +3000,7 @@ def beam_sample( - In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate() + In most cases, you do not need to call [`~generation._GenerationMixin.beam_sample`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -2663,6 +3050,8 @@ def beam_sample( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2727,7 +3116,7 @@ def beam_sample( ... ] ... ) - >>> outputs = model.beam_sample( + >>> outputs = model._beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) @@ -2737,7 +3126,7 @@ def beam_sample( raise NotImplementedError("Beam search sampling is not supported by optimum-habana yet.") - def group_beam_search( + def _group_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, @@ -2755,6 +3144,8 @@ def group_beam_search( lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ): r""" @@ -2763,7 +3154,7 @@ def group_beam_search( - In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -2809,6 +3200,8 @@ def group_beam_search( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific kwargs that will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2869,7 +3262,7 @@ def group_beam_search( ... ] ... ) - >>> outputs = model.group_beam_search( + >>> outputs = model._group_beam_search( ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... ) @@ -2879,7 +3272,7 @@ def group_beam_search( raise NotImplementedError("Group beam search is not supported by optimum-habana yet.") - def constrained_beam_search( + def _constrained_beam_search( self, input_ids: torch.LongTensor, constrained_beam_scorer: ConstrainedBeamSearchScorer, @@ -2897,6 +3290,8 @@ def constrained_beam_search( lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -2905,7 +3300,7 @@ def constrained_beam_search( - In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -2956,6 +3351,8 @@ def constrained_beam_search( Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3015,7 +3412,7 @@ def constrained_beam_search( ... ] ... ) - >>> outputs = model.constrained_beam_search( + >>> outputs = model._constrained_beam_search( ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... ) @@ -3036,7 +3433,28 @@ def constrained_beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if not self.generation_config.ignore_eos: + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -3057,10 +3475,13 @@ def constrained_beam_search( num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: # Update cur_len in case of static shapes cur_len = token_idx.item() + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -3090,22 +3511,17 @@ def constrained_beam_search( beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) + hb_profer = HabanaProfile( + warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes + ) hb_profer.start() - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break + time_to_first_token_done = False + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3203,9 +3619,11 @@ def constrained_beam_search( else: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) @@ -3218,11 +3636,18 @@ def constrained_beam_search( hb_profer.step() - if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores, token_idx=cur_len): - if not synced_gpus: - break - else: - this_peer_finished = True + if constrained_beam_scorer.is_done or get_final_stopping_criteria( + stopping_criteria(input_ids, scores, token_idx=cur_len) + ): + this_peer_finished = True + + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + hb_gen_time.step() hb_profer.stop() sequence_outputs = constrained_beam_scorer.finalize( @@ -3268,11 +3693,10 @@ def constrained_beam_search( else: return sequence_outputs["sequences"] - def assisted_decoding( + def _assisted_decoding( self, input_ids: torch.LongTensor, - assistant_model: Optional["PreTrainedModel"] = None, - candidate_generator: Optional["CandidateGenerator"] = None, + candidate_generator: Optional["GaudiCandidateGenerator"] = None, do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, @@ -3285,10 +3709,13 @@ def assisted_decoding( output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, + ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, - streamer: Optional["BaseStreamer"] = None, + hb_gen_time: Optional[HabanaGenerationtime] = None, + profiling_record_shapes: Optional[bool] = False, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3299,7 +3726,7 @@ def assisted_decoding( - In most cases, you do not need to call [`transformers.generation.GenerationMixin.candidate_decoding`] directly. Use + In most cases, you do not need to call [`transformers.generation.GenerationMixin._assisted_decoding`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -3310,12 +3737,7 @@ def assisted_decoding( The sequence used as a prompt for the generation. candidate_generator (`CandidateGenerator`, *optional*): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. - assistant_model (`PreTrainedModel`, *optional*): - An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model - is much faster than running generation with the model you're calling generate from. As such, the - assistant model should be much smaller. + more information, the documentation of [`CandidateGenerator`] should be read. do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. logits_processor (`LogitsProcessorList`, *optional*): @@ -3347,15 +3769,17 @@ def assisted_decoding( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). profiling_warmup_steps (`int`, *optional*, defaults to 0): Number of steps to ignore for profling. profiling_steps (`int`, *optional*, defaults to 0): Number of steps to be captured when enabling profiling. - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + profiling_record_shapes (`bool`, *optional*, defaults to False): + Record shapes when enabling profiling. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3378,6 +3802,7 @@ def assisted_decoding( ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) + >>> from transformers.generation import AssistedCandidateGenerator >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -3393,13 +3818,316 @@ def assisted_decoding( ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.assisted_decoding( - ... input_ids, + >>> candidate_generator = AssistedCandidateGenerator( + ... input_ids=input_ids, ... assistant_model=assistant_model, + ... generation_config=model.generation_config, + ... logits_processor=logits_processor, + ... model_kwargs={}, + ... ) + >>> outputs = model._assisted_decoding( + ... input_ids, + ... candidate_generator=candidate_generator, ... logits_processor=logits_processor, ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - raise NotImplementedError("Assisted decoding is not supported by optimum-habana yet.") + # init values + # do_sample = logits_warper is not None + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + if not ignore_eos: + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + + hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) + hb_profer.start() + this_peer_finished = False + + token_idx = model_kwargs.get("token_idx", None) + time_to_first_token_done = False + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + if lazy_mode: + self.htcore_generation.mark_step() + + if token_idx is not None: + # Update cur_len in case of static shapes + cur_len = token_idx.item() + else: + cur_len = input_ids.shape[-1] + + # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # 1. Fetch candidate sequences from a `CandidateGenerator` + + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids[:, :cur_len]) + candidate_input_ids = candidate_input_ids.to(self.device) + if candidate_logits is not None: + candidate_logits = candidate_logits.to(self.device) + + if self.generation_config.static_shapes: + candidate_length = candidate_input_ids.shape[1] - cur_len + else: + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + is_done_candidate = stopping_criteria(candidate_input_ids, None) + + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, + # we use this forward pass to also pick the subsequent logits in the original model. + + # 2.1. Prepare the model inputs + model_kwargs = _prepare_attention_mask( + model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + ) + model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in model_kwargs: + model_kwargs["cache_position"] = torch.cat( + ( + model_kwargs["cache_position"], + torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), + ), + dim=0, + ) + + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs) + if "num_logits_to_keep" in model_inputs: + model_inputs["num_logits_to_keep"] = candidate_length + 1 + + hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) + + # 2.2. Run a forward pass on the candidate sequence + outputs = self( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **hpu_graphs_kwargs, + ) + + # 2.3. Process the new logits + new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + next_token_logits = new_logits.clone() + if len(logits_processor) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + if do_sample and len(logits_warper) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). + if do_sample and candidate_logits is not None: + from transformers.generation.utils import _speculative_sampling + + valid_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + is_done_candidate, + ) + + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. + else: + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) + + candidate_new_tokens = candidate_input_ids[:, cur_len:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + + # Ensure we don't generate beyond max_len or an EOS token + if is_done_candidate and n_matches == candidate_length: + n_matches -= 1 + valid_tokens = selected_tokens[:, : n_matches + 1] + + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated + # by the model after the last candidate match is also valid, as it is generated from a correct sequence. + # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there + # is no match. + + # 4.1. Get the valid continuation, after the matching tokens + if self.generation_config.static_shapes: + input_ids[:, cur_len : cur_len + n_matches + 1] = valid_tokens + else: + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + if streamer is not None: + streamer.put(valid_tokens.cpu()) + new_cur_len = input_ids.shape[-1] + + # 4.2. Discard past key values relative to unused assistant tokens + new_cache_size = new_cur_len - 1 + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + + # 5. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + + # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. + if return_dict_in_generate: + if output_scores: + scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + if output_logits: + raw_logits += (next_token_logits,) + + if "past_key_values" not in model_kwargs: + added_len = new_cur_len + else: + added_len = n_matches + 1 + + if output_attentions: + if self.config.is_encoder_decoder: + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, cur_len, added_len + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + cur_len, + added_len, + is_decoder_attention=True, + ) + else: + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + cur_len, + added_len, + is_decoder_attention=True, + ) + if output_hidden_states: + if self.config.is_encoder_decoder: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len + ) + else: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, cur_len, added_len + ) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + if ignore_eos: + this_peer_finished = stopping_criteria( + input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id + ) + else: + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id + ) + this_peer_finished = unfinished_sequences.max() == 0 + + hb_profer.step() + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + hb_gen_time.step() + + if this_peer_finished and not synced_gpus: + break + + hb_profer.stop() + if streamer is not None: + streamer.end() + + if ( + hasattr(candidate_generator, "assistant_model") + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" + ): + candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( + candidate_generator.num_assistant_tokens + ) + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py index 4fe621709..1aa21bed3 100755 --- a/optimum/habana/transformers/modeling_attn_mask_utils.py +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -50,16 +50,14 @@ def _make_causal_mask( # add lower triangular sliding window mask if necessary if sliding_window is not None: - diagonal = past_key_values_length - sliding_window + 1 + diagonal = past_key_values_length - sliding_window - 1 - # Replace triu with below + # Replace tril with below row_indices = torch.arange(mask.size(0), device=mask.device).view(-1, 1) # Reshape to column vector col_indices = torch.arange(mask.size(1), device=mask.device) - context_mask = 1 - (col_indices >= row_indices + diagonal).int().expand_as( - mask - ) # Expand to match mask shape + context_mask = (col_indices <= row_indices + diagonal).bool().expand_as(mask) # Expand to match mask shape - mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 35ec6c0c1..8d6791188 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -19,12 +19,16 @@ from .generation import ( GaudiGenerationConfig, GaudiGenerationMixin, + gaudi_EosTokenCriteria_call, gaudi_MaxLengthCriteria_call, gaudi_MaxNewTokensCriteria_call, + gaudi_MaxTimeCriteria_call, + gaudi_StoppingCriteriaList_call, ) from .models import ( GaudiBloomForCausalLM, GaudiBloomMLP, + GaudiCLIPVisionEmbeddings, GaudiCodeGenAttention, GaudiCodeGenForCausalLM, GaudiFalconAttention, @@ -35,9 +39,11 @@ GaudiGemmaDecoderLayer, GaudiGemmaForCausalLM, GaudiGPT2Attention, + GaudiGPT2Block, GaudiGPT2LMHeadModel, GaudiGPTBigCodeForCausalLM, GaudiGPTJAttention, + GaudiGPTJBlock, GaudiGPTJForCausalLM, GaudiGPTNeoXForCausalLM, GaudiLlamaAttention, @@ -49,6 +55,7 @@ GaudiLlamaModel, GaudiLlamaRotaryEmbedding, GaudiLlavaForConditionalGeneration, + GaudiLlavaNextForConditionalGeneration, GaudiMistralAttention, GaudiMistralDecoderLayer, GaudiMistralForCausalLM, @@ -66,9 +73,16 @@ GaudiPhiDecoderLayer, GaudiPhiForCausalLM, GaudiPhiModel, + GaudiQwen2Attention, GaudiQwen2DecoderLayer, GaudiQwen2ForCausalLM, + GaudiQwen2MLP, + GaudiQwen2Model, + GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, + GaudiStarcoder2DecoderLayer, + GaudiStarcoder2ForCausalLM, + LlamaConfig, MistralConfig, MixtralConfig, _gaudi_wav2vec2_compute_mask_indices, @@ -103,13 +117,11 @@ gaudi_conv1d_forward, gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, - gaudi_falcon_attention_split_heads, gaudi_falcon_linear_forward, gaudi_gemma_attention_forward, gaudi_gemma_model_forward, gaudi_generate_speech, gaudi_get_extended_attention_mask, - gaudi_gpt2_block_forward, gaudi_gpt2_forward, gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, @@ -118,7 +130,6 @@ gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, - gaudi_gptj_block_forward, gaudi_gptj_model_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, @@ -135,16 +146,26 @@ gaudi_persimmon_attention_forward, gaudi_persimmon_decoder_layer_forward, gaudi_persimmon_model_forward, - gaudi_qwen2_attention_forward, - gaudi_qwen2_model_forward, + gaudi_qwen2_rmsnorm_forward, gaudi_rot_matmul, gaudi_rot_vec_mul, + gaudi_SeamlessM4TAttention_forward, + gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, + gaudi_SeamlessM4TDecoder_forward, + gaudi_SeamlessM4TDecoderLayer_forward, + gaudi_SeamlessM4TForTextToSpeech_forward, + gaudi_SeamlessM4TForTextToSpeech_generate, + gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitModel_forward, gaudi_SpeechT5Attention_forward, gaudi_SpeechT5Decoder_forward, gaudi_SpeechT5DecoderLayer_forward, gaudi_stablelm_attention_forward, - gaudi_stablelm_decoder_layer_forward, gaudi_stablelm_model_forward, + gaudi_starcoder2_attention_forward, + gaudi_starcoder2_model_forward, gaudi_swin_get_attn_mask, gaudi_t5_layernorm_forward, gaudi_T5Attention_forward, @@ -153,8 +174,10 @@ gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation, gaudi_T5LayerSelfAttention_forward, gaudi_T5Stack_forward, + gaudi_unconstrained_rational_quadratic_spline, gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation, gaudi_vit_self_attention_forward, + gaudi_VitsResidualCouplingLayer_forward, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2_tdnnlayer_forward, @@ -193,7 +216,6 @@ def adapt_transformers_to_gaudi(): # Generation is modified to run faster in lazy mode transformers.generation.GenerationMixin.generate = GaudiGenerationMixin.generate - transformers.generation.GenerationMixin.assisted_decoding = GaudiGenerationMixin.assisted_decoding transformers.generation.GenerationMixin._update_model_kwargs_for_generation = ( GaudiGenerationMixin._update_model_kwargs_for_generation ) @@ -201,6 +223,8 @@ def adapt_transformers_to_gaudi(): GaudiGenerationMixin.update_model_kwargs_for_bucketing ) transformers.generation.GenerationMixin._get_hpu_graphs_kwargs = GaudiGenerationMixin._get_hpu_graphs_kwargs + transformers.generation.GenerationMixin._pad_past_key_values = GaudiGenerationMixin._pad_past_key_values + transformers.generation.GenerationMixin._remove_past_key_values = GaudiGenerationMixin._remove_past_key_values transformers.generation.GenerationMixin._expand_inputs_for_generation = staticmethod( GaudiGenerationMixin._expand_inputs_for_generation ) @@ -213,17 +237,27 @@ def adapt_transformers_to_gaudi(): transformers.generation.GenerationMixin._prepare_decoder_attention_mask = ( GaudiGenerationMixin._prepare_decoder_attention_mask ) + transformers.generation.GenerationMixin._prepare_generation_config = ( + GaudiGenerationMixin._prepare_generation_config + ) + transformers.generation.GenerationMixin._prepare_generated_length = GaudiGenerationMixin._prepare_generated_length + transformers.generation.GenerationMixin._get_stopping_criteria = GaudiGenerationMixin._get_stopping_criteria transformers.generation.GenerationMixin._validate_model_kwargs = GaudiGenerationMixin._validate_model_kwargs - transformers.generation.GenerationMixin.greedy_search = GaudiGenerationMixin.greedy_search - transformers.generation.GenerationMixin.sample = GaudiGenerationMixin.sample - transformers.generation.GenerationMixin.beam_search = GaudiGenerationMixin.beam_search - transformers.generation.GenerationMixin.beam_sample = GaudiGenerationMixin.beam_sample - transformers.generation.GenerationMixin.group_beam_search = GaudiGenerationMixin.group_beam_search - transformers.generation.GenerationMixin.constrained_beam_search = GaudiGenerationMixin.constrained_beam_search + transformers.generation.GenerationMixin._greedy_search = GaudiGenerationMixin._greedy_search + transformers.generation.GenerationMixin._sample = GaudiGenerationMixin._sample + transformers.generation.GenerationMixin._beam_search = GaudiGenerationMixin._beam_search + transformers.generation.GenerationMixin._beam_sample = GaudiGenerationMixin._beam_sample + transformers.generation.GenerationMixin._group_beam_search = GaudiGenerationMixin._group_beam_search + transformers.generation.GenerationMixin._constrained_beam_search = GaudiGenerationMixin._constrained_beam_search + transformers.generation.GenerationMixin._assisted_decoding = GaudiGenerationMixin._assisted_decoding + transformers.generation.GenerationMixin._get_candidate_generator = GaudiGenerationMixin._get_candidate_generator transformers.generation.GenerationConfig = GaudiGenerationConfig transformers.modeling_utils.GenerationConfig = GaudiGenerationConfig transformers.generation.MaxLengthCriteria.__call__ = gaudi_MaxLengthCriteria_call transformers.generation.MaxNewTokensCriteria.__call__ = gaudi_MaxNewTokensCriteria_call + transformers.generation.MaxTimeCriteria.__call__ = gaudi_MaxTimeCriteria_call + transformers.generation.EosTokenCriteria.__call__ = gaudi_EosTokenCriteria_call + transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call # Optimization for BLOOM generation on Gaudi transformers.models.bloom.modeling_bloom.BloomAttention.forward = gaudi_bloom_attention_forward @@ -274,7 +308,7 @@ def adapt_transformers_to_gaudi(): transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention transformers.models.gpt2.modeling_gpt2.GPT2Model.forward = gaudi_gpt2_forward transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = GaudiGPT2LMHeadModel - transformers.models.gpt2.modeling_gpt2.GPT2Block.forward = gaudi_gpt2_block_forward + transformers.models.gpt2.modeling_gpt2.GPT2Block = GaudiGPT2Block models_with_tracing_support.extend((GaudiGPT2Attention, GaudiGPT2LMHeadModel)) # Optimization for EsmFold on Gaudi @@ -294,7 +328,7 @@ def adapt_transformers_to_gaudi(): # Optimization for GPTJ on Gaudi transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM - transformers.models.gptj.modeling_gptj.GPTJBlock.forward = gaudi_gptj_block_forward + transformers.models.gptj.modeling_gptj.GPTJBlock = GaudiGPTJBlock transformers.models.gptj.modeling_gptj.GPTJModel.forward = gaudi_gptj_model_forward # Optimization for GPTBigCode on Gaudi @@ -326,9 +360,16 @@ def adapt_transformers_to_gaudi(): GaudiLlamaDynamicNTKScalingRotaryEmbedding ) transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward + transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig # Optimization for llava on Gaudi transformers.models.llava.modeling_llava.LlavaForConditionalGeneration = GaudiLlavaForConditionalGeneration + transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration = ( + GaudiLlavaNextForConditionalGeneration + ) + + # Optimization for Clip on Gaudi + transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings # Optimization for falcon generation on Gaudi transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention @@ -336,7 +377,6 @@ def adapt_transformers_to_gaudi(): transformers.models.falcon.modeling_falcon.FalconMLP = GaudiFalconMLP transformers.models.falcon.modeling_falcon.FalconModel = GaudiFalconModel transformers.models.falcon.modeling_falcon.FalconDecoderLayer = GaudiFalconDecoderLayer - transformers.models.falcon.modeling_falcon.FalconAttention._split_heads = gaudi_falcon_attention_split_heads transformers.models.falcon.modeling_falcon.FalconLinear.forward = gaudi_falcon_linear_forward # Optimization for t5 on Gaudi @@ -412,17 +452,65 @@ def adapt_transformers_to_gaudi(): gaudi_persimmon_decoder_layer_forward ) + # Optimization for seamless m4t on Gaudi + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TAttention.forward = ( + gaudi_SeamlessM4TAttention_forward + ) + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer.forward = ( + gaudi_SeamlessM4TDecoderLayer_forward + ) + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoder.forward = ( + gaudi_SeamlessM4TDecoder_forward + ) + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitModel.forward = ( + gaudi_SeamlessM4TTextToUnitModel_forward + ) + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.forward = ( + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward + ) + + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.prepare_inputs_for_generation = gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation + + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._get_output_hifigan_lengths = ( + gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths + ) + + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.forward = ( + gaudi_SeamlessM4TForTextToSpeech_forward + ) + + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.generate = ( + gaudi_SeamlessM4TForTextToSpeech_generate + ) + + transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.prepare_inputs_for_generation = ( + gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation + ) + + transformers.models.vits.modeling_vits._unconstrained_rational_quadratic_spline = ( + gaudi_unconstrained_rational_quadratic_spline + ) + transformers.models.vits.modeling_vits.VitsResidualCouplingLayer.forward = gaudi_VitsResidualCouplingLayer_forward + + # Optimization for starcoder2 on Gaudi + transformers.models.starcoder2.modeling_starcoder2.Starcoder2ForCausalLM = GaudiStarcoder2ForCausalLM + transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model.forward = gaudi_starcoder2_model_forward + transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention.forward = gaudi_starcoder2_attention_forward + transformers.models.starcoder2.modeling_starcoder2.Starcoder2DecoderLayer = GaudiStarcoder2DecoderLayer + # Optimization for qwen2 on Gaudi transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = GaudiQwen2ForCausalLM - transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = gaudi_qwen2_model_forward - transformers.models.qwen2.modeling_qwen2.Qwen2Attention.forward = gaudi_qwen2_attention_forward + transformers.models.qwen2.modeling_qwen2.Qwen2Model = GaudiQwen2Model + transformers.models.qwen2.modeling_qwen2.Qwen2Attention = GaudiQwen2Attention + transformers.models.qwen2.modeling_qwen2.Qwen2MLP = GaudiQwen2MLP transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GaudiQwen2DecoderLayer + transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm.forward = gaudi_qwen2_rmsnorm_forward # Optimization for stablelm on Gaudi transformers.models.stablelm.modeling_stablelm.StableLmForCausalLM = GaudiStableLmForCausalLM transformers.models.stablelm.modeling_stablelm.StableLmModel.forward = gaudi_stablelm_model_forward transformers.models.stablelm.modeling_stablelm.StableLmAttention.forward = gaudi_stablelm_attention_forward - transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer.forward = gaudi_stablelm_decoder_layer_forward + transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel.prepare_inputs_for_generation = gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 87dc38b1e..f17018e31 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -30,6 +30,7 @@ gaudi_bloom_convert_to_standard_cache, gaudi_bloom_model_forward, ) +from .clip import GaudiCLIPVisionEmbeddings from .codegen import ( GaudiCodeGenAttention, GaudiCodeGenForCausalLM, @@ -48,7 +49,6 @@ GaudiFalconForCausalLM, GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_split_heads, gaudi_falcon_linear_forward, ) from .gemma import ( @@ -57,7 +57,7 @@ gaudi_gemma_attention_forward, gaudi_gemma_model_forward, ) -from .gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward +from .gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward from .gpt_bigcode import ( GaudiGPTBigCodeForCausalLM, gaudi_gpt_bigcode_attention_forward, @@ -73,8 +73,8 @@ ) from .gptj import ( GaudiGPTJAttention, + GaudiGPTJBlock, GaudiGPTJForCausalLM, - gaudi_gptj_block_forward, gaudi_gptj_model_forward, ) from .llama import ( @@ -86,9 +86,11 @@ GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, + LlamaConfig, gaudi_llama_rmsnorm_forward, ) from .llava import GaudiLlavaForConditionalGeneration +from .llava_next import GaudiLlavaNextForConditionalGeneration from .mistral import ( GaudiMistralAttention, GaudiMistralDecoderLayer, @@ -140,10 +142,24 @@ GaudiPhiModel, ) from .qwen2 import ( + GaudiQwen2Attention, GaudiQwen2DecoderLayer, GaudiQwen2ForCausalLM, - gaudi_qwen2_attention_forward, - gaudi_qwen2_model_forward, + GaudiQwen2MLP, + GaudiQwen2Model, + gaudi_qwen2_rmsnorm_forward, +) +from .seamless_m4t import ( + gaudi_SeamlessM4TAttention_forward, + gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, + gaudi_SeamlessM4TDecoder_forward, + gaudi_SeamlessM4TDecoderLayer_forward, + gaudi_SeamlessM4TForTextToSpeech_forward, + gaudi_SeamlessM4TForTextToSpeech_generate, + gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitModel_forward, ) from .speecht5 import ( gaudi_generate_speech, @@ -152,11 +168,17 @@ gaudi_SpeechT5DecoderLayer_forward, ) from .stablelm import ( + GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, gaudi_stablelm_attention_forward, - gaudi_stablelm_decoder_layer_forward, gaudi_stablelm_model_forward, ) +from .starcoder2 import ( + GaudiStarcoder2DecoderLayer, + GaudiStarcoder2ForCausalLM, + gaudi_starcoder2_attention_forward, + gaudi_starcoder2_model_forward, +) from .swin import gaudi_swin_get_attn_mask from .t5 import ( gaudi_t5_layernorm_forward, @@ -171,6 +193,10 @@ gaudi_VisionEncoderDecoderModel_prepare_inputs_for_generation, ) from .vit import gaudi_vit_self_attention_forward +from .vits import ( + gaudi_unconstrained_rational_quadratic_spline, + gaudi_VitsResidualCouplingLayer_forward, +) from .wav2vec2 import ( _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, diff --git a/optimum/habana/transformers/models/clip/__init__.py b/optimum/habana/transformers/models/clip/__init__.py new file mode 100644 index 000000000..faa3a3355 --- /dev/null +++ b/optimum/habana/transformers/models/clip/__init__.py @@ -0,0 +1 @@ +from .modeling_clip import GaudiCLIPVisionEmbeddings diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py new file mode 100644 index 000000000..604c87836 --- /dev/null +++ b/optimum/habana/transformers/models/clip/modeling_clip.py @@ -0,0 +1,18 @@ +import torch +from transformers.models.clip.modeling_clip import CLIPVisionEmbeddings + + +class GaudiCLIPVisionEmbeddings(CLIPVisionEmbeddings): + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + # if HQT quantization enabled, remove the explicit cast to float8 to avoid HQT casting error + if "float8" in str(target_dtype) and pixel_values.device.type == "hpu": + target_dtype = torch.bfloat16 + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings diff --git a/optimum/habana/transformers/models/falcon/__init__.py b/optimum/habana/transformers/models/falcon/__init__.py index a42b846c6..0b15ff561 100644 --- a/optimum/habana/transformers/models/falcon/__init__.py +++ b/optimum/habana/transformers/models/falcon/__init__.py @@ -4,6 +4,5 @@ GaudiFalconForCausalLM, GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_split_heads, gaudi_falcon_linear_forward, ) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index c33a8bb78..6cd3f4f2a 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -88,51 +88,14 @@ def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: return hidden_states -def gaudi_falcon_attention_split_heads( - self, fused_qkv: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Copied from FalconAttention._split_heads https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/falcon/modeling_falcon.py - Changing index operation of qkv[:::] to use torch.index_select to work around gradient accuracy issue and improve performance. - """ - if self.new_decoder_architecture: - batch, seq_len, _ = fused_qkv.shape - - if self.config.num_attention_heads != self.num_heads: # When DS divides heads for TP - num_heads = self.config.num_attention_heads - num_kv_heads = self.config.num_kv_heads - else: # When DS not in use - num_heads = self.num_heads - num_kv_heads = self.num_kv_heads - - qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, self.head_dim) - # query = qkv[:, :, :, :-2] - # key = qkv[:, :, :, [-2]] - # value = qkv[:, :, :, [-1]] - d3 = qkv.shape[3] - 2 - query = torch.index_select(qkv, 3, index=torch.arange(d3, device=qkv.device)) - key = torch.index_select(qkv, 3, index=torch.tensor([d3], device=qkv.device)) - value = torch.index_select(qkv, 3, index=torch.tensor([d3 + 1], device=qkv.device)) - - key = torch.broadcast_to(key, query.shape) - value = torch.broadcast_to(value, query.shape) - - query, key, value = [x.flatten(2, 3) for x in (query, key, value)] - return query, key, value - elif not self.multi_query: - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - # TODO : Need to be fixed to use index_select() - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] - else: - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) - # return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] - d2 = fused_qkv.shape[2] - 2 - query = torch.index_select(fused_qkv, 2, index=torch.arange(d2, device=fused_qkv.device)) - key = torch.index_select(fused_qkv, 2, index=torch.tensor([d2], device=fused_qkv.device)) - value = torch.index_select(fused_qkv, 2, index=torch.tensor([d2 + 1], device=fused_qkv.device)) - return query, key, value +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) class Softmax(nn.Module): @@ -159,6 +122,41 @@ def __init__(self, config: FalconConfig): self.bmm1 = Matmul() self.bmm2 = Matmul() self.softmax = Softmax() + self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads + + def repeat_kv( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, + ): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) @@ -176,12 +174,14 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - attn_weight = self.bmm1(query, key.transpose(-2, -1)) + query, key, value, attn_mask = self.repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) + attn_weight = self.bmm1(query, key.transpose(-2, -1)) attn_weight += attn_mask attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return self.bmm2(attn_weight, value) + attn_output = self.bmm2(attn_weight, value) + return attn_output def update(prev, cur, dim, idx, inp_seq_len): @@ -241,22 +241,79 @@ class GaudiFalconAttention(FalconAttention): - replace F.scaled_dot_product_attention with Habana torch's version for BF16 - use ScaledDotProductAttention for FP8 quantization - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + Choice of SDPA: + There are these variables: use_flash_attention and datatype (bf16/fp8) + datatype is determined by presence of QUANT_CONFIG env var, presence of which indicates fp8 + 1. use_flash_attention, fp8: use ModuleFusedSDPA. most optimal + 2. use_flash_attention, bf16: use FusedSDPA + 3. not use_flash_attention, fp8: Use ScaledDotProductAttention, along with QUANT_CONFIG. This is the case before this PR + 4. not use_flash_attention, bf16: F.scaled_dot_product_attention. Slowest option """ def __init__(self, config: FalconConfig): super().__init__(config) - if os.getenv("QUANT_CONFIG", ""): - self.sdpa = ScaledDotProductAttention(config) + self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" + + # In the constructor we do not know which one we will need later in the forward, so creating both + # TODO, Does this affect memory usage? + if self.is_fp8: + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) + self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 self.max_position_embeddings = config.max_position_embeddings + def _split_heads( + self, fused_qkv: torch.Tensor, broadcast: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.new_decoder_architecture: + batch, seq_len, _ = fused_qkv.shape + + if self.config.num_attention_heads != self.num_heads: # When DS divides heads for TP + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_kv_heads + else: # When DS not in use + num_heads = self.num_heads + num_kv_heads = self.num_kv_heads + + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, self.head_dim) + # query = qkv[:, :, :, :-2] + # key = qkv[:, :, :, [-2]] + # value = qkv[:, :, :, [-1]] + d3 = qkv.shape[3] - 2 + query = torch.index_select(qkv, 3, index=torch.arange(d3, device=qkv.device)) + key = torch.index_select(qkv, 3, index=torch.tensor([d3], device=qkv.device)) + value = torch.index_select(qkv, 3, index=torch.tensor([d3 + 1], device=qkv.device)) + if broadcast: + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + elif not self.multi_query: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + # TODO : Need to be fixed to use index_select() + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + else: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) + # return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] + d2 = fused_qkv.shape[2] - 2 + query = torch.index_select(fused_qkv, 2, index=torch.arange(d2, device=fused_qkv.device)) + key = torch.index_select(fused_qkv, 2, index=torch.tensor([d2], device=fused_qkv.device)) + value = torch.index_select(fused_qkv, 2, index=torch.tensor([d2 + 1], device=fused_qkv.device)) + return query, key, value + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): if self.config.new_decoder_architecture: - cache_shape = (batch_size, self.num_heads, max_seq_len, self.head_dim) + cache_shape = (batch_size, self.num_kv_heads, max_seq_len, self.head_dim) else: cache_shape = (batch_size, 1, max_seq_len, self.head_dim) device = self.query_key_value.weight.device @@ -287,16 +344,22 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None + (query_layer, key_layer, value_layer) = self._split_heads( + fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention + ) batch_size, query_length, _, _ = query_layer.shape @@ -338,7 +401,7 @@ def pre_attn_forward( dtype=self.query_key_value.weight.dtype, device=self.query_key_value.weight.device, ) - layer_past = (past_key, past_value) + layer_past = [past_key, past_value] key_layer = self.k_cache.update( layer_past[0], key_layer, -2, token_idx, self.inp_seq_len ) # k_layer bs*1, q_len, head_dim @@ -359,7 +422,12 @@ def pre_attn_forward( else: kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] - if alibi is None: + if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1): + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + present = (present[0].shape, present[1].shape) + + if alibi is None: # both train/inference if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) @@ -368,13 +436,22 @@ def pre_attn_forward( # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer else: - if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - attn_output = self.sdpa( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) + if use_flash_attention or train_with_flash_attention: + is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask + if self.is_fp8: + attn_mask = None if is_causal else attention_mask + flash_attention_fast_softmax = True # TODO pass this along + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + enable_recompute = self.is_fp8 if query_length == 1 else flash_attention_recompute + with sdp_kernel(enable_recompute=enable_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode + ) else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + # TODO very similar to the fp8 case above, could be merged. + with sdp_kernel( + enable_recompute=flash_attention_recompute + ) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, key_layer, @@ -382,22 +459,28 @@ def pre_attn_forward( attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, + is_causal and attention_mask is None, ) else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) + if self.is_fp8: + attn_output = self.unfused_scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + # Performance improvement for HPU if self.training is True and htcore: htcore.mark_step() @@ -415,8 +498,9 @@ def pre_attn_forward( return attn_output, present, _ else: - if self._use_sdpa and not output_attentions and head_mask is None: + if train_with_flash_attention: if FusedSDPA: + # TODO needs to be turned into a module for quantization with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, @@ -513,6 +597,9 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer): - add new args token_idx and position_ids - add token_idx and position_ids into attention inputs - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def __init__(self, config: FalconConfig): @@ -538,6 +625,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): if "padding_mask" in kwargs: @@ -563,6 +653,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, **kwargs, ) @@ -611,6 +704,9 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ): if self.config.new_decoder_architecture: attention_layernorm_out = self.ln_attn(hidden_states) @@ -632,6 +728,9 @@ def pre_attn( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out @@ -644,6 +743,9 @@ class GaudiFalconModel(FalconModel): - add new args token_idx and position_ids - add token_idx and position_ids into decoder inputs - add new arg reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -669,6 +771,9 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -746,28 +851,25 @@ def forward( elif head_mask is None: alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - attention_mask_2d = attention_mask # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) # We take care to integrate alibi bias in the attention_mask here. - if attention_mask_2d is None: - attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) - else: - attention_mask = torch.masked_fill( - alibi / math.sqrt(self.config.hidden_size // self.num_heads), - attention_mask < -1, - torch.finfo(alibi.dtype).min, - ) + min_dtype = torch.finfo(alibi.dtype).min + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + min_dtype, + ) - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1: - attention_mask = GaudiAttentionMaskConverter._unmask_unattended( - attention_mask, attention_mask_2d, unmasked_value=0.0 - ) + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = GaudiAttentionMaskConverter._unmask_unattended( + attention_mask, min_dtype=min_dtype + ) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _gaudi_prepare_4d_causal_attention_mask( @@ -786,7 +888,6 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - htcore.mark_step() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -803,6 +904,9 @@ def forward( use_cache, output_attentions, None, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, ) else: outputs = block( @@ -817,6 +921,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] @@ -852,6 +959,9 @@ class GaudiFalconForCausalLM(FalconForCausalLM): - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - add new args reuse_cache + - add use_flash_attention + - add flash_attention_recompute + - add flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -871,6 +981,7 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -885,8 +996,9 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] - elif reuse_cache and token_idx is not None: - # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] @@ -915,6 +1027,9 @@ def prepare_inputs_for_generation( "token_idx": token_idx, "reuse_cache": reuse_cache, "cache_idx": kwargs.get("cache_idx"), + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } def forward( @@ -934,6 +1049,9 @@ def forward( reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: int = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -942,6 +1060,12 @@ def forward( are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if use_flash_attention: + assert FusedSDPA, "`use_flash_attention` is True, but cannot find FusedSDPA. Please import it as `from habana_frameworks.torch.hpex.kernels import FusedSDPA` or set use_flash_attention to False (at the expense of a possible performance degradation)." + if flash_attention_recompute: + assert use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not" + if flash_attention_causal_mask: + assert use_flash_attention, "flash_attention_causal_mask is set, but use_flash_attention is not" transformer_outputs = self.transformer( input_ids, @@ -957,6 +1081,9 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = transformer_outputs[0] diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index b68ed794d..7f57e838b 100644 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -91,7 +91,7 @@ def gaudi_gemma_attention_forward( past_key_value.key_cache.append(key_states) past_key_value.value_cache.append(value_states) else: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -392,7 +392,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): """ Inherits from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py @@ -409,9 +409,16 @@ def prepare_inputs_for_generation( if past_key_values is not None: if token_idx is None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -457,29 +464,21 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + model_inputs.update( { - "position_ids": position_ids.contiguous(), + "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), diff --git a/optimum/habana/transformers/models/gpt2/__init__.py b/optimum/habana/transformers/models/gpt2/__init__.py index 7a23f9472..4052373a2 100644 --- a/optimum/habana/transformers/models/gpt2/__init__.py +++ b/optimum/habana/transformers/models/gpt2/__init__.py @@ -1 +1 @@ -from .modeling_gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward +from .modeling_gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index c48c71199..3a2e85d24 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -3,7 +3,7 @@ import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2LMHeadModel, logger +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2LMHeadModel, logger class GaudiGPT2Attention(GPT2Attention): @@ -168,76 +168,93 @@ def forward( return outputs # a, present, (attentions) -def gaudi_gpt2_block_forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: - """ - Copied from GPT2Block.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py - The only differences are: - - add new args token_idx - """ +class GaudiGPT2Block(torch.nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GaudiGPT2Attention + + self.ln_1 = torch.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = torch.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = torch.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + """ + Copied from GPT2Block.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py + The only differences are: + - add new args token_idx + """ - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - - attn_outputs = self.attn( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - # residual connection - hidden_states = attn_output + residual - - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( + hidden_states = self.ln_1(hidden_states) + + attn_outputs = self.attn( hidden_states, + layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, output_attentions=output_attentions, + token_idx=token_idx, ) - attn_output = cross_attn_outputs[0] + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + hidden_states = attn_output + residual - residual = hidden_states - hidden_states = self.ln_2(hidden_states) + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + feed_forward_hidden_states + residual = hidden_states + hidden_states = self.ln_2(hidden_states) - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] - return outputs # hidden_states, present, (attentions, cross_attentions) + return outputs # hidden_states, present, (attentions, cross_attentions) def gaudi_gpt2_forward( @@ -300,8 +317,6 @@ def gaudi_gpt2_forward( # GPT2Attention mask. if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 03301ec71..e35b4cac5 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -236,6 +236,15 @@ def gaudi_gpt_bigcode_model_forward( self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) + # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if self.multi_query: @@ -247,17 +256,9 @@ def gaudi_gpt_bigcode_model_forward( # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = GaudiAttentionMaskConverter._unmask_unattended( - self_attention_mask, attention_mask, unmasked_value=True + self_attention_mask, min_dtype=min_dtype ) - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full([], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device), - ) - attention_mask = self_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 08f343337..b43a5c323 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -422,8 +422,24 @@ def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dty def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if q.dtype == torch.bfloat16: + rope_q = FusedRoPE.apply( + q, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if k.dtype == torch.bfloat16: + rope_k = FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + return rope_q, rope_k else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/gptj/__init__.py b/optimum/habana/transformers/models/gptj/__init__.py index 9b3b6a643..23a1d6971 100644 --- a/optimum/habana/transformers/models/gptj/__init__.py +++ b/optimum/habana/transformers/models/gptj/__init__.py @@ -1,6 +1,6 @@ from .modeling_gptj import ( GaudiGPTJAttention, + GaudiGPTJBlock, GaudiGPTJForCausalLM, - gaudi_gptj_block_forward, gaudi_gptj_model_forward, ) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index cc08d4d2c..0fae0f046 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -5,6 +5,7 @@ from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.gptj.modeling_gptj import ( + GPTJMLP, GPTJAttention, GPTJForCausalLM, apply_rotary_pos_emb, @@ -141,51 +142,59 @@ def forward( return outputs # a, present, (attentions) -def gaudi_gptj_block_forward( - self, - hidden_states: Optional[torch.FloatTensor], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - sin: Optional[torch.Tensor] = None, - cos: Optional[torch.Tensor] = None, -) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: - """ - Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py - The only differences are: - - add new args token_idx - - pass sin and cos from upper level as they are identical for each attn block - """ - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - sin=sin, - cos=cos, - ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] +class GaudiGPTJBlock(torch.nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = torch.nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GaudiGPTJAttention(config) + self.mlp = GPTJMLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + """ + Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + The only differences are: + - add new args token_idx + - pass sin and cos from upper level as they are identical for each attn block + """ + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + sin=sin, + cos=cos, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] - feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] - return outputs # hidden_states, present, (attentions) + return outputs # hidden_states, present, (attentions) def gaudi_gptj_model_forward( diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index 20703ffd0..0a8758d89 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -1,3 +1,4 @@ +from .configuration_llama import LlamaConfig from .modeling_llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py new file mode 100644 index 000000000..7cc66488d --- /dev/null +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -0,0 +1,55 @@ +# TODO: To remove when the repo is upgraded to Transformers >= 4.41.0 +from transformers.models.llama.configuration_llama import LlamaConfig + + +class LlamaConfig(LlamaConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + pad_token_id, + bos_token_id, + eos_token_id, + pretraining_tp, + tie_word_embeddings, + rope_theta, + rope_scaling, + attention_bias, + attention_dropout, + **kwargs, + ) + + self.mlp_bias = mlp_bias diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a5e0189d6..917cb9680 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,9 +1,11 @@ import math +import os import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -71,7 +73,88 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +class GaudiLlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self._cos_cached[:seq_len].to(dtype=x.dtype), + self._sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + +class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + class GaudiLlamaMLP(LlamaMLP): + def __init__(self, config): + super(LlamaMLP, self).__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + bias = config.mlp_bias if hasattr(config, "mlp_bias") else False + self.gate_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + def pre_mlp_forward(self, x): if self.config.pretraining_tp > 1: slice = self.intermediate_size // self.config.pretraining_tp @@ -140,6 +223,16 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -169,11 +262,10 @@ def update(self, prev, cur, dim, idx, inp_seq_len): if prev.shape == cur.shape: prev.copy_(cur) return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: # Initialize prev[:, :, :inp_seq_len, :].copy_(cur) return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" if idx is not None: prev.index_copy_(dim, idx - 1, cur) return prev @@ -189,76 +281,6 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) -class GaudiLlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() - - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self._cos_cached[:seq_len].to(dtype=x.dtype), - self._sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) - - -class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) - - class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -267,6 +289,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -314,6 +337,7 @@ def pre_attn_forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -328,6 +352,7 @@ def pre_attn_forward( - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax - add new arg num_virtual_tokens """ bsz, q_len, _ = hidden_states.size() @@ -353,6 +378,7 @@ def pre_attn_forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) @@ -392,7 +418,8 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - past_key_value = (past_key, past_value) + # Return list instead of tuple + past_key_value = [past_key, past_value] if ( token_idx is not None and num_virtual_tokens is not None @@ -421,22 +448,27 @@ def pre_attn_forward( if use_flash_attention and FusedSDPA: import habana_frameworks.torch.hpu as ht + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) else: # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, softmax_mode + ) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) else: @@ -478,6 +510,11 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -525,6 +562,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -538,6 +576,7 @@ def forward( - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax """ if "padding_mask" in kwargs: warnings.warn( @@ -559,6 +598,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, **kwargs, @@ -592,6 +632,7 @@ def pre_attn( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, num_virtual_tokens: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -610,6 +651,7 @@ def pre_attn( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -664,14 +706,6 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. - # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`. - causal_mask = torch.full( - (config.max_position_embeddings, config.max_position_embeddings), - fill_value=1, - dtype=torch.bool, - ) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # Initialize weights and apply final processing self.post_init() @@ -704,6 +738,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -717,6 +752,7 @@ def forward( - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -736,6 +772,12 @@ def forward( batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if hasattr(self.config, "use_fused_rope") and self.config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + if hasattr(self.config, "use_fused_rms_norm") and self.config.use_fused_rms_norm is False: + global has_fused_rms_norm + has_fused_rms_norm = False if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( @@ -745,8 +787,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + ignore_cache_position = True # Ignoring cache position for HPU use_new_cache = False # Ignoring new Cache path for HPU + past_seen_tokens = 0 if past_key_values is not None and use_cache: # kept for BC (cache positions) @@ -765,6 +809,8 @@ def forward( if ignore_cache_position is False: if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -788,7 +834,8 @@ def forward( past_seen_tokens, ) else: - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + # embed positions hidden_states = inputs_embeds @@ -827,6 +874,8 @@ def forward( use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, + flash_attention_fast_softmax, + None, ) else: layer_outputs = decoder_layer( @@ -843,6 +892,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -916,6 +966,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -947,6 +998,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, @@ -993,19 +1045,34 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + token_idx=None, + **kwargs, ): past_length = 0 reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1029,8 +1096,9 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: - # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] @@ -1044,22 +1112,10 @@ def prepare_inputs_for_generation( position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] - # TODO: we are using token_idx, disable this for now - # if self.generation_config.cache_implementation == "static": - # generation with static cache - # cache_position = kwargs.get("cache_position", None) - # if cache_position is None: - # past_length = 0 - # else: - # past_length = cache_position[-1] + 1 - # input_ids = input_ids[:, past_length:] - # position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - # cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + # keep cache_position implementation as None for HPU cache_position = None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1071,7 +1127,7 @@ def prepare_inputs_for_generation( model_inputs.update( { - "position_ids": position_ids.contiguous(), + "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), @@ -1083,6 +1139,7 @@ def prepare_inputs_for_generation( "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), @@ -1094,6 +1151,15 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and has_fused_rope: # TODO: remove `.clone()` when it is fixed in SynapseAI + if k.dtype == torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + position_ids, + ) return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ), FusedRoPE.apply( diff --git a/optimum/habana/transformers/models/llava_next/__init__.py b/optimum/habana/transformers/models/llava_next/__init__.py new file mode 100644 index 000000000..d20661610 --- /dev/null +++ b/optimum/habana/transformers/models/llava_next/__init__.py @@ -0,0 +1 @@ +from .modeling_llava_next import GaudiLlavaNextForConditionalGeneration diff --git a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py new file mode 100644 index 000000000..7fd76c564 --- /dev/null +++ b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py @@ -0,0 +1,389 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.models.llava_next.modeling_llava_next import ( + LlavaNextCausalLMOutputWithPast, + LlavaNextForConditionalGeneration, + get_anyres_image_grid_shape, + unpad_image, +) +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GaudiLlavaNextForConditionalGeneration(LlavaNextForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L433 + The only differences are: + - add new args token_idx + - Moved the process of merging images into inputs_embeds into prepare_inputs_for_generation + """ + + if token_idx is not None: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx + self.image_offset, + ) + + if inputs_embeds.shape[1] != 1 and pixel_values is not None: + batch_size, seq_len = self.text_tokens_pos.shape + batch_indices = torch.arange(batch_size).repeat_interleave(seq_len) + logits = outputs[0][batch_indices, self.text_tokens_pos.reshape(-1), :].reshape( + batch_size, seq_len, -1 + ) + else: + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaNextCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + else: + return super().forward( + input_ids=input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Copied from https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L356 + # Remove the step 6: Mask out the embedding at padding positions + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + text_tokens_pos = new_token_positions + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + # batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + # indices_to_mask = new_token_positions[batch_indices, pad_indices] + + # final_embedding[batch_indices, indices_to_mask] = 0 + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids, text_tokens_pos + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + else: + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: + vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True) + + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + inputs_embeds, attention_mask, labels, position_ids, self.text_tokens_pos = ( + self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + ) + self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. + if labels is None: + labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None and pixel_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "token_idx": token_idx, + "image_sizes": image_sizes, + "labels": labels, + } + ) + + return model_inputs diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 573ba4745..42b573515 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -222,6 +222,8 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + config.rope_scaling = config.rope_scaling if hasattr(config, "rope_scaling") else None + self.config = config self.k_cache = KVCache() self.v_cache = KVCache() self.matmul_qk = Matmul() @@ -841,6 +843,7 @@ def prepare_inputs_for_generation( - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + reuse_cache = kwargs.get("reuse_cache", False) token_idx = kwargs.get("token_idx", None) # Omit tokens covered by past_key_values @@ -875,6 +878,10 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/optimum/habana/transformers/models/mixtral/configuration_mixtral.py b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py index e87888966..90a783f0d 100644 --- a/optimum/habana/transformers/models/mixtral/configuration_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py @@ -32,6 +32,7 @@ def __init__( num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, + router_jitter_noise=0.0, rope_scaling=None, **kwargs, ): @@ -58,6 +59,7 @@ def __init__( num_local_experts, output_router_logits, router_aux_loss_coef, + router_jitter_noise, **kwargs, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index fcf681174..3a15c6df1 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -220,6 +220,8 @@ def forward(q, k, v, mask, causal, q_block_size): class GaudiMixtralAttention(MixtralAttention): def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + config.rope_scaling = config.rope_scaling if hasattr(config, "rope_scaling") else None + self.config = config self._init_rope() self.k_cache = KVCache() self.v_cache = KVCache() @@ -349,6 +351,9 @@ def forward( past_key_value = None if FusedSDPA: + if query_states.dtype != key_states.dtype: + key_states = key_states.type(query_states.dtype) + value_states = value_states.type(query_states.dtype) # support long sequences exceeding 8192 if not self.training and q_len == key_states.size(-2) and q_len > 8192: htcore.mark_step() diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index 872d1e7f4..4c9d9dd4d 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -30,8 +30,8 @@ from transformers.models.phi.configuration_phi import PhiConfig from transformers.models.phi.modeling_phi import ( PhiAttention, - PhiDecoderLayer, PhiForCausalLM, + PhiMLP, PhiModel, apply_rotary_pos_emb, ) @@ -293,10 +293,13 @@ def forward( return attn_output, attn_weights, past_key_value -class GaudiPhiDecoderLayer(PhiDecoderLayer): +class GaudiPhiDecoderLayer(torch.nn.Module): def __init__(self, config: PhiConfig, layer_idx: int): - super().__init__(config, layer_idx) - self.self_attn = GaudiPhiAttention(config, layer_idx) + super().__init__() + self.self_attn = GaudiPhiAttention(config, layer_idx=layer_idx) + self.mlp = PhiMLP(config) + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) diff --git a/optimum/habana/transformers/models/qwen2/__init__.py b/optimum/habana/transformers/models/qwen2/__init__.py index 98c52f9f1..8103996d8 100644 --- a/optimum/habana/transformers/models/qwen2/__init__.py +++ b/optimum/habana/transformers/models/qwen2/__init__.py @@ -1,6 +1,8 @@ from .modeling_qwen2 import ( + GaudiQwen2Attention, GaudiQwen2DecoderLayer, GaudiQwen2ForCausalLM, - gaudi_qwen2_attention_forward, - gaudi_qwen2_model_forward, + GaudiQwen2MLP, + GaudiQwen2Model, + gaudi_qwen2_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 04a77f261..f192cf489 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -21,179 +21,447 @@ from typing import List, Optional, Tuple, Union import torch -from torch import nn -from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +import torch.utils.checkpoint from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, - Qwen2Config, Qwen2DecoderLayer, Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2RMSNorm, apply_rotary_pos_emb, - repeat_kv, + logger, ) -from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -logger = logging.get_logger(__name__) - - -def gaudi_qwen2_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from Qwen2Attention.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/qwen2/modeling_qwen2.py - The only differences are: - - add new args token_idx - - optimize KV cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE +except ImportError: + print("Not using HPU fused kernel for apply_rotary_pos_emb") + FusedRoPE = None - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: - # When token_idx is used, static seq len = (input token len + max output token len) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) +import habana_frameworks.torch.core as htcore - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) +def gaudi_qwen2_rmsnorm_forward(self, hidden_states): + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class GaudiQwen2MLP(Qwen2MLP): + def pre_mlp_forward(self, x): + inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(inputs) + return output + + def mlp_all_reduce(self, x): + if hasattr(self.down_proj, "all_reduce"): + self.down_proj.all_reduce(x) + + def post_mlp_forward(self, x): + if hasattr(self.down_proj, "post_all_reduce"): + return self.down_proj.post_all_reduce(x) + return x + + +def gaudi_qwen2_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) - attn_weights = attn_weights + attention_mask + return query_states, key_states, value_states, attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 - if not output_attentions: - attn_weights = None + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) - return attn_output, attn_weights, past_key_value + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) -class GaudiQwen2DecoderLayer(Qwen2DecoderLayer): - def __init__(self, config: Qwen2Config, layer_idx: int): + +class GaudiQwen2Attention(Qwen2Attention): + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - def forward( + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.block_size = 4096 + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size): + """ + Gaudi version of Flash Attention V1 to support long sequence at prompt phase + Causal mask is not supported in this optimization + """ + q_len = query_layer.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) + if attention_mask is not None: + attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) + + row_o_list = [] + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = query_layer[:, :, s:e, :] + row_mask = attention_mask[:, :, s:e, :] + attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None) + row_o_list.append(attn_output_partial) + attn_output = torch.cat(row_o_list, dim=-2) + + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + + return attn_output + + def pre_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Copied from Qwen2DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/qwen2/modeling_qwen2.py The only differences are: - add new args token_idx + - optimize KV cache + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute """ if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - residual = hidden_states + bsz, q_len, _ = hidden_states.size() - hidden_states = self.input_layernorm(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None - # Fully Connected + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + if q_len > 8192: + attn_output = self.gaudi_flash_attn_v1( + query_states, key_states, value_states, attention_mask, 0.0, self.block_size + ) + htcore.mark_step() + else: + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + + else: + query_states, key_states, value_states, attention_mask = gaudi_qwen2_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def attention_all_reduce(self, attn_output): + if hasattr(self.o_proj, "all_reduce"): + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.o_proj, "post_all_reduce"): + self.o_proj.post_all_reduce(attn_output) + return attn_output + + +class GaudiQwen2DecoderLayer(Qwen2DecoderLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super(Qwen2DecoderLayer, self).__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GaudiQwen2Attention(config, layer_idx) + + self.mlp = GaudiQwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) outputs = (hidden_states,) @@ -205,153 +473,257 @@ def forward( return outputs + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + cache_idx=cache_idx, + ) + return hidden_states, attn_weights, present_key_value + + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) -def gaudi_qwen2_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - """ - Copied from Qwen2Model.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/qwen2/modeling_qwen2.py - The only differences are: - - add new args token_idx - - replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + hidden_states = self.post_attention_layernorm(hidden_states) - past_key_values_length = 0 + hidden_states = self.mlp.pre_mlp_forward(hidden_states) + return hidden_states, residual - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if token_idx is None: - past_key_values_length = past_key_values.get_usable_length(seq_length) + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + return hidden_states - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - # 4d mask is passed through the layers - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) +class GaudiQwen2Model(Qwen2Model): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - hidden_states = inputs_embeds + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = layer_outputs[0] + self._attn_implementation = "eager" - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") - if output_attentions: - all_self_attns += (layer_outputs[1],) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - hidden_states = self.norm(hidden_states) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + use_new_cache = False # Ignoring new Cache path for HPU + past_seen_tokens = 0 - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_usable_length(seq_length) + else: + past_seen_tokens = past_key_values[0][0].shape[2] + + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # HPU specific mask generation + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() + + for layer_idx, decoder_layer in enumerate(self.layers): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) class GaudiQwen2ForCausalLM(Qwen2ForCausalLM): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None, @@ -364,20 +736,28 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + trim_logits: Optional[bool] = False, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - Inherits from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/qwen2/modeling_qwen2.py - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not hasattr(self.config, "_attn_implementation"): + setattr(self.config, "_attn_implementation", "eager") + else: + self.config._attn_implementation = "eager" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -389,12 +769,26 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + + logits = self.lm_head(hidden_states).float() loss = None if labels is not None: @@ -402,7 +796,7 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism @@ -422,21 +816,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs ): - """ - Inherits from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/qwen2/modeling_qwen2.py - The only differences are: - - add new args token_idx - - add token_idx into model_inputs - - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - """ + past_length = 0 - token_idx = kwargs.get("token_idx", None) - # Omit tokens covered by past_key_values + reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -464,8 +852,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -478,6 +868,8 @@ def prepare_inputs_for_generation( else: position_ids = position_ids[:, -input_ids.shape[1] :] + cache_position = None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -486,11 +878,32 @@ def prepare_inputs_for_generation( model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs + + +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and FusedRoPE: + # TODO: remove `.clone()` when it is fixed in SynapseAI + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/seamless_m4t/__init__.py b/optimum/habana/transformers/models/seamless_m4t/__init__.py new file mode 100644 index 000000000..7bd87abee --- /dev/null +++ b/optimum/habana/transformers/models/seamless_m4t/__init__.py @@ -0,0 +1,12 @@ +from .modeling_seamless_m4t import ( + gaudi_SeamlessM4TAttention_forward, + gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, + gaudi_SeamlessM4TDecoder_forward, + gaudi_SeamlessM4TDecoderLayer_forward, + gaudi_SeamlessM4TForTextToSpeech_forward, + gaudi_SeamlessM4TForTextToSpeech_generate, + gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward, + gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation, + gaudi_SeamlessM4TTextToUnitModel_forward, +) diff --git a/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py new file mode 100644 index 000000000..d4728c30f --- /dev/null +++ b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -0,0 +1,866 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.models.seamless_m4t.modeling_seamless_m4t import ( + SeamlessM4TForTextToSpeech, + SeamlessM4TGenerationOutput, + _compute_new_attention_mask, + format_speech_generation_kwargs, + shift_tokens_right, +) +from transformers.utils import logging + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) + + +logger = logging.get_logger(__name__) + + +def gaudi_SeamlessM4TAttention_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from SeamlessM4TAttention.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `encoder_hidden_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + if token_idx is None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +def gaudi_SeamlessM4TDecoderLayer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + token_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Copied from SeamlessM4TDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + token_idx=token_idx, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +def gaudi_SeamlessM4TDecoder_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + """ + Copied from SeamlessM4TDecoder.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + if past_key_values_length != 0 and token_idx is not None: + past_key_values_length = token_idx - 1 + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +def gaudi_SeamlessM4TTextToUnitModel_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + """ + Copied from SeamlessM4TTextToUnitModel.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +def gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + """ + Copied from SeamlessM4TTextToUnitForConditionalGeneration.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +def gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, +): + """ + Copied from SeamlessM4TTextToUnitForConditionalGeneration.prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + token_idx = kwargs.get("token_idx", None) + # cut decoder_input_ids if past is used + if past_key_values is not None: + if token_idx is None: + decoder_input_ids = decoder_input_ids[:, -1:] + else: + decoder_input_ids = torch.index_select(decoder_input_ids, 1, token_idx - 1) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, + "token_idx": token_idx, + "decoder_attention_mask": kwargs.get("decoder_attention_mask", None), + } + + +def gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Copied from SeamlessM4TCodeHifiGan._get_output_hifigan_lengths: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - fix torch.div issue + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + torch.div(input_length.item() + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + +def gaudi_SeamlessM4TForTextToSpeech_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + """ + Copied from SeamlessM4TForTextToSpeech.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx args + """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@torch.no_grad() +def gaudi_SeamlessM4TForTextToSpeech_generate( + self, + input_ids: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, +) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Copied from SeamlessM4TForTextToSpeech.generate: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - delete pad id for unit_ids output + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + if kwargs.get("hpu_graphs", True): + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + if not hasattr(self, "clear_cache"): + self = wrap_in_hpu_graph(self) + if not hasattr(self.t2u_model, "clear_cache"): + self.t2u_model = wrap_in_hpu_graph(self.t2u_model) + if not hasattr(self.vocoder, "clear_cache"): + self.vocoder = wrap_in_hpu_graph(self.vocoder) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super(SeamlessM4TForTextToSpeech, self).generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( + self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + seq_lens = (unit_ids != self.config.t2u_pad_token_id).int().sum(1) + unit_ids = unit_ids[:, 0:seq_lens] + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where(unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + +def gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, +): + """ + Copied from SeamlessM4TForTextToSpeech.prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py + The only differences are: + - add token_idx + """ + token_idx = kwargs.get("token_idx", None) + # cut decoder_input_ids if past is used + if past_key_values is not None: + if token_idx is None: + decoder_input_ids = decoder_input_ids[:, -1:] + else: + decoder_input_ids = torch.index_select(decoder_input_ids, 1, token_idx - 1) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, + "token_idx": token_idx, + "decoder_attention_mask": kwargs.get("decoder_attention_mask", None), + } diff --git a/optimum/habana/transformers/models/stablelm/__init__.py b/optimum/habana/transformers/models/stablelm/__init__.py index 659a0b1c0..9ae5d787b 100644 --- a/optimum/habana/transformers/models/stablelm/__init__.py +++ b/optimum/habana/transformers/models/stablelm/__init__.py @@ -1,6 +1,6 @@ from .modeling_stablelm import ( + GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, gaudi_stablelm_attention_forward, - gaudi_stablelm_decoder_layer_forward, gaudi_stablelm_model_forward, ) diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index 38500fe2e..f53994a5b 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -6,7 +6,14 @@ from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.stablelm.modeling_stablelm import StableLmForCausalLM, apply_rotary_pos_emb, repeat_kv +from transformers.models.stablelm.configuration_stablelm import StableLmConfig +from transformers.models.stablelm.modeling_stablelm import ( + StableLmAttention, + StableLmForCausalLM, + StableLmMLP, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -43,6 +50,10 @@ def gaudi_stablelm_attention_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -131,54 +142,74 @@ def gaudi_stablelm_attention_forward( return attn_output, attn_weights, past_key_value -def gaudi_stablelm_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from StableLmDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py - The only differences are: - - add new args token_idx - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states +class GaudiStableLmDecoderLayer(torch.nn.Module): + def __init__(self, config: StableLmConfig, layer_idx: int): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.hidden_size = config.hidden_size + self.self_attn = StableLmAttention(config, layer_idx=layer_idx) + self.mlp = StableLmMLP(config) + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = None + if not self.use_parallel_residual: + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = torch.nn.Dropout(config.hidden_dropout) - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from StableLmDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py + The only differences are: + - add new args token_idx + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + self_attn_output, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + ) - hidden_states = self.dropout(hidden_states) - hidden_states = hidden_states + residual + # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward + if self.use_parallel_residual: + # x = x + attn(ln1(x)) + mlp(ln1(x)) + # Fully Connected + mlp_output = self.mlp(hidden_states) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + self_attn_output + mlp_output + else: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + residual = residual + self_attn_output + # Fully Connected + mlp_output = self.mlp(self.post_attention_layernorm(residual)) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + mlp_output - outputs = (hidden_states,) + outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) + if output_attentions: + outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + if use_cache: + outputs += (present_key_value,) - return outputs + return outputs def gaudi_stablelm_model_forward( diff --git a/optimum/habana/transformers/models/starcoder2/__init__.py b/optimum/habana/transformers/models/starcoder2/__init__.py new file mode 100644 index 000000000..c4ac4dfc2 --- /dev/null +++ b/optimum/habana/transformers/models/starcoder2/__init__.py @@ -0,0 +1,6 @@ +from .modeling_starcoder2 import ( + GaudiStarcoder2DecoderLayer, + GaudiStarcoder2ForCausalLM, + gaudi_starcoder2_attention_forward, + gaudi_starcoder2_model_forward, +) diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py new file mode 100644 index 000000000..c40283331 --- /dev/null +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -0,0 +1,472 @@ +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.starcoder2.modeling_starcoder2 import ( + Starcoder2Attention, + Starcoder2Config, + Starcoder2DecoderLayer, + Starcoder2ForCausalLM, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) + + +logger = logging.get_logger(__name__) + + +def gaudi_starcoder2_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from Starcoder2Attention.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args token_idx + - optimize KV cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: + # When token_idx is used, static seq len = (input token len + max output token len) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + if token_idx is not None: + if 0 <= self.layer_idx < len(past_key_value.key_cache): + past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) + past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GaudiStarcoder2DecoderLayer(Starcoder2DecoderLayer): + def __init__(self, config: Starcoder2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = Starcoder2Attention(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from Starcoder2DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args token_idx + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def gaudi_starcoder2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from Starcoder2Model.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args token_idx + - replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if token_idx is None: + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + # 4d mask is passed through the layers + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiStarcoder2ForCausalLM(Starcoder2ForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args token_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + """ + Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args token_idx + - add token_idx into model_inputs + - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx + - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + """ + token_idx = kwargs.get("token_idx", None) + # Omit tokens covered by past_key_values + if past_key_values is not None: + if token_idx is None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + else: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + } + ) + return model_inputs diff --git a/optimum/habana/transformers/models/vits/__init__.py b/optimum/habana/transformers/models/vits/__init__.py new file mode 100644 index 000000000..b0cf4ecfe --- /dev/null +++ b/optimum/habana/transformers/models/vits/__init__.py @@ -0,0 +1,4 @@ +from .modeling_vits import ( + gaudi_unconstrained_rational_quadratic_spline, + gaudi_VitsResidualCouplingLayer_forward, +) diff --git a/optimum/habana/transformers/models/vits/modeling_vits.py b/optimum/habana/transformers/models/vits/modeling_vits.py new file mode 100644 index 000000000..174e8a3ac --- /dev/null +++ b/optimum/habana/transformers/models/vits/modeling_vits.py @@ -0,0 +1,77 @@ +import numpy as np +import torch +from torch import nn +from transformers.models.vits.modeling_vits import _rational_quadratic_spline +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +def gaudi_unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + reverse=False, + tail_bound=5.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, +): + """ + Copied from _unconstrained_rational_quadratic_spline: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/vits/modeling_vits.py#L126 + The only differences are: + - WA to fix hpu graph accuracy issue + """ + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + log_abs_det = torch.zeros_like(inputs) + constant = np.log(np.exp(1 - min_derivative) - 1) + + unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + log_abs_det[outside_interval_mask] = 0.0 + + outputs_i, log_abs_det_i = _rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + reverse=reverse, + tail_bound=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + outputs = outputs_i * inside_interval_mask + outputs * outside_interval_mask + log_abs_det = log_abs_det_i * inside_interval_mask + log_abs_det * outside_interval_mask + return outputs, log_abs_det + + +def gaudi_VitsResidualCouplingLayer_forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): + """ + Copied from VitsResidualCouplingLayer:forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/vits/modeling_vits.py + The only differences are: + - WA to fix torch.flip issue after conv1d + """ + first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) + hidden_states = self.conv_pre(first_half) * padding_mask + hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) + mean = self.conv_post(hidden_states) * padding_mask + log_stddev = torch.zeros_like(mean) + + if not reverse: + second_half = mean.cpu() + second_half * torch.exp(log_stddev) * padding_mask + outputs = torch.cat([first_half, second_half], dim=1) + log_determinant = torch.sum(log_stddev, [1, 2]) + return outputs, log_determinant + else: + second_half = (second_half - mean.cpu()) * torch.exp(-log_stddev) * padding_mask + outputs = torch.cat([first_half, second_half], dim=1) + return outputs, None diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 4386dac5c..5fa006873 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -36,7 +36,7 @@ from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin, save_fsdp_model from huggingface_hub import upload_folder from packaging import version -from torch.utils.data import DataLoader, Dataset, RandomSampler +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow @@ -52,6 +52,7 @@ from transformers.trainer_callback import TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LengthGroupedSampler, SequentialDistributedSampler, @@ -60,7 +61,6 @@ get_model_param_count, nested_concat, nested_detach, - nested_numpify, reissue_pt_warnings, remove_dummy_checkpoint, ) @@ -139,6 +139,10 @@ DATA_SAMPLERS = [RandomSampler, SeedableRandomSampler] +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + + def _is_peft_model(model): if is_peft_available(): classes_to_check = (PeftModel,) if is_peft_available() else () @@ -151,6 +155,11 @@ def _is_peft_model(model): return False +if TYPE_CHECKING: + if is_datasets_available(): + import datasets + + logger = logging.get_logger(__name__) @@ -175,8 +184,8 @@ def __init__( gaudi_config: GaudiConfig = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, @@ -367,7 +376,17 @@ def create_optimizer(self): "eps": self.args.adam_epsilon, } else: - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, self.model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) @@ -687,6 +706,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if delay_optimizer_creation: if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) @@ -766,6 +786,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) @@ -905,7 +926,12 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): "a `main_input_name` attribute to the model class you are using." ) else: - self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel() + input_device = inputs[main_input_name].device + self.state.num_input_tokens_seen += torch.sum( + self.accelerator.gather( + torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) + ) + ).item() if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -925,9 +951,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - # attn_softmax_bf16 and use_flash_attention is enabled only for llama + # attn_softmax_bf16 and use_flash_attention is enabled only for llama and qwen2 if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama": + if self.model.config.model_type in ["llama", "qwen2"]: if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: @@ -965,6 +991,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) @@ -1044,7 +1074,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # add remaining tr_loss self._total_loss_scalar += tr_loss.item() - train_loss = self._total_loss_scalar / self.state.global_step + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step # Warmup steps are removed from the calculation of speed metrics num_samples_for_speed_metrics = num_train_samples - args.throughput_warmup_steps * total_train_batch_size @@ -1182,6 +1213,9 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign # Moving it here so the grad tensor is only copied when it's needed. if is_accelerate_available() and self.accelerator.distributed_type == GaudiDistributedType.DEEPSPEED: grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() else: if ( _grad_norm is not None @@ -1273,21 +1307,13 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0: - logger.warning( - f"Checkpoint destination directory {output_dir} already exists and is non-empty. " - "Saving will proceed but saved results may be invalid." - ) - staging_output_dir = output_dir - else: - staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") - self.save_model(staging_output_dir, _internal_call=True) + self.save_model(output_dir, _internal_call=True) if not self.args.save_only_model: # Save optimizer and scheduler - self._save_optimizer_and_scheduler(staging_output_dir) + self._save_optimizer_and_scheduler(output_dir) # Save RNG state - self._save_rng_state(staging_output_dir) + self._save_rng_state(output_dir) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -1307,39 +1333,16 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: - self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME)) + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: - self._push_from_checkpoint(staging_output_dir) - - # Place checkpoint in final location after all saving is finished. - # First wait for everyone to finish writing - self.args.distributed_state.wait_for_everyone() - - # Then go through the rewriting process, only renaming and rotating from main process(es) - if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero(): - if staging_output_dir != output_dir: - if os.path.exists(staging_output_dir): - os.rename(staging_output_dir, output_dir) - - # Ensure rename completed in cases where os.rename is not atomic - # And can only happen on non-windows based systems - if os.name != "nt": - fd = os.open(output_dir, os.O_RDONLY) - os.fsync(fd) - os.close(fd) - - # Maybe delete some older checkpoints. - if self.args.should_save: - # Solely rely on numerical checkpoint id for rotation. - # mtime is not reliable especially on some fuse fs in cloud environments. - self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) - elif self.is_local_process_zero(): - # Clean up the remaining staging checkpoint folders on other nodes - if staging_output_dir != output_dir and os.path.exists(staging_output_dir): - shutil.rmtree(staging_output_dir) - - self.args.distributed_state.wait_for_everyone() + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + # Solely rely on numerical checkpoint id for rotation. + # mtime is not reliable especially on some fuse fs in cloud environments. + self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training @@ -1470,7 +1473,7 @@ def log(self, logs: Dict[str, float]) -> None: The values to log. """ if self.state.epoch is not None: - logs["epoch"] = round(self.state.epoch, 2) + logs["epoch"] = self.state.epoch if self.args.include_num_input_tokens_seen: logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen @@ -1788,20 +1791,14 @@ def evaluation_loop( self._past = None # Initialize containers - # losses/preds/labels on HPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None - inputs_host = None - - # losses/preds/labels on CPU (final containers) - all_losses = None - all_preds = None - all_labels = None - all_inputs = None - # Will be useful when we have an iterable dataset so don't know its length. + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 + # Main evaluation loop for step, inputs in enumerate(dataloader): if ( @@ -1818,9 +1815,9 @@ def evaluation_loop( if batch_size is None: batch_size = observed_batch_size - # attn_softmax_bf16 and use_flash_attention are enabled only for llama + # attn_softmax_bf16 and use_flash_attention are enabled only for llama and qwen2 if hasattr(self.model, "generation_config") and self.model.generation_config is not None: - if self.model.config.model_type == "llama": + if self.model.config.model_type in ["llama", "qwen2"]: if self.model.generation_config.attn_softmax_bf16: inputs["attn_softmax_bf16"] = True if self.model.generation_config.use_flash_attention: @@ -1840,22 +1837,18 @@ def evaluation_loop( if logits is not None: logits_dtype = get_dtype(logits) - # Update containers on host + # Update containers if loss is not None: losses = self.gather_function((loss.repeat(batch_size))) - losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) + all_losses.add(losses) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + all_labels.add(labels) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - inputs_host = ( - inputs_decode - if inputs_host is None - else nested_concat(inputs_host, inputs_decode, padding_index=-100) - ) + all_inputs.add(inputs_decode) if logits is not None: if args.use_habana and logits_dtype != "float32": logits = to_device_dtype(logits, target_dtype=torch.float32) @@ -1863,35 +1856,16 @@ def evaluation_loop( if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + all_preds.add(logits) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - if args.use_habana and logits_dtype != "float32": - preds_host = to_device_dtype(preds_host, target_dtype=torch.float32) - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode - if all_inputs is None - else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = ( - labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) - ) - - # Set back to None to begin a new accumulation - losses_host, preds_host, inputs_host, labels_host = None, None, None, None + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() # nested concat does accumulation on tensors of variable length. # Added mark step here to avoid graph recompile @@ -1905,22 +1879,10 @@ def evaluation_loop( delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - if args.use_habana and logits_dtype != "float32": - preds_host = to_device_dtype(preds_host, target_dtype=torch.float32) - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() # Number of samples if has_length(eval_dataset): @@ -1955,7 +1917,9 @@ def evaluation_loop( # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) - if all_losses is not None: + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() # Prefix all keys with metric_key_prefix + '_' @@ -2183,6 +2147,7 @@ def prediction_loop( logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None @@ -2282,17 +2247,49 @@ def prediction_loop( return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) def create_accelerator_and_postprocess(self): - grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + elif "num_steps" not in grad_acc_kwargs: + # take the gradient_accumulation_steps setting from TrainingArguments. + grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + "gradient_accumulation_plugin": gradient_accumulation_plugin, + "distribution_strategy": self.args.distribution_strategy, + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + # create accelerator object - self.accelerator = GaudiAccelerator( - deepspeed_plugin=self.args.deepspeed_plugin, - gradient_accumulation_plugin=gradient_accumulation_plugin, - distribution_strategy=self.args.distribution_strategy, - **self.args.accelerator_config.to_dict(), - ) + self.accelerator = GaudiAccelerator(**args) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics diff --git a/optimum/habana/transformers/trainer_seq2seq.py b/optimum/habana/transformers/trainer_seq2seq.py index 2be3c617b..52977e30a 100644 --- a/optimum/habana/transformers/trainer_seq2seq.py +++ b/optimum/habana/transformers/trainer_seq2seq.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -93,25 +94,38 @@ def load_generation_config(gen_config_arg: Union[str, GaudiGenerationConfig]) -> # GenerationConfig provided, nothing to do if isinstance(gen_config_arg, GaudiGenerationConfig): - return deepcopy(gen_config_arg) - - # str or Path - pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg - config_file_name = None - - # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL - # This step is required in order to determine config_file_name - if pretrained_model_name.is_file(): - config_file_name = pretrained_model_name.name - pretrained_model_name = pretrained_model_name.parent - # dir path - elif pretrained_model_name.is_dir(): - pass - # model id or URL + gen_config = deepcopy(gen_config_arg) else: - pretrained_model_name = gen_config_arg + # str or Path + pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg + config_file_name = None + + # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL + # This step is required in order to determine config_file_name + if pretrained_model_name.is_file(): + config_file_name = pretrained_model_name.name + pretrained_model_name = pretrained_model_name.parent + # dir path + elif pretrained_model_name.is_dir(): + pass + # model id or URL + else: + pretrained_model_name = gen_config_arg + + gen_config = GaudiGenerationConfig.from_pretrained(pretrained_model_name, config_file_name) - gen_config = GaudiGenerationConfig.from_pretrained(pretrained_model_name, config_file_name) + # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws + # an exception if there are warnings at validation time. + try: + with warnings.catch_warnings(record=True) as caught_warnings: + gen_config.validate() + if len(caught_warnings) > 0: + raise ValueError(str([w.message for w in caught_warnings])) + except ValueError as exc: + raise ValueError( + "The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings " + "and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc) + ) return gen_config def evaluate( diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 66402d242..eff3c1ede 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -28,9 +28,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig from transformers.trainer_utils import EvaluationStrategy, FSDPOption, HubStrategy, IntervalStrategy, SchedulerType from transformers.training_args import ( + _VALID_DICT_FIELDS, OptimizerNames, ParallelMode, TrainingArguments, + _convert_str_dict, default_logdir, ) from transformers.utils import ( @@ -344,6 +346,17 @@ def __post_init__(self): if self.throughput_warmup_steps < 0: raise ValueError("--throughput_warmup_steps must be positive.") + # Parse in args that could be `dict` sent in from the CLI as a string + for field in _VALID_DICT_FIELDS: + passed_value = getattr(self, field) + # We only want to do this if the str starts with a bracket to indiciate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, field, loaded_dict) + # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home # see https://github.com/huggingface/transformers/issues/10628 @@ -611,7 +624,7 @@ def __post_init__(self): ) prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() - os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefect", "false")) + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")) os.environ[f"{prefix}SYNC_MODULE_STATES"] = str(self.fsdp_config.get("sync_module_states", "true")) os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")) os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str( @@ -624,6 +637,13 @@ def __post_init__(self): self.accelerator_config = AcceleratorConfig() elif isinstance(self.accelerator_config, dict): self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) else: self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) if self.dispatch_batches is not None: diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py index 2ce160798..1d74d7e33 100644 --- a/optimum/habana/trl/trainer/dpo_trainer.py +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -83,6 +83,10 @@ def __init__( precompute_ref_log_probs: bool = False, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, + model_adapter_name: Optional[str] = None, + ref_adapter_name: Optional[str] = None, + reference_free: bool = False, + force_use_ref_model: bool = False, ): """ Copied from DPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L127 @@ -118,6 +122,10 @@ def __init__( ) ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" @@ -184,6 +192,9 @@ def make_inputs_require_grad(module, input, output): self.is_encoder_decoder = is_encoder_decoder self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + self.reference_free = reference_free if ref_model: self.ref_model = ref_model diff --git a/optimum/habana/trl/trainer/ppo_trainer.py b/optimum/habana/trl/trainer/ppo_trainer.py index ffd46fd1c..9f72b02e4 100644 --- a/optimum/habana/trl/trainer/ppo_trainer.py +++ b/optimum/habana/trl/trainer/ppo_trainer.py @@ -539,11 +539,11 @@ def step( active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) - rewards, non_score_reward = self.compute_rewards( + rewards, non_score_reward, kls = self.compute_rewards( scores, active_full_logprobs, ref_full_logprobs, masks ) else: - rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) + rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) timing["time/ppo/compute_rewards"] = time.time() - t t = time.time() @@ -648,6 +648,7 @@ def step( masks=masks, queries=queries, responses=responses, + kls=kls, ) # Gather/Reduce stats from all processes if self.is_distributed: diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 463a1eda0..f5c3345a4 100755 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -31,7 +31,7 @@ logger = logging.get_logger(__name__) -CURRENTLY_VALIDATED_SYNAPSE_VERSION = version.parse("1.15.0") +CURRENTLY_VALIDATED_SYNAPSE_VERSION = version.parse("1.16.0") def to_device_dtype(my_input: Any, target_device: torch.device = None, target_dtype: torch.dtype = None): @@ -306,7 +306,7 @@ def noop(): activities=activities, on_trace_ready=torch.profiler.tensorboard_trace_handler(output_dir), record_shapes=record_shapes, - with_stack=True, + with_stack=False, ) self.start = profiler.start self.stop = profiler.stop diff --git a/pyproject.toml b/pyproject.toml index a26b36870..b7896da5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,12 @@ # limitations under the License. [tool.ruff] -# Never enforce `E501` (line length violations). -lint.ignore = ["C901", "E501", "E741", "F402", "F823"] -lint.select = ["C", "E", "F", "I", "W"] line-length = 119 + +[tool.ruff.lint] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "F402", "F823"] +select = ["C", "E", "F", "I", "W"] exclude = ["text-generation-inference"] # Ignore import violations in all `__init__.py` files. diff --git a/setup.py b/setup.py index f94002ecc..9c0825790 100644 --- a/setup.py +++ b/setup.py @@ -29,13 +29,13 @@ INSTALL_REQUIRES = [ - "transformers >= 4.38.0, < 4.39.0", + "transformers >= 4.40.0, < 4.41.0", "optimum", "torch", "accelerate < 0.28.0", "diffusers >= 0.26.0, < 0.27.0", - "pytest < 8.0.0", "huggingface_hub < 0.23.0", + "datasets < 2.20.0", ] TESTS_REQUIRE = [ @@ -46,6 +46,7 @@ "sentencepiece", "datasets", "safetensors", + "pytest < 8.0.0", ] QUALITY_REQUIRES = [ diff --git a/tests/baselines/bert_large_uncased_whole_word_masking.json b/tests/baselines/bert_large_uncased_whole_word_masking.json old mode 100644 new mode 100755 index c9d67aeee..9bfd3a1f8 --- a/tests/baselines/bert_large_uncased_whole_word_masking.json +++ b/tests/baselines/bert_large_uncased_whole_word_masking.json @@ -66,8 +66,8 @@ "learning_rate": 4e-5, "train_batch_size": 32, "eval_f1": 93.2753, - "train_runtime": 309.9491, - "train_samples_per_second": 302.089, + "train_runtime": 342.1722, + "train_samples_per_second": 286.435, "extra_arguments": [ "--max_seq_length 384", "--use_hpu_graphs_for_inference" @@ -76,9 +76,9 @@ "multi_card": { "learning_rate": 8e-5, "train_batch_size": 32, - "eval_f1": 93.0981, - "train_runtime": 78.387, - "train_samples_per_second": 2300.127, + "eval_f1": 92.6726, + "train_runtime": 77.307, + "train_samples_per_second": 2150.333, "extra_arguments": [ "--max_seq_length 384", "--use_hpu_graphs_for_inference" diff --git a/tests/baselines/distilbert_base_uncased.json b/tests/baselines/distilbert_base_uncased.json index a85474a07..00482ebee 100644 --- a/tests/baselines/distilbert_base_uncased.json +++ b/tests/baselines/distilbert_base_uncased.json @@ -38,8 +38,8 @@ "learning_rate": 2e-4, "train_batch_size": 64, "eval_f1": 84.5418, - "train_runtime": 108.8333, - "train_samples_per_second": 1676.689, + "train_runtime": 117.8054, + "train_samples_per_second": 1547.185, "extra_arguments": [ "--max_seq_length 384", "--use_hpu_graphs_for_inference" diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json index f534a19b6..6f64832f3 100644 --- a/tests/baselines/llama_7b.json +++ b/tests/baselines/llama_7b.json @@ -25,6 +25,42 @@ } }, "gaudi2": { + "databricks/databricks-dolly-15k": { + "num_train_epochs": 1, + "eval_batch_size": 8, + "distribution": { + "single_card": { + "learning_rate": 2e-4, + "train_batch_size": 16, + "perplexity": 3.8436, + "train_runtime": 113.9713, + "train_samples_per_second": 18.428, + "extra_arguments": [ + "--bf16", + "--gradient_accumulation_steps 1", + "--evaluation_strategy no", + "--save_strategy no", + "--warmup_ratio 0.03", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--use_hpu_graphs_for_inference", + "--lora_rank 8", + "--lora_alpha 16", + "--lora_dropout 0.1", + "--lora_target_modules q_proj v_proj", + "--dataset_concatenation", + "--low_cpu_mem_usage True", + "--adam_epsilon 1e-08", + "--validation_split_percentage 20", + "--attn_softmax_bf16", + "--max_steps 100", + "--input_column_name context", + "--output_column_name response" + ] + } + } + }, "tatsu-lab/alpaca": { "num_train_epochs": 3, "eval_batch_size": 4, @@ -232,16 +268,16 @@ "multi_card": { "learning_rate": 5e-4, "train_batch_size": 1, - "train_runtime": 16.5, - "train_samples_per_second": 63.161, - "perplexity": 1.224, + "train_runtime": 16.1, + "train_samples_per_second": 63.249, + "perplexity": 1.172, "extra_arguments": [ "--num_virtual_tokens 8", "--max_seq_length 64", "--logging_steps 1", "--report_to none", "--max_steps 100", - "--peft_type prompt_tuning", + "--peft_type prefix_tuning", "--max_seq_length 64", "--lr_scheduler_type cosine", "--warmup_steps 0", @@ -258,16 +294,16 @@ "multi_card": { "learning_rate": 5e-4, "train_batch_size": 1, - "train_runtime": 16.5, + "train_runtime": 18.7, "train_samples_per_second": 63.161, - "perplexity": 1.224, + "perplexity": 1.047, "extra_arguments": [ "--num_virtual_tokens 8", "--max_seq_length 64", "--logging_steps 1", "--report_to none", "--max_steps 100", - "--peft_type prompt_tuning", + "--peft_type p_tuning", "--max_seq_length 64", "--lr_scheduler_type cosine", "--warmup_steps 0", diff --git a/tests/baselines/roberta_large.json b/tests/baselines/roberta_large.json old mode 100644 new mode 100755 index 4f1ba4c89..d5ffc8200 --- a/tests/baselines/roberta_large.json +++ b/tests/baselines/roberta_large.json @@ -56,8 +56,8 @@ "learning_rate": 3e-5, "train_batch_size": 32, "eval_f1": 94.5886, - "train_runtime": 314.4407, - "train_samples_per_second": 300.578, + "train_runtime": 342.1653, + "train_samples_per_second": 284.873, "extra_arguments": [ "--max_seq_length 384", "--use_hpu_graphs_for_inference" @@ -66,9 +66,9 @@ "multi_card": { "learning_rate": 7e-5, "train_batch_size": 32, - "eval_f1": 94.4348, - "train_runtime": 79.1007, - "train_samples_per_second": 2280.328, + "eval_f1": 94.09, + "train_runtime": 77.333, + "train_samples_per_second": 2138.366, "extra_arguments": [ "--max_seq_length 384", "--use_hpu_graphs_for_inference" diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index bb69b0ebd..278d3485f 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -33,8 +33,8 @@ < check_min_version("4.42.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") 174,176d175 < freeze_feature_extractor: Optional[bool] = field( < default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."} diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt index 70538cb2a..2eebcc2d7 100644 --- a/tests/example_diff/run_clip.txt +++ b/tests/example_diff/run_clip.txt @@ -28,8 +28,8 @@ < check_min_version("4.42.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") 181a190,192 > mediapipe_dataloader: bool = field( > default=False, metadata={"help": "Turn on MediaPipe hardware-based accelerated data loading."} diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt index 49305fdaf..7db8099ec 100644 --- a/tests/example_diff/run_clm.txt +++ b/tests/example_diff/run_clm.txt @@ -38,8 +38,8 @@ > 63a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") > diff --git a/tests/example_diff/run_generation.txt b/tests/example_diff/run_generation.txt index 57f2d8580..5da903f6e 100644 --- a/tests/example_diff/run_generation.txt +++ b/tests/example_diff/run_generation.txt @@ -1,15 +1,14 @@ -17d16 -< """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) -19c18,19 -< +17c17,19 +< """Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)""" --- +> """ > Conditional text generation on Habana Gaudi/Gaudi2. > """ -22c22 +20c22 < import inspect --- > import json -24c24,28 +22c24,28 < from typing import Tuple --- > import math @@ -17,12 +16,12 @@ > import time > from itertools import cycle > from pathlib import Path -27,28c31 +25,26c31 < from accelerate import PartialState < from accelerate.utils import set_seed --- > from utils import adjust_batch, count_hpu_graphs, initialize_model -30,52c33 +28,50c33 < from transformers import ( < AutoTokenizer, < BloomForCausalLM, @@ -48,7 +47,7 @@ < from transformers.modeling_outputs import CausalLMOutputWithPast --- > from optimum.habana.utils import get_hpu_memory_stats -62,190d42 +60,282d42 < MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop < < MODEL_CLASSES = { @@ -178,7 +177,7 @@ < < return num_layer, num_head, num_embedding_size_per_head < -192,287c44,46 +< < def generate_past_key_values(model, batch_size, seq_len): < num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) < if model.config.model_type == "bloom": @@ -272,22 +271,22 @@ < """ < return self._default._reorder_cache(past_key_values, beam_idx) < -< +284,285c44,46 < def main(): < parser = argparse.ArgumentParser() --- > def setup_parser(parser): > # Arguments management > parser.add_argument("--device", "-d", type=str, choices=["hpu"], help="Device to run", default="hpu") -289c48 +287c48 < "--model_type", --- > "--model_name_or_path", -293c52 +291c52 < help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), --- > help="Path to pre-trained model (on the HF Hub or locally).", -296c55,83 +294c55,83 < "--model_name_or_path", --- > "--bf16", @@ -319,7 +318,7 @@ > ) > parser.add_argument( > "--dataset_name", -299,300c86,92 +297,298c86,92 < required=True, < help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), --- @@ -330,13 +329,13 @@ > default=None, > type=str, > help="If `--dataset_name` was given, this will be the name of the column to use as prompts for generation.", -302,306d93 +300,304d93 < < parser.add_argument("--prompt", type=str, default="") < parser.add_argument("--length", type=int, default=20) < parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") < -308,311c95,97 +306,309c95,97 < "--temperature", < type=float, < default=1.0, @@ -345,7 +344,7 @@ > "--do_sample", > action="store_true", > help="Whether to use sampling for generation.", -314c100,221 +312c100,233 < "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" --- > "--num_beams", @@ -377,6 +376,12 @@ > help="Number of steps to capture for profiling.", > ) > parser.add_argument( +> "--profiling_record_shapes", +> default=False, +> type=bool, +> help="Record shapes when enabling profiling.", +> ) +> parser.add_argument( > "--prompt", > default=None, > type=str, @@ -398,6 +403,12 @@ > help="Optional argument list of words that must be generated.", > ) > parser.add_argument( +> "--assistant_model", +> default=None, +> type=str, +> help="Optional argument to give a path to a draft/assistant model for assisted decoding.", +> ) +> parser.add_argument( > "--peft_model", > default=None, > type=str, @@ -470,18 +481,18 @@ > "--reduce_recompile", > action="store_true", > help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)", -316,321d222 +314,319d234 < parser.add_argument("--k", type=int, default=0) < parser.add_argument("--p", type=float, default=0.9) < < parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") < parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") < parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") -323c224,244 +321d235 < parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") +323c237,252 +< "--use_cpu", --- -> parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") -> parser.add_argument( > "--use_flash_attention", > action="store_true", > help="Whether to enable Habana Flash Attention, provided that the model supports it.", @@ -497,17 +508,25 @@ > help="Whether to enable Habana Flash Attention in causal mode on first token generation.", > ) > parser.add_argument( +> "--flash_attention_fast_softmax", +325c254 +< help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available", +--- +> help="Whether to enable Habana Flash Attention in fast softmax mode.", +327d255 +< parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") +329c257 +< "--fp16", +--- > "--book_source", -> action="store_true", +331c259,288 +< help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", +--- > help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", > ) -325c246 -< "--use_cpu", ---- +> parser.add_argument( > "--torch_compile", -327c248,262 -< help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available", ---- +> action="store_true", > help="Whether to use torch compiled model or not.", > ) > parser.add_argument( @@ -523,45 +542,33 @@ > "--csp", > type=str, > help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.", -329d263 -< parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") -331c265 -< "--fp16", ---- +> ) +> parser.add_argument( > "--disk_offload", -333c267,272 -< help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ---- +> action="store_true", > help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", > ) > parser.add_argument( > "--trust_remote_code", > action="store_true", > help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", -335d273 +333d289 < parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") -338,339c276,277 +336,339c292,293 < # Initialize the distributed state. < distributed_state = PartialState(cpu=args.use_cpu) +< +< logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") --- > if args.torch_compile: > args.use_hpu_graphs = False -341c279,280 -< logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") ---- -> if not args.use_hpu_graphs: -> args.limit_hpu_graphs = False -343,344c282,287 +341,342c295,296 < if args.seed is not None: < set_seed(args.seed) --- -> args.quant_config = os.getenv("QUANT_CONFIG", "") -> if args.quant_config == "" and args.disk_offload: -> logger.warning( -> "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." -> ) -> return args -346,373d288 +> if not args.use_hpu_graphs: +> args.limit_hpu_graphs = False +344,371c298,303 < # Initialize the model and tokenizer < try: < args.model_type = args.model_type.lower() @@ -590,12 +597,19 @@ < if requires_preprocessing: < prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) < preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) -375,386c290,293 +--- +> args.quant_config = os.getenv("QUANT_CONFIG", "") +> if args.quant_config == "" and args.disk_offload: +> logger.warning( +> "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." +> ) +> return args +373,376d304 < if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: < tokenizer_kwargs = {"add_space_before_punct_symbol": True} < else: < tokenizer_kwargs = {} -< +378,384c306,309 < encoded_prompt = tokenizer.encode( < preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs < ) @@ -607,8 +621,8 @@ > def main(): > parser = argparse.ArgumentParser() > args = setup_parser(parser) -> model, tokenizer, generation_config = initialize_model(args, logger) -388,389c295,488 +> model, assistant_model, tokenizer, generation_config = initialize_model(args, logger) +386,387c311,506 < if encoded_prompt.size()[-1] == 0: < input_ids = None --- @@ -708,12 +722,14 @@ > outputs = model.generate( > **input_tokens, > generation_config=generation_config, +> assistant_model=assistant_model, > lazy_mode=use_lazy_mode, > hpu_graphs=args.use_hpu_graphs, > profiling_steps=args.profiling_steps, > profiling_warmup_steps=args.profiling_warmup_steps, > ignore_eos=args.ignore_eos, > iteration_times=iteration_times, +> profiling_record_shapes=args.profiling_record_shapes, > ).cpu() > first_token_time = iteration_times[0] + encode_duration > logger.info(f"Time to first token = {first_token_time*1000}ms") @@ -806,7 +822,7 @@ > print(f"Graph compilation duration = {compilation_duration} seconds") > print(separator) > print() -391c490,507 +389c508,525 < input_ids = encoded_prompt --- > # Downloading and loading a dataset from the hub. @@ -827,7 +843,7 @@ > .shuffle() > .select(range(args.dataset_max_samples if args.dataset_max_samples > 0 else (raw_dataset[split]).num_rows)) > ) -393,399c509,516 +391,397c527,534 < if args.jit: < jit_input_texts = ["enable jit"] < jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) @@ -844,7 +860,7 @@ > logger.info( > f"No column name was given so automatically choosing '{column_name}' for prompts. If you would like to use another column of the dataset, you can set the argument `--column_name`." > ) -401,439c518,538 +399,437c536,556 < sig = inspect.signature(model.__call__) < jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) < traced_model = torch.jit.trace(model, jit_inputs, strict=False) @@ -906,7 +922,7 @@ > preprocess_function, > batched=True, > desc="Running tokenizer on dataset", -440a540,621 +438a558,640 > # After tokenization, we can remove the column of interest > raw_dataset = raw_dataset.remove_columns([column_name]) > raw_dataset.set_format(type="torch") @@ -947,6 +963,7 @@ > profiling_steps=args.profiling_steps, > profiling_warmup_steps=args.profiling_warmup_steps, > ignore_eos=args.ignore_eos, +> profiling_record_shapes=args.profiling_record_shapes, > ).cpu() > return prompt, outputs > @@ -989,7 +1006,7 @@ > > throughput = total_new_tokens_generated / duration > # Print Stats -442,443c623,641 +440,441c642,660 < generated_sequences.append(total_sequence) < print(total_sequence) --- @@ -1012,7 +1029,7 @@ > habana_quantization_toolkit.finish_measurements(model) > if args.const_serialization_path and os.path.isdir(args.const_serialization_path): > import shutil -445c643 +443c662 < return generated_sequences --- > shutil.rmtree(args.const_serialization_path) diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt index f17194202..78cafcf01 100644 --- a/tests/example_diff/run_glue.txt +++ b/tests/example_diff/run_glue.txt @@ -1,7 +1,3 @@ -16c16 -< """ Finetuning the library models for sequence classification on GLUE.""" ---- -> """Finetuning the library models for sequence classification on GLUE.""" 29,30d28 < from datasets import load_dataset < @@ -31,8 +27,8 @@ > logger = logging.getLogger(__name__) > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") 67,68d76 < logger = logging.getLogger(__name__) < diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index ba11cf44a..0dbbe3f6c 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -28,8 +28,8 @@ < check_min_version("4.42.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") 184c192 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 1e19c0b8b..372a91383 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -34,8 +34,8 @@ 61a62,69 > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") > diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index ab710cf82..118add46a 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -32,8 +32,8 @@ > 58a62,67 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") > diff --git a/tests/example_diff/run_seq2seq_qa.txt b/tests/example_diff/run_seq2seq_qa.txt index 725064fc4..817c72b5a 100644 --- a/tests/example_diff/run_seq2seq_qa.txt +++ b/tests/example_diff/run_seq2seq_qa.txt @@ -24,8 +24,8 @@ > 54a58,63 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") > diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt index 401ba9c89..1fab0abcf 100644 --- a/tests/example_diff/run_speech_recognition_ctc.txt +++ b/tests/example_diff/run_speech_recognition_ctc.txt @@ -1,7 +1,3 @@ -17c17 -< """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition""" ---- -> """Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition""" 32,33d31 < from datasets import DatasetDict, load_dataset < @@ -29,8 +25,8 @@ > return () 59a61,66 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") > diff --git a/tests/example_diff/run_speech_recognition_seq2seq.txt b/tests/example_diff/run_speech_recognition_seq2seq.txt index 578978e6b..45b00bef9 100644 --- a/tests/example_diff/run_speech_recognition_seq2seq.txt +++ b/tests/example_diff/run_speech_recognition_seq2seq.txt @@ -22,8 +22,8 @@ 51c58,59 < check_min_version("4.42.0.dev0") --- -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") 230a239,242 > label_features_max_length: int = field( > default=None, @@ -47,7 +47,7 @@ > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, -> token=data_args.token, +> token=model_args.token, > ) > 310a334 diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt index a1bd762f6..9f01193b1 100644 --- a/tests/example_diff/run_summarization.txt +++ b/tests/example_diff/run_summarization.txt @@ -36,8 +36,8 @@ > 60a67,72 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") > diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt index a0f161c51..1aa504c06 100644 --- a/tests/example_diff/run_translation.txt +++ b/tests/example_diff/run_translation.txt @@ -28,8 +28,8 @@ > 60a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.38.0") -> check_optimum_habana_min_version("1.10.0") +> check_min_version("4.40.0") +> check_optimum_habana_min_version("1.11.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") > diff --git a/tests/resource/image-captioning-example.png b/tests/resource/img/image-captioning-example.png similarity index 100% rename from tests/resource/image-captioning-example.png rename to tests/resource/img/image-captioning-example.png diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index afc57728a..89a101fdf 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -16,25 +16,43 @@ import json import os +import random import re import subprocess import tempfile from io import BytesIO from pathlib import Path +from typing import Union from unittest import TestCase, skipUnless import numpy as np -import pytest import requests +import safetensors import torch -from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UniPCMultistepScheduler +from diffusers import ( + AutoencoderKL, + AutoencoderKLTemporalDecoder, + ControlNetModel, + UNet2DConditionModel, + UNetSpatioTemporalConditionModel, + UniPCMultistepScheduler, +) from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.utils import load_image, numpy_to_pil +from diffusers.utils.testing_utils import floats_tensor from diffusers.utils.torch_utils import randn_tensor from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from transformers.testing_utils import parse_flag_from_env, slow from optimum.habana import GaudiConfig @@ -48,6 +66,7 @@ GaudiStableDiffusionPipeline, GaudiStableDiffusionUpscalePipeline, GaudiStableDiffusionXLPipeline, + GaudiStableVideoDiffusionPipeline, ) from optimum.habana.utils import set_seed @@ -431,7 +450,7 @@ def test_stable_diffusion_batch_sizes(self): prompt = "A painting of a squirrel eating a burger" - # Test batch_size > 1 where batch_size is a divider of the total number of generated images + # Test num_images > 1 where num_images is a divider of the total number of generated images batch_size = 3 num_images_per_prompt = batch_size**2 images = sd_pipe( @@ -458,7 +477,7 @@ def test_stable_diffusion_batch_sizes(self): self.assertEqual(len(images), num_prompts * num_images_per_prompt) self.assertEqual(images[-1].shape, (64, 64, 3)) - # Test batch_size when it is not a divider of the toal number of generated images for a single prompt + # Test num_images when it is not a divider of the total number of generated images for a single prompt num_images_per_prompt = 7 images = sd_pipe( prompt, @@ -1729,7 +1748,6 @@ def test_train_text_to_image_script(self): self.assertEqual(return_code, 0) @slow - @pytest.mark.skip(reason="The dataset used in this test is not available at the moment.") def test_train_text_to_image_sdxl(self): with tempfile.TemporaryDirectory() as tmpdir: path_to_script = ( @@ -1744,11 +1762,10 @@ def test_train_text_to_image_sdxl(self): python3 {path_to_script} --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 - --pretrained_vae_model_name_or_path stabilityai/sdxl-vae - --dataset_name lambdalabs/pokemon-blip-captions - --resolution 512 - --crop_resolution 512 - --center_crop + --pretrained_vae_model_name_or_path madebyollin/sdxl-vae-fp16-fix + --dataset_name lambdalabs/naruto-blip-captions + --resolution 64 + --crop_resolution 64 --random_flip --proportion_empty_prompts=0.2 --train_batch_size 16 @@ -1762,7 +1779,10 @@ def test_train_text_to_image_sdxl(self): --use_hpu_graphs_for_training --use_hpu_graphs_for_inference --bf16 + --adjust_throughput + --center_crop --max_train_steps 2 + --checkpointing_steps 2 --output_dir {tmpdir} """.split() @@ -1774,8 +1794,10 @@ def test_train_text_to_image_sdxl(self): self.assertEqual(return_code, 0) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + self.assertTrue( + os.path.isfile(os.path.join(tmpdir, "checkpoint-2", "unet", "diffusion_pytorch_model.safetensors")) + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "checkpoint-2", "unet", "config.json"))) class TrainControlNet(TestCase): @@ -1876,3 +1898,316 @@ def test_train_controlnet(self): ).images[0] self.assertEqual(image.shape, (512, 512, 3)) + + +def install_requirements(requirements_filename: Union[str, os.PathLike]): + """ + Installs the necessary requirements to run the example if the provided file exists, otherwise does nothing. + """ + + if not Path(requirements_filename).exists(): + return + + cmd_line = f"pip install -r {requirements_filename}".split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + assert return_code == 0 + + +class DreamBooth(TestCase): + def _test_dreambooth(self, extra_config, train_text_encoder=False): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "stable-diffusion" + / "training" + / "train_dreambooth.py" + ) + install_requirements(path_to_script.parent / "requirements.txt") + instance_prompt = "soccer player kicking a ball" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + python3 + {path_to_script} + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --train_text_encoder + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --gaudi_config_name Habana/stable-diffusion + --output_dir {tmpdir} + """.split() + + test_args.append("--instance_prompt") + test_args.append(instance_prompt) + if "oft" not in extra_config: + test_args.append("--use_hpu_graphs_for_training") + test_args.append("--use_hpu_graphs_for_inference") + if train_text_encoder: + test_args.append("--train_text_encoder") + test_args.append(extra_config) + p = subprocess.Popen(test_args) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + # save_pretrained smoke test + if "full" in extra_config: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + if train_text_encoder: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "text_encoder", "model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + else: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "adapter_model.safetensors"))) + if train_text_encoder: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "text_encoder", "adapter_model.safetensors"))) + + def test_dreambooth_full(self): + self._test_dreambooth("full") + + def test_dreambooth_full_with_text_encoder(self): + self._test_dreambooth("full", train_text_encoder=True) + + def test_dreambooth_lora(self): + self._test_dreambooth("lora") + + def test_dreambooth_lora_with_text_encoder(self): + self._test_dreambooth("lora", train_text_encoder=True) + + def test_dreambooth_lokr(self): + self._test_dreambooth("lokr") + + def test_dreambooth_lokr_with_text_encoder(self): + self._test_dreambooth("lokr", train_text_encoder=True) + + def test_dreambooth_loha(self): + self._test_dreambooth("loha") + + def test_dreambooth_loha_with_text_encoder(self): + self._test_dreambooth("loha", train_text_encoder=True) + + def test_dreambooth_oft(self): + self._test_dreambooth("oft") + + def test_dreambooth_oft_with_text_encoder(self): + self._test_dreambooth("oft", train_text_encoder=True) + + +class DreamBoothLoRASDXL(TestCase): + def _test_dreambooth_lora_sdxl(self, train_text_encoder=False): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "stable-diffusion" + / "training" + / "train_dreambooth_lora_sdxl.py" + ) + install_requirements(path_to_script.parent / "requirements.txt") + + instance_prompt = "soccer player kicking a ball" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + python3 + {path_to_script} + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --gaudi_config_name Habana/stable-diffusion + --use_hpu_graphs_for_training + --use_hpu_graphs_for_inference + --output_dir {tmpdir} + """.split() + if train_text_encoder: + test_args.append("--train_text_encoder") + test_args.append("--instance_prompt") + test_args.append(instance_prompt) + p = subprocess.Popen(test_args) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + if train_text_encoder: + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") + for k in lora_state_dict.keys() + ) + else: + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_sdxl_with_text_encoder(self): + self._test_dreambooth_lora_sdxl(train_text_encoder=True) + + def test_dreambooth_lora_sdxl(self): + self._test_dreambooth_lora_sdxl(train_text_encoder=False) + + +class GaudiStableVideoDiffusionPipelineTester(TestCase): + """ + Tests the StableVideoDiffusionPipeline for Gaudi. + Adapted from: https://github.com/huggingface/diffusers/blob/v0.24.0-release/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py + """ + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNetSpatioTemporalConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=8, + out_channels=4, + down_block_types=( + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"), + cross_attention_dim=32, + num_attention_heads=8, + projection_class_embeddings_input_dim=96, + addition_time_embed_dim=32, + ) + scheduler = GaudiEulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + interpolation_type="linear", + num_train_timesteps=1000, + prediction_type="v_prediction", + sigma_max=700.0, + sigma_min=0.002, + steps_offset=1, + timestep_spacing="leading", + timestep_type="continuous", + trained_betas=None, + use_karras_sigmas=True, + ) + + torch.manual_seed(0) + vae = AutoencoderKLTemporalDecoder( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + latent_channels=4, + ) + + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + image_size=32, + intermediate_size=37, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(config) + + torch.manual_seed(0) + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) + components = { + "unet": unet, + "image_encoder": image_encoder, + "scheduler": scheduler, + "vae": vae, + "feature_extractor": feature_extractor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)).to(device) + inputs = { + "generator": generator, + "image": image, + "num_inference_steps": 2, + "output_type": "pt", + "min_guidance_scale": 1.0, + "max_guidance_scale": 2.5, + "num_frames": 2, + "height": 32, + "width": 32, + } + return inputs + + def test_stable_video_diffusion_single_video(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + gaudi_config = GaudiConfig(use_torch_autocast=False) + sd_pipe = GaudiStableVideoDiffusionPipeline(use_habana=True, gaudi_config=gaudi_config, **components) + for component in sd_pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + + sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + outputs = sd_pipe( + **self.get_dummy_inputs(device), + ).frames + image = outputs[0] + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(len(outputs), 1) + self.assertEqual(image.shape, (2, 3, 32, 32)) + + expected_slice = np.array([0.5910, 0.5797, 0.5521, 0.6628, 0.6212, 0.6422, 0.5681, 0.5232, 0.5343]) + + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) + + @slow + def test_stable_video_diffusion_no_throughput_regression_bf16(self): + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png" + ) + model_name = "stabilityai/stable-video-diffusion-img2vid-xt" + scheduler = GaudiEulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler") + + pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained( + model_name, + scheduler=scheduler, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"), + torch_dtype=torch.bfloat16, + ) + set_seed(42) + prompt_image = load_image(image_url) + outputs = pipeline( + image=prompt_image, + num_videos_per_prompt=1, + batch_size=1, + height=256, + width=256, + ) + + self.assertEqual(len(outputs.frames[0]), 25) + if IS_GAUDI2: + self.assertGreaterEqual(outputs.throughput, 0.95 * 0.012) diff --git a/tests/test_encoder_decoder.py b/tests/test_encoder_decoder.py index cacb1024d..06d03ff92 100644 --- a/tests/test_encoder_decoder.py +++ b/tests/test_encoder_decoder.py @@ -16,7 +16,7 @@ MODELS_TO_TEST = { "summarization": { "bf16": [ - ("facebook/bart-large-cnn", "Habana/bart", 5.233, 26.6928, 2, 1), + ("facebook/bart-large-cnn", "Habana/bart", 3.9, 28.9801, 2, 2), ("t5-3b", "Habana/t5", 2.955, 21.8877, 2, 1), ], }, @@ -34,7 +34,7 @@ MODELS_TO_TEST = { "summarization": { "bf16": [ - ("facebook/bart-large-cnn", "Habana/bart", 2.628, 26.7494, 2, 1), + ("facebook/bart-large-cnn", "Habana/bart", 2.304, 29.174, 2, 2), ("t5-3b", "Habana/t5", 1.005, 21.7286, 2, 1), ], }, @@ -146,7 +146,7 @@ def _test_text_summarization( "--use_habana", f"--per_device_eval_batch_size {batch_size}", f"--gaudi_config_name {gaudi_config}", - f"--generation_num_beams {num_beams}", + f"--num_beams {num_beams}", "--ignore_pad_token_for_loss False", "--pad_to_max_length", "--use_hpu_graphs_for_inference", diff --git a/tests/test_examples.py b/tests/test_examples.py index 406f76939..6e4595c8e 100755 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -252,7 +252,9 @@ def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: st return False - def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepspeed=False, fsdp=False): + def __new__( + cls, name, bases, attrs, example_name=None, multi_card=False, deepspeed=False, fsdp=False, torch_compile=False + ): distribution = "single_card" if multi_card: distribution = "multi_card" @@ -273,7 +275,7 @@ def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepsp for model_name, gaudi_config_name in models_to_test: if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp): attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test( - model_name, gaudi_config_name, multi_card, deepspeed, fsdp + model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile ) attrs["EXAMPLE_NAME"] = example_name return super().__new__(cls, name, bases, attrs) @@ -286,6 +288,7 @@ def _create_test( multi_card: bool = False, deepspeed: bool = False, fsdp: bool = False, + torch_compile: bool = False, ) -> Callable[[], None]: """ Create a test function that runs an example for a specific (model_name, gaudi_config_name) pair. @@ -388,11 +391,22 @@ def test(self): if "llama" in model_name: env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt") env_variables["PT_HPU_LAZY_MODE"] = "0" + elif deepspeed and "gpt-neox-20b" in model_name: + env_variables["LD_PRELOAD"] = "" extra_command_line_arguments = baseline.get("distribution").get(distribution).get("extra_arguments", []) if os.environ.get("DATA_CACHE", None) is not None and self.EXAMPLE_NAME == "run_clip": extra_command_line_arguments[0] = "--data_dir {}".format(os.environ["DATA_CACHE"]) + elif torch_compile and ( + model_name == "bert-large-uncased-whole-word-masking" or model_name == "roberta-large" + ): + extra_command_line_arguments.append("--torch_compile_backend hpu_backend") + extra_command_line_arguments.append("--torch_compile") + if "--use_hpu_graphs_for_inference" in extra_command_line_arguments: + extra_command_line_arguments.remove("--use_hpu_graphs_for_inference") + env_variables["PT_HPU_LAZY_MODE"] = "0" + env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" with TemporaryDirectory() as tmp_dir: cmd_line = self._create_command_line( @@ -608,12 +622,14 @@ class DeepSpeedTextClassificationExampleTester( DATASET_PARAMETER_NAME = "task_name" -class QuestionAnsweringExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_qa"): +class QuestionAnsweringExampleTester( + ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_qa", torch_compile=True +): TASK_NAME = "squad" class MultiCardQuestionAnsweringExampleTester( - ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_qa", multi_card=True + ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_qa", multi_card=True, torch_compile=True ): TASK_NAME = "squad" @@ -697,6 +713,12 @@ class ProteinFoldingExampleTester2(ExampleTesterBase, metaclass=ExampleTestMeta, pass +class CausalLanguageModelingLORAExampleTester( + ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm" +): + TASK_NAME = "databricks/databricks-dolly-15k" + + class MultiCardCausalLanguageModelingLORAExampleTester( ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True ): diff --git a/tests/test_fp8_examples.py b/tests/test_fp8_examples.py new file mode 100644 index 000000000..54f5a7a63 --- /dev/null +++ b/tests/test_fp8_examples.py @@ -0,0 +1,138 @@ +import json +import os +import re +import subprocess +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR + + +if os.environ.get("GAUDI2_CI", "0") == "1": + # Gaudi2 CI baselines + MODELS_TO_TEST = { + "fp8": [ + ( + "mistralai/Mistral-7B-Instruct-v0.2", + "tatsu-lab/alpaca", + "", + 12.373, + 0.7538, + "language-modeling", + 8, + 8, + "run_lora_clm.py", + ), + ], + } +else: + # FP8 is not supported on Gaudi1 + MODELS_TO_TEST = {"fp8": []} + + +def _test_fp8_train( + model_name: str, + dataset_name: str, + gaudi_config: str, + baseline: float, + baseline_acc: float, + task: str, + batch_size_train: int, + batch_size_eval: int, + script: str, + token: str, + world_size: int = 8, +): + path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" + + # Install question-answering example requirements + cmd_line = f"pip install -r {path_to_example_dir / task / 'requirements.txt'}".split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + assert return_code == 0 + + command = ["python3"] + + command += [ + f"{path_to_example_dir / task / script}", + f"--model_name_or_path {model_name}", + f"--dataset_name {dataset_name}", + "--do_train", + "--do_eval", + f"--per_device_eval_batch_size {batch_size_eval}", + f"--per_device_train_batch_size {batch_size_train}", + "--use_habana", + "--use_lazy_mode", + "--fp8 True", + ] + + if model_name == "mistralai/Mistral-7B-Instruct-v0.2": + command += [ + "--num_train_epochs 3", + "--evaluation_strategy no", + "--save_strategy no", + "--learning_rate 4e-4", + "--warmup_ratio 0.03", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--throughput_warmup_steps 5", + "--lora_rank 8", + "--lora_target_modules v_proj q_proj", + "--lora_alpha 16", + "--lora_dropout 0.05", + "--dataset_concatenation", + "--max_seq_length 512", + "--low_cpu_mem_usage True", + "--validation_split_percentage 4", + "--adam_epsilon 1e-08", + f"--token {token.value}", + ] + + with TemporaryDirectory() as tmp_dir: + command.append(f"--output_dir {tmp_dir}") + print(f"\n\nCommand to test: {' '.join(command)}\n") + + pattern = re.compile(r"([\"\'].+?[\"\'])|\s") + command = [x for y in command for x in re.split(pattern, y) if x] + + proc = subprocess.run(command) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) + raise + + with open(Path(tmp_dir) / "all_results.json") as fp: + results = json.load(fp) + + # Ensure performance requirements (throughput) are met + assert results["train_samples_per_second"] >= (2 - TIME_PERF_FACTOR) * baseline + assert results["eval_accuracy"] >= ACCURACY_PERF_FACTOR * baseline_acc + + +@pytest.mark.parametrize( + "model_name, dataset_name, gaudi_config, baseline, baseline_acc, task, bs_train, bs_eval, script", + MODELS_TO_TEST["fp8"], +) +def test_fp8_train( + model_name: str, + dataset_name: str, + gaudi_config: str, + baseline: float, + baseline_acc: float, + task: str, + bs_train: int, + bs_eval: int, + script: str, + token: str, +): + _test_fp8_train( + model_name, dataset_name, gaudi_config, baseline, baseline_acc, task, bs_train, bs_eval, script, token + ) diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index 366a0d71e..b3af2a05c 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -28,7 +28,7 @@ ( "meta-llama/Llama-2-7b-hf", "", - 87.016, + 85.016, 0.9093, "language-modeling", 8, @@ -123,6 +123,8 @@ def _test_fsdp( "--low_cpu_mem_usage True", "--attn_softmax_bf16 True", "--num_train_epochs 3", + "--use_flash_attention True", + "--flash_attention_causal_mask True", f"--token {token.value}", ] diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 4401660c6..85982ce7c 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -15,6 +15,13 @@ MODELS_TO_TEST = { "bf16": [ ("llava-hf/llava-1.5-7b-hf", 1, 87.2901500056982), + ("llava-hf/llava-1.5-13b-hf", 1, 54.41252589197953), + ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546), + ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), + ], + "fp8": [ + ("llava-hf/llava-1.5-7b-hf", 1, 123.00953973789325), + ("llava-hf/llava-1.5-13b-hf", 1, 82.81132373492122), ], } else: @@ -22,7 +29,11 @@ MODELS_TO_TEST = { "bf16": [ ("llava-hf/llava-1.5-7b-hf", 1, 28.04096918512148), + ("llava-hf/llava-1.5-13b-hf", 1, 16.704731010481538), + ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 10.759228696741), + ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 6.96732060769783), ], + "fp8": [], } @@ -31,6 +42,7 @@ def _test_image_to_text( baseline: float, token: str, batch_size: int = 1, + fp8: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -58,6 +70,16 @@ def _test_image_to_text( pattern = re.compile(r"([\"\'].+?[\"\'])|\s") command = [x for y in command for x in re.split(pattern, y) if x] + if fp8: + print(f"\n\nCommand to test: {' '.join(command)}\n") + env_variables["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "image-to-text/quantization_config/maxabs_measure_include_outputs.json" + ) + subprocess.run(command, env=env_variables) + env_variables["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "image-to-text/quantization_config/maxabs_quant.json" + ) + proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue @@ -79,3 +101,8 @@ def _test_image_to_text( @pytest.mark.parametrize("model_name, batch_size, baseline", MODELS_TO_TEST["bf16"]) def test_image_to_text_bf16(model_name: str, baseline: float, batch_size: int, token: str): _test_image_to_text(model_name, baseline, token, batch_size) + + +@pytest.mark.parametrize("model_name, batch_size, baseline", MODELS_TO_TEST["fp8"]) +def test_image_to_text_fp8(model_name: str, baseline: float, batch_size: int, token: str): + _test_image_to_text(model_name, baseline, token, batch_size, fp8=True) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5a56d0a12..f94ebe6a4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -13,17 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import TestCase +import os +import numpy as np +import pytest import torch +from datasets import load_dataset from habana_frameworks.torch.hpu import wrap_in_hpu_graph from transformers import pipeline from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -class GaudiPipelineTester(TestCase): - def _test_image_to_text(self, model, expected_result): +MODELS_TO_TEST = { + "text-to-speech": [ + ("microsoft/speecht5_tts", 16000), + ("facebook/hf-seamless-m4t-medium", 16000), + ("facebook/mms-tts-eng", 16000), + ], + "image-to-text": [ + ("Salesforce/blip-image-captioning-base", "a soccer player is playing a game on the app"), + ("nlpconnect/vit-gpt2-image-captioning", "a soccer game with a player jumping to catch"), + ], +} + + +class TestGaudiPipeline: + @pytest.mark.parametrize("model, expected_result", MODELS_TO_TEST["image-to-text"]) + def test_image_to_text(self, model, expected_result): adapt_transformers_to_gaudi() MODEL_DTYPE_LIST = [torch.bfloat16, torch.float32] generate_kwargs = { @@ -32,7 +49,7 @@ def _test_image_to_text(self, model, expected_result): "max_new_tokens": 128, "ignore_eos": False, } - image = "./tests/resource/image-captioning-example.png" + image = os.path.dirname(__file__) + "/resource/img/image-captioning-example.png" for model_dtype in MODEL_DTYPE_LIST: generator = pipeline( "image-to-text", @@ -43,14 +60,37 @@ def _test_image_to_text(self, model, expected_result): generator.model = wrap_in_hpu_graph(generator.model) for i in range(3): output = generator(image, generate_kwargs=generate_kwargs) - self.assertTrue(output[0]["generated_text"].startswith(expected_result)) + assert output[0]["generated_text"].startswith(expected_result) + + @pytest.mark.parametrize("model, expected_sample_rate", MODELS_TO_TEST["text-to-speech"]) + def test_text_to_speech(self, model, expected_sample_rate): + adapt_transformers_to_gaudi() + MODEL_DTYPE_LIST = [torch.bfloat16, torch.float32] + text = "hello, the dog is cooler" + for model_dtype in MODEL_DTYPE_LIST: + generator = pipeline( + "text-to-speech", + model=model, + torch_dtype=model_dtype, + device="hpu", + ) + forward_params = None + if generator.model.config.model_type == "speecht5": + embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") + speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to("hpu") + forward_params = {"speaker_embeddings": speaker_embedding} + if generator.model.config.model_type == "seamless_m4t": + forward_params = {"tgt_lang": "eng"} - def test_image_to_text_blip(self): - model = "Salesforce/blip-image-captioning-base" - expected_result = "a soccer player is playing a game on the app" - self._test_image_to_text(model, expected_result) + generate_kwargs = None + if generator.model.can_generate(): + generate_kwargs = {"lazy_mode": True, "ignore_eos": False, "hpu_graphs": True} - def test_image_to_text_vit(self): - model = "nlpconnect/vit-gpt2-image-captioning" - expected_result = "a soccer game with a player jumping to catch" - self._test_image_to_text(model, expected_result) + generator.model = wrap_in_hpu_graph(generator.model) + with torch.autocast( + "hpu", torch.bfloat16, enabled=(model_dtype == torch.bfloat16) + ), torch.no_grad(), torch.inference_mode(): + for i in range(3): + output = generator(text, forward_params=forward_params, generate_kwargs=generate_kwargs) + assert isinstance(output["audio"], np.ndarray) + assert output["sampling_rate"] == expected_sample_rate diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 88dec45ab..4bbc062fe 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -2,14 +2,19 @@ import os import re import subprocess +import sys from pathlib import Path from tempfile import TemporaryDirectory +from unittest import TestCase import pytest from .test_examples import TIME_PERF_FACTOR +prev_quant_model_name = None +prev_quant_rank = 0 + if os.environ.get("GAUDI2_CI", "0") == "1": # Gaudi2 CI baselines MODELS_TO_TEST = { @@ -29,19 +34,29 @@ ("meta-llama/Meta-Llama-3-8B", 1, True, 129), ("meta-llama/Llama-2-7b-hf", 512, True, 12808), ("meta-llama/Llama-2-7b-hf", 512, False, 8711), # in some cases like TGI, reuse_cache isnt used - ("stabilityai/stablelm-2-12b", 1, False, 80.70269834414843), + ("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218), ("codellama/CodeLlama-34b-hf", 1, True, 32.644), + ("bigcode/starcoder2-3b", 1, False, 234.2649120507936), ("adept/persimmon-8b-base", 4, False, 366.73968820698406), - ("Qwen/Qwen1.5-7B", 4, False, 451.7454544774087), + ("Qwen/Qwen1.5-7B", 4, False, 488.82855464593257), ("google/gemma-7b", 1, False, 109.70751574382221), ], "fp8": [ - ("tiiuae/falcon-180B", 52.85086442722326), - ("mistralai/Mistral-7B-Instruct-v0.2", 0), - ("mistralai/Mixtral-8x7B-v0.1", 39.26845661768185), - ("meta-llama/Llama-2-7b-hf", 0.0), - ("meta-llama/Llama-2-70b-hf", 0.0), - ("microsoft/phi-2", 254.08932787178165), + ("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68), + ("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128, 13152.7), + ("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048, 4774.7), + ("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128, 1293.3), + ("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048, 1942.9), + ("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, 5374.6), + ("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, 7422.4), + ("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, 568.5), + ("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, 4656.2), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128, 12397.11410288204), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, 5394.675714459493), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, 919.8470890081497), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048, 2471.950758729518), + ("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128, 39.26845661768185), + ("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165), ], "deepspeed": [ ("bigscience/bloomz", 36.77314954096159), @@ -56,28 +71,6 @@ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], } - LLAMA2_FP8_CONFIG = { - "meta-llama/Llama-2-7b-hf": [ - ("1200", "128", "128", 13415.103401047876), - ("160", "128", "2048", 2930.9086839308384), - ("90", "2048", "128", 1104.0681776998265), - ("60", "2048", "2048", 1248.3177998857964), - ], - "meta-llama/Llama-2-70b-hf": [ - ("2500", "128", "128", 10327.895829614834), - ("430", "128", "2048", 10425.578514886345), - ("40", "2048", "128", 695.475101514524), - ("64", "2048", "2048", 2773.173092391251), - ], - } - MISTRAL_FP8_CONFIG = { - "mistralai/Mistral-7B-Instruct-v0.2": [ - ("896", "128", "128", 12397.11410288204), - ("120", "128", "2048", 5394.675714459493), - ("120", "2048", "128", 919.8470890081497), - ("44", "2048", "2048", 2471.950758729518), - ], - } else: # Gaudi1 CI baselines MODELS_TO_TEST = { @@ -98,6 +91,7 @@ ("stabilityai/stablelm-2-12b", 1, False, 26.80858949645992), ("Qwen/Qwen1.5-7B", 1, False, 39.29068423087616), ("adept/persimmon-8b-base", 1, False, 34.53559807384106), + ("bigcode/starcoder2-3b", 1, False, 82.09655684566117), ], "fp8": [], "deepspeed": [ @@ -118,6 +112,8 @@ def _test_text_generation( world_size: int = 8, torch_compile: bool = False, fp8: bool = False, + max_input_tokens: int = 0, + max_output_tokens: int = 100, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -135,13 +131,16 @@ def _test_text_generation( f"--model_name_or_path {model_name}", f"--batch_size {batch_size}", "--use_kv_cache", - "--max_new_tokens 100", + f"--max_new_tokens {max_output_tokens}", ] if "llama" in model_name.lower(): command += ["--trim_logits", "--attn_softmax_bf16"] - if reuse_cache or torch_compile or fp8: + if "falcon" in model_name.lower(): + command += ["--use_flash_attention", "--flash_attention_causal_mask"] + + if reuse_cache or torch_compile: command += ["--reuse_cache"] if torch_compile: @@ -159,59 +158,67 @@ def _test_text_generation( if fp8: if "--trim_logits" not in command: command += ["--trim_logits"] - if "Llama-2" in model_name or "Mistral" in model_name: - command.remove("--max_new_tokens 100") + if "Llama-2" in model_name: + command.insert(-2, "--use_flash_attention") + command.insert(-2, "--flash_attention_recompute") + command.insert(-2, "--bucket_size 128") + command.insert(-2, "--bucket_internal") + elif "falcon-180b" in model_name.lower(): + command.insert(-2, "--flash_attention_recompute") + + global prev_quant_model_name + global prev_quant_rank + measure_command = None + # FP8 Measurement only needed + if (prev_quant_model_name is None) or (prev_quant_model_name != model_name) or (prev_quant_rank != world_size): + measure_command = [ + x for x in command if not x.startswith("--max_new_tokens") + ] # Remove max_new_tokens for measurement + measure_command = [ + x if not x.startswith("--batch_size") else "--batch_size 1" for x in measure_command + ] # Remove batch_size for measurement + + prev_quant_model_name = model_name + prev_quant_rank = world_size + + # FP8 text generation + command += [ + f"--max_input_tokens {max_input_tokens}", + "--limit_hpu_graphs", + ] with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") - print(f"\n\nCommand to test: {' '.join(command)}\n") - command.append(f"--token {token.value}") pattern = re.compile(r"([\"\'].+?[\"\'])|\s") if fp8: - env_variables["QUANT_CONFIG"] = os.path.join( - path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" - ) - command = [x for y in command for x in re.split(pattern, y) if x] - subprocess.run(command, env=env_variables) + env_variables["TQDM_DISABLE"] = "1" + if measure_command is not None: + measure_command.append(f"--token {token.value}") + env_variables["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" + ) + measure_command = [x for y in measure_command for x in re.split(pattern, y) if x] + print(f"\n\nMeasure Command to test: {' '.join(measure_command[:-2])}\n") + proc = subprocess.run(measure_command, env=env_variables) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(measure_command[:-2])}",) + raise + env_variables["QUANT_CONFIG"] = os.path.join( path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" ) - command.insert(-2, "--fp8") - command.insert(-2, "--warmup 1") - command.insert(-2, "--n_iterations 2") - if "Llama-2" in model_name or "Mistral" in model_name: - fp8_model_configs = LLAMA2_FP8_CONFIG if "Llama-2" in model_name else MISTRAL_FP8_CONFIG - command.insert(-2, "--limit_hpu_graphs") - command.insert(-2, "--max_input_tokens 1") - command.insert(-2, "--max_new_tokens 1") - command = [x for y in command for x in re.split(pattern, y) if x] - for model_config in fp8_model_configs[model_name]: - command[command.index("--batch_size") + 1] = model_config[0] - command[command.index("--max_input_tokens") + 1] = model_config[1] - command[command.index("--max_new_tokens") + 1] = model_config[2] - baseline = model_config[3] - proc = subprocess.run(command, env=env_variables) - - # Ensure the run finished without any issue - # Use try-except to avoid logging the token if used - try: - assert proc.returncode == 0 - except AssertionError as e: - if "'--token', 'hf_" in e.args[0]: - e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) - raise - - with open(Path(tmp_dir) / "results.json") as fp: - results = json.load(fp) - - # Ensure performance requirements (throughput) are met - assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline - return command = [x for y in command for x in re.split(pattern, y) if x] + print(f"\n\nCommand to test: {' '.join(command[:-2])}\n") proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue @@ -235,11 +242,32 @@ def test_text_generation_bf16(model_name: str, baseline: float, batch_size: int, _test_text_generation(model_name, baseline, token, batch_size, reuse_cache) -@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["fp8"]) -def test_text_generation_fp8(model_name: str, baseline: float, token: str): - deepspeed = True if "falcon-180B" in model_name or "Llama-2-70b" in model_name else False - world_size = 8 if "falcon-180B" in model_name or "Llama-2-70b" in model_name else None - _test_text_generation(model_name, baseline, token, deepspeed=deepspeed, world_size=world_size, fp8=True) +@pytest.mark.parametrize( + "model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", MODELS_TO_TEST["fp8"] +) +def test_text_generation_fp8( + model_name: str, + baseline: float, + world_size: int, + batch_size: int, + reuse_cache: bool, + input_len: int, + output_len: int, + token: str, +): + deepspeed = True if world_size > 1 else False + _test_text_generation( + model_name, + baseline, + token, + deepspeed=deepspeed, + world_size=world_size, + fp8=True, + batch_size=batch_size, + reuse_cache=reuse_cache, + max_input_tokens=input_len, + max_output_tokens=output_len, + ) @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["deepspeed"]) @@ -257,3 +285,48 @@ def test_text_generation_torch_compile(model_name: str, baseline: float, token: def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) + + +class TextGenPipeline(TestCase): + def test_text_generation_pipeline_script(self): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "text-generation" + / "text-generation-pipeline" + / "run_pipeline.py" + ) + + cmd_line = f"""ls {path_to_script}""".split() + + # check find existence + p = subprocess.Popen(cmd_line) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + + def test_text_generation_pipeline_falcon(self): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "text-generation" + / "text-generation-pipeline" + / "run_pipeline.py" + ) + sys.path.append((Path(os.path.dirname(__file__)).parent / "examples" / "text-generation")) + cmd_line = f""" + python3 + {path_to_script} + --model_name_or_path tiiuae/falcon-7b + --max_new_tokens 100 + --bf16 + --use_hpu_graphs + --use_kv_cache + --do_sample + """.split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2cb314752..8be28e79c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -22,6 +22,7 @@ import subprocess import tempfile import unittest +from functools import partial from itertools import product from pathlib import Path from typing import Dict, List, Optional, Union @@ -49,7 +50,7 @@ get_gpu_count, get_tests_dir, is_staging_test, - parse_flag_from_env, + require_accelerate, require_optuna, require_safetensors, require_sentencepiece, @@ -57,15 +58,15 @@ require_tokenizers, require_torch, ) -from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_pt_utils import AcceleratorConfig -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + is_accelerate_available, is_safetensors_available, ) from transformers.utils.hp_naming import TrialShortNamer @@ -79,30 +80,26 @@ import transformers.optimization from torch import nn from torch.utils.data import IterableDataset - from transformers import EarlyStoppingCallback, GPT2Config, GPT2LMHeadModel, PreTrainedModel, TrainerState + from transformers import EarlyStoppingCallback, GPT2Config, PreTrainedModel, TrainerState from transformers.modeling_utils import unwrap_model from optimum.habana import GaudiTrainer + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + from optimum.habana.transformers.models.gpt2 import GaudiGPT2LMHeadModel if is_safetensors_available(): import safetensors.torch -PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" - +# for version specific tests in TrainerIntegrationTest +require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") -_run_safe_loading_tests_ = parse_flag_from_env("SAFE_LOADING_TESTS", default=False) +PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" -def safe_loading_test(test_case): - """ - Decorator marking a test as needing custom bf16 ops. - Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before. - - Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them. - """ - return unittest.skipUnless(_run_safe_loading_tests_, "test requires SAFE_LOADING_TESTS")(test_case) +adapt_transformers_to_gaudi() class RegressionDataset: @@ -709,6 +706,31 @@ def test_lr_scheduler_kwargs(self): self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args) self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords) + def test_cosine_with_min_lr_scheduler(self): + train_dataset = RegressionDataset() + model = RegressionModel() + num_steps, num_warmup_steps = 10, 2 + extra_kwargs = {"min_lr": 1e-5} # Non-default arguments + args = GaudiTrainingArguments( + "./regression", + lr_scheduler_type="cosine_with_min_lr", + lr_scheduler_kwargs=extra_kwargs, + learning_rate=0.2, + warmup_steps=num_warmup_steps, + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiTrainer(model, gaudi_config=get_gaudi_config(), args=args, train_dataset=train_dataset) + trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) + + # Checking that the scheduler was created + self.assertIsNotNone(trainer.lr_scheduler) + + # Check the last learning rate + for _ in range(num_steps): + trainer.lr_scheduler.step() + self.assertEqual(trainer.lr_scheduler.get_last_lr()[0], 1e-5) + def test_reduce_lr_on_plateau_args(self): # test passed arguments for a custom ReduceLROnPlateau scheduler train_dataset = RegressionDataset(length=64) @@ -874,7 +896,7 @@ def test_trainer_works_with_dict(self): def test_evaluation_with_keys_to_drop(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) - tiny_gpt2 = GPT2LMHeadModel(config) + tiny_gpt2 = GaudiGPT2LMHeadModel(config) x = torch.randint(0, 100, (128,)) eval_dataset = RepeatDataset(x) args = GaudiTrainingArguments("./test", use_habana=True, use_lazy_mode=True) @@ -962,7 +984,7 @@ def test_number_of_steps_in_training(self): def test_logging_inf_nan_filter(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) - tiny_gpt2 = GPT2LMHeadModel(config) + tiny_gpt2 = GaudiGPT2LMHeadModel(config) x = torch.randint(0, 100, (128,)) train_dataset = RepeatDataset(x) @@ -1257,19 +1279,6 @@ def test_save_checkpoints(self): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) - def test_save_checkpoints_is_atomic(self): - class UnsaveableTokenizer(PreTrainedTokenizerBase): - def save_pretrained(self, *args, **kwargs): - raise OSError("simulated file write error") - - with tempfile.TemporaryDirectory() as tmpdir: - trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) - # Attach unsaveable tokenizer to partially fail checkpointing - trainer.tokenizer = UnsaveableTokenizer() - with self.assertRaises(OSError) as _context: - trainer.train() - assert get_last_checkpoint(tmpdir) is None - @require_safetensors def test_safe_checkpoints(self): for save_safetensors in [True, False]: @@ -1480,7 +1489,6 @@ def test_training_with_resume_from_checkpoint_false(self): trainer.train(resume_from_checkpoint=False) - @safe_loading_test @require_safetensors def test_resume_training_with_safe_checkpoint(self): # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of @@ -1674,7 +1682,6 @@ def test_load_best_model_at_end(self): self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) - @safe_loading_test @require_safetensors def test_load_best_model_from_safetensors(self): total = int(self.n_epochs * 64 / self.batch_size) @@ -1933,6 +1940,10 @@ def test_accelerator_config_empty(self): self.assertEqual(trainer.accelerator.even_batches, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: + # gradient accumulation kwargs configures gradient_state + self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs) + def test_accelerator_config_from_dict(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -1941,16 +1952,20 @@ def test_accelerator_config_from_dict(self): model = RegressionPreTrainedModel(config) eval_dataset = SampleIterableDataset() + accelerator_config = { + "split_batches": True, + "dispatch_batches": True, + "even_batches": False, + "use_seedable_sampler": True, + } + if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: + accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True} + # Leaves all options as something *not* basic gaudi_config = get_gaudi_config() args = RegressionGaudiTrainingArguments( output_dir=tmp_dir, - accelerator_config={ - "split_batches": True, - "dispatch_batches": True, - "even_batches": False, - "use_seedable_sampler": True, - }, + accelerator_config=accelerator_config, use_habana=True, ) trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) @@ -1959,6 +1974,9 @@ def test_accelerator_config_from_dict(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) + def test_accelerator_config_from_yaml(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -1971,6 +1989,8 @@ def test_accelerator_config_from_yaml(self): "even_batches": False, "use_seedable_sampler": False, } + if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: + accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True} json.dump(accelerator_config, f) config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) @@ -1985,11 +2005,18 @@ def test_accelerator_config_from_yaml(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) + def test_accelerator_config_from_dataclass(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively + accelerator_config = AcceleratorConfig( - split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False + split_batches=True, + dispatch_batches=True, + even_batches=False, + use_seedable_sampler=False, ) config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) @@ -2005,6 +2032,35 @@ def test_accelerator_config_from_dataclass(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + @require_accelerate_version_min_0_28 + def test_accelerate_config_from_dataclass_grad_accum(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + + grad_acc_kwargs = { + "num_steps": 10, + "adjust_scheduler": False, + "sync_with_dataloader": False, + "sync_each_batch": True, + } + accelerator_config = AcceleratorConfig( + split_batches=True, + dispatch_batches=True, + even_batches=False, + use_seedable_sampler=False, + gradient_accumulation_kwargs=grad_acc_kwargs, + ) + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + with tempfile.TemporaryDirectory() as tmp_dir: + args = RegressionGaudiTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) + trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10) + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False) + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False) + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) + def test_accelerator_config_from_partial(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -2083,6 +2139,77 @@ def test_accelerator_config_only_deprecated_args(self): trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.split_batches, True) + @require_accelerate_version_min_0_28 + def test_accelerator_config_from_dict_grad_accum_num_steps(self): + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # case - TrainingArguments.gradient_accumulation_steps == 1 + # - gradient_accumulation_kwargs['num_steps] == 1 + # results in grad accum set to 1 + args = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, + gradient_accumulation_steps=1, + accelerator_config={ + "gradient_accumulation_kwargs": { + "num_steps": 1, + } + }, + ) + trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1) + + # case - TrainingArguments.gradient_accumulation_steps > 1 + # - gradient_accumulation_kwargs['num_steps] specified + # results in exception raised + args = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, + gradient_accumulation_steps=2, + accelerator_config={ + "gradient_accumulation_kwargs": { + "num_steps": 10, + } + }, + ) + with self.assertRaises(Exception) as context: + trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception)) + + def test_accelerator_config_not_instantiated(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(NotImplementedError) as context: + _ = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, + accelerator_config=AcceleratorConfig, + use_habana=True, + use_lazy_mode=True, + ) + self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + + # Now test with a custom subclass + @dataclasses.dataclass + class CustomAcceleratorConfig(AcceleratorConfig): + pass + + @dataclasses.dataclass + class CustomTrainingArguments(GaudiTrainingArguments): + accelerator_config: dict = dataclasses.field( + default=CustomAcceleratorConfig, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(NotImplementedError) as context: + _ = CustomTrainingArguments( + output_dir=tmp_dir, + use_habana=True, + use_lazy_mode=True, + ) + self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + def test_profiling(self): # 24 total steps and compilation takes place during the 1st three steps trainer = get_regression_trainer(profiling_warmup_steps=3, profiling_steps=21) @@ -2602,3 +2729,56 @@ def test_hyperparameter_search_backends(self): list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()), list(HPSearchBackend), ) + + +@require_torch +class OptimizerAndModelInspectionTest(unittest.TestCase): + def test_get_num_trainable_parameters(self): + model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) + args = GaudiTrainingArguments( + output_dir="tmp_trainer", + use_habana=True, + use_lazy_mode=True, + ) + # in_features * out_features + bias + layer_1 = 128 * 64 + 64 + layer_2 = 64 * 32 + 32 + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) + # Freeze the last layer + for param in model[-1].parameters(): + param.requires_grad = False + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) + + def test_get_learning_rates(self): + model = nn.Sequential(nn.Linear(128, 64)) + args = GaudiTrainingArguments( + output_dir="tmp_trainer", + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + with self.assertRaises(ValueError): + trainer.get_learning_rates() + trainer.create_optimizer() + self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) + + def test_get_optimizer_group(self): + model = nn.Sequential(nn.Linear(128, 64)) + args = GaudiTrainingArguments( + output_dir="tmp_trainer", + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + # ValueError is raised if optimizer is None + with self.assertRaises(ValueError): + trainer.get_optimizer_group() + trainer.create_optimizer() + # Get groups + num_groups = len(trainer.get_optimizer_group()) + self.assertEqual(num_groups, 2) + # Get group of parameter + param = next(model.parameters()) + group = trainer.get_optimizer_group(param) + self.assertIn(param, group["params"]) diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index 673e69a7f..84413f902 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -180,22 +180,3 @@ def compute_metrics(p: EvalPrediction) -> Dict: exit(1) trainer.args.eval_accumulation_steps = None - - # Check that saving does indeed work with temp dir rotation - # If this fails, will see a FileNotFoundError - model = RegressionModel() - training_args.max_steps = 1 - opt = torch.optim.Adam(model.parameters(), lr=1e-3) - sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda x: 1) - trainer = GaudiTrainer( - model, - gaudi_config=gaudi_config, - args=training_args, - optimizers=(opt, sched), - data_collator=DummyDataCollator(), - eval_dataset=dataset, - ) - trainer._save_checkpoint(model=None, trial=None) - # Check that the temp folder does not exist - assert not (Path(training_args.output_dir) / "tmp-checkpoint-0").exists() - assert (Path(training_args.output_dir) / "checkpoint-0").exists() diff --git a/tests/test_trainer_seq2seq.py b/tests/test_trainer_seq2seq.py index 816fb4713..165ae0dce 100644 --- a/tests/test_trainer_seq2seq.py +++ b/tests/test_trainer_seq2seq.py @@ -13,17 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from transformers import AutoTokenizer, T5ForConditionalGeneration +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, T5ForConditionalGeneration from transformers.testing_utils import TestCasePlus, require_torch from transformers.utils import is_datasets_available -from optimum.habana import GaudiSeq2SeqTrainer, GaudiSeq2SeqTrainingArguments +from optimum.habana import GaudiConfig, GaudiSeq2SeqTrainer, GaudiSeq2SeqTrainingArguments +from optimum.habana.transformers.generation import GaudiGenerationConfig +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi if is_datasets_available(): import datasets +adapt_transformers_to_gaudi() + + class GaudiSeq2seqTrainerTester(TestCasePlus): @require_torch def test_finetune_t5(self): @@ -123,3 +128,31 @@ def _compute_metrics(pred): # start evaluation using beam search trainer.evaluate(max_length=model.config.max_length, num_beams=2) + + @require_torch + def test_bad_generation_config_fail_early(self): + # Tests that a bad geneartion config causes the trainer to fail early + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small") + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest") + gen_config = GaudiGenerationConfig( + do_sample=False, top_p=0.9 + ) # bad: top_p is not compatible with do_sample=False + + training_args = GaudiSeq2SeqTrainingArguments( + output_dir="tmp_trainer", + predict_with_generate=True, + generation_config=gen_config, + use_habana=True, + use_lazy_mode=True, + ) + with self.assertRaises(ValueError) as exc: + _ = GaudiSeq2SeqTrainer( + model=model, + gaudi_config=GaudiConfig(), + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=lambda x: {"samples": x[0].shape[0]}, + ) + self.assertIn("The loaded generation config instance is invalid", str(exc.exception)) diff --git a/tests/transformers/tests/generation/test_beam_search.py b/tests/transformers/tests/generation/test_beam_search.py index 1398052e5..a93556e02 100644 --- a/tests/transformers/tests/generation/test_beam_search.py +++ b/tests/transformers/tests/generation/test_beam_search.py @@ -17,11 +17,12 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch -from ..test_modeling_common import floats_tensor, ids_tensor +from ..test_modeling_common import floats_tensor, ids_tensor, torch_device +assert torch_device == "hpu" if is_torch_available(): import torch from transformers.generation import ( @@ -354,7 +355,7 @@ def check_constrained_beam_scorer_update( token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids stacked_token_ids = stacked_token_ids + token_ids - fulfilling_sequence = torch.LongTensor(stacked_token_ids) + fulfilling_sequence = torch.LongTensor(stacked_token_ids).to(torch_device) fulfill_len = fulfilling_sequence.size(0) input_ids[:, :fulfill_len] = fulfilling_sequence @@ -427,7 +428,7 @@ def check_constrained_beam_scorer_finalize( token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids stacked_token_ids = stacked_token_ids + token_ids - fulfilling_sequence = torch.LongTensor(stacked_token_ids) + fulfilling_sequence = torch.LongTensor(stacked_token_ids).to(torch_device) fulfill_len = fulfilling_sequence.size(0) input_ids[:, :fulfill_len] = fulfilling_sequence diff --git a/tests/transformers/tests/generation/test_logits_process.py b/tests/transformers/tests/generation/test_logits_process.py index 8761e497c..d40a94d38 100644 --- a/tests/transformers/tests/generation/test_logits_process.py +++ b/tests/transformers/tests/generation/test_logits_process.py @@ -18,9 +18,9 @@ from parameterized import parameterized from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch -from ..test_modeling_common import ids_tensor +from ..test_modeling_common import ids_tensor, torch_device if is_torch_available(): @@ -51,6 +51,7 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) + from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor @require_torch @@ -59,6 +60,35 @@ def _get_uniform_logits(self, batch_size: int, length: int): scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length return scores + def test_logits_processor_expected_device(self): + EXPECTED_DEVICE_TYPE = "hpu" + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + eos_token_id = 0 + min_eos_p = 0.1 ## some small float + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + scores = self._get_uniform_logits(batch_size, vocab_size) + + processors = [ + MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id), + TemperatureLogitsWarper(temperature=0.5), + RepetitionPenaltyLogitsProcessor(penalty=2.0), + TopKLogitsWarper(3), + TopPLogitsWarper(0.8), + NoRepeatNGramLogitsProcessor(2), + NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id), + BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p), + ] + + scores = self._get_uniform_logits(batch_size, vocab_size) + self.assertTrue(scores.device.type == EXPECTED_DEVICE_TYPE) + for processor in processors: + scores = processor(input_ids, scores) + self.assertTrue(scores.device.type == EXPECTED_DEVICE_TYPE) + def test_min_length_dist_processor(self): vocab_size = 20 batch_size = 4 @@ -154,8 +184,9 @@ def test_temperature_dist_warper(self): temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5) temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3) - warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1) - warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1) + warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores), dim=-1) + warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores), dim=-1) + processed_scores = temp_dist_warper_smoother(input_ids, scores) # uniform distribution stays uniform self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) @@ -169,6 +200,9 @@ def test_temperature_dist_warper(self): self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max()) self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min()) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) + def test_repetition_penalty_dist_process(self): input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) vocab_size = 10 @@ -181,14 +215,17 @@ def test_repetition_penalty_dist_process(self): rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) - scores = rep_penalty_proc(input_ids, scores.clone()) + processed_scores = rep_penalty_proc(input_ids, scores) # check that values were correctly changed - self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2) - self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2) + self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) * 2) + self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) / 2) + + self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) / 2) + self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) / 2) - self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) - self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) def test_encoder_repetition_penalty_dist_process(self): input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) @@ -202,18 +239,21 @@ def test_encoder_repetition_penalty_dist_process(self): rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids) - scores = rep_penalty_proc(input_ids, scores.clone()) + processed_scores = rep_penalty_proc(input_ids, scores.clone()) # check that values were correctly changed - self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2) - self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2) + self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) / 2) + self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) * 2) - self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2) - self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2) + self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) * 2) + self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) * 2) # check that values not in the encoder ids were NOT changed - self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size)) - self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size)) + self.assertAlmostEqual(processed_scores[0, 2].item(), (1 / vocab_size)) + self.assertAlmostEqual(processed_scores[1, 2].item(), (1 / vocab_size)) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) def test_top_k_dist_warper(self): input_ids = None @@ -234,6 +274,9 @@ def test_top_k_dist_warper(self): self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == ramp_logits)) + # check special cases length = 5 @@ -270,6 +313,9 @@ def test_top_p_dist_warper(self): ) self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + # processor should not change logits in-place + self.assertFalse(torch.all(top_p_warp(input_ids, dist) == dist)) + # check edge cases with negative and extreme logits ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( batch_size, 1 @@ -305,6 +351,9 @@ def test_typical_dist_warper(self): ) self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + # processor should not change logits in-place + self.assertFalse(torch.all(typical_warp(input_ids, dist) == dist)) + # check special cases length = 5 @@ -352,6 +401,9 @@ def test_epsilon_dist_warper(self): ) self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + # processor should not change logits in-place + self.assertFalse(torch.all(epsilon_warp(input_ids, dist) == dist)) + # check edge cases with negative and extreme logits ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( batch_size, 1 @@ -389,6 +441,9 @@ def test_eta_dist_warper(self): ) self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + # processor should not change logits in-place + self.assertFalse(torch.all(eta_warp(input_ids, dist) == dist)) + # check edge cases with negative and extreme logits ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( batch_size, 1 @@ -425,6 +480,10 @@ def test_no_repeat_ngram_dist_processor(self): torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]] ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == filtered_scores_2_gram)) + self.assertFalse(torch.all(scores == filtered_scores_3_gram)) + def test_encoder_no_repeat_ngram_dist_processor(self): vocab_size = 3 num_beams = 2 @@ -449,6 +508,10 @@ def test_encoder_no_repeat_ngram_dist_processor(self): torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]] ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == filtered_scores_2_gram)) + self.assertFalse(torch.all(scores == filtered_scores_3_gram)) + # Batched input vocab_size = 3 num_beams = 2 @@ -507,6 +570,9 @@ def test_no_bad_words_dist_processor(self): torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]] ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == filtered_scores)) + # check edge case no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id) filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) @@ -536,6 +602,9 @@ def test_bias_dist_processor(self): filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]] ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == filtered_scores)) + def test_processor_list(self): batch_size = 4 sequence_length = 10 @@ -607,6 +676,16 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids): torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]] ) + def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids): + return [] + + prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1) + + self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == filtered_scores)) + def test_hamming_diversity(self): vocab_size = 4 num_beams = 2 @@ -634,6 +713,9 @@ def test_hamming_diversity(self): ) ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) + def test_forced_bos_token_logits_processor(self): vocab_size = 20 batch_size = 4 @@ -644,15 +726,19 @@ def test_forced_bos_token_logits_processor(self): # check that all scores are -inf except the bos_token_id score input_ids = ids_tensor((batch_size, 1), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) - self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all()) - self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero + processed_scores = logits_processor(input_ids, scores) + self.assertTrue(torch.isneginf(processed_scores[:, bos_token_id + 1 :]).all()) + # score for bos_token_id shold be zero + self.assertListEqual(processed_scores[:, bos_token_id].tolist(), 4 * [0]) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) # check that bos_token_id is not forced if current length is greater than 1 input_ids = ids_tensor((batch_size, 4), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) - self.assertFalse(torch.isinf(scores).any()) + processed_scores = logits_processor(input_ids, scores) + self.assertFalse(torch.isinf(processed_scores).any()) def test_forced_eos_token_logits_processor(self): vocab_size = 20 @@ -665,15 +751,19 @@ def test_forced_eos_token_logits_processor(self): # check that all scores are -inf except the eos_token_id when max_length-1 is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) - self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all()) - self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero + processed_scores = logits_processor(input_ids, scores) + self.assertTrue(torch.isneginf(processed_scores[:, eos_token_id + 1 :]).all()) + # score for eos_token_id should be zero + self.assertListEqual(processed_scores[:, eos_token_id].tolist(), 4 * [0]) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) # check that eos_token_id is not forced if max_length-1 is not reached input_ids = ids_tensor((batch_size, 3), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) - self.assertFalse(torch.isinf(scores).any()) + processed_scores = logits_processor(input_ids, scores) + self.assertFalse(torch.isinf(processed_scores).any()) def test_remove_nan_inf_logits_processor(self): scores = torch.tensor( @@ -683,19 +773,25 @@ def test_remove_nan_inf_logits_processor(self): logits_processor = InfNanRemoveLogitsProcessor() - scores = logits_processor(input_ids, scores) + processed_scores = logits_processor(input_ids, scores) self.assertTrue( torch.allclose( - scores, + processed_scores, torch.tensor( - [[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]], + [ + [0.0, 0.7, 0.8, 0.0], + [0.1, torch.finfo(processed_scores.dtype).max, 0.3, torch.finfo(processed_scores.dtype).min], + ], device=torch_device, ), atol=1e-6, ) ) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == processed_scores)) + def test_exponential_decay_length_penalty(self): vocab_size = 20 batch_size = 4 @@ -722,11 +818,16 @@ def test_exponential_decay_length_penalty(self): input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size) scores = self._get_uniform_logits(batch_size, vocab_size) scores_after_start = length_decay_processor(input_ids, scores) - self.assertTrue( - torch.gt( - scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id] - ).all() - ) + self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all()) + + # check the penalty increases negative scores + input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size) + scores = torch.neg(self._get_uniform_logits(batch_size, vocab_size)) + scores_after_start = length_decay_processor(input_ids, scores) + self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all()) + + # processor should not change logits in-place + self.assertFalse(torch.all(scores == scores_after_start)) def test_normalization(self): input_ids = None @@ -743,6 +844,9 @@ def test_normalization(self): self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1))) + # processor should not change logits in-place + self.assertFalse(torch.all(scores == normalized_scores)) + def test_classifier_free_guidance(self): class Namespace(dict): pass @@ -793,3 +897,35 @@ def lsm(x): self.assertAlmostEqual(out[0].item(), res[0].item()) self.assertAlmostEqual(out[1].item(), res[1].item()) self.assertAlmostEqual(out[2].item(), res[2].item()) + + def test_early_stop_processor(self): + input_ids = None + eos_token_id = 2 + min_eos_p = 0.1 ## some small float + + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + scores[0].tolist(), + [float("-inf"), float("-inf"), scores[0][0], float("-inf")], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list) + + def test_early_stop_processor_multi_eos(self): + input_ids = None + eos_token_id = [2, 3] + min_eos_p = 0.1 ## some small float + + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + scores[0].tolist(), + [float("-inf"), float("-inf"), scores[0][0], scores[0][0]], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list) diff --git a/tests/transformers/tests/generation/test_stopping_criteria.py b/tests/transformers/tests/generation/test_stopping_criteria.py index 0c1ed7fd9..0ce7838ee 100644 --- a/tests/transformers/tests/generation/test_stopping_criteria.py +++ b/tests/transformers/tests/generation/test_stopping_criteria.py @@ -17,14 +17,15 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch -from ..test_modeling_common import ids_tensor +from ..test_modeling_common import ids_tensor, torch_device if is_torch_available(): import torch from transformers.generation import ( + EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, @@ -53,37 +54,37 @@ def test_list_criteria(self): ] ) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) def test_max_length_criteria(self): criteria = MaxLengthCriteria(max_length=10) input_ids, scores = self._get_tensors(5) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) def test_max_new_tokens_criteria(self): criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) input_ids, scores = self._get_tensors(5) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(9) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores))) input_ids, scores = self._get_tensors(10) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores))) criteria_list = StoppingCriteriaList([criteria]) self.assertEqual(criteria_list.max_length, 10) @@ -92,10 +93,31 @@ def test_max_time_criteria(self): input_ids, scores = self._get_tensors(5) criteria = MaxTimeCriteria(max_time=0.1) - self.assertFalse(criteria(input_ids, scores)) + self.assertFalse(all(criteria(input_ids, scores, needs_tensor_output=True))) + self.assertFalse(criteria(input_ids, scores, needs_tensor_output=False)) criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) - self.assertTrue(criteria(input_ids, scores)) + self.assertTrue(all(criteria(input_ids, scores, needs_tensor_output=True))) + self.assertTrue(criteria(input_ids, scores, needs_tensor_output=False)) + + def test_eos_token_criteria(self): + criteria = EosTokenCriteria(eos_token_id=0) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 0 + self.assertTrue(all(criteria(input_ids, scores, needs_tensor_output=True))) + self.assertTrue(criteria(input_ids, scores, needs_tensor_output=False)) + + input_ids, scores = self._get_tensors(5) + input_ids[:2, -1] = 0 + input_ids[2, -1] = 1 + self.assertListEqual(criteria(input_ids, scores, needs_tensor_output=True).tolist(), [True, True, False]) + self.assertFalse(criteria(input_ids, scores, needs_tensor_output=False)) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 1 + self.assertListEqual(criteria(input_ids, scores, needs_tensor_output=True).tolist(), [False, False, False]) + self.assertFalse(criteria(input_ids, scores, needs_tensor_output=False)) def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 6234b0950..8ffbad89d 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -44,7 +44,6 @@ ImageGPTForCausalImageModeling, PreTrainedModel, SpeechEncoderDecoderModel, - top_k_top_p_filtering, ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, @@ -2485,136 +2484,6 @@ def _check_sequence_inside_sequence(self, tensor_1, tensor_2): self.assertTrue(flag) -@require_torch -class UtilsFunctionsTest(unittest.TestCase): - # tests whether the top_k_top_p function behaves as expected - def test_top_k_top_p_filtering(self): - logits = torch.tensor( - [ - [ - 8.2220991, # 3rd highest value; idx. 0 - -0.5620044, - 5.23229752, - 4.0386393, - -6.8798378, - -0.54785802, - -3.2012153, - 2.92777176, - 1.88171953, - 7.35341276, - 8.43207833, # 2nd highest value; idx. 10 - -9.85711836, - -5.96209236, - -1.13039161, - -7.1115294, - -0.8369633, - -5.3186408, - 7.06427407, - 0.81369344, - -0.82023817, - -5.9179796, - 0.58813443, - -6.99778438, - 4.71551189, - -0.18771637, - 7.44020759, # 4th highest value; idx. 25 - 9.38450987, # 1st highest value; idx. 26 - 2.12662941, - -9.32562038, - 2.35652522, - ], # cummulative prob of 4 highest values <= 0.6 - [ - 0.58425518, - 4.53139238, - -5.57510464, - -6.28030699, - -7.19529503, - -4.02122551, - 1.39337037, - -6.06707057, - 1.59480517, - -9.643119, - 0.03907799, - 0.67231762, - -8.88206726, - 6.27115922, # 4th highest value; idx. 13 - 2.28520723, - 4.82767506, - 4.30421368, - 8.8275313, # 2nd highest value; idx. 17 - 5.44029958, - -4.4735794, - 7.38579536, # 3rd highest value; idx. 20 - -2.91051663, - 2.61946077, - -2.5674762, - -9.48959302, - -4.02922645, - -1.35416918, - 9.67702323, # 1st highest value; idx. 27 - -5.89478553, - 1.85370467, - ], # cummulative prob of 4 highest values <= 0.6 - ], - dtype=torch.float, - device=torch_device, - ) - - non_inf_expected_idx = torch.tensor( - [[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]], - dtype=torch.long, - device=torch_device, - ) # expected non filtered idx as noted above - - non_inf_expected_output = torch.tensor( - [ - 8.2221, - 8.4321, - 7.4402, - 9.3845, - 6.2712, - 8.8275, - 7.3858, - 9.6770, - ], # expected non filtered values as noted above - dtype=torch.float, - device=torch_device, - ) - - output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) - non_inf_output = output[output != -float("inf")].to(device=torch_device) - non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) - - self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) - self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) - - # tests whether the function uses filter_value instead of default -inf - def test_top_k_top_p_filtering_with_filter_value(self): - logits = torch.tensor( - [ - [ - 1, - 1, - 1, - 0.99, # get filtered by top-p filtering - 0.98, # get filtered by top-k filtering - ] - ], - dtype=torch.float, - device=torch_device, - ) - - expected_output = torch.tensor( - [[1, 1, 1, 0, 0]], - dtype=torch.float, - device=torch_device, - ) - - output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0) - - self.assertTrue(torch.allclose(expected_output, output, atol=1e-12)) - - @require_torch class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): # setting framework_dependent_parameters needs to be gated, just like its contents' imports diff --git a/tests/transformers/tests/models/falcon/test_modeling_falcon.py b/tests/transformers/tests/models/falcon/test_modeling_falcon.py index 63e0081b4..6b930aa1e 100644 --- a/tests/transformers/tests/models/falcon/test_modeling_falcon.py +++ b/tests/transformers/tests/models/falcon/test_modeling_falcon.py @@ -397,9 +397,8 @@ def test_lm_generate_falcon(self): "My favorite food is pizza. I love it so much that I have a pizza party every week. I love it" ) - output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=19) + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=19, ignore_eos=True) output_str = tokenizer.batch_decode(output_ids)[0] - self.assertEqual(output_str, EXPECTED_OUTPUT) @slow diff --git a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py index 8cdbe24e9..20047d6cb 100644 --- a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py @@ -600,22 +600,24 @@ def test_batch_generation(self): ) outputs = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"].to(torch_device), + input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), ignore_eos=True ) outputs_tt = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), token_type_ids=token_type_ids, + ignore_eos=True, ) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) + output_non_padded = model.generate(input_ids=inputs_non_padded, ignore_eos=True) num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) - output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + output_padded = model.generate( + input_ids=inputs_padded, max_length=model.config.max_length - num_paddings, ignore_eos=True + ) batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) @@ -728,7 +730,7 @@ def _test_lm_generate_gpt2_helper( # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,] # fmt: skip - output_ids = model.generate(input_ids, do_sample=False) + output_ids = model.generate(input_ids, do_sample=False, ignore_eos=True) if verify_outputs: self.assertListEqual(output_ids[0].tolist(), expected_output_ids) @@ -757,11 +759,11 @@ def test_gpt2_sample(self): torch.manual_seed(0) tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) input_ids = tokenized.input_ids.to(torch_device) - output_ids = model.generate(input_ids, do_sample=True) + output_ids = model.generate(input_ids, do_sample=True, ignore_eos=True) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) token_type_ids = tokenized.token_type_ids.to(torch_device) - output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5, ignore_eos=True) output_seq_tt = model.generate( input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 ) diff --git a/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py index 9296dd72b..0c13ca285 100644 --- a/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -355,7 +355,7 @@ def test_lm_generate_gptneox(self): # See: https://github.com/huggingface/transformers/pull/24193 expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure" - output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20, ignore_eos=True) output_str = tokenizer.batch_decode(output_ids)[0] self.assertEqual(output_str, expected_output) diff --git a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py index 990b4d9b8..62c3b075c 100644 --- a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1257,8 +1257,16 @@ def get_logits(model, input_features): pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng")) torch.save(adapter_weights, pt_filepath) - model.load_adapter("eng") - model.load_adapter("eng", use_safetensors=False) + # model.load_adapter is broken in transformers + # since adapter_weights fails to load with weights_only=True + with self.assertRaises(OSError): + model.load_adapter("eng") + with self.assertRaises(OSError): + model.load_adapter("eng", use_safetensors=False) + # we will load adapter_weights directly while model.load_adapter fails + state_dict = torch.load(pt_filepath) + state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} + model.load_state_dict(state_dict, strict=False) with self.assertRaises(OSError): model.load_adapter("eng", use_safetensors=True)