Actual source code: bvblas.c

slepc-3.19.2 2023-09-05
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    BV private kernels that use the BLAS
 12: */

 14: #include <slepc/private/bvimpl.h>
 15: #include <slepcblaslapack.h>

 17: #define BLOCKSIZE 64

 19: /*
 20:     C := alpha*A*B + beta*C

 22:     A is mxk (ld=m), B is kxn (ld=ldb), C is mxn (ld=m)
 23: */
 24: PetscErrorCode BVMult_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscInt ldb_,PetscScalar alpha,const PetscScalar *A,const PetscScalar *B,PetscScalar beta,PetscScalar *C)
 25: {
 26:   PetscBLASInt   m,n,k,ldb;
 27: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
 28:   PetscBLASInt   l,bs=BLOCKSIZE;
 29: #endif

 31:   PetscFunctionBegin;
 32:   PetscCall(PetscBLASIntCast(m_,&m));
 33:   PetscCall(PetscBLASIntCast(n_,&n));
 34:   PetscCall(PetscBLASIntCast(k_,&k));
 35:   PetscCall(PetscBLASIntCast(ldb_,&ldb));
 36: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
 37:   l = m % bs;
 38:   if (l) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&l,&n,&k,&alpha,(PetscScalar*)A,&m,(PetscScalar*)B,&ldb,&beta,C,&m));
 39:   for (;l<m;l+=bs) {
 40:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&bs,&n,&k,&alpha,(PetscScalar*)A+l,&m,(PetscScalar*)B,&ldb,&beta,C+l,&m));
 41:   }
 42: #else
 43:   if (m) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&m,&n,&k,&alpha,(PetscScalar*)A,&m,(PetscScalar*)B,&ldb,&beta,C,&m));
 44: #endif
 45:   PetscCall(PetscLogFlops(2.0*m*n*k));
 46:   PetscFunctionReturn(PETSC_SUCCESS);
 47: }

 49: /*
 50:     y := alpha*A*x + beta*y

 52:     A is nxk (ld=n)
 53: */
 54: PetscErrorCode BVMultVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,const PetscScalar *x,PetscScalar beta,PetscScalar *y)
 55: {
 56:   PetscBLASInt   n,k,one=1;

 58:   PetscFunctionBegin;
 59:   PetscCall(PetscBLASIntCast(n_,&n));
 60:   PetscCall(PetscBLASIntCast(k_,&k));
 61:   if (n) PetscCallBLAS("BLASgemv",BLASgemv_("N",&n,&k,&alpha,A,&n,x,&one,&beta,y,&one));
 62:   PetscCall(PetscLogFlops(2.0*n*k));
 63:   PetscFunctionReturn(PETSC_SUCCESS);
 64: }

 66: /*
 67:     A(:,s:e-1) := A*B(:,s:e-1)

 69:     A is mxk (ld=m), B is kxn (ld=ldb)  n=e-s
 70: */
 71: PetscErrorCode BVMultInPlace_BLAS_Private(BV bv,PetscInt m_,PetscInt k_,PetscInt ldb_,PetscInt s,PetscInt e,PetscScalar *A,const PetscScalar *B,PetscBool btrans)
 72: {
 73:   PetscScalar    *pb,zero=0.0,one=1.0;
 74:   PetscBLASInt   m,n,k,l,ldb,bs=BLOCKSIZE;
 75:   PetscInt       j,n_=e-s;
 76:   const char     *bt;

 78:   PetscFunctionBegin;
 79:   PetscCall(PetscBLASIntCast(m_,&m));
 80:   PetscCall(PetscBLASIntCast(n_,&n));
 81:   PetscCall(PetscBLASIntCast(k_,&k));
 82:   PetscCall(PetscBLASIntCast(ldb_,&ldb));
 83:   PetscCall(BVAllocateWork_Private(bv,BLOCKSIZE*n_));
 84:   if (PetscUnlikely(btrans)) {
 85:     pb = (PetscScalar*)B+s;
 86:     bt = "C";
 87:   } else {
 88:     pb = (PetscScalar*)B+s*ldb;
 89:     bt = "N";
 90:   }
 91:   l = m % bs;
 92:   if (l) {
 93:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&k,&one,A,&m,pb,&ldb,&zero,bv->work,&l));
 94:     for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*m,bv->work+j*l,l));
 95:   }
 96:   for (;l<m;l+=bs) {
 97:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&k,&one,A+l,&m,pb,&ldb,&zero,bv->work,&bs));
 98:     for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*m+l,bv->work+j*bs,bs));
 99:   }
100:   PetscCall(PetscLogFlops(2.0*m*n*k));
101:   PetscFunctionReturn(PETSC_SUCCESS);
102: }

104: /*
105:     V := V*B

107:     V is mxn (ld=m), B is nxn (ld=k)
108: */
109: PetscErrorCode BVMultInPlace_Vecs_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,Vec *V,const PetscScalar *B,PetscBool btrans)
110: {
111:   PetscScalar       zero=0.0,one=1.0,*out,*pout;
112:   const PetscScalar *pin;
113:   PetscBLASInt      m = 0,n,k,l,bs=BLOCKSIZE;
114:   PetscInt          j;
115:   const char        *bt;

117:   PetscFunctionBegin;
118:   PetscCall(PetscBLASIntCast(m_,&m));
119:   PetscCall(PetscBLASIntCast(n_,&n));
120:   PetscCall(PetscBLASIntCast(k_,&k));
121:   PetscCall(BVAllocateWork_Private(bv,2*BLOCKSIZE*n_));
122:   out = bv->work+BLOCKSIZE*n_;
123:   if (btrans) bt = "C";
124:   else bt = "N";
125:   l = m % bs;
126:   if (l) {
127:     for (j=0;j<n;j++) {
128:       PetscCall(VecGetArrayRead(V[j],&pin));
129:       PetscCall(PetscArraycpy(bv->work+j*l,pin,l));
130:       PetscCall(VecRestoreArrayRead(V[j],&pin));
131:     }
132:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&n,&one,bv->work,&l,(PetscScalar*)B,&k,&zero,out,&l));
133:     for (j=0;j<n;j++) {
134:       PetscCall(VecGetArray(V[j],&pout));
135:       PetscCall(PetscArraycpy(pout,out+j*l,l));
136:       PetscCall(VecRestoreArray(V[j],&pout));
137:     }
138:   }
139:   for (;l<m;l+=bs) {
140:     for (j=0;j<n;j++) {
141:       PetscCall(VecGetArrayRead(V[j],&pin));
142:       PetscCall(PetscArraycpy(bv->work+j*bs,pin+l,bs));
143:       PetscCall(VecRestoreArrayRead(V[j],&pin));
144:     }
145:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&n,&one,bv->work,&bs,(PetscScalar*)B,&k,&zero,out,&bs));
146:     for (j=0;j<n;j++) {
147:       PetscCall(VecGetArray(V[j],&pout));
148:       PetscCall(PetscArraycpy(pout+l,out+j*bs,bs));
149:       PetscCall(VecRestoreArray(V[j],&pout));
150:     }
151:   }
152:   PetscCall(PetscLogFlops(2.0*n*n*k));
153:   PetscFunctionReturn(PETSC_SUCCESS);
154: }

156: /*
157:     B := alpha*A + beta*B

159:     A,B are nxk (ld=n)
160: */
161: PetscErrorCode BVAXPY_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscScalar beta,PetscScalar *B)
162: {
163:   PetscBLASInt   m,one=1;

165:   PetscFunctionBegin;
166:   PetscCall(PetscBLASIntCast(n_*k_,&m));
167:   if (beta!=(PetscScalar)1.0) {
168:     PetscCallBLAS("BLASscal",BLASscal_(&m,&beta,B,&one));
169:     PetscCall(PetscLogFlops(m));
170:   }
171:   PetscCallBLAS("BLASaxpy",BLASaxpy_(&m,&alpha,A,&one,B,&one));
172:   PetscCall(PetscLogFlops(2.0*m));
173:   PetscFunctionReturn(PETSC_SUCCESS);
174: }

176: /*
177:     C := A'*B

179:     A' is mxk (ld=k), B is kxn (ld=k), C is mxn (ld=ldc)
180: */
181: PetscErrorCode BVDot_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscInt ldc_,const PetscScalar *A,const PetscScalar *B,PetscScalar *C,PetscBool mpi)
182: {
183:   PetscScalar    zero=0.0,one=1.0,*CC;
184:   PetscBLASInt   m,n,k,ldc,j;
185:   PetscMPIInt    len;

187:   PetscFunctionBegin;
188:   PetscCall(PetscBLASIntCast(m_,&m));
189:   PetscCall(PetscBLASIntCast(n_,&n));
190:   PetscCall(PetscBLASIntCast(k_,&k));
191:   PetscCall(PetscBLASIntCast(ldc_,&ldc));
192:   if (mpi) {
193:     if (ldc==m) {
194:       PetscCall(BVAllocateWork_Private(bv,m*n));
195:       if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,bv->work,&ldc));
196:       else PetscCall(PetscArrayzero(bv->work,m*n));
197:       PetscCall(PetscMPIIntCast(m*n,&len));
198:       PetscCall(MPIU_Allreduce(bv->work,C,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
199:     } else {
200:       PetscCall(BVAllocateWork_Private(bv,2*m*n));
201:       CC = bv->work+m*n;
202:       if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,bv->work,&m));
203:       else PetscCall(PetscArrayzero(bv->work,m*n));
204:       PetscCall(PetscMPIIntCast(m*n,&len));
205:       PetscCall(MPIU_Allreduce(bv->work,CC,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
206:       for (j=0;j<n;j++) PetscCall(PetscArraycpy(C+j*ldc,CC+j*m,m));
207:     }
208:   } else {
209:     if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&k,(PetscScalar*)B,&k,&zero,C,&ldc));
210:   }
211:   PetscCall(PetscLogFlops(2.0*m*n*k));
212:   PetscFunctionReturn(PETSC_SUCCESS);
213: }

215: /*
216:     y := A'*x

218:     A is nxk (ld=n)
219: */
220: PetscErrorCode BVDotVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,const PetscScalar *A,const PetscScalar *x,PetscScalar *y,PetscBool mpi)
221: {
222:   PetscScalar    zero=0.0,done=1.0;
223:   PetscBLASInt   n,k,one=1;
224:   PetscMPIInt    len;

226:   PetscFunctionBegin;
227:   PetscCall(PetscBLASIntCast(n_,&n));
228:   PetscCall(PetscBLASIntCast(k_,&k));
229:   if (mpi) {
230:     PetscCall(BVAllocateWork_Private(bv,k));
231:     if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&n,x,&one,&zero,bv->work,&one));
232:     else PetscCall(PetscArrayzero(bv->work,k));
233:     PetscCall(PetscMPIIntCast(k,&len));
234:     PetscCall(MPIU_Allreduce(bv->work,y,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
235:   } else {
236:     if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&n,x,&one,&zero,y,&one));
237:   }
238:   PetscCall(PetscLogFlops(2.0*n*k));
239:   PetscFunctionReturn(PETSC_SUCCESS);
240: }

242: /*
243:     Scale n scalars
244: */
245: PetscErrorCode BVScale_BLAS_Private(BV bv,PetscInt n_,PetscScalar *A,PetscScalar alpha)
246: {
247:   PetscBLASInt   n,one=1;

249:   PetscFunctionBegin;
250:   if (PetscUnlikely(alpha == (PetscScalar)0.0)) PetscCall(PetscArrayzero(A,n_));
251:   else if (alpha!=(PetscScalar)1.0) {
252:     PetscCall(PetscBLASIntCast(n_,&n));
253:     PetscCallBLAS("BLASscal",BLASscal_(&n,&alpha,A,&one));
254:     PetscCall(PetscLogFlops(n));
255:   }
256:   PetscFunctionReturn(PETSC_SUCCESS);
257: }