Skip to content

src/Visitors.cpp

This file implements the classes DiscoverTypesInDeclVisitor and FindTargetCodeVisitor.

Classes

Name
class CollectOMPClauseParamsVarsVisitor
OMP clause visitor.
class CollectOMPClauseParamsVisitor
OMP clause parameter visitor.

Functions

Name
bool stmtNeedsSemicolon(const clang::Stmt * S)
Determine whether a statement needs a semicolon.

Functions Documentation

function stmtNeedsSemicolon

1
2
3
static bool stmtNeedsSemicolon(
    const clang::Stmt * S
)

Determine whether a statement needs a semicolon.

Parameters:

  • S Statement to check

Return: true If a semicolon is needed

false If no semicolon is needed

Source code

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
//===-- sotoc/src/Visitor.cpp ---------------------------------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//

#include <sstream>
#include <string>

#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/ExprOpenMP.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Rewrite/Core/Rewriter.h"

#include "Debug.h"
#include "DeclResolver.h"
#include "TargetCode.h"
#include "TargetCodeFragment.h"
#include "Visitors.h"

static bool stmtNeedsSemicolon(const clang::Stmt *S) {
  while (1) {
    if (auto *CS = llvm::dyn_cast<clang::CapturedStmt>(S)) {
      S = CS->getCapturedStmt();
    } else if (auto *OS = llvm::dyn_cast<clang::OMPExecutableDirective>(S)) {
      S = OS->getInnermostCapturedStmt();
    } else {
      break;
    }
  }
  if (llvm::isa<clang::CompoundStmt>(S) || llvm::isa<clang::ForStmt>(S) ||
      llvm::isa<clang::IfStmt>(S)) {
    return false;
  }
  return true;
}

bool FindTargetCodeVisitor::TraverseDecl(clang::Decl *D) {
  if (!D) return false;
  if (auto *FD = llvm::dyn_cast<clang::FunctionDecl>(D)) {
    LastVisitedFuncDecl.push(FD);
  }
  bool ret = clang::RecursiveASTVisitor<FindTargetCodeVisitor>::TraverseDecl(D);
  if (auto *FD = llvm::dyn_cast<clang::FunctionDecl>(D)) {
    LastVisitedFuncDecl.pop();
  }
  return ret;
}

bool FindTargetCodeVisitor::VisitStmt(clang::Stmt *S) {
  if (auto *TD = llvm::dyn_cast<clang::OMPTargetDirective>(S)) {
    processTargetRegion(TD);
  } else if (auto *TD = llvm::dyn_cast<clang::OMPTargetTeamsDirective>(S)) {
    processTargetRegion(TD);
  } else if (auto *TD = llvm::dyn_cast<clang::OMPTargetParallelDirective>(S)) {
    processTargetRegion(TD);
  } else if (auto *LD = llvm::dyn_cast<clang::OMPLoopDirective>(S)) {
    if (auto *TD = llvm::dyn_cast<clang::OMPTargetParallelForDirective>(LD)) {
      processTargetRegion(TD);
    } else if (auto *TD =
                   llvm::dyn_cast<clang::OMPTargetParallelForSimdDirective>(
                       LD)) {
      processTargetRegion(TD);
    } else if (auto *TD = llvm::dyn_cast<clang::OMPTargetSimdDirective>(LD)) {
      processTargetRegion(TD);
    } else if (auto *TD =
                   llvm::dyn_cast<clang::OMPTargetTeamsDistributeDirective>(
                       LD)) {
      processTargetRegion(TD);
    } else if (auto *TD = llvm::dyn_cast<
                   clang::OMPTargetTeamsDistributeParallelForDirective>(LD)) {
      processTargetRegion(TD);
    } else if (auto *TD = llvm::dyn_cast<
                   clang::OMPTargetTeamsDistributeParallelForSimdDirective>(
                   LD)) {
      processTargetRegion(TD);
    } else if (auto *TD =
                   llvm::dyn_cast<clang::OMPTargetTeamsDistributeSimdDirective>(
                       LD)) {
      processTargetRegion(TD);
    }
  }
  return true;
}

class CollectOMPClauseParamsVarsVisitor
    : public clang::RecursiveASTVisitor<CollectOMPClauseParamsVarsVisitor> {
  std::shared_ptr<TargetCodeRegion> TCR;
public:
  CollectOMPClauseParamsVarsVisitor(std::shared_ptr<TargetCodeRegion> &TCR)
    : TCR(TCR) {};

  bool VisitStmt(clang::Stmt *S) {
    if (auto *DRE = llvm::dyn_cast<clang::DeclRefExpr>(S)) {
      if (auto *VD = llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())) {
        TCR->addOMPClauseParam(VD->getCanonicalDecl());
      }
    }
    return true;
  };
};

class CollectOMPClauseParamsVisitor
    : public clang::RecursiveASTVisitor<CollectOMPClauseParamsVisitor> {

      CollectOMPClauseParamsVarsVisitor VarsVisitor;
  bool InExplicitCast;
public:
  CollectOMPClauseParamsVisitor(std::shared_ptr<TargetCodeRegion> &TCR)
    : VarsVisitor(TCR), InExplicitCast(false) {};
  bool VisitStmt(clang::Stmt *S) {
    // This relies on the captured statement being the last child
    if (llvm::isa<clang::CapturedStmt>(S)) {
        return false;
    }

    if (llvm::isa<clang::ImplicitCastExpr>(S)) {
      InExplicitCast = true;
      return true;
    }

    auto *DRE = llvm::dyn_cast<clang::DeclRefExpr>(S);
    if (DRE && InExplicitCast) {
      if (auto *VD = llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())) {
        VarsVisitor.TraverseStmt(VD->getInit());
      }
    }
    InExplicitCast = false;
    return true;
  };
};

bool FindTargetCodeVisitor::processTargetRegion(
    clang::OMPExecutableDirective *TargetDirective) {
  // TODO: Not sure why to iterate the children, because I think there
  // is only one child. For me this looks wrong.
  for (auto i = TargetDirective->child_begin(),
            e = TargetDirective->child_end();
       i != e; ++i) {

    if (auto *CS = llvm::dyn_cast<clang::CapturedStmt>(*i)) {
      while (auto *NCS =
                 llvm::dyn_cast<clang::CapturedStmt>(CS->getCapturedStmt())) {
        CS = NCS;
      }

      auto TCR = std::make_shared<TargetCodeRegion>(
          CS, TargetDirective, LastVisitedFuncDecl.top(), Context);
      // if the target region cannot be added we dont want to parse its args
      if (TargetCodeInfo.addCodeFragment(TCR)) {

        FindArraySectionVisitor(TCR->CapturedLowerBounds).TraverseStmt(TargetDirective);

        for (auto C : TargetDirective->clauses()) {
          TCR->addOMPClause(C);
        }

        // For more complex data types (like structs) we need to traverse the
        // tree
        DiscoverTypeVisitor.TraverseStmt(CS);
        DiscoverFunctionVisitor.TraverseStmt(CS);
        addTargetRegionArgs(CS, TargetDirective, TCR);
        TCR->NeedsSemicolon = stmtNeedsSemicolon(CS);
        TCR->TargetCodeKind = TargetDirective->getDirectiveKind();
      }
    }
  }
  return true;
}

void FindTargetCodeVisitor::addTargetRegionArgs(
    clang::CapturedStmt *S, clang::OMPExecutableDirective *TargetDirective,
    std::shared_ptr<TargetCodeRegion> TCR) {

  DEBUGP("Add target region args");
  for (const auto &i : S->captures()) {
    if (!(i.capturesVariableArrayType())) {
      DEBUGP("captured Var: " + i.getCapturedVar()->getNameAsString());
      TCR->addCapture(&i);
    } else {
      // Not sure what exactly is caputred here. It looks like we have an
      // additional capture in cases of VATs.
      DEBUGP("Current capture is a variable-length array type (skipped)");
    }
  }

  // Find all not locally declared variables in the region
  FindPrivateVariablesVisitor PrivateVarsVisitor(S->getBeginLoc(),
                                                 Context.getSourceManager());
  PrivateVarsVisitor.TraverseStmt(S);

  // Remove any not locally declared variables which are already captured
  auto VarSet = PrivateVarsVisitor.getVarSet();
  for (auto &CapturedVar : TCR->capturedVars()) {
    VarSet.erase(CapturedVar.getDecl());
  }

  // Add variables used in OMP clauses which are not captured as first-private
  // variables
  CollectOMPClauseParamsVisitor(TCR).TraverseStmt(TargetDirective);

  // Add non-local, non-capured variable as private variables
  TCR->setPrivateVars(VarSet);
}

bool FindTargetCodeVisitor::VisitDecl(clang::Decl *D) {
  auto *FD = llvm::dyn_cast<clang::FunctionDecl>(D);
  if (FD) {
    auto search = FuncDeclWithoutBody.find(FD->getNameAsString());
    if (search != FuncDeclWithoutBody.end()) {
      Functions.addDecl(D);
      FuncDeclWithoutBody.erase(search);
    }
  }

  // search Decl attributes for 'omp declare target' attr
  for (auto &attr : D->attrs()) {
    if (attr->getKind() == clang::attr::OMPDeclareTargetDecl) {
      Functions.addDecl(D);
      if (FD) {
        if (FD->hasBody() && !FD->doesThisDeclarationHaveABody()) {
          FuncDeclWithoutBody.insert(FD->getNameAsString());
        }
      }
      return true;
    }
  }
  return true;
}

bool FindLoopStmtVisitor::VisitStmt(clang::Stmt *S) {
  if (auto LS = llvm::dyn_cast<clang::ForStmt>(S)) {
    FindDeclRefVisitor.TraverseStmt(LS->getInit());
  }
  return true;
}

bool FindDeclRefExprVisitor::VisitStmt(clang::Stmt *S) {
  if (auto DRE = llvm::dyn_cast<clang::DeclRefExpr>(S)) {
    if (auto DD = llvm::dyn_cast<clang::DeclaratorDecl>(DRE->getDecl())) {
      if (auto VD = llvm::dyn_cast<clang::VarDecl>(DD)) {
        if (VD->getNameAsString() != ".reduction.lhs") {
          VarSet.insert(VD);
        }
      }
    }
  }
  return true;
}

bool DiscoverTypesInDeclVisitor::VisitDecl(clang::Decl *D) {
  if (auto *VD = llvm::dyn_cast<clang::ValueDecl>(D)) {
    if (const clang::Type *TP = VD->getType().getTypePtrOrNull()) {
      processType(TP);
    }
  }
  return true;
}

bool DiscoverTypesInDeclVisitor::VisitExpr(clang::Expr *E) {
  if (auto *DRE = llvm::dyn_cast<clang::DeclRefExpr>(E)) {
    if (auto *ECD = llvm::dyn_cast<clang::EnumConstantDecl>(DRE->getDecl())) {
      OnEachTypeRef(llvm::cast<clang::EnumDecl>(ECD->getDeclContext()));
      return true;
    }
  }
  if (const clang::Type *TP = E->getType().getTypePtrOrNull()) {
    if (TP->isPointerType()) {
      TP = TP->getPointeeOrArrayElementType();
    }
    processType(TP);
  }
  return true;
}

bool DiscoverTypesInDeclVisitor::VisitType(clang::Type *T) {
  processType(T);
  return true;
}

void DiscoverTypesInDeclVisitor::processType(const clang::Type *TP) {
  if (const clang::TypedefType *TDT = TP->getAs<clang::TypedefType>()) {
    OnEachTypeRef(TDT->getDecl());
  } else if (auto *TD = TP->getAsTagDecl()) {
    OnEachTypeRef(TD);
  }
}

DiscoverTypesInDeclVisitor::DiscoverTypesInDeclVisitor(
    TypeDeclResolver &Types) {
  OnEachTypeRef = [&Types](clang::Decl *D) { Types.addDecl(D); };
}

DiscoverFunctionsInDeclVisitor::DiscoverFunctionsInDeclVisitor(
    FunctionDeclResolver &Functions) {
  OnEachFuncRef = [&Functions](clang::FunctionDecl *FD) {
    Functions.addDecl(FD);
  };
}

bool DiscoverFunctionsInDeclVisitor::VisitExpr(clang::Expr *E) {
  clang::DeclRefExpr *DRE = llvm::dyn_cast<clang::DeclRefExpr>(E);
  if (DRE != nullptr) {
    if (auto *D = DRE->getDecl()) {
      if (auto *FD = llvm::dyn_cast<clang::FunctionDecl>(D)) {
        OnEachFuncRef(FD);
        auto *FDDefinition = FD->getDefinition();
        if (FDDefinition != FD && FDDefinition != NULL) {
          OnEachFuncRef(FDDefinition);
        }
      }
    }
  }
  return true;
}

bool FindArraySectionVisitor::VisitExpr(clang::Expr *E) {
  if (auto *ASE = llvm::dyn_cast<clang::OMPArraySectionExpr>(E)) {
    clang::Expr *Base = ASE->getBase();
    if (llvm::isa<clang::OMPArraySectionExpr>(Base)) {
      return true;
    }
    if (auto *CastBase = llvm::dyn_cast<clang::CastExpr>(Base)) {
      Base = CastBase->getSubExpr();
      if (auto *DRE = llvm::dyn_cast<clang::DeclRefExpr>(Base)) {
        auto *VarDecl = llvm::dyn_cast<clang::VarDecl>(DRE->getDecl());
        if (!VarDecl) {
          llvm::errs() << "VALDECL != VARDECL\n";
          return true;
        }
        clang::Expr *LowerBound = ASE->getLowerBound();
        if (!LowerBound) {
          return true;
        }

        if (auto *IntegerLiteral =
                llvm::dyn_cast<clang::IntegerLiteral>(LowerBound)) {
          if (IntegerLiteral->getValue() == 0) {
            return true;
          }
        }
        LowerBoundsMap.emplace(VarDecl, LowerBound);
      }
    }
  }
  return true;
}

bool FindPrivateVariablesVisitor::VisitExpr(clang::Expr *E) {
  if (auto *DRE = llvm::dyn_cast<clang::DeclRefExpr>(E)) {
    if (auto *VD = llvm::dyn_cast<clang::VarDecl>(DRE->getDecl())) {
      // We do not collect variables in 'collect target' declarations.
      for (auto &attr : VD->attrs()) {
        if (attr->getKind() == clang::attr::OMPDeclareTargetDecl) {
          return true;
        }
      }

      // If the variable is declared outside of the target region it may be a
      // private variable
      if (SM.isBeforeInTranslationUnit(VD->getLocation(), RegionTopSourceLocation)) {
        // Add the Variable to our set
        VarSet.insert(VD);
      }
    }
  }
  return true;
}

Last update: 2021-11-24
Back to top