From 293300a62551d5cec99d7dc138885912c8cf3de5 Mon Sep 17 00:00:00 2001 From: jand Date: Mon, 24 Nov 2003 15:09:07 +0000 Subject: [PATCH] function --- myhdl/_toVerilog.py | 407 +++++++++++++++++++++++-------- myhdl/_unparse.py | 6 +- myhdl/test/toVerilog/test_hec.py | 81 ++++-- 3 files changed, 371 insertions(+), 123 deletions(-) diff --git a/myhdl/_toVerilog.py b/myhdl/_toVerilog.py index ded4d8e6..efbfcc14 100644 --- a/myhdl/_toVerilog.py +++ b/myhdl/_toVerilog.py @@ -30,8 +30,9 @@ import operator import compiler from compiler import ast from sets import Set -from types import GeneratorType, ClassType +from types import GeneratorType, FunctionType, ClassType from cStringIO import StringIO +import __builtin__ import myhdl from myhdl import * @@ -130,38 +131,31 @@ def _analyzeSigs(hierarchy): def LabelGenerator(): i = 1 while 1: - yield "_MYHDL_LABEL_%s" % i + yield "__MYHDL__%s" % i i += 1 +genLabel = LabelGenerator() + def _analyzeGens(top, gennames): - genLabel = LabelGenerator() genlist = [] for g in top: f = g.gi_frame s = inspect.getsource(f) s = s.lstrip() - gen = compiler.parse(s) - gen.sourcefile = inspect.getsourcefile(f) - gen.lineoffset = inspect.getsourcelines(f)[1]-1 + ast = compiler.parse(s) + ast.sourcefile = inspect.getsourcefile(f) + ast.lineoffset = inspect.getsourcelines(f)[1]-1 symdict = f.f_globals.copy() symdict.update(f.f_locals) - # print f.f_locals - sigdict = {} - for n, v in symdict.items(): - if isinstance(v, Signal): - sigdict[n] = v - gen.sigdict = sigdict - gen.symdict = symdict - if gennames.has_key(id(g)): - gen.name = gennames[id(g)] - else: - gen.name = genLabel.next() - v = _AnalyzeGenVisitor(sigdict, symdict, gen.sourcefile, gen.lineoffset) - compiler.walk(gen, v) - gen.vardict = v.vardict - gen.kind = v.kind - genlist.append(gen) + ast.symdict = symdict + ast.name = gennames.get(id(g), genLabel.next() + "_BLOCK") + v = _AnalyzeBlockVisitor(symdict, ast.sourcefile, ast.lineoffset) + compiler.walk(ast, v) + ast.sigdict = v.sigdict + ast.vardict = v.vardict + ast.kind = v.kind + genlist.append(ast) return genlist @@ -247,8 +241,6 @@ class _NotSupportedVisitor(_ToVerilogMixin): def visitList(self, node, *args): self.raiseError(node, _error.NotSupported, "list") - def visitReturn(self, node, *args): - self.raiseError(node, _error.NotSupported, "return statement") def visitTryExcept(self, node, *args): self.raiseError(node, _error.NotSupported, "try-except statement") @@ -262,22 +254,25 @@ def getObj(node): return node.obj return None +def getNrBits(obj): + if hasattr(obj, '_nrbits'): + return obj._nrbits + return None + INPUT, OUTPUT, INOUT = range(3) NORMAL, DECLARATION = range(2) ALWAYS, INITIAL = range(2) -class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): +class _AnalyzeVisitor(_NotSupportedVisitor, _ToVerilogMixin): - def __init__(self, sigdict, symdict, sourcefile, lineoffset): + def __init__(self, symdict, sourcefile, lineoffset): self.sourcefile = sourcefile self.lineoffset = lineoffset - self.inputs = Set() - self.outputs = Set() self.toplevel = 1 - self.sigdict = sigdict self.symdict = symdict self.vardict = {} + self.used = Set() def getObj(self, node): if hasattr(node, 'obj'): @@ -315,7 +310,12 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): def visitAssName(self, node, *args): n = node.name - self.require(node, n not in self.symdict, "Illegal redeclaration: %s" % n) + # XXX ? + if n in self.vardict: + return + if n in self.used: + self.require(node, n not in self.symdict, + "Previously used external symbol cannot be locally redeclared: %s" % n) def visitAugAssign(self, node, access=INPUT, *args): self.visit(node.node, INOUT) @@ -324,10 +324,34 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): def visitCallFunc(self, node, *args): for child in node.getChildNodes(): self.visit(child, *args) - f = self.getObj(node.node) + func = self.getObj(node.node) # print f - if type(f) is type and issubclass(f, intbv): + # print node.args + if type(func) is type and issubclass(func, intbv): node.obj = intbv() + elif func in myhdl.__dict__.values(): + pass + elif func in __builtin__.__dict__.values(): + pass + elif type(func) is FunctionType: + s = inspect.getsource(func) + s = s.lstrip() + ast = compiler.parse(s) + print ast + ast.name = genLabel.next() + "_" + func.__name__ + ast.sourcefile = inspect.getsourcefile(func) + ast.lineoffset = inspect.getsourcelines(func)[1]-1 + ast.symdict = func.func_globals.copy() + v = _AnalyzeFuncVisitor(ast.symdict, ast.sourcefile, ast.lineoffset, \ + self.inputs, self.outputs, node.args) + compiler.walk(ast, v) + ast.sigdict = v.sigdict + ast.vardict = v.vardict + ast.argnames = v.argnames + ast.returnObj = v.returnObj + # print ast.argnames + ast.kind = v.kind + node.ast = ast def visitCompare(self, node, *args): node.obj = bool() @@ -348,23 +372,9 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): self.visit(node.list) self.visit(node.body, *args) self.require(node, node.else_ is None, "for-else not supported") - - def visitFunction(self, node, *args): - if not self.toplevel: - self.raiseError(node, _error.NotSupported, "embedded function definition") - self.toplevel = 0 - print node.code - self.visit(node.code) - self.kind = ALWAYS - for n in node.code.nodes[:-1]: - if not self.getKind(n) == DECLARATION: - self.kind = INITIAL - break - if self.kind == ALWAYS: - w = node.code.nodes[-1] - if not self.getKind(w) == ALWAYS: - self.kind = INITIAL + def visitFunction(self, node, *args): + raise AssertionError def visitGetattr(self, node, *args): self.visit(node.expr, *args) @@ -375,22 +385,9 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): assert hasattr(obj, node.attrname) node.obj = getattr(obj, node.attrname) - def visitModule(self, node, *args): - print node - #assert len(node.node.nodes) == 1 - #assert isinstance(node.node.nodes[0], ast.Function) - self.visit(node.node) - for n in self.outputs: - s = self.sigdict[n] - if s._driven: - self.raiseError(node, _error._SigMultipleDriven, n) - s._driven = True - for n in self.inputs: - s = self.sigdict[n] - s._read = True - def visitName(self, node, access=INPUT, *args): n = node.name + self.used.add(n) if n in self.sigdict: if access == INPUT: self.inputs.add(n) @@ -403,8 +400,14 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): node.obj = self.vardict[n] elif n in self.symdict: node.obj = self.symdict[n] - elif n in __builtins__: + else: + #print "HERE" + #print __builtins__[n] node.obj = __builtins__[n] + + def visitReturn(self, node, *args): + self.visit(node.value) + def visitSlice(self, node, access=INPUT, kind=NORMAL, *args): self.visit(node.expr, access) @@ -439,14 +442,114 @@ class _AnalyzeGenVisitor(_NotSupportedVisitor, _ToVerilogMixin): node.kind = ALWAYS self.require(node, node.else_ is None, "while-else not supported") + +class _AnalyzeBlockVisitor(_AnalyzeVisitor): + + def __init__(self, symdict, sourcefile, lineoffset): + _AnalyzeVisitor.__init__(self, symdict, sourcefile, lineoffset) + self.sigdict = sigdict = {} + for n, v in symdict.items(): + if isinstance(v, Signal): + sigdict[n] = v + self.inputs = Set() + self.outputs = Set() + + def visitFunction(self, node, *args): + if not self.toplevel: + self.raiseError(node, _error.NotSupported, "embedded function definition") + self.toplevel = 0 + print node.code + self.visit(node.code) + self.kind = ALWAYS + for n in node.code.nodes[:-1]: + if not self.getKind(n) == DECLARATION: + self.kind = INITIAL + break + if self.kind == ALWAYS: + w = node.code.nodes[-1] + if not self.getKind(w) == ALWAYS: + self.kind = INITIAL + + def visitModule(self, node, *args): + #assert len(node.node.nodes) == 1 + #assert isinstance(node.node.nodes[0], ast.Function) + self.visit(node.node) + for n in self.outputs: + s = self.sigdict[n] + if s._driven: + self.raiseError(node, _error._SigMultipleDriven, n) + s._driven = True + for n in self.inputs: + s = self.sigdict[n] + s._read = True + + def visitReturn(self, node, *args): + self.raiseError(node, _error.NotSupported, "return statement") + + +class _AnalyzeFuncVisitor(_AnalyzeVisitor): + + def __init__(self, symdict, sourcefile, lineoffset, \ + inputs, outputs, args): + _AnalyzeVisitor.__init__(self, symdict, sourcefile, lineoffset) + self.sigdict = sigdict = {} + self.inputs = inputs + self.outputs = outputs + self.args = args + self.argnames = [] + self.kind = None + self.hasReturn = False + + def visitFunction(self, node, *args): + if not self.toplevel: + self.raiseError(node, _error.NotSupported, "embedded function definition") + self.toplevel = 0 + argnames = node.argnames + for i, arg in enumerate(self.args): + if isinstance(arg, ast.Keyword): + n = arg.name + self.symdict[n] = getObj(arg.expr) + else: # Name + n = argnames[i] + self.symdict[n] = getObj(arg) + self.argnames.append(n) + for n, v in self.symdict.items(): + if isinstance(v, Signal): + self.sigdict[n] = v + self.visit(node.code) + + def visitReturn(self, node, *args): + self.visit(node.value) + if isinstance(node.value, ast.Const) and node.value.value is None: + obj = None + elif isinstance(node.value, ast.Name) and node.value.name is None: + obj = None + elif node.value.obj is not None: + obj = node.value.obj + else: + self.raiseError(node, "Can't derive return type") + if self.hasReturn: + returnObj = self.returnObj + if getNrBits(obj) != getNrBits(returnObj): + self.raiseError(node, "Returned nr of bits is different from before") + if isinstance(obj, type(returnObj)): + pass + elif isinstance(returnObj, type(obj)): + self.returnObj = type(obj) + else: + self.raiseError(node, "Incompatible return type") + else: + self.returnObj = obj + self.hasReturn = True + def _analyzeTopFunc(func, *args, **kwargs): s = inspect.getsource(func) s = s.lstrip() - funcast = compiler.parse(s) + ast = compiler.parse(s) v = _AnalyzeTopFuncVisitor(*args, **kwargs) - compiler.walk(funcast, v) + compiler.walk(ast, v) return v @@ -570,29 +673,34 @@ def _getRangeString(s): def _convertGens(genlist, vfile): - for gen in genlist: - if gen.kind == ALWAYS: + blockBuf = StringIO() + funcBuf = StringIO() + for ast in genlist: + if ast.kind == ALWAYS: Visitor = _ConvertAlwaysVisitor else: Visitor = _ConvertInitialVisitor - v = Visitor(vfile, gen.sigdict, gen.symdict, gen.vardict, - gen.name, gen.sourcefile, gen.lineoffset ) - compiler.walk(gen, v) + v = Visitor(ast, blockBuf, funcBuf) + compiler.walk(ast, v) + #print "FUNC" + #print funcBuf.getvalue() + #print "BLOCK" + #print blockBuf.getvalue() + vfile.write(funcBuf.getvalue()); funcBuf.close() + vfile.write(blockBuf.getvalue()); blockBuf.close() -class _ConvertGenVisitor(_ToVerilogMixin): +class _ConvertVisitor(_ToVerilogMixin): - def __init__(self, f, sigdict, symdict, vardict, name, sourcefile, lineoffset): - self.buf = self.fileBuf = f - self.name = name - self.sourcefile = sourcefile - self.lineoffset = lineoffset - self.declBuf = StringIO() - self.codeBuf = StringIO() - self.sigdict = sigdict - self.symdict = symdict - self.vardict = vardict - print vardict + def __init__(self, ast, blockBuf, funcBuf=None): + self.buf = blockBuf + self.funcBuf = funcBuf + self.name = ast.name + self.sourcefile = ast.sourcefile + self.lineoffset = ast.lineoffset + self.sigdict = ast.sigdict + self.symdict = ast.symdict + self.vardict = ast.vardict self.ind = '' self.inYield = False self.isSigAss = False @@ -672,9 +780,30 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.write(node.name) def visitAugAssign(self, node): - # XXX - pass - + opmap = {"+=" : "+", + "-=" : "-", + "*=" : "*", + "//=" : "/", + "%=" : "%", + "**=" : "**", + "|=" : "|", + ">>=" : ">>", + "<<=" : "<<", + "&=" : "&", + "^=" : "^" + } + if node.op not in opmap: + self.raiseError(node, _error.NotSupported, + "augmented assignment %s" % op) + op = opmap[node.op] + self.writeline() + self.visit(node.node) + self.write(" = ") + self.visit(node.node) + self.write(" %s " % op) + self.visit(node.expr) + self.write(";") + def visitBitand(self, node): self.multiOp(node, '&') @@ -704,6 +833,8 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.write(": ") elif f is concat: opening, closing = '{', '}' + elif hasattr(node, 'ast'): + self.write(node.ast.name) else: self.write(f.__name__) self.write(opening) @@ -713,6 +844,11 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.write(", ") self.visit(arg) self.write(closing) + if hasattr(node, 'ast'): + Visitor = _ConvertFunctionVisitor + v = Visitor(node.ast, self.funcBuf) + compiler.walk(node.ast, v) + def visitCompare(self, node): self.write("(") @@ -731,7 +867,6 @@ class _ConvertGenVisitor(_ToVerilogMixin): def visitFor(self, node): var = node.assign.name - self.buf = self.codeBuf cf = node.list self.require(node, isinstance(cf, ast.CallFunc), "Expected (down)range call") f = getObj(cf.node) @@ -826,8 +961,7 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.write(")") def visitKeyword(self, node): - # XXX - pass + self.visit(node.expr) def visitLeftShift(self, node): self.binaryOp(node, '<<') @@ -843,16 +977,19 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.binaryOp(node, '*') def visitName(self, node): - if node.name in self.symdict: - obj = self.symdict[node.name] + n = node.name + if n in self.vardict: + self.write(n) + elif node.name in self.symdict: + obj = self.symdict[n] if isinstance(obj, int): self.write(str(obj)) elif type(obj) is Signal: self.write(obj._name) else: - self.write(node.name) + self.write(n) else: - self.write(node.name) + raise AssertionError def visitNot(self, node): self.write("(!") @@ -892,6 +1029,12 @@ class _ConvertGenVisitor(_ToVerilogMixin): def visitSlice(self, node): # print dir(node) # print node.obj + if isinstance(node.expr, ast.CallFunc) and \ + node.expr.node.obj is intbv: + c = self.getVal(node) + self.write("%s'h" % c._len) + self.write("%x" % c._val) + return self.visit(node.expr) self.write("[") if node.lower is None: @@ -948,10 +1091,10 @@ class _ConvertGenVisitor(_ToVerilogMixin): self.inYield = False -class _ConvertAlwaysVisitor(_ConvertGenVisitor): +class _ConvertAlwaysVisitor(_ConvertVisitor): def __init__(self, *args): - _ConvertGenVisitor.__init__(self, *args) + _ConvertVisitor.__init__(self, *args) def visitFunction(self, node): w = node.code.nodes[-1] @@ -964,36 +1107,94 @@ class _ConvertAlwaysVisitor(_ConvertGenVisitor): self.write(") begin: %s" % self.name) self.indent() self.writeDeclarations() - self.buf = self.codeBuf for s in w.body.nodes[1:]: self.visit(s) - self.buf = self.fileBuf - self.write(self.codeBuf.getvalue()) self.dedent() self.writeline() self.write("end") self.writeline() self.writeline() -class _ConvertInitialVisitor(_ConvertGenVisitor): +class _ConvertInitialVisitor(_ConvertVisitor): def __init__(self, *args): - _ConvertGenVisitor.__init__(self, *args) + _ConvertVisitor.__init__(self, *args) def visitFunction(self, node): - self.inYield = True self.write("initial begin: %s" % self.name) self.indent() self.writeDeclarations() - self.buf = self.codeBuf self.visit(node.code) - self.buf = self.fileBuf - self.write(self.codeBuf.getvalue()) self.dedent() self.writeline() self.write("end") self.writeline() self.writeline() + + +class _ConvertFunctionVisitor(_ConvertVisitor): + + def __init__(self, ast, blockBuf, funcBuf=None): + _ConvertVisitor.__init__(self, ast, blockBuf, funcBuf) + self.argnames = ast.argnames + self.returnObj = ast.returnObj + + def writeOutputDeclaration(self): + obj = self.returnObj + if type(obj) is bool: + pass + elif isinstance(obj, int): + self.write("integer") + elif isinstance(obj, intbv): + self.write("[%s-1:0]" % obj._len) + elif hasattr(obj, '_nrbits'): + self.write("[%s-1:0]" % obj._nrbits) + else: + raise AssertionError("unexpected type") + + def writeInputDeclarations(self): + for name in self.argnames: + obj = self.symdict[name] + self.writeline() + if type(obj) is bool: + self.write("input %s;" % name) + elif isinstance(obj, int): + self.write("integer %s;" % name) + elif isinstance(obj, intbv): + self.write("input [%s-1:0] %s;" % (obj._len, name)) + elif hasattr(obj, '_nrbits'): + self.write("input [%s-1:0] %s;" % (obj._nrbits, name)) + else: + raise AssertionError("unexpected type") + + def visitFunction(self, node): + self.write("function ") + self.writeOutputDeclaration() + self.write(" %s;" % self.name) + self.indent() + self.writeInputDeclarations() + self.writeDeclarations() + self.writeline() + self.write("begin: __MYHDL__") + self.visit(node.code) + self.dedent() + self.writeline() + self.write("end") + self.writeline() + self.write("endfunction") + self.writeline() + self.writeline() + + def visitReturn(self, node): + self.writeline() + self.write("%s = " % self.name) + self.visit(node.value) + self.write(";") + self.writeline() + self.write("disable __MYHDL__;") + + + diff --git a/myhdl/_unparse.py b/myhdl/_unparse.py index 8419443c..7f168b15 100644 --- a/myhdl/_unparse.py +++ b/myhdl/_unparse.py @@ -144,9 +144,11 @@ class _UnparseVisitor(object): def visitSlice(self, node): self.visit(node.expr) self.write('[') - self.visit(node.lower) + if node.lower is not None: + self.visit(node.lower) self.write(':') - self.visit(node.upper) + if node.upper is not None: + self.visit(node.upper) self.write(']') def visitSub(self, node): diff --git a/myhdl/test/toVerilog/test_hec.py b/myhdl/test/toVerilog/test_hec.py index 3d843cfc..8b322137 100644 --- a/myhdl/test/toVerilog/test_hec.py +++ b/myhdl/test/toVerilog/test_hec.py @@ -7,9 +7,10 @@ from myhdl import * COSET = 0x55 -def calculateHec(header): - """ Return hec for an ATM header, represented as an intbv. +def calculateHecRef(header): + """ Return hec for an ATM header. + Reference version. The hec polynomial is 1 + x + x**2 + x**8. """ hec = intbv(0) @@ -21,8 +22,28 @@ def calculateHec(header): ) return hec ^ COSET +def calculateHecSynth(header): + """ Return hec for an ATM header. + + Synthesizable version. + The hec polynomial is 1 + x + x**2 + x**8. + """ + h = intbv(0)[8:] + for i in downrange(len(header)): + bit = header[i] + h[:] = concat(h[7:2], + bit ^ h[1] ^ h[7], + bit ^ h[0] ^ h[7], + bit ^ h[7] + ) + h ^= COSET + return h -def HecCalculator(hec, header): +def HecCalculatorPlain(hec, header): + """ Hec calculation module. + + Plain version. + """ h = intbv(0)[8:] while 1: yield header @@ -35,6 +56,18 @@ def HecCalculator(hec, header): bit ^ h[7] ) hec.next = h ^ COSET + +def HecCalculatorFunc(hec, header): + """ Hec calculation module. + + Version with function call. + """ + h = intbv(0)[8:] + while 1: + yield header + hec.next = calculateHecSynth(header=header) + + objfile = "heccalc.o" analyze_cmd = "iverilog -o %s heccalc_inst.v tb_heccalc_inst.v" % objfile @@ -46,13 +79,6 @@ def HecCalculator_v(hec, header): os.system(analyze_cmd) return Cosimulation(simulate_cmd, **locals()) -hec = Signal(intbv(0)[8:]) -hec_v = Signal(intbv(0)[8:]) -header = Signal(intbv(-1)[32:]) - -heccalc_inst = toVerilog(HecCalculator, hec, header) -# heccalc_inst = HecCalculator(hec, header) -heccalc_v_inst = HecCalculator_v(hec_v, header) headers = [ 0x00000000L, @@ -64,17 +90,36 @@ headers.extend([randrange(2**32-1) for i in range(10)]) class TestHec(unittest.TestCase): - def stimulus(self): - for h in headers: - header.next = h - yield delay(10) - print "hec: %s hec_v: %s" % (hex(hec), hex(hec_v)) - self.assertEqual(hec, hec_v) + def bench(self, HecCalculator): + + hec = Signal(intbv(0)[8:]) + hec_v = Signal(intbv(0)[8:]) + header = Signal(intbv(-1)[32:]) - def test(self): - sim = self.stimulus(), heccalc_inst, heccalc_v_inst + heccalc_inst = toVerilog(HecCalculator, hec, header) + # heccalc_inst = HecCalculator(hec, header) + heccalc_v_inst = HecCalculator_v(hec_v, header) + + def stimulus(): + for h in headers: + header.next = h + yield delay(10) + hec_ref = calculateHecRef(header) + # print "hec: %s hec_v: %s" % (hex(hec), hex(hec_v)) + self.assertEqual(hec, hec_ref) + self.assertEqual(hec, hec_v) + + return stimulus(), heccalc_inst, heccalc_v_inst + + def testPlain(self): + sim = self.bench(HecCalculatorPlain) Simulation(sim).run() + def testFunc(self): + sim = self.bench(HecCalculatorFunc) + Simulation(sim).run() + + if __name__ == '__main__': unittest.main()