import { NDArray } from './ndarray_core.js';
import { NDWasm, fromWasm } from './ndwasm.js';
/**
* NDWasmBlas: BLAS (Basic Linear Algebra Subprograms)
* Handles O(n^2) and O(n^3) matrix-matrix and matrix-vector operations.
* @namespace NDWasmBlas
*/
export const NDWasmBlas = {
/**
* Calculates the trace of a 2D square matrix (sum of diagonal elements).
* Complexity: O(n)
* @param {NDArray} a
* @returns {number} The sum of the diagonal elements.
* @throws {Error} If the array is not 2D or not a square matrix.
*/
trace(a) {
// 1. Validation: Must be 2D
if (a.ndim !== 2) {
throw new Error(`Trace is only defined for 2D matrices, but this array is ${a.ndim}D.`);
}
// 2. Validation: Must be square (N x N)
const rows = a.shape[0];
const cols = a.shape[1];
if (rows !== cols) {
throw new Error(`Trace is only defined for square matrices, but this matrix is ${rows}x${cols}.`);
}
let sum = 0;
const data = a.data;
const offset = a.offset;
const s0 = a.strides[0]; // Row stride
const s1 = a.strides[1]; // Column stride
// 3. Calculation: Sum A[i, i]
// Memory index for element [i, i] is: offset + i * strides[0] + i * strides[1]
// We can optimize the loop by pre-calculating the step size: s0 + s1
const step = s0 + s1;
for (let i = 0; i < rows; i++) {
sum += data[offset + i * step];
}
return sum;
},
/**
* General Matrix Multiplication (GEMM): C = A * B.
* Complexity: O(m * n * k)
* @memberof NDWasmBlas
* @param {NDArray} a - Left matrix of shape [m, n].
* @param {NDArray} b - Right matrix of shape [n, k].
* @returns {NDArray} Result matrix of shape [m, k].
*/
matMul(a, b) {
if (a.shape[1] !== b.shape[0]) {
throw new Error(`Matrix inner dimensions must match: ${a.shape[1]} != ${b.shape[0]}`);
}
if(b.ndim!==2 && b.ndim!==1){
throw new Error(`Right operand must be a 2D matrix (or 1D vector), but got ${b.ndim}D.`);
}
const m = a.shape[0];
const n = a.shape[1];
const k = b.ndim === 2 ? b.shape[1] : 1;
const outShape = [m, k];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
return NDWasm._compute([a, b], outShape, a.dtype, (aPtr, bPtr, outPtr) => {
return NDWasm.runtime.exports[`MatMul${suffix}`](aPtr, bPtr, outPtr, m, n, k);
});
},
/**
* matPow computes A^k (Matrix Power).
* Matrix Functions (O(n^3))
* @memberof NDWasmBlas
* @param {NDArray} a - Matrix of shape [n, n].
* @returns {NDArray} Result matrix of shape [n, n].
*/
matPow(a, k) {
if (a.shape[0] !== a.shape[1]) {
throw new Error(`Matrix must be square: ${a.shape[0]} != ${a.shape[1]}`);
}
const n = a.shape[0];
const outShape = [n, n];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
return NDWasm._compute([a], outShape, a.dtype, (aPtr, outPtr) => {
return NDWasm.runtime.exports[`MatrixPower${suffix}`](aPtr, outPtr, n, k);
});
},
/**
* Batched Matrix Multiplication: C[i] = A[i] * B[i].
* Common in deep learning inference.
* Complexity: O(batch * m * n * k)
* @memberof NDWasmBlas
* @param {NDArray} a - Batch of matrices of shape [batch, m, n].
* @param {NDArray} b - Batch of matrices of shape [batch, n, k].
* @returns {NDArray} Result batch of shape [batch, m, k].
*/
matMulBatch(a, b) {
if (a.ndim !== 3 || b.ndim !== 3 || a.shape[0] !== b.shape[0]) {
throw new Error("Input must be 3D batches with same batch size.");
}
if (a.shape[2] !== b.shape[1]) {
throw new Error("Batch matrix inner dimensions must match.");
}
const batch = a.shape[0];
const m = a.shape[1];
const n = a.shape[2];
const k = b.shape[2];
const outShape = [batch, m, k];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
return NDWasm._compute([a, b], outShape, a.dtype, (aPtr, bPtr, outPtr) => {
return NDWasm.runtime.exports[`MatMulBatch${suffix}`](aPtr, bPtr, outPtr, batch, m, n, k);
});
},
/**
* Symmetric Rank-K Update: C = alpha * A * A^T + beta * C.
* Used for efficiently computing covariance matrices or Gram matrices.
* Complexity: O(n^2 * k)
* @memberof NDWasmBlas
* @param {NDArray} a - Input matrix of shape [n, k].
* @returns {NDArray} Symmetric result matrix of shape [n, n].
*/
syrk(a) {
const n = a.shape[0];
const k = a.shape[1];
const outShape = [n, n];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
return NDWasm._compute([a], outShape, a.dtype, (aPtr, outPtr) => {
return NDWasm.runtime.exports[`Syrk${suffix}`](aPtr, outPtr, n, k);
});
},
/**
* Triangular System Solver: Solves A * X = B for X, where A is a triangular matrix.
* Complexity: O(m^2 * n)
* @memberof NDWasmBlas
* @param {NDArray} a - Triangular matrix of shape [m, m].
* @param {NDArray} b - Right-hand side matrix/vector of shape [m, n].
* @returns {NDArray} Solution matrix X of shape [m, n].
*/
trsm(a, b, lower = false) {
if (a.shape[0] !== a.shape[1] || a.shape[0] !== b.shape[0]) {
throw new Error("Dimension mismatch for triangular solver.");
}
const m = a.shape[0];
const n = b.ndim === 1 ? 1 : b.shape[1];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
// We use _compute but must ensure b is copied to the output buffer first
// because Go's TRSM often operates on the B matrix storage.
const wa = a.toWasm(NDWasm.runtime);
const wb = b.toWasm(NDWasm.runtime); // wb already contains b's data
try {
NDWasm.runtime.exports[`Trsm${suffix}`](wa.ptr, wb.ptr, m, n, lower ? 1 : 0);
return fromWasm(wb, b.shape, b.dtype);
} finally {
wa.dispose();
wb.dispose();
}
},
/**
* Matrix-Vector Multiplication: y = A * x.
* Complexity: O(m * n)
* @memberof NDWasmBlas
* @param {NDArray} a - Matrix of shape [m, n].
* @param {NDArray} x - Vector of shape [n].
* @returns {NDArray} Result vector of shape [m].
*/
matVecMul(a, x) {
if (a.shape[1] !== x.size) {
throw new Error("Matrix-Vector dimension mismatch.");
}
const m = a.shape[0];
const n = a.shape[1];
const suffix = NDWasm.runtime._getSuffix(a.dtype);
return NDWasm._compute([a, x], [m], a.dtype, (aPtr, xPtr, outPtr) => {
return NDWasm.runtime.exports[`MatVecMul${suffix}`](aPtr, xPtr, outPtr, m, n);
});
},
/**
* Vector Outer Product (Rank-1 Update): A = x * y^T.
* Complexity: O(m * n)
* @memberof NDWasmBlas
* @param {NDArray} x - Vector of shape [m].
* @param {NDArray} y - Vector of shape [n].
* @returns {NDArray} Result matrix of shape [m, n].
*/
ger(x, y) {
const m = x.size;
const n = y.size;
const outShape = [m, n];
const suffix = NDWasm.runtime._getSuffix(x.dtype);
return NDWasm._compute([x, y], outShape, x.dtype, (xPtr, yPtr, outPtr) => {
return NDWasm.runtime.exports[`Ger${suffix}`](xPtr, yPtr, outPtr, m, n);
});
}
};