//===-- sotoc/src/Visitor.cpp ---------------------------------------------===////// The LLVM Compiler Infrastructure//// This file is distributed under the University of Illinois Open Source// License. See LICENSE.TXT for details.////===----------------------------------------------------------------------===////===----------------------------------------------------------------------===//#include<sstream>#include<string>#include"clang/AST/ASTContext.h"#include"clang/AST/Attr.h"#include"clang/AST/Decl.h"#include"clang/AST/ExprOpenMP.h"#include"clang/AST/Stmt.h"#include"clang/AST/StmtOpenMP.h"#include"clang/Basic/OpenMPKinds.h"#include"clang/Basic/SourceLocation.h"#include"clang/Basic/SourceManager.h"#include"clang/Rewrite/Core/Rewriter.h"#include"Debug.h"#include"DeclResolver.h"#include"TargetCode.h"#include"TargetCodeFragment.h"#include"Visitors.h"staticboolstmtNeedsSemicolon(constclang::Stmt*S){while(1){if(auto*CS=llvm::dyn_cast<clang::CapturedStmt>(S)){S=CS->getCapturedStmt();}elseif(auto*OS=llvm::dyn_cast<clang::OMPExecutableDirective>(S)){S=OS->getInnermostCapturedStmt();}else{break;}}if(llvm::isa<clang::CompoundStmt>(S)||llvm::isa<clang::ForStmt>(S)||llvm::isa<clang::IfStmt>(S)){returnfalse;}returntrue;}boolFindTargetCodeVisitor::TraverseDecl(clang::Decl*D){if(!D)returnfalse;if(auto*FD=llvm::dyn_cast<clang::FunctionDecl>(D)){LastVisitedFuncDecl.push(FD);}boolret=clang::RecursiveASTVisitor<FindTargetCodeVisitor>::TraverseDecl(D);if(auto*FD=llvm::dyn_cast<clang::FunctionDecl>(D)){LastVisitedFuncDecl.pop();}returnret;}boolFindTargetCodeVisitor::VisitStmt(clang::Stmt*S){if(auto*TD=llvm::dyn_cast<clang::OMPTargetDirective>(S)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetTeamsDirective>(S)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetParallelDirective>(S)){processTargetRegion(TD);}elseif(auto*LD=llvm::dyn_cast<clang::OMPLoopDirective>(S)){if(auto*TD=llvm::dyn_cast<clang::OMPTargetParallelForDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetParallelForSimdDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetSimdDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetTeamsDistributeDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetTeamsDistributeParallelForDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetTeamsDistributeParallelForSimdDirective>(LD)){processTargetRegion(TD);}elseif(auto*TD=llvm::dyn_cast<clang::OMPTargetTeamsDistributeSimdDirective>(LD)){processTargetRegion(TD);}}returntrue;}classCollectOMPClauseParamsVarsVisitor:publicclang::RecursiveASTVisitor<CollectOMPClauseParamsVarsVisitor>{std::shared_ptr<TargetCodeRegion>TCR;public:CollectOMPClauseParamsVarsVisitor(std::shared_ptr<TargetCodeRegion>&TCR):TCR(TCR){};boolVisitStmt(clang::Stmt*S){if(auto*DRE=llvm::dyn_cast<clang::DeclRefExpr>(S)){if(auto*VD=llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())){TCR->addOMPClauseParam(VD->getCanonicalDecl());}}returntrue;};};classCollectOMPClauseParamsVisitor:publicclang::RecursiveASTVisitor<CollectOMPClauseParamsVisitor>{CollectOMPClauseParamsVarsVisitorVarsVisitor;boolInExplicitCast;public:CollectOMPClauseParamsVisitor(std::shared_ptr<TargetCodeRegion>&TCR):VarsVisitor(TCR),InExplicitCast(false){};boolVisitStmt(clang::Stmt*S){// This relies on the captured statement being the last childif(llvm::isa<clang::CapturedStmt>(S)){returnfalse;}if(llvm::isa<clang::ImplicitCastExpr>(S)){InExplicitCast=true;returntrue;}auto*DRE=llvm::dyn_cast<clang::DeclRefExpr>(S);if(DRE&&InExplicitCast){if(auto*VD=llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())){VarsVisitor.TraverseStmt(VD->getInit());}}InExplicitCast=false;returntrue;};};boolFindTargetCodeVisitor::processTargetRegion(clang::OMPExecutableDirective*TargetDirective){// TODO: Not sure why to iterate the children, because I think there// is only one child. For me this looks wrong.for(autoi=TargetDirective->child_begin(),e=TargetDirective->child_end();i!=e;++i){if(auto*CS=llvm::dyn_cast<clang::CapturedStmt>(*i)){while(auto*NCS=llvm::dyn_cast<clang::CapturedStmt>(CS->getCapturedStmt())){CS=NCS;}autoTCR=std::make_shared<TargetCodeRegion>(CS,TargetDirective,LastVisitedFuncDecl.top(),Context);// if the target region cannot be added we dont want to parse its argsif(TargetCodeInfo.addCodeFragment(TCR)){FindArraySectionVisitor(TCR->CapturedLowerBounds).TraverseStmt(TargetDirective);for(autoC:TargetDirective->clauses()){TCR->addOMPClause(C);}// For more complex data types (like structs) we need to traverse the// treeDiscoverTypeVisitor.TraverseStmt(CS);DiscoverFunctionVisitor.TraverseStmt(CS);addTargetRegionArgs(CS,TargetDirective,TCR);TCR->NeedsSemicolon=stmtNeedsSemicolon(CS);TCR->TargetCodeKind=TargetDirective->getDirectiveKind();}}}returntrue;}voidFindTargetCodeVisitor::addTargetRegionArgs(clang::CapturedStmt*S,clang::OMPExecutableDirective*TargetDirective,std::shared_ptr<TargetCodeRegion>TCR){DEBUGP("Add target region args");for(constauto&i:S->captures()){if(!(i.capturesVariableArrayType())){DEBUGP("captured Var: "+i.getCapturedVar()->getNameAsString());TCR->addCapture(&i);}else{// Not sure what exactly is caputred here. It looks like we have an// additional capture in cases of VATs.DEBUGP("Current capture is a variable-length array type (skipped)");}}// Find all not locally declared variables in the regionFindPrivateVariablesVisitorPrivateVarsVisitor(S->getBeginLoc(),Context.getSourceManager());PrivateVarsVisitor.TraverseStmt(S);// Remove any not locally declared variables which are already capturedautoVarSet=PrivateVarsVisitor.getVarSet();for(auto&CapturedVar:TCR->capturedVars()){VarSet.erase(CapturedVar.getDecl());}// Add variables used in OMP clauses which are not captured as first-private// variablesCollectOMPClauseParamsVisitor(TCR).TraverseStmt(TargetDirective);// Add non-local, non-capured variable as private variablesTCR->setPrivateVars(VarSet);}boolFindTargetCodeVisitor::VisitDecl(clang::Decl*D){auto*FD=llvm::dyn_cast<clang::FunctionDecl>(D);if(FD){autosearch=FuncDeclWithoutBody.find(FD->getNameAsString());if(search!=FuncDeclWithoutBody.end()){Functions.addDecl(D);FuncDeclWithoutBody.erase(search);}}// search Decl attributes for 'omp declare target' attrfor(auto&attr:D->attrs()){if(attr->getKind()==clang::attr::OMPDeclareTargetDecl){Functions.addDecl(D);if(FD){if(FD->hasBody()&&!FD->doesThisDeclarationHaveABody()){FuncDeclWithoutBody.insert(FD->getNameAsString());}}returntrue;}}returntrue;}boolFindLoopStmtVisitor::VisitStmt(clang::Stmt*S){if(autoLS=llvm::dyn_cast<clang::ForStmt>(S)){FindDeclRefVisitor.TraverseStmt(LS->getInit());}returntrue;}boolFindDeclRefExprVisitor::VisitStmt(clang::Stmt*S){if(autoDRE=llvm::dyn_cast<clang::DeclRefExpr>(S)){if(autoDD=llvm::dyn_cast<clang::DeclaratorDecl>(DRE->getDecl())){if(autoVD=llvm::dyn_cast<clang::VarDecl>(DD)){if(VD->getNameAsString()!=".reduction.lhs"){VarSet.insert(VD);}}}}returntrue;}boolDiscoverTypesInDeclVisitor::VisitDecl(clang::Decl*D){if(auto*VD=llvm::dyn_cast<clang::ValueDecl>(D)){if(constclang::Type*TP=VD->getType().getTypePtrOrNull()){processType(TP);}}returntrue;}boolDiscoverTypesInDeclVisitor::VisitExpr(clang::Expr*E){if(auto*DRE=llvm::dyn_cast<clang::DeclRefExpr>(E)){if(auto*ECD=llvm::dyn_cast<clang::EnumConstantDecl>(DRE->getDecl())){OnEachTypeRef(llvm::cast<clang::EnumDecl>(ECD->getDeclContext()));returntrue;}}if(constclang::Type*TP=E->getType().getTypePtrOrNull()){if(TP->isPointerType()){TP=TP->getPointeeOrArrayElementType();}processType(TP);}returntrue;}boolDiscoverTypesInDeclVisitor::VisitType(clang::Type*T){processType(T);returntrue;}voidDiscoverTypesInDeclVisitor::processType(constclang::Type*TP){if(constclang::TypedefType*TDT=TP->getAs<clang::TypedefType>()){OnEachTypeRef(TDT->getDecl());}elseif(auto*TD=TP->getAsTagDecl()){OnEachTypeRef(TD);}}DiscoverTypesInDeclVisitor::DiscoverTypesInDeclVisitor(TypeDeclResolver&Types){OnEachTypeRef=[&Types](clang::Decl*D){Types.addDecl(D);};}DiscoverFunctionsInDeclVisitor::DiscoverFunctionsInDeclVisitor(FunctionDeclResolver&Functions){OnEachFuncRef=[&Functions](clang::FunctionDecl*FD){Functions.addDecl(FD);};}boolDiscoverFunctionsInDeclVisitor::VisitExpr(clang::Expr*E){clang::DeclRefExpr*DRE=llvm::dyn_cast<clang::DeclRefExpr>(E);if(DRE!=nullptr){if(auto*D=DRE->getDecl()){if(auto*FD=llvm::dyn_cast<clang::FunctionDecl>(D)){OnEachFuncRef(FD);auto*FDDefinition=FD->getDefinition();if(FDDefinition!=FD&&FDDefinition!=NULL){OnEachFuncRef(FDDefinition);}}}}returntrue;}boolFindArraySectionVisitor::VisitExpr(clang::Expr*E){if(auto*ASE=llvm::dyn_cast<clang::OMPArraySectionExpr>(E)){clang::Expr*Base=ASE->getBase();if(llvm::isa<clang::OMPArraySectionExpr>(Base)){returntrue;}if(auto*CastBase=llvm::dyn_cast<clang::CastExpr>(Base)){Base=CastBase->getSubExpr();if(auto*DRE=llvm::dyn_cast<clang::DeclRefExpr>(Base)){auto*VarDecl=llvm::dyn_cast<clang::VarDecl>(DRE->getDecl());if(!VarDecl){llvm::errs()<<"VALDECL != VARDECL\n";returntrue;}clang::Expr*LowerBound=ASE->getLowerBound();if(!LowerBound){returntrue;}if(auto*IntegerLiteral=llvm::dyn_cast<clang::IntegerLiteral>(LowerBound)){if(IntegerLiteral->getValue()==0){returntrue;}}LowerBoundsMap.emplace(VarDecl,LowerBound);}}}returntrue;}boolFindPrivateVariablesVisitor::VisitExpr(clang::Expr*E){if(auto*DRE=llvm::dyn_cast<clang::DeclRefExpr>(E)){if(auto*VD=llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())){// We do not collect variables in 'collect target' declarations.for(auto&attr:VD->attrs()){if(attr->getKind()==clang::attr::OMPDeclareTargetDecl){returntrue;}}// If the variable is declared outside of the target region it may be a// private variableif(SM.isBeforeInTranslationUnit(VD->getLocation(),RegionTopSourceLocation)){// Add the Variable to our setVarSet.insert(VD);}}}returntrue;}