diff --git a/aeneas/src/ir/Ir.v3 b/aeneas/src/ir/Ir.v3 index 5550e6ffa..29b98ae9d 100644 --- a/aeneas/src/ir/Ir.v3 +++ b/aeneas/src/ir/Ir.v3 @@ -25,6 +25,7 @@ class IrClass extends IrItem { var machSize: int = -1; var boxing: Boxing; var packed: bool; + var packingExpr: List; new(ctype, typeArgs, parent, fields, methods) { } def inherits(m: IrMember) -> bool { diff --git a/aeneas/src/ir/Normalization.v3 b/aeneas/src/ir/Normalization.v3 index a00ef32ca..ffa69c3a8 100644 --- a/aeneas/src/ir/Normalization.v3 +++ b/aeneas/src/ir/Normalization.v3 @@ -18,6 +18,7 @@ class NormalizerConfig { var GetScalar: (Compiler, Program, Type) -> Scalar.set = defaultGetScalar; var GetBitWidth: (Compiler, Program, Type) -> byte = defaultGetBitWidth; var MaxScalarWidth: byte = 64; + var UsedScalars: Scalar.set = Scalar.B32 | Scalar.B64 | Scalar.F32 | Scalar.F64 | Scalar.Ref; def setSignatureLimits(maxp: int, maxr: int) { if (maxp < MaxParams) MaxParams = maxp; diff --git a/aeneas/src/ir/PackingSolver.v3 b/aeneas/src/ir/PackingSolver.v3 index d54773f05..9e7df536f 100644 --- a/aeneas/src/ir/PackingSolver.v3 +++ b/aeneas/src/ir/PackingSolver.v3 @@ -29,6 +29,7 @@ def COLOR = false; // An interval from {start}, inclusive, to {end}, exclusive. type Interval(start: byte, end: byte) #unboxed { def render(buf: StringBuilder) -> StringBuilder { return buf.put2("%d...%d", start, end); } + def shl(n: byte) -> Interval { return Interval(start + n, end + n); } def size() -> byte { return end - start; } } @@ -62,7 +63,10 @@ class ScalarPattern(bits: Array) { var start = 0; for (end < size) { if (bits[end] != PackingBit.Unassigned) start = end + 1; - else if (end + 1 - start >= s) intervals.put(Interval(byte.view(end + 1 - s), byte.view(end + 1))); + else if (end + 1 - start >= s) { + intervals.put(Interval(byte.view(end + 1 - s), byte.view(end + 1))); + start = end; // don't try all intervals + } } return intervals.extract(); } @@ -81,6 +85,12 @@ class ScalarPattern(bits: Array) { for (j = i.start; j < i.end; j++) bits[j] = PackingBit.Unassigned; } def copy() -> ScalarPattern { return ScalarPattern.new(Arrays.dup(bits)); } + def padTo(n: byte) -> ScalarPattern { return ScalarPattern.new(Arrays.grow(bits, n)); } + def extractFixed() -> u64 { + var result: u64; + for (i < size) if (bits[i] == PackingBit.Fixed(1)) result |= 1ul << byte.view(i); + return result; + } } type PackingField #unboxed { diff --git a/aeneas/src/ir/SsaNormalizer.v3 b/aeneas/src/ir/SsaNormalizer.v3 index afdf40f9e..c1f730c47 100644 --- a/aeneas/src/ir/SsaNormalizer.v3 +++ b/aeneas/src/ir/SsaNormalizer.v3 @@ -1384,11 +1384,19 @@ class SsaRaNormalizer extends SsaRebuilder { var result = Array.new(vn.size); for (i < result.length) result[i] = newGraph.nullConst(vn.at(i)); - + for (i < vn.pattern.scalars.length) { + if (IntRepType.?(vn.at(i))) { + var irt = IntRepType.!(vn.at(i)); + var fixed = vn.pattern.scalars[i].extractFixed(); + if (irt.width == 64) result[i] = newGraph.valConst(irt, irt.boxL(long.view(fixed))); + else result[i] = newGraph.valConst(irt, irt.box(int.view(fixed))); + } + } if (vn.hasExplicitTag()) { var tagIdx = vn.tag.indexes[0]; - if (IntRepType.?(vn.at(tagIdx))) result[tagIdx] = genSetInterval(result[tagIdx], newGraph.intConst(vn.tagValue), vn.tag.intervals[0], vn.tag.tn.newType, IntRepType.!(vn.at(tagIdx))); - else result[tagIdx] = newGraph.intConst(vn.tagValue); + if (!IntRepType.?(vn.at(tagIdx))) { + result[tagIdx] = newGraph.intConst(vn.tagValue); // explicit tag scalar (not packed) + } } for (i < vn.fields.length) { diff --git a/aeneas/src/ir/VariantSolver.v3 b/aeneas/src/ir/VariantSolver.v3 index eaa4d481f..f66e1e5e3 100644 --- a/aeneas/src/ir/VariantSolver.v3 +++ b/aeneas/src/ir/VariantSolver.v3 @@ -29,6 +29,9 @@ class VariantPattern(scalars: Array) { } return buf; } + def copy() -> VariantPattern { + return VariantPattern.new(Arrays.map(scalars, ScalarPattern.copy)); + } } // The normalization of a variant, including the mapping of its fields to bit ranges within scalars. @@ -41,6 +44,8 @@ class VariantNorm extends TypeNorm { var tagValue: int = -1; var children: List; + var pattern: VariantPattern; + new(oldType: Type, newType: Type, sub: Array, fields, tag) super(oldType, newType, sub) {} @@ -77,7 +82,7 @@ class VariantNorm extends TypeNorm { // Metadata about a variant's fields before/after normalization. // The field {indexes} indicates which scalars store the field. -class VariantField(tn: TypeNorm, indexes: Array) { +class VariantField(rf: RaField, tn: TypeNorm, indexes: Array) { var intervals: Array; // null, if no packing def isPacked() -> bool { return intervals != null; } @@ -95,7 +100,7 @@ class VariantField(tn: TypeNorm, indexes: Array) { } def ON_STACK = -1; -def EMPTY_FIELD = VariantField.new(null, []); +def EMPTY_FIELD = VariantField.new(null, null, []); def NO_FIELDS = Array.new(0); def copyMap(x: HashMap) -> HashMap { @@ -108,10 +113,250 @@ def copyMap(x: HashMap) -> HashMap { return newMap; } -class VariantSolution { +type FlattenedExpr(assignment: Array<(int, Interval)>, pattern: ScalarPattern, size: byte) #unboxed { + def render(buf: StringBuilder) -> StringBuilder { + buf.put1("%q[", pattern.render); + for (i < assignment.length) { + if (i > 0) buf.csp(); + buf.put2("%d@%q", assignment[i].0, assignment[i].1.render); + } + buf.puts("]"); + return buf; + } +} + +class PackingContext { + def vars = Strings.newMap(); + def keyList = Vector.new(); + + def addVar(name: string, fieldIdx: int) { + if (!vars.has(name)) keyList.put(name); + vars[name] = fieldIdx; + } + def getVar(name: string) -> int { return vars[name]; } + def getVarFromChar(c: byte) -> int { + for (i < keyList.length) if (keyList[i][0] == c) return vars[keyList[i]]; + return -1; + } +} + +class VariantProblemGen(vn: VariantNormalizer) { var normFields: Array>; + var assignments = HashMap.new(CaseField.hash, CaseField.==); + var state: Array; + var hasPackingExpr: bool; + + def generateProblemFromRc(rc: RaClass) -> VariantProblem { + var shouldPack = rc.orig.packed; + var numChildren = Lists.length(rc.children); + normFields = Array>.new(if(rc.children == null, 1, numChildren)); + state = Array.new(normFields.length); + + if (rc.children == null) { + normFields[0] = getNormFieldsForCase(rc); + processAnnotationsForCase(rc, 0); + } else { + var caseIdx = 0; + for (l = rc.children; l != null; l = l.tail) { + normFields[caseIdx] = getNormFieldsForCase(l.head); + processAnnotationsForCase(l.head, caseIdx); + caseIdx++; + } + } + var problem = VariantProblem.new(normFields, assignments, shouldPack); + if (hasPackingExpr) problem.state = state; + return problem; + } + private def processAnnotationsForCase(rc: RaClass, caseIdx: int) { + if (!rc.orig.packed) return; + var packingExpr = rc.orig.packingExpr; + if (packingExpr != null) hasPackingExpr = true; + + var patterns = Vector.new(); + var scalarIdx = 0; + var context = PackingContext.new(); + + for (i < rc.fields.length) { + var rf = rc.fields[i]; + if (rf == null) continue; + context.addVar(rf.orig.source.name(), rf.normIndices.0); + } + for (l = packingExpr; l != null; l = l.tail) { + var expr = processPackingExpr(context, caseIdx, l.head); + for (i < expr.assignment.length) { + assignments[CaseField(caseIdx, expr.assignment[i].0)] = (scalarIdx, expr.assignment[i].1); + } + patterns.put(expr.pattern); + scalarIdx++; + } + state[caseIdx] = VariantPattern.new(patterns.extract()); + } + private def processPackingExpr(context: PackingContext, caseIdx: int, p: VstPackingExpr) -> FlattenedExpr { + match (p) { + Field(f) => { + var fieldIdx = context.getVar(f.ident.name.image); + var size = vn.getBitWidth(normFields[caseIdx][fieldIdx]); + var pattern = ScalarPattern.new(Arrays.growV(Array.new(0), size, PackingBit.Assigned(0))); + return FlattenedExpr([(fieldIdx, Interval(0, size))], pattern, size); + } + Literal(l) => { // write into pattern + match (l) { + x: IntLiteral => { + var t = IntType.!(x.exactType); + var val = Long.unboxSU(x.val, false); + var bits = Array.new(t.width); + for (i < t.width) bits[i] = PackingBit.Fixed(u1.view(val >> i & 1)); + return FlattenedExpr([], ScalarPattern.new(bits), t.width); // todo + } + } + } + Concat(l) => { + var combined = FlattenedExpr([], ScalarPattern.new([]), 0); + for (l = l.list; l != null; l = l.tail) { + var flattened = processPackingExpr(context, caseIdx, l.head); + var shifted = Array<(int, Interval)>.new(flattened.assignment.length); + for (i < shifted.length) shifted[i] = (flattened.assignment[i].0, flattened.assignment[i].1.shl(combined.size)); + combined = FlattenedExpr( + Arrays.concat(combined.assignment, shifted), + ScalarPattern.new(Arrays.concat(combined.pattern.bits, flattened.pattern.bits)), + combined.size + flattened.size + ); + } + return combined; + } + Bits(t, rep) => { // process as intervals + var curBits = Array.new(rep.length); + var assignments = Vector<(int, Interval)>.new(); + var curByte: byte = 0, oldByte: byte = 0; + var curStart: byte = 0, oldStart: byte = 0; + for (i < rep.length) { + oldByte = curByte; + oldStart = curStart; + match (rep[i].1) { + Assigned(ch) => { + if (ch != curByte) { + curByte = ch; + curStart = byte.view(i); + } + curBits[i] = PackingBit.Assigned(0); + } + Fixed(v), Unassigned => { + curByte = 0; + curBits[i] = rep[i].1; + } + } + if (oldByte != curByte && oldByte != 0) assignments.put((context.getVarFromChar(oldByte), Interval(oldStart, byte.view(i)))); + } + if (curByte != 0) assignments.put((context.getVarFromChar(curByte), Interval(curStart, byte.view(rep.length)))); + return FlattenedExpr(assignments.extract(), ScalarPattern.new(curBits), byte.view(curBits.length)); + } + App(p, args) => { + var decl = vn.rn.ra.prog.packings[p.ident.name.image]; + var argExprs = Lists.map(args.list, processPackingExpr(context, caseIdx, _)); + + var appContext = PackingContext.new(); + var argIdx = 0; + for (l = decl.pparams.list; l != null; l = l.tail) { + var pparam = l.head; + appContext.addVar(pparam.token.image, argIdx++); + } + var flattened = processPackingExpr(appContext, 0, decl.expr); + var newAssignments = Vector<(int, Interval)>.new(); + var newPattern = flattened.pattern.copy(); + + argIdx = 0; + for (l = argExprs; l != null; l = l.tail) { + var argExp = l.head; + for (i < flattened.assignment.length) { + if (flattened.assignment[i].0 == argIdx) { + var start = flattened.assignment[i].1.start; + for (assn in argExp.assignment) newAssignments.put((assn.0, assn.1.shl(start))); + for (j < argExp.size) newPattern.bits[start + j] = argExp.pattern.bits[j]; + break; + } + } + argIdx++; + } + return FlattenedExpr(newAssignments.extract(), newPattern, flattened.size); + } + Solve => ; // unimplemented + } + return FlattenedExpr([], null, 0); + } + private def getNormFieldsForCase(rc: RaClass) -> Array { + var rfs = rc.fields; + var fields = Array.new(rfs.length); + + rc.fieldRangesO = Array<(int, int)>.new(rfs.length); + rc.fieldRangesT = Array<(int, int)>.new(rfs.length); + + var fieldTypes = Vector.new(); + var vecO = Vector.new(); + + for (i < rfs.length) { + var rf = rc.fields[i]; + var startT = fieldTypes.length; + + if (rf != null && rf.normIndices.0 >= 0) { + var tn = vn.fieldNorm(rf); + tn.addTo(fieldTypes); + fields[i] = VariantField.new(rf, tn, Array.new(tn.size)); + } else fields[i] = EMPTY_FIELD; + + var origStart = vecO.length; + var fieldType = rc.orig.fields[i].fieldType.substitute(V3.getTypeArgs(rc.oldType)); + vn.rn.norm(fieldType).addTo(vecO); + rc.fieldRangesO[i] = (origStart, vecO.length); + rc.fieldRangesT[i] = (startT, fieldTypes.length); + } + + rc.variantFields = fields; + rc.vecO = vecO.extract(); + return fieldTypes.extract(); + } +} + +// an input to the VariantSolver +class VariantProblem { + var normFields: Array>; + var state: Array; var assignments: HashMap; + var types: Array; + var usePacking: bool; + + new(normFields, assignments, usePacking) {} + + def render(buf: StringBuilder) -> StringBuilder { + for (i < normFields.length) { + if (i > 0) buf.csp(); + buf.put1("c%d: [", i); + for (j < normFields[i].length) { + if (j > 0) buf.csp(); + var cf = CaseField(i, j); + var field = normFields[i][j]; + buf.put2("f%d=%q: ", j, field.render); + + if (!assignments.has(cf)) { + buf.puts("?"); + } else { + var a = assignments[cf], scalarIdx = a.0, interval = a.1; + if (interval == EMPTY_INTERVAL) buf.put1("#%d@?", scalarIdx); + else buf.put2("#%d@%q", scalarIdx, interval.render); + } + } + buf.puts("]"); + } + buf.ln(); + if (state != null) for (i < state.length) buf.put1("%q\n", state[i].render); + return buf; + } +} + +class VariantSolution { + var normFields: Array>; var state: Array; + var assignments: HashMap; + var explicitTag: (int, Interval); var types: Array; @@ -119,16 +364,22 @@ class VariantSolution { var hasTagScalar = false; var tagType: IntType; + var cachedScore = -1; + new(normFields, assignments, state, explicitTag, types) {} def copy() -> VariantSolution { - return VariantSolution.new( + var vs = VariantSolution.new( normFields, - copyMap(assignments), - Arrays.dup(state), + copyMap(assignments), + Arrays.map(state, VariantPattern.copy), explicitTag, Arrays.dup(types) ); + vs.tagType = tagType; + vs.hasIntervals = hasIntervals; + vs.hasTagScalar = hasTagScalar; + return vs; } def render(buf: StringBuilder) -> StringBuilder { for (i < normFields.length) { @@ -151,31 +402,43 @@ class VariantSolution { buf.puts("]"); } buf.ln(); - for (i < state.length) { - buf.put1("%q\n", state[i].render); - } + if (state != null) for (i < state.length) buf.put1("%q\n", state[i].render); return buf; } - def score() -> int { - // TODO - return 0; + if (cachedScore >= 0) return cachedScore; + // XXX: better score function + var total = types.length * 50; + total += if(!hasTagScalar, 100); + // estimate access cost + if (hasIntervals) { + for (i < normFields.length) { + for (j < normFields[i].length) { + var assn = assignments[CaseField(i, j)]; + if (assn.1 != EMPTY_INTERVAL && assn.1.start > 0) total += 10; + } + } + } + return total; } } - def getTagLength(numCases: int) -> byte { var i: byte = 0; while (1 << i < numCases) i++; return i; } -class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { +class VariantSolver(vnorm: VariantNormalizer) { var normFields: Array>; + var usePacking: bool; var fields = Vector.new(); // an ordering of fields to solve for during packing var curSoln: VariantSolution; + var bestSoln: VariantSolution; + var assignments = HashMap.new(CaseField.hash, CaseField.==); + var baseState: Array; var state: Array; var explicitTag: (int, Interval) = (-1, EMPTY_INTERVAL); @@ -189,23 +452,54 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { if (bwCache.has(t)) return bwCache[t]; return bwCache[t] = vnorm.getBitWidth(t); } + private def getScalarsWithSize(s: Scalar.set, n: byte) -> Scalar.set { + var newRep = s; + for (i in s) if (i.size < n) newRep -= i; + return newRep; + } // Perform recursive backtracking on the potential representation. // With each scalar in each case, we have several choices to make: // 1. Pack this scalar with an existing scalar // 2. Append this scalar as a new scalar. // We can only determine distinguishability after all the scalars have been assigned. Heuristics will help speed this part up. - def solve(cases: Array>) -> VariantSolution { - normFields = cases; - explicitTagLength = getTagLength(cases.length); + def solve(problem: VariantProblem) -> VariantSolution { + normFields = problem.normFields; + baseState = problem.state; + assignments = problem.assignments; + + explicitTagLength = getTagLength(problem.normFields.length); + usePacking = problem.usePacking; + + if (problem.state != null) { + var numScalars = problem.state[0].scalars.length; + curRep.putn(vnorm.nc.UsedScalars, numScalars); + for (i < numScalars) curUsed.put(0); + } - // TODO: Express solving constraints here (from annotations) for (i < normFields.length) { - for (j < normFields[i].length) fields.put(CaseField(i, j)); + for (j < normFields[i].length) { + if (!assignments.has(CaseField(i, j))) { + fields.put(CaseField(i, j)); + } else { + var assn = assignments[CaseField(i, j)]; + var sc = vnorm.getScalar(normFields[i][j]); + curRep[assn.0] &= sc; + } + } } - var solvable = tryRepresentationForField(0, 0); - return if(solvable, curSoln, null); + if (baseState != null) { + for (i < baseState.length) { + for (j < curRep.length) { + curRep[j] = getScalarsWithSize(curRep[j], byte.view(baseState[i].scalars[j].bits.length)); + var scalar = getScalarFromSet(curRep[j]); + baseState[i].scalars[j] = baseState[i].scalars[j].padTo(scalar.size); + } + } + } + tryRepresentationForField(0, 0); + return bestSoln; } - private def tryRepresentationForField(curCase: int, curField: int) -> bool { + private def tryRepresentationForField(curCase: int, curField: int) { if (curCase >= normFields.length) { // XXX: check distinguishable return solvePacking(); @@ -216,6 +510,8 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { for (i < curRep.length) curUsed.put(0); return tryRepresentationForField(curCase + 1, 0); } + var cf = CaseField(curCase, curField); + if (assignments.has(cf) && assignments[cf].0 != -1) return tryRepresentationForField(curCase, curField + 1); var sc = vnorm.getScalar(normFields[curCase][curField]); var bw = getBitWidth(normFields[curCase][curField]); @@ -227,23 +523,23 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { var oldRep = curRep[i]; var newRep = sc & oldRep; - for (j in sc & oldRep) if (usePacking && j.size < curUsed[i] + bw) newRep -= j; + if (usePacking) newRep = getScalarsWithSize(newRep, curUsed[i] + bw); if (newRep != none) { curUsed[i] += bw; curRep[i] = newRep; - assignments[CaseField(curCase, curField)] = (i, EMPTY_INTERVAL); - if (tryRepresentationForField(curCase, curField + 1)) return true; + assignments[cf] = (i, EMPTY_INTERVAL); + tryRepresentationForField(curCase, curField + 1); + assignments[cf] = (-1, EMPTY_INTERVAL); curUsed[i] -= bw; curRep[i] = oldRep; } } var len = curRep.length; curRep.put(sc); curUsed.put(bw); - assignments[CaseField(curCase, curField)] = (curRep.length - 1, EMPTY_INTERVAL); - if (tryRepresentationForField(curCase, curField + 1)) return true; + assignments[cf] = (curRep.length - 1, EMPTY_INTERVAL); + tryRepresentationForField(curCase, curField + 1); curRep.resize(len); curUsed.resize(len); - - return false; + assignments[cf] = (-1, EMPTY_INTERVAL); } private def tryExplicitTaggingHeuristic() -> bool { for (i < state[0].scalars.length) { @@ -267,8 +563,9 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { if (explicitTagLength > 0 && longest.size() >= explicitTagLength) { for (j < state.length) { + var realTag = normFields.length - j - 1; // variants are in reverse order for (k < explicitTagLength) { - var bit = u1.!((j >> k) & 1); + var bit = u1.!((realTag >> k) & 1); state[j].scalars[i].bits[longest.start + k] = PackingBit.Fixed(bit); } } @@ -278,19 +575,21 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { } return false; } - private def solvePacking() -> bool { + private def solvePacking() { var cases = Array.new(normFields.length); - // build up patterns from the individual cases - for (i < normFields.length) { - var patterns = Array.new(curRep.length); - for (j < curRep.length) patterns[j] = getScalarPattern(getScalarFromSet(curRep[j])); - cases[i] = VariantPattern.new(patterns); + state = baseState; + if (state == null) { + // build up patterns from the individual cases + for (i < normFields.length) { + var patterns = Array.new(curRep.length); + for (j < curRep.length) patterns[j] = getScalarPattern(getScalarFromSet(curRep[j])); + cases[i] = VariantPattern.new(patterns); + } + state = cases; } - state = cases; - - return solveField(0); + solveField(0); } - private def solveField(idx: int) -> bool { + private def solveField(idx: int) { if (idx >= fields.length) { return checkDistinguishable(); } @@ -303,15 +602,14 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { var scalar = casePattern.scalars[scalarIdx]; var intervals = scalar.getIntervalsForSize(bw); - if (intervals.length == 0) return false; + if (intervals.length == 0) return; for (interval in intervals) { scalar.assignInterval(interval); assignments[field] = (scalarIdx, interval); var result = solveField(idx + 1); - if (result) return true; + assignments[field] = (scalarIdx, EMPTY_INTERVAL); scalar.unassignInterval(interval); } - return false; } private def getScalarPattern(s: Scalar) -> ScalarPattern { // XXX: Should be based on scalar type @@ -320,12 +618,19 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { private def canDistinguish(active: Array) { // TODO - this is where unassigned bits will get assigned } - private def checkDistinguishable() -> bool { + private def updateSoln(soln: VariantSolution) { + if (bestSoln == null || bestSoln.score() > soln.score()) { + bestSoln = soln.copy(); + // Terminal.put2("updating with soln[score=%d] %q\n", bestSoln.score(), bestSoln.render); + } + } + private def checkDistinguishable() { if (normFields.length <= 1) { var types = getTypes(false); curSoln = VariantSolution.new(normFields, assignments, state, explicitTag, types); curSoln.hasIntervals = usePacking; - return true; // single case is always distinguishable + updateSoln(curSoln); + return; // single case is always distinguishable } var packedTag: bool; @@ -336,9 +641,19 @@ class VariantSolver(vnorm: VariantNormalizer, usePacking: bool) { curSoln.tagType = Int.getType(false, getTagLength(normFields.length)); curSoln.hasTagScalar = !packedTag; curSoln.hasIntervals = usePacking; + updateSoln(curSoln); + + if (packedTag) { + // undo assignments + for (i < state.length) { + for (j < explicitTagLength) { + state[i].scalars[explicitTag.0].bits[explicitTag.1.start + j] = PackingBit.Unassigned; + } + } + explicitTag = (-1, EMPTY_INTERVAL); + } // TODO: try difficult tagging - return true; } private def getTypes(useTagScalar: bool) -> Array { var types = Array.new(curRep.length + if(useTagScalar, 1)); @@ -423,7 +738,7 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos // normalize empty variant to just its tag; i.e. become an enum var tagType = V3.getVariantTagType(rc.oldType); var tagTypeNorm = rn.norm(tagType); - var tagField = VariantField.new(tagTypeNorm, [0]); + var tagField = VariantField.new(null, tagTypeNorm, [0]); unboxUsingEnumVariantNorm(rc, tagType, tagField); return true; } @@ -445,11 +760,10 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos return true; } private def unboxUsingTaglessVariantNorm(rc: RaClass) { - var normFields = getNormFieldsForCase(rc); - var shouldPack = rc.orig.packed; - - var solver = VariantSolver.new(this, shouldPack); - var solution = solver.solve([normFields]); + var solver = VariantSolver.new(this); + var pgen = VariantProblemGen.new(this); + var problem = pgen.generateProblemFromRc(rc); + var solution = solver.solve(problem); var vn = VariantNorm.new(rc.oldType, Tuple.newType(Lists.fromArray(solution.types)), solution.types, rc.variantFields, null); @@ -457,40 +771,10 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos vn.vecO = rc.vecO; rc.variantNorm = vn; setVariantFields(rc, 0, solution); + vn.pattern = solution.state[0]; printNorm(rc, vn); } - private def getNormFieldsForCase(rc: RaClass) -> Array { - var rfs = rc.fields; - var fields = Array.new(rfs.length); - - rc.fieldRangesO = Array<(int, int)>.new(rfs.length); - rc.fieldRangesT = Array<(int, int)>.new(rfs.length); - - var fieldTypes = Vector.new(); - var vecO = Vector.new(); - - for (i < rfs.length) { - var rf = rc.fields[i]; - var startT = fieldTypes.length; - - if (rf != null && rf.normIndices.0 >= 0) { - var tn = fieldNorm(rf); - tn.addTo(fieldTypes); - fields[i] = VariantField.new(tn, Array.new(tn.size)); - } else fields[i] = EMPTY_FIELD; - - var origStart = vecO.length; - var fieldType = rc.orig.fields[i].fieldType.substitute(V3.getTypeArgs(rc.oldType)); - rn.norm(fieldType).addTo(vecO); - rc.fieldRangesO[i] = (origStart, vecO.length); - rc.fieldRangesT[i] = (startT, fieldTypes.length); - } - - rc.variantFields = fields; - rc.vecO = vecO.extract(); - return fieldTypes.extract(); - } private def setVariantFields(rc: RaClass, caseIdx: int, soln: VariantSolution) { var ofs = rc.orig.fields; for (i < ofs.length) { @@ -511,25 +795,22 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos var parentUnboxed = rc.orig.boxing == Boxing.UNBOXED; var numChildren = Lists.length(rc.children); - var normFields = Array>.new(numChildren); - - var caseIdx = 0; - for (l = rc.children; l != null; l = l.tail) normFields[caseIdx++] = getNormFieldsForCase(l.head); - if (rc.recursive > 1) return false; // now that we know it's not recursive, we can safely assign the variant norm - var shouldPack = rc.orig.packed; - var needsTagScalar = !shouldPack; - - var solver = VariantSolver.new(this, shouldPack); - var solution = solver.solve(normFields); + var solver = VariantSolver.new(this); + var pgen = VariantProblemGen.new(this); + var problem = pgen.generateProblemFromRc(rc); + var solution = solver.solve(problem); + if (solution == null) return false; // XXX: warn that no solution was found + Terminal.put1("solution=%q\n", solution.render); + var tagField: VariantField; - if (solution.hasTagScalar) tagField = VariantField.new(rn.norm(solution.tagType), [solution.types.length - 1]); + if (solution.hasTagScalar) tagField = VariantField.new(null, rn.norm(solution.tagType), [solution.types.length - 1]); else if (solution.explicitTag.0 >= 0) { - tagField = VariantField.new(rn.norm(solution.tagType), [solution.explicitTag.0]); + tagField = VariantField.new(null, rn.norm(solution.tagType), [solution.explicitTag.0]); tagField.intervals = [solution.explicitTag.1]; } var newType = Tuple.newType(Lists.fromArray(solution.types)); @@ -539,19 +820,21 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos rc.variantNorm = parentNorm; var children = rc.children; - caseIdx = 0; + var caseIdx = 0; for (l = rc.children; l != null; l = l.tail) { var child = l.head; - setVariantFields(child, caseIdx++, solution); + setVariantFields(child, caseIdx, solution); var vn = VariantNorm.new(child.oldType, newType, solution.types, child.variantFields, tagField); vn.fieldRanges = child.fieldRangesO; vn.vecO = child.vecO; vn.tagValue = V3.getVariantTag(child.oldType); + vn.pattern = solution.state[caseIdx]; child.variantNorm = vn; parentNorm.children = List.new(vn, parentNorm.children); if (verbose) Terminal.put1(" %q\n", vn.render); + caseIdx++; } printNorm(rc, parentNorm); return true; @@ -575,19 +858,23 @@ class VariantNormalizer(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbos def getBitWidth(t: Type) -> byte { return nc.GetBitWidth(rn.ra.compiler, rn.ra.prog, t); } + private def printField(field: VariantField, idx: int, tag: bool) { + if (tag) buf.puts("tag:"); + else buf.put2("%s.%d:", field.rf.orig.source.name(), idx); + buf.green().put1("%q", field.tn.at(idx).render).end().puts("->").cyan().put1("#%d", field.indexes[idx]); + if (field.intervals != null && field.intervals[idx] != EMPTY_INTERVAL) buf.put1("@%q", field.intervals[idx].render); + buf.end().putc(' '); + } private def printSingleNorm(rc: RaClass, vn: VariantNorm) { buf.typeColor().put1(" %q ", rc.oldType.render).end(); for (field in vn.fields) { if (field == EMPTY_FIELD) continue; - for (j < field.tn.size) { - buf.put2("%q.%d:", field.tn.oldType.render, j).green().put1("%q", field.tn.at(j).render).end() - .puts("->").cyan().put1("#%d", field.indexes[j]); - if (field.intervals != null && field.intervals[j] != EMPTY_INTERVAL) buf.put1("@%q", field.intervals[j].render); - buf.end().putc(' '); - } + for (j < field.tn.size) printField(field, j, false); } + if (vn.tag != null) printField(vn.tag, 0, true); buf.outln(); + if (vn.pattern != null) buf.purple().put1(" %q\n", vn.pattern.render).end(); buf.outt(); } private def printNorm(rc: RaClass, vn: VariantNorm) { diff --git a/aeneas/src/ir/VstIr.v3 b/aeneas/src/ir/VstIr.v3 index e9dd5a0e3..6ea81ef37 100644 --- a/aeneas/src/ir/VstIr.v3 +++ b/aeneas/src/ir/VstIr.v3 @@ -45,7 +45,7 @@ class IrBuilder(ctype: Type, parent: IrClass) { def buildClass(decl: VstCompound) -> IrClass { fields.grow(decl.numFields); methods.grow(decl.numMethods + 1); - var boxing = Boxing.AUTO, isVariant = false, packed = false; + var boxing = Boxing.AUTO, isVariant = false, packed = false, packingExpr: List; match (decl) { cdecl: VstClass => { isVariant = cdecl.isVariant(); @@ -55,6 +55,10 @@ class IrBuilder(ctype: Type, parent: IrClass) { Boxed => boxing = Boxing.BOXED; Unboxed => boxing = Boxing.UNBOXED; Packed => packed = true; + Packing(p) => { + packed = true; + packingExpr = p.list; + } _ => ; } } @@ -70,6 +74,7 @@ class IrBuilder(ctype: Type, parent: IrClass) { var ic = build(); ic.boxing = boxing; ic.packed = packed; + ic.packingExpr = packingExpr; return ic; } def addVstField(f: VstField, isVariant: bool) { diff --git a/aeneas/src/jvm/JvmTarget.v3 b/aeneas/src/jvm/JvmTarget.v3 index 4e45698f3..802ec51d1 100644 --- a/aeneas/src/jvm/JvmTarget.v3 +++ b/aeneas/src/jvm/JvmTarget.v3 @@ -36,6 +36,7 @@ class JvmTarget extends Target { norm.NormalizeRange = false; // norm.setSignatureLimits(10000, 1); norm.GetScalar = getScalar; + norm.UsedScalars = Scalar.B32 | Scalar.B64 | Scalar.F32 | Scalar.F64 | Scalar.Ref; } private def isRefType(t: Type) -> bool { match (t) { diff --git a/aeneas/src/os/Linux.v3 b/aeneas/src/os/Linux.v3 index de76a19ca..04920baa6 100644 --- a/aeneas/src/os/Linux.v3 +++ b/aeneas/src/os/Linux.v3 @@ -26,6 +26,9 @@ class LinuxTarget extends Target { } compiler.NormConfig.GetScalar = getScalar; compiler.NormConfig.GetBitWidth = getBitWidth; + + if (space.addressWidth == 32) compiler.NormConfig.UsedScalars = Scalar.B32 | Scalar.B64 | Scalar.F32 | Scalar.F64 | Scalar.R32; + else compiler.NormConfig.UsedScalars = Scalar.B64 | Scalar.F64 | Scalar.R64; } private def getScalar(compiler: Compiler, prog: Program, t: Type) -> Scalar.set { if (space.addressWidth == 32) { diff --git a/aeneas/src/vst/Verifier.v3 b/aeneas/src/vst/Verifier.v3 index 5aef44116..77f7a5378 100644 --- a/aeneas/src/vst/Verifier.v3 +++ b/aeneas/src/vst/Verifier.v3 @@ -899,7 +899,7 @@ class VstCompoundVerifier { var total: byte = 0; for (e = l.list; e != null; e = e.tail) { var s = inferPackingExpr(env, e.head); - if (total + s >= MAX_PACKING_WIDTH) { + if (total + s > MAX_PACKING_WIDTH) { total = MAX_PACKING_WIDTH; errAtRange(l.range()).PackingExprTooLong(MAX_PACKING_WIDTH, total + s); } else { diff --git a/aeneas/src/wasm/WasmTarget.v3 b/aeneas/src/wasm/WasmTarget.v3 index 3befb891f..dca0b6593 100644 --- a/aeneas/src/wasm/WasmTarget.v3 +++ b/aeneas/src/wasm/WasmTarget.v3 @@ -90,6 +90,7 @@ class WasmTarget extends Target { compiler.Reachability = true; compiler.NormConfig.setSignatureLimits(10000, if(CLOptions.WASM_MULTI_VALUE.val, 1000, 1)); compiler.NormConfig.GetScalar = getScalar; + compiler.NormConfig.UsedScalars = Scalar.B32 | Scalar.B64 | Scalar.F32 | Scalar.F64 | Scalar.R64; } private def getScalar(compiler: Compiler, prog: Program, t: Type) -> Scalar.set { var none: Scalar.set; diff --git a/aeneas/src/x86-64/X86_64Darwin.v3 b/aeneas/src/x86-64/X86_64Darwin.v3 index a97560fa1..00164b882 100644 --- a/aeneas/src/x86-64/X86_64Darwin.v3 +++ b/aeneas/src/x86-64/X86_64Darwin.v3 @@ -34,6 +34,7 @@ class X86_64DarwinTarget extends Target { def configureCompiler(compiler: Compiler) { compiler.Reachability = true; compiler.NormConfig.GetScalar = getScalar; + compiler.NormConfig.UsedScalars = Scalar.B64 | Scalar.F64 | Scalar.R64; } private def getScalar(compiler: Compiler, prog: Program, t: Type) -> Scalar.set { match (t) { diff --git a/aeneas/src/x86/X86Darwin.v3 b/aeneas/src/x86/X86Darwin.v3 index f950396a9..2136930e9 100644 --- a/aeneas/src/x86/X86Darwin.v3 +++ b/aeneas/src/x86/X86Darwin.v3 @@ -20,6 +20,7 @@ class X86DarwinTarget extends Target { compiler.Reachability = true; compiler.NormConfig.GetScalar = getScalar; compiler.NormConfig.GetBitWidth = getBitWidth; + compiler.NormConfig.UsedScalars = Scalar.B32 | Scalar.B64 | Scalar.F32 | Scalar.F64 | Scalar.R32; } private def getScalar(compiler: Compiler, prog: Program, t: Type) -> Scalar.set { match (t) { diff --git a/test/variants/ub_packannot00.v3 b/test/variants/ub_packannot00.v3 new file mode 100644 index 000000000..426cd5303 --- /dev/null +++ b/test/variants/ub_packannot00.v3 @@ -0,0 +1,18 @@ +//@execute 0=12;1=34;2=112;3=156 +type T #packed #unboxed { + case A(x: int) #packing x { def f() -> int { return x; } } + case B(x: int) #packing x { def f() -> int { return x * 2; } } + + def f() -> int; +} + +def arr = [ + T.A(12), + T.A(34), + T.B(56), + T.B(78) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot01.v3 b/test/variants/ub_packannot01.v3 new file mode 100644 index 000000000..bc8eea212 --- /dev/null +++ b/test/variants/ub_packannot01.v3 @@ -0,0 +1,18 @@ +//@execute 0=3;1=7;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int) #packing #concat(x, y) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing #concat(y, x) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot02.v3 b/test/variants/ub_packannot02.v3 new file mode 100644 index 000000000..8bd6abd49 --- /dev/null +++ b/test/variants/ub_packannot02.v3 @@ -0,0 +1,18 @@ +//@execute 0=3;1=7;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int) #packing (x, y) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing (y, x) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot03.v3 b/test/variants/ub_packannot03.v3 new file mode 100644 index 000000000..26bd1d6d1 --- /dev/null +++ b/test/variants/ub_packannot03.v3 @@ -0,0 +1,18 @@ +//@execute 0=6;1=12;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int, z: int) #packing (x, #concat(y, z)) { def f() -> int { return x + y + z; } } + case B(x: u64, y: int) #packing (x, y) { def f() -> int { return int.view(x) - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2, 3), + T.A(3, 4, 5), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot04.v3 b/test/variants/ub_packannot04.v3 new file mode 100644 index 000000000..db074c86f --- /dev/null +++ b/test/variants/ub_packannot04.v3 @@ -0,0 +1,18 @@ +//@execute 0=6;1=12;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int, z: int) #packing (x, #concat(y, z)) { def f() -> int { return x + y + z; } } + case B(x: u64, y: int) #packing (y, x) { def f() -> int { return int.view(x) - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2, 3), + T.A(3, 4, 5), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot05.v3 b/test/variants/ub_packannot05.v3 new file mode 100644 index 000000000..770e45af7 --- /dev/null +++ b/test/variants/ub_packannot05.v3 @@ -0,0 +1,20 @@ +//@execute 0=12;1=34;2=112;3=156 +packing P(x: 32): 32 = x; + +type T #packed #unboxed { + case A(x: int) #packing P(x) { def f() -> int { return x; } } + case B(x: int) #packing P(x) { def f() -> int { return x * 2; } } + + def f() -> int; +} + +def arr = [ + T.A(12), + T.A(34), + T.B(56), + T.B(78) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot06.v3 b/test/variants/ub_packannot06.v3 new file mode 100644 index 000000000..f56e5b76f --- /dev/null +++ b/test/variants/ub_packannot06.v3 @@ -0,0 +1,20 @@ +//@execute 0=3;1=7;2=-1;3=-1 +packing P(x: 32, y: 32): 64 = #concat(x, y); + +type T #packed #unboxed { + case A(x: int, y: int) #packing P(x, y) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing P(y, x) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot07.v3 b/test/variants/ub_packannot07.v3 new file mode 100644 index 000000000..31b390c93 --- /dev/null +++ b/test/variants/ub_packannot07.v3 @@ -0,0 +1,21 @@ +//@execute 0=3;1=7;2=-1;3=-1 +packing Q(x: 32): 32 = x; +packing P(x: 32, y: 32): 64 = #concat(x, y); + +type T #packed #unboxed { + case A(x: int, y: int) #packing P(Q(x), Q(y)) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing P(Q(y), x) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packing09.v3 b/test/variants/ub_packannot09.v3 similarity index 91% rename from test/variants/ub_packing09.v3 rename to test/variants/ub_packannot09.v3 index ba21e22e7..cbd1997be 100644 --- a/test/variants/ub_packing09.v3 +++ b/test/variants/ub_packannot09.v3 @@ -1,5 +1,5 @@ //@execute 0=46;1=112 -type A00 #unboxed { +type A00 #unboxed #packed { case X(a: u5, b: u7) #packing 0b_aaaa_abbb_bbbb { def f() -> u32 { return a + b; } } diff --git a/test/variants/ub_packing10.v3 b/test/variants/ub_packannot10.v3 similarity index 100% rename from test/variants/ub_packing10.v3 rename to test/variants/ub_packannot10.v3 diff --git a/test/variants/ub_packannot11.v3 b/test/variants/ub_packannot11.v3 new file mode 100644 index 000000000..00c18ad18 --- /dev/null +++ b/test/variants/ub_packannot11.v3 @@ -0,0 +1,18 @@ +//@execute 0=3;1=7;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int) #packing (x, 123, y) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing (y, 456, x) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packannot12.v3 b/test/variants/ub_packannot12.v3 new file mode 100644 index 000000000..2ae91e77a --- /dev/null +++ b/test/variants/ub_packannot12.v3 @@ -0,0 +1,20 @@ +//@execute 0=5;1=14 +type A00 #unboxed #packed { + case X(a: u4) #packing 0b_?aaaa { + def f() -> u4 { return a; } + } + case Y(b: u4) #packing 0b_?bbbb { + def f() -> u4 { return b * 2u4; } + } + + def f() -> u4; +} + +def arr = [ + A00.X(5u4), + A00.Y(7u4) +]; + +def main(a: int) -> int { + return int.view(arr[a].f()); +} \ No newline at end of file diff --git a/test/variants/ub_packannot13.v3 b/test/variants/ub_packannot13.v3 new file mode 100644 index 000000000..2f42d9679 --- /dev/null +++ b/test/variants/ub_packannot13.v3 @@ -0,0 +1,20 @@ +//@execute 0=7;1=8 +type A00 #unboxed #packed { + case X(a: u4, b: u4) #packing 0b_bbbb?aaaa { + def f() -> u4 { return a + b; } + } + case Y(a: u4, b: u4) #packing 0b_aaaa?bbbb { + def f() -> u4 { return (a + b) * 2u4; } + } + + def f() -> u4; +} + +def arr = [ + A00.X(5u4, 2u4), + A00.Y(1u4, 3u4) +]; + +def main(a: int) -> int { + return int.view(arr[a].f()); +} \ No newline at end of file diff --git a/test/variants/ub_packannot14.v3 b/test/variants/ub_packannot14.v3 new file mode 100644 index 000000000..801d8447e --- /dev/null +++ b/test/variants/ub_packannot14.v3 @@ -0,0 +1,20 @@ +//@execute 0=7;1=8 +type A00 #unboxed #packed { + case X(a: u4, b: u4) #packing 0b_bbbb1aaaa { + def f() -> u4 { return a + b; } + } + case Y(a: u4, b: u4) #packing 0b_aaaa0bbbb { + def f() -> u4 { return (a + b) * 2u4; } + } + + def f() -> u4; +} + +def arr = [ + A00.X(5u4, 2u4), + A00.Y(1u4, 3u4) +]; + +def main(a: int) -> int { + return int.view(arr[a].f()); +} \ No newline at end of file diff --git a/test/variants/ub_packannot15.v3 b/test/variants/ub_packannot15.v3 new file mode 100644 index 000000000..f6dc68a18 --- /dev/null +++ b/test/variants/ub_packannot15.v3 @@ -0,0 +1,18 @@ +//@execute 0=3;1=7;2=-1;3=-1 +type T #packed #unboxed { + case A(x: int, y: int) #packing (x, y, 123) { def f() -> int { return x + y; } } + case B(x: int, y: int) #packing (y, x, 456) { def f() -> int { return x - y; } } + + def f() -> int; +} + +def arr = [ + T.A(1, 2), + T.A(3, 4), + T.B(5, 6), + T.B(7, 8) +]; + +def main(a: int) -> int { + return arr[a].f(); +} \ No newline at end of file diff --git a/test/variants/ub_packing08.v3 b/test/variants/ub_packing08.v3 index da7d346bb..44da6899c 100644 --- a/test/variants/ub_packing08.v3 +++ b/test/variants/ub_packing08.v3 @@ -3,7 +3,7 @@ type A00 #unboxed { case X(a: u32, b: u32) { def f() -> u32 { return a + b; } } - case Y(c: u32, d: u32) #packing #solve(c, d) { + case Y(c: u32, d: u32) { def f() -> u32 { return c * d; } }