Move sampling to a compute kernel
This commit is contained in:
parent
d60b2d8744
commit
275b260cf9
5 changed files with 121 additions and 44 deletions
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import Metal
|
import Metal
|
||||||
|
import simd
|
||||||
|
|
||||||
class MarchingSquares {
|
class MarchingSquares {
|
||||||
private var field: Field
|
private var field: Field
|
||||||
|
@ -15,10 +16,13 @@ class MarchingSquares {
|
||||||
|
|
||||||
private var semaphore: DispatchSemaphore
|
private var semaphore: DispatchSemaphore
|
||||||
|
|
||||||
|
private var samplingPipeline: MTLComputePipelineState?
|
||||||
|
|
||||||
|
private var parametersBuffer: MTLBuffer?
|
||||||
/// Samples of the field's current state.
|
/// Samples of the field's current state.
|
||||||
private(set) var samplesBuffer: MTLBuffer?
|
private(set) var samplesBuffer: MTLBuffer?
|
||||||
/// Indexes of geometry to render.
|
/// Indexes of geometry to render.
|
||||||
private(set) var indexes: MTLTexture?
|
private(set) var indexBuffer: MTLBuffer?
|
||||||
|
|
||||||
private(set) var gridGeometry: MTLBuffer?
|
private(set) var gridGeometry: MTLBuffer?
|
||||||
|
|
||||||
|
@ -41,31 +45,45 @@ class MarchingSquares {
|
||||||
semaphore = DispatchSemaphore(value: 1)
|
semaphore = DispatchSemaphore(value: 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupMetal(withDevice device: MTLDevice) {
|
func setupMetal(withDevice device: MTLDevice, library: MTLLibrary) {
|
||||||
// let samplesDesc = MTLTextureDescriptor()
|
guard let samplingFunction = library.makeFunction(name: "samplingKernel") else {
|
||||||
// samplesDesc.textureType = .type2D
|
fatalError("Couldn't get samplingKernel function from library")
|
||||||
// samplesDesc.width = xSamples
|
}
|
||||||
// samplesDesc.height = ySamples
|
do {
|
||||||
// samplesDesc.pixelFormat = .r32Float
|
samplingPipeline = try device.makeComputePipelineState(function: samplingFunction)
|
||||||
// samples = device.makeTexture(descriptor: samplesDesc)
|
} catch let e {
|
||||||
//
|
fatalError("Error building compute pipeline state for sampling kernel: \(e)")
|
||||||
// let indexesDesc = MTLTextureDescriptor()
|
}
|
||||||
// indexesDesc.textureType = .type2D
|
|
||||||
// indexesDesc.width = xSamples - 1
|
let parametersLength = MemoryLayout<simd.packed_int2>.stride * 3 + MemoryLayout<simd.uint>.stride
|
||||||
// indexesDesc.height = ySamples - 1
|
parametersBuffer = device.makeBuffer(length: parametersLength, options: .storageModeShared)
|
||||||
// indexesDesc.pixelFormat = .a8Unorm
|
populateParametersBuffer()
|
||||||
// indexes = device.makeTexture(descriptor: indexesDesc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fieldDidResize() {
|
func fieldDidResize() {
|
||||||
guard let device = gridGeometry?.device else {
|
guard let device = gridGeometry?.device else {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
populateParametersBuffer()
|
||||||
populateGrid(withDevice: device)
|
populateGrid(withDevice: device)
|
||||||
populateSamples(withDevice: device)
|
populateSamples(withDevice: device)
|
||||||
lastSamplesCount = samplesCount
|
lastSamplesCount = samplesCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func populateParametersBuffer() {
|
||||||
|
guard let buffer = parametersBuffer else {
|
||||||
|
print("Tried to copy parameters buffer before buffer was allocated!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
let params: [uint] = [
|
||||||
|
field.size.x, field.size.y,
|
||||||
|
uint(xSamples), uint(ySamples),
|
||||||
|
sampleGridSize.x, sampleGridSize.y,
|
||||||
|
uint(field.balls.count)
|
||||||
|
]
|
||||||
|
memcpy(buffer.contents(), params, MemoryLayout<uint>.stride * params.count)
|
||||||
|
}
|
||||||
|
|
||||||
func populateGrid(withDevice device: MTLDevice) {
|
func populateGrid(withDevice device: MTLDevice) {
|
||||||
guard lastSamplesCount != samplesCount else {
|
guard lastSamplesCount != samplesCount else {
|
||||||
return
|
return
|
||||||
|
@ -97,24 +115,45 @@ class MarchingSquares {
|
||||||
}
|
}
|
||||||
|
|
||||||
func populateSamples(withDevice device: MTLDevice) {
|
func populateSamples(withDevice device: MTLDevice) {
|
||||||
var samples = [Float]()
|
// var samples = [Float]()
|
||||||
samples.reserveCapacity(samplesCount)
|
// samples.reserveCapacity(samplesCount)
|
||||||
|
|
||||||
for ys in 0..<ySamples {
|
// for ys in 0..<ySamples {
|
||||||
let y = Float(ys * Int(sampleGridSize.y))
|
// let y = Float(ys * Int(sampleGridSize.y))
|
||||||
for xs in 0..<xSamples {
|
// for xs in 0..<xSamples {
|
||||||
let x = Float(xs * Int(sampleGridSize.x))
|
// let x = Float(xs * Int(sampleGridSize.x))
|
||||||
let sample = field.sample(at: Float2(x: x, y: y))
|
// let sample = field.sample(at: Float2(x: x, y: y))
|
||||||
samples.append(sample)
|
// samples.append(sample)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let samplesLength = MemoryLayout<Float>.stride * samplesCount
|
let samplesLength = MemoryLayout<Float>.stride * samplesCount
|
||||||
if let buffer = device.makeBuffer(length: MemoryLayout<Float>.stride * samples.count, options: .storageModeShared) {
|
samplesBuffer = device.makeBuffer(length: samplesLength, options: .storageModePrivate)
|
||||||
memcpy(buffer.contents(), samples, samplesLength)
|
if samplesBuffer == nil {
|
||||||
samplesBuffer = buffer
|
fatalError("Couldn't create samplesBuffer!")
|
||||||
} else {
|
|
||||||
fatalError("Couldn't create buffer for samples")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func encodeSamplingKernel(intoBuffer buffer: MTLCommandBuffer) {
|
||||||
|
guard let samplingPipeline = samplingPipeline else {
|
||||||
|
print("Encode called before sampling pipeline was set up!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
guard let encoder = buffer.makeComputeCommandEncoder() else {
|
||||||
|
print("Couldn't create compute encoder")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
encoder.label = "Sample Field"
|
||||||
|
encoder.setComputePipelineState(samplingPipeline)
|
||||||
|
encoder.setBuffer(parametersBuffer, offset: 0, index: 0)
|
||||||
|
encoder.setBuffer(field.ballBuffer, offset: 0, index: 1)
|
||||||
|
encoder.setBuffer(samplesBuffer, offset: 0, index: 2)
|
||||||
|
|
||||||
|
// Dispatch!
|
||||||
|
let gridSize = MTLSize(width: xSamples, height: ySamples, depth: 1)
|
||||||
|
let threadgroupSize = MTLSize(width: xSamples, height: 1, depth: 1)
|
||||||
|
encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
configureMarchingSquaresPipeline(withPixelFormat: view.colorPixelFormat)
|
configureMarchingSquaresPipeline(withPixelFormat: view.colorPixelFormat)
|
||||||
|
|
||||||
delegate.field.setupMetal(withDevice: device)
|
delegate.field.setupMetal(withDevice: device)
|
||||||
delegate.marchingSquares.setupMetal(withDevice: device)
|
delegate.marchingSquares.setupMetal(withDevice: device, library: library)
|
||||||
delegate.marchingSquares.populateGrid(withDevice: device)
|
delegate.marchingSquares.populateGrid(withDevice: device)
|
||||||
delegate.marchingSquares.populateSamples(withDevice: device)
|
delegate.marchingSquares.populateSamples(withDevice: device)
|
||||||
}
|
}
|
||||||
|
@ -51,9 +51,9 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
|
|
||||||
private var device: MTLDevice
|
private var device: MTLDevice
|
||||||
|
|
||||||
private lazy var library: MTLLibrary? = {
|
private lazy var library: MTLLibrary = {
|
||||||
let bundle = Bundle(for: type(of: self))
|
let bundle = Bundle(for: type(of: self))
|
||||||
return try? device.makeDefaultLibrary(bundle: bundle)
|
return try! device.makeDefaultLibrary(bundle: bundle)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
private var commandQueue: MTLCommandQueue
|
private var commandQueue: MTLCommandQueue
|
||||||
|
@ -90,10 +90,6 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
}
|
}
|
||||||
|
|
||||||
private func configurePixelPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
|
private func configurePixelPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
|
||||||
guard let library = library else {
|
|
||||||
fatalError("Couldn't get Metal library")
|
|
||||||
}
|
|
||||||
|
|
||||||
let vertexShader = library.makeFunction(name: "passthroughVertexShader")
|
let vertexShader = library.makeFunction(name: "passthroughVertexShader")
|
||||||
let fragmentShader = library.makeFunction(name: "sampleToColorShader")
|
let fragmentShader = library.makeFunction(name: "sampleToColorShader")
|
||||||
|
|
||||||
|
@ -124,10 +120,6 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
}
|
}
|
||||||
|
|
||||||
private func configureMarchingSquaresPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
|
private func configureMarchingSquaresPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
|
||||||
guard let library = library else {
|
|
||||||
fatalError("Couldn't get Metal library")
|
|
||||||
}
|
|
||||||
|
|
||||||
guard let vertexShader = library.makeFunction(name: "gridVertexShader"),
|
guard let vertexShader = library.makeFunction(name: "gridVertexShader"),
|
||||||
let fragmentShader = library.makeFunction(name: "gridFragmentShader") else {
|
let fragmentShader = library.makeFunction(name: "gridFragmentShader") else {
|
||||||
fatalError("Couldn't get marching squares vertex or fragment function from library")
|
fatalError("Couldn't get marching squares vertex or fragment function from library")
|
||||||
|
@ -209,7 +201,7 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
|
|
||||||
field.update()
|
field.update()
|
||||||
if let ms = delegate?.marchingSquares {
|
if let ms = delegate?.marchingSquares {
|
||||||
ms.populateSamples(withDevice: device)
|
ms.populateParametersBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.pixelGeometry == nil {
|
if self.pixelGeometry == nil {
|
||||||
|
@ -233,6 +225,9 @@ public class Renderer: NSObject, MTKViewDelegate {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
if let marchingSquares = delegate?.marchingSquares {
|
if let marchingSquares = delegate?.marchingSquares {
|
||||||
|
// Compute samples first.
|
||||||
|
marchingSquares.encodeSamplingKernel(intoBuffer: buffer)
|
||||||
|
|
||||||
// Render the marching squares version over top of the pixel version.
|
// Render the marching squares version over top of the pixel version.
|
||||||
// We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other.
|
// We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other.
|
||||||
let pass = renderPass.copy() as! MTLRenderPassDescriptor
|
let pass = renderPass.copy() as! MTLRenderPassDescriptor
|
||||||
|
|
|
@ -10,6 +10,17 @@
|
||||||
#include "ShaderTypes.hh"
|
#include "ShaderTypes.hh"
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
struct MarchingSquaresParameters {
|
||||||
|
/// Field size in pixels.
|
||||||
|
packed_uint2 pixelSize;
|
||||||
|
/// Field size in grid units.
|
||||||
|
packed_uint2 gridSize;
|
||||||
|
/// Size of a cell in pixels.
|
||||||
|
packed_uint2 cellSize;
|
||||||
|
/// Number of balls in the array above.
|
||||||
|
uint ballsCount;
|
||||||
|
};
|
||||||
|
|
||||||
struct Rect {
|
struct Rect {
|
||||||
float4x4 transform;
|
float4x4 transform;
|
||||||
float4 color;
|
float4 color;
|
||||||
|
@ -22,6 +33,37 @@ struct RasterizerData {
|
||||||
int instance;
|
int instance;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
kernel void
|
||||||
|
generateGridGeometry()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sample the field at regularly spaced intervals and populate `samples` with the resulting values.
|
||||||
|
kernel void
|
||||||
|
samplingKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]],
|
||||||
|
constant Ball *balls [[buffer(1)]],
|
||||||
|
device float *samples [[buffer(2)]],
|
||||||
|
uint2 position [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
// Find the midpoint of this grid cell.
|
||||||
|
const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0),
|
||||||
|
position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0));
|
||||||
|
|
||||||
|
// Sample the grid.
|
||||||
|
float sample = 0.0;
|
||||||
|
for (uint i = 0; i < parameters.ballsCount; i++) {
|
||||||
|
constant Ball &ball = balls[i];
|
||||||
|
float r2 = ball.z * ball.z;
|
||||||
|
float xDiff = point.x - ball.x;
|
||||||
|
float yDiff = point.y - ball.y;
|
||||||
|
sample += r2 / ((xDiff * xDiff) + (yDiff * yDiff));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Playing a bit fast and loose with these values here. The compute grid is the size of the grid itself, so parameters.gridSize == [[threads_per_grid]].
|
||||||
|
uint idx = position.y * parameters.gridSize.x + position.x;
|
||||||
|
samples[idx] = sample;
|
||||||
|
}
|
||||||
|
|
||||||
vertex RasterizerData
|
vertex RasterizerData
|
||||||
gridVertexShader(constant Vertex *vertexes [[buffer(0)]],
|
gridVertexShader(constant Vertex *vertexes [[buffer(0)]],
|
||||||
constant Rect *rects [[buffer(1)]],
|
constant Rect *rects [[buffer(1)]],
|
||||||
|
|
|
@ -11,6 +11,9 @@
|
||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
/// A Ball is a float 3-tuple: (x, y, r).
|
||||||
|
typedef float3 Ball;
|
||||||
|
|
||||||
struct Vertex {
|
struct Vertex {
|
||||||
float2 position;
|
float2 position;
|
||||||
float2 textureCoordinate;
|
float2 textureCoordinate;
|
||||||
|
|
|
@ -39,8 +39,6 @@ struct Parameters {
|
||||||
float3x3 colorTransform;
|
float3x3 colorTransform;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef float3 Ball;
|
|
||||||
|
|
||||||
#pragma mark - Vertex
|
#pragma mark - Vertex
|
||||||
|
|
||||||
vertex RasterizerData
|
vertex RasterizerData
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue