- prefill
- decode
prefill
Axes (8 dimensions):total_q,num_pages,len_indptr,num_kv_indices: variablenum_qo_heads,num_kv_heads,head_dim,page_size: constant
q: query tensor [total_q, num_qo_heads, head_dim]k_cache,v_cache: paged KV cache [num_pages, page_size, num_kv_heads, head_dim]qo_indptr,kv_indptr,kv_indices: paging indicessm_scale: softmax scale (scalar)
output: attention output [total_q, num_qo_heads, head_dim]lse: log-sum-exp values [total_q, num_qo_heads]
total_q == qo_indptr[-1]num_kv_indices = kv_indptr[-1]
decode
Axes (8 dimensions):total_q,num_pages,len_indptr,num_kv_indices: variablenum_qo_heads,num_kv_heads,head_dim,page_size: constant
q: query tensor [total_q, num_qo_heads, head_dim]k_cache,v_cache: paged KV cache [num_pages, page_size, num_kv_heads, head_dim]kv_indptr,kv_indices: paging indicessm_scale: softmax scale (scalar)
output: attention output [total_q, num_qo_heads, head_dim]lse: log-sum-exp values [total_q, num_qo_heads]
len_indptr = num_pages + 1num_kv_indices = kv_indptr[-1]

