#include #include struct frame_t; struct list_t; struct symbol_t; struct integer_t; struct string_t; void error(std::string str) __attribute__((noreturn)); void error(std::string str) { std::cout << "Error: " << str << std::endl; throw str; } struct node_t { virtual node_t* eval(frame_t* frame) { return this; } virtual node_t* apply(frame_t* frame, list_t* list); virtual std::string to_string() { return "???"; } virtual bool is_integer() { return false; } virtual uintptr_t int_value() { error(to_string() + " is not an integer"); } virtual list_t* as_list() { error(to_string() + " is not a list"); } virtual bool is(std::string id) { return false; } }; struct integer_t : node_t { uintptr_t v; integer_t(uintptr_t v) : v { v } { } uintptr_t int_value() { return v; } bool is_integer() { return true; } std::string to_string() { return std::to_string(int_value()); } node_t* apply(frame_t* frame, list_t* list); }; struct string_t : node_t { std::string v; string_t(std::string v) : v { v } { } std::string to_string() { return v; } }; struct list_t : node_t { node_t** elements; size_t nb_elements; size_t capacity; list_t() : elements { nullptr }, nb_elements { 0 }, capacity { 0 } { } node_t* at(size_t n) { if(nb_elements < n) error("list index out of bounds"); return elements[n]; } void set_at(size_t n, node_t* node) { if(nb_elements < n) error("list index out of bounds"); elements[n] = node; } size_t size() { return nb_elements; } void append(node_t* node) { if(size() == capacity) { capacity = capacity ? capacity << 1 : 4; node_t** new_elements = new node_t*[capacity]; for(size_t i=0; i 0) res += " "; res += at(i)->to_string(); } return res + ")"; } node_t* apply(frame_t* frame, list_t* list); node_t* eval(frame_t* frame) { return at(0)->eval(frame)->apply(frame, this); } }; struct symbol_t : node_t { std::string id; symbol_t(std::string id) : id { id } { } bool is(std::string id) { return this->id == id; } node_t* eval(frame_t* frame); std::string to_string() { return id; } }; struct frame_t : public list_t { frame_t* parent; frame_t() : frame_t { nullptr } { } frame_t(frame_t* parent) : parent { parent } { } void set(node_t* sym, node_t* value) { for(size_t i=0; iis(at(i)->as_list()->at(0)->to_string())) { at(i)->as_list()->set_at(1, value); return; } append(new list_t); at(size()-1)->as_list()->append(sym); at(size()-1)->as_list()->append(value); } node_t* get(symbol_t* sym) { for(size_t i=0; ias_list()->at(0)->is(sym->to_string())) return at(i)->as_list()->at(1); if(parent == nullptr) error("variable " + sym->to_string() + " does not exist"); return parent->get(sym); } }; struct keyword_t : node_t { std::string id; keyword_t(std::string id) : id { id } { } std::string to_string() { return "<" + id + ">"; } }; struct add_t : keyword_t { add_t() : keyword_t { "+" } { } node_t* apply(frame_t* frame, list_t* list) { uintptr_t res = 0; for(size_t i=1; isize(); i++) res += list->at(i)->eval(frame)->int_value(); return new integer_t { res }; } }; struct sub_t : keyword_t { sub_t() : keyword_t { "-" } { } node_t* apply(frame_t* frame, list_t* list) { uintptr_t res = list->at(1)->eval(frame)->int_value(); if(list->size() == 2) return new integer_t { - res }; else { for(size_t i=2; isize(); i++) res -= list->at(i)->eval(frame)->int_value(); return new integer_t { res }; } } }; #define on_unary_op(_) \ _(not_t, "not", !); \ #define on_binary_op(_) \ _(eq_t, "=", ==); \ _(ne_t, "!=", !=); \ _(lt_t, "<", <); \ _(le_t, "<=", <=); \ _(gt_t, ">", >); \ _(ge_t, ">=", >=) #define define_unary_op(name_t, str, op) \ struct name_t : keyword_t { \ name_t() : keyword_t { str } { } \ \ node_t* apply(frame_t* frame, list_t* list) { \ return new integer_t { op list->at(1)->eval(frame)->int_value() }; \ } \ } #define define_binary_op(name_t, str, op) \ struct name_t : keyword_t { \ name_t() : keyword_t { str } { } \ \ node_t* apply(frame_t* frame, list_t* list) { \ return new integer_t { list->at(1)->eval(frame)->int_value() op list->at(2)->eval(frame)->int_value() }; \ } \ } on_unary_op(define_unary_op); on_binary_op(define_binary_op); struct if_t : keyword_t { if_t() : keyword_t { "if" } { } node_t* apply(frame_t* frame, list_t* list) { return list->at(1)->eval(frame)->int_value() == 0 ? (list->size() == 4 ? list->at(3)->eval(frame) : new integer_t { 0 }) : list->at(2)->eval(frame); } }; struct set_t : keyword_t { set_t() : keyword_t { "set!" } { } node_t* apply(frame_t* frame, list_t* list) { node_t* value = list->at(2)->eval(frame); frame->set(list->at(1), value); return value; } }; struct begin_t : keyword_t { begin_t() : keyword_t { "begin" } { } node_t* apply(frame_t* frame, list_t* list) { node_t* res = 0; for(size_t i=0; isize(); i++) res = list->at(i)->eval(frame); return res; } }; struct while_t : keyword_t { while_t() : keyword_t { "while" } { } node_t* apply(frame_t* frame, list_t* list) { node_t* res = 0; while(list->at(1)->eval(frame)->int_value() != 0) for(size_t i=2; isize(); i++) res = list->at(2)->eval(frame); return res; } }; struct lambda_t : keyword_t { lambda_t() : keyword_t { "lambda" } { } node_t* apply(frame_t* frame, list_t* list) { return list; } }; struct quote_t : keyword_t { quote_t() : keyword_t { "'" } { } node_t* apply(frame_t* frame, list_t* list) { return list->at(1); } }; node_t* symbol_t::eval(frame_t* frame) { node_t* value = frame->get(this); if(value == nullptr) { value = new integer_t { 0 }; frame->set(this, value); } return value; } node_t* node_t::apply(frame_t* frame, list_t* list) { error("cannot apply " + to_string() + " to " + list->to_string()); } node_t* list_t::apply(frame_t* frame, list_t* list) { if(!at(0)->is("lambda")) error("cannot apply " + to_string() + " to " + list->to_string()); list_t* params = at(1)->as_list(); if(params->size() != list->size() - 1) error("wrong number of arguments when applying " + to_string() + " to " + list->to_string()); frame_t* new_frame = new frame_t { frame }; for(size_t i=0; isize(); i++) new_frame->set(params->at(i), list->at(i+1)->eval(frame)); return at(2)->eval(new_frame); } node_t* integer_t::apply(frame_t* frame, list_t* list) { uintptr_t (*f)(uintptr_t, uintptr_t, uintptr_t, uintptr_t) = (uintptr_t (*)(uintptr_t, uintptr_t, uintptr_t, uintptr_t))int_value(); uintptr_t (*g)(uintptr_t, ...) = (uintptr_t (*)(uintptr_t, ...))int_value(); std::string strs[4]; uintptr_t args[] = { 0, 0, 0, 0 }; size_t n = list->size() < 5 ? list->size() : 5; for(size_t i=1; iat(i)->eval(frame); std::cout << "str: '" << arg->to_string() << "'" << std::endl; args[i-1] = arg->is_integer() ? arg->int_value() : (uintptr_t)(strs[i-1] = arg->to_string()).c_str(); } if((uintptr_t)f == (uintptr_t)printf || (uintptr_t)f == (uintptr_t)vprintf || (uintptr_t)f == (uintptr_t)snprintf) return new integer_t { g(args[0], args[1], args[2], args[3]) }; else return new integer_t { f(args[0], args[1], args[2], args[3]) }; } struct lexer_t { std::istream* is; int next; lexer_t(std::istream* is) : is { is }, next { ' ' } { } node_t* next_token() { std::string str; bool is_number = true; uintptr_t number = 0; while(std::isspace(next)) next = is->get(); switch(next) { case EOF: std::cout << "prematured end of file" << std::endl; exit(1); case '"': for(;;) { next = is->get(); if(next == EOF || next == '"') { next = is->get(); return new string_t { str }; } if(next == '\\') switch(next = is->get()) { case EOF: error("unexpected \\ in a string"); case 'n': next = '\n'; break; case 't': next = '\t'; break; } str += next; } case '\'': case '(': case ')': str += next; next = is->get(); return new symbol_t { str }; default: for(;;) { is_number &= next >= '0' && next <= '9'; number = number * 10 + next - '0'; str += next; next = is->get(); if(std::isspace(next) || next == '\"' || next == '\'' || next == '(' || next == ')') { if(is_number) return new integer_t { number }; else return new symbol_t { str }; } } }; } }; node_t* parse(lexer_t* lexer) { node_t* token = lexer->next_token(); if(token == nullptr) error("unexpected end of string"); if(token->is("'")) { list_t* res = new list_t; res->append(token); res->append(parse(lexer)); return res; } else if(token->is("(")) { list_t* list = new list_t; while(!(token = parse(lexer))->is(")")) list->append(token); return list; } else return token; } int main(int argc, char* argv[]) { frame_t* frame = new frame_t; #define add_op(name_t, str, op) frame->set( new symbol_t { str }, new name_t) frame->set(new symbol_t { "+" }, new add_t); frame->set(new symbol_t { "-" }, new sub_t); on_unary_op(add_op); on_binary_op(add_op); frame->set(new symbol_t { "'" }, new quote_t); frame->set(new symbol_t { "if" }, new if_t); frame->set(new symbol_t { "set!" }, new set_t); frame->set(new symbol_t { "while" }, new while_t); frame->set(new symbol_t { "begin" }, new begin_t); frame->set(new symbol_t { "dlsym" }, new integer_t { (uintptr_t)dlsym }); frame->set(new symbol_t { "dlself" }, new integer_t { (uintptr_t)RTLD_DEFAULT }); frame->set(new symbol_t { "lambda" }, new lambda_t); lexer_t lexer { &std::cin }; for(;;) { try { std::cout << "> "; node_t* ast = parse(&lexer); node_t* res = ast->eval(frame); std::cout << " => " << res->to_string() << std::endl; } catch(std::string str) { } } return 0; }