slight refactoring

svn:r700
This commit is contained in:
Niels Provos 2008-04-03 03:33:07 +00:00
parent 193c06a7ed
commit a7e395512e

View File

@ -52,6 +52,12 @@ class Struct:
name = "%s_%s" % (self._name, entry.Name())
return name.upper()
class StructCCode(Struct):
""" Knows how to generate C code for a struct """
def __init__(self, name):
Struct.__init__(self, name)
def PrintIndented(self, file, ident, code):
"""Takes an array, add indentation to each entry and prints it."""
for entry in code:
@ -328,22 +334,6 @@ class Entry:
def GetInitializer(self):
assert 0, "Entry does not provide initializer"
def GetTranslation(self, extradict = {}):
mapping = {
"parent_name" : self._struct.Name(),
"name" : self._name,
"ctype" : self._ctype,
"refname" : self._refname,
"optpointer" : self._optpointer and "*" or "",
"optreference" : self._optpointer and "&" or "",
"optaddarg" :
self._optaddarg and ", const %s value" % self._ctype or ""
}
for (k, v) in extradict.items():
mapping[k] = v
return mapping
def SetStruct(self, struct):
self._struct = struct
@ -375,6 +365,39 @@ class Entry:
def MakeOptional(self):
self._optional = 1
def Verify(self):
if self.Array() and not self._can_be_array:
print >>sys.stderr, (
'Entry "%s" cannot be created as an array '
'around line %d' ) % (self._name, self.LineCount())
sys.exit(1)
if not self._struct:
print >>sys.stderr, (
'Entry "%s" does not know which struct it belongs to '
'around line %d' ) % (self._name, self.LineCount())
sys.exit(1)
if self._optional and self._array:
print >>sys.stderr, ( 'Entry "%s" has illegal combination of '
'optional and array around line %d' ) % (
self._name, self.LineCount() )
sys.exit(1)
def GetTranslation(self, extradict = {}):
mapping = {
"parent_name" : self._struct.Name(),
"name" : self._name,
"ctype" : self._ctype,
"refname" : self._refname,
"optpointer" : self._optpointer and "*" or "",
"optreference" : self._optpointer and "&" or "",
"optaddarg" :
self._optaddarg and ", const %s value" % self._ctype or ""
}
for (k, v) in extradict.items():
mapping[k] = v
return mapping
def GetVarName(self, var):
return '%(var)s->%(name)s_data' % self.GetTranslation({ 'var' : var })
@ -451,23 +474,6 @@ class Entry:
code = code % self.GetTranslation()
return code.split('\n')
def Verify(self):
if self.Array() and not self._can_be_array:
print >>sys.stderr, (
'Entry "%s" cannot be created as an array '
'around line %d' ) % (self._name, self.LineCount())
sys.exit(1)
if not self._struct:
print >>sys.stderr, (
'Entry "%s" does not know which struct it belongs to '
'around line %d' ) % (self._name, self.LineCount())
sys.exit(1)
if self._optional and self._array:
print >>sys.stderr, ( 'Entry "%s" has illegal combination of '
'optional and array around line %d' ) % (
self._name, self.LineCount() )
sys.exit(1)
class EntryBytes(Entry):
def __init__(self, type, name, tag, length):
# Init base class
@ -1261,7 +1267,7 @@ def NormalizeLine(line):
return line
def ProcessOneEntry(newstruct, entry):
def ProcessOneEntry(factory, newstruct, entry):
optional = 0
array = 0
entry_type = ''
@ -1327,19 +1333,19 @@ def ProcessOneEntry(newstruct, entry):
# Create the right entry
if entry_type == 'bytes':
if fixed_length:
newentry = EntryBytes(entry_type, name, tag, fixed_length)
newentry = factory.EntryBytes(entry_type, name, tag, fixed_length)
else:
newentry = EntryVarBytes(entry_type, name, tag)
newentry = factory.EntryVarBytes(entry_type, name, tag)
elif entry_type == 'int' and not fixed_length:
newentry = EntryInt(entry_type, name, tag)
newentry = factory.EntryInt(entry_type, name, tag)
elif entry_type == 'string' and not fixed_length:
newentry = EntryString(entry_type, name, tag)
newentry = factory.EntryString(entry_type, name, tag)
else:
res = re.match(r'^struct\[(%s)\]$' % _STRUCT_RE,
entry_type, re.IGNORECASE)
if res:
# References another struct defined in our file
newentry = EntryStruct(entry_type, name, tag, res.group(1))
newentry = factory.EntryStruct(entry_type, name, tag, res.group(1))
else:
print >>sys.stderr, 'Bad type: "%s" in "%s"' % (entry_type, entry)
sys.exit(1)
@ -1369,11 +1375,11 @@ def ProcessOneEntry(newstruct, entry):
return structs
def ProcessStruct(data):
def ProcessStruct(factory, data):
tokens = data.split(' ')
# First three tokens are: 'struct' 'name' '{'
newstruct = Struct(tokens[1])
newstruct = factory.Struct(tokens[1])
inside = ' '.join(tokens[3:-1])
@ -1387,7 +1393,7 @@ def ProcessStruct(data):
continue
# It's possible that new structs get defined in here
structs.extend(ProcessOneEntry(newstruct, entry))
structs.extend(ProcessOneEntry(factory, newstruct, entry))
structs.append(newstruct)
return structs
@ -1472,7 +1478,7 @@ def GetNextStruct(file):
return data
def Parse(file):
def Parse(factory, file):
"""
Parses the input file and returns C code and corresponding header file.
"""
@ -1486,93 +1492,114 @@ def Parse(file):
if not data:
break
entities.extend(ProcessStruct(data))
entities.extend(ProcessStruct(factory, data))
return entities
def GuardName(name):
name = '_'.join(name.split('.'))
name = '_'.join(name.split('/'))
guard = '_'+name.upper()+'_'
class CCodeGenerator:
def __init__(self):
pass
return guard
def GuardName(self, name):
name = '_'.join(name.split('.'))
name = '_'.join(name.split('/'))
guard = '_' + name.upper() + '_'
def HeaderPreamble(name):
guard = GuardName(name)
pre = (
'/*\n'
' * Automatically generated from %s\n'
' */\n\n'
'#ifndef %s\n'
'#define %s\n\n' ) % (
name, guard, guard)
return guard
# insert stdint.h - let's hope everyone has it
pre += (
'#include <event-config.h>\n'
'#ifdef _EVENT_HAVE_STDINT_H\n'
'#include <stdint.h>\n'
'#endif\n' )
for statement in headerdirect:
pre += '%s\n' % statement
if headerdirect:
pre += '\n'
pre += (
'#define EVTAG_HAS(msg, member) ((msg)->member##_set == 1)\n'
'#define EVTAG_ASSIGN(msg, member, args...) '
'(*(msg)->base->member##_assign)(msg, ## args)\n'
'#define EVTAG_GET(msg, member, args...) '
'(*(msg)->base->member##_get)(msg, ## args)\n'
'#define EVTAG_ADD(msg, member, args...) '
'(*(msg)->base->member##_add)(msg, ## args)\n'
'#define EVTAG_LEN(msg, member) ((msg)->member##_length)\n'
)
return pre
def HeaderPostamble(name):
guard = GuardName(name)
return '#endif /* %s */' % guard
def BodyPreamble(name):
global _NAME
global _VERSION
header_file = '.'.join(name.split('.')[:-1]) + '.gen.h'
pre = ( '/*\n'
def HeaderPreamble(self, name):
guard = self.GuardName(name)
pre = (
'/*\n'
' * Automatically generated from %s\n'
' * by %s/%s. DO NOT EDIT THIS FILE.\n'
' */\n\n' ) % (name, _NAME, _VERSION)
pre += ( '#include <sys/types.h>\n'
'#include <sys/time.h>\n'
'#include <stdlib.h>\n'
'#include <string.h>\n'
'#include <assert.h>\n'
'#include <event.h>\n\n' )
' */\n\n'
'#ifndef %s\n'
'#define %s\n\n' ) % (
name, guard, guard)
for statement in cppdirect:
pre += '%s\n' % statement
# insert stdint.h - let's hope everyone has it
pre += (
'#include <event-config.h>\n'
'#ifdef _EVENT_HAVE_STDINT_H\n'
'#include <stdint.h>\n'
'#endif\n' )
pre += '\n#include "%s"\n\n' % header_file
for statement in headerdirect:
pre += '%s\n' % statement
if headerdirect:
pre += '\n'
pre += 'void event_err(int eval, const char *fmt, ...);\n'
pre += 'void event_warn(const char *fmt, ...);\n'
pre += 'void event_errx(int eval, const char *fmt, ...);\n'
pre += 'void event_warnx(const char *fmt, ...);\n\n'
pre += (
'#define EVTAG_HAS(msg, member) ((msg)->member##_set == 1)\n'
'#define EVTAG_ASSIGN(msg, member, args...) '
'(*(msg)->base->member##_assign)(msg, ## args)\n'
'#define EVTAG_GET(msg, member, args...) '
'(*(msg)->base->member##_get)(msg, ## args)\n'
'#define EVTAG_ADD(msg, member, args...) '
'(*(msg)->base->member##_add)(msg, ## args)\n'
'#define EVTAG_LEN(msg, member) ((msg)->member##_length)\n'
)
return pre
return pre
def main(argv):
if len(argv) < 2 or not argv[1]:
print >>sys.stderr, 'Need RPC description file as first argument.'
sys.exit(1)
def HeaderPostamble(self, name):
guard = self.GuardName(name)
return '#endif /* %s */' % guard
filename = argv[1]
def BodyPreamble(self, name):
global _NAME
global _VERSION
header_file = '.'.join(name.split('.')[:-1]) + '.gen.h'
pre = ( '/*\n'
' * Automatically generated from %s\n'
' * by %s/%s. DO NOT EDIT THIS FILE.\n'
' */\n\n' ) % (name, _NAME, _VERSION)
pre += ( '#include <sys/types.h>\n'
'#include <sys/time.h>\n'
'#include <stdlib.h>\n'
'#include <string.h>\n'
'#include <assert.h>\n'
'#include <event.h>\n\n' )
for statement in cppdirect:
pre += '%s\n' % statement
pre += '\n#include "%s"\n\n' % header_file
pre += 'void event_err(int eval, const char *fmt, ...);\n'
pre += 'void event_warn(const char *fmt, ...);\n'
pre += 'void event_errx(int eval, const char *fmt, ...);\n'
pre += 'void event_warnx(const char *fmt, ...);\n\n'
return pre
def HeaderFilename(self, filename):
return '.'.join(filename.split('.')[:-1]) + '.gen.h'
def CodeFilename(self, filename):
return '.'.join(filename.split('.')[:-1]) + '.gen.c'
def Struct(self, name):
return StructCCode(name)
def EntryBytes(self, entry_type, name, tag, fixed_length):
return EntryBytes(entry_type, name, tag, fixed_length)
def EntryVarBytes(self, entry_type, name, tag):
return EntryVarBytes(entry_type, name, tag)
def EntryInt(self, entry_type, name, tag):
return EntryInt(entry_type, name, tag)
def EntryString(self, entry_type, name, tag):
return EntryString(entry_type, name, tag)
def EntryStruct(self, entry_type, name, tag, struct_name):
return EntryStruct(entry_type, name, tag, struct_name)
def Generate(factory, filename):
ext = filename.split('.')[-1]
if ext != 'rpc':
print >>sys.stderr, 'Unrecognized file extension: %s' % ext
@ -1581,15 +1608,15 @@ def main(argv):
print >>sys.stderr, 'Reading \"%s\"' % filename
fp = open(filename, 'r')
entities = Parse(fp)
entities = Parse(factory, fp)
fp.close()
header_file = '.'.join(filename.split('.')[:-1]) + '.gen.h'
impl_file = '.'.join(filename.split('.')[:-1]) + '.gen.c'
header_file = factory.HeaderFilename(filename)
impl_file = factory.CodeFilename(filename)
print >>sys.stderr, '... creating "%s"' % header_file
header_fp = open(header_file, 'w')
print >>header_fp, HeaderPreamble(filename)
print >>header_fp, factory.HeaderPreamble(filename)
# Create forward declarations: allows other structs to reference
# each other
@ -1600,15 +1627,22 @@ def main(argv):
for entry in entities:
entry.PrintTags(header_fp)
entry.PrintDeclaration(header_fp)
print >>header_fp, HeaderPostamble(filename)
print >>header_fp, factory.HeaderPostamble(filename)
header_fp.close()
print >>sys.stderr, '... creating "%s"' % impl_file
impl_fp = open(impl_file, 'w')
print >>impl_fp, BodyPreamble(filename)
print >>impl_fp, factory.BodyPreamble(filename)
for entry in entities:
entry.PrintCode(impl_fp)
impl_fp.close()
def main(argv):
if len(argv) < 2 or not argv[1]:
print >>sys.stderr, 'Need RPC description file as first argument.'
sys.exit(1)
Generate(CCodeGenerator(), argv[1])
if __name__ == '__main__':
main(sys.argv)