Move sampling to a compute kernel

This commit is contained in:
Eryn Wells 2018-10-14 17:26:12 -07:00
parent d60b2d8744
commit 275b260cf9
5 changed files with 121 additions and 44 deletions

View file

@ -8,6 +8,7 @@
import Foundation import Foundation
import Metal import Metal
import simd
class MarchingSquares { class MarchingSquares {
private var field: Field private var field: Field
@ -15,10 +16,13 @@ class MarchingSquares {
private var semaphore: DispatchSemaphore private var semaphore: DispatchSemaphore
private var samplingPipeline: MTLComputePipelineState?
private var parametersBuffer: MTLBuffer?
/// Samples of the field's current state. /// Samples of the field's current state.
private(set) var samplesBuffer: MTLBuffer? private(set) var samplesBuffer: MTLBuffer?
/// Indexes of geometry to render. /// Indexes of geometry to render.
private(set) var indexes: MTLTexture? private(set) var indexBuffer: MTLBuffer?
private(set) var gridGeometry: MTLBuffer? private(set) var gridGeometry: MTLBuffer?
@ -41,31 +45,45 @@ class MarchingSquares {
semaphore = DispatchSemaphore(value: 1) semaphore = DispatchSemaphore(value: 1)
} }
func setupMetal(withDevice device: MTLDevice) { func setupMetal(withDevice device: MTLDevice, library: MTLLibrary) {
// let samplesDesc = MTLTextureDescriptor() guard let samplingFunction = library.makeFunction(name: "samplingKernel") else {
// samplesDesc.textureType = .type2D fatalError("Couldn't get samplingKernel function from library")
// samplesDesc.width = xSamples }
// samplesDesc.height = ySamples do {
// samplesDesc.pixelFormat = .r32Float samplingPipeline = try device.makeComputePipelineState(function: samplingFunction)
// samples = device.makeTexture(descriptor: samplesDesc) } catch let e {
// fatalError("Error building compute pipeline state for sampling kernel: \(e)")
// let indexesDesc = MTLTextureDescriptor() }
// indexesDesc.textureType = .type2D
// indexesDesc.width = xSamples - 1 let parametersLength = MemoryLayout<simd.packed_int2>.stride * 3 + MemoryLayout<simd.uint>.stride
// indexesDesc.height = ySamples - 1 parametersBuffer = device.makeBuffer(length: parametersLength, options: .storageModeShared)
// indexesDesc.pixelFormat = .a8Unorm populateParametersBuffer()
// indexes = device.makeTexture(descriptor: indexesDesc)
} }
func fieldDidResize() { func fieldDidResize() {
guard let device = gridGeometry?.device else { guard let device = gridGeometry?.device else {
return return
} }
populateParametersBuffer()
populateGrid(withDevice: device) populateGrid(withDevice: device)
populateSamples(withDevice: device) populateSamples(withDevice: device)
lastSamplesCount = samplesCount lastSamplesCount = samplesCount
} }
func populateParametersBuffer() {
guard let buffer = parametersBuffer else {
print("Tried to copy parameters buffer before buffer was allocated!")
return
}
let params: [uint] = [
field.size.x, field.size.y,
uint(xSamples), uint(ySamples),
sampleGridSize.x, sampleGridSize.y,
uint(field.balls.count)
]
memcpy(buffer.contents(), params, MemoryLayout<uint>.stride * params.count)
}
func populateGrid(withDevice device: MTLDevice) { func populateGrid(withDevice device: MTLDevice) {
guard lastSamplesCount != samplesCount else { guard lastSamplesCount != samplesCount else {
return return
@ -97,24 +115,45 @@ class MarchingSquares {
} }
func populateSamples(withDevice device: MTLDevice) { func populateSamples(withDevice device: MTLDevice) {
var samples = [Float]() // var samples = [Float]()
samples.reserveCapacity(samplesCount) // samples.reserveCapacity(samplesCount)
for ys in 0..<ySamples { // for ys in 0..<ySamples {
let y = Float(ys * Int(sampleGridSize.y)) // let y = Float(ys * Int(sampleGridSize.y))
for xs in 0..<xSamples { // for xs in 0..<xSamples {
let x = Float(xs * Int(sampleGridSize.x)) // let x = Float(xs * Int(sampleGridSize.x))
let sample = field.sample(at: Float2(x: x, y: y)) // let sample = field.sample(at: Float2(x: x, y: y))
samples.append(sample) // samples.append(sample)
} // }
} // }
let samplesLength = MemoryLayout<Float>.stride * samplesCount let samplesLength = MemoryLayout<Float>.stride * samplesCount
if let buffer = device.makeBuffer(length: MemoryLayout<Float>.stride * samples.count, options: .storageModeShared) { samplesBuffer = device.makeBuffer(length: samplesLength, options: .storageModePrivate)
memcpy(buffer.contents(), samples, samplesLength) if samplesBuffer == nil {
samplesBuffer = buffer fatalError("Couldn't create samplesBuffer!")
} else {
fatalError("Couldn't create buffer for samples")
} }
} }
func encodeSamplingKernel(intoBuffer buffer: MTLCommandBuffer) {
guard let samplingPipeline = samplingPipeline else {
print("Encode called before sampling pipeline was set up!")
return
}
guard let encoder = buffer.makeComputeCommandEncoder() else {
print("Couldn't create compute encoder")
return
}
encoder.label = "Sample Field"
encoder.setComputePipelineState(samplingPipeline)
encoder.setBuffer(parametersBuffer, offset: 0, index: 0)
encoder.setBuffer(field.ballBuffer, offset: 0, index: 1)
encoder.setBuffer(samplesBuffer, offset: 0, index: 2)
// Dispatch!
let gridSize = MTLSize(width: xSamples, height: ySamples, depth: 1)
let threadgroupSize = MTLSize(width: xSamples, height: 1, depth: 1)
encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
} }

View file

@ -43,7 +43,7 @@ public class Renderer: NSObject, MTKViewDelegate {
configureMarchingSquaresPipeline(withPixelFormat: view.colorPixelFormat) configureMarchingSquaresPipeline(withPixelFormat: view.colorPixelFormat)
delegate.field.setupMetal(withDevice: device) delegate.field.setupMetal(withDevice: device)
delegate.marchingSquares.setupMetal(withDevice: device) delegate.marchingSquares.setupMetal(withDevice: device, library: library)
delegate.marchingSquares.populateGrid(withDevice: device) delegate.marchingSquares.populateGrid(withDevice: device)
delegate.marchingSquares.populateSamples(withDevice: device) delegate.marchingSquares.populateSamples(withDevice: device)
} }
@ -51,9 +51,9 @@ public class Renderer: NSObject, MTKViewDelegate {
private var device: MTLDevice private var device: MTLDevice
private lazy var library: MTLLibrary? = { private lazy var library: MTLLibrary = {
let bundle = Bundle(for: type(of: self)) let bundle = Bundle(for: type(of: self))
return try? device.makeDefaultLibrary(bundle: bundle) return try! device.makeDefaultLibrary(bundle: bundle)
}() }()
private var commandQueue: MTLCommandQueue private var commandQueue: MTLCommandQueue
@ -90,10 +90,6 @@ public class Renderer: NSObject, MTKViewDelegate {
} }
private func configurePixelPipeline(withPixelFormat pixelFormat: MTLPixelFormat) { private func configurePixelPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
guard let library = library else {
fatalError("Couldn't get Metal library")
}
let vertexShader = library.makeFunction(name: "passthroughVertexShader") let vertexShader = library.makeFunction(name: "passthroughVertexShader")
let fragmentShader = library.makeFunction(name: "sampleToColorShader") let fragmentShader = library.makeFunction(name: "sampleToColorShader")
@ -124,10 +120,6 @@ public class Renderer: NSObject, MTKViewDelegate {
} }
private func configureMarchingSquaresPipeline(withPixelFormat pixelFormat: MTLPixelFormat) { private func configureMarchingSquaresPipeline(withPixelFormat pixelFormat: MTLPixelFormat) {
guard let library = library else {
fatalError("Couldn't get Metal library")
}
guard let vertexShader = library.makeFunction(name: "gridVertexShader"), guard let vertexShader = library.makeFunction(name: "gridVertexShader"),
let fragmentShader = library.makeFunction(name: "gridFragmentShader") else { let fragmentShader = library.makeFunction(name: "gridFragmentShader") else {
fatalError("Couldn't get marching squares vertex or fragment function from library") fatalError("Couldn't get marching squares vertex or fragment function from library")
@ -209,7 +201,7 @@ public class Renderer: NSObject, MTKViewDelegate {
field.update() field.update()
if let ms = delegate?.marchingSquares { if let ms = delegate?.marchingSquares {
ms.populateSamples(withDevice: device) ms.populateParametersBuffer()
} }
if self.pixelGeometry == nil { if self.pixelGeometry == nil {
@ -233,6 +225,9 @@ public class Renderer: NSObject, MTKViewDelegate {
// } // }
if let marchingSquares = delegate?.marchingSquares { if let marchingSquares = delegate?.marchingSquares {
// Compute samples first.
marchingSquares.encodeSamplingKernel(intoBuffer: buffer)
// Render the marching squares version over top of the pixel version. // Render the marching squares version over top of the pixel version.
// We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other. // We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other.
let pass = renderPass.copy() as! MTLRenderPassDescriptor let pass = renderPass.copy() as! MTLRenderPassDescriptor

View file

@ -10,6 +10,17 @@
#include "ShaderTypes.hh" #include "ShaderTypes.hh"
using namespace metal; using namespace metal;
struct MarchingSquaresParameters {
/// Field size in pixels.
packed_uint2 pixelSize;
/// Field size in grid units.
packed_uint2 gridSize;
/// Size of a cell in pixels.
packed_uint2 cellSize;
/// Number of balls in the array above.
uint ballsCount;
};
struct Rect { struct Rect {
float4x4 transform; float4x4 transform;
float4 color; float4 color;
@ -22,6 +33,37 @@ struct RasterizerData {
int instance; int instance;
}; };
kernel void
generateGridGeometry()
{
}
/// Sample the field at regularly spaced intervals and populate `samples` with the resulting values.
kernel void
samplingKernel(constant MarchingSquaresParameters &parameters [[buffer(0)]],
constant Ball *balls [[buffer(1)]],
device float *samples [[buffer(2)]],
uint2 position [[thread_position_in_grid]])
{
// Find the midpoint of this grid cell.
const float2 point = float2(position.x * parameters.cellSize.x + (parameters.cellSize.x / 2.0),
position.y * parameters.cellSize.y + (parameters.cellSize.y / 2.0));
// Sample the grid.
float sample = 0.0;
for (uint i = 0; i < parameters.ballsCount; i++) {
constant Ball &ball = balls[i];
float r2 = ball.z * ball.z;
float xDiff = point.x - ball.x;
float yDiff = point.y - ball.y;
sample += r2 / ((xDiff * xDiff) + (yDiff * yDiff));
}
// Playing a bit fast and loose with these values here. The compute grid is the size of the grid itself, so parameters.gridSize == [[threads_per_grid]].
uint idx = position.y * parameters.gridSize.x + position.x;
samples[idx] = sample;
}
vertex RasterizerData vertex RasterizerData
gridVertexShader(constant Vertex *vertexes [[buffer(0)]], gridVertexShader(constant Vertex *vertexes [[buffer(0)]],
constant Rect *rects [[buffer(1)]], constant Rect *rects [[buffer(1)]],

View file

@ -11,6 +11,9 @@
#include <metal_stdlib> #include <metal_stdlib>
/// A Ball is a float 3-tuple: (x, y, r).
typedef float3 Ball;
struct Vertex { struct Vertex {
float2 position; float2 position;
float2 textureCoordinate; float2 textureCoordinate;

View file

@ -39,8 +39,6 @@ struct Parameters {
float3x3 colorTransform; float3x3 colorTransform;
}; };
typedef float3 Ball;
#pragma mark - Vertex #pragma mark - Vertex
vertex RasterizerData vertex RasterizerData