Jekyll2023-11-20T20:38:35+00:00/feed.xmlAlex Dowad ComputesExplorations in the world of codeA Toy Runge-Kutta Differential Equation Solver2023-08-22T00:00:00+00:002023-08-22T00:00:00+00:00/a-toy-runge-kutta-solver<p>This post presents a simple, interactive <a href='https://en.wikipedia.org/wiki/Differential_equation'>differential equation</a> solution graphing tool based on the classic <a href='https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods'>Runge-Kutta method</a> (which is really an algorithm).</p>
<p>It graphs solutions to one 1<sup>st</sup>-order equation, one 2<sup>nd</sup>-order equation, or a system of two 1<sup>st</sup>-order equations. The right-hand side of the equation(s) must be entered using syntax similar to expressions in the C, Java, or JavaScript programming languages (<a href='#syntax'>syntax help</a>). Since the solution of a differential equation depends on starting conditions, you can set the range of starting conditions which should be graphed and the “time” value at which the starting conditions apply. The tool will draw one line for each set of starting conditions.</p>
<p>Hover over a solution line to see what starting conditions it is based on. Roll your mouse wheel to zoom in and out; hold down the middle mouse button to pan. On mobile, use one finger to pan and pinch with two fingers to zoom.</p>
<p><b>Click these buttons to see a variety of samples:</b> <button id='sample_Y'>y' = y</button> <button id='sample_cosY'>y' = cos(y)</button> <button id='sample_cosT'>y'' = cos(t)</button> <button id='sample_cosY_2'>y'' = cos(y)</button> <button id='sample_logistic'>y' = ay - by² <i>(the logistic equation)</i></button> <button id='sample_mathieu'>y'' = y(a - 2q cos(t)) <i>(Mathieu's equation)</i></button> <button id='sample_riccati'>y' = -y²t² + (2y / t) <i>(a Riccati equation)</i></button> <button id='sample_cauchy'>y'' = -(a / t)y' - (b / t²)y <i>(a Cauchy-Euler equation)</i></button> <button id='sample_meissner'>y'' = -y(α² + ω² sign(cos(t))) <i>(Meissner equation)</i></button> <button id='sample_predprey'>y' = y - (0.01yz), z' = -z + (0.02yz) <i>(a “predator-prey” system)</i></button></p>
<p>Try playing with the parameters for some of the samples; you can get pretty wild pictures!</p>
<hr>
<div id='view1' style='height: 35rem'></div>
<div id='view2' style='height: 35rem'></div>
<div id='typewrapper'>
Graph solutions for:
<select id='graphtype'>
<option value='order1'>one first-order equation</option>
<option value='order2' selected>one second-order equation</option>
<option value='eqs2_order1'>a system of two first-order equations</option>
</select>
</div>
<div class='expwrapper'>
<label for='exp1' style="font-style: italic">y'' =</label> <input id='exp1' value='cos(y)'>
</div>
<div class='expwrapper eqs2only' style="display: none">
<span style="font-style: italic">z' =</span> <input id='exp2' value='cos(z)'>
</div>
<div class='params'>
Start time: <input id='starttime' type='number' value='-10'>
End time: <input id='endtime' type='number' value='10'>
Time step: <input id='timestep' type='number' value='0.02'>
<span title="Starting values for variables are at this time">t₀:</span> <input id='t0' type='number' value='0'>
<label for='starty1' class='rowstart'>Min y(0):</label> <input id='starty1' type='number' value='-3'>
<label for='starty2'>Max y(0):</label> <input id='starty2' type='number' value='3'>
# values to graph: <input id='nstarty' type='number' value='5' max='50'>
<label for='startdy1' class='order2only rowstart'>Min y'(0):</label> <input id='startdy1' type='number' value='-3' class='order2only'>
<label for='startdy2' class='order2only'>Max y'(0):</label> <input id='startdy2' type='number' value='3' class='order2only'>
<span class='order2only'># values to graph:</span> <input id='nstartdy' type='number' value='5' max='50' class='order2only'>
<label for='startz1' class='eqs2only rowstart' style="display: none">Min z(0):</label> <input id='startz1' type='number' value='-3' class='eqs2only' style="display: none">
<label for='startz2' class='eqs2only' style="display: none">Max z(0):</label> <input id='startz2' type='number' value='3' class='eqs2only' style="display: none">
<span class='eqs2only' style="display: none"># values to graph:</span> <input id='nstartz' type='number' value='5' max='50' class='eqs2only' style="display: none">
<label for='graph1' class='eqs2only rowstart' style="display: none">Left graph:</label>
<select id='graph1' class='eqs2only' style="display: none">
<option value='z-y' selected>z vs y</option>
<option value='y-t'>y vs t</option>
<option value='dy-y'>y' vs y</option>
<option value='z-t'>z vs t</option>
<option value='dz-z'>z' vs z</option>
</select>
<label for='graph2' class='eqs2only' style="display: none">Right graph:</label>
<select id='graph2' class='eqs2only' style="display: none">
<option value='z-y'>z vs y</option>
<option value='y-t' selected>y vs t</option>
<option value='dy-y'>y' vs y</option>
<option value='z-t'>z vs t</option>
<option value='dz-z'>z' vs z</option>
</select>
</div>
<script src='/assets/js/d3-array.js'></script>
<script src='/assets/js/d3-color.js'></script>
<script src='/assets/js/d3-interpolate.js'></script>
<script src='/assets/js/d3-scale.js'></script>
<script src='/assets/js/d3-selection.js'></script>
<script src='/assets/js/d3-format.js'></script>
<script src='/assets/js/d3-axis.js'></script>
<script src='/assets/js/d3-dispatch.js'></script>
<script src='/assets/js/d3-timer.js'></script>
<script src='/assets/js/d3-ease.js'></script>
<script src='/assets/js/d3-transition.js'></script>
<script>
'use strict';
const $id = document.getElementById.bind(document);
const $class = document.getElementsByClassName.bind(document);
Element.prototype.$class = Element.prototype.getElementsByClassName;
Element.prototype.$class1 = function(cssClass) { return this.$class(cssClass)?.[0]; }
Element.prototype.listen = Element.prototype.addEventListener;
// When getting min/max of a large array, this is far faster than using Math.min(...array)
// and Math.max(...array), at least with V8 JS engine
// Plus, we can also filter out NaN and Infinity
function minMax(array) {
let i = 1, limit = array.length, max = array[0], min = array[0];
if (!Number.isFinite(max) || Number.isNaN(max)) {
while (i < limit) {
let value = array[i];
if (Number.isFinite(value) && !Number.isNaN(value)) {
min = max = value;
break;
}
i++;
}
}
while (i < limit) {
let value = array[i];
if (Number.isFinite(value) && !Number.isNaN(value)) {
min = Math.min(min, value); // Hopefully this JITs to a single MINSD instruction on x86-64
max = Math.max(max, value);
}
i++;
}
return [min, max];
}
function makeSvgNode(tag, cssClass, attributes = {}) {
const node = document.createElementNS("http://www.w3.org/2000/svg", tag);
if (cssClass)
node.classList.add(cssClass);
for (const [key, value] of Object.entries(attributes))
node.setAttribute(key, value);
return node;
}
function appendSvgNode(parent, tag, cssClass, attributes = {}) {
const node = makeSvgNode(tag, cssClass, attributes);
parent.appendChild(node);
return node;
}
function prependSvgNode(parent, tag, cssClass, attributes = {}) {
const node = makeSvgNode(tag, cssClass, attributes);
parent.prepend(node);
return node;
}
// Make SVG which will resize itself to match its container
function initDynamicSvg(wrapper) {
const svg = appendSvgNode(wrapper, 'svg', undefined, { viewBox: '0 0 200 200', preserveAspectRatio: 'none' });
const watcher = new ResizeObserver(entries => {
svg.style.width = Math.round(entries[0].contentBoxSize[0].inlineSize);
svg.style.height = Math.round(entries[0].contentBoxSize[0].blockSize);
});
watcher.observe(wrapper);
svg.style.width = wrapper.offsetWidth;
svg.style.height = wrapper.offsetHeight;
return svg;
}
// Convert position of a mouse event from px coordinates to SVG unit coordinates
function getMousePosition(event, svg) {
const CTM = svg.getScreenCTM();
if (event.touches)
event = event.touches[0];
return [
(event.clientX - CTM.e) / CTM.a,
(event.clientY - CTM.f) / CTM.d
];
}
function fireEvent(node, eventType) {
node.dispatchEvent(new Event(eventType));
}
// Parser for C/Java/JavaScript-like expression syntax
function tokenize(str) {
str = str.replaceAll(/\s+/g, ''); // strip all whitespace
str = str.replaceAll("\u{2212}", '-') // convert Unicode MINUS to ordinary dash
if (!str.length)
return [];
str = str.toLowerCase();
// Use 'sticky' regex so that matches must be contiguous; we don't just want to pick
// matches out of the string, but want to split the entire string into matches
const matches = Array.from(str.matchAll(/\d+(\.\d+)?|[()^/+-]|\*\*?|[\w'π]+/gy));
const lastMatch = matches.at(-1);
const matchedUpTo = lastMatch.index + lastMatch[0].length;
if (matchedUpTo !== str.length) {
// The last match did not appear at the end of the string, meaning there was
// some text which we couldn't tokenize
return [false, { offset: matchedUpTo, length: 1, message: "I don't know what to do with the highlighted character." }];
}
return [true, Array.from(matches).map((m) => m[0])]
}
const knownFunctions = new Set(['sqrt', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'ln', 'log', 'log10', 'log2', 'sgn', 'sign']);
function parse(tokens) {
const origTokens = Array.from(tokens); // Keep a copy for later inspection
let exp;
try {
exp = parseExpression(tokens);
} catch (e) {
const match = e.message.match(/^Parse error: (.*)$/)
if (match) {
let lastTokenIndex = Math.max(0, origTokens.length-tokens.length-1);
const precedingLength = origTokens.slice(0, lastTokenIndex).reduce((sum, s) => sum + s.length, 0);
return [false, { offset: precedingLength, length: origTokens[lastTokenIndex].length, message: match[1] }]
}
throw e; // Not a parse error
}
if (tokens.length) {
// All tokens should have been consumed by the parser
const entireLength = origTokens.reduce((sum, s) => sum + s.length, 0);
const remainingLength = tokens.reduce((sum, s) => sum + s.length, 0);
return [false, { offset: entireLength - remainingLength, length: remainingLength, message: "The last part of this expression is extraneous; it doesn't seem like it should be there" }];
}
return [true, exp];
}
function parseExpression(tokens) {
if (!tokens.length)
throw new Error("Parse error: The expression is incomplete, you still need to add something more");
let exp = parseNonBinaryExpression(tokens);
if (!tokens.length)
return exp;
// Are we in the middle of a binary expression?
const operands = [exp], operators = [];
let tok = tokens[0];
while (tok === '+' || tok === '-' || tok === '*' || tok === '/' || tok === '**' || tok === '^') {
tokens.shift();
operators.push(tok);
operands.push(parseNonBinaryExpression(tokens));
tok = tokens[0];
}
if (operators.length) {
const operatorPriority = [new Set(['**', '^']), new Set(['*', '/']), new Set(['+', '-'])];
for (const toCollapse of operatorPriority) {
let i = 0;
while (i < operators.length) {
if (toCollapse.has(operators[i])) {
operands[i] = [operators[i], operands[i], operands[i+1]];
operands.splice(i+1, 1);
operators.splice(i, 1);
} else {
i++;
}
}
}
if (operands.length !== 1)
throw new Error("Parser could not properly build parse tree for binary operations");
return operands[0];
}
return exp;
}
function parseNonBinaryExpression(tokens) {
if (!tokens.length)
throw new Error("Parse error: The expression is incomplete, you still need to add something more");
const tok = tokens.shift();
let subExp, nextTok;
if (tok === '(') {
subExp = parseExpression(tokens);
nextTok = tokens.shift();
if (nextTok !== ')') {
throw new Error("Parse error: There should have been a closing parenthesis here");
}
return subExp;
}
if (tok === ')' || tok === '+' || tok === '*' || tok === '/' || tok === '**' || tok === '^')
throw new Error("Parse error: I don't know what to do with the highlighted symbol in this position");
if (/^\-?\d+(\.\d+)?$/.test(tok) || tok === 't' || tok === 'x' || tok === 'y' || tok === 'z' || tok === "y'" || tok == 'e' || tok === 'π' || tok === 'pi')
return tok;
if (tok === '-')
return ['-', parseNonBinaryExpression(tokens)];
if (/[\w']+/.test(tok)) {
if (!knownFunctions.has(tok))
throw new Error(`Parse error: I don't know what '${tok}' means`);
nextTok = tokens.shift();
if (nextTok !== '(')
throw new Error(`Parse error: After the function '${tok}', there should be an opening parenthesis before the parameter`);
let param = parseExpression(tokens);
nextTok = tokens.shift();
if (nextTok !== ')')
throw new Error("Parse error: There should be a closing parenthesis after the function parameter");
return [tok, param];
}
throw new Error(`Unexpected token: ${tok}`);
}
// Convert parsed expression to a JS function object
function compileFunction(ast) {
return new Function('t', 'y', 'z', `return ${compileNode(ast)};`);
}
function compileNode(ast) {
if (ast === 't' || ast === 'y' || ast === 'z' || (typeof(ast) === 'string' && /^\-?\d+(\.\d+)?$/.test(ast)))
return ast;
if (ast === "y'")
return 'z';
// Both 't' and 'x' are traditional names for the independent variable; accept either
if (ast === 'x')
return 't';
if (ast === 'e')
return 'Math.E';
if (ast === 'π' || ast === 'pi')
return 'Math.PI';
if (Array.isArray(ast)) {
const fn = ast[0];
if (fn === 'cos' || fn === 'sin' || fn === 'tan' || fn === 'sqrt' || fn === 'log2' || fn === 'log10')
return `Math.${fn}(${compileNode(ast[1])})`;
if (fn === 'arcsin' || fn === 'arccos' || fn === 'arctan')
return `Math.a${fn.substr(3,3)}(${compileNode(ast[1])})`;
if (fn === 'ln' || fn === 'log')
return `Math.log(${compileNode(ast[1])})`
if (fn === 'sgn' || fn === 'sign')
return `Math.sign(${compileNode(ast[1])})`
if (fn === '-' && ast.length === 2) {
if (typeof(ast[1]) === 'string')
return `-${compileNode(ast[1])}`;
else
return `-(${compileNode(ast[1])})`;
}
if (fn === '+' || fn === '-' || fn === '*' || fn === '/' || fn === '**')
return `(${compileNode(ast[1])}) ${fn} (${compileNode(ast[2])})`;
if (fn === '^')
return `(${compileNode(ast[1])}) ** (${compileNode(ast[2])})`;
}
throw new Error(`Unexpected AST node: ${JSON.stringify(ast)}`);
}
function highlightParseError(input, error) {
const inputStyles = getComputedStyle(input);
const textMeasurer = document.createElement('span');
textMeasurer.style.visibility = 'hidden';
textMeasurer.style.font = inputStyles.font;
document.body.appendChild(textMeasurer);
const xOffset = input.offsetLeft + parseInt(inputStyles.borderLeftWidth) + parseInt(inputStyles.paddingLeft);
textMeasurer.innerText = input.value.slice(0, error.offset);
const left = textMeasurer.offsetWidth + xOffset;
textMeasurer.innerText = input.value.slice(0, error.offset + error.length);
const right = textMeasurer.offsetWidth + xOffset;
const bottom = parseInt(inputStyles.borderBottomWidth);
const underline = document.createElement('span');
underline.classList.add('underline')
underline.style.left = left + 'px';
underline.style.width = (right - left) + 'px';
underline.style.bottom = bottom + 'px';
input.parentNode.appendChild(underline);
if (error.message)
input.setAttribute('title', error.message);
textMeasurer.remove();
}
function clearParseError(input) {
input.setAttribute('title', '');
for (const underline of input.parentNode.$class('underline'))
underline.remove();
}
class Solution {
constructor(t_start, t_end, Δt, nVars) {
this.t_start = t_start;
this.t_end = t_end;
this.Δt = Δt;
this.nVars = nVars;
this.nPoints = Math.floor(((t_end - t_start) / Δt) + 1);
// Packed array of values for all variables at each time step:
// (Values for each variable occupy a contiguous range of indices)
this.array = new Float64Array(nVars * this.nPoints);
// Packed array of estimated derivatives of all variables at each time step:
this.diff = undefined;
}
// Find range of indices with values for a particular variable
startIndex(varIndex) {
return this.nPoints * varIndex;
}
endIndex(varIndex) {
return this.startIndex(varIndex+1);
}
// Find index of value for a particular variable and time step
timeIndex(varIndex, t) {
return this.startIndex(varIndex) + Math.floor((t - this.t_start) / this.Δt);
}
timeValue(varIndex, t) {
return this.array[this.timeIndex(varIndex, t)];
}
values(varIndex) {
return this.array.subarray(this.startIndex(varIndex), this.endIndex(varIndex));
}
}
class PhaseLine {
constructor(solution, varIndexX = 0, varIndexY = 1) {
this.solution = solution;
this.varIndexX = varIndexX;
this.varIndexY = varIndexY;
this._bounds = undefined;
}
title() {
return this.solution.startConditions;
}
svgPath(xScale, yScale) {
const limit = this.solution.endIndex(this.varIndexX), ary = this.solution.array;
const minXDiff = Math.abs(xScale.invert(0.01) - xScale.invert(0)), minYDiff = Math.abs(yScale.invert(0.01) - yScale.invert(0));
let s = '', startingLine = true, prevX, prevY;
for (let i = this.solution.startIndex(this.varIndexX), j = this.solution.startIndex(this.varIndexY); i < limit; i++, j++) {
const x = ary[i], y = ary[j];
if (Number.isNaN(x) || !Number.isFinite(x) || Number.isNaN(y) || !Number.isFinite(y) || Math.abs(x) > 1000000000000 || Math.abs(y) > 1000000000000) {
startingLine = true;
} else if (startingLine || Math.abs(x - prevX) >= minXDiff || Math.abs(y - prevY) >= minYDiff) {
s += (startingLine ? 'M' : 'L') + xScale(x).toFixed(2) + ',' + yScale(y).toFixed(2);
prevX = x;
prevY = y;
startingLine = false;
}
}
return s;
}
bounds() {
if (this._bounds)
return this._bounds;
const y_values = this.solution.values(0);
const z_values = this.solution.values(1);
return this._bounds = minMax(y_values).concat(minMax(z_values));
}
}
// Graph the value of a variable against its derivative
class DerivativeLine {
constructor(solution, varIndex = 0, fn) {
this.solution = solution;
this.varIndex = varIndex;
this._bounds = undefined;
this.derivValues = new Float64Array(solution.endIndex(0));
let t = solution.t_start;
if (solution.nVars === 1) {
const values = solution.values(varIndex);
for (let i = 0; i < this.derivValues.length; i++) {
this.derivValues[i] = fn(t, values[i]);
t += solution.Δt;
}
} else if (solution.nVars === 2) {
const yValues = solution.values(0), zValues = solution.values(1);
for (let i = 0; i < this.derivValues.length; i++) {
this.derivValues[i] = fn(t, yValues[i], zValues[i]);
t += solution.Δt;
}
}
}
title() {
return this.solution.startConditions;
}
svgPath(xScale, yScale) {
const values = this.solution.values(this.varIndex), limit = this.solution.endIndex(0);
const minXDiff = Math.abs(xScale.invert(0.01) - xScale.invert(0)), minYDiff = Math.abs(yScale.invert(0.01) - yScale.invert(0));
let s = '', startingLine = true, prevX, prevY;
for (let i = 0; i < limit; i++) {
const x = values[i], y = this.derivValues[i];
if (Number.isNaN(x) || !Number.isFinite(x) || Number.isNaN(y) || !Number.isFinite(y) || Math.abs(x) > 1000000000000 || Math.abs(y) > 1000000000000) {
startingLine = true;
} else if (startingLine || Math.abs(x - prevX) >= minXDiff || Math.abs(y - prevY) >= minYDiff) {
s += (startingLine ? 'M' : 'L') + xScale(x).toFixed(2) + ',' + yScale(y).toFixed(2);
prevX = x;
prevY = y;
startingLine = false;
}
}
return s;
}
bounds() {
if (this._bounds)
return this._bounds;
const values = this.solution.values(this.varIndex);
return this._bounds = minMax(values).concat(minMax(this.derivValues));
}
}
class TimeLine {
constructor(solution, varIndex = 0) {
this.solution = solution;
this.varIndex = varIndex;
this._bounds = undefined;
}
title() {
return this.solution.startConditions;
}
svgPath(xScale, yScale) {
const limit = this.solution.endIndex(this.varIndex), Δt = this.solution.Δt, ary = this.solution.array;
let s = '', t = this.solution.t_start, startingLine = true;
for (let i = this.solution.startIndex(this.varIndex); i < limit; i++) {
const val = ary[i];
if (!Number.isNaN(val) && Number.isFinite(val) && Math.abs(val) <= 1000000000000) {
s += (startingLine ? 'M' : 'L') + xScale(t).toFixed(2) + ',' + yScale(val).toFixed(2);
startingLine = false;
} else {
startingLine = true;
}
t += Δt;
}
return s;
}
bounds() {
if (this._bounds)
return this._bounds;
const values = this.solution.values(this.varIndex);
return this._bounds = [this.solution.t_start, this.solution.t_end].concat(minMax(values));
}
}
function combinedBounds(lines) {
const allBounds = lines.map((line) => line.bounds()).filter((bounds) => bounds.every((b) => !Number.isNaN(b) && Number.isFinite(b)));
return [
Math.min(...allBounds.map((b) => b[0])),
Math.max(...allBounds.map((b) => b[1])),
Math.min(...allBounds.map((b) => b[2])),
Math.max(...allBounds.map((b) => b[3]))
];
}
// Trace out evolution of our system using classic Runge-Kutta (AKA "RK4")
// Store results in a packed array of floats
function rk4trace(y, t, Δt, fn, ary, i, Δi, limit) {
while (i !== limit) {
const half_Δt = Δt / 2.0;
const next_t = t + Δt;
const half_t = t + half_Δt;
const k_1 = fn(t, y); // Slope at starting point
const k_2 = fn(half_t, y + (half_Δt * k_1)); // Estimated slope at mid-point
const k_3 = fn(half_t, y + (half_Δt * k_2)); // Another estimate of slope at mid-point
const k_4 = fn(next_t, y + (Δt * k_3)); // Estimated slope at endpoint
const slope = (k_1 + 2*k_2 + 2*k_3 + k_4) / 6.0; // Weighted average of those four slopes
y += Δt * slope;
ary[i] = y;
t += Δt;
i += Δi;
}
}
// Apply RK4 to find phase lines for a system with one dependent variable
function rk4solve(y_0, t_0, t_start, t_end, Δt, fn) {
let t = t_0, y = y_0, solution = new Solution(t_start, t_end, Δt, 1);
let i = solution.timeIndex(0, t_0);
solution.array[i] = y_0;
// Trace out phase line from starting point
rk4trace(y_0, t_0, Δt, fn, solution.array, i+1, 1, solution.endIndex(0));
// Trace out phase line in the opposite direction from the starting point
rk4trace(y_0, t_0, -Δt, fn, solution.array, i-1, -1, solution.startIndex(0)-1);
return solution;
}
function rk4trace_2(y, z, t, Δt, fn_y, fn_z, ary, i, Δi, offset, limit) {
while (i !== limit) {
const half_Δt = Δt / 2.0;
const next_t = t + Δt;
const half_t = t + half_Δt;
const k_1y = fn_y(t, y, z); // Slope at starting point
const k_1z = fn_z(t, y, z);
const k_2y = fn_y(half_t, y + (half_Δt * k_1y), z + (half_Δt * k_1z)); // Estimated slope at mid-point
const k_2z = fn_z(half_t, y + (half_Δt * k_1y), z + (half_Δt * k_1z));
const k_3y = fn_y(half_t, y + (half_Δt * k_2y), z + (half_Δt * k_2z)); // Another estimate of slope at mid-point
const k_3z = fn_z(half_t, y + (half_Δt * k_2y), z + (half_Δt * k_2z));
const k_4y = fn_y(next_t, y + (Δt * k_3y), z + (Δt * k_3z)); // Estimated slope at endpoint
const k_4z = fn_z(next_t, y + (Δt * k_3y), z + (Δt * k_3z));
const slope_y = (k_1y + 2*k_2y + 2*k_3y + k_4y) / 6.0; // Weighted average of those four slopes
const slope_z = (k_1z + 2*k_2z + 2*k_3z + k_4z) / 6.0;
y += Δt * slope_y;
z += Δt * slope_z;
ary[i] = y;
ary[i+offset] = z;
t += Δt;
i += Δi;
}
}
// Apply RK4 to find phase lines for a system with two dependent variables
function rk4solve_2(y_0, z_0, t_0, t_start, t_end, Δt, fn_y, fn_z) {
let t = t_0, y = y_0, z = z_0, solution = new Solution(t_start, t_end, Δt, 2);
let i = solution.timeIndex(0, t_0), offset = solution.startIndex(1);
solution.array[i] = y_0;
solution.array[i+offset] = z_0;
// Trace out phase line from starting point
rk4trace_2(y_0, z_0, t_0, Δt, fn_y, fn_z, solution.array, i+1, 1, offset, solution.endIndex(0));
// Trace out phase line in the opposite direction from the starting point
rk4trace_2(y_0, z_0, t_0, -Δt, fn_y, fn_z, solution.array, i-1, -1, offset, solution.startIndex(0)-1);
return solution;
}
function order1_solutions(y_0, t_0, t_start, t_end, Δt, fn) {
const result = [];
if (!Array.isArray(y_0))
y_0 = [y_0];
for (const y of y_0) {
const solution = rk4solve(y, t_0, t_start, t_end, Δt, fn);
solution.startConditions = `y(${t_0}) = ${y}`;
solution.startY = y;
result.push(solution);
}
return result;
}
function order2_solutions(y_0, dy_0, t_0, t_start, t_end, Δt, fn) {
const result = [];
if (!Array.isArray(y_0))
y_0 = [y_0];
if (!Array.isArray(dy_0))
dy_0 = [dy_0];
for (const y of y_0) {
for (const dy of dy_0) {
const solution = rk4solve_2(y, dy, t_0, t_start, t_end, Δt, function(t, y, dy) { return dy; }, fn);
solution.startConditions = `y(${t_0}) = ${y}, y'(${t_0}) = ${dy}`;
solution.startY = y;
solution.startZ = dy;
result.push(solution);
}
}
return result;
}
function eqs2_order1_solutions(y_0, z_0, t_0, t_start, t_end, Δt, fn_y, fn_z) {
const result = [];
if (!Array.isArray(y_0))
y_0 = [y_0];
if (!Array.isArray(z_0))
z_0 = [z_0];
for (const y of y_0) {
for (const z of z_0) {
const solution = rk4solve_2(y, z, t_0, t_start, t_end, Δt, fn_y, fn_z);
solution.startConditions = `y(${t_0}) = ${y}, z(${t_0}) = ${z}`;
solution.startY = y;
solution.startZ = z;
result.push(solution);
}
}
return result;
}
function gradations(from, to, n) {
if (n === 0)
return [];
if (to == from)
return [to];
if (n === 1)
return [(to - from) / 2];
const result = [], interval = (to - from) / (n - 1);
let value = from;
n--;
while (n--) {
result.push(value);
value += interval;
}
result.push(to);
return result;
}
class Viewport {
constructor(div, titles) {
this.svg = initDynamicSvg(div);
this.svg.classList.add('viewport');
this.lines = []; this.colors = [];
this.xAxis = this.yAxis = undefined; // DOM objects for axes
this.xScale = this.yScale = undefined; // Scale objects used to draw graph
this.xDomain = this.yDomain = undefined; // Currently displayed range of X/Y coordinates
this.drawn = false;
this.panning = false;
this.touchPos = new Map;
this.linkedViews = [];
if (titles)
this.setTitles(titles[0], titles[1]);
this.svg.listen('wheel', this.mouseWheel.bind(this));
this.svg.listen('mousedown', this.mouseDown.bind(this));
this.svg.listen('mouseup', this.mouseUp.bind(this));
this.svg.listen('mousemove', this.mouseMove.bind(this));
this.svg.listen('mouseleave', this.mouseLeave.bind(this));
this.svg.listen('touchstart', this.touchStart.bind(this));
this.svg.listen('touchmove', this.touchMove.bind(this));
this.svg.listen('touchend', this.touchEnd.bind(this));
this.svg.listen('touchcancel', this.touchEnd.bind(this));
}
draw(lines, colors) {
for (const path of Array.from(this.svg.$class('line')))
path.remove();
const bounds = combinedBounds(lines);
const xSize = bounds[1] - bounds[0];
const ySize = bounds[3] - bounds[2];
// Don't allow the X/Y domain of the graph to become so large that exponential notation is used
const xDomain = [Math.max(bounds[0] - (xSize * 0.05), -1000000000000), Math.min(bounds[1] + (xSize * 0.05), 1000000000000)];
const yDomain = [Math.max(bounds[2] - (ySize * 0.05), -1000000000000), Math.min(bounds[3] + (ySize * 0.05), 1000000000000)];
const xScale = d3.scaleLinear().domain(xDomain).range([25, 190]);
const yScale = d3.scaleLinear().domain(yDomain).range([185, 15]);
const xAxis = d3.axisBottom().scale(xScale);
const yAxis = d3.axisLeft().scale(yScale);
if (!this.xAxis) {
this.xAxis = d3.select(this.svg).append('g').attr('class', 'axis xaxis').attr('transform', 'translate(0,185)').call(xAxis);
} else {
this.xAxis.transition().duration(250).call(xAxis);
}
if (!this.yAxis) {
this.yAxis = d3.select(this.svg).append('g').attr('class', 'axis yaxis').attr('transform', 'translate(25,0)').call(yAxis);
} else {
this.yAxis.transition().duration(250).call(yAxis);
}
for (let i = 0; i < lines.length; i++) {
const line = lines[i];
const node = prependSvgNode(this.svg, 'path', 'line', {
d: line.svgPath(xScale, yScale),
stroke: colors[i],
fill: 'none',
'stroke-width': '.12rem',
'vector-effect': 'non-scaling-stroke'
});
node.listen('mouseenter', () => this.linkedViews.forEach((view) => {
view.highlightLine(i);
node.setAttribute('stroke-width', '.24rem');
}));
node.listen('mouseleave', () => this.linkedViews.forEach((view) => {
view.clearHighlight();
node.setAttribute('stroke-width', '.12rem');
}));
if (line.title()) {
const title = appendSvgNode(node, 'title');
title.textContent = line.title();
}
}
this.xScale = xScale; this.yScale = yScale; // Record scale which graph was first drawn at
this.xDomain = xDomain; this.yDomain = yDomain;
this.lines = Object.freeze(lines); this.colors = Object.freeze(colors);
this.drawn = true;
}
adjustView(xDomain, yDomain, duration = 250) {
if (!this.drawn)
return;
// Don't allow scale to become so large that exponential notation is used for axis labels
xDomain = [Math.max(xDomain[0], -1000000000000), Math.min(xDomain[1], 1000000000000)];
yDomain = [Math.max(yDomain[0], -1000000000000), Math.min(yDomain[1], 1000000000000)];
const xScale = d3.scaleLinear().domain(xDomain).range([25, 190]);
const yScale = d3.scaleLinear().domain(yDomain).range([185, 15]);
const xAxis = d3.axisBottom().scale(xScale);
const yAxis = d3.axisLeft().scale(yScale);
this.xAxis.transition().duration(duration).call(xAxis);
this.yAxis.transition().duration(duration).call(yAxis);
const stretchX = (xScale(1) - xScale(0)) / (this.xScale(1) - this.xScale(0));
const stretchY = (yScale(1) - yScale(0)) / (this.yScale(1) - this.yScale(0));
d3.select(this.svg).selectAll('.line').transition().duration(duration)
.attr('transform', `matrix(${stretchX} 0 0 ${stretchY} ${xScale(0) - (this.xScale(0) * stretchX)} ${yScale(0) - (this.yScale(0) * stretchY)})`);
this.xDomain = xDomain; this.yDomain = yDomain;
}
// factor > 1 means zooming out, 0 < factor < 1 means zooming in
zoom(factor) {
if (factor <= 0)
throw new Error("Cannot zoom by negative factor");
const xSize = this.xDomain[1] - this.xDomain[0];
const ySize = this.yDomain[1] - this.yDomain[0];
const newXSize = xSize * factor;
const newYSize = ySize * factor;
const Δx = (newXSize - xSize) / 2;
const Δy = (newYSize - ySize) / 2;
this.adjustView([this.xDomain[0] - Δx, this.xDomain[1] + Δx], [this.yDomain[0] - Δy, this.yDomain[1] + Δy]);
}
pan(Δx, Δy, duration = 250) {
if (Δx === 0 && Δy === 0)
return;
this.adjustView([this.xDomain[0] + Δx, this.xDomain[1] + Δx], [this.yDomain[0] + Δy, this.yDomain[1] + Δy], duration);
}
// Get position of mouse event in function input/output space
mouseDomainPosition(event) {
const pos = getMousePosition(event, this.svg);
const xScale = d3.scaleLinear().domain(this.xDomain).range([25, 190]);
const yScale = d3.scaleLinear().domain(this.yDomain).range([185, 15]);
return [xScale.invert(pos[0]), yScale.invert(pos[1])];
}
mouseWheel(event) {
event.preventDefault();
if (!this.drawn)
return;
const [x, y] = this.mouseDomainPosition(event);
const factor = Math.E ** (event.deltaY * 0.001);
function scaleAround(v_0, v_1, factor) {
return v_0 + ((v_1 - v_0) * factor);
}
// 100ms delay feels better than 250ms here
this.adjustView([scaleAround(x, this.xDomain[0], factor), scaleAround(x, this.xDomain[1], factor)], [scaleAround(y, this.yDomain[0], factor), scaleAround(y, this.yDomain[1], factor)], 100);
}
mouseDown(event) {
if (!this.drawn)
return;
if (event.button === 1 && !this.panning) {
// middle/wheel button pressed
this.beginPan(event);
}
}
mouseUp(event) {
if (this.panning) {
this.endPan(event);
}
}
mouseMove(event) {
if (this.panning) {
const panPos = this.mouseDomainPosition(event);
this.pan(this.panStart[0] - panPos[0], this.panStart[1] - panPos[1], 0);
}
}
mouseLeave(event) {
if (this.panning) {
this.endPan(event);
}
}
beginPan(event) {
this.panning = true;
this.svg.classList.add('panning');
this.panStart = this.mouseDomainPosition(event);
}
endPan(event) {
this.panning = false;
this.svg.classList.remove('panning');
this.panStart = undefined;
}
touchStart(event) {
this.touchPos.clear();
for (const touch of event.targetTouches) {
this.touchPos.set(touch.identifier, this.mouseDomainPosition(touch));
}
}
touchMove(event) {
if (this.touchPos.size === 1) {
event.preventDefault();
const touch = event.changedTouches[0];
const startPos = this.touchPos.get(touch.identifier);
if (!startPos)
return;
const currentPos = this.mouseDomainPosition(touch);
this.pan(startPos[0] - currentPos[0], startPos[1] - currentPos[1], 0);
} else if (this.touchPos.size === 2) {
// Pinch to zoom (on mobile)
event.preventDefault();
function average(numbers) {
return numbers.reduce((a,b) => a + b, 0) / numbers.length;
}
function avgPoint(points) {
return [average(points.map((p) => p[0])), average(points.map((p) => p[1]))];
}
function distance(points) {
return Math.sqrt((points[0][0] - points[1][0]) ** 2 + (points[0][1] - points[1][1]) ** 2);
}
const startPoints = Array.from(this.touchPos.values());
const currentPoints = Array.from(event.targetTouches).map((t) => this.mouseDomainPosition(t));
const startCenter = avgPoint(startPoints);
const currentCenter = avgPoint(currentPoints);
const Δx = startCenter[0] - currentCenter[0];
const Δy = startCenter[1] - currentCenter[1];
const stretch = distance(startPoints) / distance(currentPoints);
function scaleAround(v_0, v_1, factor) {
return v_0 + ((v_1 - v_0) * factor);
}
this.adjustView([scaleAround(startCenter[0], this.xDomain[0] + Δx, stretch), scaleAround(startCenter[0], this.xDomain[1] + Δx, stretch)], [scaleAround(startCenter[1], this.yDomain[0] + Δy, stretch), scaleAround(startCenter[1], this.yDomain[1] + Δy, stretch)], 0);
}
}
touchEnd(event) {
for (const touch of event.changedTouches) {
this.touchPos.delete(touch.identifier);
}
}
highlightLine(index) {
this.clearHighlight();
const line = this.lines[index];
if (line) {
const xScale = d3.scaleLinear().domain(this.xDomain).range([25, 190]);
const yScale = d3.scaleLinear().domain(this.yDomain).range([185, 15]);
const node = appendSvgNode(this.svg, 'path', 'line', {
d: line.svgPath(xScale, yScale),
stroke: this.colors[index],
fill: 'none',
'stroke-width': '.24rem',
'vector-effect': 'non-scaling-stroke'
});
node.classList.add('highlight');
}
}
clearHighlight() {
for (const path of Array.from(this.svg.$class('highlight')))
path.remove();
}
setTitles(xLabel, yLabel) {
const xText = this.svg.$class1('xaxis-label');
if (xText)
xText.textContent = xLabel;
else
d3.select(this.svg).append('text').attr('class', 'xaxis-label').attr('transform', 'translate(193,190)').attr('font-size', '8').style('text-anchor', 'left').text(xLabel);
const yText = this.svg.$class1('yaxis-label');
if (yText)
yText.textContent = yLabel;
else
d3.select(this.svg).append('text').attr('class', 'yaxis-label').attr('transform', 'translate(20,12)').attr('font-size', '8').style('text-anchor', 'left').text(yLabel);
}
}
const vp1 = new Viewport($id('view1'), ['y', "y'"]);
const vp2 = new Viewport($id('view2'), ['t', 'y']);
vp1.linkedViews.push(vp2);
vp2.linkedViews.push(vp1);
const axisTitles = { 'z-y': ['y', 'z'], 'dy-y': ['y', "y'"], 'y-t': ['t', 'y'], 'dz-z': ['z', "z'"], 'z-t': ['t', 'z'] };
function validateField(ok, good, field, message) {
if (good) {
field.setAttribute('title', '');
field.classList.remove('err');
} else {
field.setAttribute('title', message);
field.classList.add('err');
}
return ok && good;
}
function validate() {
let ok = true;
const timeStep = Number($id('timestep').value);
const startTime = Number($id('starttime').value);
const endTime = Number($id('endtime').value);
const t0 = Number($id('t0').value);
ok = validateField(ok, timeStep > 0, $id('timestep'), "Time step must be a positive number");
ok = validateField(ok, startTime < endTime, $id('starttime'), "Start time must be before end time");
ok = validateField(ok, startTime < endTime, $id('endtime'), "End time must be after start time");
ok = validateField(ok, t0 >= startTime && t0 <= endTime, $id('t0'), "t₀ must be between start and end time");
const min_y0 = Number($id('starty1').value);
const max_y0 = Number($id('starty2').value);
const n_y0 = Number($id('nstarty').value);
ok = validateField(ok, min_y0 <= max_y0, $id('starty1'), "Min y must not be greater than max y");
ok = validateField(ok, min_y0 <= max_y0, $id('starty2'), "Max y must not be less than min y");
ok = validateField(ok, n_y0 > 0, $id('nstarty'), "Number of values to graph must be one or more");
if ($id('startdy1').style.display !== 'none') {
const min_dy0 = Number($id('startdy1').value);
const max_dy0 = Number($id('startdy2').value);
const n_dy0 = Number($id('nstartdy').value);
ok = validateField(ok, min_dy0 <= max_dy0, $id('startdy1'), "Min y' must not be greater than max y'");
ok = validateField(ok, min_dy0 <= max_dy0, $id('startdy2'), "Max y' must not be less than min y'");
ok = validateField(ok, n_dy0 > 0, $id('nstartdy'), "Number of values to graph must be one or more");
}
if ($id('startz1').style.display !== 'none') {
const min_z0 = Number($id('startz1').value);
const max_z0 = Number($id('startz2').value);
const n_z0 = Number($id('nstartz').value);
ok = validateField(ok, min_z0 <= max_z0, $id('startz1'), "Min z must not be greater than max z");
ok = validateField(ok, min_z0 <= max_z0, $id('startz2'), "Max z must not be less than min z");
ok = validateField(ok, n_z0 > 0, $id('nstartz'), "Number of values to graph must be one or more");
}
return ok;
}
function redraw() {
if (!validate())
return;
if (/^\s*$/.test($id('exp1').value))
return;
const timeStep = Number($id('timestep').value);
const startTime = Number($id('starttime').value);
const endTime = Number($id('endtime').value);
const t0 = Number($id('t0').value);
let [ok, result] = tokenize($id('exp1').value);
if (!ok) {
highlightParseError($id('exp1'), result);
return;
}
[ok, result] = parse(result);
if (!ok) {
highlightParseError($id('exp1'), result);
return;
}
const fn = compileFunction(result);
const min_y0 = Number($id('starty1').value);
const max_y0 = Number($id('starty2').value);
const n_y0 = Number($id('nstarty').value);
const y0_values = gradations(min_y0, max_y0, n_y0);
const hueScale = d3.scaleLinear().domain([min_y0, max_y0]).range([0,270]);
let solutions, colors, satScale, lightScale;
switch($id('graphtype').value) {
case 'order1':
solutions = order1_solutions(y0_values, t0, startTime, endTime, timeStep, fn);
colors = solutions.map((solution) => d3.hsl(hueScale(solution.startY), 0.75, 0.75).formatHex());
vp1.draw(solutions.map((s) => new DerivativeLine(s, 0, fn)), colors);
vp2.draw(solutions.map((s) => new TimeLine(s)), colors);
break;
case 'order2':
const min_dy0 = Number($id('startdy1').value);
const max_dy0 = Number($id('startdy2').value);
const n_dy0 = Number($id('nstartdy').value);
const dy0_values = gradations(min_dy0, max_dy0, n_dy0);
solutions = order2_solutions(y0_values, dy0_values, t0, startTime, endTime, timeStep, fn);
satScale = d3.scaleLinear().domain([min_dy0, max_dy0]).range([0.34,0.84]);
lightScale = d3.scaleLinear().domain([min_dy0, max_dy0]).range([0.78,0.48]);
colors = solutions.map((solution) => d3.hsl(hueScale(solution.startY), satScale(solution.startZ), lightScale(solution.startZ)).formatHex());
vp1.draw(solutions.map((s) => new PhaseLine(s)), colors);
vp2.draw(solutions.map((s) => new TimeLine(s)), colors);
break;
case 'eqs2_order1':
if (/^\s*$/.test($id('exp2').value))
return;
const min_z0 = Number($id('startz1').value);
const max_z0 = Number($id('startz2').value);
const n_z0 = Number($id('nstartz').value);
const z0_values = gradations(min_z0, max_z0, n_z0);
[ok, result] = tokenize($id('exp2').value);
if (!ok) {
highlightParseError($id('exp2'), result);
return;
}
[ok, result] = parse(result);
if (!ok) {
highlightParseError($id('exp2'), result);
return;
}
const fn2 = compileFunction(result);
solutions = eqs2_order1_solutions(y0_values, z0_values, t0, startTime, endTime, timeStep, fn, fn2);
satScale = d3.scaleLinear().domain([min_z0, max_z0]).range([0.34,0.84]);
lightScale = d3.scaleLinear().domain([min_z0, max_z0]).range([0.78,0.48]);
colors = solutions.map((solution) => d3.hsl(hueScale(solution.startY), satScale(solution.startZ), lightScale(solution.startZ)).formatHex());
switch($id('graph1').value) {
case 'z-y': vp1.draw(solutions.map((s) => new PhaseLine(s)), colors); break;
case 'dy-y': vp1.draw(solutions.map((s) => new DerivativeLine(s, 0, fn)), colors); break;
case 'y-t': vp1.draw(solutions.map((s) => new TimeLine(s)), colors); break;
case 'dz-z': vp1.draw(solutions.map((s) => new DerivativeLine(s, 1, fn2)), colors); break;
case 'z-t': vp1.draw(solutions.map((s) => new TimeLine(s, 1)), colors); break;
}
switch($id('graph2').value) {
case 'z-y': vp2.draw(solutions.map((s) => new PhaseLine(s)), colors); break;
case 'dy-y': vp2.draw(solutions.map((s) => new DerivativeLine(s, 0, fn)), colors); break;
case 'y-t': vp2.draw(solutions.map((s) => new TimeLine(s)), colors); break;
case 'dz-z': vp2.draw(solutions.map((s) => new DerivativeLine(s, 1, fn2)), colors); break;
case 'z-t': vp2.draw(solutions.map((s) => new TimeLine(s, 1)), colors); break;
}
break;
}
}
$id('graphtype').listen('change', (event) => {
if (event.target.value === 'eqs2_order1') {
for (const elem of $class('eqs2only')) elem.style.display = '';
vp1.setTitles(...axisTitles[$id('graph1').value]);
vp2.setTitles(...axisTitles[$id('graph2').value]);
} else {
for (const elem of $class('eqs2only')) elem.style.display = 'none';
vp1.setTitles('y', "y'");
vp2.setTitles('t', 'y');
}
for (const elem of $class('order2only'))
elem.style.display = (event.target.value === 'order2') ? '' : 'none';
$id('exp1').labels[0].textContent = (event.target.value === 'order2') ? "y'' =" : "y' =";
redraw();
});
$id('exp1').listen('change', redraw);
$id('exp2').listen('change', redraw);
$id('exp1').listen('keydown', () => clearParseError($id('exp1')));
$id('exp2').listen('keydown', () => clearParseError($id('exp2')));
$id('timestep').listen('change', redraw);
$id('starttime').listen('change', redraw);
$id('endtime').listen('change', redraw);
$id('t0').listen('change', (event) => {
$id('starty1').labels[0].textContent = `Min. y(${event.target.value}):`;
$id('starty2').labels[0].textContent = `Max. y(${event.target.value}):`;
$id('startdy1').labels[0].textContent = `Min. y'(${event.target.value}):`;
$id('startdy2').labels[0].textContent = `Max. y'(${event.target.value}):`;
$id('startz1').labels[0].textContent = `Min. z(${event.target.value}):`;
$id('startz2').labels[0].textContent = `Max. z(${event.target.value}):`;
redraw();
});
$id('starty1').listen('change', redraw);
$id('starty2').listen('change', redraw);
$id('nstarty').listen('change', redraw);
$id('startdy1').listen('change', redraw);
$id('startdy2').listen('change', redraw);
$id('nstartdy').listen('change', redraw);
$id('startz1').listen('change', redraw);
$id('startz2').listen('change', redraw);
$id('nstartz').listen('change', redraw);
$id('graph1').listen('change', (event) => {
vp1.setTitles(...axisTitles[event.target.value]);
redraw();
});
$id('graph2').listen('change', (event) => {
vp2.setTitles(...axisTitles[event.target.value]);
redraw();
});
$id('sample_Y').listen('click', () => {
$id('graphtype').value = 'order1';
$id('exp1').value = 'y';
$id('starttime').value = '0';
$id('endtime').value = '5';
$id('timestep').value = '0.01';
$id('t0').value = '0';
$id('starty1').value = '-3';
$id('starty2').value = '3';
$id('nstarty').value = '5';
fireEvent($id('graphtype'), 'change');
});
$id('sample_cosY').listen('click', () => {
$id('graphtype').value = 'order1';
$id('exp1').value = 'cos(y)';
$id('starttime').value = '0';
$id('endtime').value = '10';
$id('timestep').value = '0.01';
$id('t0').value = '0';
$id('starty1').value = '-1';
$id('starty2').value = '10.5';
$id('nstarty').value = '30';
fireEvent($id('graphtype'), 'change');
});
$id('sample_riccati').listen('click', () => {
$id('graphtype').value = 'order1';
$id('exp1').value = '-((y^2) * (t^2)) + ((2 * y) / t)';
$id('starttime').value = '1';
$id('endtime').value = '3.5';
$id('timestep').value = '0.01';
$id('t0').value = '1';
$id('starty1').value = '1.5';
$id('starty2').value = '3.5';
$id('nstarty').value = '8';
fireEvent($id('graphtype'), 'change');
});
$id('sample_logistic').listen('click', () => {
$id('graphtype').value = 'order1';
$id('exp1').value = '(10*y)-(y^2)';
$id('starttime').value = '-1';
$id('endtime').value = '1';
$id('timestep').value = '0.01';
$id('t0').value = '0';
$id('starty1').value = '1';
$id('starty2').value = '5';
$id('nstarty').value = '10';
fireEvent($id('graphtype'), 'change');
});
$id('sample_cosT').listen('click', () => {
$id('graphtype').value = 'order2';
$id('exp1').value = 'cos(t)';
$id('starttime').value = '-10';
$id('endtime').value = '10';
$id('timestep').value = '0.05';
$id('t0').value = '0';
$id('starty1').value = '-3';
$id('starty2').value = '3';
$id('nstarty').value = '5';
$id('startdy1').value = '-3';
$id('startdy2').value = '3';
$id('nstartdy').value = '5';
fireEvent($id('graphtype'), 'change');
});
$id('sample_cosY_2').listen('click', () => {
$id('graphtype').value = 'order2';
$id('exp1').value = 'cos(y)';
$id('starttime').value = '-10';
$id('endtime').value = '10';
$id('timestep').value = '0.02';
$id('t0').value = '0';
$id('starty1').value = '-3';
$id('starty2').value = '3';
$id('nstarty').value = '5';
$id('startdy1').value = '-3';
$id('startdy2').value = '3';
$id('nstartdy').value = '5';
fireEvent($id('graphtype'), 'change');
});
$id('sample_mathieu').listen('click', () => {
$id('graphtype').value = 'order2';
$id('exp1').value = '(1 + (2 * cos(t))) * y';
$id('starttime').value = '-12.5';
$id('endtime').value = '12.5';
$id('timestep').value = '0.05';
$id('t0').value = '0';
$id('starty1').value = '-5';
$id('starty2').value = '5';
$id('nstarty').value = '4';
$id('startdy1').value = '-3';
$id('startdy2').value = '3';
$id('nstartdy').value = '4';
fireEvent($id('graphtype'), 'change');
});
$id('sample_cauchy').listen('click', () => {
$id('graphtype').value = 'order2';
$id('exp1').value = "((-2.75 / t) * y') + ((-4.8 / (t^2)) * y)";
$id('starttime').value = '2';
$id('endtime').value = '35';
$id('timestep').value = '0.05';
$id('t0').value = '2';
$id('starty1').value = '-1.5';
$id('starty2').value = '1.5';
$id('nstarty').value = '4';
$id('startdy1').value = '-2';
$id('startdy2').value = '2';
$id('nstartdy').value = '3';
fireEvent($id('graphtype'), 'change');
});
$id('sample_meissner').listen('click', () => {
$id('graphtype').value = 'order2';
$id('exp1').value = "-y*(2 + (3 * sgn(cos(t))))";
$id('starttime').value = '-10';
$id('endtime').value = '10';
$id('timestep').value = '0.01';
$id('t0').value = '0';
$id('starty1').value = '-3';
$id('starty2').value = '3';
$id('nstarty').value = '5';
$id('startdy1').value = '-3';
$id('startdy2').value = '3';
$id('nstartdy').value = '5';
fireEvent($id('graphtype'), 'change');
});
$id('sample_predprey').listen('click', () => {
$id('graphtype').value = 'eqs2_order1';
$id('exp1').value = 'y-(0.01*y*z)';
$id('exp2').value = '-z+(0.02*y*z)'
$id('starttime').value = '0';
$id('endtime').value = '20';
$id('timestep').value = '0.05';
$id('t0').value = '0';
$id('starty1').value = '10';
$id('starty2').value = '20';
$id('nstarty').value = '3';
$id('startz1').value = '10';
$id('startz2').value = '20';
$id('nstartz').value = '3';
fireEvent($id('graphtype'), 'change');
});
redraw();
</script>
<hr>
<p>Classic Runge-Kutta, also known as “RK4”, generates four different estimates of the rate of change of each variable at each time step, and takes a weighted sum of those four estimates as a final estimate which is (usually) more accurate than any of the four. Then, we use those estimated derivatives to adjust the values of each variable, bump “time” forward by one step, and repeat until we reach the ending time of the simulation. Here is a simple implementation of RK4 for a system with just one dependent variable:</p>
<figure class="highlight"><pre><code class="language-javascript" data-lang="javascript"><span class="c1">// Trace out evolution of our system using classic Runge-Kutta (AKA "RK4")</span>
<span class="c1">// Store results in a packed array of floats</span>
<span class="kd">function</span> <span class="nx">rk4trace</span><span class="p">(</span><span class="nx">y</span><span class="p">,</span> <span class="nx">t</span><span class="p">,</span> <span class="nx">Δt</span><span class="p">,</span> <span class="nx">fn</span><span class="p">,</span> <span class="nx">array</span><span class="p">,</span> <span class="nx">i</span><span class="p">,</span> <span class="nx">Δi</span><span class="p">,</span> <span class="nx">limit</span><span class="p">)</span> <span class="p">{</span>
<span class="k">while</span> <span class="p">(</span><span class="nx">i</span> <span class="o">!==</span> <span class="nx">limit</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">half_Δt</span> <span class="o">=</span> <span class="nx">Δt</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">;</span>
<span class="kd">const</span> <span class="nx">next_t</span> <span class="o">=</span> <span class="nx">t</span> <span class="o">+</span> <span class="nx">Δt</span><span class="p">;</span>
<span class="kd">const</span> <span class="nx">half_t</span> <span class="o">=</span> <span class="nx">t</span> <span class="o">+</span> <span class="nx">half_Δt</span><span class="p">;</span>
<span class="kd">const</span> <span class="nx">k_1</span> <span class="o">=</span> <span class="nx">fn</span><span class="p">(</span><span class="nx">t</span><span class="p">,</span> <span class="nx">y</span><span class="p">);</span> <span class="c1">// Slope at starting point</span>
<span class="kd">const</span> <span class="nx">k_2</span> <span class="o">=</span> <span class="nx">fn</span><span class="p">(</span><span class="nx">half_t</span><span class="p">,</span> <span class="nx">y</span> <span class="o">+</span> <span class="p">(</span><span class="nx">half_Δt</span> <span class="o">*</span> <span class="nx">k_1</span><span class="p">));</span> <span class="c1">// Estimated slope at mid-point</span>
<span class="kd">const</span> <span class="nx">k_3</span> <span class="o">=</span> <span class="nx">fn</span><span class="p">(</span><span class="nx">half_t</span><span class="p">,</span> <span class="nx">y</span> <span class="o">+</span> <span class="p">(</span><span class="nx">half_Δt</span> <span class="o">*</span> <span class="nx">k_2</span><span class="p">));</span> <span class="c1">// Another estimate of slope at mid-point</span>
<span class="kd">const</span> <span class="nx">k_4</span> <span class="o">=</span> <span class="nx">fn</span><span class="p">(</span><span class="nx">next_t</span><span class="p">,</span> <span class="nx">y</span> <span class="o">+</span> <span class="p">(</span><span class="nx">Δt</span> <span class="o">*</span> <span class="nx">k_3</span><span class="p">));</span> <span class="c1">// Estimated slope at endpoint</span>
<span class="kd">const</span> <span class="nx">slope</span> <span class="o">=</span> <span class="p">(</span><span class="nx">k_1</span> <span class="o">+</span> <span class="mi">2</span><span class="o">*</span><span class="nx">k_2</span> <span class="o">+</span> <span class="mi">2</span><span class="o">*</span><span class="nx">k_3</span> <span class="o">+</span> <span class="nx">k_4</span><span class="p">)</span> <span class="o">/</span> <span class="mf">6.0</span><span class="p">;</span> <span class="c1">// Weighted average of those four slopes</span>
<span class="nx">y</span> <span class="o">+=</span> <span class="nx">Δt</span> <span class="o">*</span> <span class="nx">slope</span><span class="p">;</span>
<span class="nx">array</span><span class="p">[</span><span class="nx">i</span><span class="p">]</span> <span class="o">=</span> <span class="nx">y</span><span class="p">;</span>
<span class="nx">t</span> <span class="o">+=</span> <span class="nx">Δt</span><span class="p">;</span>
<span class="nx">i</span> <span class="o">+=</span> <span class="nx">Δi</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span>
<span class="c1">// Apply RK4 to find phase lines for a system with one dependent variable</span>
<span class="kd">function</span> <span class="nx">rk4solve</span><span class="p">(</span><span class="nx">y_0</span><span class="p">,</span> <span class="nx">t_0</span><span class="p">,</span> <span class="nx">t_start</span><span class="p">,</span> <span class="nx">t_end</span><span class="p">,</span> <span class="nx">Δt</span><span class="p">,</span> <span class="nx">fn</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">timeSteps</span> <span class="o">=</span> <span class="nb">Math</span><span class="p">.</span><span class="nx">floor</span><span class="p">(((</span><span class="nx">t_end</span> <span class="o">-</span> <span class="nx">t_start</span><span class="p">)</span> <span class="o">/</span> <span class="nx">Δt</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">);</span>
<span class="c1">// Packed array of variable values at each time step:</span>
<span class="kd">const</span> <span class="nx">array</span> <span class="o">=</span> <span class="k">new</span> <span class="nb">Float64Array</span><span class="p">(</span><span class="nx">timeSteps</span><span class="p">);</span>
<span class="kd">let</span> <span class="nx">t</span> <span class="o">=</span> <span class="nx">t_0</span><span class="p">,</span> <span class="nx">y</span> <span class="o">=</span> <span class="nx">y_0</span><span class="p">,</span> <span class="nx">i</span> <span class="o">=</span> <span class="nb">Math</span><span class="p">.</span><span class="nx">floor</span><span class="p">((</span><span class="nx">t_0</span> <span class="o">-</span> <span class="nx">t_start</span><span class="p">)</span> <span class="o">/</span> <span class="nx">Δt</span><span class="p">);</span>
<span class="nx">array</span><span class="p">[</span><span class="nx">i</span><span class="p">]</span> <span class="o">=</span> <span class="nx">y_0</span><span class="p">;</span>
<span class="c1">// Trace out phase line from starting point</span>
<span class="nx">rk4trace</span><span class="p">(</span><span class="nx">y_0</span><span class="p">,</span> <span class="nx">t_0</span><span class="p">,</span> <span class="nx">Δt</span><span class="p">,</span> <span class="nx">fn</span><span class="p">,</span> <span class="nx">array</span><span class="p">,</span> <span class="nx">i</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="nx">array</span><span class="p">.</span><span class="nx">length</span><span class="p">);</span>
<span class="c1">// Trace out phase line in the opposite direction from the starting point</span>
<span class="nx">rk4trace</span><span class="p">(</span><span class="nx">y_0</span><span class="p">,</span> <span class="nx">t_0</span><span class="p">,</span> <span class="o">-</span><span class="nx">Δt</span><span class="p">,</span> <span class="nx">fn</span><span class="p">,</span> <span class="nx">array</span><span class="p">,</span> <span class="nx">i</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">);</span>
<span class="k">return</span> <span class="nx">array</span><span class="p">;</span>
<span class="p">}</span></code></pre></figure>
<p>The above implementation stores the values of the dependent variable in a <a href='https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Float64Array'><code>Float64Array</code></a> instead of a regular JavaScript Array; this is for speed and memory efficiency. It doesn't store the time value for each entry in the array, since that can be easily rederived from <code>t_start</code>, <code>t_end</code>, and <code>Δt</code>.</p>
<hr>
<h2 id='syntax'>Expression Syntax Help for this Tool</h2>
<table>
<tr><td>Variables</td><td><code>t</code> <code>y</code><br><code>y'</code> <i>(for 2<sup>nd</sup>-order equations)</i><br><code>z</code> <i>(for systems of two 1<sup>st</sup>-order equations)</i><br><code>x</code> <i>(alternative name for <code>t</code>)</i></td></tr>
<tr>
<td>Numbers</td>
<td>
<code>1</code>, <code>2</code>, <code>-1</code>, etc<br>
<code>1.1234</code>...<br>
<code>e</code> <i>(Euler's number, 2.71828...)</i><br>
<code>pi</code> or <code>π</code> <i>(3.14159...)</i><br>
If you want any other mathematical constants, drop the author a line.
</td>
</tr>
<tr><td>Arithmetic</td><td><code><i>expression</i> + <i>expression</i></code><br>Other binary operators are <code>-</code>, <code>*</code>, <code>/</code>, and <code>^</code> or <code>**</code> for exponentiation<br><code>-<i>expression</i></code> for negation</td></tr>
<tr><td>Parentheses</td><td><code>(<i>expression</i>)</code><br>Use parentheses to ensure expressions are grouped in the way you want</td></tr>
<tr>
<td>Functions</td>
<td>
<code>sqrt(<i>expression</i>)</code><br>
<code>sin(<i>expression</i>)</code> <i>(for trig functions, the input parameter is in radians)</i><br>
<code>cos(<i>expression</i>)</code><br>
<code>tan(<i>expression</i>)</code><br>
<code>arcsin(<i>expression</i>)</code> <i>(inverse trig functions)</i><br>
<code>arccos(<i>expression</i>)</code><br>
<code>arctan(<i>expression</i>)</code><br>
<code>ln(<i>expression</i>)</code> or <code>log(<i>expression</i>)</code> <i>(natural logarithm)</i><br>
<code>log10(<i>expression</i>)</code><br>
<code>log2(<i>expression</i>)</code><br>
<code>sign(<i>expression</i>)</code> or <code>sgn(<i>expression</i>)</code> <i>(0 for zero, 1 for positive numbers, -1 for negative numbers)</i><br>
If you want any other mathematical functions, drop the author a line.
</td>
</tr>
</table>
<p><i><b>Special thanks</b> to Gilbert Strang for his text “Differential Equations and Linear Algebra”, which inspired me to make this tool.</i></p>This post presents a simple, interactive differential equation solution graphing tool based on the classic Runge-Kutta method (which is really an algorithm). It graphs solutions to one 1st-order equation, one 2nd-order equation, or a system of two 1st-order equations. The right-hand side of the equation(s) must be entered using syntax similar to expressions in the C, Java, or JavaScript programming languages (syntax help). Since the solution of a differential equation depends on starting conditions, you can set the range of starting conditions which should be graphed and the “time” value at which the starting conditions apply. The tool will draw one line for each set of starting conditions. Hover over a solution line to see what starting conditions it is based on. Roll your mouse wheel to zoom in and out; hold down the middle mouse button to pan. On mobile, use one finger to pan and pinch with two fingers to zoom. Click these buttons to see a variety of samples: y' = y y' = cos(y) y'' = cos(t) y'' = cos(y) y' = ay - by² (the logistic equation) y'' = y(a - 2q cos(t)) (Mathieu's equation) y' = -y²t² + (2y / t) (a Riccati equation) y'' = -(a / t)y' - (b / t²)y (a Cauchy-Euler equation) y'' = -y(α² + ω² sign(cos(t))) (Meissner equation) y' = y - (0.01yz), z' = -z + (0.02yz) (a “predator-prey” system) Try playing with the parameters for some of the samples; you can get pretty wild pictures! Graph solutions for: one first-order equation one second-order equation a system of two first-order equations y'' = z' = Start time: End time: Time step: t₀: Min y(0): Max y(0): # values to graph: Min y'(0): Max y'(0): # values to graph: Min z(0): Max z(0): # values to graph: Left graph: z vs y y vs t y' vs y z vs t z' vs z Right graph: z vs y y vs t y' vs y z vs t z' vs z Classic Runge-Kutta, also known as “RK4”, generates four different estimates of the rate of change of each variable at each time step, and takes a weighted sum of those four estimates as a final estimate which is (usually) more accurate than any of the four. Then, we use those estimated derivatives to adjust the values of each variable, bump “time” forward by one step, and repeat until we reach the ending time of the simulation. Here is a simple implementation of RK4 for a system with just one dependent variable: // Trace out evolution of our system using classic Runge-Kutta (AKA "RK4") // Store results in a packed array of floats function rk4trace(y, t, Δt, fn, array, i, Δi, limit) { while (i !== limit) { const half_Δt = Δt / 2.0; const next_t = t + Δt; const half_t = t + half_Δt; const k_1 = fn(t, y); // Slope at starting point const k_2 = fn(half_t, y + (half_Δt * k_1)); // Estimated slope at mid-point const k_3 = fn(half_t, y + (half_Δt * k_2)); // Another estimate of slope at mid-point const k_4 = fn(next_t, y + (Δt * k_3)); // Estimated slope at endpoint const slope = (k_1 + 2*k_2 + 2*k_3 + k_4) / 6.0; // Weighted average of those four slopes y += Δt * slope; array[i] = y; t += Δt; i += Δi; } } // Apply RK4 to find phase lines for a system with one dependent variable function rk4solve(y_0, t_0, t_start, t_end, Δt, fn) { const timeSteps = Math.floor(((t_end - t_start) / Δt) + 1); // Packed array of variable values at each time step: const array = new Float64Array(timeSteps); let t = t_0, y = y_0, i = Math.floor((t_0 - t_start) / Δt); array[i] = y_0; // Trace out phase line from starting point rk4trace(y_0, t_0, Δt, fn, array, i+1, 1, array.length); // Trace out phase line in the opposite direction from the starting point rk4trace(y_0, t_0, -Δt, fn, array, i-1, -1, -1); return array; } The above implementation stores the values of the dependent variable in a Float64Array instead of a regular JavaScript Array; this is for speed and memory efficiency. It doesn't store the time value for each entry in the array, since that can be easily rederived from t_start, t_end, and Δt. Expression Syntax Help for this Tool Variablest yy' (for 2nd-order equations)z (for systems of two 1st-order equations)x (alternative name for t) Numbers 1, 2, -1, etc 1.1234... e (Euler's number, 2.71828...) pi or π (3.14159...) If you want any other mathematical constants, drop the author a line. Arithmeticexpression + expressionOther binary operators are -, *, /, and ^ or ** for exponentiation-expression for negation Parentheses(expression)Use parentheses to ensure expressions are grouped in the way you want Functions sqrt(expression) sin(expression) (for trig functions, the input parameter is in radians) cos(expression) tan(expression) arcsin(expression) (inverse trig functions) arccos(expression) arctan(expression) ln(expression) or log(expression) (natural logarithm) log10(expression) log2(expression) sign(expression) or sgn(expression) (0 for zero, 1 for positive numbers, -1 for negative numbers) If you want any other mathematical functions, drop the author a line. Special thanks to Gilbert Strang for his text “Differential Equations and Linear Algebra”, which inspired me to make this tool.Visualizing Nelder-Mead Optimization2022-06-13T00:00:00+00:002022-06-13T00:00:00+00:00/visualizing-nelder-mead<p>Recently, I ran across a fantastic article, <a href='https://www.jmeiners.com/why-train-when-you-can-optimize/'>“Why Train When You Can Optimize?”</a>, which introduced me to the Nelder-Mead optimization algorithm. It's a lovely algorithm, and I couldn't wait to create an interactive version. First, though: what does <b>“optimization”</b> mean in this context?</p>
<p>An <b>“optimization”</b> algorithm takes some mathematical function as its input, and tries to find values for the parameters which make the output either as large or as small as possible. If you are like many computer programmers, your first impression might be that you are unlikely to ever use such an algorithm in your own programs. But optimization is a much more general and useful technique than it might seem. The <a href='https://www.jmeiners.com/why-train-when-you-can-optimize/'>article mentioned above</a> gives a great example: a drawing program which detects when the user is trying to draw a straight line and replaces their jittery line with a perfectly straight one. <i>(If you know other examples of good uses for optimization outside science and engineering, please let me know!)</i></p>
<p>Obviously, finding inputs for some function which give you the largest or smallest output value can be done without a special algorithm. You could just use brute force: test many inputs and pick the best one. But if the space of possible inputs is large, that could be too slow.</p>
<p>Optimization algorithms typically avoid exhaustively searching the input space by starting at some arbitrary point, then repeatedly searching for a nearby point which is better, until it hits a maximum or minimum and can't find any better point.</p>
<p>Many such algorithms require that you know how to calculate the derivative (slope) of the function at any given point; but Nelder-Mead doesn't need any derivatives, and combined with its general simplicity, this makes it easy to apply.</p>
<p>Now let me show you how Nelder-Mead works. Rather than starting with one test point and iteratively improving it, Nelder-Mead starts with <b><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.7667em; vertical-align: -0.0833em;"></span><span class="mord mathnormal" style="margin-right: 0.109em;">N</span><span class="mspace" style="margin-right: 0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.2222em;"></span></span><span class="base"><span class="strut" style="height: 0.6444em;"></span><span class="mord">1</span></span></span></span></b> test points, when the input space has <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.6833em;"></span><span class="mord mathnormal" style="margin-right: 0.109em;">N</span></span></span></span> dimensions. (Or, in other words, when there are <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.6833em;"></span><span class="mord mathnormal" style="margin-right: 0.109em;">N</span></span></span></span> different input variables whose values need to be found.) For example, if your function has two parameters, there will be 3 starting points, which will form a triangle in the 2-dimensional plane of possible inputs:</p>
<div id='example-triangle' class='plot'></div>
<script src='/assets/js/d3-array.js'></script>
<script src='/assets/js/d3-color.js'></script>
<script src='/assets/js/d3-interpolate.js'></script>
<script src='/assets/js/d3-scale.js'></script>
<script src='/assets/js/d3-contour.js'></script>
<script src='/assets/js/d3-geo.js'></script>
<script>
'use strict';
const $id = document.getElementById.bind(document);
Element.prototype.$class = Element.prototype.getElementsByClassName;
Element.prototype.$class1 = function(cssClass) { return this.$class(cssClass)?.[0]; }
Element.prototype.$tag = Element.prototype.getElementsByTagName;
Element.prototype.listen = Element.prototype.addEventListener;
function makeSvgNode(tag, cssClass, attributes = {}) {
const node = document.createElementNS("http://www.w3.org/2000/svg", tag);
if (cssClass)
node.classList.add(cssClass);
for (const [key, value] of Object.entries(attributes))
node.setAttribute(key, value);
return node;
}
function appendSvgNode(parent, tag, cssClass, attributes = {}) {
const node = makeSvgNode(tag, cssClass, attributes);
parent.appendChild(node);
return node;
}
function prependSvgNode(parent, tag, cssClass, attributes = {}) {
const node = makeSvgNode(tag, cssClass, attributes);
parent.prepend(node);
return node;
}
// When working with functions defined over a 2D space, we express all
// coordinates in that space as ranging from 0-100 along both axes
function Point(x, y) {
this.x = x;
this.y = y;
}
Point.prototype.midpoint = function(other) {
return new Point((this.x + other.x) / 2, (this.y + other.y) / 2);
}
Point.prototype.moveTowards = function(other, factor) {
return new Point(
this.x + (other.x - this.x) * factor,
this.y + (other.y - this.y) * factor);
}
Point.prototype.dist = function(other) {
return Math.sqrt((this.x - other.x) ** 2 + (this.y - other.y) ** 2);
}
Point.prototype.eval = function(fn) {
return fn(this.x, this.y);
}
// Test functions
// All these functions should (1) always be greater than zero, and (2)
// they should increase as one moves outside the 0-100 range which we
// use for x,y coordinates (so the algorithm doesn't run off outside
// the area which we will plot)
function himmelblauFunction(x, y) {
// Map 0→100 range to -5→5
x = (x / 10) - 5;
y = (y / 10) - 5;
const a = x*x + y - 11;
const b = x + y*y - 7;
return 1 + a*a + b*b;
}
himmelblauFunction.scale = d3.scalePow().exponent(2);
himmelblauFunction.minima = [[80,70], [21.94882, 81.3], [12.206899, 17.16814], [85.84458, 31.51874]];
function bealeFunction(x, y) {
// Map 0→100 range to -5→5
x = (x / 10) - 5;
y = (y / 10) - 5;
const a = (1.5 - x + x*y);
const b = (2.25 - x + x*y*y);
const c = (2.625 - x + x*y*y*y);
return 1 + a*a + b*b + c*c;
}
bealeFunction.scale = d3.scalePow().exponent(3.8);
bealeFunction.minima = [[80, 55]];
function rosenbrockFunction(x, y) {
// Map 0→100 range to -3→3
x = (x * 0.06) - 3;
y = (y * 0.06) - 3;
const a = (1 - x);
const b = (y - x*x);
return 1 + a*a + 100*b*b;
}
rosenbrockFunction.scale = d3.scalePow().exponent(3.5);
rosenbrockFunction.minima = [[200/3, 200/3]];
function squaredDistFunction(x, y) {
const a = Math.abs(x - 50);
const b = Math.abs(y - 50);
return 1 + a*a + b*b;
}
squaredDistFunction.scale = d3.scalePow().exponent(2);
squaredDistFunction.minima = [[50, 50]];
function goldsteinPriceFunction(x, y) {
x = (x / 25) - 2;
y = (y / 25) - 2;
const a = 1 + ((x + y + 1) ** 2) * (19 - 14*x + (3 * x**2) - 14*y + 6*x*y + (3 * y**2));
const b = 30 + ((2*x - 3*y) ** 2) * (18 - 32*x + 12*x*x + 48*y - 36*x*y + (27 * y**2));
return a * b;
}
goldsteinPriceFunction.scale = d3.scalePow().exponent(4);
goldsteinPriceFunction.minima = [[50, 25]];
function bukinFunction(x, y) {
x = (x / 10) - 15;
y = (y * 0.06) - 3;
return 100*Math.sqrt(Math.abs(y - (0.01 * x**2))) + 0.01*Math.abs(x + 10);
}
bukinFunction.scale = d3.scalePow().exponent(2);
bukinFunction.minima = [[50, 200/3]];
function threeHumpCamel(x, y) {
x = (x / 10) - 5;
y = (y / 10) - 5;
return (2 * x**2) - (1.05 * x**4) + (x**6 / 6) + x*y + y*y;
}
threeHumpCamel.scale = d3.scalePow().exponent(4);
threeHumpCamel.minima = [[50, 50]];
threeHumpCamel.localMinima = [[32.7, 58.1], [67.3, 41.9]];
function mccormickFunction(x, y) {
// The McCormick function can go very negative if y goes in the negative
// direction, and x is not too far off from y
// Add a small factor so our optimization algorithm doesn't "run away"
// in that direction
let correction = 0;
if (y < 0)
correction = -4 * y;
x = (x * 0.055) - 1.5;
y = (y * 0.07) - 3;
return Math.sin(x + y) + ((x - y) ** 2) - 1.5*x + 2.5*y + 1 + correction;
}
mccormickFunction.scale = d3.scalePow().exponent(2);
mccormickFunction.minima = [[17.323818,20.754428]];
mccormickFunction.localMinima = [[74.2,65.6]];
const testFunctions = [
himmelblauFunction,
bealeFunction,
rosenbrockFunction,
squaredDistFunction,
goldsteinPriceFunction,
bukinFunction,
threeHumpCamel,
mccormickFunction
];
// Evaluate 'fn' at every point on a 'steps' by 'steps' grid
function gridFromFunction(fn, steps) {
const values = [], stepSize = 100 / steps;
for (let y = stepSize / 2; y < 100; y += stepSize)
for (let x = stepSize / 2; x < 100; x += stepSize)
values.push(fn(x, y));
return values;
}
function contoursFromGrid(grid, xySteps, zScale, zSteps) {
zScale = zScale.domain([0, zSteps-1]).range(d3.extent(grid));
const thresholds = Array(zSteps).fill(0).map((_, i) => zScale(i));
thresholds.unshift(0); // Ensure the lowest contour will go right around the plotted area
return d3.contours().size([xySteps, xySteps]).thresholds(thresholds)(grid);
}
function drawContours(svg, contours) {
for (const path of Array.from(svg.$class('contour')))
path.remove();
const colorScale = d3.scaleLinear().domain([0, contours.length-1]).range(['#66ff99', '#dc143c'])
const geoPath = d3.geoPath(d3.geoIdentity().fitSize([100, 100], contours[0]));
contours.map((contour, i) => {
if (contour.coordinates.length > 0) {
const path = prependSvgNode(svg, 'path', 'contour');
path.setAttribute('d', geoPath(contour));
path.setAttribute('stroke', colorScale(i));
path.setAttribute('fill', 'none');
path.setAttribute('stroke-width', '.03rem');
}
});
}
function plotFunction(fn, svg) {
const grid = gridFromFunction(fn, 100);
const contours = contoursFromGrid(grid, 100, fn.scale ?? d3.scalePow().exponent(2.5), 23);
drawContours(svg, contours);
}
// Make SVG which will resize itself to match its container
function initDynamicSvg(wrapper) {
// Put this SVG underneath anything else in the same container
const svg = prependSvgNode(wrapper, 'svg', undefined, { viewBox: '0 0 100 100', preserveAspectRatio: 'none' });
const watcher = new ResizeObserver(entries => {
svg.style.width = Math.round(entries[0].contentBoxSize[0].inlineSize);
svg.style.height = Math.round(entries[0].contentBoxSize[0].blockSize);
});
watcher.observe(wrapper);
svg.style.width = wrapper.offsetWidth;
svg.style.height = wrapper.offsetHeight;
return svg;
}
function moveLineToPoints(line, point1, point2) {
line.setAttribute('x1', point1.x);
line.setAttribute('y1', point1.y);
line.setAttribute('x2', point2.x);
line.setAttribute('y2', point2.y);
}
function moveCircleToPoint(circle, point) {
circle.setAttribute('cx', point.x);
circle.setAttribute('cy', point.y);
}
const pointColor = '#ddcc44';
const radius = '.15rem';
function showTrianglePoint(svg, point, cssClass) {
const circle = svg.$class1(cssClass) || appendSvgNode(svg, 'circle', cssClass, { r: radius, fill: pointColor });
moveCircleToPoint(circle, point);
return circle;
}
// A triangle is an array of 3 points
// Take mathematical points as input, return corresponding SVG elements
function showTriangle(svg, triangle) {
const circles = [];
if (triangle.length >= 1)
circles.push(showTrianglePoint(svg, triangle[0], 'point1'));
if (triangle.length >= 2)
circles.push(showTrianglePoint(svg, triangle[1], 'point2'));
if (triangle.length >= 3) {
circles.push(showTrianglePoint(svg, triangle[2], 'point3'));
showTriangleLines(svg, triangle);
}
return circles;
}
function showTriangleLine(svg, cssClass, point1, point2) {
let line = svg.$class1(cssClass);
if (!line) {
line = makeSvgNode('line', cssClass);
if (svg.$class1('point1'))
svg.insertBefore(line, svg.$class1('point1'));
else
svg.appendChild(line);
}
moveLineToPoints(line, point1, point2);
}
function showTriangleLines(svg, triangle) {
showTriangleLine(svg, 'line1', triangle[0], triangle[1]);
showTriangleLine(svg, 'line2', triangle[1], triangle[2]);
showTriangleLine(svg, 'line3', triangle[2], triangle[0]);
}
function getMousePosition(event, svg) {
const CTM = svg.getScreenCTM();
if (event.touches)
event = event.touches[0];
return [
(event.clientX - CTM.e) / CTM.a,
(event.clientY - CTM.f) / CTM.d
];
}
function makeCircleDraggable(svg, circle, callback) {
let dragging = false;
let startX, startY, circleX, circleY;
function startDrag(event) {
dragging = true;
[startX, startY] = getMousePosition(event, svg);
circleX = Number(circle.getAttribute('cx'));
circleY = Number(circle.getAttribute('cy'));
return false;
}
function touchStartDrag(event) {
if (!dragging) {
const [touchX, touchY] = getMousePosition(event, svg);
circleX = Number(circle.getAttribute('cx'));
circleY = Number(circle.getAttribute('cy'));
const dist = (new Point(touchX, touchY)).dist(new Point(circleX, circleY));
if (dist < 5.5)
startDrag(event);
}
}
function drag(event) {
if (dragging) {
const [dragX, dragY] = getMousePosition(event, svg);
const newPoint = new Point(
Math.round(circleX + dragX - startX),
Math.round(circleY + dragY - startY));
moveCircleToPoint(circle, newPoint);
if (callback)
callback(newPoint);
event.preventDefault(); // Suppress scrolling of page on mobile
return false;
}
}
function endDrag(event) {
if (dragging) {
dragging = false;
}
}
circle.listen('mousedown', startDrag);
svg.listen('mousemove', drag);
svg.listen('mouseup', endDrag)
svg.listen('mouseleave', endDrag);
svg.listen('touchstart', touchStartDrag, { passive: false });
svg.listen('touchmove', drag, { passive: false });
svg.listen('touchend', endDrag);
svg.listen('touchleave', endDrag);
svg.listen('touchcancel', endDrag);
circle.classList.add('draggable');
}
function showDraggableTriangle(svg, triangle, callback) {
function dragCallback(index) {
return point => {
triangle = Array.from(triangle);
triangle[index] = point;
showTriangleLines(svg, triangle);
if (callback)
callback(triangle);
}
}
const circles = showTriangle(svg, triangle);
makeCircleDraggable(svg, circles[0], dragCallback(0));
makeCircleDraggable(svg, circles[1], dragCallback(1));
makeCircleDraggable(svg, circles[2], dragCallback(2));
if (callback)
callback(triangle);
}
const svg1 = initDynamicSvg($id('example-triangle'));
plotFunction(himmelblauFunction, svg1);
showDraggableTriangle(svg1, [
new Point(30, 20),
new Point(40, 75),
new Point(75, 55)
]);
</script>
<p>(From here on, all examples will be 2-dimensional, but the algorithm generalizes naturally to any number of dimensions. Further, we will assume that we are searching for a minimum rather than a maximum.<sup><a href='#footnote1' id='fnref1'>[1]</a></sup>)</p>
<p>Nelder-Mead repeatedly transforms the triangle of test points, replacing the worst point with a better one. This causes the triangle to move across the plane in whichever direction the function's value is dropping, and then contract around a local minimum when it finds one. When the triangle becomes small enough, then the algorithm terminates. Like this:</p>
<div id='example-moving-triangle' class='plot'></div>
<script>
'use strict';
const defaultCoefficients = Object.freeze({
reflect: 1,
expand: 2,
contract: 0.5,
shrink: 0.5
});
function readonly(object, property, getter) {
Object.defineProperty(object, property, {
get: getter
});
}
function Demo(fn) {
this.fn = fn;
this.values = [];
this.indices = []; // worst, second, best
this.coefficients = defaultCoefficients;
this.previousTriangle = undefined;
let triangle = [];
Object.defineProperty(this, 'triangle', {
get: () => triangle,
set: (t) => {
this.previousTriangle = triangle;
triangle = Object.freeze(t);
this.values = Object.freeze(t.map(point => this.fn(point.x, point.y)));
this.indices = Object.freeze([0, 1, 2].sort((a,b) => this.values[b] - this.values[a]));
}
});
readonly(this, 'worstPoint', () => this.triangle[this.indices[0]]);
readonly(this, 'secondPoint', () => this.triangle[this.indices[1]]);
readonly(this, 'bestPoint', () => this.triangle[this.indices[2]]);
readonly(this, 'worstValue', () => this.values[this.indices[0]]);
readonly(this, 'secondValue', () => this.values[this.indices[1]]);
readonly(this, 'bestValue', () => this.values[this.indices[2]]);
readonly(this, 'worstIndex', () => this.indices[0]);
readonly(this, 'secondIndex', () => this.indices[1]);
readonly(this, 'bestIndex', () => this.indices[2]);
readonly(this, 'midpoint', () => this.secondPoint.midpoint(this.bestPoint));
readonly(this, 'reflectPoint', () => this.worstPoint.moveTowards(this.midpoint, 1 + this.coefficients.reflect));
readonly(this, 'expandPoint', () => this.worstPoint.moveTowards(this.midpoint, 1 + this.coefficients.expand));
readonly(this, 'insidePoint', () => this.worstPoint.moveTowards(this.midpoint, 1 - this.coefficients.contract));
readonly(this, 'outsidePoint', () => this.worstPoint.moveTowards(this.midpoint, 1 + this.coefficients.contract));
readonly(this, 'reflectValue', () => this.reflectPoint.eval(this.fn));
readonly(this, 'expandValue', () => this.expandPoint.eval(this.fn));
readonly(this, 'insideValue', () => this.insidePoint.eval(this.fn));
readonly(this, 'outsideValue', () => this.outsidePoint.eval(this.fn));
readonly(this, 'shrunkTriangle', () => this.triangle.map((point,i) => (i === this.bestIndex) ? point : point.moveTowards(this.bestPoint, this.coefficients.shrink)));
readonly(this, 'shrinkPoint1', () => this.worstPoint.moveTowards(this.bestPoint, this.coefficients.shrink));
readonly(this, 'shrinkPoint2', () => this.secondPoint.moveTowards(this.bestPoint, this.coefficients.shrink));
}
Demo.prototype.doReflect = function() {
const newTriangle = Array.from(this.triangle);
newTriangle[this.worstIndex] = this.reflectPoint;
this.triangle = newTriangle;
}
Demo.prototype.doExpand = function() {
const newTriangle = Array.from(this.triangle);
newTriangle[this.worstIndex] = this.expandPoint;
this.triangle = newTriangle;
}
Demo.prototype.doContractInside = function() {
const newTriangle = Array.from(this.triangle);
newTriangle[this.worstIndex] = this.insidePoint;
this.triangle = newTriangle;
}
Demo.prototype.doContractOutside = function() {
const newTriangle = Array.from(this.triangle);
newTriangle[this.worstIndex] = this.outsidePoint;
this.triangle = newTriangle;
}
Demo.prototype.doShrink = function() {
this.triangle = this.shrunkTriangle;
}
Demo.prototype.doNelderMead = function() {
if (this.reflectValue < this.bestValue) {
if (this.expandValue < this.reflectValue)
this.doExpand();
else
this.doReflect();
} else if (this.reflectValue < this.secondValue) {
this.doReflect();
} else if (this.reflectValue < this.worstValue) {
if (this.outsideValue < this.worstValue)
this.doContractOutside();
else
this.doShrink();
} else if (this.insideValue < this.worstValue) {
this.doContractInside();
} else {
this.doShrink();
}
}
Demo.prototype.drawContours = function(svg) {
plotFunction(this.fn, svg);
}
Demo.prototype.showTriangle = function(svg) {
showTriangle(svg, this.triangle);
}
function clearAnimations(svg) {
for (const anim of Array.from(svg.$tag('animate')))
anim.remove();
}
// line1 must start with this circle, line2 must end with this circle
function slideCircleToPoint(circle, oldPoint, point, line1, line2) {
moveCircleToPoint(circle, oldPoint);
line1.setAttribute('x1', oldPoint.x);
line1.setAttribute('y1', oldPoint.y);
line2.setAttribute('x2', oldPoint.x);
line2.setAttribute('y2', oldPoint.y);
clearAnimations(circle);
// Move at a constant speed most of the time, but make sure we get there
// in time even if it means speeding up
const duration = Math.min((oldPoint.dist(point) / 32), 1.25) + 's';
if (point.x !== oldPoint.x) {
appendSvgNode(circle, 'animate', undefined, { attributeName: 'cx', dur: duration, from: oldPoint.x, to: point.x, fill: 'freeze' });
appendSvgNode(line1, 'animate', undefined, { attributeName: 'x1', dur: duration, from: oldPoint.x, to: point.x, fill: 'freeze' });
appendSvgNode(line2, 'animate', undefined, { attributeName: 'x2', dur: duration, from: oldPoint.x, to: point.x, fill: 'freeze' });
}
if (point.y !== oldPoint.y) {
appendSvgNode(circle, 'animate', undefined, { attributeName: 'cy', dur: duration, from: oldPoint.y, to: point.y, fill: 'freeze' });
appendSvgNode(line1, 'animate', undefined, { attributeName: 'y1', dur: duration, from: oldPoint.y, to: point.y, fill: 'freeze' });
appendSvgNode(line2, 'animate', undefined, { attributeName: 'y2', dur: duration, from: oldPoint.y, to: point.y, fill: 'freeze' });
}
for (const anim of circle.$tag('animate'))
anim.beginElement();
for (const anim of line1.$tag('animate'))
anim.beginElement();
for (const anim of line2.$tag('animate'))
anim.beginElement();
}
function slideTriangle(svg, oldTriangle, triangle) {
clearAnimations(svg.$class1('line1'));
clearAnimations(svg.$class1('line2'));
clearAnimations(svg.$class1('line3'));
slideCircleToPoint(svg.$class1('point1'), oldTriangle[0], triangle[0], svg.$class1('line1'), svg.$class1('line3'));
slideCircleToPoint(svg.$class1('point2'), oldTriangle[1], triangle[1], svg.$class1('line2'), svg.$class1('line1'));
slideCircleToPoint(svg.$class1('point3'), oldTriangle[2], triangle[2], svg.$class1('line3'), svg.$class1('line2'));
}
Demo.prototype.slideTriangle = function(svg) {
slideTriangle(svg, this.previousTriangle, this.triangle);
}
Demo.prototype.showColoredTriangle = function(svg) {
this.showTriangle(svg);
const cssClasses = ['point1', 'point2', 'point3'];
svg.$class1(cssClasses[this.worstIndex]).setAttribute('fill', d3.color(pointColor).darker());
svg.$class1(cssClasses[this.secondIndex]).setAttribute('fill', pointColor);
svg.$class1(cssClasses[this.bestIndex]).setAttribute('fill', d3.color(pointColor).brighter());
}
Demo.prototype.showMinima = function(svg) {
for (const circle of Array.from(svg.$class('minimum')))
circle.remove();
for (const [x, y] of this.fn.minima) {
const minCircle = makeSvgNode('circle', 'minimum', { r: radius, fill: '#d55' });
if (svg.$class1('point1'))
svg.insertBefore(minCircle, svg.$class1('point1'));
else
svg.appendChild(minCircle);
moveCircleToPoint(minCircle, { x, y });
}
}
Demo.prototype.showLocalMinima = function(svg) {
for (const circle of Array.from(svg.$class('localMinimum')))
circle.remove();
for (const [x, y] of this.fn.localMinima || []) {
const minCircle = makeSvgNode('circle', 'localMinimum', { r: radius, fill: '#ffa500' });
if (svg.$class1('point1'))
svg.insertBefore(minCircle, svg.$class1('point1'));
else
svg.appendChild(minCircle);
moveCircleToPoint(minCircle, { x, y });
}
}
Demo.prototype.showReflectPoint = function(svg) {
const circle = svg.$class1('reflectPoint') || makeSvgNode('circle', 'reflectPoint', { r: radius, fill: '#80bde3' });
svg.insertBefore(circle, svg.$class1('point1'));
moveCircleToPoint(circle, this.reflectPoint);
}
Demo.prototype.showExpandPoint = function(svg) {
const circle = svg.$class1('expandPoint') || makeSvgNode('circle', 'expandPoint', { r: radius, fill: '#ffa500' });
svg.insertBefore(circle, svg.$class1('point1'));
moveCircleToPoint(circle, this.expandPoint);
}
Demo.prototype.showInsideContractionPoint = function(svg) {
const circle = svg.$class1('contractPoint') || makeSvgNode('circle', 'contractPoint', { r: radius, fill: '#ffc0cb' });
svg.insertBefore(circle, svg.$class1('point1'));
moveCircleToPoint(circle, this.insidePoint);
}
Demo.prototype.showOutsideContractionPoint = function(svg) {
const circle = svg.$class1('contractPoint') || makeSvgNode('circle', 'contractPoint', { r: radius, fill: '#ffc0cb' });
svg.insertBefore(circle, svg.$class1('point1'));
moveCircleToPoint(circle, this.outsidePoint);
}
Demo.prototype.showShrinkPoints = function(svg) {
const circle1 = svg.$class1('shrinkPoint1') || makeSvgNode('circle', 'shrinkPoint1', { r: radius, fill: '#a0d468' });
svg.insertBefore(circle1, svg.$class1('point1'));
moveCircleToPoint(circle1, this.shrinkPoint1);
const circle2 = svg.$class1('shrinkPoint2') || makeSvgNode('circle', 'shrinkPoint2', { r: radius, fill: '#a0d468' });
svg.insertBefore(circle2, svg.$class1('point1'));
moveCircleToPoint(circle2, this.shrinkPoint2);
}
function moveLabelToPoint(label, point) {
label.setAttribute('x', point.x);
label.setAttribute('y', point.y);
}
function setLabelText(label, value) {
value = Math.round(value * 1000) / 1000;
label.textContent = value;
}
Demo.prototype.showLabels = function(svg) {
const cssClasses = ['label1', 'label2', 'label3'];
for (let i = 0; i < this.triangle.length; i++) {
const cssClass = cssClasses[i];
const label = svg.$class1(cssClass) || appendSvgNode(svg, 'text', cssClass, { transform: 'translate(-2.5 -2.5)' });
moveLabelToPoint(label, this.triangle[i]);
setLabelText(label, this.values[i]);
}
}
function insidePlot(point) {
return point.x >= 0 && point.x <= 100 && point.y >= 0 && point.y <= 100;
}
Demo.prototype.canReflect = function() {
return insidePlot(this.reflectPoint);
}
Demo.prototype.canExpand = function() {
return insidePlot(this.expandPoint);
}
Demo.prototype.canContractInside = function() {
return insidePlot(this.insidePoint);
}
Demo.prototype.canContractOutside = function() {
return insidePlot(this.outsidePoint);
}
Demo.prototype.canShrink = function() {
return this.shrunkTriangle.every(point => insidePlot(point));
}
function randomTriangle() {
function randInt() {
return Math.floor(Math.random() * 100);
}
return [
new Point(randInt(), randInt()),
new Point(randInt(), randInt()),
new Point(randInt(), randInt())
];
}
function triangleArea(triangle) {
// Area of a triangle is half the determinant of a 2x2 matrix,
// where the columns are two side vectors of the triangle
const dx1 = triangle[1].x - triangle[0].x;
const dx2 = triangle[2].x - triangle[0].x;
const dy1 = triangle[1].y - triangle[0].y;
const dy2 = triangle[2].y - triangle[0].y;
return Math.abs(dx1*dy2 - dx2*dy1) / 2;
}
function maxVertexDist(triangle) {
return Math.max(triangle[0].dist(triangle[1]), triangle[0].dist(triangle[2]), triangle[1].dist(triangle[2]));
}
let contoursDrawn = false;
function initAnimation(svg, fn) {
const demo = new Demo(fn);
if (!contoursDrawn) {
demo.drawContours(svg);
contoursDrawn = true;
}
demo.triangle = randomTriangle();
demo.showTriangle(svg);
demo.countdown = 3;
return demo;
}
function stepAnimation(svg, demo) {
if (maxVertexDist(demo.triangle) < 2) {
if (demo.countdown-- === 0) {
clearAnimations(svg);
return initAnimation(svg, demo.fn);
}
} else {
demo.doNelderMead();
demo.slideTriangle(svg);
}
return demo;
}
const svg2 = initDynamicSvg($id('example-moving-triangle'));
let demo1 = initAnimation(svg2, himmelblauFunction);
setInterval(() => {
demo1 = stepAnimation(svg2, demo1);
}, 1350);
</script>
<p>At each iteration, Nelder-Mead will apply one of four possible transformations to the triangle. Let's see them one by one (try dragging around the points or adjusting the coefficient values if you like):</p>
<div class='transformation-container'>
<div class='transformation-def'><b>Reflect.</b> Move the worst point through the middle of the other two.</div>
<div id='reflect-vis' class='transformation-vis plot'>
<input type='number' id='reflect-coeff' class='trans-coeff' value='1.0' step='0.1' min='0'>
</div>
</div>
<div class='transformation-container'>
<div class='transformation-def'><b>Expand.</b> Like Reflect, but it moves further.</div>
<div id='expand-vis' class='transformation-vis plot'>
<input type='number' id='expand-coeff' class='trans-coeff' value='2.0' step='0.1' min='0'>
</div>
</div>
<div class='transformation-container'>
<div class='transformation-def'><b>Contract.</b> Move the worst point towards the middle of the other two. Depending on the situation, Contract can either stop short of the opposite side, or move slightly past it. We will call these variants “Contract Inside” and “Contract Outside”.</div>
<div id='contract-inside-vis' class='transformation-vis plot'>
<input type='number' id='inside-coeff' class='trans-coeff' value='0.5' step='0.1' min='0' max='1.0'>
</div>
</div>
<div class='transformation-container'>
<div class='transformation-def'></div>
<div id='contract-outside-vis' class='transformation-vis plot'>
<input type='number' id='outside-coeff' class='trans-coeff' value='0.5' step='0.1' min='0'>
</div>
</div>
<div class='transformation-container'>
<div class='transformation-def'><b>Shrink.</b> Shrink the whole triangle towards the best point, maintaining its angles.</div>
<div id='shrink-vis' class='transformation-vis plot'>
<input type='number' id='shrink-coeff' class='trans-coeff' value='0.5' step='0.1' min='0' max='1.0'>
</div>
</div>
<script>
'use strict';
function plotMultipleSvgs(fn, svgs) {
const grid = gridFromFunction(fn, 100);
const contours = contoursFromGrid(grid, 100, fn.scale, 23);
for (const svg of svgs)
drawContours(svg, contours);
}
const svg3 = initDynamicSvg($id('reflect-vis'));
const svg4 = initDynamicSvg($id('expand-vis'));
const svg5 = initDynamicSvg($id('contract-inside-vis'));
const svg6 = initDynamicSvg($id('contract-outside-vis'));
const svg7 = initDynamicSvg($id('shrink-vis'));
plotMultipleSvgs(himmelblauFunction, [svg3, svg4, svg5, svg6, svg7]);
const exampleTriangle = [
new Point(33, 30),
new Point(37, 50),
new Point(60, 45)
];
function drawArrow(svg, cssClass, oldPoint, newPoint) {
const lineLength = newPoint.dist(oldPoint);
if (lineLength < 0.5) {
svg.$class1(cssClass)?.remove();
return; // Don't even draw line if it will be extremely short
}
const line = svg.$class1(cssClass) || appendSvgNode(svg, 'line', cssClass);
const marker = addArrowhead(svg);
// Arrowhead looks silly if the line is too short
line.setAttribute('marker-end', (lineLength > 3) ? `url(#${marker.id})` : '');
// Make line a bit shorter so arrowhead point ends on the desired point
const scaleFactor = Math.max((lineLength - 2.5) / lineLength, 0);
moveLineToPoints(line, oldPoint, oldPoint.moveTowards(newPoint, scaleFactor));
}
function addArrowhead(svg) {
let marker = svg.$class1('arrowhead');
if (!marker) {
marker = appendSvgNode(svg, 'marker', 'arrowhead', { markerWidth: 3, markerHeight: 2, refX: 0, refY: 1, orient: 'auto' });
marker.id = Math.floor(Math.random() * 10000000).toString();
appendSvgNode(marker, 'polygon', undefined, { points: '0 0, 3 1, 0 2', fill: '#666' });
}
return marker;
}
// Use an arrow to show where a node will go
function illustrateTransform(svg, oldPoint, newPoint) {
const point4 = svg.$class1('point4') || appendSvgNode(svg, 'circle', 'point4', { r: radius, fill: d3.color(pointColor).brighter() });
moveCircleToPoint(point4, newPoint);
drawArrow(svg, 'arrowline', oldPoint, newPoint);
}
function illustrateShrink(svg, oldPoint1, newPoint1, oldPoint2, newPoint2) {
const point4 = svg.$class1('point4') || appendSvgNode(svg, 'circle', 'point4', { r: radius, fill: d3.color(pointColor).brighter() });
moveCircleToPoint(point4, newPoint1);
const point5 = svg.$class1('point5') || appendSvgNode(svg, 'circle', 'point5', { r: radius, fill: d3.color(pointColor).brighter() });
moveCircleToPoint(point5, newPoint2);
drawArrow(svg, 'arrowline1', oldPoint1, newPoint1);
drawArrow(svg, 'arrowline2', oldPoint2, newPoint2);
}
// Build callback which can be used when a point is dragged around in visualization
// The same callback will also be called if the value of the coefficient is changed
function initTransformationVis(svg, triangle, fn, demoProperty, coeffInput, coeffName) {
const demo = new Demo(fn);
demo.triangle = triangle;
function callback(triangle) {
if (triangle)
demo.triangle = triangle;
illustrateTransform(svg, demo.worstPoint, demo[demoProperty]);
}
coeffInput.listen('input', () => {
demo.coefficients = Object.assign({}, demo.coefficients);
demo.coefficients[coeffName] = Number(coeffInput.value);
callback();
});
showDraggableTriangle(svg, triangle, callback);
}
function initShrinkVis(svg, triangle, fn, coeffInput) {
const demo = new Demo(fn);
demo.triangle = triangle;
function callback(triangle) {
if (triangle)
demo.triangle = triangle;
illustrateShrink(svg, demo.worstPoint, demo.shrinkPoint1, demo.secondPoint, demo.shrinkPoint2);
}
coeffInput.listen('input', () => {
demo.coefficients = Object.assign({}, demo.coefficients);
demo.coefficients.shrink = Number(coeffInput.value);
callback();
});
showDraggableTriangle(svg, triangle, callback);
}
initTransformationVis(svg3, exampleTriangle, himmelblauFunction, 'reflectPoint', $id('reflect-coeff'), 'reflect');
initTransformationVis(svg4, exampleTriangle, himmelblauFunction, 'expandPoint', $id('expand-coeff'), 'expand');
initTransformationVis(svg5, exampleTriangle, himmelblauFunction, 'insidePoint', $id('inside-coeff'), 'contract');
initTransformationVis(svg6, exampleTriangle, himmelblauFunction, 'outsidePoint', $id('outside-coeff'), 'contract');
initShrinkVis(svg7, exampleTriangle, himmelblauFunction, $id('shrink-coeff'));
</script>
<p>The magnitude of each transformation can be tuned by adjusting a coefficient. As shown in the above visualizations, the default coefficient values are 1.0, 2.0, 0.5, and 0.5.</p>
<p>To help get a feel for how these transformations can be used to explore the input space and find a minimum, let's play a game. The below square represents the space of input values for a function <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.1076em;">f</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.1667em;"></span><span class="mord mathnormal" style="margin-right: 0.0359em;">y</span><span class="mclose">)</span></span></span></span>. Just as Nelder-Mead only computes the function value at the corners of its triangle, I will only show you its value at those three points. Click on any three points to start, then click any of the five transformation buttons to transform your triangle. Once you contract the triangle to a sufficiently small size (or reach 50 iterations), I'll reveal the contours of the graph. You “win” only if your best point is close enough to a minimum point.</p>
<div id='game' class='plot'>
<div id='click-please'>
<span id='click-msg'>Click three points to start</span>
</div>
</div>
<div class='control-wrapper'>
<div class='btn-wrapper'>
<button type='button' id='reflect-btn' disabled>Reflect</button>
<button type='button' id='expand-btn' disabled>Expand</button>
<button type='button' id='inside-btn' disabled>Contract Inside</button>
<button type='button' id='outside-btn' disabled>Contract Outside</button>
<button type='button' id='shrink-btn' disabled>Shrink</button>
<button type='button' id='reset-game' disabled>Play Again</button>
</div>
<div id='final-score'></div>
</div>
<script>
'use strict';
const gameSvg = initDynamicSvg($id('game'));
function disableButtons() {
$id('reset-game').setAttribute('disabled', true);
$id('reflect-btn').setAttribute('disabled', true);
$id('expand-btn').setAttribute('disabled', true);
$id('inside-btn').setAttribute('disabled', true);
$id('outside-btn').setAttribute('disabled', true);
$id('shrink-btn').setAttribute('disabled', true);
}
function setButtonState(id, enabled) {
$id(id).disabled = !enabled;
}
function initGameState() {
const game = new Demo(testFunctions[Math.floor(Math.random() * testFunctions.length)]);
game.iterations = 0;
while (gameSvg.lastChild) {
gameSvg.removeChild(gameSvg.lastChild);
}
$id('final-score').textContent = '';
$id('click-please').style.display = 'flex';
disableButtons();
return game;
}
// Before starting to play, the user must lay down 3 points
function setPoint(clickEvent, game) {
if (game.triangle.length === 3)
return;
const cssClass = (['point1', 'point2', 'point3'])[game.triangle.length];
const [clickX, clickY] = getMousePosition(clickEvent, gameSvg);
const newPoint = new Point(clickX, clickY);
game.triangle = game.triangle.concat([newPoint]);
game.showLabels(gameSvg);
if (game.triangle.length === 3) {
$id('click-please').style.display = 'none';
game.showColoredTriangle(gameSvg);
startNextTurn(game);
} else {
game.showTriangle(gameSvg);
}
}
function startNextTurn(game) {
if (gameTerminationTest(game)) {
gameOver(game);
return;
}
// Decide which moves are valid, enable/disable buttons accordingly
setButtonState('reflect-btn', game.canReflect());
setButtonState('expand-btn', game.canExpand());
setButtonState('inside-btn', game.canContractInside());
setButtonState('outside-btn', game.canContractOutside());
setButtonState('shrink-btn', game.canShrink());
}
function stepGame(game) {
game.showColoredTriangle(gameSvg);
game.showLabels(gameSvg);
game.iterations += 1;
startNextTurn(game);
}
function gameTerminationTest(game) {
return maxVertexDist(game.triangle) < 9 || game.iterations >= 50;
}
function insideTriangle(triangle, x, y) {
// Convert (x,y) to barycentric coordinates
const [t1, t2, t3] = triangle;
const a = ((t2.y - t3.y)*(x - t3.x) + (t3.x - t2.x)*(y - t3.y)) / ((t2.y - t3.y)*(t1.x - t3.x) + (t3.x - t2.x)*(t1.y - t3.y));
const b = ((t3.y - t1.y)*(x - t3.x) + (t1.x - t3.x)*(y - t3.y)) / ((t2.y - t3.y)*(t1.x - t3.x) + (t3.x - t2.x)*(t1.y - t3.y));
const c = 1 - a - b;
return a >= 0 && a <= 1 && b >= 0 && b <= 1 && c >= 0 && c <= 1;
}
function gameOver(game) {
disableButtons();
$id('reset-game').removeAttribute('disabled');
game.drawContours(gameSvg);
game.showMinima(gameSvg);
const distToMin = Math.min(...game.fn.minima.map(([x,y]) => game.bestPoint.dist(new Point(x, y))));
console.log(`Best point was ${distToMin} from minimum point`);
const won = distToMin < 5.5;
if (won)
$id('final-score').textContent = `You found a minimum in ${game.iterations} steps. Congratulations!`;
else
$id('final-score').textContent = 'You did not find a minimum. Too bad! Maybe try again?';
}
let gameState = initGameState();
$id('reset-game').listen('click', () => gameState = initGameState());
$id('reflect-btn').listen('click', () => { gameState.doReflect(); stepGame(gameState); });
$id('expand-btn').listen('click', () => { gameState.doExpand(); stepGame(gameState); });
$id('inside-btn').listen('click', () => { gameState.doContractInside(); stepGame(gameState); });
$id('outside-btn').listen('click', () => { gameState.doContractOutside(); stepGame(gameState); });
$id('shrink-btn').listen('click', () => { gameState.doShrink(); stepGame(gameState); });
gameSvg.listen('click', event => setPoint(event, gameState));
</script>
<p>This game is quite difficult; don't be surprised if you hardly ever “win”. The Nelder-Mead algorithm doesn't always reach a minimum point, either; or at least, not in a reasonable number of iterations. Sometimes it gets <i>close</i> to a minimum point... and then moves <i>very, very</i> slowly towards it.</p>
<p>For that reason, when implementing Nelder-Mead, you need to limit the number of iterations so it doesn't run for too long. Rather than running Nelder-Mead for a huge number of iterations, you will probably get better results by restarting it several times, with different starting points, and then picking the best overall solution found.<sup><a href='#footnote2' id='fnref2'>[2]</a></sup></p>
<p>You may be wondering why the algorithm works the way it does. Here are some interesting questions to think about (click to reveal possible answers):</p>
<p><b>Why does the algorithm use a triangle rather than a single test point?</b></p>
<p class='reveal hidden'>Remember, one of the key characteristics of Nelder-Mead is that it doesn't require derivatives. However, by comparing the function's value at the corners of a triangle, it's as if Nelder-Mead is estimating the average value of the derivative<sup><a href='#footnote3' id='fnref3'>[3]</a></sup> across the area of the triangle, and using that approximated derivative to decide which direction to move in. This only works using <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.7667em; vertical-align: -0.0833em;"></span><span class="mord mathnormal" style="margin-right: 0.109em;">N</span><span class="mspace" style="margin-right: 0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.2222em;"></span></span><span class="base"><span class="strut" style="height: 0.6444em;"></span><span class="mord">1</span></span></span></span> points in an <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.6833em;"></span><span class="mord mathnormal" style="margin-right: 0.109em;">N</span></span></span></span>-dimensional space; less test points would not be enough to unambiguously determine a direction.</p>
<p><b>Why does the size of the triangle change? Why not use a fixed-size triangle which flips, rotates, and slides around the plane?</b></p>
<div class='reveal hidden'>
<p>Other optimization algorithms, such as gradient descent, typically have a “step size” which gradually decreases as the search progresses. The “step size” determines how far each new test point will be from the previous one. By starting with a large step size, the algorithm can quickly move through the search space at the beginning, which helps it to find a solution in fewer iterations. But as it approaches an optimum point, maintaining a large step size would be counterproductive; that would make the algorithm overshoot the optimum point, and possibly jump so far away that it couldn't find its way back. So the step size needs to gradually decrease as the search progresses.</p>
<p>For Nelder-Mead, the size of the triangle serves as a kind of “step size” parameter. The transformations which move the triangle across the plane all move it a distance proportional to its size. Just as the step size for gradient descent needs to decrease as it gets closer to a minimum point, as Nelder-Mead approaches a minimum point, it needs to use Contract or Shrink to make the triangle smaller.</p>
</div>
<p>The final piece of the algorithm, which I haven't described yet, is how it chooses which transformation to use on each iteration. Here is the procedure:</p>
<ol>
<li>Find the <b>reflection point</b> (the point which Reflect would move the worst point to) and compute the function's value there.</li>
<li>If the reflection point is better than the second-best point, but not better than the best point, then do Reflect.</li>
<li>Otherwise, if the reflection point is better than the best point, then check the <b>expansion point</b> (the one which Expand would move to). If the expansion point is better than the reflection point, do Expand. If not, do Reflect.</li>
<li>Otherwise, if the reflection point is worse than the second-best point but not worse than the worst point, check the <b>outside contraction point</b>. If it's better than the worst point, do Contract Outside. If not, do Shrink.</li>
<li>Finally, if the reflection point is worse than the worst point, check the <b>inside contraction point</b>. If it's better than the worst point, do Contract Inside. If not, do Shrink.</li>
</ol>
<p>Does that seem to make sense? Perhaps these comments might make it more understandable:</p>
<p>While we expect that our function's graph has some kind of curved surface, Nelder-Mead can't “see” the curve; it only knows the value at its three test points (as you experienced when playing the above <a href='#game'>game</a>). With only three numbers to work with, the best guess Nelder-Mead can make is that it should move the worst point in the direction of the better two. And in the absence of other information, a reasonable default is to move the worst point just far enough to maintain the size and shape of the triangle (that's what Reflect does). If the default was to move it more or less than that, then the triangle would tend to grow and grow or shrink and shrink, even when there was no reason to do so.</p>
<p>However, that guess isn't always right. Even if the triangle is sitting on a slope, it is possible that Reflect might overshoot the base of the slope and start going up the opposite slope. If the reflection point is <i>worse</i> than the existing points, that indicates that we are going too far and need to back off, perhaps by using Contract Inside or Contract Outside.</p>
<p>On the other hand, if the reflection point is <i>better</i> than all the existing points, that strongly suggests that the triangle really is on a slope and that Reflect really is moving in the right direction. In that case, we can try to go even further in the same direction with Expand. This not only moves the triangle further in a good direction, it also enlarges the triangle, which means the following steps will be bigger. In effect, as long as Nelder-Mead keeps picking good directions and each successive point is better and better than the previous ones, it will “accelerate downhill”. That helps the algorithm to move more quickly towards a minimum and converge in a smaller number of iterations.</p>
<p>As for Shrink, the original paper on the Nelder-Mead algorithm explained that Shrink is necessary for the algorithm to avoid getting stuck in some (rare) situations. One example is below. Try various combinations of Refresh, Expand, Contract Inside, and Contract Outside to see if you can get the triangle to close in on the minimum point (which is marked in red). Then try Shrink and see how it helps.</p>
<div id='shrink-demo' class='plot'></div>
<div class='control-wrapper'>
<div class='btn-wrapper'>
<button type='button' id='reflect-btn2' disabled>Reflect</button>
<button type='button' id='expand-btn2' disabled>Expand</button>
<button type='button' id='inside-btn2' disabled>Contract Inside</button>
<button type='button' id='outside-btn2' disabled>Contract Outside</button>
<button type='button' id='shrink-btn2' disabled>Shrink</button>
<button type='button' id='reset-shrink-demo'>Reset</button>
</div>
<div id='final-score'></div>
</div>
<script>
'use strict';
const shrinkDemoSvg = initDynamicSvg($id('shrink-demo'));
// I wanted to make up a function which would have a curved valley
// Started with a parabolic curve x = y², then took |x - y²| to get the
// distance to the parabolic curve along the y-axis...
// That creates a curved valley; then to make sure the minimum is at 0,0,
// add terms for x², y², x, and y
// The ironic thing is that while Shrink makes Nelder-Mead work a lot better
// on this function, it can't actually find the absolute minimum without an
// ENORMOUS number of iterations
// It does get fairly close, though
function shrinkDemoFunction(x, y) {
x = (x - 50) / 4;
y = (70 - y) / 1.5;
return (5 * Math.abs(y - x*x)) + (0.02 * (x*x + y*y)) + (0.002 * Math.abs(y)) + (0.002 * Math.abs(x));
}
shrinkDemoFunction.scale = d3.scalePow().exponent(2.2);
shrinkDemoFunction.minima = [[50,70]];
let shrinkDemo;
function initShrinkDemo() {
shrinkDemo = new Demo(shrinkDemoFunction);
shrinkDemo.drawContours(shrinkDemoSvg);
shrinkDemo.triangle = [
new Point(50, 82),
new Point(34, 40),
new Point(66, 40)
];
shrinkDemo.showColoredTriangle(shrinkDemoSvg);
shrinkDemo.showMinima(shrinkDemoSvg);
setShrinkDemoButtonState();
}
function setShrinkDemoButtonState() {
setButtonState('reflect-btn2', shrinkDemo.canReflect());
setButtonState('expand-btn2', shrinkDemo.canExpand());
setButtonState('inside-btn2', shrinkDemo.canContractInside());
setButtonState('outside-btn2', shrinkDemo.canContractOutside());
setButtonState('shrink-btn2', shrinkDemo.canShrink());
}
function stepShrinkDemo() {
shrinkDemo.showColoredTriangle(shrinkDemoSvg);
setShrinkDemoButtonState();
}
initShrinkDemo();
$id('reflect-btn2').listen('click', () => { shrinkDemo.doReflect(); stepShrinkDemo(); });
$id('expand-btn2').listen('click', () => { shrinkDemo.doExpand(); stepShrinkDemo(); });
$id('inside-btn2').listen('click', () => { shrinkDemo.doContractInside(); stepShrinkDemo(); });
$id('outside-btn2').listen('click', () => { shrinkDemo.doContractOutside(); stepShrinkDemo(); });
$id('shrink-btn2').listen('click', () => { shrinkDemo.doShrink(); stepShrinkDemo(); });
$id('reset-shrink-demo').listen('click', initShrinkDemo);
</script>
<p>This last visualization will show all the points which Nelder-Mead considers on each iteration, and why it chooses the move which it does. Click the 'Next' button to move forward. Reflection points will be shown in <span style='color:#80bde3'>⬤</span> blue, expansion points in <span style='color:#ffa500'>⬤</span> orange, contraction points in <span style='color:#ffc0cb'>⬤</span> pink, and shrink points in <span style='color:#a0d468'>⬤</span> green.</p>
<div id='decision-demo' class='plot'></div>
<div id='decision-foot'>
<div id='step-explanation'></div>
<div id='step-btns'>
<button id='decision-step' type='button'>Next</button>
<button id='decision-reset' type='button'>Reset</button>
</div>
</div>
<script>
'use strict';
const demoSvg = initDynamicSvg($id('decision-demo'));
const explanation = $id('step-explanation');
let demoState;
function initDecisionDemo() {
const fn = testFunctions[Math.floor(Math.random() * testFunctions.length)];
demoState = new Demo(fn);
demoState.iterations = 0;
demoState.triangle = randomTriangle();
// If we randomly picked a 'bad' starting configuration (one which makes
// the demo uninteresting), then pick another
while (demoTerminationTest(demoState) || !insidePlot(demoState.expandPoint))
demoState.triangle = randomTriangle();
demoState.drawContours(demoSvg);
demoState.showMinima(demoSvg);
$id('decision-step').disabled = false;
prepareNextStep();
}
function demoTerminationTest(state) {
return maxVertexDist(state.triangle) < 6 || state.iterations >= 100;
}
function roundNumber(number) {
return Math.round(number * 1000) / 1000;
}
function prepareNextStep() {
demoState.showColoredTriangle(demoSvg);
demoState.showLabels(demoSvg);
if (demoTerminationTest(demoState)) {
endDemo();
return;
}
demoState.showReflectPoint(demoSvg);
if (demoState.reflectValue < demoState.bestValue) {
explanation.innerHTML = `Reflect point (${roundNumber(demoState.reflectValue)}) is better than all 3 points.<br>`;
demoState.showExpandPoint(demoSvg);
demoSvg.$class1('contractPoint')?.remove();
demoSvg.$class1('shrinkPoint1')?.remove();
demoSvg.$class1('shrinkPoint2')?.remove();
demoSvg.$class1('arrowline2')?.remove();
if (demoState.expandValue < demoState.reflectValue) {
explanation.innerHTML += `Expand point (${roundNumber(demoState.expandValue)}) is better still. Full speed ahead! <b>Expand!</b>`;
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.expandPoint);
} else {
explanation.innerHTML += `Expand point (${roundNumber(demoState.expandValue)}) is not as good as that, so let's hold back a bit. <b>Reflect!</b>`;
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.reflectPoint);
}
} else if (demoState.reflectValue < demoState.secondValue) {
explanation.innerHTML = `Reflect point (${roundNumber(demoState.reflectValue)}) is at least better than 2 of the 3 points. That's good enough for me. <b>Reflect!</b>`;
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.reflectPoint);
demoSvg.$class1('contractPoint')?.remove();
demoSvg.$class1('expandPoint')?.remove();
demoSvg.$class1('shrinkPoint1')?.remove();
demoSvg.$class1('shrinkPoint2')?.remove();
demoSvg.$class1('arrowline2')?.remove();
} else if (demoState.reflectValue < demoState.worstValue) {
explanation.innerHTML = `Reflect point (${roundNumber(demoState.reflectValue)}) is better than the worst point only. This suggests we might be pushing a bit too far.<br>`;
demoState.showOutsideContractionPoint(demoSvg);
demoSvg.$class1('expandPoint')?.remove();
if (demoState.outsideValue < demoState.worstValue) {
explanation.innerHTML += `Sure enough, outside contraction point (${roundNumber(demoState.outsideValue)}) is better than reflect point. <b>Contract Outside.</b>`;
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.outsidePoint);
demoSvg.$class1('shrinkPoint1')?.remove();
demoSvg.$class1('shrinkPoint2')?.remove();
demoSvg.$class1('arrowline2')?.remove();
} else {
explanation.innerHTML += `Hmm. Outside contraction point (${roundNumber(demoState.outsideValue)}) isn't any better. <b>Shrink.</b>`;
demoState.showShrinkPoints(demoSvg);
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.shrinkPoint1);
drawArrow(demoSvg, 'arrowline2', demoState.secondPoint, demoState.shrinkPoint2);
}
} else {
explanation.innerHTML = `Oh, my. Reflect point (${roundNumber(demoState.reflectValue)}) is worse than all 3 points. Maybe we are close to a minimum, or maybe we need to change direction.<br>`;
demoState.showInsideContractionPoint(demoSvg);
demoSvg.$class1('expandPoint')?.remove();
if (demoState.insideValue < demoState.worstValue) {
explanation.innerHTML += `Inside contraction point (${roundNumber(demoState.insideValue)}) is at least an improvement. <b>Contract Inside.</b>`;
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.insidePoint);
demoSvg.$class1('shrinkPoint1')?.remove();
demoSvg.$class1('shrinkPoint2')?.remove();
demoSvg.$class1('arrowline2')?.remove();
} else {
explanation.innerHTML += `Hmm. Inside contraction point (${roundNumber(demoState.insideValue)}) is also worse than all 3 points. We are in danger of getting stuck. <b>Shrink.</b>`;
demoState.showShrinkPoints(demoSvg);
drawArrow(demoSvg, 'arrowline', demoState.worstPoint, demoState.shrinkPoint1);
drawArrow(demoSvg, 'arrowline2', demoState.secondPoint, demoState.shrinkPoint2);
}
}
}
function takeNextStep() {
demoState.doNelderMead();
demoState.iterations += 1;
prepareNextStep();
}
function endDemo() {
demoSvg.$class1('arrowline')?.remove();
demoSvg.$class1('reflectPoint')?.remove();
demoSvg.$class1('contractPoint')?.remove();
demoSvg.$class1('expandPoint')?.remove();
$id('decision-step').disabled = true;
if (demoState.iterations >= 100)
explanation.innerHTML = "We've already done 100 iterations. I think it's time to give up.<br>";
else
explanation.innerHTML = "Our triangle is becoming small, so we'll stop here.<br>";
const dist = Math.min(...demoState.fn.minima.map(([x, y]) => demoState.bestPoint.dist(new Point(x, y))));
console.log(`Best point was ${dist} from minimum point`);
const won = demoState.fn.minima.some((x,y) => insideTriangle(demoState.triangle, x, y)) || dist < 2.5;
if (!won && demoState.fn.localMinima) {
const localMinimaDist = Math.min(...demoState.fn.localMinima.map(([x,y]) => demoState.bestPoint.dist(new Point(x, y))));
console.log(`Best point was ${localMinimaDist} from local minimum point`);
if (localMinimaDist < 6) {
explanation.innerHTML += "It looks like we got caught by a local minimum and missed the global minimum.";
return;
}
}
if (won)
explanation.innerHTML += "And... we found a minimum point! Hooray!";
else if (dist < 6)
explanation.innerHTML += "We're not right on the minimum point, but it's close. Good enough for me.";
else if (dist < 15)
explanation.innerHTML += "We're not so far from the minimum point, though it would have been nice to get closer.";
else
explanation.innerHTML += "And... Nelder-Mead is not looking so hot in this case. Well, maybe it will do better on the next one.";
}
$id('decision-step').listen('click', takeNextStep);
$id('decision-reset').listen('click', initDecisionDemo);
initDecisionDemo();
</script>
<p>For a sample implementation of Nelder-Mead optimization in JavaScript, see the latter part of Justin Meiners' <a href='https://github.com/justinmeiners/why-train-when-you-can-optimize/blob/4d9a73995800ee7c80f2a7c08bcc2a221449a79d/docs/src/math.js#L359'>math.js.</a></p>
<p><i><b>Special thanks</b> to Justin Meiners for the <a href='https://www.jmeiners.com/why-train-when-you-can-optimize/'>article</a> which inspired this post, to Peter Collingridge for his <a href='https://www.petercollingridge.co.uk/tutorials/svg/interactive/dragging/'>helpful post on making SVG elements draggable</a>, to Mike Bostock for D3.js, to John Nelder and Peter Mead for their lovely algorithm, and... to you for reading all the way to the end!</i></p>
<p id='footnote1' style='font-size: 0.8em; margin-top: 2.5em'>[1] The version of the algorithm which searches for minimum points is all we need anyways, since if we want to find a maximum point, we can just search for a minimum of the negated function <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.0359em;">g</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.2778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.1076em;">f</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> instead. <a href='#fnref1'>⏎</a></p>
<p id='footnote2' style='font-size: 0.8em'>[2] In <a href='http://enricoschumann.net/R/restartnm.htm'>this post by Enrico Schumann</a>, the performance of a single run of Nelder-Mead with N iterations is compared empirically with 4 runs of N/4 iterations each, over varying values of N. The version which restarts 4 times completely dominates the one which doesn't restart. <a href='#fnref2'>⏎</a></p>
<p id='footnote3' style='font-size: 0.8em;'>[3] To be precise, the right term is “gradient”. <a href='#fnref3'>⏎</a></p>
<script>
'use strict';
function revealText(event) {
event.preventDefault();
this.classList.remove('hidden');
this.removeEventListener('click', revealText);
}
document.querySelectorAll('.reveal').forEach(el => el.listen('click', revealText));
</script>Recently, I ran across a fantastic article, “Why Train When You Can Optimize?”, which introduced me to the Nelder-Mead optimization algorithm. It's a lovely algorithm, and I couldn't wait to create an interactive version. First, though: what does “optimization” mean in this context? An “optimization” algorithm takes some mathematical function as its input, and tries to find values for the parameters which make the output either as large or as small as possible. If you are like many computer programmers, your first impression might be that you are unlikely to ever use such an algorithm in your own programs. But optimization is a much more general and useful technique than it might seem. The article mentioned above gives a great example: a drawing program which detects when the user is trying to draw a straight line and replaces their jittery line with a perfectly straight one. (If you know other examples of good uses for optimization outside science and engineering, please let me know!) Obviously, finding inputs for some function which give you the largest or smallest output value can be done without a special algorithm. You could just use brute force: test many inputs and pick the best one. But if the space of possible inputs is large, that could be too slow. Optimization algorithms typically avoid exhaustively searching the input space by starting at some arbitrary point, then repeatedly searching for a nearby point which is better, until it hits a maximum or minimum and can't find any better point. Many such algorithms require that you know how to calculate the derivative (slope) of the function at any given point; but Nelder-Mead doesn't need any derivatives, and combined with its general simplicity, this makes it easy to apply. Now let me show you how Nelder-Mead works. Rather than starting with one test point and iteratively improving it, Nelder-Mead starts with N+1 test points, when the input space has N dimensions. (Or, in other words, when there are N different input variables whose values need to be found.) For example, if your function has two parameters, there will be 3 starting points, which will form a triangle in the 2-dimensional plane of possible inputs: (From here on, all examples will be 2-dimensional, but the algorithm generalizes naturally to any number of dimensions. Further, we will assume that we are searching for a minimum rather than a maximum.[1]) Nelder-Mead repeatedly transforms the triangle of test points, replacing the worst point with a better one. This causes the triangle to move across the plane in whichever direction the function's value is dropping, and then contract around a local minimum when it finds one. When the triangle becomes small enough, then the algorithm terminates. Like this: At each iteration, Nelder-Mead will apply one of four possible transformations to the triangle. Let's see them one by one (try dragging around the points or adjusting the coefficient values if you like): Reflect. Move the worst point through the middle of the other two. Expand. Like Reflect, but it moves further. Contract. Move the worst point towards the middle of the other two. Depending on the situation, Contract can either stop short of the opposite side, or move slightly past it. We will call these variants “Contract Inside” and “Contract Outside”. Shrink. Shrink the whole triangle towards the best point, maintaining its angles. The magnitude of each transformation can be tuned by adjusting a coefficient. As shown in the above visualizations, the default coefficient values are 1.0, 2.0, 0.5, and 0.5. To help get a feel for how these transformations can be used to explore the input space and find a minimum, let's play a game. The below square represents the space of input values for a function f(x,y). Just as Nelder-Mead only computes the function value at the corners of its triangle, I will only show you its value at those three points. Click on any three points to start, then click any of the five transformation buttons to transform your triangle. Once you contract the triangle to a sufficiently small size (or reach 50 iterations), I'll reveal the contours of the graph. You “win” only if your best point is close enough to a minimum point. Click three points to start Reflect Expand Contract Inside Contract Outside Shrink Play Again This game is quite difficult; don't be surprised if you hardly ever “win”. The Nelder-Mead algorithm doesn't always reach a minimum point, either; or at least, not in a reasonable number of iterations. Sometimes it gets close to a minimum point... and then moves very, very slowly towards it. For that reason, when implementing Nelder-Mead, you need to limit the number of iterations so it doesn't run for too long. Rather than running Nelder-Mead for a huge number of iterations, you will probably get better results by restarting it several times, with different starting points, and then picking the best overall solution found.[2] You may be wondering why the algorithm works the way it does. Here are some interesting questions to think about (click to reveal possible answers): Why does the algorithm use a triangle rather than a single test point? Remember, one of the key characteristics of Nelder-Mead is that it doesn't require derivatives. However, by comparing the function's value at the corners of a triangle, it's as if Nelder-Mead is estimating the average value of the derivative[3] across the area of the triangle, and using that approximated derivative to decide which direction to move in. This only works using N+1 points in an N-dimensional space; less test points would not be enough to unambiguously determine a direction. Why does the size of the triangle change? Why not use a fixed-size triangle which flips, rotates, and slides around the plane? Other optimization algorithms, such as gradient descent, typically have a “step size” which gradually decreases as the search progresses. The “step size” determines how far each new test point will be from the previous one. By starting with a large step size, the algorithm can quickly move through the search space at the beginning, which helps it to find a solution in fewer iterations. But as it approaches an optimum point, maintaining a large step size would be counterproductive; that would make the algorithm overshoot the optimum point, and possibly jump so far away that it couldn't find its way back. So the step size needs to gradually decrease as the search progresses. For Nelder-Mead, the size of the triangle serves as a kind of “step size” parameter. The transformations which move the triangle across the plane all move it a distance proportional to its size. Just as the step size for gradient descent needs to decrease as it gets closer to a minimum point, as Nelder-Mead approaches a minimum point, it needs to use Contract or Shrink to make the triangle smaller. The final piece of the algorithm, which I haven't described yet, is how it chooses which transformation to use on each iteration. Here is the procedure: Find the reflection point (the point which Reflect would move the worst point to) and compute the function's value there. If the reflection point is better than the second-best point, but not better than the best point, then do Reflect. Otherwise, if the reflection point is better than the best point, then check the expansion point (the one which Expand would move to). If the expansion point is better than the reflection point, do Expand. If not, do Reflect. Otherwise, if the reflection point is worse than the second-best point but not worse than the worst point, check the outside contraction point. If it's better than the worst point, do Contract Outside. If not, do Shrink. Finally, if the reflection point is worse than the worst point, check the inside contraction point. If it's better than the worst point, do Contract Inside. If not, do Shrink. Does that seem to make sense? Perhaps these comments might make it more understandable: While we expect that our function's graph has some kind of curved surface, Nelder-Mead can't “see” the curve; it only knows the value at its three test points (as you experienced when playing the above game). With only three numbers to work with, the best guess Nelder-Mead can make is that it should move the worst point in the direction of the better two. And in the absence of other information, a reasonable default is to move the worst point just far enough to maintain the size and shape of the triangle (that's what Reflect does). If the default was to move it more or less than that, then the triangle would tend to grow and grow or shrink and shrink, even when there was no reason to do so. However, that guess isn't always right. Even if the triangle is sitting on a slope, it is possible that Reflect might overshoot the base of the slope and start going up the opposite slope. If the reflection point is worse than the existing points, that indicates that we are going too far and need to back off, perhaps by using Contract Inside or Contract Outside. On the other hand, if the reflection point is better than all the existing points, that strongly suggests that the triangle really is on a slope and that Reflect really is moving in the right direction. In that case, we can try to go even further in the same direction with Expand. This not only moves the triangle further in a good direction, it also enlarges the triangle, which means the following steps will be bigger. In effect, as long as Nelder-Mead keeps picking good directions and each successive point is better and better than the previous ones, it will “accelerate downhill”. That helps the algorithm to move more quickly towards a minimum and converge in a smaller number of iterations. As for Shrink, the original paper on the Nelder-Mead algorithm explained that Shrink is necessary for the algorithm to avoid getting stuck in some (rare) situations. One example is below. Try various combinations of Refresh, Expand, Contract Inside, and Contract Outside to see if you can get the triangle to close in on the minimum point (which is marked in red). Then try Shrink and see how it helps. Reflect Expand Contract Inside Contract Outside Shrink Reset This last visualization will show all the points which Nelder-Mead considers on each iteration, and why it chooses the move which it does. Click the 'Next' button to move forward. Reflection points will be shown in ⬤ blue, expansion points in ⬤ orange, contraction points in ⬤ pink, and shrink points in ⬤ green. Next Reset For a sample implementation of Nelder-Mead optimization in JavaScript, see the latter part of Justin Meiners' math.js. Special thanks to Justin Meiners for the article which inspired this post, to Peter Collingridge for his helpful post on making SVG elements draggable, to Mike Bostock for D3.js, to John Nelder and Peter Mead for their lovely algorithm, and... to you for reading all the way to the end! [1] The version of the algorithm which searches for minimum points is all we need anyways, since if we want to find a maximum point, we can just search for a minimum of the negated function g(x)=−f(x) instead. ⏎ [2] In this post by Enrico Schumann, the performance of a single run of Nelder-Mead with N iterations is compared empirically with 4 runs of N/4 iterations each, over varying values of N. The version which restarts 4 times completely dominates the one which doesn't restart. ⏎ [3] To be precise, the right term is “gradient”. ⏎JPEG Series, Part II: Huffman Coding2021-05-16T00:00:00+00:002021-05-16T00:00:00+00:00/huffman-coding<p>The <a href='/visualizing-the-idct/'>previous article in this series</a> explored how JPEG compression converts pixel values to DCT coefficients. A later stage of the compression process uses either a method called "Huffman coding" or another called "arithmetic coding" to store those coefficients in a compact manner. The Huffman coding algorithm is very simple, but powerful and widely used. If you've never learned how it works, I promise this will be interesting.</p>
<hr>
<p>You have some data to compress. Your data can be viewed as a sequence of values; perhaps Unicode codepoints, pixel color samples, audio amplitude samples, or something comparable. In the context of Huffman coding, each of these values is called a "symbol".</p>
<p>We are going to encode each symbol using a unique, variable-length string of bits. For example, if each symbol is a letter, the letter "a" could be "10111", "b" could be "00011", and so on. We can pick any series of bits we like to represent each symbol, with the restriction that after all the symbols are converted to bitstrings, and all those little bitstrings are smashed together, it must be possible to figure out what the original sequence of symbols was.</p>
<p>That means the bitstrings for each symbol must be <b>prefix-free</b>; none of them can be a prefix of another. Example of a <i>non</i>-prefix-free code: say we choose "1100" to represent the letter "a", "11" to represent "b", and "00" for "c". When decoding text, we find "1100" somewhere. How are we supposed to know whether that was originally a letter "a", or the two letters "bc"? It's impossible to tell, precisely because "11" is a prefix of "1100". Fortunately, Huffman codes are <i>always</i> prefix-free.</p>
<p>Let's encode this sentence with such a code. Click the button below, and the computer will generate a random prefix-free code. Try clicking a number of times and see what the smallest total number of bits required to encode the sentence appears to be.</p>
<div class='container'>
<div class='stiff'>
<table id='code0' class='codetbl fillparent'><tbody></tbody></table>
</div>
<div id='encoded0' class='encoded stretchy'></div>
</div>
<a id='genrand0' class='button'>Generate Random Code</a>
<script>
'use strict';
class Minheap {
/* `comparator` must return true if first argument is 'larger' than second */
constructor(comparator) {
this.heap = [];
this.compare = comparator;
}
get length() {
return this.heap.length;
}
insert(item) {
let index = this.heap.length;
while (index > 0) {
const parentIndex = ((index + 1) >>> 1) - 1;
if (this.compare(item, this.heap[parentIndex]))
break;
this.heap[index] = this.heap[parentIndex];
index = parentIndex;
}
this.heap[index] = item;
}
/* Remove and return the smallest item in the heap */
pop() {
const result = this.heap[0], item = this.heap.pop();
/* If the heap is not empty, move items upward to restore the heap property,
* until we find an appropriate place to put `item` */
if (this.heap.length) {
let index = 0;
while (true) {
const leftIndex = (index << 1) + 1, rightIndex = leftIndex + 1;
let childIndex = leftIndex;
if (rightIndex < this.heap.length) {
if (this.compare(this.heap[leftIndex], this.heap[rightIndex]))
childIndex = rightIndex;
} else if (leftIndex >= this.heap.length) {
break;
}
if (this.compare(item, this.heap[childIndex])) {
this.heap[index] = this.heap[childIndex];
index = childIndex;
} else {
break;
}
}
this.heap[index] = item;
}
return result;
}
}
/* Count how many times each character appears in a string */
function histogram(string) {
const histogram = new Map();
for (const char of string)
histogram.set(char, (histogram.get(char) || 0) + 1);
return histogram;
}
function symbols(histogram) {
return Array.from(histogram).map(([char, count]) => ({ value: char, weight: count }));
}
function huffmanTree(symbols) {
const heap = new Minheap((a,b) => a.weight > b.weight);
for (const symbol of symbols)
heap.insert(symbol);
while (heap.length > 1) {
const a = heap.pop(), b = heap.pop();
heap.insert({ value: [a, b], weight: a.weight + b.weight });
}
return heap.pop();
}
function shuffle(array) {
let index = array.length;
while (index > 0) {
const randIndex = Math.floor(Math.random() * index--);
const temp = array[index];
array[index] = array[randIndex];
array[randIndex] = temp;
}
return array;
}
function randomTree(symbols) {
const array = Array.from(symbols);
while (array.length > 1) {
shuffle(array);
const a = array.pop(), b = array.pop();
array.push({ value: [a, b], weight: a.weight + b.weight });
}
return array[0];
}
function dictionary(tree, map = new Map(), prefix = '') {
if (tree.value instanceof Array) {
dictionary(tree.value[0], map, prefix + '0');
dictionary(tree.value[1], map, prefix + '1');
} else {
map.set(tree.value, prefix);
}
return map;
}
function randomPrefixFreeCode(string) {
const hist = histogram(string),
sym = symbols(hist),
tree = randomTree(sym),
dict = dictionary(tree);
return {
string: string,
histogram: hist,
symbols: sym,
tree: tree,
dictionary: dict,
bitLength: sym.reduce((sum,sym) => sum + (sym.weight * dict.get(sym.value).length), 0)
};
}
function addTableRow(tbody, values) {
const tr = document.createElement('tr');
for (const value of values) {
const td = document.createElement('td');
td.innerText = value;
tr.appendChild(td);
}
tbody.appendChild(tr);
}
function showCodingTableWithLength(code, table, orderBy) {
table.innerHTML = '';
const tbody = table.tBodies[0] || table.createTBody();
for (const sym of Array.from(code.symbols).sort(orderBy)) {
const bitstring = code.dictionary.get(sym.value);
const displayValue = (sym.value === ' ') ? '_' : sym.value;
addTableRow(tbody, [displayValue, bitstring, `${bitstring.length} bits × ${sym.weight} occurrences = ${bitstring.length * sym.weight} bits`]);
}
addTableRow(tbody, ['', '', `Total encoded bit length: ${code.bitLength}`]);
}
function displayCodedSymbol(char, dict) {
const elem = document.createElement('span');
elem.classList.add('coded');
elem.innerText = (char === ' ' ? '_' : char) + "\n" + dict.get(char);
return elem;
}
function showCodedText(container, text, dictionary) {
container.textContent = '';
for (var char of text) {
container.appendChild(displayCodedSymbol(char, dictionary));
}
}
function comparator(keyFn) {
return function(a, b) {
const keyA = keyFn(a), keyB = keyFn(b);
if (keyA > keyB)
return 1;
else if (keyA < keyB)
return -1;
else
return 0;
}
}
var alphabeticOrder = comparator((sym) => sym.value);
function $id(id) {
return document.getElementById(id);
}
const sentence = "Let's encode this sentence with such a code";
$id('genrand0').addEventListener('click', () => {
const code = randomPrefixFreeCode(sentence);
showCodingTableWithLength(code, $id('code0'), alphabeticOrder);
showCodedText($id('encoded0'), sentence, code.dictionary);
});
$id('genrand0').click();
</script>
<p>Obviously, the number of bits required to encode a sequence can vary wildly depending on the chosen code. Can you see <i>why</i> some prefix-free codes are more efficient than others? What is the key difference between an efficient code and an inefficient one? (Click to reveal.)</p>
<div class='reveal hidden'>
<p>Efficient codes use the <i>shortest</i> bitstrings for the <i>most common</i> symbols.</p>
<p>At the same time, since the code must be prefix-free, if we use a very short bitstring such as "0" for a common symbol, that means "00", "01", "000", "001", and so on will all become unavailable. So we need to strike a balance. While we want to use short bitstrings for common symbols, we don't want to be forced to use excessively long bitstrings for less common symbols.</p>
</div>
<p>Out of the vast number of prefix-free codes which <i>could</i> be used, we want to find an <b>optimal</b> one; one which will encode our particular data in the smallest number of bits. (There will always be many optimal codes for any sequence, but we just need to find one of them.) At first, it might appear that we need to try millions of possible codes to be sure that we have an optimal one. Fortunately, that is not the case. Just count how many times each symbol appears in the input data, and in an almost trivially simple way, you <i>can</i> find an optimal prefix-free code. It will be easy, and fun!</p>
<p>I could tell you the algorithm right now, but it will be so much more enjoyable to discover it yourself. So I'll take this slow, and reason towards a solution step by deliberate step. If at any point you catch the scent of the solution, stop and think it out before continuing.</p>
<p>The first step is to represent a prefix-free code as a <b>binary tree</b>. Have a look:</p>
<div class='container'>
<div class='stiff'>
<table id='code1' class='codetbl fillparent'><tbody></tbody></table>
</div>
<div id='tree0' class='tree stiff' style='flex-basis: 100%'></div>
</div>
<a id='genrand1' class='button'>Generate Random Code</a>
<script type='module'>
import { renderTree } from '/assets/js/treeviz.js';
function showCodingTable(code, table, orderBy) {
table.innerHTML = '';
const tbody = table.tBodies[0] || table.createTBody();
for (const sym of Array.from(code.symbols).sort(orderBy)) {
const bitstring = code.dictionary.get(sym.value);
const displayValue = (sym.value === ' ') ? '_' : sym.value;
addTableRow(tbody, [displayValue, bitstring]);
}
}
function showCode1(code, table, treeDiv) {
showCodingTable(code, table, alphabeticOrder);
const svg = renderTree(code.tree, { width: 700, height: 450, nodeRadius: 50, nodeSpacing: 150, labelLines: true });
treeDiv.innerHTML = '';
treeDiv.appendChild(svg);
}
$id('genrand1').addEventListener('click', () => {
showCode1(randomPrefixFreeCode(sentence), $id('code1'), $id('tree0'));
});
$id('genrand1').click();
</script>
<p>Please make sure that you clearly see the correspondence between coding tables and binary trees before continuing.</p>
<p>We can see that the number of bits used to encode each symbol equals the number of links between its tree node and the root node, also called the "depth" of the node. Leaf nodes which are closer to the root (smaller depth) have shorter bitstrings.</p>
<p>We will add a <b>weight</b> to each leaf node, which is simply the number of times its symbol appears in the input data:</p>
<div style='border: 1px solid #aaa; --ratio:70/45'>
<div id='tree1' class='tree' style='text-align: center'></div>
</div>
<a id='genrand2' class='button'>Generate Random Code</a>
<script type='module'>
import { renderTree } from '/assets/js/treeviz.js';
$id('genrand2').addEventListener('click', () => {
const code = randomPrefixFreeCode(sentence);
const svg = renderTree(code.tree, { width: 700, height: 450, nodeRadius: 50, nodeSpacing: 150, showWeight: true });
const treeDiv = $id('tree1');
const textNode = document.createElement('h3');
textNode.style.marginTop = '0.5em';
textNode.innerText = 'Characters from: "' + sentence + '"';
treeDiv.innerHTML = '';
treeDiv.appendChild(textNode);
treeDiv.appendChild(svg);
});
$id('genrand2').click();
</script>
<p>Now the total length in bits of the compressed output will equal <b>weight times depth</b>, summed over all the leaf nodes.</p>
<p>So now our goal is to find a binary tree which minimizes the sum of weight times depth. We don't really have an idea how to do that, though. At least we do know what the leaf nodes of the tree should be:</p>
<div style='border: 1px solid #aaa; --ratio:70/15'>
<div id='tree2' class='tree'></div>
</div>
<script type='module'>
import { renderTrees } from '/assets/js/treeviz.js';
/* Just draw leaf nodes for the tree with the characters in `sentence` */
$id('tree2').appendChild(renderTrees(symbols(histogram(sentence)), { width: 700, height: 150, nodeRadius: 10, nodeSpacing: 25, showWeight: true }));
</script>
<p>How are we going to find the right structure for the internal nodes? Well, we could try to do it top-down, meaning we figure out what child nodes the root node should have, then the nodes below those, and so on. Or we could work bottom-up, meaning we figure out which leaf nodes should become children of the same parent node, then find which parent nodes should be children of the same "grandparent" node, until the whole tree is joined together. A third option would be to work both up and down from the middle, but that is just as hopeless as it sounds. These animations may help you understand "top-down" and "bottom-up" tree construction:</p>
<div style='--ratio:160/45'>
<div class='container'>
<div id='animation0' class='tree' style='flex-basis: 45%; text-align: center; border-right: 1px solid #aaa'>
<h3 style='margin-top: 0.35em'>Bottom-up</h3>
</div>
<div id='animation1' class='tree' style='flex-basis: 55%; text-align: center'>
<h3 style='margin-top: 0.35em'>Top-down</h3>
</div>
</div>
</div>
<a id='genrand3' class='button'>Generate Random Examples</a>
<script type='module'>
import { animateRandomBottomUpTree, animateRandomTopDownTree } from '/assets/js/treeviz.js';
function removeNode(node) {
if (node)
node.parentNode.removeChild(node);
}
const sym = symbols(histogram(sentence)).slice(0, 6);
$id('genrand3').addEventListener('click', () => {
const animation0 = animateRandomBottomUpTree(sym, { width: 720, height: 450, nodeRadius: 26, nodeSpacing: 65 });
removeNode($id('animation0').querySelector('svg'));
$id('animation0').appendChild(animation0);
const animation1 = animateRandomTopDownTree(sym, { width: 880, height: 450, nodeRadius: 26, nodeSpacing: 65 });
removeNode($id('animation1').querySelector('svg'));
$id('animation1').appendChild(animation1);
});
$id('genrand3').click();
</script>
<p>To build an optimal tree top-down, we would need a way to partition the symbols into two subsets, such that the total weights of each subset are as close as possible to 50%-50%. That might be tricky. On the other hand, if we can come up a simple criterion to identify two leaf nodes which should be siblings in the tree, we might be able to apply the same criterion repeatedly to build an optimal tree bottom-up. That sounds more promising.</p>
<p>Before we consider that further, take note of an important fact. <i>How many</i> internal nodes, including the root, does it take to connect <i>N</i> leaf nodes together into a binary tree? Watch the above animations again and try to figure it out:</p>
<p class="reveal hidden">N - 1.</p>
Good. Another one: When building a tree bottom-up, every time we pick two subtrees and join them together as children of a new internal node, what happens to the <b>depth</b> of all the leaves in the combined subtree?
<p class="reveal hidden">The depth of all the leaf nodes increases by 1.</p>
<p>Remember that the depth of each leaf node equals the number of bits required to encode the corresponding symbol. So every time we join two subtrees, we are in a sense "lengthening" the bitstrings for all the symbols in the new subtree. Since we want the most common symbols to have the shortest bitstrings (equivalent: we want their nodes to be closest to the root), they should be the <i>last</i> ones to be joined into the tree.</p>
<p>With that in mind, can you now see what the first step in building an optimal tree bottom-up should be?</p>
<p class="reveal hidden">Join the two lowest-weighted leaf nodes together into a subtree.</p>
<p>Yes! Just like this:</p>
<div style='border: 1px solid #aaa; --ratio:70/12'>
<div id='animation2' class='tree'></div>
</div>
<script type='module'>
import { animateOptimalTree } from '/assets/js/treeviz.js';
/* Just animate the first step, where the two lowest weighted nodes are joined together */
$id('animation2').appendChild(animateOptimalTree(symbols(histogram(sentence)), { width: 700, height: 120, nodeRadius: 10, nodeSpacing: 25, showWeight: true }, 2));
</script>
<hr>
<p>Just another small conceptual leap, and the complete solution will be ours. Here's what we need to figure out: Just now, we took the two lowest-weighted leaf nodes and joined them together. But how should we "weight" the resulting subtree? How will we know when and where to join it into a bigger subtree? More concretely: for the second step, we could either take our new 3-node subtree, and use it as one child of a new 5-node subtree, or we could pick two of the remaining single nodes, and join them into another 3-node subtree. How do we decide which choice is better?</p>
<p style="text-align: center">.</p>
<p style="text-align: center">.</p>
<p style="text-align: center">.</p>
<p>Think about it this way. When we attach a single node into a subtree, the bitstring representation for its symbol is being "lengthened" by one bit. In a sense, it's like the total bit length of the final encoded message is being increased by the weight of the node.</p>
<p>When we attach a subtree into a bigger subtree, the same thing happens to <i>all</i> the leaf nodes in the subtree. <i>All</i> of their bitstrings are growing by one bit, so the final encoded message size is growing by the sum of their weights.</p>
<p>That was a giveaway if there ever was one. So answer now, how should we weight subtrees which contain multiple leaf nodes?</p>
<p class="reveal hidden">The weight of a subtree should be the sum of the weights of all its leaves combined.</p>
<p>And then what is our algorithm for building an optimal tree bottom-up?</p>
<p class="reveal hidden">Start with a worklist of the bare leaf nodes. Keep picking the two lowest-weighted subtrees (where a single node also counts as a "subtree") from the worklist, join them together, then put the new, bigger subtree back into the worklist. When only one item remains in the worklist, it is the completed tree.</p>
<p>Type some text in the below entry field, and I'll animate the process for you:</p>
<div style='border: 1px solid #aaa; --ratio:80/35'>
<div id='animation3' class='tree'></div>
</div>
<input type='text' id='text0' style='width: 100%; box-sizing: border-box;' maxlength=32 value='Type something here' />
<script type='module'>
import { animateOptimalTree } from '/assets/js/treeviz.js';
function showOptimalCode() {
const text = $id('text0').value;
const sym = symbols(histogram(text));
sym.sort((a, b) => a.weight < b.weight ? -1 : 1);
const animation3 = animateOptimalTree(sym, { width: 800, height: 350, nodeRadius: 12, nodeSpacing: 30, showWeight: true });
$id('animation3').innerHTML = '';
$id('animation3').appendChild(animation3);
}
$id('text0').addEventListener('input', showOptimalCode);
showOptimalCode();
</script>
<p>Yes, that tree represents an <b>optimal</b> prefix-free code!</p>
<hr>
<p>That can't be hard to code, can it? (It's not.) One thing, though: Since the symbol set might be large, we need a data structure which allows quick retrieval of the two lowest-weighted subtrees at each step. A <a href="https://en.wikipedia.org/wiki/Binary_heap">minheap</a> fits the bill perfectly. Here's an minimal implementation using JavaScript Arrays:</p>
<figure class="highlight"><pre><code class="language-javascript" data-lang="javascript"><span class="kd">class</span> <span class="nx">Minheap</span> <span class="p">{</span>
<span class="cm">/* `comparator` must return true if first argument is 'larger' than second */</span>
<span class="kd">constructor</span><span class="p">(</span><span class="nx">comparator</span><span class="p">)</span> <span class="p">{</span>
<span class="k">this</span><span class="p">.</span><span class="nx">heap</span> <span class="o">=</span> <span class="p">[];</span>
<span class="k">this</span><span class="p">.</span><span class="nx">compare</span> <span class="o">=</span> <span class="nx">comparator</span><span class="p">;</span>
<span class="p">}</span>
<span class="kd">get</span> <span class="nx">length</span><span class="p">()</span> <span class="p">{</span>
<span class="k">return</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span><span class="p">;</span>
<span class="p">}</span>
<span class="nx">insert</span><span class="p">(</span><span class="nx">item</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">let</span> <span class="nx">index</span> <span class="o">=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span><span class="p">;</span>
<span class="k">while</span> <span class="p">(</span><span class="nx">index</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">parentIndex</span> <span class="o">=</span> <span class="p">((</span><span class="nx">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">>>></span> <span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
<span class="k">if</span> <span class="p">(</span><span class="k">this</span><span class="p">.</span><span class="nx">compare</span><span class="p">(</span><span class="nx">item</span><span class="p">,</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">parentIndex</span><span class="p">]))</span>
<span class="k">break</span><span class="p">;</span>
<span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">index</span><span class="p">]</span> <span class="o">=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">parentIndex</span><span class="p">];</span>
<span class="nx">index</span> <span class="o">=</span> <span class="nx">parentIndex</span><span class="p">;</span>
<span class="p">}</span>
<span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">index</span><span class="p">]</span> <span class="o">=</span> <span class="nx">item</span><span class="p">;</span>
<span class="p">}</span>
<span class="cm">/* Remove and return the smallest item in the heap */</span>
<span class="nx">pop</span><span class="p">()</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">result</span> <span class="o">=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nx">item</span> <span class="o">=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">pop</span><span class="p">();</span>
<span class="cm">/* If the heap is not empty, move items upward to restore the heap property,
* until we find an appropriate place to put `item` */</span>
<span class="k">if</span> <span class="p">(</span><span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">let</span> <span class="nx">index</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="k">while</span> <span class="p">(</span><span class="kc">true</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">leftIndex</span> <span class="o">=</span> <span class="p">(</span><span class="nx">index</span> <span class="o"><<</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="nx">rightIndex</span> <span class="o">=</span> <span class="nx">leftIndex</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span>
<span class="kd">let</span> <span class="nx">childIndex</span> <span class="o">=</span> <span class="nx">leftIndex</span><span class="p">;</span>
<span class="k">if</span> <span class="p">(</span><span class="nx">rightIndex</span> <span class="o"><</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="k">this</span><span class="p">.</span><span class="nx">compare</span><span class="p">(</span><span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">leftIndex</span><span class="p">],</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">rightIndex</span><span class="p">]))</span>
<span class="nx">childIndex</span> <span class="o">=</span> <span class="nx">rightIndex</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="nx">leftIndex</span> <span class="o">>=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span><span class="p">)</span> <span class="p">{</span>
<span class="k">break</span><span class="p">;</span>
<span class="p">}</span>
<span class="k">if</span> <span class="p">(</span><span class="k">this</span><span class="p">.</span><span class="nx">compare</span><span class="p">(</span><span class="nx">item</span><span class="p">,</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">childIndex</span><span class="p">]))</span> <span class="p">{</span>
<span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">index</span><span class="p">]</span> <span class="o">=</span> <span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">childIndex</span><span class="p">];</span>
<span class="nx">index</span> <span class="o">=</span> <span class="nx">childIndex</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
<span class="k">break</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span>
<span class="k">this</span><span class="p">.</span><span class="nx">heap</span><span class="p">[</span><span class="nx">index</span><span class="p">]</span> <span class="o">=</span> <span class="nx">item</span><span class="p">;</span>
<span class="p">}</span>
<span class="k">return</span> <span class="nx">result</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span></code></pre></figure>
<p>It would be fun to animate the minheap operations and show you how they work, but that would have to be a different article.</p>
<p>The rest of the code to build Huffman trees is almost anticlimactic:</p>
<figure class="highlight"><pre><code class="language-javascript" data-lang="javascript"><span class="cm">/* Count how many times each character appears in a string */</span>
<span class="kd">function</span> <span class="nx">histogram</span><span class="p">(</span><span class="nx">string</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">histogram</span> <span class="o">=</span> <span class="k">new</span> <span class="nb">Map</span><span class="p">();</span>
<span class="k">for</span> <span class="p">(</span><span class="kd">const</span> <span class="nx">char</span> <span class="k">of</span> <span class="nx">string</span><span class="p">)</span>
<span class="nx">histogram</span><span class="p">.</span><span class="kd">set</span><span class="p">(</span><span class="nx">char</span><span class="p">,</span> <span class="p">(</span><span class="nx">histogram</span><span class="p">.</span><span class="kd">get</span><span class="p">(</span><span class="nx">char</span><span class="p">)</span> <span class="o">||</span> <span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">);</span>
<span class="k">return</span> <span class="nx">histogram</span><span class="p">;</span>
<span class="p">}</span>
<span class="kd">function</span> <span class="nx">symbols</span><span class="p">(</span><span class="nx">histogram</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="nb">Array</span><span class="p">.</span><span class="k">from</span><span class="p">(</span><span class="nx">histogram</span><span class="p">).</span><span class="nx">map</span><span class="p">(([</span><span class="nx">char</span><span class="p">,</span> <span class="nx">count</span><span class="p">])</span> <span class="o">=></span> <span class="p">({</span> <span class="na">value</span><span class="p">:</span> <span class="nx">char</span><span class="p">,</span> <span class="na">weight</span><span class="p">:</span> <span class="nx">count</span> <span class="p">}));</span>
<span class="p">}</span>
<span class="kd">function</span> <span class="nx">huffmanTree</span><span class="p">(</span><span class="nx">symbols</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">heap</span> <span class="o">=</span> <span class="k">new</span> <span class="nx">Minheap</span><span class="p">((</span><span class="nx">a</span><span class="p">,</span><span class="nx">b</span><span class="p">)</span> <span class="o">=></span> <span class="nx">a</span><span class="p">.</span><span class="nx">weight</span> <span class="o">></span> <span class="nx">b</span><span class="p">.</span><span class="nx">weight</span><span class="p">);</span>
<span class="k">for</span> <span class="p">(</span><span class="kd">const</span> <span class="nx">symbol</span> <span class="k">of</span> <span class="nx">symbols</span><span class="p">)</span>
<span class="nx">heap</span><span class="p">.</span><span class="nx">insert</span><span class="p">(</span><span class="nx">symbol</span><span class="p">);</span>
<span class="k">while</span> <span class="p">(</span><span class="nx">heap</span><span class="p">.</span><span class="nx">length</span> <span class="o">></span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">const</span> <span class="nx">a</span> <span class="o">=</span> <span class="nx">heap</span><span class="p">.</span><span class="nx">pop</span><span class="p">(),</span> <span class="nx">b</span> <span class="o">=</span> <span class="nx">heap</span><span class="p">.</span><span class="nx">pop</span><span class="p">();</span>
<span class="nx">heap</span><span class="p">.</span><span class="nx">insert</span><span class="p">({</span> <span class="na">value</span><span class="p">:</span> <span class="p">[</span><span class="nx">a</span><span class="p">,</span> <span class="nx">b</span><span class="p">],</span> <span class="na">weight</span><span class="p">:</span> <span class="nx">a</span><span class="p">.</span><span class="nx">weight</span> <span class="o">+</span> <span class="nx">b</span><span class="p">.</span><span class="nx">weight</span> <span class="p">});</span>
<span class="p">}</span>
<span class="k">return</span> <span class="nx">heap</span><span class="p">.</span><span class="nx">pop</span><span class="p">();</span>
<span class="p">}</span></code></pre></figure>
<h2>Modifying the Basic Algorithm for JPEG</h2>
<p>The Huffman codes generated above have <b>two</b> important differences from those used to compress pixel data in JPEG files.</p>
<p><b>Difference #1:</b> JPEG Huffman tables never use bitstrings which are composed of only 1's. "111" is out. "1111" is forbidden. And you can just forget about "111111".</p>
<p><i>BUT WHY?</i> Because while sections of Huffman-coded data in a JPEG file must always occupy a whole number of 8-bit bytes, all those variable-length bitstrings will not necessarily add up to a multiple of 8 bits. If there are some extra bits left to fill in the last byte, "1" bits are used as padding. If bitstrings composed of only 1's were used, the padding in the last byte could be mistakenly decoded as an extraneous trailing symbol. By avoiding such bitstrings, it is always possible to recognize the padding.</p>
<p>How can we modify our algorithm to account for that? Can you think of an idea?</p>
<div class="reveal hidden"><p>Here's one solution: Include a "dummy" symbol when building the tree. The dummy symbol should have the lowest possible weight, so it does not uselessly occupy a desirable position in the tree (meaning a short bitstring). When subtrees are joined, we must make sure any subtree containing the dummy is always used as the <i>right</i> child. That guarantees that the dummy will end up at the far right of the tree, and will have a bitstring consisting of all 1's.</p>
<p>Since the generated code is prefix-free, we can be sure no other node will get a bitstring with only 1's.</p>
<p>After the entire tree is completed, we can delete the dummy node.</p></div>
<p>That just takes a few more lines of code:</p>
<figure class="highlight"><pre><code class="language-diff" data-lang="diff"><span class="p">@@ -7,7 +7,9 @@</span>
}
function symbols(histogram) {
<span class="gd">- return Array.from(histogram).map(([char, count]) => ({ value: char, weight: count }));
</span><span class="gi">+ const sym = Array.from(histogram).map(([char, count]) => ({ value: char, weight: count }));
+ sym.push({ value: "🃏", weight: 0, dummy: true });
+ return sym;
</span> }
function huffmanTree(symbols) {
<span class="p">@@ -16,8 +18,13 @@</span>
heap.insert(symbol);
while (heap.length > 1) {
<span class="gd">- const a = heap.pop(), b = heap.pop();
- heap.insert({ value: [a, b], weight: a.weight + b.weight });
</span><span class="gi">+ let a = heap.pop(), b = heap.pop();
+ if (a.dummy) {
+ /* Dummy must always be on the right-hand side */
+ let temp = a; a = b; b = temp;
+ }
+ const parent = { value: [a, b], weight: a.weight + b.weight, dummy: a.dummy || b.dummy };
+ heap.insert(parent);
</span> }
return heap.pop();</code></pre></figure>
<p>This is optimal tree construction with a dummy node:</p>
<div id='animation4' class='tree' style='border: 1px solid #aaa; --ratio:80/35'></div>
<script type='module'>
import { animateDummyTree } from '/assets/js/treeviz.js';
(function() {
/* Count how many times each character appears in a string */
function histogram(string) {
const histogram = new Map();
for (const char of string)
histogram.set(char, (histogram.get(char) || 0) + 1);
return histogram;
}
function symbols(histogram) {
const sym = Array.from(histogram).map(([char, count]) => ({ value: char, weight: count }));
sym.push({ value: "🃏", weight: 0, dummy: true });
return sym;
}
function huffmanTree(symbols) {
const heap = new Minheap((a,b) => a.weight > b.weight);
for (const symbol of symbols)
heap.insert(symbol);
while (heap.length > 1) {
let a = heap.pop(), b = heap.pop();
if (a.dummy) {
/* Dummy must always be on the right-hand side */
let temp = a; a = b; b = temp;
}
const parent = { value: [a, b], weight: a.weight + b.weight, dummy: a.dummy || b.dummy };
heap.insert(parent);
}
return heap.pop();
}
$id('animation4').appendChild(animateDummyTree(symbols(histogram(sentence)), { width: 800, height: 350, nodeRadius: 11, nodeSpacing: 28, showWeight: true }));
})();
</script>
<p><b>Difference #2:</b> JPEG Huffman codes are always <b>canonical</b>.</p>
<p>In a <a href='https://en.wikipedia.org/wiki/Canonical_Huffman_code'>canonical Huffman code</a>, when the bitstrings are read as binary numbers, shorter bitstrings are always smaller numbers. For example, such a code could not use both "000" and "10", since the former bitstring is longer, but is a smaller binary number. Further, when all the bitstrings used in the code are sorted by their numeric value, each successive bitstring increments by the smallest amount possible while remaining prefix-free. Here's an example, courtesy of Wikipedia:</p>
<table>
<tr><td>0</td></tr>
<tr><td>10</td></tr>
<tr><td>110</td></tr>
<tr><td>111</td></tr>
</table>
<p>Interpreted as numbers, those are zero, two, six, and seven. Why wasn't the second bitstring "01", or one? Because then the first would have been its prefix. Likewise, if the third was "011" (three), "100" (four), or "101" (five), in each case one of the first two would have been a prefix. For the fourth, incrementing by one to "111" didn't create a prefix, so "111" it is. (Hopefully that example gives you the idea; hit me up if you need more!)</p>
<p><i>But WHY does JPEG use canonical codes?</i> Because their coding tables can be represented in a very compact way<sup><a href='#footnote1' id='fnref1'>[1]</a></sup>, which makes our JPEG files smaller and faster to decode. (Yes, JPEG files must contain not just Huffman-encoded pixel data but also the coding tables which were used.)</p>
<p>So given a symbol set and frequencies, how can we generate a canonical Huffman code? Unfortunately, there is no straightforward way to do it directly by building a binary tree. But we can use our existing method to generate a non-canonical (but optimal) code, and then <i>rewrite the bitstrings to make them canonical</i> while maintaining their length. Remember, it's the length of the bitstrings assigned to each symbol which makes a prefix-free code optimal. The exact bitstrings which are used don't matter; we can shuffle them around and assign different ones with the same length.</p>
<p>The algorithm suggested in the JPEG specification (Appendix K) gets a step ahead of the game by not explicitly building a binary tree with left and right child pointers. It just tracks what the depth of each leaf node <i>would have been</i> had they actually been built into a binary tree. So these depths can be incremented whenever two "subtrees" are "joined together", the leaf nodes for each subtree are kept on a linked list. "Subtrees" are "joined" by concatenating their linked lists. (Libjpeg uses this trick when saving a Huffman-encoded JPEG file.<sup><a href='#footnote2' id='fnref2'>[2]</a></sup>)</p>
<p>Regardless of whether you actually build a binary tree or use the trick from Appendix K, once you know what the lengths of all the bitstrings in an optimal code should be, generating a canonical code is as simple as this:</p>
<figure class="highlight"><pre><code class="language-javascript" data-lang="javascript"><span class="cm">/* `lengths` is a sorted array of bitstring lengths required for an optimal code
*
* In real applications, an array of counts would likely be passed: how many
* bitstrings must have 1 bit, how many 2 bits, how many 3 bits, and so on
*
* Also, in real applications, the returned values would almost certainly
* not be strings; integers would be more likely */</span>
<span class="kd">function</span> <span class="nx">makeCanonical</span><span class="p">(</span><span class="nx">lengths</span><span class="p">)</span> <span class="p">{</span>
<span class="kd">let</span> <span class="nx">result</span> <span class="o">=</span> <span class="p">[],</span> <span class="nx">nextCode</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="k">for</span> <span class="p">(</span><span class="kd">var</span> <span class="nx">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="nx">i</span> <span class="o"><</span> <span class="nx">lengths</span><span class="p">.</span><span class="nx">length</span><span class="p">;</span> <span class="nx">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="nx">i</span> <span class="o">></span> <span class="mi">0</span> <span class="o">&&</span> <span class="nx">lengths</span><span class="p">[</span><span class="nx">i</span><span class="p">]</span> <span class="o">!==</span> <span class="nx">lengths</span><span class="p">[</span><span class="nx">i</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="nx">nextCode</span> <span class="o"><<=</span> <span class="mi">1</span><span class="p">;</span>
<span class="nx">result</span><span class="p">.</span><span class="nx">push</span><span class="p">(</span><span class="nx">nextCode</span><span class="p">.</span><span class="nx">toString</span><span class="p">(</span><span class="mi">2</span><span class="p">).</span><span class="nx">padStart</span><span class="p">(</span><span class="nx">lengths</span><span class="p">[</span><span class="nx">i</span><span class="p">],</span> <span class="dl">'</span><span class="s1">0</span><span class="dl">'</span><span class="p">));</span>
<span class="nx">nextCode</span><span class="o">++</span><span class="p">;</span>
<span class="p">}</span>
<span class="k">return</span> <span class="nx">result</span><span class="p">;</span>
<span class="p">}</span></code></pre></figure>
<p>Here is an example. Note that we are not using a dummy, so bitstrings with all 1 bits may be included.</p>
<table id='code2'>
<thead>
<tr><th>Random Code</th><th>Sorted by Bitstring Length</th><th>Canonicalized</th></tr>
</thead>
<tbody></tbody>
</table>
<a id='genrand4' class='button'>Generate Random Code</a>
<script>
'use strict';
/* `lengths` is a sorted array of bitstring lengths required for an optimal code
*
* In real applications, an array of counts would likely be passed: how many
* bitstrings must have 1 bit, how many 2 bits, how many 3 bits, and so on
*
* Also, in real applications, the returned values would almost certainly
* not be strings; integers would be more likely */
function makeCanonical(lengths) {
let result = [], nextCode = 0;
for (var i = 0; i < lengths.length; i++) {
if (i > 0 && lengths[i] !== lengths[i-1])
nextCode <<= 1;
result.push(nextCode.toString(2).padStart(lengths[i], '0'));
nextCode++;
}
return result;
}
function showCanonicalCodeTable(table) {
const code = randomPrefixFreeCode(sentence);
const col1 = Array.from(code.symbols).sort(alphabeticOrder).map((sym) => sym.value);
const bitlength = (sym) => code.dictionary.get(sym.value).length;
const col2 = Array.from(code.symbols).sort(comparator(bitlength)).map((sym) => sym.value);
const col3 = makeCanonical(code.symbols.map(bitlength).sort());
const display = function(str) { return str === ' ' ? '_' : str; };
const tbody = table.tBodies[0];
tbody.innerHTML = '';
for (var i = 0; i < col1.length; i++) {
addTableRow(tbody, [
`${display(col1[i])} ${code.dictionary.get(col1[i])}`,
`${display(col2[i])} ${code.dictionary.get(col2[i])}`,
`${display(col2[i])} ${col3[i]}`
]);
}
}
$id('genrand4').addEventListener('click', () => showCanonicalCodeTable($id('code2')));
showCanonicalCodeTable($id('code2'));
</script>
<h2>Huffman Coding in Practice</h2>
<p>All through this article, ASCII characters have been used as Huffman symbols. But in reality, if you want to compress English text, Huffman coding with each character treated as a separate symbol would be a terrible way to do it. Note two big weaknesses with that approach:</p>
<ul>
<li>Huffman coding is oblivious to patterns which involve the <i>order</i> of symbols. It only cares about their frequency. But real-life data usually has patterns related to the order of values, which can be exploited to achieve better compression.</li>
<li>Huffman coding always uses at least one bit for each symbol, and usually much more. So in the "ideal" case of a text file which just contains a single ASCII character repeated thousands of times, Huffman coding with one symbol per letter could only compress it to ⅛ of its original size. 8× compression may sound good, but any reasonable compression method should get far greater gains in that ridiculously easy-to-compress case.</li>
</ul>
<p>So just what am I saying here? Is Huffman coding a bad algorithm?</p>
<p>Not at all! But it is just one piece of a practical compression method; it's not a complete compression method by itself. And to make Huffman coding work to greatest advantage, it may be necessary to find an alternative data representation which is well-suited to such coding. Just taking the most "natural" or intuitive representation and directly applying Huffman coding to it will probably not work well.</p>
<p>As an example, in JPEG, the values which we want to compress are quantized DCT coefficients <a href='/visualizing-the-idct/'>(see the previous post for details)</a>, which have 8 bits of precision each.<sup><a href='#footnote3' id='fnref3'>[3]</a></sup> We could take the 256 possible coefficient values as 256 Huffman symbols and Huffman-code them directly, but this would be very suboptimal.</p>
<p>In the symbol set which is actually used, each symbol represents either:</p>
<ul>
<li>Some specific number of successive zero coefficients (0-15 of them), <i>and</i> the number of significant bits in the following non-zero coefficient.</li>
<li>A run of zeroes filling the remainder of a 64-coefficient block.</li>
</ul>
<p>Note that each symbol only tells us the <i>number</i> of significant bits in the next non-zero coefficient, not what those bits actually are. The actual coefficient value bits are simply inserted into the output data stream uncompressed. This is because the values of non-zero DCT coefficients don't actually repeat very much, so Huffman-coding them wouldn't really help. (See <a href='/visualizing-the-idct/#idctdemo'>the demonstration in the previous post</a>. Does it look like the coefficients within an 8-by-8 DCT matrix repeat much?) However, since the Huffman symbols tell us the number of significant bits, high-order zero bits can be discarded, which does help significantly.</p>
<p>JPEG files can use "arithmetic coding" as an alternative to Huffman coding (although this is not common). I dare say arithmetic coding is a more intriguing and fascinating algorithm than Huffman coding. So it will not surprise you that the next article in this series will focus on arithmetic coding. See you then!</p>
<p id='footnote1' style='font-size: 0.8em; margin-top: 2.5em'>[1] With a canonical code, only the number of bitstrings used of each possible length needs to be stored; how many are 1 bit long, how many 2 bits long, how many 3 bits long, and so on. The actual bitstrings can be quickly recovered from that. <a href='#fnref1'>⏎</a></p>
<p id='footnote2' style='font-size: 0.8em'>[2] But interestingly, libjpeg does <i>not</i> use a minheap when generating a Huffman code. Instead, it uses an array of symbol frequencies, and scans the whole array at each step to find the two lowest-weighted subtrees. <a href='#fnref2'>⏎</a></p>
<p id='footnote3' style='font-size: 0.8em'>[3] The JPEG standard actually allows DCT coefficients to be either 8-bit or 12-bit, but 8 bits is almost universally used. Libjpeg can theoretically handle JPEG files with 12-bit coefficients, but it must be specially configured to do so at compile time, and binary distributions are not generally built in that way. <a href='#fnref3'>⏎</a></p>
<script>
'use strict';
function revealText(event) {
event.preventDefault();
this.classList.remove('hidden');
this.removeEventListener('click', revealText);
}
document.querySelectorAll('.reveal').forEach((el) => el.addEventListener('click', revealText));
</script>The previous article in this series explored how JPEG compression converts pixel values to DCT coefficients. A later stage of the compression process uses either a method called "Huffman coding" or another called "arithmetic coding" to store those coefficients in a compact manner. The Huffman coding algorithm is very simple, but powerful and widely used. If you've never learned how it works, I promise this will be interesting. You have some data to compress. Your data can be viewed as a sequence of values; perhaps Unicode codepoints, pixel color samples, audio amplitude samples, or something comparable. In the context of Huffman coding, each of these values is called a "symbol". We are going to encode each symbol using a unique, variable-length string of bits. For example, if each symbol is a letter, the letter "a" could be "10111", "b" could be "00011", and so on. We can pick any series of bits we like to represent each symbol, with the restriction that after all the symbols are converted to bitstrings, and all those little bitstrings are smashed together, it must be possible to figure out what the original sequence of symbols was. That means the bitstrings for each symbol must be prefix-free; none of them can be a prefix of another. Example of a non-prefix-free code: say we choose "1100" to represent the letter "a", "11" to represent "b", and "00" for "c". When decoding text, we find "1100" somewhere. How are we supposed to know whether that was originally a letter "a", or the two letters "bc"? It's impossible to tell, precisely because "11" is a prefix of "1100". Fortunately, Huffman codes are always prefix-free. Let's encode this sentence with such a code. Click the button below, and the computer will generate a random prefix-free code. Try clicking a number of times and see what the smallest total number of bits required to encode the sentence appears to be. Generate Random Code Obviously, the number of bits required to encode a sequence can vary wildly depending on the chosen code. Can you see why some prefix-free codes are more efficient than others? What is the key difference between an efficient code and an inefficient one? (Click to reveal.) Efficient codes use the shortest bitstrings for the most common symbols. At the same time, since the code must be prefix-free, if we use a very short bitstring such as "0" for a common symbol, that means "00", "01", "000", "001", and so on will all become unavailable. So we need to strike a balance. While we want to use short bitstrings for common symbols, we don't want to be forced to use excessively long bitstrings for less common symbols. Out of the vast number of prefix-free codes which could be used, we want to find an optimal one; one which will encode our particular data in the smallest number of bits. (There will always be many optimal codes for any sequence, but we just need to find one of them.) At first, it might appear that we need to try millions of possible codes to be sure that we have an optimal one. Fortunately, that is not the case. Just count how many times each symbol appears in the input data, and in an almost trivially simple way, you can find an optimal prefix-free code. It will be easy, and fun! I could tell you the algorithm right now, but it will be so much more enjoyable to discover it yourself. So I'll take this slow, and reason towards a solution step by deliberate step. If at any point you catch the scent of the solution, stop and think it out before continuing. The first step is to represent a prefix-free code as a binary tree. Have a look: Generate Random Code Please make sure that you clearly see the correspondence between coding tables and binary trees before continuing. We can see that the number of bits used to encode each symbol equals the number of links between its tree node and the root node, also called the "depth" of the node. Leaf nodes which are closer to the root (smaller depth) have shorter bitstrings. We will add a weight to each leaf node, which is simply the number of times its symbol appears in the input data: Generate Random Code Now the total length in bits of the compressed output will equal weight times depth, summed over all the leaf nodes. So now our goal is to find a binary tree which minimizes the sum of weight times depth. We don't really have an idea how to do that, though. At least we do know what the leaf nodes of the tree should be: How are we going to find the right structure for the internal nodes? Well, we could try to do it top-down, meaning we figure out what child nodes the root node should have, then the nodes below those, and so on. Or we could work bottom-up, meaning we figure out which leaf nodes should become children of the same parent node, then find which parent nodes should be children of the same "grandparent" node, until the whole tree is joined together. A third option would be to work both up and down from the middle, but that is just as hopeless as it sounds. These animations may help you understand "top-down" and "bottom-up" tree construction: Bottom-up Top-down Generate Random Examples To build an optimal tree top-down, we would need a way to partition the symbols into two subsets, such that the total weights of each subset are as close as possible to 50%-50%. That might be tricky. On the other hand, if we can come up a simple criterion to identify two leaf nodes which should be siblings in the tree, we might be able to apply the same criterion repeatedly to build an optimal tree bottom-up. That sounds more promising. Before we consider that further, take note of an important fact. How many internal nodes, including the root, does it take to connect N leaf nodes together into a binary tree? Watch the above animations again and try to figure it out: N - 1. Good. Another one: When building a tree bottom-up, every time we pick two subtrees and join them together as children of a new internal node, what happens to the depth of all the leaves in the combined subtree? The depth of all the leaf nodes increases by 1. Remember that the depth of each leaf node equals the number of bits required to encode the corresponding symbol. So every time we join two subtrees, we are in a sense "lengthening" the bitstrings for all the symbols in the new subtree. Since we want the most common symbols to have the shortest bitstrings (equivalent: we want their nodes to be closest to the root), they should be the last ones to be joined into the tree. With that in mind, can you now see what the first step in building an optimal tree bottom-up should be? Join the two lowest-weighted leaf nodes together into a subtree. Yes! Just like this: Just another small conceptual leap, and the complete solution will be ours. Here's what we need to figure out: Just now, we took the two lowest-weighted leaf nodes and joined them together. But how should we "weight" the resulting subtree? How will we know when and where to join it into a bigger subtree? More concretely: for the second step, we could either take our new 3-node subtree, and use it as one child of a new 5-node subtree, or we could pick two of the remaining single nodes, and join them into another 3-node subtree. How do we decide which choice is better? . . . Think about it this way. When we attach a single node into a subtree, the bitstring representation for its symbol is being "lengthened" by one bit. In a sense, it's like the total bit length of the final encoded message is being increased by the weight of the node. When we attach a subtree into a bigger subtree, the same thing happens to all the leaf nodes in the subtree. All of their bitstrings are growing by one bit, so the final encoded message size is growing by the sum of their weights. That was a giveaway if there ever was one. So answer now, how should we weight subtrees which contain multiple leaf nodes? The weight of a subtree should be the sum of the weights of all its leaves combined. And then what is our algorithm for building an optimal tree bottom-up? Start with a worklist of the bare leaf nodes. Keep picking the two lowest-weighted subtrees (where a single node also counts as a "subtree") from the worklist, join them together, then put the new, bigger subtree back into the worklist. When only one item remains in the worklist, it is the completed tree. Type some text in the below entry field, and I'll animate the process for you: Yes, that tree represents an optimal prefix-free code! That can't be hard to code, can it? (It's not.) One thing, though: Since the symbol set might be large, we need a data structure which allows quick retrieval of the two lowest-weighted subtrees at each step. A minheap fits the bill perfectly. Here's an minimal implementation using JavaScript Arrays: class Minheap { /* `comparator` must return true if first argument is 'larger' than second */ constructor(comparator) { this.heap = []; this.compare = comparator; } get length() { return this.heap.length; } insert(item) { let index = this.heap.length; while (index > 0) { const parentIndex = ((index + 1) >>> 1) - 1; if (this.compare(item, this.heap[parentIndex])) break; this.heap[index] = this.heap[parentIndex]; index = parentIndex; } this.heap[index] = item; } /* Remove and return the smallest item in the heap */ pop() { const result = this.heap[0], item = this.heap.pop(); /* If the heap is not empty, move items upward to restore the heap property, * until we find an appropriate place to put `item` */ if (this.heap.length) { let index = 0; while (true) { const leftIndex = (index << 1) + 1, rightIndex = leftIndex + 1; let childIndex = leftIndex; if (rightIndex < this.heap.length) { if (this.compare(this.heap[leftIndex], this.heap[rightIndex])) childIndex = rightIndex; } else if (leftIndex >= this.heap.length) { break; } if (this.compare(item, this.heap[childIndex])) { this.heap[index] = this.heap[childIndex]; index = childIndex; } else { break; } } this.heap[index] = item; } return result; } } It would be fun to animate the minheap operations and show you how they work, but that would have to be a different article. The rest of the code to build Huffman trees is almost anticlimactic: /* Count how many times each character appears in a string */ function histogram(string) { const histogram = new Map(); for (const char of string) histogram.set(char, (histogram.get(char) || 0) + 1); return histogram; } function symbols(histogram) { return Array.from(histogram).map(([char, count]) => ({ value: char, weight: count })); } function huffmanTree(symbols) { const heap = new Minheap((a,b) => a.weight > b.weight); for (const symbol of symbols) heap.insert(symbol); while (heap.length > 1) { const a = heap.pop(), b = heap.pop(); heap.insert({ value: [a, b], weight: a.weight + b.weight }); } return heap.pop(); } Modifying the Basic Algorithm for JPEG The Huffman codes generated above have two important differences from those used to compress pixel data in JPEG files. Difference #1: JPEG Huffman tables never use bitstrings which are composed of only 1's. "111" is out. "1111" is forbidden. And you can just forget about "111111". BUT WHY? Because while sections of Huffman-coded data in a JPEG file must always occupy a whole number of 8-bit bytes, all those variable-length bitstrings will not necessarily add up to a multiple of 8 bits. If there are some extra bits left to fill in the last byte, "1" bits are used as padding. If bitstrings composed of only 1's were used, the padding in the last byte could be mistakenly decoded as an extraneous trailing symbol. By avoiding such bitstrings, it is always possible to recognize the padding. How can we modify our algorithm to account for that? Can you think of an idea? Here's one solution: Include a "dummy" symbol when building the tree. The dummy symbol should have the lowest possible weight, so it does not uselessly occupy a desirable position in the tree (meaning a short bitstring). When subtrees are joined, we must make sure any subtree containing the dummy is always used as the right child. That guarantees that the dummy will end up at the far right of the tree, and will have a bitstring consisting of all 1's. Since the generated code is prefix-free, we can be sure no other node will get a bitstring with only 1's. After the entire tree is completed, we can delete the dummy node. That just takes a few more lines of code: @@ -7,7 +7,9 @@ } function symbols(histogram) { - return Array.from(histogram).map(([char, count]) => ({ value: char, weight: count })); + const sym = Array.from(histogram).map(([char, count]) => ({ value: char, weight: count })); + sym.push({ value: "🃏", weight: 0, dummy: true }); + return sym; } function huffmanTree(symbols) { @@ -16,8 +18,13 @@ heap.insert(symbol); while (heap.length > 1) { - const a = heap.pop(), b = heap.pop(); - heap.insert({ value: [a, b], weight: a.weight + b.weight }); + let a = heap.pop(), b = heap.pop(); + if (a.dummy) { + /* Dummy must always be on the right-hand side */ + let temp = a; a = b; b = temp; + } + const parent = { value: [a, b], weight: a.weight + b.weight, dummy: a.dummy || b.dummy }; + heap.insert(parent); } return heap.pop(); This is optimal tree construction with a dummy node: Difference #2: JPEG Huffman codes are always canonical. In a canonical Huffman code, when the bitstrings are read as binary numbers, shorter bitstrings are always smaller numbers. For example, such a code could not use both "000" and "10", since the former bitstring is longer, but is a smaller binary number. Further, when all the bitstrings used in the code are sorted by their numeric value, each successive bitstring increments by the smallest amount possible while remaining prefix-free. Here's an example, courtesy of Wikipedia: 0 10 110 111 Interpreted as numbers, those are zero, two, six, and seven. Why wasn't the second bitstring "01", or one? Because then the first would have been its prefix. Likewise, if the third was "011" (three), "100" (four), or "101" (five), in each case one of the first two would have been a prefix. For the fourth, incrementing by one to "111" didn't create a prefix, so "111" it is. (Hopefully that example gives you the idea; hit me up if you need more!) But WHY does JPEG use canonical codes? Because their coding tables can be represented in a very compact way[1], which makes our JPEG files smaller and faster to decode. (Yes, JPEG files must contain not just Huffman-encoded pixel data but also the coding tables which were used.) So given a symbol set and frequencies, how can we generate a canonical Huffman code? Unfortunately, there is no straightforward way to do it directly by building a binary tree. But we can use our existing method to generate a non-canonical (but optimal) code, and then rewrite the bitstrings to make them canonical while maintaining their length. Remember, it's the length of the bitstrings assigned to each symbol which makes a prefix-free code optimal. The exact bitstrings which are used don't matter; we can shuffle them around and assign different ones with the same length. The algorithm suggested in the JPEG specification (Appendix K) gets a step ahead of the game by not explicitly building a binary tree with left and right child pointers. It just tracks what the depth of each leaf node would have been had they actually been built into a binary tree. So these depths can be incremented whenever two "subtrees" are "joined together", the leaf nodes for each subtree are kept on a linked list. "Subtrees" are "joined" by concatenating their linked lists. (Libjpeg uses this trick when saving a Huffman-encoded JPEG file.[2]) Regardless of whether you actually build a binary tree or use the trick from Appendix K, once you know what the lengths of all the bitstrings in an optimal code should be, generating a canonical code is as simple as this: /* `lengths` is a sorted array of bitstring lengths required for an optimal code * * In real applications, an array of counts would likely be passed: how many * bitstrings must have 1 bit, how many 2 bits, how many 3 bits, and so on * * Also, in real applications, the returned values would almost certainly * not be strings; integers would be more likely */ function makeCanonical(lengths) { let result = [], nextCode = 0; for (var i = 0; i < lengths.length; i++) { if (i > 0 && lengths[i] !== lengths[i-1]) nextCode <<= 1; result.push(nextCode.toString(2).padStart(lengths[i], '0')); nextCode++; } return result; } Here is an example. Note that we are not using a dummy, so bitstrings with all 1 bits may be included. Random CodeSorted by Bitstring LengthCanonicalized Generate Random Code Huffman Coding in Practice All through this article, ASCII characters have been used as Huffman symbols. But in reality, if you want to compress English text, Huffman coding with each character treated as a separate symbol would be a terrible way to do it. Note two big weaknesses with that approach: Huffman coding is oblivious to patterns which involve the order of symbols. It only cares about their frequency. But real-life data usually has patterns related to the order of values, which can be exploited to achieve better compression. Huffman coding always uses at least one bit for each symbol, and usually much more. So in the "ideal" case of a text file which just contains a single ASCII character repeated thousands of times, Huffman coding with one symbol per letter could only compress it to ⅛ of its original size. 8× compression may sound good, but any reasonable compression method should get far greater gains in that ridiculously easy-to-compress case. So just what am I saying here? Is Huffman coding a bad algorithm? Not at all! But it is just one piece of a practical compression method; it's not a complete compression method by itself. And to make Huffman coding work to greatest advantage, it may be necessary to find an alternative data representation which is well-suited to such coding. Just taking the most "natural" or intuitive representation and directly applying Huffman coding to it will probably not work well. As an example, in JPEG, the values which we want to compress are quantized DCT coefficients (see the previous post for details), which have 8 bits of precision each.[3] We could take the 256 possible coefficient values as 256 Huffman symbols and Huffman-code them directly, but this would be very suboptimal. In the symbol set which is actually used, each symbol represents either: Some specific number of successive zero coefficients (0-15 of them), and the number of significant bits in the following non-zero coefficient. A run of zeroes filling the remainder of a 64-coefficient block. Note that each symbol only tells us the number of significant bits in the next non-zero coefficient, not what those bits actually are. The actual coefficient value bits are simply inserted into the output data stream uncompressed. This is because the values of non-zero DCT coefficients don't actually repeat very much, so Huffman-coding them wouldn't really help. (See the demonstration in the previous post. Does it look like the coefficients within an 8-by-8 DCT matrix repeat much?) However, since the Huffman symbols tell us the number of significant bits, high-order zero bits can be discarded, which does help significantly. JPEG files can use "arithmetic coding" as an alternative to Huffman coding (although this is not common). I dare say arithmetic coding is a more intriguing and fascinating algorithm than Huffman coding. So it will not surprise you that the next article in this series will focus on arithmetic coding. See you then! [1] With a canonical code, only the number of bitstrings used of each possible length needs to be stored; how many are 1 bit long, how many 2 bits long, how many 3 bits long, and so on. The actual bitstrings can be quickly recovered from that. ⏎ [2] But interestingly, libjpeg does not use a minheap when generating a Huffman code. Instead, it uses an array of symbol frequencies, and scans the whole array at each step to find the two lowest-weighted subtrees. ⏎ [3] The JPEG standard actually allows DCT coefficients to be either 8-bit or 12-bit, but 8 bits is almost universally used. Libjpeg can theoretically handle JPEG files with 12-bit coefficients, but it must be specially configured to do so at compile time, and binary distributions are not generally built in that way. ⏎JPEG Series, Part I: Visualizing the Inverse Discrete Cosine Transform2021-04-18T00:00:00+00:002021-04-18T00:00:00+00:00/visualizing-the-idct<p>A key step in JPEG image compression is converting 8-by-8-pixel blocks of color values into the frequency domain, so instead of storing color values, we store amplitudes of sinusoidal waveforms. This is a fun little bit of applied math, and you might enjoy seeing how it works.<p>
<p>It all really started in the early 1820's, when Joseph Fourier figured out that any periodic waveform can be broken down into a sum of sinusoids. Kalid Azad has <a href="https://betterexplained.com/articles/an-interactive-guide-to-the-fourier-transform/">explained this much better than I could over at BetterExplained</a>, and if you are not familiar with the Fourier transform, I recommend you go learn about it from Kalid first. I'll be waiting right here.</p>
<hr>
<p>Welcome back! Let's apply what you learned to a block of pixel values. How about this block of pixels right here?</p>
<!-- If you want to reuse the program code in this page for any purpose, hit the author up.
He is generally cool about stuff and will probably be willing to grant permission. -->
<div class="pic-container">
<div class="pic-spacer"></div>
<div class="pic" id="pic1"></div>
</div>
<script>
'use strict';
const coffeeCup = [
96, 213, 255, 96, 255, 213, 96, 255,
255, 119, 255, 255, 109, 255, 79, 255,
255, 79, 255, 85, 255, 119, 255, 255,
255, 0, 5, 10, 12, 15, 20, 20,
255, 18, 35, 46, 140, 10, 255, 30,
255, 18, 35, 46, 140, 10, 255, 45,
255, 18, 35, 46, 140, 12, 35, 255,
255, 255, 18, 26, 32, 12, 255, 255
];
function initPixelGrid(pic) {
const pixels = [];
for (var i = 0; i < 64; i++) {
const pixel = document.createElement('span');
pixels.push(pixel);
pixel.classList.add('pixel');
pic.appendChild(pixel);
}
return pixels;
}
function setPixelValues(pixels, values) {
for (var i = 0; i < 64; i++) {
const luminance = values[i];
pixels[i].style.backgroundColor = `rgb(${luminance},${luminance},${luminance})`;
}
}
const pic1 = document.getElementById('pic1');
const px1 = initPixelGrid(pic1);
setPixelValues(px1, coffeeCup);
</script>
<p>Take a horizontal slice, 8 pixels wide, from that block. Take the position of each pixel as an <i>x</i> value and its brightness as a <i>y</i> value. Then the 8 pixels correspond to 8 <i>(x, y)</i> points on a plane, and we could find a combination of sinusoidal waves that would go through those 8 points.</p>
<p>We could, and we will. Let's do it now. Click on any row of the pixel grid below:</p>
<div class="pic-container">
<div class="pic-spacer"></div>
<div class="pic" id="pic2"></div>
</div>
<div class="wave-container">
<div class="wave-spacer"></div>
<div class="wave-graph">
<canvas id="waveform1" class="wave"></canvas>
</div>
<div class="wave-legend" id="legend1"></div>
</div>
<script>
'use strict';
const pic2 = document.getElementById('pic2');
const px2 = initPixelGrid(pic2);
setPixelValues(px2, coffeeCup);
const wave1 = document.getElementById('waveform1');
const ctx1 = wave1.getContext('2d');
const vertTickColor = 'rgb(230,230,150)';
function graphDrawVerticalTicks(canvas, ctx) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
ctx.strokeStyle = vertTickColor;
ctx.lineWidth = 1;
const xRange = xMax - xMin;
for (var i = 0; i < 8; i++) {
ctx.beginPath();
ctx.moveTo(xMin + (xRange * (i + 0.5) / 8), yMin);
ctx.lineTo(xMin + (xRange * (i + 0.5) / 8), yMax);
ctx.stroke();
}
}
const zeroLineColor = 'rgb(40,50,180)';
function graphDrawZeroLine(canvas, ctx) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
ctx.strokeStyle = zeroLineColor;
ctx.lineWidth = 2;
ctx.beginPath();
ctx.moveTo(xMin, (yMin + yMax) / 2);
ctx.lineTo(xMax, (yMin + yMax) / 2);
ctx.stroke();
}
function autosizeGraph(canvas, callback) {
const parent = canvas.parentElement;
function resizeCanvas() {
canvas.width = parent.clientWidth;
canvas.height = parent.clientHeight;
if (callback)
callback();
}
if (window.ResizeObserver) {
const resizeObserver = new ResizeObserver(resizeCanvas);
resizeObserver.observe(parent);
} else {
window.addEventListener('resize', resizeCanvas);
}
}
/* `values` is an array of 8 evenly-spaced discrete time samples */
function discreteCosTransform1D(values) {
const coeffs = [];
for (var u = 0; u < 8; u++) {
const correctionFactor = (u === 0) ? Math.sqrt(8) : 2;
var sum = 0;
for (var x = 0; x < 8; x++)
sum += values[x] * Math.cos(Math.PI * u * ((2 * x) + 1) / 16);
coeffs.push(sum / correctionFactor);
}
return coeffs;
}
function graphCosine(canvas, ctx, frequency, amplitude, color, lineWidth) {
if (Math.abs(amplitude) <= 0.02)
return;
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
const yRange = yMax - yMin,
yMid = (yMax + yMin) / 2,
xRange = xMax - xMin;
ctx.strokeStyle = color;
ctx.lineWidth = lineWidth || 1.5;
ctx.beginPath();
/* A cosine wave with no phase shift always starts at maximum amplitude
* Subtract `amplitude` because lower y-values are 'up' on an HTML canvas */
ctx.moveTo(xMin, yMid - (amplitude * (yRange / 2)));
for (var x = xMin+2; x < xMax; x += 2) {
ctx.lineTo(x, yMid - (amplitude * (yRange / 2) * Math.cos(Math.PI * 2 * ((x - xMin) / xRange) * frequency)));
}
ctx.stroke();
}
function graphCosines(canvas, ctx, frequencies, amplitudes, colors) {
for (var i = 0; i < frequencies.length; i++)
graphCosine(canvas, ctx, frequencies[i], amplitudes[i], colors[i]);
}
function graphSumOfCosines(canvas, ctx, frequencies, amplitudes) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
const yRange = yMax - yMin,
yMid = (yMax + yMin) / 2,
xRange = xMax - xMin;
ctx.strokeStyle = 'black';
ctx.lineWidth = 3;
ctx.beginPath();
/* A cosine wave with no phase shift always starts at maximum amplitude
* Subtract amplitude because lower y-values are 'up' on an HTML canvas */
ctx.moveTo(xMin, yMid - (amplitudes.reduce((sum,a) => sum+a, 0) * (yRange / 2)));
for (var x = xMin+2; x < xMax; x += 2) {
var sum = 0;
for (var i = 0; i < frequencies.length; i++) {
sum += amplitudes[i] * (yRange / 2) * Math.cos(Math.PI * 2 * ((x - xMin) / xRange) * frequencies[i]);
}
ctx.lineTo(x, yMid - sum);
}
ctx.stroke();
}
function graphTargetDots(canvas, ctx, values, coordinates) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
const xRange = xMax - xMin,
yRange = yMax - yMin,
yMid = (yMax + yMin) / 2,
dotSize = 4.5;
ctx.fillStyle = 'black';
for (var i = 0; i < 8; i++) {
ctx.beginPath();
const x = xMin + (xRange * (i + 0.5) / 8);
const y = yMid - (values[i] * yRange / 2);
ctx.arc(x, y, dotSize, 0, Math.PI * 2);
coordinates[i] = [x, y];
ctx.fill();
}
}
/* 8 colors, each 45 degrees apart around the color wheel */
const graphColors = [
'rgb(200,10,10)',
'rgb(200,150,10)',
'rgb(105,200,10)',
'rgb(10,200,58)',
'rgb(10,200,200)',
'rgb(10,58,200)',
'rgb(105,10,200)',
'rgb(200,10,150)'
];
function initLegend(parent) {
const entries = [];
for (var i = 0; i < 8; i++) {
const entry = document.createElementNS("http://www.w3.org/2000/svg", "svg");
entry.classList.add('legend-entry');
entry.style.visibility = 'hidden'; /* Will appear when a waveform is shown */
entry.setAttribute('viewBox', '0 0 500 300');
entry.innerHTML = '<circle cx=22 cy=158 r=117 fill="white" class="selection" />' +
'<circle cx=22 cy=158 r=100 fill="' + graphColors[i] + '" />' +
'<text x=147 y=150 font-size=110>F = ' + (0.5 * i) + 'Hz</text><text x=147 y=265 font-size=110 class="amplitude"></text>';
parent.appendChild(entry);
entries.push(entry);
}
return entries;
}
const legend1 = initLegend(document.getElementById('legend1'));
function updateLegend(entries, amplitudes) {
for (var i = 0; i < 8; i++) {
if (Math.abs(amplitudes[i]) <= 0.02) {
entries[i].style.visibility = 'hidden';
} else {
entries[i].style.visibility = '';
const span = entries[i].querySelector('.amplitude');
span.textContent = 'A = ' + amplitudes[i].toFixed(2);
}
}
}
function highlightLegend(entry, color) {
if (entry)
entry.querySelector('.selection').setAttribute('fill', color);
}
function selectRow(pic, rowNumber) {
var highlight = pic.getElementsByClassName('highlight')[0];
if (!highlight) {
highlight = document.createElement('span');
highlight.classList.add('highlight');
pic.appendChild(highlight);
}
highlight.style.left = 0;
highlight.style.width = '100%';
highlight.style.top = (rowNumber * 12.5) + '%';
highlight.style.height = '12.5%';
}
function drawCosineGraph(canvas, ctx, values, coeffs, dotLocations) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
graphDrawVerticalTicks(canvas, ctx);
graphDrawZeroLine(canvas, ctx);
graphCosines(canvas, ctx, [0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5], coeffs, graphColors);
graphSumOfCosines(canvas, ctx, [0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5], coeffs, graphColors);
graphTargetDots(canvas, ctx, values, dotLocations);
}
function updateCosineGraphAndLegend(canvas, ctx, legend, samples, hilite, dotLocations) {
/* Center range of possible sample values on the zero line */
var shifted = samples.map((s) => s - 128);
/* After DCT, apply an arbitrary scaling factor to make the graph fit well
* in the available space */
const coeffs = discreteCosTransform1D(shifted).map((c) => c / 420);
/* The IDCT applies a correction factor of 1/sqrt(2) for u=0; do the same */
coeffs[0] /= Math.sqrt(2);
/* Scaling factor for distance of target dots from zero line
* This must be 1/2 the scaling factor for the wave amplitudes for the waveform
* and dots to line up */
shifted = shifted.map((s) => s / 210);
updateLegend(legend, coeffs);
drawCosineGraph(canvas, ctx, shifted, coeffs, dotLocations);
if (hilite || hilite === 0)
graphCosine(canvas, ctx, hilite * 0.5, coeffs[hilite], graphColors[hilite], 3.5);
}
function onPixelClick(pixels, callback) {
for (var i = 0; i < 64; i++) {
(function(i) {
pixels[i].addEventListener('click', function() { callback(i, pixels[i]) });
})(i);
}
}
function onLegendClick(entries, callback) {
for (var i = 0; i < 8; i++) {
(function(i) {
entries[i].addEventListener('click', function() { callback(i, entries[i]); });
})(i);
}
}
function onTargetDotDrag(canvas, dotLocations, callback) {
var draggedDot, dragX, dragY;
function acquireDot(offsetX, offsetY) {
dragX = offsetX;
dragY = offsetY;
for (var i = 0; i < 8; i++) {
const [x, y] = dotLocations[i];
if (x && y && ((x - dragX) ** 2 + (y - dragY) ** 2) < 40) {
draggedDot = i;
return;
}
}
draggedDot = dragX = dragY = undefined;
}
canvas.addEventListener('mousedown', function(event) {
acquireDot(event.offsetX, event.offsetY);
});
canvas.addEventListener('touchstart', function(event) {
const rect = canvas.getBoundingClientRect();
const touch = event.changedTouches[0];
acquireDot(touch.clientX - rect.left, touch.clientY - rect.top);
/* If we acquired a dot, then this touch event should not have other effects
* like initiating page scrolling */
if (draggedDot || draggedDot === 0)
event.preventDefault();
});
canvas.addEventListener('mouseup', function() { draggedDot = dragX = dragY = undefined; });
canvas.addEventListener('mouseleave', function() { draggedDot = dragX = dragY = undefined; });
canvas.addEventListener('touchcancel', function() { draggedDot = dragX = dragY = undefined; });
canvas.addEventListener('touchend', function() { draggedDot = dragX = dragY = undefined; });
canvas.addEventListener('mousemove', function(event) {
if (draggedDot || draggedDot === 0)
callback(draggedDot, dragX, dragY, event.offsetX, event.offsetY);
});
canvas.addEventListener('touchmove', function(event) {
if (draggedDot || draggedDot === 0) {
event.preventDefault();
const rect = canvas.getBoundingClientRect();
const touch = event.changedTouches[0];
callback(draggedDot, dragX, dragY, touch.clientX - rect.left, touch.clientY - rect.top);
}
});
}
(function() {
const image = Array.from(coffeeCup); /* Copy, since we may modify the image */
var index, samples, hilitedFreq, repaintScheduled = false;
const dotLocations = new Array(8).fill([]);
autosizeGraph(wave1, function() {
if (samples) {
updateCosineGraphAndLegend(wave1, ctx1, legend1, samples, hilitedFreq, dotLocations);
} else {
graphDrawVerticalTicks(wave1, ctx1);
graphDrawZeroLine(wave1, ctx1);
}
});
onPixelClick(px2, function(i, px) {
selectRow(pic2, Math.floor(i / 8));
highlightLegend(legend1[hilitedFreq], 'white');
hilitedFreq = undefined;
index = Math.floor(i / 8) * 8;
samples = image.slice(index, index + 8);
updateCosineGraphAndLegend(wave1, ctx1, legend1, samples, undefined, dotLocations);
});
onLegendClick(legend1, function(i) {
highlightLegend(legend1[hilitedFreq], 'white');
hilitedFreq = (hilitedFreq === i) ? undefined : i;
highlightLegend(legend1[hilitedFreq], graphColors[hilitedFreq]);
if (samples)
updateCosineGraphAndLegend(wave1, ctx1, legend1, samples, hilitedFreq, dotLocations);
});
var count = 0;
onTargetDotDrag(wave1, dotLocations, function(dot, startX, startY, endX, endY) {
const yMin = 10, yMax = wave1.height - 10, yRange = yMax - yMin, yMid = yMin + (yRange / 2);
/* The 1.666 multiplier is because the range from 0-255 luminance only covers about
* 60% of the available vertical space in the graph (1 / 0.6 = 1.666...)
* (We don't use the entire space available, so that parts of the waveform which
* go higher or lower can still be seen) */
var luminance = Math.max(Math.min(255 * (0.5 + (1.666 * (yMid - endY) / yRange)), 255), 0);
if (luminance !== samples[dot]) {
image[index + dot] = luminance;
samples[dot] = luminance;
px2[index + dot].style.backgroundColor = `rgb(${luminance},${luminance},${luminance})`;
/* The other events which can trigger a repaint of the graph, such as clicking on a
* row of pixels or a legend entry, don't generally happen more than once per screen
* refresh. Drag events, though, can happen much faster than that, so let's make sure
* we only redraw the graph once per screen refresh (and avoid pegging the CPU). */
if (!repaintScheduled) {
repaintScheduled = true;
requestAnimationFrame(function() {
repaintScheduled = false;
updateCosineGraphAndLegend(wave1, ctx1, legend1, samples, hilitedFreq, dotLocations);
});
}
}
});
})();
</script>
<p>The heavy, black waveform is the sum of all the colored sinusoids. The height of the 8 black dots represent the brightness values of the 8 pixels in the selected row. Try comparing several different rows to see that in each case, the height of the dots matches the brightness of the pixels.</p>
<p>You might notice that for darker pixel values, the dots appear below the 'zero line', while for brighter pixels, they appear above it. This is because we subtracted 128 from each brightness value before converting to a waveform, so the range of possible values (0-255) would be centered on the zero line. This is also done when an image is stored in JPEG format.</p>
<p>You might have also noticed that one of the component waves doesn't look like a sinusoid; it's the red one. It is just a flat line. That is the <b>zero frequency</b> component; it represents the <b>average</b> of the 8 values. It shifts the black waveform up or down to just the right height for it to hit all 8 target points.</p>
<p>The legend displays the frequency and amplitude (on a scale of zero to one) of each component wave. Try clicking on the color swatches in the legend to see the component waves more clearly.</p>
<p>Of course, we could do exactly the same with columns of 8 pixels:</p>
<div class="pic-container">
<div class="pic-spacer"></div>
<div class="pic" id="pic3"></div>
</div>
<div class="wave-container">
<div class="wave-spacer"></div>
<div class="wave-graph">
<canvas id="waveform2" class="wave"></canvas>
</div>
<div class="wave-legend" id="legend2"></div>
</div>
<script>
'use strict';
const pic3 = document.getElementById('pic3');
const px3 = initPixelGrid(pic3);
setPixelValues(px3, coffeeCup);
const wave2 = document.getElementById('waveform2');
const ctx2 = wave2.getContext('2d');
const legend2 = initLegend(document.getElementById('legend2'));
function selectCol(pic, colNumber) {
var highlight = pic.getElementsByClassName('highlight')[0];
if (!highlight) {
highlight = document.createElement('span');
highlight.classList.add('highlight');
pic.appendChild(highlight);
}
highlight.style.left = (colNumber * 12.5) + '%';;
highlight.style.width = '12.5%';
highlight.style.top = 0;
highlight.style.height = '100%';
}
function transpose8by8(ary) {
return [
ary[0], ary[8], ary[16], ary[24], ary[32], ary[40], ary[48], ary[56],
ary[1], ary[9], ary[17], ary[25], ary[33], ary[41], ary[49], ary[57],
ary[2], ary[10], ary[18], ary[26], ary[34], ary[42], ary[50], ary[58],
ary[3], ary[11], ary[19], ary[27], ary[35], ary[43], ary[51], ary[59],
ary[4], ary[12], ary[20], ary[28], ary[36], ary[44], ary[52], ary[60],
ary[5], ary[13], ary[21], ary[29], ary[37], ary[45], ary[53], ary[61],
ary[6], ary[14], ary[22], ary[30], ary[38], ary[46], ary[54], ary[62],
ary[7], ary[15], ary[23], ary[31], ary[39], ary[47], ary[55], ary[63]
];
}
(function() {
const image = Array.from(coffeeCup); /* Copy, since we may modify the image */
var index, samples, hilitedFreq, repaintScheduled = false;
const dotLocations = new Array(8).fill([]);
autosizeGraph(wave2, function() {
if (samples) {
updateCosineGraphAndLegend(wave2, ctx2, legend2, samples, hilitedFreq, dotLocations);
} else {
graphDrawVerticalTicks(wave2, ctx2);
graphDrawZeroLine(wave2, ctx2);
}
});
onPixelClick(px3, function(i, px) {
selectCol(pic3, i % 8);
highlightLegend(legend2[hilitedFreq], 'white');
hilitedFreq = undefined;
index = (i % 8) * 8;
samples = transpose8by8(coffeeCup).slice(index, index + 8);
updateCosineGraphAndLegend(wave2, ctx2, legend2, samples, undefined, dotLocations);
});
onLegendClick(legend2, function(i) {
highlightLegend(legend2[hilitedFreq], 'white');
hilitedFreq = (hilitedFreq === i) ? undefined : i;
highlightLegend(legend2[hilitedFreq], graphColors[hilitedFreq]);
if (samples)
updateCosineGraphAndLegend(wave2, ctx2, legend2, samples, hilitedFreq, dotLocations);
});
var count = 0;
onTargetDotDrag(wave2, dotLocations, function(dot, startX, startY, endX, endY) {
const yMin = 10, yMax = wave1.height - 10, yRange = yMax - yMin, yMid = yMin + (yRange / 2);
var luminance = Math.max(Math.min(255 * (0.5 + (1.666 * (yMid - endY) / yRange)), 255), 0);
if (luminance !== samples[dot]) {
const pxIndex = (dot * 8) + (index / 8);
image[pxIndex] = luminance;
samples[dot] = luminance;
px3[pxIndex].style.backgroundColor = `rgb(${luminance},${luminance},${luminance})`;
if (!repaintScheduled) {
repaintScheduled = true;
requestAnimationFrame(function() {
repaintScheduled = false;
updateCosineGraphAndLegend(wave2, ctx2, legend2, samples, hilitedFreq, dotLocations);
});
}
}
});
})();
</script>
<p>Now, you have seen that each row or column of pixel values in this "coffee cup" icon can be converted to a sum of sinusoidal waves. But could that just be a fluke? Can we really do this with <i>any</i> sequence of eight brightness values?</p>
<p>If the answer is obvious, just humour me here. Go back and try dragging any of the black points up or down. The pixel colors and waveform will update as you drag.</p>
<p>Looks cool, doesn't it?</p>
<hr>
<p>Now... we need to talk.</p>
<p>I have misled you here. The <a href="https://betterexplained.com/articles/an-interactive-guide-to-the-fourier-transform/">link to BetterExplained</a> above probably tricked you into thinking that these waveforms were derived using a Fourier transform. Not so. This page is about the <b>Discrete Cosine Transform</b>, not the Fourier transform. But it was good for you to understand the idea of the Fourier transform before learning about the DCT.</p>
<p>If you take some time to play with both the <a href="https://betterexplained.com/examples/fourier/">Fourier transform demonstration on BetterExplained</a> and the DCT demonstration here, you might recognize some differences between these transforms. Even when given the same input, they produce a different series of component waves. (That's an interesting point; the <i>same</i> sequence of discrete time samples can be broken down into sinusoids in more than one way.)</p>
<p>Do you want to go back and give it a try? Either way, whenever you are ready, click to reveal two major differences:</p>
<p class="reveal hidden">The frequencies of the component waves found by the Fourier transform are <i>integral multiples</i> of the overall period of repetition; i.e. 0, 1F, 2F, 3F, and so on. Those found by the DCT are multiples of <i>half</i> the period of repetition: 0, 0.5F, 1F, 1.5F, 2F, 2.5F, etc.</p>
<p class="reveal hidden">The component waves found by the Fourier transform often have varying phase shifts (they have to, otherwise it wouldn't be possible to match the input values). Those found by the DCT all have the same phase. Look at the graphs above and you will see this clearly.</p>
<p>Other differences which you can't see from the demonstrations are:</p>
<ul>
<li>While the computation of the Fourier transform uses complex numbers, the DCT only involves real numbers.</li>
<li>The DCT is easier to compute. It's just a simple nested loop which evaluates a cosine function and a couple of adds and multiplies for each input sample.</li>
</ul>
<p>Now, another important point. Look back at the interactive DCT. As you drag the target points up and down, what is the maximum number of component waves which are required to match the 8 target points?</p>
<p class="reveal hidden">8.</p>
<p>Right. That is a key point for JPEG image compression. Remember, in a JPEG image, blocks of 8 pixels by 8 pixels are represented as a combination of component waves. The amplitude of each component is called a DCT <b>coefficient</b>. Since each block of input pixels converts to a fixed number of component waves, we just need to store a fixed number of coefficients for each such block. We don't need to store their frequencies, since those are known and are always the same. Nor do we need to store phase shifts, because all are at a constant phase.</p>
<p>Let's move into two dimensions now. We have demonstrated that if we just needed to represent 8 color values in a row, 8 coefficients would be enough. But how many coefficients do we need to represent 8-by-8, or 64, color values? Guess before clicking:</p>
<p class="reveal hidden">64.</p>
<p>You might ask: <i><b>Isn't JPEG a lossy format which compresses images into fewer bits? How can converting 64 numbers into 64 other numbers save storage space?</b></i></p>
<p>You are right; <i>by itself</i>, applying the DCT to a block of color samples doesn't result in any compression. It's like translating English text to French or Chinese; you are representing the same information in a different form. So it's not surprising that 64 color samples convert to 64 coefficients. However, the DCT is still a key step in achieving image compression. <a href="#why-useful">More on this later.</a></p>
<p>I want to show you examples of 8-by-8 images broken down into 64 two-dimensional waveforms. As in the one-dimensional case, the frequencies and phase of the 64 components will always be the same; only their amplitudes will vary. Before looking at any sample images, though, first let me show you the 64 components of the two-dimensional DCT at a fixed amplitude.</p>
<p>On the left is an 8 by 8 grid. In each cell is a <i>(u, v)</i> pair. (We talk about positions in a DCT coefficient matrix using <i>u, v</i> coordinates; coordinates in the corresponding block of pixels are named <i>x</i> and <i>y</i>.) On the right is a waveform graph. Click on each position in the coefficient matrix to see what the corresponding waveform looks like. You can click and drag on the graph to pivot.</p>
<div class="matrix-container">
<div class="matrix-spacer"></div>
<div class="edge-colors">
<table id="matrix1" class="matrix"></table>
</div>
</div>
<div class="matrix-container">
<div class="matrix-spacer"></div>
<div class="fill-parent">
<canvas id="demo3" style='cursor: move'></canvas>
</div>
</div>
<script>
'use strict';
/* DISPLAY OF 3D GRAPH, SHOWING HOW EACH IDCT COEFFICIENT AFFECTS EACH POINT IN X,Y SPACE
* How finely should we divide our 3D graph along the X/Y axes? */
const xRes = 120;
const yRes = 120;
const xyScale = 1.3;
const zScale = 0.3;
/* Take vectors from point1->point2 and point1->point3, cross them to get normal vector */
function crossProduct(point1, point2, point3) {
const x1 = point1[0] - point2[0], x2 = point1[0] - point3[0],
y1 = point1[1] - point2[1], y2 = point1[1] - point3[1],
z1 = point1[2] - point2[2], z2 = point1[2] - point3[2];
return [
y1 * z2 - y2 * z1,
z1 * x2 - z2 * x1,
x1 * y2 - x2 * y1
];
}
/* Multiply 3x3 matrices */
function matrixMultiply(m1, m2) {
return [
m1[0]*m2[0] + m1[1]*m2[3] + m1[2]*m2[6],
m1[0]*m2[1] + m1[1]*m2[4] + m1[2]*m2[7],
m1[0]*m2[2] + m1[1]*m2[5] + m1[2]*m2[8],
m1[3]*m2[0] + m1[4]*m2[3] + m1[5]*m2[6],
m1[3]*m2[1] + m1[4]*m2[4] + m1[5]*m2[7],
m1[3]*m2[2] + m1[4]*m2[5] + m1[5]*m2[8],
m1[6]*m2[0] + m1[7]*m2[3] + m1[8]*m2[6],
m1[6]*m2[1] + m1[7]*m2[4] + m1[8]*m2[7],
m1[6]*m2[2] + m1[7]*m2[5] + m1[8]*m2[8],
];
}
function plotIDCTAmplitudeGraph(u, v) {
const points = [];
for (var x = 0; x < xRes; x++) {
for (var y = 0; y < yRes; y++) {
/* (x,y) in IDCT formula */
const px_x = ((x * 8) - 0.5) / xRes;
const px_y = ((y * 8) - 0.5) / yRes;
/* Value of IDCT coefficient (u,v) at each point in (x,y) plane */
const z = Math.cos((2*px_x + 1) * u * Math.PI / 16) * Math.cos((2*px_y + 1) * v * Math.PI / 16);
points.push([(xyScale * x / xRes) - (xyScale/2), (xyScale * y / yRes) - (xyScale/2), z * zScale]);
}
}
return points;
}
function buildSurfaceTriangles(points) {
const triangles = [], normals = [];
/* Iterate over grid of (x,y,z) values; stitch them together with triangles */
for (var x = 0; x < xRes-1; x++) {
for (var y = 0; y < yRes-1; y++) {
const botLeft = points[y + (x * yRes)];
const topLeft = points[y + 1 + (x * yRes)];
const botRight = points[y + ((x + 1) * yRes)];
const topRight = points[y + 1 + ((x + 1) * yRes)];
/* Triangle 1, covering half of this grid square */
triangles.push(topLeft[0]); triangles.push(topLeft[1]); triangles.push(topLeft[2]);
triangles.push(botLeft[0]); triangles.push(botLeft[1]); triangles.push(botLeft[2]);
triangles.push(topRight[0]); triangles.push(topRight[1]); triangles.push(topRight[2]);
/* Triangle 2, covering the remainder of this grid square */
triangles.push(botRight[0]); triangles.push(botRight[1]); triangles.push(botRight[2]);
triangles.push(botLeft[0]); triangles.push(botLeft[1]); triangles.push(botLeft[2]);
triangles.push(topRight[0]); triangles.push(topRight[1]); triangles.push(topRight[2]);
/* Compute surface normals; these will be used for lighting */
const normal1 = crossProduct(topLeft, botLeft, topRight);
const normal2 = crossProduct(botRight, topRight, botLeft);
/* We need one normal vector for _each_ vertice */
for (var i = 0; i < 3; i++) {
normals.push(normal1[0]); normals.push(normal1[1]); normals.push(normal1[2]);
}
for (var i = 0; i < 3; i++) {
normals.push(normal2[0]); normals.push(normal2[1]); normals.push(normal2[2]);
}
}
}
return [triangles, normals];
}
/* We will highlight edges so orientation of surface can be seen */
function buildEdgeLines(points) {
const edges = [], edgeColors = [];
const topR = 0.2, topG = 0.41, topB = 0.82; /* blue */
const botR = 0.82, botG = 0.18, botB = 0.17; /* red */
const leftR = 0.9, leftG = 0.85, leftB = 0.14; /* yellow */
const rightR = 0.18, rightG = 0.81, rightB = 0.19; /* green */
for (var x = 0; x < xRes-1; x++) {
var y = 0;
const botLeft = points[y + (x * yRes)];
const botRight = points[y + ((x + 1) * yRes)];
edges.push(botLeft[0]); edges.push(botLeft[1]); edges.push(botLeft[2]);
edges.push(botRight[0]); edges.push(botRight[1]); edges.push(botRight[2]);
edgeColors.push(botR); edgeColors.push(botG); edgeColors.push(botB);
edgeColors.push(botR); edgeColors.push(botG); edgeColors.push(botB);
y = yRes-2;
const topLeft = points[y + 1 + (x * yRes)];
const topRight = points[y + 1 + ((x + 1) * yRes)];
edges.push(topLeft[0]); edges.push(topLeft[1]); edges.push(topLeft[2]);
edges.push(topRight[0]); edges.push(topRight[1]); edges.push(topRight[2]);
edgeColors.push(topR); edgeColors.push(topG); edgeColors.push(topB);
edgeColors.push(topR); edgeColors.push(topG); edgeColors.push(topB);
}
for (var y = 0; y < yRes-1; y++) {
var x = 0;
const topLeft = points[y + 1 + (x * yRes)];
const botLeft = points[y + (x * yRes)];
edges.push(topLeft[0]); edges.push(topLeft[1]); edges.push(topLeft[2]);
edges.push(botLeft[0]); edges.push(botLeft[1]); edges.push(botLeft[2]);
edgeColors.push(leftR); edgeColors.push(leftG); edgeColors.push(leftB);
edgeColors.push(leftR); edgeColors.push(leftG); edgeColors.push(leftB);
x = xRes-2;
const topRight = points[y + 1 + ((x + 1) * yRes)];
const botRight = points[y + ((x + 1) * yRes)];
edges.push(topRight[0]); edges.push(topRight[1]); edges.push(topRight[2]);
edges.push(botRight[0]); edges.push(botRight[1]); edges.push(botRight[2]);
edgeColors.push(rightR); edgeColors.push(rightG); edgeColors.push(rightB);
edgeColors.push(rightR); edgeColors.push(rightG); edgeColors.push(rightB);
}
return [edges, edgeColors];
}
function xRotationMatrix(radians) {
return [
1, 0, 0,
0, Math.cos(radians), Math.sin(radians),
0, -Math.sin(radians), Math.cos(radians)
];
}
function yRotationMatrix(radians) {
return [
Math.cos(radians), 0, Math.sin(radians),
0, 1, 0,
-Math.sin(radians), 0, Math.cos(radians)
];
}
const canvas3 = document.getElementById('demo3');
const gl = canvas3.getContext('webgl2');
gl.enable(gl.DEPTH_TEST);
function createShader(gl, type, src) {
const shader = gl.createShader(type);
gl.shaderSource(shader, src);
gl.compileShader(shader);
if (gl.getShaderParameter(shader, gl.COMPILE_STATUS))
return shader;
console.log(gl.getShaderInfoLog(shader));
gl.deleteShader(shader);
}
/* Shader code which will be compiled by WebGL and run on GPU */
const vertexShader = createShader(gl, gl.VERTEX_SHADER, `#version 300 es
uniform mat3 rotationMatrix;
in vec3 position;
in vec3 normal;
out vec3 vNormal;
void main() {
gl_Position = vec4(rotationMatrix * position, 1);
vNormal = rotationMatrix * normal;
}`);
const fragmentShader = createShader(gl, gl.FRAGMENT_SHADER, `#version 300 es
precision highp float;
in vec3 vNormal;
out vec4 color;
void main() {
vec3 normal = normalize(vNormal);
vec3 lightDirection = normalize(vec3(0.5, 0.5, 0.5));
color = vec4(1, 0.3, 0.6, 1);
color.rgb *= 0.35 + (0.65 * abs(dot(normal, lightDirection)));
}`);
const edgeVertexShader = createShader(gl, gl.VERTEX_SHADER, `#version 300 es
uniform mat3 rotationMatrix;
in vec3 position;
in vec3 edgeColor;
out vec3 _color;
void main() {
gl_Position = vec4(rotationMatrix * position, 1);
_color = edgeColor;
}`);
const edgeFragmentShader = createShader(gl, gl.FRAGMENT_SHADER, `#version 300 es
precision highp float;
in vec3 _color;
out vec4 color;
void main() {
color = vec4(_color, 1);
}`);
function createProgram(gl, vertexShader, fragmentShader, primitiveType, inputs, uniforms) {
const program = gl.createProgram();
gl.attachShader(program, vertexShader);
gl.attachShader(program, fragmentShader);
gl.linkProgram(program);
if (!gl.getProgramParameter(program, gl.LINK_STATUS)) {
console.log(gl.getProgramInfoLog(program));
gl.deleteProgram(program);
return;
}
const vertexArray = gl.createVertexArray();
function draw(arrays, uniformValues) {
gl.useProgram(program);
gl.bindVertexArray(vertexArray);
for (var i = 0; i < inputs.length; i++) {
const buffer = gl.createBuffer();
gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array(arrays[i]), gl.STATIC_DRAW);
const attr = gl.getAttribLocation(program, inputs[i]);
gl.enableVertexAttribArray(attr);
gl.vertexAttribPointer(attr, 3, gl.FLOAT, false, 0, 0);
}
for (var uniform of uniforms)
gl.uniformMatrix3fv(gl.getUniformLocation(program, uniform), false, uniformValues[uniform]);
gl.drawArrays(primitiveType, 0, arrays[0].length / 3);
}
return draw;
}
const drawTriangles = createProgram(gl, vertexShader, fragmentShader, gl.TRIANGLES, ['position', 'normal'], ['rotationMatrix']);
const drawEdges = createProgram(gl, edgeVertexShader, edgeFragmentShader, gl.LINES, ['position', 'edgeColor'], ['rotationMatrix']);
/* Fairly arbitrary product which seems to give a good starting angle */
var rotationMatrix = matrixMultiply(xRotationMatrix(1), yRotationMatrix(0.5));
var scene = {};
function plotGraph(u, v) {
/* (u, v) has changed */
const points = plotIDCTAmplitudeGraph(u, v);
const [triangles, normals] = buildSurfaceTriangles(points);
const [edges, edgeColors] = buildEdgeLines(points);
scene = { triangles: triangles, normals: normals, edges: edges, edgeColors: edgeColors };
redrawGraph();
}
function redrawGraph() {
/* Either (u,v) has changed, or graph has been rotated, or window has been resized */
requestAnimationFrame(function() {
gl.clearColor(0.05, 0.15, 0.25, 1);
gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT);
drawTriangles([scene.triangles, scene.normals], { rotationMatrix: rotationMatrix });
drawEdges([scene.edges, scene.edgeColors], { rotationMatrix: rotationMatrix });
});
}
/* Allow drag to rotate 3D scene */
var dragX, dragY;
canvas3.addEventListener('mousedown', function(event) {
dragX = event.offsetX;
dragY = event.offsetY;
});
canvas3.addEventListener('touchstart', function(event) {
event.preventDefault();
const rect = canvas3.getBoundingClientRect();
const touch = event.changedTouches[0];
dragX = touch.clientX - rect.left;
dragY = touch.clientY - rect.top;
})
canvas3.addEventListener('mouseup', function(event) { dragX = dragY = undefined; });
canvas3.addEventListener('mouseleave', function(event) { dragX = dragY = undefined; });
canvas3.addEventListener('touchcancel', function(event) { dragX = dragY = undefined; });
canvas3.addEventListener('touchend', function(event) { dragX = dragY = undefined; });
function handleGraphDrag(offsetX, offsetY) {
var changedRotation = false;
if (dragX && offsetX !== dragX) {
rotationMatrix = matrixMultiply(rotationMatrix, yRotationMatrix((offsetX - dragX) / 40));
dragX = offsetX;
changedRotation = true;
}
if (dragY && offsetY !== dragY) {
rotationMatrix = matrixMultiply(rotationMatrix, xRotationMatrix((dragY - offsetY) / 40));
dragY = offsetY;
changedRotation = true;
}
if (changedRotation)
redrawGraph();
}
canvas3.addEventListener('mousemove', function(event) {
handleGraphDrag(event.offsetX, event.offsetY);
});
canvas3.addEventListener('touchmove', function(event) {
event.preventDefault();
const rect = canvas3.getBoundingClientRect();
const touch = event.changedTouches[0];
handleGraphDrag(touch.clientX - rect.left, touch.clientY - rect.top);
});
function cellSelected(u, v, cell) {
plotGraph(u, v);
const previousSelected = matrix1.querySelector('td.selected')
if (previousSelected)
previousSelected.classList.remove('selected');
cell.classList.add('selected');
}
function initCoeffMatrix(tbl, onClick) {
const cells = [];
const values = new Array(64).fill(0);
for (var i = 0; i < 64; i++) {
const cell = document.createElement('td');
cells.push(cell);
/* Capture the value of `i` on _this_ iteration */
(function(i) {
cell.addEventListener('click', function(event) {
onClick(i % 8, Math.floor(i / 8), cell, event);
});
})(i);
}
for (var i = 0; i < 8; i++) {
const tr = document.createElement('tr');
for (var j = 0; j < 8; j++)
tr.appendChild(cells[i*8 + j]);
tbl.appendChild(tr);
}
return cells;
}
window.addEventListener('load', function() {
autosizeGraph(canvas3, function() {
gl.viewport(0, 0, canvas3.width, canvas3.height);
redrawGraph();
});
const matrix1 = document.getElementById('matrix1');
const cells1 = initCoeffMatrix(matrix1, cellSelected);
for (var i = 0; i < 64; i++) {
cells1[i].innerText = `${i % 8},${Math.floor(i / 8)}`;
}
cellSelected(0, 0, matrix1.querySelector('td'));
})
</script>
<p>Do you see how the 64 components in the two-dimensional case correspond to the 8 components in the one-dimensional case? Look again at the waveform for coefficient <i>(0, 0)</i>, in the top-left corner. This one is very important, since it gives the average value of all 64 color samples in a block. It is called the <b>DC coefficient</b>. The other 63 coefficients represent all the deviations from the average and are called the <b>AC coefficients</b>.</p>
<p id='idctdemo'>Now you are ready to see the Inverse Discrete Cosine Transform at work. Click on each of the sample images below to see its DCT coefficients. Click on any coefficient to disable it and see what the image looks like with it removed; or Control-click to enable only a single coefficient and see its contribution to the image.</p>
<div style="margin: auto">
<div class="pic-container smallpic">
<div class="pic-spacer"></div>
<div class="pic" id="smallpic1"></div>
</div>
<div class="pic-container smallpic">
<div class="pic-spacer"></div>
<div class="pic" id="smallpic2"></div>
</div>
<div class="pic-container smallpic">
<div class="pic-spacer"></div>
<div class="pic" id="smallpic3"></div>
</div>
<div class="pic-container smallpic">
<div class="pic-spacer"></div>
<div class="pic" id="smallpic4"></div>
</div>
<div class="pic-container smallpic">
<div class="pic-spacer"></div>
<div class="pic" id="smallpic5"></div>
</div>
</div>
<div class="matrix-container" style="border: none; margin-right: 1.5em; width: 45%">
<div class="matrix-spacer"></div>
<div class="fill-parent">
<table id="matrix2" class="matrix"></table>
</div>
</div>
<div class="pic-container" style="width: 45%">
<div class="pic-spacer"></div>
<div class="pic" id="pic4"></div>
</div>
<script>
const pic4 = document.getElementById('pic4');
const px4 = initPixelGrid(pic4);
const matrix2 = document.getElementById('matrix2');
const coeffs = new Array(64).fill(0);
const enabled = new Array(64).fill(true);
const face1 = [
72, 83, 57, 41, 42, 51, 71, 68,
79, 70, 101, 150, 96, 39, 71, 72,
89, 51, 150, 183, 168, 82, 55, 71,
98, 65, 158, 164, 149, 119, 50, 78,
109, 43, 159, 174, 164, 132, 34, 62,
92, 30, 109, 150, 135, 85, 24, 93,
87, 71, 60, 115, 122, 40, 35, 111,
96, 176, 160, 101, 140, 98, 37, 137
];
const dog1 = [
170, 113, 145, 210, 207, 151, 76, 114,
79, 94, 122, 212, 230, 218, 162, 118,
185, 184, 210, 231, 249, 203, 138, 119,
216, 226, 218, 241, 245, 176, 59, 29,
196, 201, 217, 244, 238, 198, 90, 79,
194, 182, 212, 238, 210, 147, 91, 73,
180, 195, 191, 187, 91, 24, 6, 125,
216, 220, 207, 118, 105, 43, 62, 172
];
const pawprint = [
255, 244, 76, 217, 209, 50, 244, 255,
255, 217, 0, 168, 134, 0, 231, 255,
156, 198, 149, 217, 198, 149, 202, 156,
60, 122, 250, 149, 156, 255, 71, 71,
209, 202, 156, 0, 0, 191, 198, 217,
255, 202, 0, 0, 0, 0, 217, 255,
255, 122, 0, 0, 0, 0, 143, 255,
255, 174, 50, 143, 134, 0, 202, 255
];
const chinese = [
255, 255, 255, 103, 17, 255, 255, 255,
255, 255, 255, 111, 38, 254, 181, 255,
255, 131, 71, 28, 36, 34, 4, 102,
255, 182, 122, 44, 140, 255, 28, 129,
255, 255, 188, 36, 255, 231, 10, 255,
255, 255, 74, 125, 255, 118, 36, 255,
255, 104, 59, 255, 144, 42, 54, 255,
147, 103, 255, 255, 122, 18, 158, 255
];
const rose = [
255, 255, 254, 201, 143, 185, 251, 255,
241, 152, 128, 104, 90, 96, 138, 213,
176, 90, 107, 106, 86, 111, 109, 113,
134, 113, 109, 83, 73, 109, 126, 218,
175, 112, 97, 100, 110, 96, 124, 254,
250, 181, 90, 94, 85, 87, 68, 255,
255, 233, 69, 73, 75, 70, 109, 254,
255, 251, 177, 135, 76, 96, 201, 251
];
const smallpic1 = document.getElementById('smallpic1');
const smallpic2 = document.getElementById('smallpic2');
const smallpic3 = document.getElementById('smallpic3');
const smallpic4 = document.getElementById('smallpic4');
const smallpic5 = document.getElementById('smallpic5');
setPixelValues(initPixelGrid(smallpic1), face1);
setPixelValues(initPixelGrid(smallpic2), dog1);
setPixelValues(initPixelGrid(smallpic3), pawprint);
setPixelValues(initPixelGrid(smallpic4), chinese);
setPixelValues(initPixelGrid(smallpic5), rose);
const cells2 = initCoeffMatrix(matrix2, function(col, row, td, event) {
const index = row*8 + col;
if (event.ctrlKey)
selectCoeff(index, td);
else
toggleCoeff(index, td);
});
for (var i = 0; i < 64; i++)
cells2[i].innerText = '0';
function setCoefficients(samples) {
const newCoeffs = forwardDCT(samples.map((s) => s - 128));
for (var i = 0; i < 64; i++)
cells2[i].innerText = coeffs[i] = newCoeffs[i];
enableAllCoeffs();
}
smallpic1.addEventListener('click', function() { setCoefficients(face1); });
smallpic2.addEventListener('click', function() { setCoefficients(dog1); });
smallpic3.addEventListener('click', function() { setCoefficients(pawprint); });
smallpic4.addEventListener('click', function() { setCoefficients(chinese); });
smallpic5.addEventListener('click', function() { setCoefficients(rose); });
function enableCoeff(index, td) {
enabled[index] = true;
td.classList.remove('disabled');
drawPic();
}
function disableCoeff(index, td) {
enabled[index] = false;
td.classList.add('disabled');
drawPic();
}
function toggleCoeff(index, td) {
if (coeffs[index] === 0)
return;
enabled[index] = !enabled[index];
td.classList.toggle('disabled');
drawPic();
}
function selectCoeff(index, td) {
if (coeffs[index] === 0)
return;
enableCoeff(index, td);
for (var i = 0; i < 64; i++) {
if (coeffs[i] !== 0 && i !== index) {
enabled[i] = false;
cells2[i].classList.add('disabled');
}
}
drawPic();
}
function enableAllCoeffs() {
for (var i = 0; i < 64; i++) {
enabled[i] = true;
cells2[i].classList.remove('disabled');
}
drawPic();
}
function drawPic() {
drawCoeffs(coeffs.map((c, idx) => enabled[idx] ? c : 0));
}
function drawCoeffs(coeffs) {
const samples = inverseDCT(coeffs);
for (var i = 0; i < 64; i++) {
const luminance = Math.round(samples[i] + 128);
px4[i].style.backgroundColor = `rgb(${luminance},${luminance},${luminance})`;
}
}
function forwardDCT(samples) {
const coeffs = [];
for (var v = 0; v < 8; v++) {
for (var u = 0; u < 8; u++) {
var coeff = 0;
for (var x = 0; x < 8; x++) {
for (var y = 0; y < 8; y++) {
coeff += samples[y*8 + x] *
Math.cos(Math.PI * u * (2*x + 1) / 16) *
Math.cos(Math.PI * v * (2*y + 1) / 16);
}
}
if (u == 0)
coeff /= Math.sqrt(2);
if (v == 0)
coeff /= Math.sqrt(2);
coeffs.push(Math.round(coeff / 4));
}
}
return coeffs;
}
function inverseDCT(coeffs) {
const samples = new Array(64).fill(0);
for (var x = 0; x < 8; x++) {
for (var y = 0; y < 8; y++) {
var sample = 0;
for (var u = 0; u < 8; u++) {
const cu = (u === 0) ? (1 / Math.sqrt(2)) : 1;
for (var v = 0; v < 8; v++) {
const cv = (v === 0) ? (1 / Math.sqrt(2)) : 1;
var coefficient = coeffs[v*8 + u];
if (coefficient === 0)
continue;
sample += cu * cv * coefficient *
Math.cos(Math.PI * u * (2*x + 1) / 16) *
Math.cos(Math.PI * v * (2*y + 1) / 16);
}
}
samples[y*8 + x] = Math.round(sample / 4);
}
}
return samples;
}
drawPic();
</script>
<p>Here's a hint of one interesting thing to look for. If you try disabling various coefficients in the images, which coefficients generally seem to have a smaller effect on the picture? (Those at the top of the matrix, the left, right, bottom, or towards a certain corner?) This has much to do with JPEG compression. More on that towards the end of the post...</p>
<hr>
<p>If you've made it this far, you might want to know how the DCT and IDCT are calculated. (<a href="#why-useful">Don't care about the math? Feel free to skip it.</a>) For simplicity, we'll stick to eight pixel values in one dimension (a row or column). The two-dimensional transforms are very similar.</p>
<p>First, the DCT, which converts pixel values to coefficients:</p>
<!-- TeX: S_{u} = {1 \over 2} C_u \sum_{x=0}^7 s_x\ cos {(2x + 1)\pi u \over 16} -->
<span class="katex-display"><span class="katex">
<span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span>
<span class="mord Su"><span class="mord mathnormal" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span>
<span class="mspace" style="margin-right: 0.277778em;"></span>
<span class="mrel">=</span>
<span class="mspace" style="margin-right: 0.277778em;"></span>
</span>
<span class="base"><span class="strut" style="height: 3.06823em; vertical-align: -1.26711em;"></span>
<span class="mord half"><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span>
<span class="mord Cu"><span class="mord mathnormal" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span>
<span class="mspace" style="margin-right: 0.166667em;"></span>
<span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.80111em;"><span class="" style="top: -1.88289em; margin-left: 0em;"><span class="pstrut" style="height: 3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight x">x</span><span class="mrel mtight">=</span><span class="mord mtight">0</span></span></span></span><span class="" style="top: -3.05001em;"><span class="pstrut" style="height: 3.05em;"></span><span class=""><span class="mop op-symbol large-op sum">∑</span></span></span><span class="" style="top: -4.30001em; margin-left: 0em;"><span class="pstrut" style="height: 3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">7</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 1.26711em;"><span class=""></span></span></span></span></span>
<span class="mspace" style="margin-right: 0.166667em;"></span>
<span class="mord pxval"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">x</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span>
<span class="mspace"> </span>
<span class="mord mathnormal">cos</span>
<span class="mord"><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.427em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord freq">16</span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mopen">(</span><span class="mord freq">2</span><span class="mord mathnormal x">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord freq">1</span><span class="mclose">)</span><span class="mord mathnormal freq">π</span><span class="mord mathnormal u" style="margin-right: 0.03588em;">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span>
</span></span></span></span>
<p>where</p>
<!-- TeX: C_0 = {1 \over \sqrt 2}, C_{1-7} = 1 -->
<span class="katex-display"><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord Cu"><span class="mord mathnormal" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.25144em; vertical-align: -0.93em;"></span><span class="mord"><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span class="" style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;">2</span></span><span class="" style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span class=""></span></span></span></span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord Cu"><span class="mord mathnormal" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mbin mtight">−</span><span class="mord mtight">7</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>
<p>That's quite a mouthful in English: to calculate <span class="Su">the DCT coefficient <i>u</i></span>, <span class="sum">loop over</span> all eight pixels and <span class="sum">sum up</span>: the <span class="pxval">pixel value</span> times the cosine of: <span class="freq">twice</span> <span class="x">the pixel's <i>x</i> coordinate</span> plus <span class="freq">one</span>, times <span class="freq">π</span>, times <span class="u"><i>u</i></span>, divided by <span class="freq">16</span>. <span class="half">Halve the total.</span> Further, <span class="Cu">if <i>u</i> is zero, divide the total again by root 2</span>.</p>
<p>Or if you speak JavaScript:</p>
<div class="highlight"><pre class="highlight"><code>const coefficients = [];
for (var <span class="u">u</span> = 0; <span class="u">u</span> < 8; <span class="u">u</span>++) {
var <span class="Su">coeff</span> = 0;
<span class="sum">for (var <span class="x">x</span> = 0; <span class="x">x</span> < 8; <span class="x">x</span>++)</span>
<span class="Su">coeff</span> <span class="sum">+=</span> <span class="pxval">pixelValues[x]</span> * Math.cos(((<span class="freq">2</span> * <span class="x">x</span>) + <span class="freq">1</span>) * <span class="freq">Math.PI</span> * <span class="u">u</span> / <span class="freq">16</span>);
<span class="half">coeff /= 2;</span>
<span class="Cu">if (u === 0)
coeff /= Math.sqrt(2);</span>
coefficients.push(<span class="Su">coeff</span>);
}</code></pre></div>
<!-- TeX: (2x + 1)\pi / 16 -->
<p>Note that <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">2</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">1</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mord">/16</span></span></span></span> ranges from just above zero to just below π; that is, half a complete cycle for the cosine function. So when <i>u</i> is one, the component wave only makes half a cycle as <i>x</i> moves from zero to seven. If <i>u</i> is two, then the input to the cosine function increments twice as fast, and a full cycle is made. Each increment of <i>u</i> increases the frequency of the component wave by 0.5Hz.</p>
<p>On the other hand, when <i>u</i> is zero, the cosine always evaluates to one, and we are essentially just summing up the eight pixel values.</p>
<p>Essentially, coefficient <i>u</i> expresses how well the eight values "fit" a (0.5<i>u</i>)Hz cosine wave. Each value which is positive where the cosine wave is positive (or negative where it is negative) increases the coefficient. Each value which is of opposite sign to the cosine wave at its location on the x-axis decreases the coefficient. If the coefficient value is very positive, that means the sequence of samples closely fits a cosine wave; or if the coefficient value is very negative, that means the samples are close to the <i>opposite</i> of a cosine wave (that is, a cosine wave shifted by 180 degrees).</p>
<p>The Inverse DCT is very similar. Shall we stick to JavaScript this time?</p>
<div class="highlight"><pre class="highlight"><code>const pixelValues = [];
for (var <span class="x">x</span> = 0; <span class="x">x</span> < 8; <span class="x">x</span>++) {
var <span class="pxval">value</span> = <span class="Su">coefficients[0]</span> <span class="Cu">/ Math.sqrt(2)</span>;
<span class="sum">for (var <span class="u">u</span> = 1; <span class="u">u</span> < 8; <span class="u">u</span>++)</span>
<span class="pxval">value</span> <span class="sum">+=</span> <span class="Su">coefficients[u]</span> * Math.cos(((<span class="freq">2</span> * <span class="x">x</span>) + <span class="freq">1</span>) * <span class="freq">Math.PI</span> * <span class="u">u</span> / <span class="freq">16</span>);
<span class="half">value /= 2;</span>
pixelValues.push(<span class="pxval">value</span>);
}</code></pre></div>
<hr>
<p>Let me share another thing which I find fascinating about the DCT. Actually, no; let me show you and see if you can recognize it yourself.</p>
<p>Earlier I showed you waveform graphs demonstrating how the DCT converts a sequence of discrete color samples to a sum of sinusoid waves. The graphs were bounded tightly around the eight samples on the X-axis. This time let's stretch out the X-axis and let the waves carry on to the left and right. I will draw grey dots at evenly spaced intervals, so you can see if the same pattern repeats itself every eight time units or not.</p>
<div class="pic-container">
<div class="pic-spacer"></div>
<div class="pic" id="pic5"></div>
</div>
<div class="wave-container">
<div class="wave-spacer"></div>
<div class="fill-parent">
<canvas id="waveform3" class="wave"></canvas>
</div>
</div>
<script>
'use strict';
const pic5 = document.getElementById('pic5');
const px5 = initPixelGrid(pic5);
setPixelValues(px5, coffeeCup);
const wave3 = document.getElementById('waveform3');
const ctx3 = wave3.getContext('2d');
function graphDrawManyVerticalTicks(canvas, ctx) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
ctx.lineWidth = 1;
const xRange = xMax - xMin;
for (var i = 0; i < 40; i++) {
ctx.strokeStyle = (i % 8 === 0) ? 'rgb(60,100,200)' : vertTickColor;
ctx.beginPath();
ctx.moveTo(xMin + (xRange * (i + 0.5) / 40), yMin);
ctx.lineTo(xMin + (xRange * (i + 0.5) / 40), yMax);
ctx.stroke();
}
}
function graphTargetDotsExtended(canvas, ctx, values) {
const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10;
const xRange = xMax - xMin,
yRange = yMax - yMin,
yMid = (yMax + yMin) / 2,
dotSize = 4.5;
const reversed = Array.from(values).reverse();
values = values.concat(reversed).concat(values).concat(reversed).concat(values);
for (var i = 0; i < 40; i++) {
ctx.fillStyle = (i >= 16 && i < 24) ? 'black' : '#888888';
ctx.beginPath();
const x = xMin + (xRange * (i + 0.5) / 40);
const y = yMid - (values[i] * yRange / 2);
ctx.arc(x, y, dotSize, 0, Math.PI * 2);
ctx.fill();
}
}
function drawExtendedCosineGraph(canvas, ctx, values, coeffs, dotLocations) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
graphDrawManyVerticalTicks(canvas, ctx);
graphDrawZeroLine(canvas, ctx);
graphCosines(canvas, ctx, [0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5], coeffs, graphColors);
graphSumOfCosines(canvas, ctx, [0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5], coeffs, graphColors);
graphTargetDotsExtended(canvas, ctx, values);
}
function updateExtendedCosineGraph(canvas, ctx, samples) {
var shifted = samples.map((s) => s - 128);
const coeffs = discreteCosTransform1D(shifted).map((c) => c / 420);
coeffs[0] /= Math.sqrt(2);
shifted = shifted.map((s) => s / 210);
drawExtendedCosineGraph(canvas, ctx, shifted, coeffs);
}
(function() {
const image = Array.from(coffeeCup);
var samples;
autosizeGraph(wave3, function() {
if (samples) {
updateExtendedCosineGraph(wave3, ctx3, samples);
} else {
graphDrawManyVerticalTicks(wave3, ctx3);
graphDrawZeroLine(wave3, ctx3);
}
});
onPixelClick(px5, function(i, px) {
selectRow(pic5, Math.floor(i / 8));
const index = Math.floor(i / 8) * 8;
samples = image.slice(index, index + 8);
updateExtendedCosineGraph(wave3, ctx3, samples);
});
})();
</script>
<p>Look carefully at the pattern created when the waves derived from the DCT are extended to the left and right. Is it simply repeating the same pattern every eight time units? Or...what?</p>
<div class="reveal hidden">
<p>Every eight time units, the pattern is mirrored left to right. It appears left to right, then right to left, then left to right, and so on. The effect is such that it actually repeats every 16 time units, not every eight.</p>
<p>One of the secrets of the DCT, revealed! The DCT is equivalent to taking our eight samples, concatenating them with the same eight samples (reversed), and doing a discrete Fourier transform on that concatenation. No wonder the frequencies of the resulting sinusoids increment in 0.5F steps; our "Fourier transform" is working on twice as many samples.</p>
<p>Of course, the DCT is much faster to compute than a DFT on twice as many samples.</p>
</div>
<p>This is part of why the DCT is useful for image compression. When an image is broken down into blocks of pixels, the color of the left edge of a block will often be different from the right edge, and likewise for the top and bottom edges. If we used a discrete Fourier transform on those color samples, the discontinuity between the colors of opposing edges would tend to produce strong high-frequency component waves. (When breaking a waveform down into sinusoids, any sharp "jumps" result in strong high-frequency components.) But since the DCT, in effect, buts the block up with a mirror image of itself on each side, that discontinuity doesn't exist, and the high-frequency components will usually be much weaker.</p>
<hr id="why-useful">
<p>I still haven't told you why the DCT is useful for image compression. First, understand that not all the information contained in an image is equally important or noticeable to a human viewer. It happens that converting color samples to the frequency domain concentrates the information which is most detectable by our visual system in the coefficients at the <i>top-left</i> of the DCT matrix. Conversely, the information which is least perceptible to our visual system is concentrated in the coefficients at the bottom-right.</p>
<p>In this way, the DCT sets things up for subsequent stages of compression to work their fullest effect. First, <b>quantization</b>. This stage throws away part of the data in the less-significant bits of the DCT coefficients. Since we know that the values of the coefficients towards the bottom-right have less of an effect on what we see, those can be heavily quantized, while retaining more bits of the coefficients toward the top-left. That means we can discard a significant amount of data with little effect on visual quality.</p>
<p>The DCT works synergistically with quantization and the <a href="https://en.wikipedia.org/wiki/JPEG#/media/File:JPEG_ZigZag.svg">zig-zag ordering</a> of coefficients to make the final <b>entropy coding</b> stage more effective. This stage applies a lossless compression algorithm to the quantized coefficients.</p>
<p>Many images will have smaller coefficient values toward the bottom right of the matrix, and after quantization is applied, these may become zeroes. So the <a href="https://en.wikipedia.org/wiki/JPEG#/media/File:JPEG_ZigZag.svg">zig-zag ordering</a> of coefficients will tend to produce runs of zeroes toward the end of each block. Those consecutive zeroes can then be represented using an efficient run-length encoding.</p>
<p>Interestingly, both the <a href="https://en.wikipedia.org/wiki/WebP">WebP</a> and AVIF compressed image formats also transform color samples into the frequency domain. Both can use either the Discrete Cosine Transform or a different transform which serves a similar purpose.</p>
<p>The other popular compressed image formats are PNG and GIF. Neither of these transform samples into the frequency domain.</p>
<p>The <a href='/huffman-coding/'>next post in this series</a> will explore Huffman coding, a lossless compression algorithm which is another key ingredient of JPEG.</p>
<script>
'use strict';
function revealText(event) {
event.preventDefault();
this.classList.remove('hidden');
this.removeEventListener('click', revealText);
}
document.querySelectorAll('.reveal').forEach((el) => el.addEventListener('click', revealText));
</script>A key step in JPEG image compression is converting 8-by-8-pixel blocks of color values into the frequency domain, so instead of storing color values, we store amplitudes of sinusoidal waveforms. This is a fun little bit of applied math, and you might enjoy seeing how it works. It all really started in the early 1820's, when Joseph Fourier figured out that any periodic waveform can be broken down into a sum of sinusoids. Kalid Azad has explained this much better than I could over at BetterExplained, and if you are not familiar with the Fourier transform, I recommend you go learn about it from Kalid first. I'll be waiting right here. Welcome back! Let's apply what you learned to a block of pixel values. How about this block of pixels right here? Take a horizontal slice, 8 pixels wide, from that block. Take the position of each pixel as an x value and its brightness as a y value. Then the 8 pixels correspond to 8 (x, y) points on a plane, and we could find a combination of sinusoidal waves that would go through those 8 points. We could, and we will. Let's do it now. Click on any row of the pixel grid below: 'use strict'; const pic2 = document.getElementById('pic2'); const px2 = initPixelGrid(pic2); setPixelValues(px2, coffeeCup); const wave1 = document.getElementById('waveform1'); const ctx1 = wave1.getContext('2d'); const vertTickColor = 'rgb(230,230,150)'; function graphDrawVerticalTicks(canvas, ctx) { const xMin = 10, xMax = canvas.width - 10, yMin = 10, yMax = canvas.height - 10; ctx.strokeStyle = vertTickColor; ctx.lineWidth = 1; const xRange = xMax - xMin; for (var i = 0; i sum+a, 0) * (yRange / 2))); for (var x = xMin+2; x ' + '' + 'F = ' + (0.5 * i) + 'Hz'; parent.appendChild(entry); entries.push(entry); } return entries; } const legend1 = initLegend(document.getElementById('legend1')); function updateLegend(entries, amplitudes) { for (var i = 0; i s - 128); /* After DCT, apply an arbitrary scaling factor to make the graph fit well * in the available space */ const coeffs = discreteCosTransform1D(shifted).map((c) => c / 420); /* The IDCT applies a correction factor of 1/sqrt(2) for u=0; do the same */ coeffs[0] /= Math.sqrt(2); /* Scaling factor for distance of target dots from zero line * This must be 1/2 the scaling factor for the wave amplitudes for the waveform * and dots to line up */ shifted = shifted.map((s) => s / 210); updateLegend(legend, coeffs); drawCosineGraph(canvas, ctx, shifted, coeffs, dotLocations); if (hilite || hilite === 0) graphCosine(canvas, ctx, hilite * 0.5, coeffs[hilite], graphColors[hilite], 3.5); } function onPixelClick(pixels, callback) { for (var i = 0; i < 64; i++) { (function(i) { pixels[i].addEventListener('click', function() { callback(i, pixels[i]) }); })(i); } } function onLegendClick(entries, callback) { for (var i = 0; i < 8; i++) { (function(i) { entries[i].addEventListener('click', function() { callback(i, entries[i]); }); })(i); } } function onTargetDotDrag(canvas, dotLocations, callback) { var draggedDot, dragX, dragY; function acquireDot(offsetX, offsetY) { dragX = offsetX; dragY = offsetY; for (var i = 0; i < 8; i++) { const [x, y] = dotLocations[i]; if (x && y && ((x - dragX) ** 2 + (y - dragY) ** 2) < 40) { draggedDot = i; return; } } draggedDot = dragX = dragY = undefined; } canvas.addEventListener('mousedown', function(event) { acquireDot(event.offsetX, event.offsetY); }); canvas.addEventListener('touchstart', function(event) { const rect = canvas.getBoundingClientRect(); const touch = event.changedTouches[0]; acquireDot(touch.clientX - rect.left, touch.clientY - rect.top); /* If we acquired a dot, then this touch event should not have other effects * like initiating page scrolling */ if (draggedDot || draggedDot === 0) event.preventDefault(); }); canvas.addEventListener('mouseup', function() { draggedDot = dragX = dragY = undefined; }); canvas.addEventListener('mouseleave', function() { draggedDot = dragX = dragY = undefined; }); canvas.addEventListener('touchcancel', function() { draggedDot = dragX = dragY = undefined; }); canvas.addEventListener('touchend', function() { draggedDot = dragX = dragY = undefined; });Peering into the Linux Kernel with trace2020-06-04T00:00:00+00:002020-06-04T00:00:00+00:00/peering-in-the-kernel-with-trace<p>Recently, I was working on a patch for a popular open-source project, and discovered that the test suite was failing intermittently. A closer look revealed that the <a href="https://en.wikipedia.org/wiki/Stat_(system_call)">last access time</a> for some files in the project folder were changing unexpectedly, and this was causing a test to fail. (The failing test was not related to my patch.)</p>
<p>Looking at the project code, it seemed impossible for it to be unexpectedly accessing those files during the test in question. Running the test case under <a href="https://strace.io/"><code class="language-plaintext highlighter-rouge">strace</code></a> confirmed that this was not happening. But incontrovertibly, the access times <em>were</em> changing. Could another process on the same machine be reading those files? But why? Could it be a bug in the operating system? Were my tools lying to me?</p>
<p>Faced with a puzzle like this, the inclination might be to shrug one’s shoulders and forget about it, perhaps with a dismissive remark about the general brokenness of most software. (I’ve done that many times.) Anyways, it wasn’t <em>my</em> code which was failing. And yet, it seemed prudent to clear up the mystery, rather than bumbling along and <em>hoping</em> that what I didn’t know wouldn’t hurt me.</p>
<p>This seemed like a good opportunity to try out the <a href="https://iovisor.github.io/bcc/">BCC tools</a>. This is a powerful suite for examining and monitoring Linux kernel activity in real-time. Support is built in to the kernel (starting from 4.1), so you can immediately investigate when a problem is occurring, without needing to install a special kernel or reboot with special boot parameters.</p>
<p>One of the more than 100 utilities included in the BCC tools is <code class="language-plaintext highlighter-rouge">trace</code>. Using this program, one can monitor when <em>any</em> function in the kernel is called, what arguments it receives, what processes are causing those calls, and so on. Having <code class="language-plaintext highlighter-rouge">trace</code> is really like having a superpower.</p>
<p>Of course, the argument(s) of interest might not just be integers or strings. They might be pointers to C structs, which might contain pointers to other structs, and so on… but <code class="language-plaintext highlighter-rouge">trace</code> still has you covered. If you point it to the appropriate C header files which your kernel was compiled with, it can follow those pointers, pick out fields of interest, and print them at the console. (The header files enable <code class="language-plaintext highlighter-rouge">trace</code> to figure out the layout of those structs in memory.)</p>
<p>The invocation of <code class="language-plaintext highlighter-rouge">trace</code> which did the job for me turned out to be:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>sudo /usr/share/bcc/tools/trace -I/home/alex/Programming/linux/include/linux/path.h -I/home/alex/Programming/linux/include/linux/dcache.h 'touch_atime(struct path *path) "%s", path->dentry->d_name.name'
</code></pre></div></div>
<p>That says that every time a function called <code class="language-plaintext highlighter-rouge">touch_atime</code> (with parameter <code class="language-plaintext highlighter-rouge">struct path *path</code>) is called in the kernel, I want to see the string identified by the C expression <code class="language-plaintext highlighter-rouge">path->dentry->d_name.name</code>. In response, <code class="language-plaintext highlighter-rouge">trace</code> prints out a stream of messages like:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>2135 2135 sublime_text touch_atime ld.so.cache
2076 2076 chrome touch_atime
2494 2497 Chrome_ChildIOT touch_atime
1071 1071 Xorg touch_atime
2135 2135 sublime_text touch_atime Default.sublime-package
1566 1566 pulseaudio touch_atime
</code></pre></div></div>
<p>As you can see, it very helpfully shows some additional information for each call. From the left, that is the process ID, thread ID, command, function name, and then the requested string. Piping that into ripgrep revealed (within minutes) that my text editor had a background thread which was scanning the project files for changes, as part of its git integration. <em>That</em> is what was updating the access times and causing the erratic test failures.</p>
<p>What a difference it makes to be able to directly look inside a system and see what it is doing, instead of blindly groping using trial and error! This was the first time I harnessed the formidable power of <code class="language-plaintext highlighter-rouge">trace</code>, but it won’t be the last. It has a permanent home in my debugging toolbox now.</p>
<p><a href="http://www.catb.org/~esr/writings/taoup/html/ch01s06.html#id2878054">Eric Raymond’s “Rule of Transparency”</a> sagely advises programmers: “Design for visibility to make inspection and debugging easier”. You said it, Eric, you said it.</p>
<p><strong><em>⸻But how did you know the function to trace was touch_atime?</em></strong></p>
<p>Just poking around in the kernel source a bit. I knew there should be a function somewhere in the <code class="language-plaintext highlighter-rouge">fs</code> subfolder, and grepped for functions with <code class="language-plaintext highlighter-rouge">atime</code> in their name. There are just a few, and <code class="language-plaintext highlighter-rouge">touch_atime</code> almost jumped out. Reading the code confirmed that it was the right one.</p>
<p><strong><em>⸻OK. So how does <code class="language-plaintext highlighter-rouge">trace</code> work under the hood?</em></strong></p>
<p>First, it parses the “probe specifications” which you provide, converts them to a little C program, and uses BCC to convert that C program into eBPF bytecode. (The VM which runs this bytecode is built-in to the Linux kernel.) A special system call is used to load the bytecode into the kernel.</p>
<p><span id="kprobe">Next, it registers a <strong>kprobe</strong> with the kernel. The “kprobe” mechanism allows arbitrary callbacks to be associated with almost any function (actually, any machine instruction) in the kernel binary, which will fire whenever that instruction is executed. When a kprobe is registered, the kernel stores the original instruction somewhere and overwrites it with a breakpoint instruction (such as an <code class="language-plaintext highlighter-rouge">INT3</code> instruction on x86). Then it sets things up so that when the breakpoint fires, all the callbacks will be executed. Of course, the instruction which was overwritten will also be executed, so as not to break the function which is being traced.</span></p>
<p>There are a couple different APIs which user programs can use to create kprobes; one of them is by writing some specially formatted data to a “magic” file called <code class="language-plaintext highlighter-rouge">/sys/kernel/debug/tracing/kprobe_events</code>.</p>
<p>Then <code class="language-plaintext highlighter-rouge">trace</code> uses another API to tell the kernel to use the previously loaded eBPF bytecode as a callback for the new kprobe. Then it uses another API to get a file descriptor from the kernel, from which it can read the output generated by the BPF program.</p>
<p>It’s an intricate mechanism, but very, very flexible. Just thinking of the possibilities boggles the mind…</p>Recently, I was working on a patch for a popular open-source project, and discovered that the test suite was failing intermittently. A closer look revealed that the last access time for some files in the project folder were changing unexpectedly, and this was causing a test to fail. (The failing test was not related to my patch.) Looking at the project code, it seemed impossible for it to be unexpectedly accessing those files during the test in question. Running the test case under strace confirmed that this was not happening. But incontrovertibly, the access times were changing. Could another process on the same machine be reading those files? But why? Could it be a bug in the operating system? Were my tools lying to me? Faced with a puzzle like this, the inclination might be to shrug one’s shoulders and forget about it, perhaps with a dismissive remark about the general brokenness of most software. (I’ve done that many times.) Anyways, it wasn’t my code which was failing. And yet, it seemed prudent to clear up the mystery, rather than bumbling along and hoping that what I didn’t know wouldn’t hurt me. This seemed like a good opportunity to try out the BCC tools. This is a powerful suite for examining and monitoring Linux kernel activity in real-time. Support is built in to the kernel (starting from 4.1), so you can immediately investigate when a problem is occurring, without needing to install a special kernel or reboot with special boot parameters. One of the more than 100 utilities included in the BCC tools is trace. Using this program, one can monitor when any function in the kernel is called, what arguments it receives, what processes are causing those calls, and so on. Having trace is really like having a superpower. Of course, the argument(s) of interest might not just be integers or strings. They might be pointers to C structs, which might contain pointers to other structs, and so on… but trace still has you covered. If you point it to the appropriate C header files which your kernel was compiled with, it can follow those pointers, pick out fields of interest, and print them at the console. (The header files enable trace to figure out the layout of those structs in memory.) The invocation of trace which did the job for me turned out to be: sudo /usr/share/bcc/tools/trace -I/home/alex/Programming/linux/include/linux/path.h -I/home/alex/Programming/linux/include/linux/dcache.h 'touch_atime(struct path *path) "%s", path->dentry->d_name.name' That says that every time a function called touch_atime (with parameter struct path *path) is called in the kernel, I want to see the string identified by the C expression path->dentry->d_name.name. In response, trace prints out a stream of messages like: 2135 2135 sublime_text touch_atime ld.so.cache 2076 2076 chrome touch_atime 2494 2497 Chrome_ChildIOT touch_atime 1071 1071 Xorg touch_atime 2135 2135 sublime_text touch_atime Default.sublime-package 1566 1566 pulseaudio touch_atime As you can see, it very helpfully shows some additional information for each call. From the left, that is the process ID, thread ID, command, function name, and then the requested string. Piping that into ripgrep revealed (within minutes) that my text editor had a background thread which was scanning the project files for changes, as part of its git integration. That is what was updating the access times and causing the erratic test failures. What a difference it makes to be able to directly look inside a system and see what it is doing, instead of blindly groping using trial and error! This was the first time I harnessed the formidable power of trace, but it won’t be the last. It has a permanent home in my debugging toolbox now. Eric Raymond’s “Rule of Transparency” sagely advises programmers: “Design for visibility to make inspection and debugging easier”. You said it, Eric, you said it. ⸻But how did you know the function to trace was touch_atime? Just poking around in the kernel source a bit. I knew there should be a function somewhere in the fs subfolder, and grepped for functions with atime in their name. There are just a few, and touch_atime almost jumped out. Reading the code confirmed that it was the right one. ⸻OK. So how does trace work under the hood? First, it parses the “probe specifications” which you provide, converts them to a little C program, and uses BCC to convert that C program into eBPF bytecode. (The VM which runs this bytecode is built-in to the Linux kernel.) A special system call is used to load the bytecode into the kernel. Next, it registers a kprobe with the kernel. The “kprobe” mechanism allows arbitrary callbacks to be associated with almost any function (actually, any machine instruction) in the kernel binary, which will fire whenever that instruction is executed. When a kprobe is registered, the kernel stores the original instruction somewhere and overwrites it with a breakpoint instruction (such as an INT3 instruction on x86). Then it sets things up so that when the breakpoint fires, all the callbacks will be executed. Of course, the instruction which was overwritten will also be executed, so as not to break the function which is being traced. There are a couple different APIs which user programs can use to create kprobes; one of them is by writing some specially formatted data to a “magic” file called /sys/kernel/debug/tracing/kprobe_events. Then trace uses another API to tell the kernel to use the previously loaded eBPF bytecode as a callback for the new kprobe. Then it uses another API to get a file descriptor from the kernel, from which it can read the output generated by the BPF program. It’s an intricate mechanism, but very, very flexible. Just thinking of the possibilities boggles the mind…