Actual source code: bvblas.c
slepc-3.19.2 2023-09-05
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: BV private kernels that use the BLAS
12: */
14: #include <slepc/private/bvimpl.h>
15: #include <slepcblaslapack.h>
17: #define BLOCKSIZE 64
19: /*
20: C := alpha*A*B + beta*C
22: A is mxk (ld=m), B is kxn (ld=ldb), C is mxn (ld=m)
23: */
24: PetscErrorCode BVMult_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscInt ldb_,PetscScalar alpha,const PetscScalar *A,const PetscScalar *B,PetscScalar beta,PetscScalar *C)
25: {
26: PetscBLASInt m,n,k,ldb;
27: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
28: PetscBLASInt l,bs=BLOCKSIZE;
29: #endif
31: PetscFunctionBegin;
32: PetscCall(PetscBLASIntCast(m_,&m));
33: PetscCall(PetscBLASIntCast(n_,&n));
34: PetscCall(PetscBLASIntCast(k_,&k));
35: PetscCall(PetscBLASIntCast(ldb_,&ldb));
36: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
37: l = m % bs;
38: if (l) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&l,&n,&k,&alpha,(PetscScalar*)A,&m,(PetscScalar*)B,&ldb,&beta,C,&m));
39: for (;l<m;l+=bs) {
40: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&bs,&n,&k,&alpha,(PetscScalar*)A+l,&m,(PetscScalar*)B,&ldb,&beta,C+l,&m));
41: }
42: #else
43: if (m) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&m,&n,&k,&alpha,(PetscScalar*)A,&m,(PetscScalar*)B,&ldb,&beta,C,&m));
44: #endif
45: PetscCall(PetscLogFlops(2.0*m*n*k));
46: PetscFunctionReturn(PETSC_SUCCESS);
47: }
49: /*
50: y := alpha*A*x + beta*y
52: A is nxk (ld=n)
53: */
54: PetscErrorCode BVMultVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,const PetscScalar *x,PetscScalar beta,PetscScalar *y)
55: {
56: PetscBLASInt n,k,one=1;
58: PetscFunctionBegin;
59: PetscCall(PetscBLASIntCast(n_,&n));
60: PetscCall(PetscBLASIntCast(k_,&k));
61: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("N",&n,&k,&alpha,A,&n,x,&one,&beta,y,&one));
62: PetscCall(PetscLogFlops(2.0*n*k));
63: PetscFunctionReturn(PETSC_SUCCESS);
64: }
66: /*
67: A(:,s:e-1) := A*B(:,s:e-1)
69: A is mxk (ld=m), B is kxn (ld=ldb) n=e-s
70: */
71: PetscErrorCode BVMultInPlace_BLAS_Private(BV bv,PetscInt m_,PetscInt k_,PetscInt ldb_,PetscInt s,PetscInt e,PetscScalar *A,const PetscScalar *B,PetscBool btrans)
72: {
73: PetscScalar *pb,zero=0.0,one=1.0;
74: PetscBLASInt m,n,k,l,ldb,bs=BLOCKSIZE;
75: PetscInt j,n_=e-s;
76: const char *bt;
78: PetscFunctionBegin;
79: PetscCall(PetscBLASIntCast(m_,&m));
80: PetscCall(PetscBLASIntCast(n_,&n));
81: PetscCall(PetscBLASIntCast(k_,&k));
82: PetscCall(PetscBLASIntCast(ldb_,&ldb));
83: PetscCall(BVAllocateWork_Private(bv,BLOCKSIZE*n_));
84: if (PetscUnlikely(btrans)) {
85: pb = (PetscScalar*)B+s;
86: bt = "C";
87: } else {
88: pb = (PetscScalar*)B+s*ldb;
89: bt = "N";
90: }
91: l = m % bs;
92: if (l) {
93: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&k,&one,A,&m,pb,&ldb,&zero,bv->work,&l));
94: for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*m,bv->work+j*l,l));
95: }
96: for (;l<m;l+=bs) {
97: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&k,&one,A+l,&m,pb,&ldb,&zero,bv->work,&bs));
98: for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*m+l,bv->work+j*bs,bs));
99: }
100: PetscCall(PetscLogFlops(2.0*m*n*k));
101: PetscFunctionReturn(PETSC_SUCCESS);
102: }
104: /*
105: V := V*B
107: V is mxn (ld=m), B is nxn (ld=k)
108: */
109: PetscErrorCode BVMultInPlace_Vecs_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,Vec *V,const PetscScalar *B,PetscBool btrans)
110: {
111: PetscScalar zero=0.0,one=1.0,*out,*pout;
112: const PetscScalar *pin;
113: PetscBLASInt m = 0,n,k,l,bs=BLOCKSIZE;
114: PetscInt j;
115: const char *bt;
117: PetscFunctionBegin;
118: PetscCall(PetscBLASIntCast(m_,&m));
119: PetscCall(PetscBLASIntCast(n_,&n));
120: PetscCall(PetscBLASIntCast(k_,&k));
121: PetscCall(BVAllocateWork_Private(bv,2*BLOCKSIZE*n_));
122: out = bv->work+BLOCKSIZE*n_;
123: if (btrans) bt = "C";
124: else bt = "N";
125: l = m % bs;
126: if (l) {
127: for (j=0;j<n;j++) {
128: PetscCall(VecGetArrayRead(V[j],&pin));
129: PetscCall(PetscArraycpy(bv->work+j*l,pin,l));
130: PetscCall(VecRestoreArrayRead(V[j],&pin));
131: }
132: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&n,&one,bv->work,&l,(PetscScalar*)B,&k,&zero,out,&l));
133: for (j=0;j<n;j++) {
134: PetscCall(VecGetArray(V[j],&pout));
135: PetscCall(PetscArraycpy(pout,out+j*l,l));
136: PetscCall(VecRestoreArray(V[j],&pout));
137: }
138: }
139: for (;l<m;l+=bs) {
140: for (j=0;j<n;j++) {
141: PetscCall(VecGetArrayRead(V[j],&pin));
142: PetscCall(PetscArraycpy(bv->work+j*bs,pin+l,bs));
143: PetscCall(VecRestoreArrayRead(V[j],&pin));
144: }
145: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&n,&one,bv->work,&bs,(PetscScalar*)B,&k,&zero,out,&bs));
146: for (j=0;j<n;j++) {
147: PetscCall(VecGetArray(V[j],&pout));
148: PetscCall(PetscArraycpy(pout+l,out+j*bs,bs));
149: PetscCall(VecRestoreArray(V[j],&pout));
150: }
151: }
152: PetscCall(PetscLogFlops(2.0*n*n*k));
153: PetscFunctionReturn(PETSC_SUCCESS);
154: }
156: /*
157: B := alpha*A + beta*B
159: A,B are nxk (ld=n)
160: */
161: PetscErrorCode BVAXPY_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscScalar beta,PetscScalar *B)
162: {
163: PetscBLASInt m,one=1;
165: PetscFunctionBegin;
166: PetscCall(PetscBLASIntCast(n_*k_,&m));
167: if (beta!=(PetscScalar)1.0) {
168: PetscCallBLAS("BLASscal",BLASscal_(&m,&beta,B,&one));
169: PetscCall(PetscLogFlops(m));
170: }
171: PetscCallBLAS("BLASaxpy",BLASaxpy_(&m,&alpha,A,&one,B,&one));
172: PetscCall(PetscLogFlops(2.0*m));
173: PetscFunctionReturn(PETSC_SUCCESS);
174: }
176: /*
177: C := A'*B
179: A' is mxk (ld=k), B is kxn (ld=k), C is mxn (ld=ldc)
180: */
181: PetscErrorCode BVDot_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscInt ldc_,const PetscScalar *A,const PetscScalar *B,PetscScalar *C,PetscBool mpi)
182: {
183: PetscScalar zero=0.0,one=1.0,*CC;
184: PetscBLASInt m,n,k,ldc,j;
185: PetscMPIInt len;
187: PetscFunctionBegin;
188: PetscCall(PetscBLASIntCast(m_,&m));
189: PetscCall(PetscBLASIntCast(n_,&n));
190: PetscCall(PetscBLASIntCast(k_,&k));
191: PetscCall(PetscBLASIntCast(ldc_,&ldc));
192: if (mpi) {
193: if (ldc==m) {
194: PetscCall(BVAllocateWork_Private(bv,m*n));
195: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,bv->work,&ldc));
196: else PetscCall(PetscArrayzero(bv->work,m*n));
197: PetscCall(PetscMPIIntCast(m*n,&len));
198: PetscCall(MPIU_Allreduce(bv->work,C,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
199: } else {
200: PetscCall(BVAllocateWork_Private(bv,2*m*n));
201: CC = bv->work+m*n;
202: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,bv->work,&m));
203: else PetscCall(PetscArrayzero(bv->work,m*n));
204: PetscCall(PetscMPIIntCast(m*n,&len));
205: PetscCall(MPIU_Allreduce(bv->work,CC,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
206: for (j=0;j<n;j++) PetscCall(PetscArraycpy(C+j*ldc,CC+j*m,m));
207: }
208: } else {
209: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,C,&ldc));
210: }
211: PetscCall(PetscLogFlops(2.0*m*n*k));
212: PetscFunctionReturn(PETSC_SUCCESS);
213: }
215: /*
216: y := A'*x
218: A is nxk (ld=n)
219: */
220: PetscErrorCode BVDotVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,const PetscScalar *A,const PetscScalar *x,PetscScalar *y,PetscBool mpi)
221: {
222: PetscScalar zero=0.0,done=1.0;
223: PetscBLASInt n,k,one=1;
224: PetscMPIInt len;
226: PetscFunctionBegin;
227: PetscCall(PetscBLASIntCast(n_,&n));
228: PetscCall(PetscBLASIntCast(k_,&k));
229: if (mpi) {
230: PetscCall(BVAllocateWork_Private(bv,k));
231: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&n,x,&one,&zero,bv->work,&one));
232: else PetscCall(PetscArrayzero(bv->work,k));
233: PetscCall(PetscMPIIntCast(k,&len));
234: PetscCall(MPIU_Allreduce(bv->work,y,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
235: } else {
236: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&n,x,&one,&zero,y,&one));
237: }
238: PetscCall(PetscLogFlops(2.0*n*k));
239: PetscFunctionReturn(PETSC_SUCCESS);
240: }
242: /*
243: Scale n scalars
244: */
245: PetscErrorCode BVScale_BLAS_Private(BV bv,PetscInt n_,PetscScalar *A,PetscScalar alpha)
246: {
247: PetscBLASInt n,one=1;
249: PetscFunctionBegin;
250: if (PetscUnlikely(alpha == (PetscScalar)0.0)) PetscCall(PetscArrayzero(A,n_));
251: else if (alpha!=(PetscScalar)1.0) {
252: PetscCall(PetscBLASIntCast(n_,&n));
253: PetscCallBLAS("BLASscal",BLASscal_(&n,&alpha,A,&one));
254: PetscCall(PetscLogFlops(n));
255: }
256: PetscFunctionReturn(PETSC_SUCCESS);
257: }