diff --git a/package/socket/_socket.c b/package/socket/_socket.c index bfd2e5357..f19bef60e 100644 --- a/package/socket/_socket.c +++ b/package/socket/_socket.c @@ -1,5 +1,5 @@ -#include "_socket_socket.h" #include "PikaPlatform_socket.h" +#include "_socket_socket.h" #if !PIKASCRIPT_VERSION_REQUIRE_MINIMUN(1, 12, 0) #error "This library requires PikaScript version 1.12.0 or higher" @@ -17,6 +17,7 @@ void _socket_socket__init(PikaObj* self) { return; } obj_setInt(self, "sockfd", sockfd); + obj_setInt(self, "blocking", 1); } void _socket_socket__close(PikaObj* self) { @@ -71,9 +72,11 @@ Arg* _socket_socket__recv(PikaObj* self, int num) { data_recv = arg_getBytes(res); ret = __platform_recv(sockfd, data_recv, num, 0); if (ret < 0) { - obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); - __platform_printf("recv error\n"); - return NULL; + if (obj_getInt(self, "blocking")) { + obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); + __platform_printf("recv error\n"); + return NULL; + } } return res; } @@ -91,6 +94,14 @@ void _socket_socket__connect(PikaObj* self, char* host, int port) { server_addr.sin_addr.s_addr = inet_addr(host); __platform_connect(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (obj_getInt(self, "blocking") == 0) { + int flags = fcntl(sockfd, F_GETFL); + if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) { + obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); + __platform_printf("Unable to set socket non blocking\n"); + return; + } + } } void _socket_socket__bind(PikaObj* self, char* host, int port) { @@ -114,3 +125,7 @@ char* _socket__gethostname(PikaObj* self) { __platform_gethostname(hostname_buff, 128); return obj_cacheStr(self, hostname); } + +void _socket_socket__setblocking(PikaObj* self, int sta) { + obj_setInt(self, "blocking", sta); +} \ No newline at end of file diff --git a/package/socket/_socket.pyi b/package/socket/_socket.pyi index 32aba0159..134d34429 100644 --- a/package/socket/_socket.pyi +++ b/package/socket/_socket.pyi @@ -10,6 +10,7 @@ class socket: def _connect(host: str, port: int): ... def _recv(num: int) -> bytes: ... def _init(): ... + def _setblocking(sta: int): ... def _gethostname() -> str: ... diff --git a/package/socket/socket.py b/package/socket/socket.py index 4810a7555..d1420b1b5 100644 --- a/package/socket/socket.py +++ b/package/socket/socket.py @@ -50,5 +50,8 @@ class socket(_socket.socket): def recv(self, num): return self._recv(num) + def setblocking(self, sta): + return self._setblocking(sta) + def gethostname(): return _socket._gethostname() diff --git a/port/linux/package/pikascript/_socket.pyi b/port/linux/package/pikascript/_socket.pyi index 32aba0159..134d34429 100644 --- a/port/linux/package/pikascript/_socket.pyi +++ b/port/linux/package/pikascript/_socket.pyi @@ -10,6 +10,7 @@ class socket: def _connect(host: str, port: int): ... def _recv(num: int) -> bytes: ... def _init(): ... + def _setblocking(sta: int): ... def _gethostname() -> str: ... diff --git a/port/linux/package/pikascript/pikascript-lib/socket/_socket.c b/port/linux/package/pikascript/pikascript-lib/socket/_socket.c index bfd2e5357..f19bef60e 100644 --- a/port/linux/package/pikascript/pikascript-lib/socket/_socket.c +++ b/port/linux/package/pikascript/pikascript-lib/socket/_socket.c @@ -1,5 +1,5 @@ -#include "_socket_socket.h" #include "PikaPlatform_socket.h" +#include "_socket_socket.h" #if !PIKASCRIPT_VERSION_REQUIRE_MINIMUN(1, 12, 0) #error "This library requires PikaScript version 1.12.0 or higher" @@ -17,6 +17,7 @@ void _socket_socket__init(PikaObj* self) { return; } obj_setInt(self, "sockfd", sockfd); + obj_setInt(self, "blocking", 1); } void _socket_socket__close(PikaObj* self) { @@ -71,9 +72,11 @@ Arg* _socket_socket__recv(PikaObj* self, int num) { data_recv = arg_getBytes(res); ret = __platform_recv(sockfd, data_recv, num, 0); if (ret < 0) { - obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); - __platform_printf("recv error\n"); - return NULL; + if (obj_getInt(self, "blocking")) { + obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); + __platform_printf("recv error\n"); + return NULL; + } } return res; } @@ -91,6 +94,14 @@ void _socket_socket__connect(PikaObj* self, char* host, int port) { server_addr.sin_addr.s_addr = inet_addr(host); __platform_connect(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (obj_getInt(self, "blocking") == 0) { + int flags = fcntl(sockfd, F_GETFL); + if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) { + obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR); + __platform_printf("Unable to set socket non blocking\n"); + return; + } + } } void _socket_socket__bind(PikaObj* self, char* host, int port) { @@ -114,3 +125,7 @@ char* _socket__gethostname(PikaObj* self) { __platform_gethostname(hostname_buff, 128); return obj_cacheStr(self, hostname); } + +void _socket_socket__setblocking(PikaObj* self, int sta) { + obj_setInt(self, "blocking", sta); +} \ No newline at end of file diff --git a/port/linux/package/pikascript/socket.py b/port/linux/package/pikascript/socket.py index 4810a7555..d1420b1b5 100644 --- a/port/linux/package/pikascript/socket.py +++ b/port/linux/package/pikascript/socket.py @@ -50,5 +50,8 @@ class socket(_socket.socket): def recv(self, num): return self._recv(num) + def setblocking(self, sta): + return self._setblocking(sta) + def gethostname(): return _socket._gethostname()