summaryrefslogtreecommitdiff
path: root/src/ast/Python3VisitorImpl.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast/Python3VisitorImpl.java')
-rw-r--r--src/ast/Python3VisitorImpl.java347
1 files changed, 306 insertions, 41 deletions
diff --git a/src/ast/Python3VisitorImpl.java b/src/ast/Python3VisitorImpl.java
index 98ede6d..6b3d7d0 100644
--- a/src/ast/Python3VisitorImpl.java
+++ b/src/ast/Python3VisitorImpl.java
@@ -1,10 +1,19 @@
package ast;
import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import ast.nodes.*;
+import ast.types.*;
+import codegen.Label;
+import parser.Python3Lexer;
import parser.Python3ParserBaseVisitor;
import parser.Python3Parser.*;
+
+import org.antlr.v4.runtime.*;
+import org.antlr.v4.runtime.tree.*;
import org.antlr.v4.runtime.tree.TerminalNode;
/**
@@ -13,6 +22,25 @@ import org.antlr.v4.runtime.tree.TerminalNode;
*/
public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
+ Map<String, Integer> R;
+ private TokenStreamRewriter rewriter;
+ private boolean optimize;
+ private boolean optimizationDone;
+
+ public Python3VisitorImpl(CommonTokenStream tokens, boolean optimize) {
+ rewriter = new TokenStreamRewriter(tokens);
+ this.optimize = optimize;
+ optimizationDone = false;
+ }
+
+ public String getRewriter() {
+ return rewriter.getText();
+ }
+
+ public boolean getOptimizationDone() {
+ return optimizationDone;
+ }
+
/**
* Since a root can be a simple_stmts or a compound_stmt, this method
* returns a new `RootNode` with a list of them.
@@ -20,7 +48,9 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` root : NEWLINE* (simple_stmts | compound_stmt)* EOF; ```
*/
public Node visitRoot(RootContext ctx) {
- ArrayList<Node> childs = new ArrayList<Node>();
+ ArrayList<Node> childs = new ArrayList<>();
+
+ R = new HashMap<>();
for (int i = 0; i < ctx.getChildCount(); i++) {
var child = ctx.getChild(i);
@@ -32,6 +62,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
}
}
+ // cfg.addEdge(cfg.getExitNode());
return new RootNode(childs);
}
@@ -41,7 +72,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` simple_stmts : simple_stmt (';' simple_stmt)* ';'? NEWLINE ; ```
*/
public Node visitSimple_stmts(Simple_stmtsContext ctx) {
- ArrayList<Node> stmts = new ArrayList<Node>();
+ ArrayList<Node> stmts = new ArrayList<>();
for (Simple_stmtContext stm : ctx.simple_stmt()) {
stmts.add(visit(stm));
@@ -64,21 +95,17 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
if (ctx.if_stmt() != null) {
ifStmt = visit(ctx.if_stmt());
- }
-
- if (ctx.funcdef() != null) {
+ } else if (ctx.funcdef() != null) {
funcDef = visit(ctx.funcdef());
- }
-
- if (ctx.for_stmt() != null) {
+ } else if (ctx.for_stmt() != null) {
forStmt = visit(ctx.for_stmt());
- }
-
- if (ctx.while_stmt() != null) {
+ } else if (ctx.while_stmt() != null) {
whileStmt = visit(ctx.while_stmt());
}
- return new CompoundNode(ifStmt, funcDef, forStmt, whileStmt);
+ CompoundNode compoundNode = new CompoundNode(ifStmt, funcDef, forStmt, whileStmt);
+
+ return compoundNode;
}
/**
@@ -123,7 +150,13 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
Node assign = visit(ctx.augassign());
Node rhr = visit(ctx.exprlist(1));
- return new AssignmentNode(lhr, assign, rhr);
+ AssignmentNode assignmentNode = new AssignmentNode(lhr, assign, rhr,
+ ctx.exprlist(0).getStart().getTokenIndex(),
+ ctx.exprlist(1).getStop().getTokenIndex());
+
+ R.put(((ExprNode) ((ExprListNode) lhr).getElem(0)).getId(), ctx.getStart().getLine());
+
+ return assignmentNode;
}
/**
@@ -154,7 +187,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
Node dottedName = visit(ctx.dotted_name());
- ArrayList<String> names = new ArrayList<String>();
+ ArrayList<String> names = new ArrayList<>();
for (var s : ctx.NAME()) {
names.add(s.toString());
@@ -169,7 +202,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` dotted_name : NAME ('.' NAME)* ; ```
*/
public Node visitDotted_name(Dotted_nameContext ctx) {
- ArrayList<TerminalNode> names = new ArrayList<TerminalNode>();
+ ArrayList<TerminalNode> names = new ArrayList<>();
for (var name : ctx.NAME()) {
names.add(name);
@@ -191,6 +224,8 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
}
Node block = visit(ctx.block());
+ rewriter.insertAfter(ctx.getStart().getTokenIndex(), " ");
+ rewriter.insertBefore(ctx.CLOSE_PAREN().getSymbol().getStartIndex() - 1, "\n ");
return new FuncdefNode(ctx.NAME(), paramlist, block);
}
@@ -204,7 +239,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ```
*/
public Node visitParamlist(ParamlistContext ctx) {
- ArrayList<Node> params = new ArrayList<Node>();
+ ArrayList<Node> params = new ArrayList<>();
for (ParamdefContext s : ctx.paramdef()) {
params.add(visit(s));
@@ -220,6 +255,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` paramdef : NAME (':' expr)? ; ```
*/
public Node visitParamdef(ParamdefContext ctx) {
+ R.remove(ctx.NAME().toString());
return new ParamdefNode(ctx.NAME().toString());
}
@@ -267,7 +303,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
}
/**
- * Returns a `IfNode`. FIXME: add support for elif statement.
+ * Returns a `IfNode`.
*
* ``` if_stmt : 'if' expr ':' block ('elif' expr ':' block)* ('else' ':'
* block)? ; ```
@@ -275,12 +311,15 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
public Node visitIf_stmt(If_stmtContext ctx) {
var blocks = ctx.block();
Node condExp = visit(ctx.expr(0));
+
Node thenExp = visit(blocks.get(0));
+
Node elseExp = null;
if (blocks.size() > 1) {
elseExp = visit(blocks.get(1));
}
+ rewriter.insertAfter(ctx.getStart().getTokenIndex(), " ");
return new IfNode(condExp, thenExp, elseExp);
}
@@ -290,12 +329,132 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` while_stmt : 'while' expr ':' block ('else' ':' block)? ; ```
*/
public Node visitWhile_stmt(While_stmtContext ctx) {
- Node expr = visit(ctx.expr());
+ // Do the same for the while expression and the block
+ ExprNode expr = (ExprNode) visit(ctx.expr());
// Block 1 is for the while-else statement
- Node block = visit(ctx.block(0));
+ BlockNode block = (BlockNode) visit(ctx.block(0));
+
+ WhileStmtNode whileStmt = new WhileStmtNode(expr, block);
+ if (!optimize) {
+ return whileStmt;
+ }
+
+ rewriter.insertAfter(ctx.COLON(0).getSymbol().getTokenIndex(), "\n");
+
+ int lineStart = ctx.getStart().getLine();
+ int lineStop = ctx.getStop().getLine();
+ int index = ctx.getStart().getTokenIndex();
+ optimizeWithSecond(block, lineStart, lineStop, index);
+
+ // optimize while's guard
+ int counter = 0;
+ var exprs = expr.getExprs();
+ for (var e : exprs) {
+ if (e instanceof ExprNode) {
+ ExprNode exprNode = (ExprNode) e;
+ if (exprNode.typeCheck() instanceof AtomType) {
+ continue;
+ }
+ }
+ ArrayList<String> al = findAtomPresent(e, new ArrayList<>());
+ if (!al.isEmpty()) {
+ boolean constant = true;
+ for (String a : al) {
+ int n = R.get(a);
+ if (n > lineStart && n <= lineStop) {
+ constant = false;
+ break;
+ }
+ }
+ if (constant) {
+ String newVar = Label.newVar();
+ rewriter.insertBefore(index, newVar + "=" + e.toPrint("") + "\n");
+ int lastToken = ctx.expr().expr(counter).getStop().getTokenIndex();
+ int firstToken = ctx.expr().expr(counter).getStart().getTokenIndex();
+ rewriter.replace(firstToken, lastToken, newVar);
+ }
+ }
+ counter++;
+ }
- return new WhileStmtNode(expr, block);
+ optimizeWithThird(block, lineStart, lineStop, index);
+
+ return whileStmt;
+ }
+
+ private ArrayList<String> findAtomPresent(Node e, ArrayList<String> Acc) {
+ if (e instanceof ExprNode) {
+ ExprNode expNode = (ExprNode) e;
+ ArrayList<Node> exprs = expNode.getExprs();
+ if (!exprs.isEmpty()) {
+ for (Node i : exprs) {
+ findAtomPresent(i, Acc);
+ }
+ } else {
+ AtomNode a = (AtomNode) expNode.getAtom();
+ if (a.typeCheck() instanceof AtomType) {
+ Acc.add(a.getId());
+ }
+ }
+ }
+ return Acc;
+ }
+
+ private void optimizeWithSecond(BlockNode block, int lineStart, int lineStop, int index) {
+ rewriter.insertAfter(index, " ");
+ ArrayList<AssignmentNode> assignments = new ArrayList<>();
+ for (var child : block.getChilds()) {
+ if (child instanceof SimpleStmtsNode) {
+ var stmts = (SimpleStmtsNode) child;
+ for (var stmt : stmts.getStmts()) {
+ var assignment = (AssignmentNode) ((SimpleStmtNode) stmt).getAssignment();
+ if (assignment != null) {
+ assignments.add(assignment);
+ }
+ }
+ }
+ }
+
+ // g , x + 2 * y
+ // m , m + n + g
+ // n , n + 1
+ for (var assignment : assignments) {
+
+ var lhr = (ExprNode) assignment.getLhr().getElem(0);
+ var rhr = (ExprNode) assignment.getRhr().getElem(0);
+ ArrayList<String> al = findAtomPresent(rhr, new ArrayList<>());
+ if (!al.isEmpty()) {
+ boolean constant = true;
+ for (String a : al) {
+ if (R.get(a) == null) {
+ rewriter.insertBefore(assignment.getLhrIndex(), "\n");
+ rewriter.insertAfter(assignment.getLhrIndex() - 1, "\t");
+ constant = false;
+ break;
+ }
+ int n = R.get(a);
+ if (n > lineStart && n <= lineStop) {
+ constant = false;
+ break;
+ }
+ }
+
+ rewriter.insertAfter(assignment.getRhrIndex(), "\n");
+ if (constant) {
+ rewriter.insertBefore(index, lhr.toPrint("") + "=" + rhr.toPrint("") + "\n");
+ rewriter.replace(assignment.getLhrIndex(), assignment.getRhrIndex(), "");
+ optimizationDone = true;
+ } else {
+ rewriter.insertBefore(assignment.getLhrIndex(), "\t");
+ }
+ } else {
+ String newVar = Label.newVar();
+ rewriter.insertBefore(index, newVar + "=" + rhr.toPrint("") + "\n");
+ rewriter.replace(assignment.getLhrIndex(), assignment.getRhrIndex(),
+ "\t" + lhr.toPrint("") + "=" + newVar + "\n");
+ }
+ }
}
/**
@@ -304,12 +463,89 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* ``` for_stmt : 'for' exprlist ':' block ('else' ':' block)? ; ```
*/
public Node visitFor_stmt(For_stmtContext ctx) {
+
+ // Do the same for the for expression and the block
Node exprList = visit(ctx.exprlist());
+ int dimGetExpr = ((ExprListNode) exprList).getExprs().size();
+
+ for (int e = 0; e < dimGetExpr; e++) {
+ if (e == dimGetExpr - 1) {
+ R.remove(((ExprNode) ((ExprNode) ((ExprListNode) exprList).getElem(e)).getExpr(0)).getId());
+ } else {
+ R.remove(((ExprNode) ((ExprListNode) exprList).getElem(e)).getId());
+ }
+ }
+
// Block 1 is for the for-else statement
- Node block = visit(ctx.block(0));
+ BlockNode block = (BlockNode) visit(ctx.block(0));
+
+ Node forNode = new ForStmtNode(exprList, block);
+ if (!optimize) {
+ return forNode;
+ }
+
+ int lineStart = ctx.getStart().getLine();
+ int lineStop = ctx.getStop().getLine();
+ int index = ctx.getStart().getTokenIndex();
- return new ForStmtNode(exprList, block);
+ rewriter.insertAfter(index, " ");
+ // NOTE: works only for one argument
+ rewriter.insertAfter(index + 1, " ");
+ rewriter.insertAfter(index + 2, " ");
+ optimizeWithSecond(block, lineStart, lineStop, index);
+ optimizeWithThird(block, lineStart, lineStop, index);
+
+ return forNode;
+ }
+
+ private void optimizeWithThird(BlockNode block, int lineStart, int lineStop, int index) {
+ int counter = 0;
+ ArrayList<Node> stms = block.getChilds();
+ for (var e : stms) {
+ if (e instanceof SimpleStmtsNode) {
+ SimpleStmtsNode stmss = (SimpleStmtsNode) e;
+ for (Node stm : stmss.getStmts()) {
+ SimpleStmtNode singleStm = (SimpleStmtNode) stm;
+ AssignmentNode ass = (AssignmentNode) singleStm.getAssignment();
+ if (ass != null) {
+ ExprListNode rhr = ass.getRhr();
+ ExprNode rExpr = (ExprNode) rhr.getElem(0);
+ ArrayList<Node> exprsList = rExpr.getExprs();
+ if (exprsList.size() > 1) {
+ List<Node> exprsLists = exprsList.subList(0, exprsList.size() - 1);
+ for (var elem : exprsLists) {
+ if (elem instanceof ExprNode) {
+ ExprNode exprNode = (ExprNode) elem;
+ if (exprNode.typeCheck() instanceof AtomType) {
+ continue;
+ }
+ }
+ ArrayList<String> al = findAtomPresent(elem, new ArrayList<>());
+ if (!al.isEmpty()) {
+ boolean constant = true;
+ for (String a : al) {
+ int n = R.get(a);
+ if (n > lineStart && n <= lineStop) {
+ constant = false;
+ break;
+ }
+ }
+ if (constant) {
+ String newVar = Label.newVar();
+ rewriter.insertBefore(index, newVar + "=" + elem.toPrint("") + "\n");
+ int firstToken = ass.getLhrIndex() + 2;
+ int lastToken = ass.getRhrIndex() - 2;
+ rewriter.replace(firstToken, lastToken, newVar);
+ }
+ }
+ counter++;
+ }
+ }
+ }
+ }
+ }
+ }
}
/**
@@ -418,18 +654,35 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
if (ctx.atom() != null) {
atom = visit(ctx.atom());
- }
- if (ctx.comp_op() != null) {
- compOp = visit(ctx.comp_op());
- }
+ for (TrailerContext s : ctx.trailer()) {
+ trailers.add(visit(s));
+ }
+ } else {
+ if (ctx.ADD(0) != null) {
+ op = ctx.ADD(0).toString();
- for (ExprContext s : ctx.expr()) {
- exprs.add(visit(s));
- }
+ } else if (ctx.MINUS(0) != null) {
+ op = ctx.MINUS(0).toString();
- for (TrailerContext s : ctx.trailer()) {
- trailers.add(visit(s));
+ } else if (ctx.NOT() != null) {
+ op = ctx.NOT().toString();
+
+ } else if (ctx.STAR() != null) {
+ op = ctx.STAR().toString();
+
+ } else if (ctx.DIV() != null) {
+ op = ctx.DIV().toString();
+
+ }
+
+ if (ctx.comp_op() != null) {
+ compOp = visit(ctx.comp_op());
+ }
+
+ for (ExprContext s : ctx.expr()) {
+ exprs.add(visit(s));
+ }
}
return new ExprNode(atom, compOp, exprs, op, trailers);
@@ -461,11 +714,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
}
return new AtomNode(varName, null);
} else if (ctx.OPEN_BRACE() != null && ctx.CLOSE_BRACE() != null) {
- return manageCompListContext(tlc);
+ return manageCompListContext(tlc, "{", "}");
} else if (ctx.OPEN_BRACK() != null && ctx.CLOSE_BRACK() != null) {
- return manageCompListContext(tlc);
+ return manageCompListContext(tlc, "[", "]");
} else if (ctx.OPEN_PAREN() != null && ctx.CLOSE_PAREN() != null) {
- return manageCompListContext(tlc);
+ return manageCompListContext(tlc, "(", ")");
}
return new AtomNode(null, null);
}
@@ -475,12 +728,12 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
* `testlist_comp` set if the context is not null. Otherwise, returns an
* `AtomNode` with nulls.
*/
- public AtomNode manageCompListContext(Testlist_compContext tlc) {
+ public AtomNode manageCompListContext(Testlist_compContext tlc, String prefix, String suffix) {
if (tlc != null) {
Node testlist_comp = visit(tlc);
- return new AtomNode(null, testlist_comp);
+ return new AtomNode(null, testlist_comp, prefix, suffix);
}
- return new AtomNode(null, null);
+ return new AtomNode(null, null, prefix, suffix);
}
/**
@@ -491,6 +744,19 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
*/
public Node visitTrailer(TrailerContext ctx) {
Node arglist = null;
+ String prefix = "";
+ String suffix = "";
+
+ if (ctx.OPEN_BRACK() != null) {
+ prefix = "[";
+ suffix = "]";
+ } else if (ctx.OPEN_PAREN() != null) {
+ prefix = "(";
+ suffix = ")";
+ } else if (ctx.DOT() != null) {
+ prefix = ".";
+ }
+
if (ctx.arglist() != null) {
arglist = visit(ctx.arglist());
}
@@ -506,11 +772,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
methodCall = ctx.NAME();
}
- return new TrailerNode(arglist, exprs, methodCall, ctx.OPEN_PAREN() != null);
+ return new TrailerNode(arglist, exprs, methodCall, prefix, suffix);
}
/**
- * Returns a `Node`. FIXME: what to do in case of list??
+ * Returns a `Node`.
*
* ``` exprlist : expr (',' expr )* ','? ; ```
*/
@@ -577,12 +843,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {
/**
* Returns a `CompIterNode`.
+ * NOTE: We ignore `comp_if`.
*
* ``` comp_iter : comp_for | comp_if ; ;```
*/
public Node visitComp_iter(Comp_iterContext ctx) {
- // TODO: Implement comp_if
- // Node iter = visit(ctx.comp_if());
Comp_forContext cfc = ctx.comp_for();
Node forNode = null;
if (cfc != null) {