-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathspadd_compute.c
65 lines (61 loc) · 2.14 KB
/
spadd_compute.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// Generated by the Tensor Algebra Compiler (tensor-compiler.org)
// taco "A(i,j)=B(i,j)+C(i,j)" -f=A:ds:0,1 -f=B:ds:0,1 -f=C:ds:0,1 -s="assemble(A,Insert)" -s="parallelize(i,CPUThread,NoRaces)" -write-source=taco_kernel.c -write-compute=taco_compute.c -write-assembly=taco_assembly.c
int compute(taco_tensor_t *A, taco_tensor_t *B, taco_tensor_t *C) {
int A1_dimension = (int)(A->dimensions[0]);
int* restrict A2_pos = (int*)(A->indices[1][0]);
double* restrict A_vals = (double*)(A->vals);
int B1_dimension = (int)(B->dimensions[0]);
int* restrict B2_pos = (int*)(B->indices[1][0]);
int* restrict B2_crd = (int*)(B->indices[1][1]);
double* restrict B_vals = (double*)(B->vals);
int C1_dimension = (int)(C->dimensions[0]);
int* restrict C2_pos = (int*)(C->indices[1][0]);
int* restrict C2_crd = (int*)(C->indices[1][1]);
double* restrict C_vals = (double*)(C->vals);
#pragma omp parallel for schedule(runtime)
for (int32_t i = 0; i < C1_dimension; i++) {
int32_t jB = B2_pos[i];
int32_t pB2_end = B2_pos[(i + 1)];
int32_t jC = C2_pos[i];
int32_t pC2_end = C2_pos[(i + 1)];
while (jB < pB2_end && jC < pC2_end) {
int32_t jB0 = B2_crd[jB];
int32_t jC0 = C2_crd[jC];
int32_t j = TACO_MIN(jB0,jC0);
if (jB0 == j && jC0 == j) {
int32_t pA2 = A2_pos[i];
A2_pos[i] = A2_pos[i] + 1;
A_vals[pA2] = B_vals[jB] + C_vals[jC];
}
else if (jB0 == j) {
int32_t pA20 = A2_pos[i];
A2_pos[i] = A2_pos[i] + 1;
A_vals[pA20] = B_vals[jB];
}
else {
int32_t pA21 = A2_pos[i];
A2_pos[i] = A2_pos[i] + 1;
A_vals[pA21] = C_vals[jC];
}
jB += (int32_t)(jB0 == j);
jC += (int32_t)(jC0 == j);
}
while (jB < pB2_end) {
int32_t pA22 = A2_pos[i];
A2_pos[i] = A2_pos[i] + 1;
A_vals[pA22] = B_vals[jB];
jB++;
}
while (jC < pC2_end) {
int32_t pA23 = A2_pos[i];
A2_pos[i] = A2_pos[i] + 1;
A_vals[pA23] = C_vals[jC];
jC++;
}
}
for (int32_t p = 0; p < A1_dimension; p++) {
A2_pos[A1_dimension - p] = A2_pos[((A1_dimension - p) - 1)];
}
A2_pos[0] = 0;
return 0;
}