Skip to content

[Triton] Flash Attention Triton Windows build support #171

[Triton] Flash Attention Triton Windows build support

[Triton] Flash Attention Triton Windows build support #171

name: Flash Attention Integration
on:
push:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
- 'setup.py'
- '.github/workflows/flash_attention_integration.yaml'
pull_request:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
- 'setup.py'
- '.github/workflows/flash_attention_integration.yaml'
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.event_name != 'push' }}
env:
# TODO: Switch to Dao-AILab/flash-attention main
FA_BRANCH: micmelesse/aiter_migration
FA_REPOSITORY_URL: https://github.com/ROCm/flash-attention.git
BASE_IMAGE: rocm/pytorch:latest@sha256:683765a52c61341e1674fe730ab3be861a444a45a36c0a8caae7653a08a0e208
AITER_SUBMODULE_PATH: third_party/aiter
jobs:
check-signal:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Download and check signal artifact
run: ./.github/scripts/check_signal.sh
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_SHA: ${{ github.sha }}
prechecks:
runs-on: ubuntu-latest
outputs:
run_triton: ${{ steps.gate.outputs.run_triton }}
run_ck: ${{ steps.gate.outputs.run_ck }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
shared:
- 'setup.py'
- '.github/workflows/flash_attention_integration.yaml'
triton:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
ck:
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
- name: Compute job gates
id: gate
run: |
SHARED="${{ steps.filter.outputs.shared }}"
TRITON="${{ steps.filter.outputs.triton }}"
CK="${{ steps.filter.outputs.ck }}"
DISPATCH="${{ github.event_name == 'workflow_dispatch' }}"
if [ "$TRITON" = "true" ] || [ "$SHARED" = "true" ] || [ "$DISPATCH" = "true" ]; then
echo "run_triton=true" >> "$GITHUB_OUTPUT"
fi
if [ "$CK" = "true" ] || [ "$SHARED" = "true" ] || [ "$DISPATCH" = "true" ]; then
echo "run_ck=true" >> "$GITHUB_OUTPUT"
fi
# =============================================================================
# Triton Backend
# =============================================================================
flash_attention_triton:
if: ${{ needs.prechecks.outputs.run_triton == 'true' }}
name: Flash Attention - Triton / ${{ matrix.label }} (1 GPU)
needs: [check-signal, prechecks]
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: linux-aiter-mi355-1
label: MI355
- runner: aiter-gfx1100
label: RDNA3
steps:
- name: Checkout aiter repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Docker login
run: docker login -u rocmshared -p ${{ secrets.DOCKER_PASSWORD }} || true
- name: Pull base image
run: docker pull ${{ env.BASE_IMAGE }}
- name: Generate Dockerfile
run: |
cat <<'EOF' > Dockerfile.triton
FROM ${{ env.BASE_IMAGE }}
# Install test dependencies
RUN pip install --upgrade pip
RUN pip install pytest pytest-rerunfailures pytest-timeout einops
# Clone flash-attention and override aiter submodule with local checkout
COPY . /aiter
RUN git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
rm -rf /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cp -a /aiter /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cd /flash-attention && \
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pip install --no-build-isolation .
RUN echo "=== Installed versions ===" && \
pip show flash-attn && \
pip show amd-aiter
WORKDIR /flash-attention
EOF
- name: Show Dockerfile
run: cat Dockerfile.triton
- name: Build Docker image
run: docker build -t fa_triton_test:ci -f Dockerfile.triton .
- name: Start CI container
run: |
docker ps -aq -f name=fa_triton_test | xargs -r docker stop | xargs -r docker rm || true
if [ -f "/etc/podinfo/gha-render-devices" ]; then
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
else
DEVICE_FLAG="--device /dev/dri"
fi
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
--ipc=host --group-add video \
--network=host \
--shm-size 32g \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-w /flash-attention \
--name fa_triton_test \
fa_triton_test:ci
- name: Run correctness tests
timeout-minutes: 360
run: |
if [ "${{ github.event_name }}" = "push" ]; then
# Post-merge: full test suite
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py
"
else
# PR: core API subset (~1 hour)
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py \
-k 'test_flash_attn_output or test_flash_attn_kvcache'
"
fi
- name: Run benchmarks
timeout-minutes: 30
run: |
set -o pipefail
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
python benchmarks/benchmark_flash_attention.py
" |& tee benchmark_triton_${{ matrix.label }}.log
- name: Upload benchmark results
if: success()
uses: actions/upload-artifact@v4
with:
name: flash-attention-triton-benchmark-${{ matrix.label }}
path: benchmark_triton_${{ matrix.label }}.log
- name: Clean Up
if: always()
run: |
docker stop fa_triton_test || true
docker rm -f fa_triton_test || true
docker rmi fa_triton_test:ci || true
# =============================================================================
# CK Backend
# =============================================================================
flash_attention_ck:
if: false # Disabled until CK tests are ready
# if: ${{ needs.prechecks.outputs.run_ck == 'true' }}
name: Flash Attention - CK (1 GPU)
needs: [check-signal, prechecks]
runs-on: linux-aiter-mi355-1
steps:
- name: Checkout aiter repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Docker login
run: docker login -u rocmshared -p ${{ secrets.DOCKER_PASSWORD }} || true
- name: Pull base image
run: docker pull ${{ env.BASE_IMAGE }}
- name: Generate Dockerfile
run: |
cat <<'EOF' > Dockerfile.ck
FROM ${{ env.BASE_IMAGE }}
# Install test dependencies
RUN pip install --upgrade pip
RUN pip install pytest pytest-rerunfailures pytest-timeout einops
# Clone and install flash-attention (CK backend)
# Override aiter submodule with local checkout
COPY . /aiter
RUN git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
rm -rf /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cp -a /aiter /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cd /flash-attention && \
pip install --no-build-isolation .
RUN echo "=== Installed versions ===" && \
pip show flash-attn && \
pip show amd-aiter
WORKDIR /flash-attention
EOF
- name: Show Dockerfile
run: cat Dockerfile.ck
- name: Build Docker image
run: docker build -t fa_ck_test:ci -f Dockerfile.ck .
- name: Start CI container
run: |
docker ps -aq -f name=fa_ck_test | xargs -r docker stop | xargs -r docker rm || true
if [ -f "/etc/podinfo/gha-render-devices" ]; then
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
else
DEVICE_FLAG="--device /dev/dri"
fi
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
--ipc=host --group-add video \
--network=host \
--shm-size 32g \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-w /flash-attention \
--name fa_ck_test \
fa_ck_test:ci
- name: Run correctness tests
timeout-minutes: 360
run: |
echo "CK tests not yet implemented - to be enabled by ChunYu Lai"
# docker exec fa_ck_test bash -c "
# cd /flash-attention
# pytest -v --reruns 2 --timeout=120 tests/test_flash_attn_ck.py
# "
- name: Run benchmarks
timeout-minutes: 30
run: |
echo "CK benchmarks not yet implemented - to be enabled by ChunYu Lai"
# set -o pipefail
# docker exec fa_ck_test bash -c "
# cd /flash-attention
# python benchmarks/benchmark_flash_attention.py
# " |& tee benchmark_ck.log
- name: Upload benchmark results
if: success()
run: |
echo "CK benchmark upload not yet implemented - to be enabled by ChunYu Lai"
- name: Clean Up
if: always()
run: |
docker stop fa_ck_test || true
docker rm -f fa_ck_test || true
docker rmi fa_ck_test:ci || true