diff --git a/MetaballsKit/MarchingSquares.swift b/MetaballsKit/MarchingSquares.swift index 8806a9b..dfce055 100644 --- a/MetaballsKit/MarchingSquares.swift +++ b/MetaballsKit/MarchingSquares.swift @@ -8,6 +8,7 @@ import Foundation import Metal +import simd class MarchingSquares { private var field: Field @@ -15,10 +16,13 @@ class MarchingSquares { private var semaphore: DispatchSemaphore + private var samplingPipeline: MTLComputePipelineState? + + private var parametersBuffer: MTLBuffer? /// Samples of the field's current state. private(set) var samplesBuffer: MTLBuffer? /// Indexes of geometry to render. - private(set) var indexes: MTLTexture? + private(set) var indexBuffer: MTLBuffer? private(set) var gridGeometry: MTLBuffer? @@ -41,31 +45,45 @@ class MarchingSquares { semaphore = DispatchSemaphore(value: 1) } - func setupMetal(withDevice device: MTLDevice) { -// let samplesDesc = MTLTextureDescriptor() -// samplesDesc.textureType = .type2D -// samplesDesc.width = xSamples -// samplesDesc.height = ySamples -// samplesDesc.pixelFormat = .r32Float -// samples = device.makeTexture(descriptor: samplesDesc) -// -// let indexesDesc = MTLTextureDescriptor() -// indexesDesc.textureType = .type2D -// indexesDesc.width = xSamples - 1 -// indexesDesc.height = ySamples - 1 -// indexesDesc.pixelFormat = .a8Unorm -// indexes = device.makeTexture(descriptor: indexesDesc) + func setupMetal(withDevice device: MTLDevice, library: MTLLibrary) { + guard let samplingFunction = library.makeFunction(name: "samplingKernel") else { + fatalError("Couldn't get samplingKernel function from library") + } + do { + samplingPipeline = try device.makeComputePipelineState(function: samplingFunction) + } catch let e { + fatalError("Error building compute pipeline state for sampling kernel: \(e)") + } + + let parametersLength = MemoryLayout.stride * 3 + MemoryLayout.stride + parametersBuffer = device.makeBuffer(length: parametersLength, options: .storageModeShared) + populateParametersBuffer() } func fieldDidResize() { guard let device = gridGeometry?.device else { return } + populateParametersBuffer() populateGrid(withDevice: device) populateSamples(withDevice: device) lastSamplesCount = samplesCount } + func populateParametersBuffer() { + guard let buffer = parametersBuffer else { + print("Tried to copy parameters buffer before buffer was allocated!") + return + } + let params: [uint] = [ + field.size.x, field.size.y, + uint(xSamples), uint(ySamples), + sampleGridSize.x, sampleGridSize.y, + uint(field.balls.count) + ] + memcpy(buffer.contents(), params, MemoryLayout.stride * params.count) + } + func populateGrid(withDevice device: MTLDevice) { guard lastSamplesCount != samplesCount else { return @@ -97,24 +115,45 @@ class MarchingSquares { } func populateSamples(withDevice device: MTLDevice) { - var samples = [Float]() - samples.reserveCapacity(samplesCount) +// var samples = [Float]() +// samples.reserveCapacity(samplesCount) - for ys in 0...stride * samplesCount - if let buffer = device.makeBuffer(length: MemoryLayout.stride * samples.count, options: .storageModeShared) { - memcpy(buffer.contents(), samples, samplesLength) - samplesBuffer = buffer - } else { - fatalError("Couldn't create buffer for samples") + samplesBuffer = device.makeBuffer(length: samplesLength, options: .storageModePrivate) + if samplesBuffer == nil { + fatalError("Couldn't create samplesBuffer!") } } + + func encodeSamplingKernel(intoBuffer buffer: MTLCommandBuffer) { + guard let samplingPipeline = samplingPipeline else { + print("Encode called before sampling pipeline was set up!") + return + } + guard let encoder = buffer.makeComputeCommandEncoder() else { + print("Couldn't create compute encoder") + return + } + encoder.label = "Sample Field" + encoder.setComputePipelineState(samplingPipeline) + encoder.setBuffer(parametersBuffer, offset: 0, index: 0) + encoder.setBuffer(field.ballBuffer, offset: 0, index: 1) + encoder.setBuffer(samplesBuffer, offset: 0, index: 2) + + // Dispatch! + let gridSize = MTLSize(width: xSamples, height: ySamples, depth: 1) + let threadgroupSize = MTLSize(width: xSamples, height: 1, depth: 1) + encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) + + encoder.endEncoding() + } } diff --git a/MetaballsKit/Renderer.swift b/MetaballsKit/Renderer.swift index d7a554e..731d0c7 100644 --- a/MetaballsKit/Renderer.swift +++ b/MetaballsKit/Renderer.swift @@ -43,7 +43,7 @@ public class Renderer: NSObject, MTKViewDelegate { configureMarchingSquaresPipeline(withPixelFormat: view.colorPixelFormat) delegate.field.setupMetal(withDevice: device) - delegate.marchingSquares.setupMetal(withDevice: device) + delegate.marchingSquares.setupMetal(withDevice: device, library: library) delegate.marchingSquares.populateGrid(withDevice: device) delegate.marchingSquares.populateSamples(withDevice: device) } @@ -51,9 +51,9 @@ public class Renderer: NSObject, MTKViewDelegate { private var device: MTLDevice - private lazy var library: MTLLibrary? = { + private lazy var library: MTLLibrary = { let bundle = Bundle(for: type(of: self)) - return try? device.makeDefaultLibrary(bundle: bundle) + return try! device.makeDefaultLibrary(bundle: bundle) }() private var commandQueue: MTLCommandQueue @@ -90,10 +90,6 @@ public class Renderer: NSObject, MTKViewDelegate { } private func configurePixelPipeline(withPixelFormat pixelFormat: MTLPixelFormat) { - guard let library = library else { - fatalError("Couldn't get Metal library") - } - let vertexShader = library.makeFunction(name: "passthroughVertexShader") let fragmentShader = library.makeFunction(name: "sampleToColorShader") @@ -124,10 +120,6 @@ public class Renderer: NSObject, MTKViewDelegate { } private func configureMarchingSquaresPipeline(withPixelFormat pixelFormat: MTLPixelFormat) { - guard let library = library else { - fatalError("Couldn't get Metal library") - } - guard let vertexShader = library.makeFunction(name: "gridVertexShader"), let fragmentShader = library.makeFunction(name: "gridFragmentShader") else { fatalError("Couldn't get marching squares vertex or fragment function from library") @@ -209,7 +201,7 @@ public class Renderer: NSObject, MTKViewDelegate { field.update() if let ms = delegate?.marchingSquares { - ms.populateSamples(withDevice: device) + ms.populateParametersBuffer() } if self.pixelGeometry == nil { @@ -233,6 +225,9 @@ public class Renderer: NSObject, MTKViewDelegate { // } if let marchingSquares = delegate?.marchingSquares { + // Compute samples first. + marchingSquares.encodeSamplingKernel(intoBuffer: buffer) + // Render the marching squares version over top of the pixel version. // We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other. let pass = renderPass.copy() as! MTLRenderPassDescriptor diff --git a/MetaballsKit/Shaders/MarchingSquares.metal b/MetaballsKit/Shaders/MarchingSquares.metal index d5051ba..1a8b9a3 100644 --- a/MetaballsKit/Shaders/MarchingSquares.metal +++ b/MetaballsKit/Shaders/MarchingSquares.metal @@ -10,6 +10,17 @@ #include "ShaderTypes.hh" using namespace metal; +struct MarchingSquaresParameters { + /// Field size in pixels. + packed_uint2 pixelSize; + /// Field size in grid units. + packed_uint2 gridSize; + /// Size of a cell in pixels. + packed_uint2 cellSize; + /// Number of balls in the array above. + uint ballsCount; +}; + struct Rect { float4x4 transform; float4 color; @@ -22,6 +33,37 @@ struct RasterizerData { int instance; }; +kernel void +generateGridGeometry() +{ +} + +/// Sample the field at regularly spaced intervals and populate `samples` with the resulting values. +kernel void +samplingKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]], + constant Ball *balls [[buffer(1)]], + device float *samples [[buffer(2)]], + uint2 position [[thread_position_in_grid]]) +{ + // Find the midpoint of this grid cell. + const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0), + position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0)); + + // Sample the grid. + float sample = 0.0; + for (uint i = 0; i < parameters.ballsCount; i++) { + constant Ball &ball = balls[i]; + float r2 = ball.z * ball.z; + float xDiff = point.x - ball.x; + float yDiff = point.y - ball.y; + sample += r2 / ((xDiff * xDiff) + (yDiff * yDiff)); + } + + // Playing a bit fast and loose with these values here. The compute grid is the size of the grid itself, so parameters.gridSize == [[threads_per_grid]]. + uint idx = position.y * parameters.gridSize.x + position.x; + samples[idx] = sample; +} + vertex RasterizerData gridVertexShader(constant Vertex *vertexes [[buffer(0)]], constant Rect *rects [[buffer(1)]], diff --git a/MetaballsKit/Shaders/ShaderTypes.hh b/MetaballsKit/Shaders/ShaderTypes.hh index a75ebcb..904e721 100644 --- a/MetaballsKit/Shaders/ShaderTypes.hh +++ b/MetaballsKit/Shaders/ShaderTypes.hh @@ -11,6 +11,9 @@ #include +/// A Ball is a float 3-tuple: (x, y, r). +typedef float3 Ball; + struct Vertex { float2 position; float2 textureCoordinate; diff --git a/MetaballsKit/Shaders/Shaders.metal b/MetaballsKit/Shaders/Shaders.metal index 52cd4b4..48caa42 100644 --- a/MetaballsKit/Shaders/Shaders.metal +++ b/MetaballsKit/Shaders/Shaders.metal @@ -39,8 +39,6 @@ struct Parameters { float3x3 colorTransform; }; -typedef float3 Ball; - #pragma mark - Vertex vertex RasterizerData