--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/main/kotlin/de/uapcore/lightpit/AbstractServlet.kt Fri Apr 02 11:59:14 2021 +0200 @@ -0,0 +1,182 @@ +/* + * Copyright 2021 Mike Becker. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package de.uapcore.lightpit + +import de.uapcore.lightpit.DataSourceProvider.Companion.SC_ATTR_NAME +import de.uapcore.lightpit.dao.DataAccessObject +import de.uapcore.lightpit.dao.createDataAccessObject +import java.sql.SQLException +import java.util.* +import javax.servlet.http.HttpServlet +import javax.servlet.http.HttpServletRequest +import javax.servlet.http.HttpServletResponse + +abstract class AbstractServlet : LoggingTrait, HttpServlet() { + + /** + * Contains the GET request mappings. + */ + private val getMappings = mutableMapOf<PathPattern, MappingMethod>() + + /** + * Contains the POST request mappings. + */ + private val postMappings = mutableMapOf<PathPattern, MappingMethod>() + + protected fun get(pattern: String, method: MappingMethod) { + getMappings[PathPattern(pattern)] = method + } + + protected fun post(pattern: String, method: MappingMethod) { + postMappings[PathPattern(pattern)] = method + } + + private fun notFound(http: HttpRequest, dao: DataAccessObject) { + http.response.sendError(HttpServletResponse.SC_NOT_FOUND) + } + + private fun findMapping( + mappings: Map<PathPattern, MappingMethod>, + req: HttpServletRequest + ): Pair<PathPattern, MappingMethod> { + val requestPath = sanitizedRequestPath(req) + val candidates = mappings.filter { it.key.matches(requestPath) } + return if (candidates.isEmpty()) { + Pair(PathPattern(requestPath), ::notFound) + } else { + if (candidates.size > 1) { + logger().warn("Ambiguous mapping for request path '{}'", requestPath) + } + candidates.entries.first().toPair() + } + } + + private fun invokeMapping( + mapping: Pair<PathPattern, MappingMethod>, + req: HttpServletRequest, + resp: HttpServletResponse, + dao: DataAccessObject + ) { + val params = mapping.first.obtainPathParameters(sanitizedRequestPath(req)) + val method = mapping.second + logger().trace("invoke {}", method) + method(HttpRequest(req, resp, params), dao) + } + + private fun sanitizedRequestPath(req: HttpServletRequest) = req.pathInfo ?: "/" + + private fun doProcess( + req: HttpServletRequest, + resp: HttpServletResponse, + mappings: Map<PathPattern, MappingMethod> + ) { + val session = req.session + + // the very first thing to do is to force UTF-8 + req.characterEncoding = "UTF-8" + + // choose the requested language as session language (if available) or fall back to english, otherwise + if (session.getAttribute(Constants.SESSION_ATTR_LANGUAGE) == null) { + val availableLanguages = availableLanguages() + val reqLocale = req.locale + val sessionLocale = if (availableLanguages.contains(reqLocale)) reqLocale else availableLanguages.first() + session.setAttribute(Constants.SESSION_ATTR_LANGUAGE, sessionLocale) + logger().debug( + "Setting language for new session {}: {}", session.id, sessionLocale.displayLanguage + ) + } else { + val sessionLocale = session.getAttribute(Constants.SESSION_ATTR_LANGUAGE) as Locale + resp.locale = sessionLocale + logger().trace("Continuing session {} with language {}", session.id, sessionLocale) + } + + // set some internal request attributes + val http = HttpRequest(req, resp) + val fullPath = req.servletPath + Optional.ofNullable(req.pathInfo).orElse("") + req.setAttribute(Constants.REQ_ATTR_BASE_HREF, http.baseHref) + req.setAttribute(Constants.REQ_ATTR_PATH, fullPath) + req.getHeader("Referer")?.let { + // TODO: add a sanity check to avoid link injection + req.setAttribute(Constants.REQ_ATTR_REFERER, it) + } + + // if this is an error path, bypass the normal flow + if (fullPath.startsWith("/error/")) { + http.styleSheets = listOf("error") + http.render("error") + return + } + + // obtain a connection and create the data access objects + val dsp = req.servletContext.getAttribute(SC_ATTR_NAME) as DataSourceProvider + val dialect = dsp.dialect + val ds = dsp.dataSource + if (ds == null) { + resp.sendError( + HttpServletResponse.SC_SERVICE_UNAVAILABLE, + "JNDI DataSource lookup failed. See log for details." + ) + return + } + try { + ds.connection.use { connection -> + val dao = createDataAccessObject(dialect, connection) + try { + connection.autoCommit = false + invokeMapping(findMapping(mappings, req), req, resp, dao) + connection.commit() + } catch (ex: SQLException) { + logger().warn("Database transaction failed (Code {}): {}", ex.errorCode, ex.message) + logger().debug("Details: ", ex) + resp.sendError( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + "Unhandled Transaction Error - Code: " + ex.errorCode + ) + connection.rollback() + } + } + } catch (ex: SQLException) { + logger().error("Severe Database Exception (Code {}): {}", ex.errorCode, ex.message) + logger().debug("Details: ", ex) + resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Database Error - Code: " + ex.errorCode) + } + } + + override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { + doProcess(req, resp, getMappings) + } + + override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { + doProcess(req, resp, postMappings) + } + + protected fun availableLanguages(): List<Locale> { + val langTags = servletContext.getInitParameter(Constants.CTX_ATTR_LANGUAGES)?.split(",")?.map(String::trim) ?: emptyList() + val locales = langTags.map(Locale::forLanguageTag).filter { it.language.isNotEmpty() } + return if (locales.isEmpty()) listOf(Locale.ENGLISH) else locales + } + +}