//EINGABE
#include <sycl/sycl.hpp>
#include <sycl/ext/oneapi/experimental/matrix/matrix.hpp>
using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
template
void xmx_kern(
const scalar_t* Q_ptr,
const scalar_t* K_ptr,
const scalar_t* V_ptr,
scalar_t* Out_ptr,
int num_q,
int num_k,
int d_k,
int d_v,
int q_stride,
int k_stride,
int v_stride,
int out_stride,
nd_item<1> item
) {
sub_group sg = item.get_sub_group();
const int head_row_base = (item.get_group(0) * 16);
if (head_row_base >= num_q) return;
using t_Q = joint_matrix<sub_group, sycl::half, use::a, 16, 16, layout::row_major>;
using t_K = joint_matrix<sub_group, sycl::half, use::b, 16, 16, layout::col_major>;
using t_V = joint_matrix<sub_group, sycl::half, use::b, 16, 16, layout::row_major>;
using t_S = joint_matrix<sub_group, float, use::accumulator, 16, 16> mat_s;
using t_O = joint_matrix<sub_group, float, use::accumulator, 16, 16> mat_o;
t_Q mat_q;
t_K mat_k;
t_V mat_v;
t_S mat_s;
t_O mat_o;
joint_matrix_fill(sg, mat_s, 0.0f);
const scalar_t* q_tile_ptr = Q_ptr + head_row_base * q_stride;
joint_matrix_load(sg, mat_q, q_tile_ptr, q_stride);
const float scale_factor = 1.0f / sycl::sqrt(static_cast<float>(d_k));
//VERARBEITUNG
for (int k_idx = 0; k_idx < num_k; k_idx += 16) {
joint_matrix_fill(sg, mat_s, 0.0f);
//1.
const scalar_t* k_tile_ptr = K_ptr + k_idx * k_stride;
joint_matrix_load(sg, mat_k, k_tile_ptr, k_stride);
joint_matrix_mad(sg, mat_s, mat_q, mat_k, mat_s);
//2.
auto wi_data = get_wi_data(sg, mat_s);
//a.
float local_max = -INFINITY;
for (int i = 0; i < wi_data.length(); ++i) {
wi_data[i] *= scale_factor;
local_max = sycl::fmax(local_max, wi_data[i]);
}
float row_max_total = reduce_over_group(sg, local_max, maximum<float>());
//b.
float local_sum = 0.0f;
for (int i = 0; i < wi_data.length(); ++i) {
wi_data[i] = sycl::exp(wi_data[i] - row_max_total);
local_sum += wi_data[i];
}
//c.
float row_sum_total = reduce_over_group(sg, local_sum, plus<float>());
float inv_sum = 1.0f / (row_sum_total + 1e-6f);
//d.
for (int i = 0; i < wi_data.length(); ++i) {
wi_data[i] *= inv_sum;
}
//3.
const scalar_t* v_tile_ptr = V_ptr + k_idx * v_stride;
joint_matrix_load(sg, mat_v, v_tile_ptr, v_stride);
joint_matrix_mad(sg, mat_o, mat_s, mat_v, mat_o);
//AUSGABE
scalar_t* out_ptr = Out_ptr + head_row_base * out_stride;
joint_matrix_store(sg, mat_o, out_ptr, out_stride, layout::row_major);
}
}
//EVA=0=N,1=Y?