diff options
Diffstat (limited to 'src/ast/Python3VisitorImpl.java')
-rw-r--r-- | src/ast/Python3VisitorImpl.java | 347 |
1 files changed, 306 insertions, 41 deletions
diff --git a/src/ast/Python3VisitorImpl.java b/src/ast/Python3VisitorImpl.java index 98ede6d..6b3d7d0 100644 --- a/src/ast/Python3VisitorImpl.java +++ b/src/ast/Python3VisitorImpl.java @@ -1,10 +1,19 @@ package ast; import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import ast.nodes.*; +import ast.types.*; +import codegen.Label; +import parser.Python3Lexer; import parser.Python3ParserBaseVisitor; import parser.Python3Parser.*; + +import org.antlr.v4.runtime.*; +import org.antlr.v4.runtime.tree.*; import org.antlr.v4.runtime.tree.TerminalNode; /** @@ -13,6 +22,25 @@ import org.antlr.v4.runtime.tree.TerminalNode; */ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { + Map<String, Integer> R; + private TokenStreamRewriter rewriter; + private boolean optimize; + private boolean optimizationDone; + + public Python3VisitorImpl(CommonTokenStream tokens, boolean optimize) { + rewriter = new TokenStreamRewriter(tokens); + this.optimize = optimize; + optimizationDone = false; + } + + public String getRewriter() { + return rewriter.getText(); + } + + public boolean getOptimizationDone() { + return optimizationDone; + } + /** * Since a root can be a simple_stmts or a compound_stmt, this method * returns a new `RootNode` with a list of them. @@ -20,7 +48,9 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` root : NEWLINE* (simple_stmts | compound_stmt)* EOF; ``` */ public Node visitRoot(RootContext ctx) { - ArrayList<Node> childs = new ArrayList<Node>(); + ArrayList<Node> childs = new ArrayList<>(); + + R = new HashMap<>(); for (int i = 0; i < ctx.getChildCount(); i++) { var child = ctx.getChild(i); @@ -32,6 +62,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { } } + // cfg.addEdge(cfg.getExitNode()); return new RootNode(childs); } @@ -41,7 +72,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` simple_stmts : simple_stmt (';' simple_stmt)* ';'? NEWLINE ; ``` */ public Node visitSimple_stmts(Simple_stmtsContext ctx) { - ArrayList<Node> stmts = new ArrayList<Node>(); + ArrayList<Node> stmts = new ArrayList<>(); for (Simple_stmtContext stm : ctx.simple_stmt()) { stmts.add(visit(stm)); @@ -64,21 +95,17 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { if (ctx.if_stmt() != null) { ifStmt = visit(ctx.if_stmt()); - } - - if (ctx.funcdef() != null) { + } else if (ctx.funcdef() != null) { funcDef = visit(ctx.funcdef()); - } - - if (ctx.for_stmt() != null) { + } else if (ctx.for_stmt() != null) { forStmt = visit(ctx.for_stmt()); - } - - if (ctx.while_stmt() != null) { + } else if (ctx.while_stmt() != null) { whileStmt = visit(ctx.while_stmt()); } - return new CompoundNode(ifStmt, funcDef, forStmt, whileStmt); + CompoundNode compoundNode = new CompoundNode(ifStmt, funcDef, forStmt, whileStmt); + + return compoundNode; } /** @@ -123,7 +150,13 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { Node assign = visit(ctx.augassign()); Node rhr = visit(ctx.exprlist(1)); - return new AssignmentNode(lhr, assign, rhr); + AssignmentNode assignmentNode = new AssignmentNode(lhr, assign, rhr, + ctx.exprlist(0).getStart().getTokenIndex(), + ctx.exprlist(1).getStop().getTokenIndex()); + + R.put(((ExprNode) ((ExprListNode) lhr).getElem(0)).getId(), ctx.getStart().getLine()); + + return assignmentNode; } /** @@ -154,7 +187,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { Node dottedName = visit(ctx.dotted_name()); - ArrayList<String> names = new ArrayList<String>(); + ArrayList<String> names = new ArrayList<>(); for (var s : ctx.NAME()) { names.add(s.toString()); @@ -169,7 +202,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` dotted_name : NAME ('.' NAME)* ; ``` */ public Node visitDotted_name(Dotted_nameContext ctx) { - ArrayList<TerminalNode> names = new ArrayList<TerminalNode>(); + ArrayList<TerminalNode> names = new ArrayList<>(); for (var name : ctx.NAME()) { names.add(name); @@ -191,6 +224,8 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { } Node block = visit(ctx.block()); + rewriter.insertAfter(ctx.getStart().getTokenIndex(), " "); + rewriter.insertBefore(ctx.CLOSE_PAREN().getSymbol().getStartIndex() - 1, "\n "); return new FuncdefNode(ctx.NAME(), paramlist, block); } @@ -204,7 +239,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` */ public Node visitParamlist(ParamlistContext ctx) { - ArrayList<Node> params = new ArrayList<Node>(); + ArrayList<Node> params = new ArrayList<>(); for (ParamdefContext s : ctx.paramdef()) { params.add(visit(s)); @@ -220,6 +255,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` paramdef : NAME (':' expr)? ; ``` */ public Node visitParamdef(ParamdefContext ctx) { + R.remove(ctx.NAME().toString()); return new ParamdefNode(ctx.NAME().toString()); } @@ -267,7 +303,7 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { } /** - * Returns a `IfNode`. FIXME: add support for elif statement. + * Returns a `IfNode`. * * ``` if_stmt : 'if' expr ':' block ('elif' expr ':' block)* ('else' ':' * block)? ; ``` @@ -275,12 +311,15 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { public Node visitIf_stmt(If_stmtContext ctx) { var blocks = ctx.block(); Node condExp = visit(ctx.expr(0)); + Node thenExp = visit(blocks.get(0)); + Node elseExp = null; if (blocks.size() > 1) { elseExp = visit(blocks.get(1)); } + rewriter.insertAfter(ctx.getStart().getTokenIndex(), " "); return new IfNode(condExp, thenExp, elseExp); } @@ -290,12 +329,132 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` while_stmt : 'while' expr ':' block ('else' ':' block)? ; ``` */ public Node visitWhile_stmt(While_stmtContext ctx) { - Node expr = visit(ctx.expr()); + // Do the same for the while expression and the block + ExprNode expr = (ExprNode) visit(ctx.expr()); // Block 1 is for the while-else statement - Node block = visit(ctx.block(0)); + BlockNode block = (BlockNode) visit(ctx.block(0)); + + WhileStmtNode whileStmt = new WhileStmtNode(expr, block); + if (!optimize) { + return whileStmt; + } + + rewriter.insertAfter(ctx.COLON(0).getSymbol().getTokenIndex(), "\n"); + + int lineStart = ctx.getStart().getLine(); + int lineStop = ctx.getStop().getLine(); + int index = ctx.getStart().getTokenIndex(); + optimizeWithSecond(block, lineStart, lineStop, index); + + // optimize while's guard + int counter = 0; + var exprs = expr.getExprs(); + for (var e : exprs) { + if (e instanceof ExprNode) { + ExprNode exprNode = (ExprNode) e; + if (exprNode.typeCheck() instanceof AtomType) { + continue; + } + } + ArrayList<String> al = findAtomPresent(e, new ArrayList<>()); + if (!al.isEmpty()) { + boolean constant = true; + for (String a : al) { + int n = R.get(a); + if (n > lineStart && n <= lineStop) { + constant = false; + break; + } + } + if (constant) { + String newVar = Label.newVar(); + rewriter.insertBefore(index, newVar + "=" + e.toPrint("") + "\n"); + int lastToken = ctx.expr().expr(counter).getStop().getTokenIndex(); + int firstToken = ctx.expr().expr(counter).getStart().getTokenIndex(); + rewriter.replace(firstToken, lastToken, newVar); + } + } + counter++; + } - return new WhileStmtNode(expr, block); + optimizeWithThird(block, lineStart, lineStop, index); + + return whileStmt; + } + + private ArrayList<String> findAtomPresent(Node e, ArrayList<String> Acc) { + if (e instanceof ExprNode) { + ExprNode expNode = (ExprNode) e; + ArrayList<Node> exprs = expNode.getExprs(); + if (!exprs.isEmpty()) { + for (Node i : exprs) { + findAtomPresent(i, Acc); + } + } else { + AtomNode a = (AtomNode) expNode.getAtom(); + if (a.typeCheck() instanceof AtomType) { + Acc.add(a.getId()); + } + } + } + return Acc; + } + + private void optimizeWithSecond(BlockNode block, int lineStart, int lineStop, int index) { + rewriter.insertAfter(index, " "); + ArrayList<AssignmentNode> assignments = new ArrayList<>(); + for (var child : block.getChilds()) { + if (child instanceof SimpleStmtsNode) { + var stmts = (SimpleStmtsNode) child; + for (var stmt : stmts.getStmts()) { + var assignment = (AssignmentNode) ((SimpleStmtNode) stmt).getAssignment(); + if (assignment != null) { + assignments.add(assignment); + } + } + } + } + + // g , x + 2 * y + // m , m + n + g + // n , n + 1 + for (var assignment : assignments) { + + var lhr = (ExprNode) assignment.getLhr().getElem(0); + var rhr = (ExprNode) assignment.getRhr().getElem(0); + ArrayList<String> al = findAtomPresent(rhr, new ArrayList<>()); + if (!al.isEmpty()) { + boolean constant = true; + for (String a : al) { + if (R.get(a) == null) { + rewriter.insertBefore(assignment.getLhrIndex(), "\n"); + rewriter.insertAfter(assignment.getLhrIndex() - 1, "\t"); + constant = false; + break; + } + int n = R.get(a); + if (n > lineStart && n <= lineStop) { + constant = false; + break; + } + } + + rewriter.insertAfter(assignment.getRhrIndex(), "\n"); + if (constant) { + rewriter.insertBefore(index, lhr.toPrint("") + "=" + rhr.toPrint("") + "\n"); + rewriter.replace(assignment.getLhrIndex(), assignment.getRhrIndex(), ""); + optimizationDone = true; + } else { + rewriter.insertBefore(assignment.getLhrIndex(), "\t"); + } + } else { + String newVar = Label.newVar(); + rewriter.insertBefore(index, newVar + "=" + rhr.toPrint("") + "\n"); + rewriter.replace(assignment.getLhrIndex(), assignment.getRhrIndex(), + "\t" + lhr.toPrint("") + "=" + newVar + "\n"); + } + } } /** @@ -304,12 +463,89 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * ``` for_stmt : 'for' exprlist ':' block ('else' ':' block)? ; ``` */ public Node visitFor_stmt(For_stmtContext ctx) { + + // Do the same for the for expression and the block Node exprList = visit(ctx.exprlist()); + int dimGetExpr = ((ExprListNode) exprList).getExprs().size(); + + for (int e = 0; e < dimGetExpr; e++) { + if (e == dimGetExpr - 1) { + R.remove(((ExprNode) ((ExprNode) ((ExprListNode) exprList).getElem(e)).getExpr(0)).getId()); + } else { + R.remove(((ExprNode) ((ExprListNode) exprList).getElem(e)).getId()); + } + } + // Block 1 is for the for-else statement - Node block = visit(ctx.block(0)); + BlockNode block = (BlockNode) visit(ctx.block(0)); + + Node forNode = new ForStmtNode(exprList, block); + if (!optimize) { + return forNode; + } + + int lineStart = ctx.getStart().getLine(); + int lineStop = ctx.getStop().getLine(); + int index = ctx.getStart().getTokenIndex(); - return new ForStmtNode(exprList, block); + rewriter.insertAfter(index, " "); + // NOTE: works only for one argument + rewriter.insertAfter(index + 1, " "); + rewriter.insertAfter(index + 2, " "); + optimizeWithSecond(block, lineStart, lineStop, index); + optimizeWithThird(block, lineStart, lineStop, index); + + return forNode; + } + + private void optimizeWithThird(BlockNode block, int lineStart, int lineStop, int index) { + int counter = 0; + ArrayList<Node> stms = block.getChilds(); + for (var e : stms) { + if (e instanceof SimpleStmtsNode) { + SimpleStmtsNode stmss = (SimpleStmtsNode) e; + for (Node stm : stmss.getStmts()) { + SimpleStmtNode singleStm = (SimpleStmtNode) stm; + AssignmentNode ass = (AssignmentNode) singleStm.getAssignment(); + if (ass != null) { + ExprListNode rhr = ass.getRhr(); + ExprNode rExpr = (ExprNode) rhr.getElem(0); + ArrayList<Node> exprsList = rExpr.getExprs(); + if (exprsList.size() > 1) { + List<Node> exprsLists = exprsList.subList(0, exprsList.size() - 1); + for (var elem : exprsLists) { + if (elem instanceof ExprNode) { + ExprNode exprNode = (ExprNode) elem; + if (exprNode.typeCheck() instanceof AtomType) { + continue; + } + } + ArrayList<String> al = findAtomPresent(elem, new ArrayList<>()); + if (!al.isEmpty()) { + boolean constant = true; + for (String a : al) { + int n = R.get(a); + if (n > lineStart && n <= lineStop) { + constant = false; + break; + } + } + if (constant) { + String newVar = Label.newVar(); + rewriter.insertBefore(index, newVar + "=" + elem.toPrint("") + "\n"); + int firstToken = ass.getLhrIndex() + 2; + int lastToken = ass.getRhrIndex() - 2; + rewriter.replace(firstToken, lastToken, newVar); + } + } + counter++; + } + } + } + } + } + } } /** @@ -418,18 +654,35 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { if (ctx.atom() != null) { atom = visit(ctx.atom()); - } - if (ctx.comp_op() != null) { - compOp = visit(ctx.comp_op()); - } + for (TrailerContext s : ctx.trailer()) { + trailers.add(visit(s)); + } + } else { + if (ctx.ADD(0) != null) { + op = ctx.ADD(0).toString(); - for (ExprContext s : ctx.expr()) { - exprs.add(visit(s)); - } + } else if (ctx.MINUS(0) != null) { + op = ctx.MINUS(0).toString(); - for (TrailerContext s : ctx.trailer()) { - trailers.add(visit(s)); + } else if (ctx.NOT() != null) { + op = ctx.NOT().toString(); + + } else if (ctx.STAR() != null) { + op = ctx.STAR().toString(); + + } else if (ctx.DIV() != null) { + op = ctx.DIV().toString(); + + } + + if (ctx.comp_op() != null) { + compOp = visit(ctx.comp_op()); + } + + for (ExprContext s : ctx.expr()) { + exprs.add(visit(s)); + } } return new ExprNode(atom, compOp, exprs, op, trailers); @@ -461,11 +714,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { } return new AtomNode(varName, null); } else if (ctx.OPEN_BRACE() != null && ctx.CLOSE_BRACE() != null) { - return manageCompListContext(tlc); + return manageCompListContext(tlc, "{", "}"); } else if (ctx.OPEN_BRACK() != null && ctx.CLOSE_BRACK() != null) { - return manageCompListContext(tlc); + return manageCompListContext(tlc, "[", "]"); } else if (ctx.OPEN_PAREN() != null && ctx.CLOSE_PAREN() != null) { - return manageCompListContext(tlc); + return manageCompListContext(tlc, "(", ")"); } return new AtomNode(null, null); } @@ -475,12 +728,12 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { * `testlist_comp` set if the context is not null. Otherwise, returns an * `AtomNode` with nulls. */ - public AtomNode manageCompListContext(Testlist_compContext tlc) { + public AtomNode manageCompListContext(Testlist_compContext tlc, String prefix, String suffix) { if (tlc != null) { Node testlist_comp = visit(tlc); - return new AtomNode(null, testlist_comp); + return new AtomNode(null, testlist_comp, prefix, suffix); } - return new AtomNode(null, null); + return new AtomNode(null, null, prefix, suffix); } /** @@ -491,6 +744,19 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { */ public Node visitTrailer(TrailerContext ctx) { Node arglist = null; + String prefix = ""; + String suffix = ""; + + if (ctx.OPEN_BRACK() != null) { + prefix = "["; + suffix = "]"; + } else if (ctx.OPEN_PAREN() != null) { + prefix = "("; + suffix = ")"; + } else if (ctx.DOT() != null) { + prefix = "."; + } + if (ctx.arglist() != null) { arglist = visit(ctx.arglist()); } @@ -506,11 +772,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { methodCall = ctx.NAME(); } - return new TrailerNode(arglist, exprs, methodCall, ctx.OPEN_PAREN() != null); + return new TrailerNode(arglist, exprs, methodCall, prefix, suffix); } /** - * Returns a `Node`. FIXME: what to do in case of list?? + * Returns a `Node`. * * ``` exprlist : expr (',' expr )* ','? ; ``` */ @@ -577,12 +843,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> { /** * Returns a `CompIterNode`. + * NOTE: We ignore `comp_if`. * * ``` comp_iter : comp_for | comp_if ; ;``` */ public Node visitComp_iter(Comp_iterContext ctx) { - // TODO: Implement comp_if - // Node iter = visit(ctx.comp_if()); Comp_forContext cfc = ctx.comp_for(); Node forNode = null; if (cfc != null) { |