feat: basic dl_call support

1. add basic dl_call support
2. re-organize code structure
3. fix `wasm_runtime_call_wasm` error handler
4. update examples
This commit is contained in:
Paul Pan 2023-02-11 14:33:45 +08:00
parent 6f3888c1f0
commit a471519ff7
12 changed files with 169 additions and 88 deletions

View File

@ -36,7 +36,8 @@ include(${WAMR_ROOT_DIR}/build-scripts/runtime_lib.cmake)
add_library(vmlib ${WAMR_RUNTIME_LIB_SOURCE}) add_library(vmlib ${WAMR_RUNTIME_LIB_SOURCE})
# wrt # wrt
add_library(wrt wrt/wrt.c wrt/dl.c) file(GLOB_RECURSE WRT_SRC wrt/*.c)
add_library(wrt ${WRT_SRC})
target_include_directories(wrt PRIVATE .) target_include_directories(wrt PRIVATE .)
target_link_libraries(wrt vmlib PkgConfig::FFI ${M_LIBRARY}) target_link_libraries(wrt vmlib PkgConfig::FFI ${M_LIBRARY})

1
main.c
View File

@ -2,7 +2,6 @@
#include <stdlib.h> #include <stdlib.h>
#include "utils/log.h" #include "utils/log.h"
#include "wrt/dl.h"
#include "wrt/wrt.h" #include "wrt/wrt.h"
uint8_t *read_file(const char *filename, uint32_t *fsize) { uint8_t *read_file(const char *filename, uint32_t *fsize) {

View File

@ -3,7 +3,7 @@
extern "C" { extern "C" {
#include "utils/defs.h" #include "utils/defs.h"
#include "wrt/dl.h" #include "wrt/lib/dl.h"
#include "wasm_exec_env.h" #include "wasm_exec_env.h"
extern bool signature_check(const char *signature, int *arg_count, bool *has_ret); extern bool signature_check(const char *signature, int *arg_count, bool *has_ret);

View File

@ -9,6 +9,7 @@ int fib_tail_call(int n);
int fib_fast(int n); int fib_fast(int n);
int call_test(int a, long long b, float c, double d); int call_test(int a, long long b, float c, double d);
int call_test2(int a, long long b, int c, float d); int call_test2(int a, long long b, int c, float d);
int my_cpy(char *dest, char *src, int len);
} }
int fib_recurse(int n) { int fib_recurse(int n) {
@ -53,3 +54,12 @@ int fib_fast(int n) { return fib_fast_internal(n).first; }
int call_test(int a, long long b, float c, double d) { return a + b + c + d; } int call_test(int a, long long b, float c, double d) { return a + b + c + d; }
int call_test2(int a, long long b, int c, float d) { return a + b + c + d; } int call_test2(int a, long long b, int c, float d) { return a + b + c + d; }
int my_cpy(char *dest, char *src, int len) {
char seed = 0x42;
for (int i = 0; i < len; i++) {
dest[i] = src[i];
seed ^= src[i];
}
return seed;
}

View File

@ -8,7 +8,7 @@ typedef enum {
int dl_open(const char *filename, int flags); int dl_open(const char *filename, int flags);
int dl_sym(int handle, const char *symbol, const char *signature); int dl_sym(int handle, const char *symbol, const char *signature);
Status dl_call(int symbol, ...); void *dl_call(int symbol, ...);
Status dl_close_sym(int handle, int symbol); Status dl_close_sym(int handle, int symbol);
Status dl_close(int handle); Status dl_close(int handle);

View File

@ -3,48 +3,89 @@
unsigned fib(unsigned n) { return n < 2 ? n : fib(n - 1) + fib(n - 2); } unsigned fib(unsigned n) { return n < 2 ? n : fib(n - 1) + fib(n - 2); }
void entry() { void dl_test() {
printf("Hello WASM\n"); int hnd, fib_iterate_sym, call_test_sym, call_test2_sym, my_cpy_sym;
Status status;
void *result;
for (unsigned i = 0; i <= 5; i++) printf("fib(%u) = %u\n", i, fib(i)); { // Load library
hnd = dl_open("./SimpleLib.so", 1);
int hnd = dl_open("./SimpleLib.so", 1);
printf("hnd = %d\n", hnd); printf("hnd = %d\n", hnd);
{
int fib_iterate_sym = dl_sym(hnd, "fib_iterate", "(i)i");
printf("fib_iterate_sym = %d\n", fib_iterate_sym);
Status status = dl_call(fib_iterate_sym, 6);
printf("call status = %d\n", status);
status = dl_close_sym(hnd, fib_iterate_sym);
printf("close sym status = %d\n", status);
} }
{ { // Load symbols
int call_test_sym = dl_sym(hnd, "call_test", "(ilfd)i"); fib_iterate_sym = dl_sym(hnd, "fib_iterate", "(i)i");
printf("fib_iterate_sym = %d\n", fib_iterate_sym);
call_test_sym = dl_sym(hnd, "call_test", "(ilfd)i");
printf("call_test_sym = %d\n", call_test_sym); printf("call_test_sym = %d\n", call_test_sym);
Status status = call_test2_sym = dl_sym(hnd, "call_test2", "(ilid)i");
dl_call(call_test_sym, 42, 0x12345678abcdef01, 3.1415926f, 2.718281828459045); printf("call_test2_sym = %d\n", call_test2_sym);
printf("call status = %d\n", status);
my_cpy_sym = dl_sym(hnd, "my_cpy", "(ppi)i");
printf("my_cpy_sym = %d\n", my_cpy_sym);
}
{ // Call function
result = dl_call(fib_iterate_sym, 6);
printf("fib_iterate(6) = %d\n", (int)result);
result = dl_call(call_test_sym, 42, 0x12345678abcdef01, 3.1415926f, 2.718281828459045);
printf("call_test_sym = %d\n", (int)result);
result = dl_call(call_test2_sym, 42, 0x12345678abcdef01, 21, 3.1415926f);
printf("call_test2_sym = %d\n", (int)result);
volatile char src[16] = "Hello, World!";
volatile char dst[16] = "++++++++++++++++";
result = dl_call(my_cpy_sym, dst, src, 16);
printf("my_cpy = %d, src = %s , dst = %s\n", (int)result, src, dst);
}
{ // Close symbols
status = dl_close_sym(hnd, fib_iterate_sym);
printf("close sym status = %d\n", status);
status = dl_close_sym(hnd, call_test_sym); status = dl_close_sym(hnd, call_test_sym);
printf("close sym status = %d\n", status); printf("close sym status = %d\n", status);
}
{
int call_test2_sym = dl_sym(hnd, "call_test2", "(ilid)i");
printf("call_test2_sym = %d\n", call_test2_sym);
Status status = dl_call(call_test2_sym, 42, 0x12345678abcdef01, 21, 3.1415926f);
printf("call status = %d\n", status);
status = dl_close_sym(hnd, call_test2_sym); status = dl_close_sym(hnd, call_test2_sym);
printf("close sym status = %d\n", status); printf("close sym status = %d\n", status);
status = dl_close_sym(hnd, my_cpy_sym);
printf("close sym status = %d\n", status);
} }
Status status = dl_close(hnd); { // Close library
status = dl_close(hnd);
printf("close lib status = %d\n", status); printf("close lib status = %d\n", status);
}
}
void evil_test() {
// Test Something Bad
// 1. invalid symbol - PASSED
// func(0, 0x12345678);
int hnd = dl_open("./SimpleLib.so", 1);
int my_cpy_sym = dl_sym(hnd, "my_cpy", "(pp)i");
// TODO: redesign these cases
// 2. valid symbol, invalid args
dl_call(my_cpy_sym, (void *)0x42424242, (void *)evil_test, 0x12345678);
// 3. valid symbol, valid args
dl_call(my_cpy_sym, (void *)evil_test, (void *)evil_test, 1);
dl_close_sym(hnd, my_cpy_sym);
dl_close(hnd);
}
void entry() {
printf("Hello WASM\n");
dl_test();
evil_test();
} }

View File

@ -1,11 +1,8 @@
#include <string.h> #include <string.h>
#include <stdlib.h>
#include <dlfcn.h> #include <dlfcn.h>
#include "utils/log.h" #include "utils/log.h"
#include "wrt/dl.h" #include "wrt/lib/dl.h"
// TODO: Stateful context -> per WRTContext
#warning "TODO: Permission Check" #warning "TODO: Permission Check"
@ -54,17 +51,17 @@ Status dl_free(DLContext *ctx) {
return WRT_OK; return WRT_OK;
} }
#define GET_CONTEXT \ #define GET_CONTEXT(ret) \
wasm_module_inst_t module_inst = get_module_inst(exec_env); \ wasm_module_inst_t module_inst = get_module_inst(exec_env); \
(void)module_inst; \ (void)module_inst; \
DLContext *ctx = wasm_runtime_get_function_attachment(exec_env); \ DLContext *ctx = wasm_runtime_get_function_attachment(exec_env); \
if (ctx == NULL) return WRT_ERROR; if (ctx == NULL) return ret;
int dl_open(wasm_exec_env_t exec_env, const char *filename, int flags) { int dl_open(wasm_exec_env_t exec_env, const char *filename, int flags) {
if (filename == NULL) return WRT_ERROR; if (filename == NULL) return WRT_ERROR;
if (filename[0] == '\0') return WRT_ERROR; if (filename[0] == '\0') return WRT_ERROR;
GET_CONTEXT GET_CONTEXT(WRT_ERROR)
void *handle = dlopen(filename, flags); void *handle = dlopen(filename, flags);
if (handle == NULL) { if (handle == NULL) {
@ -160,7 +157,7 @@ ffi_cif *create_ffi_cif(DLContext *ctx, const char *signature) {
case 'l': arg_types[i++] = &ffi_type_slong; break; case 'l': arg_types[i++] = &ffi_type_slong; break;
case 'f': arg_types[i++] = &ffi_type_float; break; case 'f': arg_types[i++] = &ffi_type_float; break;
case 'd': arg_types[i++] = &ffi_type_double; break; case 'd': arg_types[i++] = &ffi_type_double; break;
case 'p': arg_types[i++] = &ffi_type_pointer; break; // TODO: pointer to pointer case 'p': arg_types[i++] = &ffi_type_pointer; break;
default: { default: {
LOG_ERR("Invalid signature: %s", signature); LOG_ERR("Invalid signature: %s", signature);
goto fail; goto fail;
@ -178,7 +175,7 @@ ffi_cif *create_ffi_cif(DLContext *ctx, const char *signature) {
case 'l': ret_type = &ffi_type_sint64; break; case 'l': ret_type = &ffi_type_sint64; break;
case 'f': ret_type = &ffi_type_float; break; case 'f': ret_type = &ffi_type_float; break;
case 'd': ret_type = &ffi_type_double; break; case 'd': ret_type = &ffi_type_double; break;
case 'p': ret_type = &ffi_type_pointer; break; // TODO: pointer to pointer case 'p': ret_type = &ffi_type_pointer; break;
case '\0': ret_type = &ffi_type_void; break; case '\0': ret_type = &ffi_type_void; break;
default: { default: {
LOG_ERR("Invalid signature: %s", signature); LOG_ERR("Invalid signature: %s", signature);
@ -205,7 +202,7 @@ int dl_sym(wasm_exec_env_t exec_env, int handle, const char *symbol, const char
if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR; if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR;
if (symbol == NULL || signature == NULL) return WRT_ERROR; if (symbol == NULL || signature == NULL) return WRT_ERROR;
GET_CONTEXT GET_CONTEXT(WRT_ERROR)
void *ptr = ctx->hnd[handle - 1].handle; void *ptr = ctx->hnd[handle - 1].handle;
if (ptr == NULL) return WRT_ERROR; if (ptr == NULL) return WRT_ERROR;
@ -233,45 +230,46 @@ int dl_sym(wasm_exec_env_t exec_env, int handle, const char *symbol, const char
return WRT_ERROR; return WRT_ERROR;
} }
Status dl_call(wasm_exec_env_t exec_env, int symbol, _va_list va_args) { void *dl_call(wasm_exec_env_t exec_env, int symbol, _va_list va_args) {
GET_CONTEXT(NULL)
#define LOG_AND_QUIT(msg) \
do { \
LOG_ERR(msg); \
wasm_runtime_set_exception(module_inst, msg); \
return NULL; \
} while (0)
if (symbol < 1 || symbol > DL_MAX_SYMBOLS) { if (symbol < 1 || symbol > DL_MAX_SYMBOLS) {
LOG_ERR("dl_call: invalid symbol"); LOG_AND_QUIT("dl_call: invalid symbol");
return WRT_ERROR;
} }
GET_CONTEXT
if (ctx->sym[symbol - 1].symbol == NULL) { if (ctx->sym[symbol - 1].symbol == NULL) {
LOG_ERR("dl_call: symbol not found"); LOG_AND_QUIT("dl_call: symbol not found");
return WRT_ERROR;
} }
if (!wasm_runtime_validate_native_addr(module_inst, va_args, sizeof(uint32_t))) { if (!wasm_runtime_validate_native_addr(module_inst, va_args, sizeof(uint32_t))) {
LOG_ERR("dl_call: invalid va_args"); LOG_AND_QUIT("dl_call: invalid va_args");
return WRT_ERROR;
} }
uint8_t *native_end_addr; uint8_t *native_end_addr;
if (!wasm_runtime_get_native_addr_range(module_inst, (uint8_t *)va_args, NULL, if (!wasm_runtime_get_native_addr_range(module_inst, (uint8_t *)va_args, NULL,
&native_end_addr)) { &native_end_addr)) {
LOG_ERR("dl_call: va_args out of range"); LOG_AND_QUIT("dl_call: va_args out of bounds");
wasm_runtime_set_exception(module_inst, "out of bounds memory access");
return WRT_ERROR;
} }
ffi_cif *cif = ctx->sym[symbol - 1].cif; ffi_cif *cif = ctx->sym[symbol - 1].cif;
unsigned arg_count = cif->nargs; unsigned arg_count = cif->nargs;
void *args[arg_count]; void *args[arg_count];
void *tmp_buffer[arg_count]; // TODO: tmp_buffer only used for ptr args, which wastes memory
#define ALIGN(n, b) (((uintptr_t)(n) + b) & (uintptr_t)~b) #define ALIGN(n, b) (((uintptr_t)(n) + b) & (uintptr_t)~b)
#define GET_ARG(ptr, type) (*(type *)((ptr += ALIGN(sizeof(type), 3)) - ALIGN(sizeof(type), 3))) #define GET_ARG(ptr, type) (*(type *)((ptr += ALIGN(sizeof(type), 3)) - ALIGN(sizeof(type), 3)))
#define CHECK_VA_ARG(ptr, type) \ #define CHECK_VA_ARG(ptr, type) \
do { \ do { \
if ((uint8_t *)ptr + ALIGN(type, 3) > native_end_addr) { \ if ((uint8_t *)ptr + ALIGN(type, 3) > native_end_addr) { \
LOG_ERR("dl_call: va_args out of range"); \ LOG_AND_QUIT("dl_call: func args out of bounds"); \
wasm_runtime_set_exception(module_inst, "out of bounds memory access"); \
return WRT_ERROR; \
} \ } \
} while (0) } while (0)
@ -282,58 +280,89 @@ Status dl_call(wasm_exec_env_t exec_env, int symbol, _va_list va_args) {
// int 32, long long 32, float 64, double 64, pointer 32 // int 32, long long 32, float 64, double 64, pointer 32
switch (cif->arg_types[i]->type) { switch (cif->arg_types[i]->type) {
case FFI_TYPE_SINT32: case FFI_TYPE_SINT32: { // 32 bit
case FFI_TYPE_POINTER: {
CHECK_VA_ARG(cur, int32_t); CHECK_VA_ARG(cur, int32_t);
args[i] = &GET_ARG(cur, int32_t); args[i] = &GET_ARG(cur, int32_t);
LOG_DBG("dl_call: arg[%d] = %d / 0x%x", i, *(int32_t *)args[i],
*(int32_t *)args[i]);
break; break;
} }
case FFI_TYPE_SINT64: { case FFI_TYPE_POINTER: { // 32 bit
CHECK_VA_ARG(cur, int32_t);
uint32_t ptr = GET_ARG(cur, int32_t);
/* TODO: Security Warning !!
* - 3rd-party libraries could access native memory directly, which means a buffer
* overflow could modify either host or wasm memory.
* Possible solutions:
* - On applicable devices: map a new page and copy the data, requires a flag to
* indicate whether the 3rd-party library will write to the memory (new arg or
* force memory alignment)
* - On embedded devices: maybe disable this feature? or embedded native symbols?
*/
if (!wasm_runtime_validate_app_addr(module_inst, ptr, 0)) {
LOG_AND_QUIT("dl_call: invalid ptr");
}
tmp_buffer[i] = wasm_runtime_addr_app_to_native(module_inst, ptr);
args[i] = &tmp_buffer[i];
LOG_WARN("dl_call: ptr arg[%d] = 0x%lx", i, (uintptr_t)tmp_buffer[i]);
break;
}
case FFI_TYPE_SINT64: { // 64 bit
cur = (_va_list)ALIGN(cur, 7); cur = (_va_list)ALIGN(cur, 7);
CHECK_VA_ARG(cur, int64_t); CHECK_VA_ARG(cur, int64_t);
args[i] = &GET_ARG(cur, int64_t); args[i] = &GET_ARG(cur, int64_t);
LOG_DBG("dl_call: arg[%d] = %ld / 0x%lx", i, *(int64_t *)args[i],
*(int64_t *)args[i]);
break; break;
} }
case FFI_TYPE_FLOAT: case FFI_TYPE_FLOAT:
case FFI_TYPE_DOUBLE: { case FFI_TYPE_DOUBLE: { // 64 bit: FLOAT and DOUBLE are the same in WASM
cur = (_va_list)ALIGN(cur, 7); cur = (_va_list)ALIGN(cur, 7);
CHECK_VA_ARG(cur, double); CHECK_VA_ARG(cur, double);
args[i] = &GET_ARG(cur, double); args[i] = &GET_ARG(cur, double);
LOG_DBG("dl_call: arg[%d] = %lf", i, *(double *)args[i]);
break; break;
} }
default: { default: {
LOG_ERR("dl_call: unsupported type"); LOG_AND_QUIT("dl_call: unsupported type");
return WRT_ERROR;
} }
} }
} }
#undef CHECK_VA_ARG
#undef GET_ARG
#undef ALIGN
void *ret = NULL; void *ret = NULL;
if (cif->rtype != &ffi_type_void) { if (cif->rtype->type != FFI_TYPE_VOID) {
// TODO: return val storage, re-design dl_call unsigned size = cif->rtype->size;
ret = ctx->mem.malloc(cif->rtype->size); if (cif->rtype->type == FFI_TYPE_FLOAT) size = (&ffi_type_double)->size;
ret = ctx->mem.malloc(size);
}
unsigned short not_implemented[4] = {FFI_TYPE_POINTER, FFI_TYPE_SINT64, FFI_TYPE_DOUBLE,
FFI_TYPE_FLOAT};
for (int i = 0; i < 4; i++)
if (cif->rtype->type == not_implemented[i]) {
/* TODO:
* - FFI_TYPE_POINTER: copy the data, but how to free?(make a linked list?), how to
* determine the size?
* - Others: 64 bit, how to return? change signature to return int64? which suits ptr
* :)
*/
LOG_AND_QUIT("dl_call: unsupported return type (pointer)");
} }
ffi_call(cif, FFI_FN(ctx->sym[symbol - 1].symbol), ret, args); ffi_call(cif, FFI_FN(ctx->sym[symbol - 1].symbol), ret, args);
LOG_DBG("dl_call: symbol = %d, ret = 0x%x", symbol, *(unsigned *)ret); LOG_DBG("dl_call: symbol = %d, ret = 0x%x", symbol, *(unsigned *)ret);
return WRT_OK;
return (void *)0x42;
#undef CHECK_VA_ARG
#undef GET_ARG
#undef ALIGN
#undef LOG_AND_QUIT
} }
Status dl_close_sym(wasm_exec_env_t exec_env, int handle, int symbol) { Status dl_close_sym(wasm_exec_env_t exec_env, int handle, int symbol) {
if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR; if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR;
if (symbol < 1 || symbol > DL_MAX_SYMBOLS) return WRT_ERROR; if (symbol < 1 || symbol > DL_MAX_SYMBOLS) return WRT_ERROR;
GET_CONTEXT GET_CONTEXT(WRT_ERROR)
void *sym = ctx->sym[symbol - 1].symbol; void *sym = ctx->sym[symbol - 1].symbol;
if (sym == NULL) return WRT_ERROR; if (sym == NULL) return WRT_ERROR;
@ -350,7 +379,7 @@ Status dl_close_sym(wasm_exec_env_t exec_env, int handle, int symbol) {
Status dl_close(wasm_exec_env_t exec_env, int handle) { Status dl_close(wasm_exec_env_t exec_env, int handle) {
if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR; if (handle < 1 || handle > DL_MAX_HANDLES) return WRT_ERROR;
GET_CONTEXT GET_CONTEXT(WRT_ERROR)
void *ptr = ctx->hnd[handle - 1].handle; void *ptr = ctx->hnd[handle - 1].handle;
if (ptr == NULL) return WRT_ERROR; if (ptr == NULL) return WRT_ERROR;

View File

@ -27,7 +27,7 @@ Status dl_init(DLContext *context);
Status dl_free(DLContext *context); Status dl_free(DLContext *context);
int dl_open(wasm_exec_env_t exec_env, const char *filename, int flags); int dl_open(wasm_exec_env_t exec_env, const char *filename, int flags);
int dl_sym(wasm_exec_env_t exec_env, int handle, const char *symbol, const char *signature); int dl_sym(wasm_exec_env_t exec_env, int handle, const char *symbol, const char *signature);
Status dl_call(wasm_exec_env_t exec_env, int symbol, _va_list va_args); // TODO void *dl_call(wasm_exec_env_t exec_env, int symbol, _va_list va_args);
Status dl_close_sym(wasm_exec_env_t exec_env, int handle, int symbol); Status dl_close_sym(wasm_exec_env_t exec_env, int handle, int symbol);
Status dl_close(wasm_exec_env_t exec_env, int handle); Status dl_close(wasm_exec_env_t exec_env, int handle);

View File

@ -166,7 +166,8 @@ Status wrt_free(WRTContext *context) {
Status wrt_run(WRTContext *context) { Status wrt_run(WRTContext *context) {
// run wasm // run wasm
if (!wasm_runtime_call_wasm(context->wamr.exec_env, context->wamr.entry_func, 0, NULL)) { if (!wasm_runtime_call_wasm(context->wamr.exec_env, context->wamr.entry_func, 0, NULL)) {
LOG_ERR("Call wasm function failed. error: %s", context->mem.error_buf); LOG_ERR("Call wasm function failed. Reason: %s",
wasm_runtime_get_exception(context->wamr.module_inst));
return WRT_ERROR; return WRT_ERROR;
} }

View File

@ -4,7 +4,7 @@
#include <wasm_export.h> #include <wasm_export.h>
#include "utils/defs.h" #include "utils/defs.h"
#include "wrt/dl.h" #include "wrt/lib/dl.h"
typedef struct { typedef struct {
struct { struct {