
/*
 * Odyssey.
 *
 * Scalable PostgreSQL connection pooler.
 */

#include <odyssey.h>

#include <machinarium/machinarium.h>

#include <auth_query.h>
#include <types.h>
#include <global.h>
#include <auth.h>
#include <client.h>
#include <instance.h>
#include <frontend.h>
#include <route.h>
#include <module.h>
#include <backend.h>
#include <dns.h>
#include <sasl.h>
#include <extension.h>
#include <external_auth.h>
#include <mdb_iamproxy.h>

static inline int od_auth_frontend_external_authentication(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;
	od_route_t *route = client->route;

	/* AuthenticationCleartextPassword */
	machine_msg_t *msg;
	msg = kiwi_be_write_authentication_clear_text(NULL);
	if (msg == NULL) {
		return -1;
	}
	int rc;
	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));
		return -1;
	}

	/* wait for password response */
	for (;;) {
		msg = od_read(&client->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", client, NULL,
				 "read error: %s", od_io_error(&client->io));
			return -1;
		}
		kiwi_fe_type_t type = *(char *)machine_msg_data(msg);
		od_debug(&instance->logger, "auth", client, NULL, "%s",
			 kiwi_fe_type_to_string(type));
		if (type == KIWI_FE_PASSWORD_MESSAGE) {
			break;
		}
		machine_msg_free(msg);
	}

	/* read password message */
	kiwi_password_t client_token;
	kiwi_password_init(&client_token);

	rc = kiwi_be_read_password(machine_msg_data(msg), machine_msg_size(msg),
				   &client_token);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "password read error");
		od_frontend_error(client, KIWI_PROTOCOL_VIOLATION,
				  "bad password message");
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		return -1;
	}

	if (route->rule->enable_password_passthrough) {
		kiwi_password_copy(&client->received_password, &client_token);
		od_debug(&instance->logger, "auth", client, NULL,
			 "saved user password to perform backend auth");
	}

	/* support external authentication */
	int authentication_result =
		external_user_authentication(client->startup.user.value,
					     client_token.password, instance,
					     client);
	kiwi_password_free(&client_token);
	machine_msg_free(msg);
	if (authentication_result != OK_RESPONSE) {
		goto auth_failed;
	}
	return OK_RESPONSE;

auth_failed:
	od_log(&instance->logger, "auth", client, NULL,
	       "user '%s.%s' incorrect password",
	       client->startup.database.value, client->startup.user.value);
	od_frontend_fatal(client, KIWI_INVALID_PASSWORD,
			  "external authentication failed for user \"%s\"",
			  client->startup.user.value);
	return NOT_OK_RESPONSE;
}

static inline int od_auth_frontend_cleartext(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;
	od_route_t *route = client->route;

	/* AuthenticationCleartextPassword */
	machine_msg_t *msg;
	msg = kiwi_be_write_authentication_clear_text(NULL);
	if (msg == NULL) {
		return -1;
	}
	int rc;
	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));
		return -1;
	}

	/* wait for password response */
	for (;;) {
		msg = od_read(&client->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", client, NULL,
				 "read error: %s", od_io_error(&client->io));
			return -1;
		}
		kiwi_fe_type_t type = *(char *)machine_msg_data(msg);
		od_debug(&instance->logger, "auth", client, NULL, "%s",
			 kiwi_fe_type_to_string(type));
		if (type == KIWI_FE_PASSWORD_MESSAGE) {
			break;
		}
		machine_msg_free(msg);
	}

	/* read password message */
	kiwi_password_t client_token;
	kiwi_password_init(&client_token);

	rc = kiwi_be_read_password(machine_msg_data(msg), machine_msg_size(msg),
				   &client_token);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "password read error");
		od_frontend_error(client, KIWI_PROTOCOL_VIOLATION,
				  "bad password message");
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		return -1;
	}

	if (route->rule->enable_password_passthrough) {
		kiwi_password_copy(&client->received_password, &client_token);
		od_debug(&instance->logger, "auth", client, NULL,
			 "saved user password to perform backend auth");
	}

	od_extension_t *extensions = client->global->extensions;

	/* support mdb_iamproxy authentication */
	if (client->rule->enable_mdb_iamproxy_auth) {
		int authentication_result = mdb_iamproxy_authenticate_user(
			client->startup.user.value, client_token.password,
			instance, client);
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		if (authentication_result != OK_RESPONSE) {
			goto auth_failed; /* reference at line 80, 100 and etc */
		}
		return OK_RESPONSE;
	}

#ifdef LDAP_FOUND
	if (client->rule->ldap_endpoint_name) {
		od_debug(&instance->logger, "auth", client, NULL,
			 "checking passwd against ldap endpoint %s",
			 client->rule->ldap_endpoint_name);

		rc = od_auth_ldap(client, &client_token);
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		if (rc != OK_RESPONSE) {
			goto auth_failed;
		}
		return OK_RESPONSE;
	}
#endif
	if (client->rule->auth_module) {
		od_module_t *modules = extensions->modules;

		/* auth callback */
		od_module_t *module;
		module = od_modules_find(modules, client->rule->auth_module);
		if (module->od_auth_cleartext_cb == NULL) {
			kiwi_password_free(&client_token);
			machine_msg_free(msg);
			goto auth_failed;
		}
		int rc = module->od_auth_cleartext_cb(client, &client_token);
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		if (rc != OD_MODULE_CB_OK_RETCODE) {
			goto auth_failed;
		}
		return OK_RESPONSE;
	}

#ifdef PAM_FOUND
	/* support PAM authentication */
	if (client->rule->auth_pam_service) {
		od_pam_convert_passwd(client->rule->auth_pam_data,
				      client_token.password);

		rc = od_pam_auth(client->rule->auth_pam_service,
				 client->startup.user.value,
				 client->rule->auth_pam_data, client->io.io);
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		if (rc == -1) {
			goto auth_failed;
		}
		return OK_RESPONSE;
	}
#endif

	/* use remote or local password source */
	kiwi_password_t client_password;
	if (client->rule->auth_query) {
		char peer[128];
		od_getpeername(client->io.io, peer, sizeof(peer), 1, 0);
		od_debug(&instance->logger, "auth", client, NULL,
			 "running auth_query for peer %s", peer);
		rc = od_auth_query(client, peer);
		if (rc == -1) {
			od_error(&instance->logger, "auth", client, NULL,
				 "failed to make auth_query");
			od_frontend_error(
				client,
				KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
				"failed to make auth query");
			kiwi_password_free(&client_token);
			machine_msg_free(msg);
			return NOT_OK_RESPONSE;
		}

		/* TODO: consider support for empty password case. */
		if (client->password.password == NULL) {
			od_log(&instance->logger, "auth", client, NULL,
			       "user '%s.%s' incorrect user from %s",
			       client->startup.database.value,
			       client->startup.user.value, peer);
			od_frontend_error(client, KIWI_INVALID_PASSWORD,
					  "incorrect user");
			kiwi_password_free(&client_token);
			machine_msg_free(msg);
			return NOT_OK_RESPONSE;
		}
		client_password = client->password;
	} else {
		client_password.password_len = client->rule->password_len + 1;
		client_password.password = client->rule->password;
	}

	/* authenticate */
	int check = kiwi_password_compare(&client_password, &client_token);
	kiwi_password_free(&client_token);
	machine_msg_free(msg);
	if (check) {
		return OK_RESPONSE;
	}

	goto auth_failed;

auth_failed:
	od_log(&instance->logger, "auth", client, NULL,
	       "user '%s.%s' incorrect password",
	       client->startup.database.value, client->startup.user.value);
	od_frontend_fatal(client, KIWI_INVALID_PASSWORD,
			  "password authentication failed for user \"%s\"",
			  client->startup.user.value);
	return NOT_OK_RESPONSE;
}

static inline int od_auth_frontend_md5(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;

	/* generate salt */
	uint32_t salt =
		kiwi_password_salt(&client->key, (uint32_t)machine_lrand48());

	/* AuthenticationMD5Password */
	machine_msg_t *msg;
	msg = kiwi_be_write_authentication_md5(NULL, (char *)&salt);
	if (msg == NULL) {
		return -1;
	}
	int rc;
	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));
		return -1;
	}

	/* wait for password response */
	for (;;) {
		msg = od_read(&client->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", client, NULL,
				 "read error: %s", od_io_error(&client->io));
			return -1;
		}
		kiwi_fe_type_t type = *(char *)machine_msg_data(msg);
		od_debug(&instance->logger, "auth", client, NULL, "%s",
			 kiwi_fe_type_to_string(type));
		if (type == KIWI_FE_PASSWORD_MESSAGE) {
			break;
		}
		machine_msg_free(msg);
	}

	/* read password message */
	kiwi_password_t client_token;
	kiwi_password_init(&client_token);
	rc = kiwi_be_read_password(machine_msg_data(msg), machine_msg_size(msg),
				   &client_token);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "password read error");
		od_frontend_error(client, KIWI_PROTOCOL_VIOLATION,
				  "bad password message");
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		return -1;
	}

	/* use remote or local password source */
	kiwi_password_t client_password;
	kiwi_password_init(&client_password);

	kiwi_password_t query_password;
	kiwi_password_init(&query_password);

	if (client->rule->auth_query) {
		char peer[128];
		od_getpeername(client->io.io, peer, sizeof(peer), 1, 0);
		rc = od_auth_query(client, peer);
		if (rc == -1) {
			od_error(&instance->logger, "auth", client, NULL,
				 "failed to make auth_query");
			od_frontend_error(
				client,
				KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
				"failed to make auth query");
			kiwi_password_free(&client_token);
			kiwi_password_free(&query_password);
			machine_msg_free(msg);
			return -1;
		}

		/* TODO: consider support for empty password case. */
		if (client->password.password == NULL) {
			od_log(&instance->logger, "auth", client, NULL,
			       "user '%s.%s' incorrect user from %s",
			       client->startup.database.value,
			       client->startup.user.value, peer);
			od_frontend_error(client, KIWI_INVALID_PASSWORD,
					  "incorrect user");
			kiwi_password_free(&client_token);
			machine_msg_free(msg);
			return -1;
		}

		query_password = client->password;
		query_password.password_len = client->password.password_len - 1;
	} else {
		query_password.password_len = client->rule->password_len;
		query_password.password = client->rule->password;
	}

#ifdef LDAP_FOUND
	if (client->rule->ldap_endpoint) {
		od_debug(&instance->logger, "auth", client, NULL,
			 "checking passwd against ldap endpoint %s",
			 client->rule->ldap_endpoint_name);

		rc = od_auth_ldap(client, &client_token);
		kiwi_password_free(&client_token);
		machine_msg_free(msg);
		if (rc != OK_RESPONSE) {
			od_log(&instance->logger, "auth", client, NULL,
			       "user '%s.%s' incorrect password",
			       client->startup.database.value,
			       client->startup.user.value);
			/* TODO: pass error from ldap here */
			od_frontend_fatal(
				client, KIWI_INVALID_PASSWORD,
				"password authentication failed for user \"%s\"",
				client->startup.user.value);
			return NOT_OK_RESPONSE;
		}
		return OK_RESPONSE;
	}
#endif

	/* prepare password hash */
	rc = kiwi_password_md5(&client_password, client->startup.user.value,
			       client->startup.user.value_len - 1,
			       query_password.password,
			       query_password.password_len, (char *)&salt);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "memory allocation error");
		kiwi_password_free(&client_password);
		kiwi_password_free(&client_token);
		if (client->rule->auth_query) {
			kiwi_password_free(&query_password);
		}
		machine_msg_free(msg);
		return -1;
	}

	/* authenticate */
	int check = kiwi_password_compare(&client_password, &client_token);
	kiwi_password_free(&client_password);
	kiwi_password_free(&client_token);
	machine_msg_free(msg);

	if (!check) {
		od_log(&instance->logger, "auth", client, NULL,
		       "user '%s.%s' incorrect password",
		       client->startup.database.value,
		       client->startup.user.value);
		od_frontend_fatal(
			client, KIWI_INVALID_PASSWORD,
			"password authentication failed for user \"%s\"",
			client->startup.user.value);
		return -1;
	}

	return 0;
}

static inline int
od_auth_frontend_scram_sha_256_internal(od_client_t *client,
					od_scram_state_t *scram_state)
{
	/* separated function to ensure, that scram_state will be fried properly in caller */

	od_instance_t *instance = client->global->instance;
	char *mechanisms[2] = { "SCRAM-SHA-256", "SCRAM-SHA-256-PLUS" };

	/* request AuthenticationSASL */
	machine_msg_t *msg;

	if (!machine_io_is_tls(client->io.io)) {
		msg = kiwi_be_write_authentication_sasl(NULL, mechanisms, 1);
	} else {
		msg = kiwi_be_write_authentication_sasl(NULL, mechanisms, 2);
	}

	if (msg == NULL) {
		return -1;
	}

	int rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));

		return -1;
	}

	/* wait for SASLInitialResponse */
	for (;;) {
		msg = od_read(&client->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", client, NULL,
				 "read error: %s", od_io_error(&client->io));

			return -1;
		}

		kiwi_fe_type_t type = *(char *)machine_msg_data(msg);

		od_debug(&instance->logger, "auth", client, NULL, "%s",
			 kiwi_fe_type_to_string(type));

		if (type == KIWI_FE_PASSWORD_MESSAGE) {
			break;
		}

		machine_msg_free(msg);
	}

	/* read the SASLInitialResponse */
	char *mechanism;
	char *auth_data;
	size_t auth_data_size;
	rc = kiwi_be_read_authentication_sasl_initial(machine_msg_data(msg),
						      machine_msg_size(msg),
						      &mechanism, &auth_data,
						      &auth_data_size);
	if (rc == -1) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: malformed SASLInitialResponse message");
		machine_msg_free(msg);
		return -1;
	}

	if (strcmp(mechanism, "SCRAM-SHA-256") != 0 &&
	    strcmp(mechanism, "SCRAM-SHA-256-PLUS") != 0) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: unsupported SASL authorization mechanism");
		machine_msg_free(msg);
		return -1;
	}

	/* use remote or local password source */
	kiwi_password_t query_password;
	kiwi_password_init(&query_password);

	if (client->rule->auth_query) {
		char peer[128];
		od_getpeername(client->io.io, peer, sizeof(peer), 1, 0);
		rc = od_auth_query(client, peer);
		if (rc == -1) {
			od_error(&instance->logger, "auth", client, NULL,
				 "frontend auth: failed to make auth_query");
			od_frontend_error(
				client,
				KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
				"frontend auth: failed to make auth query");
			kiwi_password_free(&query_password);
			machine_msg_free(msg);
			return -1;
		}

		/* TODO: consider support for empty password case. */
		if (client->password.password == NULL) {
			od_log(&instance->logger, "auth", client, NULL,
			       "user '%s.%s' incorrect user from %s",
			       client->startup.database.value,
			       client->startup.user.value, peer);
			od_frontend_error(client, KIWI_INVALID_PASSWORD,
					  "incorrect user");
			machine_msg_free(msg);
			return -1;
		}

		query_password = client->password;
	} else {
		query_password.password_len = client->rule->password_len;
		query_password.password = client->rule->password;
	}

	/* try to parse authentication data */
	rc = od_scram_read_client_first_message(scram_state, auth_data,
						auth_data_size);
	machine_msg_free(msg);
	switch (rc) {
	case 0:
		break;

	case -1:
		return -1;

	case -2:
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: malformed SASLInitialResponse message");
		return -1;

	case -3:
		od_frontend_error(
			client, KIWI_FEATURE_NOT_SUPPORTED,
			"frontend auth: doesn't support channel binding at the moment");
		return -1;

	case -4:
		od_frontend_error(
			client, KIWI_FEATURE_NOT_SUPPORTED,
			"frontend auth: doesn't support authorization identity at the moment");
		return -1;

	case OD_SASL_ERROR_MANDATORY_EXT:
		od_frontend_error(
			client, KIWI_FEATURE_NOT_SUPPORTED,
			"frontend auth: doesn't support mandatory extensions at the moment");
		return -1;
	}

	rc = od_scram_parse_verifier(scram_state, query_password.password);
	if (rc == -1) {
		rc = od_scram_init_from_plain_password(scram_state,
						       query_password.password);
	}

	if (rc == -1) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: invalid user password or SCRAM secret, check your config");

		return -1;
	}

	msg = od_scram_create_server_first_message(scram_state);
	if (msg == NULL) {
		kiwi_password_free(&query_password);

		return -1;
	}

	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));

		return -1;
	}

	/* wait for SASLResponse */
	for (;;) {
		/*
		 * TODO: here's infinite wait, need to replace it with
		 * client_login_timeout
		 */
		msg = od_read(&client->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", client, NULL,
				 "read error: %s", od_io_error(&client->io));

			return -1;
		}

		kiwi_fe_type_t type = *(char *)machine_msg_data(msg);

		od_debug(&instance->logger, "auth", client, NULL, "%s",
			 kiwi_fe_type_to_string(type));

		if (type == KIWI_FE_PASSWORD_MESSAGE) {
			break;
		}

		machine_msg_free(msg);
	}

	/* read the SASLResponse */
	rc = kiwi_be_read_authentication_sasl(machine_msg_data(msg),
					      machine_msg_size(msg), &auth_data,
					      &auth_data_size);

	if (rc == -1) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: malformed client SASLResponse");

		machine_msg_free(msg);
		return -1;
	}

	char *final_nonce;
	size_t final_nonce_size;
	uint8_t *client_proof;
	rc = od_scram_read_client_final_message(client->io.io, scram_state,
						auth_data, auth_data_size,
						&final_nonce, &final_nonce_size,
						&client_proof);
	if (rc == -1) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: malformed client SASLResponse");

		machine_msg_free(msg);
		return -1;
	}

	/* verify signatures */
	rc = od_scram_verify_final_nonce(scram_state, final_nonce,
					 final_nonce_size);
	if (rc == -1) {
		od_frontend_error(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"frontend auth: malformed client SASLResponse: nonce doesn't match");

		machine_msg_free(msg);
		od_free(client_proof);
		return -1;
	}

	rc = od_scram_verify_client_proof(scram_state, client_proof);
	od_free(client_proof);
	if (rc == -1) {
		od_frontend_fatal(
			client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			"password authentication failed for user \"%s\"",
			client->startup.user.value);

		machine_msg_free(msg);
		return -1;
	}

	machine_msg_free(msg);
	/* SASLFinal Message */
	msg = od_scram_create_server_final_message(scram_state);
	if (msg == NULL) {
		kiwi_password_free(&query_password);

		return -1;
	}

	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));

		return -1;
	}

	return 0;
}

static inline int od_auth_frontend_scram_sha_256(od_client_t *client)
{
	od_scram_state_t scram_state;
	od_scram_state_init(&scram_state);

	int rc = od_auth_frontend_scram_sha_256_internal(client, &scram_state);

	od_scram_state_free(&scram_state);

	return rc;
}

static inline int od_auth_frontend_cert(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;
	if (!client->startup.is_ssl_request) {
		od_error(&instance->logger, "auth", client, NULL,
			 "TLS connection required");
		od_frontend_error(client,
				  KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
				  "TLS connection required");
		return -1;
	}

	/* compare client certificate common name */
	od_route_t *route = client->route;
	int rc;
	if (route->rule->auth_common_name_default) {
		rc = machine_io_verify(client->io.io, route->rule->user_name);
		if (!rc) {
			return 0;
		}
	}

	od_list_t *i;
	od_list_foreach(&route->rule->auth_common_names, i)
	{
		od_rule_auth_t *auth;
		auth = od_container_of(i, od_rule_auth_t, link);
		rc = machine_io_verify(client->io.io, auth->common_name);
		if (!rc) {
			return 0;
		}
	}

	od_error(&instance->logger, "auth", client, NULL,
		 "TLS certificate common name mismatch");
	od_frontend_fatal(client, KIWI_INVALID_PASSWORD,
			  "certificate authentication failed for user \"%s\"",
			  client->startup.user.value);
	return -1;
}

static inline int od_auth_frontend_block(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;
	od_log(&instance->logger, "auth", client, NULL,
	       "user '%s.%s' is blocked", client->startup.database.value,
	       client->startup.user.value);
	od_frontend_fatal(client, KIWI_INVALID_AUTHORIZATION_SPECIFICATION,
			  "user blocked: %s %s", client->startup.database.value,
			  client->startup.user.value);
	return 0;
}

int od_auth_frontend(od_client_t *client)
{
	od_instance_t *instance = client->global->instance;

	/* authentication mode */
	int rc;
	switch (client->rule->auth_mode) {
	case OD_RULE_AUTH_CLEAR_TEXT:
		rc = od_auth_frontend_cleartext(client);
		if (rc == -1) {
			return -1;
		}
		break;
	case OD_RULE_AUTH_EXTERNAL:
		rc = od_auth_frontend_external_authentication(client);
		if (rc == -1) {
			return -1;
		}
		break;
	case OD_RULE_AUTH_MD5:
		rc = od_auth_frontend_md5(client);
		if (rc == -1) {
			return -1;
		}
		break;
	case OD_RULE_AUTH_SCRAM_SHA_256:
		rc = od_auth_frontend_scram_sha_256(client);
		if (rc == -1) {
			return -1;
		}
		break;
	case OD_RULE_AUTH_CERT:
		rc = od_auth_frontend_cert(client);
		if (rc == -1) {
			return -1;
		}
		break;
	case OD_RULE_AUTH_BLOCK:
		od_auth_frontend_block(client);
		return -1;
	case OD_RULE_AUTH_NONE:
		break;
	default:
		assert(0);
		break;
	}

	/* pass */
	machine_msg_t *msg;
	msg = kiwi_be_write_authentication_ok(NULL);
	if (msg == NULL) {
		return -1;
	}
	rc = od_write(&client->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", client, NULL,
			 "write error: %s", od_io_error(&client->io));
		return -1;
	}
	return 0;
}

static inline int od_auth_backend_cleartext(od_server_t *server,
					    od_client_t *client)
{
	od_instance_t *instance = server->global->instance;
	od_route_t *route = server->route;
	assert(route != NULL);

	od_debug(&instance->logger, "auth", NULL, server,
		 "requested clear-text authentication");

	/* use storage or user password */
	char *password;
	int password_len;

	if (client != NULL && client->password.password != NULL) {
		password = client->password.password;
		password_len = client->password.password_len - /* NULL */ 1;
	} else if (route->rule->storage_password) {
		password = route->rule->storage_password;
		password_len = route->rule->storage_password_len;
	} else if (route->rule->password) {
		password = route->rule->password;
		password_len = route->rule->password_len;
	} else if (client != NULL &&
		   client->received_password.password != NULL) {
		password = client->received_password.password;
		password_len = client->received_password.password_len - 1;
	} else {
		od_error(&instance->logger, "auth", NULL, server,
			 "password required for route '%s.%s'",
			 route->rule->db_name, route->rule->user_name);
		return -1;
	}
#ifdef LDAP_FOUND
	if (client->rule->ldap_storage_credentials_attr) {
		password = client->ldap_storage_password;
		password_len = client->ldap_storage_password_len;
	}
#endif
	/* PasswordMessage */
	machine_msg_t *msg;
	msg = kiwi_fe_write_password(NULL, password, password_len + 1);
	if (msg == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "memory allocation error");
		return -1;
	}
	int rc;
	rc = od_write(&server->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "write error: %s", od_io_error(&server->io));
		return -1;
	}
	return 0;
}

static inline int od_auth_backend_md5(od_server_t *server, char salt[4],
				      od_client_t *client)
{
	od_instance_t *instance = server->global->instance;
	od_route_t *route = server->route;
	assert(route != NULL);

	od_debug(&instance->logger, "auth", NULL, server,
		 "requested md5 authentication");

	/* use storage user or route user */
	char *user;
	int user_len;
	if (route->rule->storage_user) {
		user = route->rule->storage_user;
		user_len = route->rule->storage_user_len;
	} else {
		user = route->rule->user_name;
		user_len = route->rule->user_name_len;
	}

	/* use storage or user password */
	char *password;
	int password_len;
	if (client != NULL && client->password.password != NULL) {
		password = client->password.password;
		password_len = client->password.password_len - /* NULL */ 1;
	} else if (route->rule->storage_password) {
		password = route->rule->storage_password;
		password_len = route->rule->storage_password_len;
	} else if (route->rule->password) {
		password = route->rule->password;
		password_len = route->rule->password_len;
	} else if (client != NULL &&
		   client->received_password.password != NULL) {
		password = client->received_password.password;
		password_len = client->received_password.password_len - 1;
	} else {
		od_error(&instance->logger, "auth", NULL, server,
			 "password required for route '%s.%s'",
			 route->rule->db_name, route->rule->user_name);
		return -1;
	}
#ifdef LDAP_FOUND
	if (client->rule->ldap_storage_credentials_attr) {
		user = client->ldap_storage_username;
		user_len = client->ldap_storage_username_len;
		password = client->ldap_storage_password;
		password_len = client->ldap_storage_password_len;
	}
#endif
	/* prepare md5 password using server supplied salt */
	kiwi_password_t client_password;
	kiwi_password_init(&client_password);
	int rc;
	rc = kiwi_password_md5(&client_password, user, user_len, password,
			       password_len, salt);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "memory allocation error");
		kiwi_password_free(&client_password);
		return -1;
	}

	/* PasswordMessage */
	machine_msg_t *msg;
	msg = kiwi_fe_write_password(NULL, client_password.password,
				     client_password.password_len);
	kiwi_password_free(&client_password);
	if (msg == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "memory allocation error");
		return -1;
	}
	rc = od_write(&server->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "write error: %s", od_io_error(&server->io));
		return -1;
	}
	return 0;
}

static inline int od_auth_backend_sasl(od_server_t *server, od_client_t *client)
{
	od_instance_t *instance = server->global->instance;
	od_route_t *route = server->route;

	assert(route != NULL);

	/* free possible stale state from previous unlucky auth */
	od_scram_state_free(&server->scram_state);

	if (server->scram_state.client_nonce != NULL) {
		od_error(
			&instance->logger, "auth", NULL, server,
			"unexpected message: AuthenticationSASL was already received");

		return -1;
	}

	od_debug(&instance->logger, "auth", NULL, server,
		 "requested SASL authentication");

	if (!route->rule->storage_password && !route->rule->password &&
	    (client == NULL || client->password.password == NULL) &&
	    client->received_password.password == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "password required for route '%s.%s'",
			 route->rule->db_name, route->rule->user_name);

		return -1;
	}

	/* SASLInitialResponse Message */
	machine_msg_t *msg =
		od_scram_create_client_first_message(&server->scram_state);
	if (msg == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "memory allocation error");

		return -1;
	}

	int rc = od_write(&server->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "write error: %s", od_io_error(&server->io));

		return -1;
	}

	return 0;
}

static inline int od_auth_backend_sasl_continue(od_server_t *server,
						char *auth_data,
						size_t auth_data_size,
						od_client_t *client)
{
	od_instance_t *instance = server->global->instance;
	od_route_t *route = server->route;

	assert(route != NULL);

	if (server->scram_state.client_nonce == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "unexpected message: AuthenticationSASL is missing");

		return -1;
	}

	if (server->scram_state.server_first_message != NULL) {
		od_error(
			&instance->logger, "auth", NULL, server,
			"unexpected message: AuthenticationSASLContinue was already "
			"received");

		return -1;
	}

	/* use storage or user password */
	char *password;

	if (route->rule->storage_password) {
		password = route->rule->storage_password;
	} else if (client != NULL && client->password.password != NULL) {
		od_error(
			&instance->logger, "auth", NULL, server,
			"cannot authenticate with SCRAM secret from auth_query",
			route->rule->db_name, route->rule->user_name);

		return -1;
	} else if (route->rule->password) {
		password = route->rule->password;
	} else if (client->received_password.password) {
		password = client->received_password.password;
	} else {
		od_error(&instance->logger, "auth", NULL, server,
			 "password required for route '%s.%s'",
			 route->rule->db_name, route->rule->user_name);

		return -1;
	}
#ifdef LDAP_FOUND
	if (client->rule->ldap_storage_credentials_attr) {
		password = client->ldap_storage_password;
	}
#endif
	od_debug(&instance->logger, "auth", NULL, server,
		 "continue SASL authentication using password %s", password);

	/* SASLResponse Message */
	machine_msg_t *msg = od_scram_create_client_final_message(
		&server->scram_state, password, auth_data, auth_data_size);
	if (msg == NULL) {
		od_error(&instance->logger, "auth", NULL, server,
			 "malformed SASLResponse message");

		return -1;
	}

	int rc = od_write(&server->io, msg);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "write error: %s", od_io_error(&server->io));

		return -1;
	}

	return 0;
}

static inline int od_auth_backend_sasl_final(od_server_t *server,
					     char *auth_data,
					     size_t auth_data_size)
{
	od_instance_t *instance = server->global->instance;

	assert(server->route);

	if (server->scram_state.server_first_message == NULL) {
		od_error(
			&instance->logger, "auth", NULL, server,
			"unexpected message: AuthenticationSASLContinue is missing");

		return -1;
	}

	od_debug(&instance->logger, "auth", NULL, server,
		 "finishing SASL authentication");

	int rc = od_scram_verify_server_signature(&server->scram_state,
						  auth_data, auth_data_size);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "server verify failed: invalid signature");

		return -1;
	}

	return 0;
}

int od_auth_backend(od_server_t *server, machine_msg_t *msg,
		    od_client_t *client)
{
	od_instance_t *instance = server->global->instance;
	assert(*(char *)machine_msg_data(msg) == KIWI_BE_AUTHENTICATION);

	uint32_t auth_type;
	char salt[4];
	char *auth_data = NULL;
	size_t auth_data_size = 0;
	int rc;
	rc = kiwi_fe_read_auth(machine_msg_data(msg), machine_msg_size(msg),
			       &auth_type, salt, &auth_data, &auth_data_size);
	if (rc == -1) {
		od_error(&instance->logger, "auth", NULL, server,
			 "failed to parse authentication message");
		return -1;
	}

	od_debug(&instance->logger, "auth", NULL, server,
		 "received msg type %u", auth_type);

	msg = NULL;

	switch (auth_type) {
	/* AuthenticationOk */
	case 0:
		return 0;
	/* AuthenticationCleartextPassword */
	case 3:
		rc = od_auth_backend_cleartext(server, client);
		if (rc == -1) {
			return -1;
		}
		break;
	/* AuthenticationMD5Password */
	case 5:
		rc = od_auth_backend_md5(server, salt, client);
		if (rc == -1) {
			return -1;
		}
		break;
	/* AuthenticationSASL */
	case 10:
		rc = od_auth_backend_sasl(server, client);
		if (rc != OK_RESPONSE) {
			od_scram_state_free(&server->scram_state);
		}
		return rc;
	/* AuthenticationSASLContinue */
	case 11:
		rc = od_auth_backend_sasl_continue(server, auth_data,
						   auth_data_size, client);
		if (rc != OK_RESPONSE) {
			od_scram_state_free(&server->scram_state);
		}
		return rc;
	/* AuthenticationSASLFinal */
	case 12:
		rc = od_auth_backend_sasl_final(server, auth_data,
						auth_data_size);
		od_scram_state_free(&server->scram_state);
		return rc;
	/* unsupported */
	default:
		od_error(&instance->logger, "auth", NULL, server,
			 "unsupported authentication method");
		return -1;
	}

	/* wait for authentication response */
	for (;;) {
		msg = od_read(&server->io, UINT32_MAX);
		if (msg == NULL) {
			od_error(&instance->logger, "auth", NULL, server,
				 "read error: %s", od_io_error(&server->io));
			return -1;
		}
		kiwi_be_type_t type = *(char *)machine_msg_data(msg);
		od_debug(&instance->logger, "auth", NULL, server, "%s",
			 kiwi_be_type_to_string(type));

		switch (type) {
		case KIWI_BE_AUTHENTICATION:
			rc = kiwi_fe_read_auth(machine_msg_data(msg),
					       machine_msg_size(msg),
					       &auth_type, salt, NULL, NULL);
			machine_msg_free(msg);
			if (rc == -1) {
				od_error(
					&instance->logger, "auth", NULL, server,
					"failed to parse authentication message");
				return -1;
			}
			if (auth_type != 0) {
				od_error(&instance->logger, "auth", NULL,
					 server,
					 "incorrect authentication flow");
				return 0;
			}
			return 0;
		case KIWI_BE_ERROR_RESPONSE:
			od_backend_error(server, "auth", machine_msg_data(msg),
					 machine_msg_size(msg));
			/* save error to fwd it to client */
			server->error_connect = msg;
			return -1;
		default:
			machine_msg_free(msg);
			break;
		}
	}
	return 0;
}
