|  | 
|  | 1 | +import { createSingletonBuffer, WebGPUBufferSet } from "./buffertools"; | 
|  | 2 | +import { StatefulGPU } from "./lib"; | 
|  | 3 | + | 
|  | 4 | +type TinyForestParams = { | 
|  | 5 | +  nTrees: number; | 
|  | 6 | +  depth: number; | 
|  | 7 | +  // The number of features to consider at each split. | 
|  | 8 | +  maxFeatures: number; | 
|  | 9 | +  D: number; | 
|  | 10 | +} | 
|  | 11 | + | 
|  | 12 | +const defaultTinyForestParams : TinyForestParams = { | 
|  | 13 | +  nTrees: 128, | 
|  | 14 | +  depth: 8, | 
|  | 15 | +  maxFeatures: 32, | 
|  | 16 | +  D: 768, | 
|  | 17 | +} | 
|  | 18 | + | 
|  | 19 | +export class TinyForest extends StatefulGPU { | 
|  | 20 | +  params: TinyForestParams; | 
|  | 21 | +   | 
|  | 22 | +  private _bootstrapSamples?: GPUBuffer; // On the order of 100 KB | 
|  | 23 | +  protected _forests?: GPUBuffer // On the order of 10 MB. | 
|  | 24 | +  // private trainedThrough: number = 0; | 
|  | 25 | +  constructor( | 
|  | 26 | +    device: GPUDevice,  | 
|  | 27 | +    bufferSize = 1024 * 1024 * 256,  | 
|  | 28 | +    t: Partial<TinyForestParams> = {}) { | 
|  | 29 | +    super(device, bufferSize) | 
|  | 30 | +    this.params = {...defaultTinyForestParams, ...t} | 
|  | 31 | +    this.initializeForestsToZero() | 
|  | 32 | +    this.bufferSet = new WebGPUBufferSet(device, bufferSize); | 
|  | 33 | +  } | 
|  | 34 | + | 
|  | 35 | +  countPipeline(): GPUComputePipeline { | 
|  | 36 | +    const { device } = this; | 
|  | 37 | +    // const { maxFeatures, nTrees } = this.params | 
|  | 38 | +    // const OPTIONS = 2; | 
|  | 39 | +    // const countBuffer = device.createBuffer({ | 
|  | 40 | +    //   size: OPTIONS * maxFeatures * nTrees * 4, | 
|  | 41 | +    //   usage: GPUBufferUsage.STORAGE & GPUBufferUsage.COPY_SRC, | 
|  | 42 | +    //   mappedAtCreation: false | 
|  | 43 | +    // }); | 
|  | 44 | + | 
|  | 45 | +    const layout = device.createBindGroupLayout({ | 
|  | 46 | +      entries: [ | 
|  | 47 | +        { | 
|  | 48 | +          // features buffer; | 
|  | 49 | +          binding: 0, | 
|  | 50 | +          visibility: GPUShaderStage.COMPUTE, | 
|  | 51 | +          buffer: { type: 'storage' } | 
|  | 52 | +        }, | 
|  | 53 | +        { | 
|  | 54 | +          // dims to check array; | 
|  | 55 | +          binding: 1, | 
|  | 56 | +          visibility: GPUShaderStage.COMPUTE, | 
|  | 57 | +          buffer: { type: 'storage' } | 
|  | 58 | +        }, | 
|  | 59 | +        { | 
|  | 60 | +          // output count buffer. | 
|  | 61 | +          binding: 2, | 
|  | 62 | +          visibility: GPUShaderStage.COMPUTE, | 
|  | 63 | +          buffer: { type: 'storage' } | 
|  | 64 | +        } | 
|  | 65 | +      ] | 
|  | 66 | +    }) | 
|  | 67 | + | 
|  | 68 | +    // const subsetsToCheck = this.chooseNextFeatures(); | 
|  | 69 | +    const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [layout] }); | 
|  | 70 | + | 
|  | 71 | +    const shaderModule = device.createShaderModule({ code: ` | 
|  | 72 | +      @group(0) @binding(0) var<storage, read> features: array<u32>; | 
|  | 73 | +      @group(0) @binding(1) var<storage, read> dimsToCheck: array<u16>; | 
|  | 74 | +      @group(0) @binding(2) var<storage, write> counts: array<u32>; | 
|  | 75 | +
 | 
|  | 76 | +      @compute @workgroup_size(64) | 
|  | 77 | +      //TODOD HERE | 
|  | 78 | +      ` }); | 
|  | 79 | + | 
|  | 80 | + | 
|  | 81 | +    return device.createComputePipeline({ | 
|  | 82 | +      layout: pipelineLayout, | 
|  | 83 | +      compute: { | 
|  | 84 | +        module: shaderModule, | 
|  | 85 | +        entryPoint: 'main' | 
|  | 86 | +      } | 
|  | 87 | +    }); | 
|  | 88 | +  } | 
|  | 89 | + | 
|  | 90 | +  //@ts-expect-error foo | 
|  | 91 | +  private chooseNextFeatures(n = 32) { | 
|  | 92 | +    console.log({n}) | 
|  | 93 | +    const { maxFeatures, nTrees, D } = this.params; | 
|  | 94 | +    const features = new Uint16Array(maxFeatures * D); | 
|  | 95 | +    for (let i = 0; i < nTrees; i++) { | 
|  | 96 | +      const set = new Set<number>(); | 
|  | 97 | +      while (set.size < maxFeatures) { | 
|  | 98 | +        set.add(Math.floor(Math.random() * D)); | 
|  | 99 | +      } | 
|  | 100 | +      const arr = new Uint16Array([...set].sort()); | 
|  | 101 | +      features.set(arr, i * maxFeatures); | 
|  | 102 | +    } | 
|  | 103 | +    return createSingletonBuffer( | 
|  | 104 | +      this.device, | 
|  | 105 | +      features, | 
|  | 106 | +      GPUBufferUsage.STORAGE | 
|  | 107 | +    ) | 
|  | 108 | +  } | 
|  | 109 | + | 
|  | 110 | + | 
|  | 111 | + | 
|  | 112 | +  initializeForestsToZero() { | 
|  | 113 | +    // Each tree is a set of bits; For every possible configuration  | 
|  | 114 | +    // the first D indicating  | 
|  | 115 | +    // the desired outcome for the dimension, | 
|  | 116 | +    // the second D indicating whether the bits in those | 
|  | 117 | +    // positions are to be considered in checking if the tree | 
|  | 118 | +    // fits. There are 2**depth bitmasks for each dimension--each point | 
|  | 119 | +    // will match only one, and part of the inference task is determining which one. | 
|  | 120 | + | 
|  | 121 | +    const treeSizeInBytes =  | 
|  | 122 | +      2 * this.params.D * (2 ** this.params.depth) / 8; | 
|  | 123 | + | 
|  | 124 | +    const data = new Uint8Array(treeSizeInBytes * this.params.nTrees) | 
|  | 125 | +    this._forests = createSingletonBuffer( | 
|  | 126 | +      this.device, | 
|  | 127 | +      data, | 
|  | 128 | +      GPUBufferUsage.STORAGE | 
|  | 129 | +    ) | 
|  | 130 | +  } | 
|  | 131 | +   | 
|  | 132 | + | 
|  | 133 | +  // Rather than actually bootstrap, we generate a single | 
|  | 134 | +  // list of 100,000 numbers drawn from a poisson distribution. | 
|  | 135 | +  // These serve as weights for draws with replacement; to  | 
|  | 136 | +  // bootstrap any given record batch, we take a sequence of | 
|  | 137 | +  // numbers from the buffer with offset i.  | 
|  | 138 | +  get bootstrapSamples() { | 
|  | 139 | +    if (this._bootstrapSamples) { | 
|  | 140 | +      return this._bootstrapSamples | 
|  | 141 | +    } else { | 
|  | 142 | +      const arr = new Uint8Array(100000) | 
|  | 143 | +      for (let i = 0; i < arr.length; i++) { | 
|  | 144 | +        arr[i] = poissonRandomNumber() | 
|  | 145 | +      } | 
|  | 146 | +      this._bootstrapSamples = createSingletonBuffer( | 
|  | 147 | +        this.device, | 
|  | 148 | +        arr, | 
|  | 149 | +        GPUBufferUsage.STORAGE | 
|  | 150 | +      ) | 
|  | 151 | +      return this._bootstrapSamples | 
|  | 152 | +    } | 
|  | 153 | +  } | 
|  | 154 | + | 
|  | 155 | +   | 
|  | 156 | +} | 
|  | 157 | + | 
|  | 158 | + | 
|  | 159 | +function poissonRandomNumber() : number { | 
|  | 160 | +  let p = 1.0; | 
|  | 161 | +  let k = 0; | 
|  | 162 | + | 
|  | 163 | +  do { | 
|  | 164 | +    k++; | 
|  | 165 | +    p *= Math.random(); | 
|  | 166 | +  } while (p > 1/Math.E); | 
|  | 167 | + | 
|  | 168 | +  return k - 1; | 
|  | 169 | +} | 
|  | 170 | + | 
0 commit comments