mirror of
https://github.com/myhdl/myhdl.git
synced 2025-01-24 21:52:56 +08:00
Merge pull request #113 from jck/decorators
core decorator improvements
This commit is contained in:
commit
a1ba159a16
@ -24,7 +24,7 @@ from __future__ import absolute_import
|
||||
from types import FunctionType
|
||||
|
||||
from myhdl import AlwaysError
|
||||
from myhdl._util import _isGenFunc
|
||||
from myhdl._util import _isGenFunc, _makeAST
|
||||
from myhdl._delay import delay
|
||||
from myhdl._Signal import _Signal, _WaiterList, posedge, negedge
|
||||
from myhdl._Waiter import _Waiter, _SignalWaiter, _SignalTupleWaiter, \
|
||||
@ -58,28 +58,27 @@ def always(*args):
|
||||
raise AlwaysError(_error.NrOfArgs)
|
||||
return _Always(func, args)
|
||||
return _always_decorator
|
||||
|
||||
|
||||
|
||||
class _Always(_Instantiator):
|
||||
|
||||
def __init__(self, func, args):
|
||||
def __init__(self, func, senslist):
|
||||
self.func = func
|
||||
self.senslist = tuple(args)
|
||||
self.gen = self.genfunc()
|
||||
|
||||
self.senslist = tuple(senslist)
|
||||
super(_Always, self).__init__(self.genfunc)
|
||||
|
||||
def _waiter(self):
|
||||
# infer appropriate waiter class
|
||||
# first infer base type of arguments
|
||||
for t in (_Signal, _WaiterList, delay):
|
||||
if isinstance(args[0], t):
|
||||
if isinstance(self.senslist[0], t):
|
||||
bt = t
|
||||
for arg in args[1:]:
|
||||
if not isinstance(arg, bt):
|
||||
for s in self.senslist[1:]:
|
||||
if not isinstance(s, bt):
|
||||
bt = None
|
||||
break
|
||||
# now set waiter class
|
||||
|
||||
W = _Waiter
|
||||
|
||||
if bt is delay:
|
||||
W = _DelayWaiter
|
||||
elif len(self.senslist) == 1:
|
||||
@ -93,8 +92,11 @@ class _Always(_Instantiator):
|
||||
elif bt is _WaiterList:
|
||||
W = _EdgeTupleWaiter
|
||||
|
||||
self.waiter = W(self.gen)
|
||||
|
||||
return W
|
||||
|
||||
@property
|
||||
def ast(self):
|
||||
return _makeAST(self.func)
|
||||
|
||||
def genfunc(self):
|
||||
senslist = self.senslist
|
||||
@ -104,4 +106,3 @@ class _Always(_Instantiator):
|
||||
while 1:
|
||||
yield senslist
|
||||
func()
|
||||
|
||||
|
@ -32,7 +32,9 @@ from myhdl._util import _isGenFunc, _dedent
|
||||
from myhdl._cell_deref import _cell_deref
|
||||
from myhdl._Waiter import _Waiter, _SignalWaiter, _SignalTupleWaiter
|
||||
from myhdl._instance import _Instantiator
|
||||
from myhdl._always import _Always
|
||||
from myhdl._resolverefs import _AttrRefTransformer
|
||||
from myhdl._visitors import _SigNameVisitor
|
||||
|
||||
class _error:
|
||||
pass
|
||||
@ -67,100 +69,8 @@ def always_comb(func):
|
||||
return c
|
||||
|
||||
|
||||
INPUT, OUTPUT, INOUT = range(3)
|
||||
|
||||
|
||||
|
||||
class _SigNameVisitor(ast.NodeVisitor):
|
||||
def __init__(self, symdict):
|
||||
self.inputs = set()
|
||||
self.outputs = set()
|
||||
self.toplevel = 1
|
||||
self.symdict = symdict
|
||||
self.context = INPUT
|
||||
|
||||
def visit_Module(self, node):
|
||||
inputs = self.inputs
|
||||
outputs = self.outputs
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
for n in inputs:
|
||||
if n in outputs:
|
||||
raise AlwaysCombError(_error.SignalAsInout % n)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if self.toplevel:
|
||||
self.toplevel = 0 # skip embedded functions
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
else:
|
||||
raise AlwaysCombError(_error.EmbeddedFunction)
|
||||
|
||||
def visit_If(self, node):
|
||||
if not node.orelse:
|
||||
if isinstance(node.test, ast.Name) and \
|
||||
node.test.id == '__debug__':
|
||||
return # skip
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
id = node.id
|
||||
if id not in self.symdict:
|
||||
return
|
||||
s = self.symdict[id]
|
||||
if isinstance(s, _Signal) or _isListOfSigs(s):
|
||||
if self.context == INPUT:
|
||||
self.inputs.add(id)
|
||||
elif self.context == OUTPUT:
|
||||
self.outputs.add(id)
|
||||
elif self.context == INOUT:
|
||||
raise AlwaysCombError(_error.SignalAsInout % id)
|
||||
else:
|
||||
raise AssertionError("bug in always_comb")
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.context = OUTPUT
|
||||
for n in node.targets:
|
||||
self.visit(n)
|
||||
self.context = INPUT
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = None
|
||||
if isinstance(node.func, ast.Name):
|
||||
fn = node.func.id
|
||||
if fn == "len":
|
||||
pass
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_Subscript(self, node, access=INPUT):
|
||||
self.visit(node.value)
|
||||
self.context = INPUT
|
||||
self.visit(node.slice)
|
||||
|
||||
def visit_AugAssign(self, node, access=INPUT):
|
||||
self.context = INOUT
|
||||
self.visit(node.target)
|
||||
self.context = INPUT
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Exec(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Print(self, node):
|
||||
pass # skip
|
||||
|
||||
|
||||
|
||||
class _AlwaysComb(_Instantiator):
|
||||
# class _AlwaysComb(_Instantiator):
|
||||
class _AlwaysComb(_Always):
|
||||
|
||||
# def __init__(self, func, symdict):
|
||||
# self.func = func
|
||||
@ -192,7 +102,6 @@ class _AlwaysComb(_Instantiator):
|
||||
# self.waiter = W(self.gen)
|
||||
|
||||
def __init__(self, func, symdict):
|
||||
self.func = func
|
||||
self.symdict = symdict
|
||||
s = inspect.getsource(func)
|
||||
s = _dedent(s)
|
||||
@ -202,8 +111,16 @@ class _AlwaysComb(_Instantiator):
|
||||
v.visit(tree)
|
||||
v = _SigNameVisitor(self.symdict)
|
||||
v.visit(tree)
|
||||
self.inputs = v.inputs
|
||||
self.outputs = v.outputs
|
||||
self.inputs = v.results['input']
|
||||
self.outputs = v.results['output']
|
||||
|
||||
inouts = v.results['inout'] | self.inputs.intersection(self.outputs)
|
||||
if inouts:
|
||||
raise AlwaysCombError(_error.SignalAsInout % inouts)
|
||||
|
||||
if v.results['embedded_func']:
|
||||
raise AlwaysCombError(_error.EmbeddedFunction)
|
||||
|
||||
senslist = []
|
||||
for n in self.inputs:
|
||||
s = self.symdict[n]
|
||||
@ -212,16 +129,10 @@ class _AlwaysComb(_Instantiator):
|
||||
else: # list of sigs
|
||||
senslist.extend(s)
|
||||
self.senslist = tuple(senslist)
|
||||
self.gen = self.genfunc()
|
||||
if len(self.senslist) == 0:
|
||||
raise AlwaysCombError(_error.EmptySensitivityList)
|
||||
if len(self.senslist) == 1:
|
||||
W = _SignalWaiter
|
||||
else:
|
||||
W = _SignalTupleWaiter
|
||||
self.waiter = W(self.gen)
|
||||
|
||||
|
||||
super(_AlwaysComb, self).__init__(func, senslist)
|
||||
|
||||
def genfunc(self):
|
||||
senslist = self.senslist
|
||||
|
@ -32,8 +32,9 @@ from myhdl._cell_deref import _cell_deref
|
||||
from myhdl._delay import delay
|
||||
from myhdl._Signal import _Signal, _WaiterList,_isListOfSigs
|
||||
from myhdl._Waiter import _Waiter, _EdgeWaiter, _EdgeTupleWaiter
|
||||
from myhdl._instance import _Instantiator
|
||||
from myhdl._always import _Always
|
||||
from myhdl._resolverefs import _AttrRefTransformer
|
||||
from myhdl._visitors import _SigNameVisitor
|
||||
|
||||
# evacuate this later
|
||||
AlwaysSeqError = AlwaysError
|
||||
@ -82,13 +83,13 @@ def always_seq(edge, reset):
|
||||
return _always_seq_decorator
|
||||
|
||||
|
||||
class _AlwaysSeq(_Instantiator):
|
||||
class _AlwaysSeq(_Always):
|
||||
|
||||
def __init__(self, func, edge, reset):
|
||||
self.func = func
|
||||
self.senslist = senslist = [edge]
|
||||
senslist = [edge]
|
||||
self.reset = reset
|
||||
if reset is not None:
|
||||
self.genfunc = self.genfunc_reset
|
||||
active = self.reset.active
|
||||
async = self.reset.async
|
||||
if async:
|
||||
@ -96,14 +97,10 @@ class _AlwaysSeq(_Instantiator):
|
||||
senslist.append(reset.posedge)
|
||||
else:
|
||||
senslist.append(reset.negedge)
|
||||
self.gen = self.genfunc()
|
||||
else:
|
||||
self.gen = self.genfunc_no_reset()
|
||||
if len(self.senslist) == 1:
|
||||
W = _EdgeWaiter
|
||||
else:
|
||||
W = _EdgeTupleWaiter
|
||||
self.waiter = W(self.gen)
|
||||
self.genfunc = self.genfunc_no_reset
|
||||
|
||||
super(_AlwaysSeq, self).__init__(func, senslist)
|
||||
|
||||
# find symdict
|
||||
# similar to always_comb, but in class constructor
|
||||
@ -131,9 +128,16 @@ class _AlwaysSeq(_Instantiator):
|
||||
v.visit(tree)
|
||||
v = _SigNameVisitor(self.symdict)
|
||||
v.visit(tree)
|
||||
|
||||
if v.results['inout']:
|
||||
raise AlwaysSeqError(_error.SigAugAssign, v.results['inout'])
|
||||
|
||||
if v.results['embedded_func']:
|
||||
raise AlwaysSeqError(_error.EmbeddedFunction)
|
||||
|
||||
sigregs = self.sigregs = []
|
||||
varregs = self.varregs = []
|
||||
for n in v.outputs:
|
||||
for n in v.results['output']:
|
||||
reg = self.symdict[n]
|
||||
if isinstance(reg, _Signal):
|
||||
sigregs.append(reg)
|
||||
@ -144,7 +148,6 @@ class _AlwaysSeq(_Instantiator):
|
||||
for e in reg:
|
||||
sigregs.append(e)
|
||||
|
||||
|
||||
def reset_sigs(self):
|
||||
for s in self.sigregs:
|
||||
s.next = s._init
|
||||
@ -155,7 +158,7 @@ class _AlwaysSeq(_Instantiator):
|
||||
n, reg, init = v
|
||||
reg._val = init
|
||||
|
||||
def genfunc(self):
|
||||
def genfunc_reset(self):
|
||||
senslist = self.senslist
|
||||
if len(senslist) == 1:
|
||||
senslist = senslist[0]
|
||||
@ -178,88 +181,3 @@ class _AlwaysSeq(_Instantiator):
|
||||
while 1:
|
||||
yield senslist
|
||||
func()
|
||||
|
||||
|
||||
# similar to always_comb, calls for refactoring
|
||||
# note: make a difference between augmented assign and inout signals
|
||||
|
||||
INPUT, OUTPUT, INOUT = range(3)
|
||||
|
||||
class _SigNameVisitor(ast.NodeVisitor):
|
||||
def __init__(self, symdict):
|
||||
self.inputs = set()
|
||||
self.outputs = set()
|
||||
self.toplevel = 1
|
||||
self.symdict = symdict
|
||||
self.context = INPUT
|
||||
|
||||
def visit_Module(self, node):
|
||||
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if self.toplevel:
|
||||
self.toplevel = 0 # skip embedded functions
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
else:
|
||||
raise AlwaysSeqError(_error.EmbeddedFunction)
|
||||
|
||||
def visit_If(self, node):
|
||||
if not node.orelse:
|
||||
if isinstance(node.test, ast.Name) and \
|
||||
node.test.id == '__debug__':
|
||||
return # skip
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
id = node.id
|
||||
if id not in self.symdict:
|
||||
return
|
||||
s = self.symdict[id]
|
||||
if isinstance(s, (_Signal, intbv)) or _isListOfSigs(s):
|
||||
if self.context == INPUT:
|
||||
self.inputs.add(id)
|
||||
elif self.context == OUTPUT:
|
||||
self.outputs.add(id)
|
||||
elif self.context == INOUT:
|
||||
raise AlwaysSeqError(_error.SigAugAssign, id)
|
||||
else:
|
||||
raise AssertionError("bug in always_seq")
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.context = OUTPUT
|
||||
for n in node.targets:
|
||||
self.visit(n)
|
||||
self.context = INPUT
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Subscript(self, node, access=INPUT):
|
||||
self.visit(node.value)
|
||||
self.context = INPUT
|
||||
self.visit(node.slice)
|
||||
|
||||
def visit_AugAssign(self, node, access=INPUT):
|
||||
self.context = INOUT
|
||||
self.visit(node.target)
|
||||
self.context = INPUT
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Exec(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Print(self, node):
|
||||
pass # skip
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ from __future__ import absolute_import
|
||||
from types import FunctionType
|
||||
|
||||
from myhdl import InstanceError
|
||||
from myhdl._util import _isGenFunc
|
||||
from myhdl._util import _isGenFunc, _makeAST
|
||||
from myhdl._Waiter import _inferWaiter
|
||||
|
||||
class _error:
|
||||
@ -44,8 +44,17 @@ def instance(genFunc):
|
||||
|
||||
class _Instantiator(object):
|
||||
|
||||
def __init__(self, genFunc):
|
||||
self.genfunc = genFunc
|
||||
self.gen = genFunc()
|
||||
self.waiter = _inferWaiter(self.gen)
|
||||
|
||||
def __init__(self, genfunc):
|
||||
self.genfunc = genfunc
|
||||
self.gen = genfunc()
|
||||
|
||||
@property
|
||||
def waiter(self):
|
||||
return self._waiter()(self.gen)
|
||||
|
||||
def _waiter(self):
|
||||
return _inferWaiter
|
||||
|
||||
@property
|
||||
def ast(self):
|
||||
return _makeAST(self.gen.gi_frame)
|
||||
|
@ -3,7 +3,7 @@ import ast
|
||||
import itertools
|
||||
from types import FunctionType
|
||||
|
||||
from myhdl._util import _flatten, _makeAST, _genfunc
|
||||
from myhdl._util import _flatten
|
||||
from myhdl._enum import EnumType
|
||||
from myhdl._Signal import SignalType
|
||||
|
||||
@ -18,9 +18,7 @@ def _resolveRefs(symdict, arg):
|
||||
data.symdict = symdict
|
||||
v = _AttrRefTransformer(data)
|
||||
for gen in gens:
|
||||
func = _genfunc(gen)
|
||||
tree = _makeAST(func)
|
||||
v.visit(tree)
|
||||
v.visit(gen.ast)
|
||||
return data.objlist
|
||||
|
||||
#TODO: Refactor this into two separate nodetransformers, since _resolveRefs
|
||||
|
89
myhdl/_visitors.py
Normal file
89
myhdl/_visitors.py
Normal file
@ -0,0 +1,89 @@
|
||||
import ast
|
||||
|
||||
from myhdl import intbv
|
||||
from myhdl._Signal import _Signal, _isListOfSigs
|
||||
|
||||
|
||||
class _SigNameVisitor(ast.NodeVisitor):
|
||||
def __init__(self, symdict):
|
||||
self.toplevel = 1
|
||||
self.symdict = symdict
|
||||
self.results = {
|
||||
'input': set(),
|
||||
'output': set(),
|
||||
'inout': set(),
|
||||
'embedded_func': set()
|
||||
}
|
||||
self.context = 'input'
|
||||
|
||||
def visit_Module(self, node):
|
||||
inputs = self.results['input']
|
||||
outputs = self.results['output']
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if self.toplevel:
|
||||
self.toplevel = 0 # skip embedded functions
|
||||
for n in node.body:
|
||||
self.visit(n)
|
||||
else:
|
||||
self.results['embedded_func'] = node.name
|
||||
|
||||
def visit_If(self, node):
|
||||
if not node.orelse:
|
||||
if isinstance(node.test, ast.Name) and \
|
||||
node.test.id == '__debug__':
|
||||
return # skip
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
id = node.id
|
||||
if id not in self.symdict:
|
||||
return
|
||||
s = self.symdict[id]
|
||||
if isinstance(s, (_Signal, intbv)) or _isListOfSigs(s):
|
||||
if self.context in ('input', 'output', 'inout'):
|
||||
self.results[self.context].add(id)
|
||||
else:
|
||||
print(self.context)
|
||||
raise AssertionError("bug in always_comb")
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.context = 'output'
|
||||
for n in node.targets:
|
||||
self.visit(n)
|
||||
self.context = 'input'
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = None
|
||||
if isinstance(node.func, ast.Name):
|
||||
fn = node.func.id
|
||||
if fn == "len":
|
||||
pass
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Subscript(self, node, access='input'):
|
||||
self.visit(node.value)
|
||||
self.context = 'input'
|
||||
self.visit(node.slice)
|
||||
|
||||
def visit_AugAssign(self, node, access='input'):
|
||||
self.context = 'inout'
|
||||
self.visit(node.target)
|
||||
self.context = 'input'
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Exec(self, node):
|
||||
pass # skip
|
||||
|
||||
def visit_Print(self, node):
|
||||
pass # skip
|
@ -145,7 +145,7 @@ def _analyzeGens(top, absnames):
|
||||
tree = g
|
||||
elif isinstance(g, (_AlwaysComb, _AlwaysSeq, _Always)):
|
||||
f = g.func
|
||||
tree = _makeAST(f)
|
||||
tree = g.ast
|
||||
tree.symdict = f.__globals__.copy()
|
||||
tree.callstack = []
|
||||
# handle free variables
|
||||
@ -171,7 +171,7 @@ def _analyzeGens(top, absnames):
|
||||
v.visit(tree)
|
||||
else: # @instance
|
||||
f = g.gen.gi_frame
|
||||
tree = _makeAST(f)
|
||||
tree = g.ast
|
||||
tree.symdict = f.f_globals.copy()
|
||||
tree.symdict.update(f.f_locals)
|
||||
tree.nonlocaldict = {}
|
||||
|
@ -122,7 +122,7 @@ class TestAlwaysCombCompilation:
|
||||
def h():
|
||||
c.next += 1
|
||||
a += 1
|
||||
with raises_kind(AlwaysCombError, _error.SignalAsInout % "c"):
|
||||
with raises_kind(AlwaysCombError, _error.SignalAsInout % set('c')):
|
||||
g = always_comb(h).gen
|
||||
|
||||
def testInfer6(self):
|
||||
@ -131,7 +131,7 @@ class TestAlwaysCombCompilation:
|
||||
def h():
|
||||
c.next = a
|
||||
x.next = c
|
||||
with raises_kind(AlwaysCombError, _error.SignalAsInout % "c"):
|
||||
with raises_kind(AlwaysCombError, _error.SignalAsInout % set('c')):
|
||||
g = always_comb(h).gen
|
||||
|
||||
def testInfer7(self):
|
||||
|
@ -1,10 +1,17 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@contextmanager
|
||||
def raises_kind(exc, kind):
|
||||
with pytest.raises(exc) as excinfo:
|
||||
yield
|
||||
assert excinfo.value.kind == kind
|
||||
class raises_kind(object):
|
||||
def __init__(self, exc, kind):
|
||||
self.exc = exc
|
||||
self.kind = kind
|
||||
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, *tp):
|
||||
__tracebackhide__ = True
|
||||
if tp[0] is None:
|
||||
pytest.fail("DID NOT RAISE")
|
||||
assert tp[1].kind == self.kind
|
||||
return issubclass(tp[0], self.exc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user