-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmttkrp_compute.c
75 lines (68 loc) · 2.91 KB
/
mttkrp_compute.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Generated by the Tensor Algebra Compiler (tensor-compiler.org)
// taco "A(i,j)=B(i,k,l)*D(l,j)*C(k,j)" -f=A:dd:0,1 -f=B:sss:0,1,2 -f=D:dd:0,1 -f=C:dd:0,1 -s="reorder(i,k,l,j)" -s="precompute(B(i,k,l)*D(l,j),j,j)" -s="split(i,i0,i1,32)" -s="parallelize(i0,CPUThread,NoRaces)" -write-source=taco_kernel.c -write-compute=taco_compute.c -write-assembly=taco_assembly.c
int compute(taco_tensor_t *A, taco_tensor_t *B, taco_tensor_t *D, taco_tensor_t *C) {
int A1_dimension = (int)(A->dimensions[0]);
int A2_dimension = (int)(A->dimensions[1]);
double* restrict A_vals = (double*)(A->vals);
int B1_dimension = (int)(B->dimensions[0]);
int* restrict B1_pos = (int*)(B->indices[0][0]);
int* restrict B1_crd = (int*)(B->indices[0][1]);
int* restrict B2_pos = (int*)(B->indices[1][0]);
int* restrict B2_crd = (int*)(B->indices[1][1]);
int* restrict B3_pos = (int*)(B->indices[2][0]);
int* restrict B3_crd = (int*)(B->indices[2][1]);
double* restrict B_vals = (double*)(B->vals);
int D1_dimension = (int)(D->dimensions[0]);
int D2_dimension = (int)(D->dimensions[1]);
double* restrict D_vals = (double*)(D->vals);
int C1_dimension = (int)(C->dimensions[0]);
int C2_dimension = (int)(C->dimensions[1]);
double* restrict C_vals = (double*)(C->vals);
#pragma omp parallel for schedule(static)
for (int32_t pA = 0; pA < (A1_dimension * A2_dimension); pA++) {
A_vals[pA] = 0.0;
}
#pragma omp parallel for schedule(runtime)
for (int32_t i0 = 0; i0 < ((B1_dimension + 31) / 32); i0++) {
int32_t pB1_begin = i0 * 32;
int32_t iB = taco_binarySearchAfter(B1_crd, B1_pos[0], B1_pos[1], pB1_begin);
int32_t pB1_end = B1_pos[1];
int32_t iB0 = B1_crd[iB];
int32_t i = B1_crd[iB];
int32_t i1 = i - i0 * 32;
int32_t i1_end = 32;
while (iB < pB1_end && i1 < i1_end) {
iB0 = B1_crd[iB];
i = B1_crd[iB];
if (iB0 == i) {
double* restrict workspace = 0;
workspace = (double*)malloc(sizeof(double) * C2_dimension);
for (int32_t kB = B2_pos[iB]; kB < B2_pos[(iB + 1)]; kB++) {
int32_t k = B2_crd[kB];
for (int32_t pworkspace = 0; pworkspace < C2_dimension; pworkspace++) {
workspace[pworkspace] = 0.0;
}
for (int32_t lB = B3_pos[kB]; lB < B3_pos[(kB + 1)]; lB++) {
int32_t l = B3_crd[lB];
for (int32_t j = 0; j < C2_dimension; j++) {
int32_t jD = l * D2_dimension + j;
workspace[j] = workspace[j] + B_vals[lB] * D_vals[jD];
}
}
for (int32_t j = 0; j < C2_dimension; j++) {
int32_t jA = i * A2_dimension + j;
int32_t jC = k * C2_dimension + j;
A_vals[jA] = A_vals[jA] + workspace[j] * C_vals[jC];
}
}
free(workspace);
}
iB += (int32_t)(iB0 == i);
iB0 = B1_crd[iB];
i = B1_crd[iB];
i1 = i - i0 * 32;
}
}
A->vals = (uint8_t*)A_vals;
return 0;
}