@@ -725,63 +725,82 @@ namespace dftefe
725725 d_diagonalInv, 1 );
726726
727727 // Now form the enrichment block matrix.
728+
728729 // utils::MemoryStorage<ValueTypeOperator, memorySpace>
729730 // basisOverlapInvEnrichmentBlockExact(
730731 // d_nglobalEnrichmentIds * d_nglobalEnrichmentIds);
731732
732- utils::MemoryStorage<ValueTypeOperator, memorySpace>
733- basisOverlapInvEnrichmentBlock (d_nglobalEnrichmentIds *
734- d_nglobalEnrichmentIds,
735- 0 );
733+ // utils::MemoryStorage<ValueTypeOperator, memorySpace>
734+ // basisOverlapEnrichmentBlock(d_nglobalEnrichmentIds *
735+ // d_nglobalEnrichmentIds,
736+ // 0);
737+
738+ // size_type cellId = 0;
739+ // size_type cumulativeBasisDataInCells = 0;
740+ // for (auto enrichmentVecInCell :
741+ // efebasisDofHandler.getEnrichmentIdsPartition()
742+ // ->overlappingEnrichmentIdsInCells())
743+ // {
744+ // size_type nCellEnrichmentDofs = enrichmentVecInCell.size();
745+ // for (unsigned int j = 0; j < nCellEnrichmentDofs; j++)
746+ // {
747+ // for (unsigned int k = 0; k < nCellEnrichmentDofs; k++)
748+ // {
749+ // // *(basisOverlapInvEnrichmentBlockExact.data() +
750+ // // enrichmentVecInCell[j] * d_nglobalEnrichmentIds +
751+ // // enrichmentVecInCell[k]) +=
752+ // // *(NiNjInAllCells.data() + cumulativeBasisDataInCells +
753+ // // (numCellClassicalDofs + nCellEnrichmentDofs) *
754+ // // (numCellClassicalDofs + j) +
755+ // // numCellClassicalDofs + k);
756+
757+ // basis::EnrichmentIdAttribute eIdAttrj =
758+ // efeBDH->getEnrichmentIdsPartition()
759+ // ->getEnrichmentIdAttribute(enrichmentVecInCell[j]);
760+
761+ // basis::EnrichmentIdAttribute eIdAttrk =
762+ // efeBDH->getEnrichmentIdsPartition()
763+ // ->getEnrichmentIdAttribute(enrichmentVecInCell[k]);
764+
765+ // if (eIdAttrj.atomId == eIdAttrk.atomId)
766+ // {
767+ // *(basisOverlapEnrichmentBlock.data() +
768+ // enrichmentVecInCell[j] * d_nglobalEnrichmentIds +
769+ // enrichmentVecInCell[k]) +=
770+ // *(NiNjInAllCells.data() + cumulativeBasisDataInCells
771+ // +
772+ // (numCellClassicalDofs + nCellEnrichmentDofs) *
773+ // (numCellClassicalDofs + j) +
774+ // numCellClassicalDofs + k);
775+ // }
776+ // }
777+ // }
778+ // cumulativeBasisDataInCells += utils::mathFunctions::sizeTypePow(
779+ // (nCellEnrichmentDofs + numCellClassicalDofs), 2);
780+ // cellId += 1;
781+ // }
736782
737- size_type cellId = 0 ;
738- size_type cumulativeBasisDataInCells = 0 ;
739- for (auto enrichmentVecInCell :
740- efebasisDofHandler.getEnrichmentIdsPartition ()
741- ->overlappingEnrichmentIdsInCells ())
742- {
743- size_type nCellEnrichmentDofs = enrichmentVecInCell.size ();
744- for (unsigned int j = 0 ; j < nCellEnrichmentDofs; j++)
745- {
746- for (unsigned int k = 0 ; k < nCellEnrichmentDofs; k++)
747- {
748- // *(basisOverlapInvEnrichmentBlockExact.data() +
749- // enrichmentVecInCell[j] * d_nglobalEnrichmentIds +
750- // enrichmentVecInCell[k]) +=
751- // *(NiNjInAllCells.data() + cumulativeBasisDataInCells +
752- // (numCellClassicalDofs + nCellEnrichmentDofs) *
753- // (numCellClassicalDofs + j) +
754- // numCellClassicalDofs + k);
755-
756- basis::EnrichmentIdAttribute eIdAttrj =
757- efeBDH->getEnrichmentIdsPartition ()
758- ->getEnrichmentIdAttribute (enrichmentVecInCell[j]);
759-
760- basis::EnrichmentIdAttribute eIdAttrk =
761- efeBDH->getEnrichmentIdsPartition ()
762- ->getEnrichmentIdAttribute (enrichmentVecInCell[k]);
763-
764- if (eIdAttrj.atomId == eIdAttrk.atomId )
765- {
766- *(basisOverlapInvEnrichmentBlock.data () +
767- enrichmentVecInCell[j] * d_nglobalEnrichmentIds +
768- enrichmentVecInCell[k]) +=
769- *(NiNjInAllCells.data () + cumulativeBasisDataInCells +
770- (numCellClassicalDofs + nCellEnrichmentDofs) *
771- (numCellClassicalDofs + j) +
772- numCellClassicalDofs + k);
773- }
774- }
775- }
776- cumulativeBasisDataInCells += utils::mathFunctions::sizeTypePow (
777- (nCellEnrichmentDofs + numCellClassicalDofs), 2 );
778- cellId += 1 ;
779- }
783+ // // int err = utils::mpi::MPIAllreduce<memorySpace>(
784+ // // utils::mpi::MPIInPlace,
785+ // // basisOverlapInvEnrichmentBlockExact.data(),
786+ // // basisOverlapInvEnrichmentBlockExact.size(),
787+ // // utils::mpi::MPIDouble,
788+ // // utils::mpi::MPISum,
789+ // // d_feBasisManager->getMPIPatternP2P()->mpiCommunicator());
790+ // // std::pair<bool, std::string> mpiIsSuccessAndMsg =
791+ // // utils::mpi::MPIErrIsSuccessAndMsg(err);
792+ // // utils::throwException(mpiIsSuccessAndMsg.first,
793+ // // "MPI Error:" + mpiIsSuccessAndMsg.second);
794+
795+ // // linearAlgebra::blasLapack::inverse<ValueTypeOperator, memorySpace>(
796+ // // d_nglobalEnrichmentIds,
797+ // // basisOverlapInvEnrichmentBlockExact.data(),
798+ // // *(d_diagonalInv.getLinAlgOpContext()));
780799
781800 // int err = utils::mpi::MPIAllreduce<memorySpace>(
782801 // utils::mpi::MPIInPlace,
783- // basisOverlapInvEnrichmentBlockExact .data(),
784- // basisOverlapInvEnrichmentBlockExact .size(),
802+ // basisOverlapEnrichmentBlock .data(),
803+ // basisOverlapEnrichmentBlock .size(),
785804 // utils::mpi::MPIDouble,
786805 // utils::mpi::MPISum,
787806 // d_feBasisManager->getMPIPatternP2P()->mpiCommunicator());
@@ -790,27 +809,46 @@ namespace dftefe
790809 // utils::throwException(mpiIsSuccessAndMsg.first,
791810 // "MPI Error:" + mpiIsSuccessAndMsg.second);
792811
793- // linearAlgebra::blasLapack::inverse<ValueTypeOperator, memorySpace>(
794- // d_nglobalEnrichmentIds,
795- // basisOverlapInvEnrichmentBlockExact.data(),
796- // *(d_diagonalInv.getLinAlgOpContext()));
797-
798- int err = utils::mpi::MPIAllreduce<memorySpace>(
799- utils::mpi::MPIInPlace,
800- basisOverlapInvEnrichmentBlock.data (),
801- basisOverlapInvEnrichmentBlock.size (),
802- utils::mpi::MPIDouble,
803- utils::mpi::MPISum,
804- d_feBasisManager->getMPIPatternP2P ()->mpiCommunicator ());
805- std::pair<bool , std::string> mpiIsSuccessAndMsg =
806- utils::mpi::MPIErrIsSuccessAndMsg (err);
807- utils::throwException (mpiIsSuccessAndMsg.first ,
808- " MPI Error:" + mpiIsSuccessAndMsg.second );
809-
810- linearAlgebra::blasLapack::inverse<ValueTypeOperator, memorySpace>(
811- d_nglobalEnrichmentIds,
812- basisOverlapInvEnrichmentBlock.data (),
813- *(d_diagonalInv.getLinAlgOpContext ()));
812+ // global_size_type globalEnrichmentStartId =
813+ // efeBDH->getGlobalRanges()[1].first;
814+
815+ // std::pair<global_size_type, global_size_type> locOwnEidPair{
816+ // efeBDH->getLocallyOwnedRanges()[1].first - globalEnrichmentStartId,
817+ // efeBDH->getLocallyOwnedRanges()[1].second - globalEnrichmentStartId};
818+
819+ // global_size_type nlocallyOwnedEnrichmentIds =
820+ // locOwnEidPair.second - locOwnEidPair.first;
821+
822+ // d_atomBlockEnrichmentOverlapInv.resize(nlocallyOwnedEnrichmentIds *
823+ // nlocallyOwnedEnrichmentIds);
824+
825+ // for (global_size_type i = 0; i < nlocallyOwnedEnrichmentIds; i++)
826+ // {
827+ // for (global_size_type j = 0; j < nlocallyOwnedEnrichmentIds; j++)
828+ // {
829+ // *(d_atomBlockEnrichmentOverlapInv.data() +
830+ // i * nlocallyOwnedEnrichmentIds + j) =
831+ // *(basisOverlapEnrichmentBlock.data() +
832+ // (i + locOwnEidPair.first) * d_nglobalEnrichmentIds +
833+ // (j + locOwnEidPair.first));
834+ // }
835+ // }
836+
837+
838+ std::pair<global_size_type, global_size_type> locOwnPair =
839+ efebasisDofHandler.getEnrichmentIdsPartition ()
840+ ->locallyOwnedEnrichmentIds ();
841+
842+ std::vector<global_size_type> ghostVec =
843+ efebasisDofHandler.getEnrichmentIdsPartition ()->ghostEnrichmentIds ();
844+
845+ std::shared_ptr<const utils::mpi::MPIPatternP2P<memorySpace>>
846+ mpiPatternP2P =
847+ std::make_shared<const utils::mpi::MPIPatternP2P<memorySpace>>(
848+ std::vector<std::pair<global_size_type, global_size_type>>{
849+ locOwnPair},
850+ ghostVec,
851+ d_feBasisManager->getMPIPatternP2P ()->mpiCommunicator ());
814852
815853 global_size_type globalEnrichmentStartId =
816854 efeBDH->getGlobalRanges ()[1 ].first ;
@@ -822,21 +860,100 @@ namespace dftefe
822860 global_size_type nlocallyOwnedEnrichmentIds =
823861 locOwnEidPair.second - locOwnEidPair.first ;
824862
863+ size_type nLocalEnrichmentIds =
864+ ghostVec.size () + nlocallyOwnedEnrichmentIds;
865+
825866 d_atomBlockEnrichmentOverlapInv.resize (nlocallyOwnedEnrichmentIds *
826867 nlocallyOwnedEnrichmentIds);
827868
828- for (global_size_type i = 0 ; i < nlocallyOwnedEnrichmentIds; i++)
869+ // utils::MemoryStorage<ValueTypeOperator, memorySpace>
870+ // atomBlockEnrichmentOverlap(nlocallyOwnedEnrichmentIds *
871+ // nlocallyOwnedEnrichmentIds);
872+
873+ global_size_type enrichBatchSize = 5000 ;
874+ for (global_size_type enrichStartId = 0 ;
875+ enrichStartId < d_nglobalEnrichmentIds;
876+ enrichStartId += enrichBatchSize)
829877 {
830- for (global_size_type j = 0 ; j < nlocallyOwnedEnrichmentIds; j++)
878+ const size_type enrichEndId =
879+ std::min (enrichStartId + enrichBatchSize, d_nglobalEnrichmentIds);
880+ const size_type numEnrichInBatch = enrichEndId - enrichStartId;
881+
882+ linearAlgebra::MultiVector<ValueType, memorySpace>
883+ basisOverlapEnrichmentBlock (mpiPatternP2P,
884+ linAlgOpContext,
885+ numEnrichInBatch);
886+
887+ size_type cellId = 0 ;
888+ size_type cumulativeBasisDataInCells = 0 ;
889+ for (auto enrichmentVecInCell :
890+ efebasisDofHandler.getEnrichmentIdsPartition ()
891+ ->overlappingEnrichmentIdsInCells ())
831892 {
832- *(d_atomBlockEnrichmentOverlapInv.data () +
833- i * nlocallyOwnedEnrichmentIds + j) =
834- *(basisOverlapInvEnrichmentBlock.data () +
835- (i + locOwnEidPair.first ) * d_nglobalEnrichmentIds +
836- (j + locOwnEidPair.first ));
893+ size_type nCellEnrichmentDofs = enrichmentVecInCell.size ();
894+ for (unsigned int j = 0 ; j < nCellEnrichmentDofs; j++)
895+ {
896+ for (unsigned int k = 0 ; k < nCellEnrichmentDofs; k++)
897+ {
898+ if (enrichmentVecInCell[k] >= enrichStartId &&
899+ enrichmentVecInCell[k] < enrichEndId)
900+ {
901+ basis::EnrichmentIdAttribute eIdAttrj =
902+ efeBDH->getEnrichmentIdsPartition ()
903+ ->getEnrichmentIdAttribute (
904+ enrichmentVecInCell[j]);
905+
906+ basis::EnrichmentIdAttribute eIdAttrk =
907+ efeBDH->getEnrichmentIdsPartition ()
908+ ->getEnrichmentIdAttribute (
909+ enrichmentVecInCell[k]);
910+
911+ if (eIdAttrj.atomId == eIdAttrk.atomId )
912+ {
913+ *(basisOverlapEnrichmentBlock.data () +
914+ mpiPatternP2P->globalToLocal (
915+ enrichmentVecInCell[j]) *
916+ numEnrichInBatch +
917+ (enrichmentVecInCell[k] - enrichStartId)) +=
918+ *(NiNjInAllCells.data () +
919+ cumulativeBasisDataInCells +
920+ (numCellClassicalDofs + nCellEnrichmentDofs) *
921+ (numCellClassicalDofs + j) +
922+ numCellClassicalDofs + k);
923+ }
924+ }
925+ }
926+ }
927+ cumulativeBasisDataInCells += utils::mathFunctions::sizeTypePow (
928+ (nCellEnrichmentDofs + numCellClassicalDofs), 2 );
929+ cellId += 1 ;
930+ }
931+
932+ basisOverlapEnrichmentBlock.accumulateAddLocallyOwned ();
933+
934+ for (global_size_type i = 0 ; i < nlocallyOwnedEnrichmentIds; i++)
935+ {
936+ for (global_size_type j = 0 ; j < nlocallyOwnedEnrichmentIds; j++)
937+ {
938+ if ((j + locOwnEidPair.first ) >= enrichStartId &&
939+ (j + locOwnEidPair.first ) < enrichEndId)
940+ *(d_atomBlockEnrichmentOverlapInv.data () +
941+ i * nlocallyOwnedEnrichmentIds + j) =
942+ *(basisOverlapEnrichmentBlock.data () +
943+ mpiPatternP2P->globalToLocal (i + locOwnEidPair.first ) *
944+ numEnrichInBatch +
945+ (j + locOwnEidPair.first ) - enrichStartId);
946+ }
837947 }
838948 }
839949
950+ if (nlocallyOwnedEnrichmentIds > 0 )
951+ linearAlgebra::LapackError ret =
952+ linearAlgebra::blasLapack::inverse<ValueTypeOperator, memorySpace>(
953+ nlocallyOwnedEnrichmentIds,
954+ d_atomBlockEnrichmentOverlapInv.data (),
955+ *(d_diagonalInv.getLinAlgOpContext ()));
956+
840957 int rank;
841958 utils::mpi::MPICommRank (
842959 d_feBasisManager->getMPIPatternP2P ()->mpiCommunicator (), &rank);
0 commit comments