diff --git a/MetaballsKit/MarchingSquares.swift b/MetaballsKit/MarchingSquares.swift index 2b0ec6b..ab17cec 100644 --- a/MetaballsKit/MarchingSquares.swift +++ b/MetaballsKit/MarchingSquares.swift @@ -35,11 +35,15 @@ class MarchingSquares { private(set) var gridGeometry: MTLBuffer? private var xSamples: Int { - return Int(field.size.x / sampleGridSize.x) + let xSize = field.size.x / sampleGridSize.x + let xRem = field.size.x % sampleGridSize.x + return Int(xSize + (sampleGridSize.x - xRem)) } private var ySamples: Int { - return Int(field.size.y / sampleGridSize.y) + let ySize = field.size.y / sampleGridSize.y + let yRem = field.size.y % sampleGridSize.y + return Int(ySize + (sampleGridSize.y - yRem)) } private var lastSamplesCount = 0 @@ -49,9 +53,12 @@ class MarchingSquares { } var contourIndexesCount: Int { - return (xSamples - 1) * (ySamples - 1) + return samplesCount } + /// Threadgroup size for the compute kernels. + private let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + init(field: Field) { self.field = field semaphore = DispatchSemaphore(value: 1) @@ -136,6 +143,7 @@ class MarchingSquares { func populateGrid(withDevice device: MTLDevice) { guard lastSamplesCount != samplesCount else { + print("Populate requested, but lastSampleCount(\(lastSamplesCount) == samplesCount(\(samplesCount))") return } @@ -181,9 +189,7 @@ class MarchingSquares { encoder.setBuffer(samplesBuffer, offset: 0, index: 2) // Dispatch! - // TODO: Kernel threadgroup size limit is 256. Figure out how to make this work. - let gridSize = MTLSize(width: xSamples, height: ySamples, depth: 1) - let threadgroupSize = MTLSize(width: xSamples, height: 1, depth: 1) + let gridSize = computeGridSize(forCellGridSize: Size(x: UInt32(xSamples), y: UInt32(ySamples))) encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) encoder.endEncoding() @@ -206,12 +212,24 @@ class MarchingSquares { encoder.setBuffer(contourIndexesBuffer, offset: 0, index: 2) // Dispatch! - let gridSize = MTLSize(width: contourIndexesCount, height: 1, depth: 1) - let threadgroupSize = MTLSize(width: xSamples - 1, height: 1, depth: 1) + let gridSize = computeGridSize(forCellGridSize: Size(x: UInt32(xSamples - 1), y: UInt32(ySamples - 1))) encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) encoder.endEncoding() } + + /// Grid size for the compute kernels. + func computeGridSize(forCellGridSize gridSize: Size) -> MTLSize { + let xs = Int(gridSize.x) + let ys = Int(gridSize.y) + let xrem = xs % threadgroupSize.width + let yrem = ys % threadgroupSize.height + // Our compute grid size is the next multiple of threadgroupSize larger than the current cell grid size. + let gridSize = MTLSize(width: xs + (threadgroupSize.width - xrem), + height: ys + (threadgroupSize.height - yrem), + depth: 1) + return gridSize + } } struct Variants { diff --git a/MetaballsKit/Renderer.swift b/MetaballsKit/Renderer.swift index 9cd65ea..59cd27d 100644 --- a/MetaballsKit/Renderer.swift +++ b/MetaballsKit/Renderer.swift @@ -206,7 +206,7 @@ public class Renderer: NSObject, MTKViewDelegate { if self.pixelGeometry == nil { self.pixelGeometry = self.pixelGeometry(forViewSize: view.drawableSize) } - let pixelGeometry = self.pixelGeometry! +// let pixelGeometry = self.pixelGeometry! if let renderPass = view.currentRenderPassDescriptor { // Render the per-pixel metaballs @@ -240,7 +240,7 @@ public class Renderer: NSObject, MTKViewDelegate { encoder.setVertexBuffer(marchingSquares.gridGeometry, offset: 0, index: 1) encoder.setVertexBuffer(parametersBuffer, offset: 0, index: 2) encoder.setFragmentBuffer(marchingSquares.contourIndexesBuffer, offset: 0, index: 0) - encoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: Rect.geometry.count, instanceCount: marchingSquares.samplesCount) + encoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: Rect.geometry.count, instanceCount: marchingSquares.contourIndexesCount) encoder.endEncoding() didEncode = true } diff --git a/MetaballsKit/Shaders/MarchingSquares.metal b/MetaballsKit/Shaders/MarchingSquares.metal index 59af943..77ed09e 100644 --- a/MetaballsKit/Shaders/MarchingSquares.metal +++ b/MetaballsKit/Shaders/MarchingSquares.metal @@ -45,6 +45,11 @@ samplingKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]], device float *samples [[buffer(2)]], uint2 position [[thread_position_in_grid]]) { + if (position.x >= parameters.gridSize.x || position.y >= parameters.gridSize.y) + { + return; + } + // Find the midpoint of this grid cell. const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0), position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0)); @@ -68,38 +73,43 @@ kernel void contouringKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]], constant float *samples [[buffer(1)]], device ushort *contourIndexes [[buffer(2)]], - uint position [[thread_position_in_grid]]) + uint2 position [[thread_position_in_grid]]) { + if (position.x >= (parameters.gridSize.x - 1) || position.y >= (parameters.gridSize.y - 1)) { + return; + } + // Calculate an index based on the samples at the four points around this cell. // If the point is above the threshold, adjust the value accordingly. // d--c 8--4 // | | -> | | // a--b 1--2 - uint a = position + parameters.gridSize.x; - uint b = position + parameters.gridSize.x + 1; - uint c = position + 1; - uint d = position; + uint rowSize = parameters.gridSize.x - 1; + uint d = position.y * rowSize + position.x; + uint c = d + 1; + uint b = d + rowSize + 1; + uint a = d + rowSize; uint index = (samples[d] >= 1.0 ? 0b1000 : 0) + (samples[c] >= 1.0 ? 0b0100 : 0) + (samples[b] >= 1.0 ? 0b0010 : 0) + (samples[a] >= 1.0 ? 0b0001 : 0); - contourIndexes[position] = index; + contourIndexes[d] = index; } vertex RasterizerData gridVertexShader(constant Vertex *vertexes [[buffer(0)]], - constant Rect *rects [[buffer(1)]], + constant Rect *cells [[buffer(1)]], constant RenderParameters &renderParameters [[buffer(2)]], uint vid [[vertex_id]], uint instid [[instance_id]]) { Vertex v = vertexes[vid]; - Rect rect = rects[instid]; + Rect cell = cells[instid]; RasterizerData out; - out.position = renderParameters.projection * rect.transform * float4(v.position.xy, 0, 1); - out.color = rect.color; + out.position = renderParameters.projection * cell.transform * float4(v.position.xy, 0, 1); + out.color = cell.color; out.textureCoordinate = v.textureCoordinate; out.instance = instid; return out;