Skip to content

Commit 2a91259

Browse files
Fix float16 buffer initialization path (#1112)
The float16 buffers initialized with FILL and SERIES_FROM were handled in the integer path. This PR fix float16 handling in BUFFER ... FILL and BUFFER ... SERIES_FROM
1 parent 53a4c89 commit 2a91259

3 files changed

Lines changed: 74 additions & 4 deletions

File tree

src/amberscript/parser.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,15 +2952,16 @@ Result Parser::ParseBufferInitializerFill(Buffer* buffer,
29522952
}
29532953

29542954
auto fmt = buffer->GetFormat();
2955-
bool is_double_data = fmt->IsFloat32() || fmt->IsFloat64();
2955+
bool is_float_data =
2956+
fmt->IsFloat16() || fmt->IsFloat32() || fmt->IsFloat64();
29562957

29572958
// Inflate the size because our items are multi-dimensional.
29582959
size_in_items = size_in_items * fmt->InputNeededPerElement();
29592960

29602961
std::vector<Value> values;
29612962
values.resize(size_in_items);
29622963
for (size_t i = 0; i < size_in_items; ++i) {
2963-
if (is_double_data) {
2964+
if (is_float_data) {
29642965
values[i].SetDoubleValue(token->AsDouble());
29652966
} else {
29662967
values[i].SetIntValue(token->AsUint64());
@@ -2994,7 +2995,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer,
29942995
auto n = type->AsNumber();
29952996
FormatMode mode = n->GetFormatMode();
29962997
uint32_t num_bits = n->NumBits();
2997-
if (type::Type::IsFloat32(mode, num_bits) ||
2998+
if (type::Type::IsFloat16(mode, num_bits) ||
2999+
type::Type::IsFloat32(mode, num_bits) ||
29983000
type::Type::IsFloat64(mode, num_bits)) {
29993001
counter.SetDoubleValue(token->AsDouble());
30003002
} else {
@@ -3020,7 +3022,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer,
30203022
std::vector<Value> values;
30213023
values.resize(size_in_items);
30223024
for (size_t i = 0; i < size_in_items; ++i) {
3023-
if (type::Type::IsFloat32(mode, num_bits) ||
3025+
if (type::Type::IsFloat16(mode, num_bits) ||
3026+
type::Type::IsFloat32(mode, num_bits) ||
30243027
type::Type::IsFloat64(mode, num_bits)) {
30253028
double value = counter.AsDouble();
30263029
values[i].SetDoubleValue(value);

src/amberscript/parser_buffer_test.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "gtest/gtest.h"
1616
#include "src/amberscript/parser.h"
17+
#include "src/float16_helper.h"
1718

1819
namespace amber {
1920
namespace amberscript {
@@ -400,6 +401,34 @@ TEST_F(AmberScriptParserTest, BufferFillFloat) {
400401
}
401402
}
402403

404+
TEST_F(AmberScriptParserTest, BufferFillFloat16) {
405+
std::string in = "BUFFER my_buffer DATA_TYPE float16 SIZE 5 FILL 5.5";
406+
407+
Parser parser;
408+
Result r = parser.Parse(in);
409+
ASSERT_TRUE(r.IsSuccess()) << r.Error();
410+
411+
auto script = parser.GetScript();
412+
const auto& buffers = script->GetBuffers();
413+
ASSERT_EQ(1U, buffers.size());
414+
415+
ASSERT_TRUE(buffers[0] != nullptr);
416+
417+
auto* buffer = buffers[0].get();
418+
EXPECT_EQ("my_buffer", buffer->GetName());
419+
EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout());
420+
EXPECT_EQ(5U, buffer->ElementCount());
421+
EXPECT_EQ(5U, buffer->ValueCount());
422+
EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes());
423+
424+
const auto* data = buffer->GetValues<uint16_t>();
425+
for (size_t i = 0; i < buffer->ValueCount(); ++i) {
426+
EXPECT_FLOAT_EQ(
427+
5.5f, float16::HexFloatToFloat(
428+
reinterpret_cast<const uint8_t*>(&data[i]), 16));
429+
}
430+
}
431+
403432
TEST_F(AmberScriptParserTest, BufferSeries) {
404433
std::string in =
405434
"BUFFER my_buffer DATA_TYPE uint8 SIZE 5 SERIES_FROM 2 INC_BY 1";
@@ -461,6 +490,38 @@ TEST_F(AmberScriptParserTest, BufferSeriesFloat) {
461490
}
462491
}
463492

493+
TEST_F(AmberScriptParserTest, BufferSeriesFloat16) {
494+
std::string in =
495+
"BUFFER my_buffer DATA_TYPE float16 SIZE 5 SERIES_FROM 2.5 INC_BY "
496+
"0.5";
497+
498+
Parser parser;
499+
Result r = parser.Parse(in);
500+
ASSERT_TRUE(r.IsSuccess()) << r.Error();
501+
502+
auto script = parser.GetScript();
503+
const auto& buffers = script->GetBuffers();
504+
ASSERT_EQ(1U, buffers.size());
505+
506+
ASSERT_TRUE(buffers[0] != nullptr);
507+
508+
auto* buffer = buffers[0].get();
509+
EXPECT_EQ("my_buffer", buffer->GetName());
510+
EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout());
511+
EXPECT_EQ(5U, buffer->ElementCount());
512+
EXPECT_EQ(5U, buffer->ValueCount());
513+
EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes());
514+
515+
std::vector<float> results = {2.5f, 3.0f, 3.5f, 4.0f, 4.5f};
516+
const auto* data = buffer->GetValues<uint16_t>();
517+
ASSERT_EQ(results.size(), buffer->ValueCount());
518+
for (size_t i = 0; i < results.size(); ++i) {
519+
EXPECT_FLOAT_EQ(
520+
results[i], float16::HexFloatToFloat(
521+
reinterpret_cast<const uint8_t*>(&data[i]), 16));
522+
}
523+
}
524+
464525
TEST_F(AmberScriptParserTest, BufferMultipleBuffers) {
465526
std::string in = R"(
466527
BUFFER color_buffer DATA_TYPE uint8 SIZE 5 FILL 5

src/format.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ class Format {
199199
type::Type::IsUint64(type_->AsNumber()->GetFormatMode(),
200200
type_->AsNumber()->NumBits());
201201
}
202+
/// Returns true if all components of this format are a 16 bit float.
203+
bool IsFloat16() const {
204+
return type_->IsNumber() &&
205+
type::Type::IsFloat16(type_->AsNumber()->GetFormatMode(),
206+
type_->AsNumber()->NumBits());
207+
}
202208
/// Returns true if all components of this format are a 32 bit float.
203209
bool IsFloat32() const {
204210
return type_->IsNumber() &&

0 commit comments

Comments
 (0)