diff --git a/examples/mqttnet.c b/examples/mqttnet.c index 53111939..6959fcdb 100644 --- a/examples/mqttnet.c +++ b/examples/mqttnet.c @@ -391,37 +391,66 @@ static int NetRead(void *context, byte* buf, int buf_len, * MQTT_CODE_CONTINUE, or proceed with a smaller buffer read/write. * Used for testing nonblocking. */ static int -mqttcurl_test_nonblock(int* buf_len, int for_recv) +mqttcurl_test_nonblock_read(int* buf_len) { - static int testNbAlt = 0; - static int testSmallerBuf = 0; - #if !defined(WOLFMQTT_DEBUG_SOCKET) - (void)for_recv; - #endif + static int testNbReadAlt = 0; + static int testSmallerRead = 0; + + if (testNbReadAlt < WOLFMQTT_TEST_NONBLOCK_TIMES) { + testNbReadAlt++; + #if defined(WOLFMQTT_DEBUG_SOCKET) + PRINTF("mqttcurl_test_nonblock_read: returning early with CONTINUE"); + #endif + return MQTT_CODE_CONTINUE; + } + + testNbReadAlt = 0; + + if (!testSmallerRead) { + if (*buf_len > 2) { + *buf_len /= 2; + testSmallerRead = 1; + #if defined(WOLFMQTT_DEBUG_SOCKET) + PRINTF("mqttcurl_test_nonblock_read: testing small buff: %d", + *buf_len); + #endif + } + } + else { + testSmallerRead = 0; + } + + return MQTT_CODE_SUCCESS; +} - if (testNbAlt < WOLFMQTT_TEST_NONBLOCK_TIMES) { - testNbAlt++; +static int +mqttcurl_test_nonblock_write(int* buf_len) +{ + static int testNbWriteAlt = 0; + static int testSmallerWrite = 0; + + if (testNbWriteAlt < WOLFMQTT_TEST_NONBLOCK_TIMES) { + testNbWriteAlt++; #if defined(WOLFMQTT_DEBUG_SOCKET) - PRINTF("mqttcurl_test_nonblock(%d): returning early with CONTINUE", - for_recv); + PRINTF("mqttcurl_test_nonblock_write: returning early with CONTINUE"); #endif return MQTT_CODE_CONTINUE; } - testNbAlt = 0; + testNbWriteAlt = 0; - if (!testSmallerBuf) { + if (!testSmallerWrite) { if (*buf_len > 2) { *buf_len /= 2; - testSmallerBuf = 1; + testSmallerWrite = 1; #if defined(WOLFMQTT_DEBUG_SOCKET) - PRINTF("mqttcurl_test_nonblock(%d): testing small buff: %d", - for_recv, *buf_len); + PRINTF("mqttcurl_test_nonblock_write: testing small buff: %d", + *buf_len); #endif } } else { - testSmallerBuf = 0; + testSmallerWrite = 0; } return MQTT_CODE_SUCCESS; @@ -745,7 +774,7 @@ static int NetWrite(void *context, const byte* buf, int buf_len, #if defined(WOLFMQTT_NONBLOCK) && defined(WOLFMQTT_TEST_NONBLOCK) if (sock->mqttCtx->useNonBlockMode) { - if (mqttcurl_test_nonblock(&buf_len, 0)) { + if (mqttcurl_test_nonblock_write(&buf_len)) { return MQTT_CODE_CONTINUE; } } @@ -773,8 +802,19 @@ static int NetWrite(void *context, const byte* buf, int buf_len, * payload will be transferred in a single shot without buffering. * todo: add buffering? */ for (size_t i = 0; i < MQTT_CURL_NUM_RETRY; ++i) { + #ifdef WOLFMQTT_MULTITHREAD + int rc = wm_SemLock(&sock->mqttCtx->client.lockCURL); + if (rc != 0) { + return rc; + } + #endif + res = curl_easy_send(sock->curl, buf, buf_len, &sent); + #ifdef WOLFMQTT_MULTITHREAD + wm_SemUnlock(&sock->mqttCtx->client.lockCURL); + #endif + if (res == CURLE_OK) { #if defined(WOLFMQTT_DEBUG_SOCKET) PRINTF("info: curl_easy_send(%d) returned: %d, %s", buf_len, res, @@ -828,7 +868,7 @@ static int NetRead(void *context, byte* buf, int buf_len, #if defined(WOLFMQTT_NONBLOCK) && defined(WOLFMQTT_TEST_NONBLOCK) if (sock->mqttCtx->useNonBlockMode) { - if (mqttcurl_test_nonblock(&buf_len, 1)) { + if (mqttcurl_test_nonblock_read(&buf_len)) { return MQTT_CODE_CONTINUE; } } @@ -856,8 +896,19 @@ static int NetRead(void *context, byte* buf, int buf_len, * payload will be transferred in a single shot without buffering. * todo: add buffering? */ for (size_t i = 0; i < MQTT_CURL_NUM_RETRY; ++i) { + #ifdef WOLFMQTT_MULTITHREAD + int rc = wm_SemLock(&sock->mqttCtx->client.lockCURL); + if (rc != 0) { + return rc; + } + #endif + res = curl_easy_recv(sock->curl, buf, buf_len, &recvd); + #ifdef WOLFMQTT_MULTITHREAD + wm_SemUnlock(&sock->mqttCtx->client.lockCURL); + #endif + if (res == CURLE_OK) { #if defined(WOLFMQTT_DEBUG_SOCKET) PRINTF("info: curl_easy_recv(%d) returned: %d, %s", buf_len, res, diff --git a/src/mqtt_client.c b/src/mqtt_client.c index f8e6b7d8..5aff6e80 100644 --- a/src/mqtt_client.c +++ b/src/mqtt_client.c @@ -1551,6 +1551,11 @@ int MqttClient_Init(MqttClient *client, MqttNet* net, if (rc == 0) { rc = wm_SemInit(&client->lockClient); } + #ifdef ENABLE_MQTT_CURL + if (rc == 0) { + rc = wm_SemInit(&client->lockCURL); + } + #endif #endif if (rc == 0) { @@ -1573,6 +1578,9 @@ void MqttClient_DeInit(MqttClient *client) (void)wm_SemFree(&client->lockSend); (void)wm_SemFree(&client->lockRecv); (void)wm_SemFree(&client->lockClient); + #ifdef ENABLE_MQTT_CURL + (void)wm_SemFree(&client->lockCURL); + #endif #endif } #ifdef WOLFMQTT_V5 diff --git a/src/mqtt_socket.c b/src/mqtt_socket.c index 03e062bb..ab68c426 100644 --- a/src/mqtt_socket.c +++ b/src/mqtt_socket.c @@ -64,7 +64,7 @@ int MqttSocket_TlsSocketReceive(WOLFSSL* ssl, char *buf, int sz, (void)ssl; /* Not used */ rc = client->net->read(client->net->context, (byte*)buf, sz, - client->tls.timeout_ms); + client->tls.timeout_ms_read); /* save network read response */ client->tls.sockRcRead = rc; @@ -87,7 +87,7 @@ int MqttSocket_TlsSocketSend(WOLFSSL* ssl, char *buf, int sz, (void)ssl; /* Not used */ rc = client->net->write(client->net->context, (byte*)buf, sz, - client->tls.timeout_ms); + client->tls.timeout_ms_write); /* save network write response */ client->tls.sockRcWrite = rc; @@ -116,7 +116,8 @@ int MqttSocket_Init(MqttClient *client, MqttNet *net) #if defined(ENABLE_MQTT_TLS) && !defined(ENABLE_MQTT_CURL) client->tls.ctx = NULL; client->tls.ssl = NULL; - client->tls.timeout_ms = client->cmd_timeout_ms; + client->tls.timeout_ms_read = client->cmd_timeout_ms; + client->tls.timeout_ms_write = client->cmd_timeout_ms; #endif /* Validate callbacks are not null! */ @@ -134,8 +135,9 @@ static int MqttSocket_WriteDo(MqttClient *client, const byte* buf, int buf_len, #if defined(ENABLE_MQTT_TLS) && !defined(ENABLE_MQTT_CURL) if (MqttClient_Flags(client,0,0) & MQTT_CLIENT_FLAG_IS_TLS) { - client->tls.timeout_ms = timeout_ms; + client->tls.timeout_ms_write = timeout_ms; client->tls.sockRcWrite = 0; /* init value */ + rc = wolfSSL_write(client->tls.ssl, (char*)buf, buf_len); if (rc < 0) { #if defined(WOLFMQTT_DEBUG_SOCKET) || defined(WOLFSSL_ASYNC_CRYPT) @@ -236,8 +238,9 @@ static int MqttSocket_ReadDo(MqttClient *client, byte* buf, int buf_len, #if defined(ENABLE_MQTT_TLS) && !defined(ENABLE_MQTT_CURL) if (MqttClient_Flags(client,0,0) & MQTT_CLIENT_FLAG_IS_TLS) { - client->tls.timeout_ms = timeout_ms; + client->tls.timeout_ms_read = timeout_ms; client->tls.sockRcRead = 0; /* init value */ + rc = wolfSSL_read(client->tls.ssl, (char*)buf, buf_len); if (rc < 0) { int error = wolfSSL_get_error(client->tls.ssl, 0); diff --git a/wolfmqtt/mqtt_client.h b/wolfmqtt/mqtt_client.h index ddb70afe..d3789bce 100644 --- a/wolfmqtt/mqtt_client.h +++ b/wolfmqtt/mqtt_client.h @@ -209,6 +209,9 @@ typedef struct _MqttClient { wm_Sem lockSend; wm_Sem lockRecv; wm_Sem lockClient; + #ifdef ENABLE_MQTT_CURL + wm_Sem lockCURL; + #endif struct _MqttPendResp* firstPendResp; /* protected with client lock */ struct _MqttPendResp* lastPendResp; /* protected with client lock */ #endif diff --git a/wolfmqtt/mqtt_socket.h b/wolfmqtt/mqtt_socket.h index cb1465e6..410c6c6b 100644 --- a/wolfmqtt/mqtt_socket.h +++ b/wolfmqtt/mqtt_socket.h @@ -71,7 +71,8 @@ typedef struct _MqttTls { WOLFSSL *ssl; int sockRcRead; int sockRcWrite; - int timeout_ms; + int timeout_ms_read; + int timeout_ms_write; } MqttTls; #endif