diff --git a/lapi.c b/lapi.c index bb76b15a..69b890cd 100644 --- a/lapi.c +++ b/lapi.c @@ -781,7 +781,7 @@ LUA_API int lua_rawgetp (lua_State *L, int idx, const void *p) { } -LUA_API void lua_createtable (lua_State *L, int narray, int nrec) { +LUA_API void lua_createtable (lua_State *L, unsigned narray, unsigned nrec) { Table *t; lua_lock(L); t = luaH_new(L); diff --git a/ltable.c b/ltable.c index 353e567b..b86e2281 100644 --- a/ltable.c +++ b/ltable.c @@ -61,18 +61,25 @@ typedef union { /* -** MAXABITS is the largest integer such that MAXASIZE fits in an +** MAXABITS is the largest integer such that 2^MAXABITS fits in an ** unsigned int. */ #define MAXABITS cast_int(sizeof(int) * CHAR_BIT - 1) /* -** MAXASIZE is the maximum size of the array part. It is the minimum -** between 2^MAXABITS and the maximum size that, measured in bytes, -** fits in a 'size_t'. +** MAXASIZEB is the maximum number of elements in the array part such +** that the size of the array fits in 'size_t'. */ -#define MAXASIZE luaM_limitN(1u << MAXABITS, TValue) +#define MAXASIZEB ((MAX_SIZET/sizeof(ArrayCell)) * NM) + + +/* +** MAXASIZE is the maximum size of the array part. It is the minimum +** between 2^MAXABITS and MAXASIZEB. +*/ +#define MAXASIZE \ + (((1u << MAXABITS) < MAXASIZEB) ? (1u << MAXABITS) : cast_uint(MAXASIZEB)) /* ** MAXHBITS is the largest integer such that 2^MAXHBITS fits in a @@ -663,6 +670,8 @@ void luaH_resize (lua_State *L, Table *t, unsigned int newasize, Table newt; /* to keep the new hash part */ unsigned int oldasize = setlimittosize(t); ArrayCell *newarray; + if (newasize > MAXASIZE) + luaG_runerror(L, "table overflow"); /* create new hash part with appropriate size into 'newt' */ newt.flags = 0; setnodevector(L, &newt, nhsize); diff --git a/ltablib.c b/ltablib.c index c8838963..2ba31a4f 100644 --- a/ltablib.c +++ b/ltablib.c @@ -59,8 +59,10 @@ static void checktab (lua_State *L, int arg, int what) { static int tcreate (lua_State *L) { - int sizeseq = (int)luaL_checkinteger(L, 1); - int sizerest = (int)luaL_optinteger(L, 2, 0); + lua_Unsigned sizeseq = (lua_Unsigned)luaL_checkinteger(L, 1); + lua_Unsigned sizerest = (lua_Unsigned)luaL_optinteger(L, 2, 0); + luaL_argcheck(L, sizeseq <= UINT_MAX, 1, "out of range"); + luaL_argcheck(L, sizerest <= UINT_MAX, 2, "out of range"); lua_createtable(L, sizeseq, sizerest); return 1; } diff --git a/lua.h b/lua.h index 58f31646..b8e0c571 100644 --- a/lua.h +++ b/lua.h @@ -268,7 +268,7 @@ LUA_API int (lua_rawget) (lua_State *L, int idx); LUA_API int (lua_rawgeti) (lua_State *L, int idx, lua_Integer n); LUA_API int (lua_rawgetp) (lua_State *L, int idx, const void *p); -LUA_API void (lua_createtable) (lua_State *L, int narr, int nrec); +LUA_API void (lua_createtable) (lua_State *L, unsigned narr, unsigned nrec); LUA_API void *(lua_newuserdatauv) (lua_State *L, size_t sz, int nuvalue); LUA_API int (lua_getmetatable) (lua_State *L, int objindex); LUA_API int (lua_getiuservalue) (lua_State *L, int idx, int n); diff --git a/manual/manual.of b/manual/manual.of index aaaf15b7..cdd54f66 100644 --- a/manual/manual.of +++ b/manual/manual.of @@ -3234,7 +3234,7 @@ Values at other positions are not affected. } -@APIEntry{void lua_createtable (lua_State *L, int nseq, int nrec);| +@APIEntry{void lua_createtable (lua_State *L, unsigned nseq, unsigned nrec);| @apii{0,1,m} Creates a new empty table and pushes it onto the stack. diff --git a/testes/sort.lua b/testes/sort.lua index 7e566a5a..442b3129 100644 --- a/testes/sort.lua +++ b/testes/sort.lua @@ -3,33 +3,6 @@ print "testing (parts of) table library" -do print "testing 'table.create'" - collectgarbage() - local m = collectgarbage("count") * 1024 - local t = table.create(10000) - local memdiff = collectgarbage("count") * 1024 - m - assert(memdiff > 10000 * 4) - for i = 1, 20 do - assert(#t == i - 1) - t[i] = 0 - end - for i = 1, 20 do t[#t + 1] = i * 10 end - assert(#t == 40 and t[39] == 190) - assert(not T or T.querytab(t) == 10000) - t = nil - collectgarbage() - m = collectgarbage("count") * 1024 - t = table.create(0, 1024) - memdiff = collectgarbage("count") * 1024 - m - assert(memdiff > 1024 * 12) - assert(not T or select(2, T.querytab(t)) == 1024) -end - - -print "testing unpack" - -local unpack = table.unpack - local maxI = math.maxinteger local minI = math.mininteger @@ -40,6 +13,38 @@ local function checkerror (msg, f, ...) end +do print "testing 'table.create'" + local N = 10000 + collectgarbage() + local m = collectgarbage("count") * 1024 + local t = table.create(N) + local memdiff = collectgarbage("count") * 1024 - m + assert(memdiff > N * 4) + for i = 1, 20 do + assert(#t == i - 1) + t[i] = 0 + end + for i = 1, 20 do t[#t + 1] = i * 10 end + assert(#t == 40 and t[39] == 190) + assert(not T or T.querytab(t) == N) + t = nil + collectgarbage() + m = collectgarbage("count") * 1024 + t = table.create(0, 1024) + memdiff = collectgarbage("count") * 1024 - m + assert(memdiff > 1024 * 12) + assert(not T or select(2, T.querytab(t)) == 1024) + + checkerror("table overflow", table.create, (1<<31) + 1) + checkerror("table overflow", table.create, 0, (1<<31) + 1) +end + + +print "testing unpack" + +local unpack = table.unpack + + checkerror("wrong number of arguments", table.insert, {}, 2, 3, 4) local x,y,z,a,n