ndwasmarray.js

/**
 * File: ndwasmarray.js
 * Responsibility: High-performance WASM-resident NDArray.
 * Handles explicit memory management within the WASM heap. 
 * Users must call .dispose() manually to free memory.
 */

import { NDArray } from './ndarray_core.js';
import { NDWasm } from './ndwasm.js';

/**
 * NDWasmArray
 */
export class NDWasmArray {
    /**
     * @param {WasmBuffer} buffer - The WASM memory bridge (contains .ptr and .view).
     * @param {Int32Array|Array} shape - Dimensions of the array.
     * @param {string} dtype - Data type (e.g., 'float64').
     */
    constructor(buffer, shape, dtype) {
        this.buffer = buffer;
        this.shape = shape instanceof Int32Array ? shape : Int32Array.from(shape);
        this.dtype = dtype;
        this.ndim = this.shape.length;
        this.size = this.ndim === 0 ? 1 : this.shape.reduce((a, b) => a * b, 1);
    }

    /**
     * Static factory: Creates an uninitialized WASM-resident array.
     * @param {*} shape - Shape of the array.
     * @param {*} dtype - Data type.
     * @returns {NDWasmArray}
     */
    static newArray(shape, dtype = 'float64') {
        const size = shape.reduce((a, b) => a * b, 1);
        const buffer = NDWasm.runtime.createBuffer(size, dtype);
        return new NDWasmArray(buffer, shape, dtype);
    }

    /**
     * Static factory: Creates a WASM-resident array.
     * 1. If source is an NDArray, it calls .push() to move it to WASM.
     * 2. If source is a JS Array, it allocates WASM memory and fills it directly 
     *    via recursive traversal to avoid intermediate flattening.
     * 3. If source is a single number, it creates a 1-element WASM array.
     * Must dispose() manually to free WASM memory.
     * @param {NDArray|Array|number} source - Source data.
     * @param {string} [dtype='float64'] - Data type.
     * @returns {NDWasmArray}
     */
    static fromArray(source, dtype = 'float64') {
        // Handle existing NDArray instance
        if (source instanceof NDArray) {
            return source.push();
        }

        if (!NDWasm.runtime?.isLoaded) {
            throw new Error("WasmRuntime not initialized. Call NDWasm.bind(runtime) first.");
        }

        // Handle standard JS Arrays
        if (Array.isArray(source)) {
            // 1. Infer shape and total size
            const shape = [];
            let curr = source;
            while (Array.isArray(curr)) {
                shape.push(curr.length);
                curr = curr[0];
            }
            const size = shape.length === 0 ? 0 : shape.reduce((a, b) => a * b, 1);

            // 2. Pre-allocate memory in WASM heap
            const buffer = NDWasm.runtime.createBuffer(size, dtype);

            // 3. Populate WASM memory directly via recursive traversal
            let offset = 0;
            const fill = (arr) => {
                for (let i = 0; i < arr.length; i++) {
                    if (Array.isArray(arr[i])) {
                        fill(arr[i]);
                    } else {
                        buffer.view[offset++] = arr[i];
                    }
                }
            };
            fill(source);

            return new NDWasmArray(buffer, shape, dtype);
        }

        // Handle single numeric value
        if (typeof source === 'number') {
            const buffer = NDWasm.runtime.createBuffer(1, dtype);
            buffer.view[0] = source;
            return new NDWasmArray(buffer, [1], dtype);
        }

        throw new Error("Source must be an Array or an NDArray.");
    }

    /**
     * Pulls data from WASM to a JS-managed NDArray.
     * @param {boolean} [dispose=true] - Release WASM memory after pulling.
     */
    pull(dispose = true) {
        if (!this.buffer) throw new Error("WASM memory already disposed.");
        
        const data = this.buffer.pull(); 
        const result = new NDArray(data, { shape: this.shape, dtype: this.dtype });

        if (dispose) this.dispose();
        return result;
    }

    /**
     * Manually releases WASM heap memory.
     */
    dispose() {
        if (this.buffer) {
            this.buffer.dispose();
            this.buffer = null;
        }
    }

    /**
     * Matrix Multiplication: C = this * right
     * @param {NDWasmArray} right
     * @param {NDWasmArray} result - Pre-allocated result array.
     * @returns {NDWasmArray} the result array.
     */
    matMul(right, result) {
        if(!(right instanceof NDWasmArray)) {
            throw new Error("Right operand must be an NDWasmArray.");
        }
        
        if (this.shape[1] !== right.shape[0]) {
            throw new Error(`Inner dimensions mismatch: ${this.shape[1]} != ${right.shape[0]}`);
        }

        if(!result || !(result instanceof NDWasmArray) || result.dtype !== this.dtype) {
            throw new Error(`Invalid result array. Expected NDWasmArray with dtype ${this.dtype}.`);
        }else if(result.shape[0] !== this.shape[0] || result.shape[1] !== right.shape[1] ||result.ndim !== 2) {
            throw new Error(`Result array has incorrect shape. Expected [${this.shape[0]}, ${right.shape[1]}], got [${result.shape[0]}, ${result.shape[1]}].`);
        }      


        const m = this.shape[0];
        const n = this.shape[1];
        const k = right.shape[1];
        const suffix = NDWasm.runtime._getSuffix(this.dtype);

        const status = NDWasm.runtime.exports[`MatMul${suffix}`](
            this.buffer.ptr, 
            right.buffer.ptr, 
            result.buffer.ptr, 
            m, n, k
        );

        if (status !==undefined && status !== 0) throw new Error(`WASM MatMul failed with status: ${status}`);

        return result;        
    }

    /**
     * Batched Matrix Multiplication: C[i] = this[i] * other[i]
     * @param {NDWasmArray} other
     * @param {NDWasmArray} result - Pre-allocated result array.
     * @returns {NDWasmArray}
     */
    matMulBatch(right, result) {
        if(!(right instanceof NDWasmArray)) {
            throw new Error("Right operand must be an NDWasmArray.");
        }
        if (this.ndim !== 3 || right.ndim !== 3 || this.shape[0] !== right.shape[0]) {
            throw new Error(`dimensions mismatch. ${this.ndim}D and ${right.ndim}D arrays with batch size ${this.shape[0]} and ${right.shape[0]}.`);
        }

        if(!result || !(result instanceof NDWasmArray) || result.dtype !== this.dtype){
            throw new Error(`Invalid result array. Expected NDWasmArray with dtype ${this.dtype}.`);
        }else if(result.shape[0] !== this.shape[0] || result.shape[1] !== this.shape[1] || result.shape[2] !== right.shape[2] || result.ndim !== 3) {
            throw new Error(`Invalid result array. Expected shape [${this.shape[0]}, ${this.shape[1]}, ${right.shape[2]}], got [${result.shape[0]}, ${result.shape[1]}, ${result.shape[2]}].`);
        }

        const batch = this.shape[0];
        const m = this.shape[1];
        const n = this.shape[2];
        const k = right.shape[2];
        const suffix = NDWasm.runtime._getSuffix(this.dtype);

        const status = NDWasm.runtime.exports[`MatMulBatch${suffix}`](
            this.buffer.ptr,
            right.buffer.ptr,
            result.buffer.ptr,
            batch, m, n, k
        );

        if (status !==undefined && status !== 0) throw new Error(`WASM MatMulBatch failed with status: ${status}`);

        return result;
    }
}