From b117bdb3448778d9e7f9a0302791e8ac3bb97ddd Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Sat, 16 Nov 2024 12:00:28 -0300 Subject: [PATCH] Counter for length of chains of __call metamethods This counter will allow (in a later commit) error messages to correct argument numbers in functions called through __call metamethods. --- ldo.c | 41 +++++++++++++++++++++++++++-------------- lstate.h | 28 +++++++++++++++++----------- manual/manual.of | 4 ++++ testes/calls.lua | 23 ++++++++++++++++++++--- 4 files changed, 68 insertions(+), 28 deletions(-) diff --git a/ldo.c b/ldo.c index 72a1e306..cb7e5aef 100644 --- a/ldo.c +++ b/ldo.c @@ -464,21 +464,26 @@ static void rethook (lua_State *L, CallInfo *ci, int nres) { /* ** Check whether 'func' has a '__call' metafield. If so, put it in the -** stack, below original 'func', so that 'luaD_precall' can call it. Raise -** an error if there is no '__call' metafield. +** stack, below original 'func', so that 'luaD_precall' can call it. +** Raise an error if there is no '__call' metafield. +** Bits CIST_CCMT in status count how many _call metamethods were +** invoked and how many corresponding extra arguments were pushed. +** (This count will be saved in the 'callstatus' of the call). +** Raise an error if this counter overflows. */ -static StkId tryfuncTM (lua_State *L, StkId func) { +static unsigned tryfuncTM (lua_State *L, StkId func, unsigned status) { const TValue *tm; StkId p; - checkstackp(L, 1, func); /* space for metamethod */ - tm = luaT_gettmbyobj(L, s2v(func), TM_CALL); /* (after previous GC) */ - if (l_unlikely(ttisnil(tm))) - luaG_callerror(L, s2v(func)); /* nothing to call */ + tm = luaT_gettmbyobj(L, s2v(func), TM_CALL); + if (l_unlikely(ttisnil(tm))) /* no metamethod? */ + luaG_callerror(L, s2v(func)); for (p = L->top.p; p > func; p--) /* open space for metamethod */ setobjs2s(L, p, p-1); L->top.p++; /* stack space pre-allocated by the caller */ setobj2s(L, func, tm); /* metamethod is the new function to be called */ - return func; + if ((status & MAX_CCMT) == MAX_CCMT) /* is counter full? */ + luaG_runerror(L, "'__call' chain too long"); + return status + (1u << CIST_CCMT); /* increment counter */ } @@ -564,11 +569,17 @@ void luaD_poscall (lua_State *L, CallInfo *ci, int nres) { #define next_ci(L) (L->ci->next ? L->ci->next : luaE_extendCI(L)) +/* +** Allocate and initialize CallInfo structure. At this point, the +** only valid fields in the call status are number of results, +** CIST_C (if it's a C function), and number of extra arguments. +** (All these bit-fields fit in 16-bit values.) +*/ l_sinline CallInfo *prepCallInfo (lua_State *L, StkId func, unsigned status, StkId top) { CallInfo *ci = L->ci = next_ci(L); /* new frame */ ci->func.p = func; - lua_assert((status & ~(CIST_NRESULTS | CIST_C)) == 0); + lua_assert((status & ~(CIST_NRESULTS | CIST_C | MAX_CCMT)) == 0); ci->callstatus = status; ci->top.p = top; return ci; @@ -607,12 +618,13 @@ l_sinline int precallC (lua_State *L, StkId func, unsigned status, */ int luaD_pretailcall (lua_State *L, CallInfo *ci, StkId func, int narg1, int delta) { + unsigned status = LUA_MULTRET + 1; retry: switch (ttypetag(s2v(func))) { case LUA_VCCL: /* C closure */ - return precallC(L, func, LUA_MULTRET + 1, clCvalue(s2v(func))->f); + return precallC(L, func, status, clCvalue(s2v(func))->f); case LUA_VLCF: /* light C function */ - return precallC(L, func, LUA_MULTRET + 1, fvalue(s2v(func))); + return precallC(L, func, status, fvalue(s2v(func))); case LUA_VLCL: { /* Lua function */ Proto *p = clLvalue(s2v(func))->p; int fsize = p->maxstacksize; /* frame size */ @@ -633,8 +645,8 @@ int luaD_pretailcall (lua_State *L, CallInfo *ci, StkId func, return -1; } default: { /* not a function */ - func = tryfuncTM(L, func); /* try to get '__call' metamethod */ - /* return luaD_pretailcall(L, ci, func, narg1 + 1, delta); */ + checkstackp(L, 1, func); /* space for metamethod */ + status = tryfuncTM(L, func, status); /* try '__call' metamethod */ narg1++; goto retry; /* try again */ } @@ -676,7 +688,8 @@ CallInfo *luaD_precall (lua_State *L, StkId func, int nresults) { return ci; } default: { /* not a function */ - func = tryfuncTM(L, func); /* try to get '__call' metamethod */ + checkstackp(L, 1, func); /* space for metamethod */ + status = tryfuncTM(L, func, status); /* try '__call' metamethod */ goto retry; /* try again with metamethod */ } } diff --git a/lstate.h b/lstate.h index ab567213..1c81b6ed 100644 --- a/lstate.h +++ b/lstate.h @@ -221,16 +221,24 @@ struct CallInfo { */ /* bits 0-7 are the expected number of results from this function + 1 */ #define CIST_NRESULTS 0xffu -/* Bits 8-10 are used for CIST_RECST (see below) */ -#define CIST_RECST 8 /* the offset, not the mask */ -/* original value of 'allowhook' */ -#define CIST_OAH (cast(l_uint32, 1) << 11) -/* call is running a C function */ -#define CIST_C (CIST_OAH << 1) + +/* bits 8-11 count call metamethods (and their extra arguments) */ +#define CIST_CCMT 8 /* the offset, not the mask */ +#define MAX_CCMT (0xfu << CIST_CCMT) + +/* Bits 12-14 are used for CIST_RECST (see below) */ +#define CIST_RECST 12 /* the offset, not the mask */ + +/* call is running a C function (still in first 16 bits) */ +#define CIST_C (1u << (CIST_RECST + 3)) /* call is on a fresh "luaV_execute" frame */ -#define CIST_FRESH (CIST_C << 1) +#define CIST_FRESH cast(l_uint32, CIST_C << 1) +/* function is closing tbc variables */ +#define CIST_CLSRET (CIST_FRESH << 1) +/* original value of 'allowhook' */ +#define CIST_OAH (CIST_CLSRET << 1) /* call is running a debug hook */ -#define CIST_HOOKED (CIST_FRESH << 1) +#define CIST_HOOKED (CIST_OAH << 1) /* doing a yieldable protected call */ #define CIST_YPCALL (CIST_HOOKED << 1) /* call was tail called */ @@ -239,11 +247,9 @@ struct CallInfo { #define CIST_HOOKYIELD (CIST_TAIL << 1) /* function "called" a finalizer */ #define CIST_FIN (CIST_HOOKYIELD << 1) - /* function is closing tbc variables */ -#define CIST_CLSRET (CIST_FIN << 1) #if defined(LUA_COMPAT_LT_LE) /* using __lt for __le */ -#define CIST_LEQ (CIST_CLSRET << 1) +#define CIST_LEQ (CIST_FIN << 1) #endif diff --git a/manual/manual.of b/manual/manual.of index f0b17b4c..ce42ff51 100644 --- a/manual/manual.of +++ b/manual/manual.of @@ -9392,6 +9392,10 @@ If you need to change it, declare a local variable with the same name in the loop body. } +@item{ +A chain of @id{__call} metamethods can have at most 15 objects. +} + } } diff --git a/testes/calls.lua b/testes/calls.lua index 409a275d..12312d60 100644 --- a/testes/calls.lua +++ b/testes/calls.lua @@ -178,7 +178,7 @@ do -- tail calls x chain of __call end -- build a chain of __call metamethods ending in function 'foo' - for i = 1, 100 do + for i = 1, 15 do foo = setmetatable({}, {__call = foo}) end @@ -190,8 +190,8 @@ end print('+') -do -- testing chains of '__call' - local N = 20 +do print"testing chains of '__call'" + local N = 15 local u = table.pack for i = 1, N do u = setmetatable({i}, {__call = u}) @@ -207,6 +207,23 @@ do -- testing chains of '__call' end +do -- testing chains too long + local a = {} + for i = 1, 16 do -- one too many + a = setmetatable({}, {__call = a}) + end + local status, msg = pcall(a) + assert(not status and string.find(msg, "too long")) + + setmetatable(a, {__call = a}) -- infinite chain + local status, msg = pcall(a) + assert(not status and string.find(msg, "too long")) + + -- again, with a tail call + local status, msg = pcall(function () return a() end) + assert(not status and string.find(msg, "too long")) +end + a = nil (function (x) a=x end)(23) assert(a == 23 and (function (x) return x*2 end)(20) == 40)